add count_leading_zeros, count_trailing_zeros, and count_ones implementations
authorJacob Lifshay <programmerjake@gmail.com>
Tue, 18 May 2021 04:13:12 +0000 (21:13 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Tue, 18 May 2021 04:13:12 +0000 (21:13 -0700)
src/algorithms.rs
src/algorithms/base.rs
src/algorithms/integer.rs [new file with mode: 0644]
src/prim.rs
src/stdsimd.rs

index cfa78b85f0dd756a250692bc6b110d9054cb5e53..4278ac2da94d1623c3e78f9f4dd81a51bad48821 100644 (file)
@@ -1,3 +1,4 @@
 pub mod base;
 pub mod ilogb;
+pub mod integer;
 pub mod trig_pi;
index b4ec103aac2f8ce2bc7d198753daab9d456d7fed..4ebd8493ee3d7dc12e0753e469db896253e6f14f 100644 (file)
@@ -88,7 +88,9 @@ pub fn floor<
     let offset_value: VecF = v.abs() + offset;
     let rounded = (offset_value - offset).copy_sign(v);
     let need_round_down = v.lt(rounded);
-    let in_range_value = need_round_down.select(rounded - ctx.make(1.to()), rounded).copy_sign(v);
+    let in_range_value = need_round_down
+        .select(rounded - ctx.make(1.to()), rounded)
+        .copy_sign(v);
     big.select(v, in_range_value)
 }
 
@@ -108,7 +110,9 @@ pub fn ceil<
     let offset_value: VecF = v.abs() + offset;
     let rounded = (offset_value - offset).copy_sign(v);
     let need_round_up = v.gt(rounded);
-    let in_range_value = need_round_up.select(rounded + ctx.make(1.to()), rounded).copy_sign(v);
+    let in_range_value = need_round_up
+        .select(rounded + ctx.make(1.to()), rounded)
+        .copy_sign(v);
     big.select(v, in_range_value)
 }
 
diff --git a/src/algorithms/integer.rs b/src/algorithms/integer.rs
new file mode 100644 (file)
index 0000000..1091723
--- /dev/null
@@ -0,0 +1,341 @@
+use crate::{
+    prim::PrimUInt,
+    traits::{Context, ConvertFrom, ConvertTo, Make, SInt, Select, UInt},
+};
+
+pub fn count_leading_zeros_uint<
+    Ctx: Context,
+    VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
+    PrimU: PrimUInt,
+>(
+    ctx: Ctx,
+    mut v: VecU,
+) -> VecU {
+    let mut retval: VecU = ctx.make(PrimU::BITS);
+    let mut bits = PrimU::BITS;
+    while bits > 1.to() {
+        bits /= 2.to();
+        let limit = PrimU::ONE << bits;
+        let found = v.ge(ctx.make(limit));
+        let shift: VecU = found.select(ctx.make(bits), ctx.make(0.to()));
+        retval -= shift;
+        v >>= shift;
+    }
+    let nonzero = v.ne(ctx.make(0.to()));
+    retval - nonzero.select(ctx.make(1.to()), ctx.make(0.to()))
+}
+
+pub fn count_leading_zeros_sint<
+    Ctx: Context,
+    VecU: UInt + Make<Context = Ctx> + ConvertFrom<VecS>,
+    VecS: SInt<UnsignedType = VecU> + ConvertFrom<VecU>,
+>(
+    ctx: Ctx,
+    v: VecS,
+) -> VecS {
+    count_leading_zeros_uint(ctx, VecU::cvt_from(v)).to()
+}
+
+pub fn count_trailing_zeros_uint<
+    Ctx: Context,
+    VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
+    PrimU: PrimUInt,
+>(
+    ctx: Ctx,
+    mut v: VecU,
+) -> VecU {
+    let mut retval: VecU = ctx.make(PrimU::ZERO);
+    let mut bits = PrimU::BITS;
+    while bits > 1.to() {
+        bits /= 2.to();
+        let mask = (PrimU::ONE << bits) - 1.to();
+        let zero = (v & ctx.make(mask)).eq(ctx.make(0.to()));
+        let shift: VecU = zero.select(ctx.make(bits), ctx.make(0.to()));
+        retval += shift;
+        v >>= shift;
+    }
+    let zero = v.eq(ctx.make(0.to()));
+    retval + zero.select(ctx.make(1.to()), ctx.make(0.to()))
+}
+
+pub fn count_trailing_zeros_sint<
+    Ctx: Context,
+    VecU: UInt + Make<Context = Ctx> + ConvertFrom<VecS>,
+    VecS: SInt<UnsignedType = VecU> + ConvertFrom<VecU>,
+>(
+    ctx: Ctx,
+    v: VecS,
+) -> VecS {
+    count_trailing_zeros_uint(ctx, VecU::cvt_from(v)).to()
+}
+
+pub fn count_ones_uint<
+    Ctx: Context,
+    VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
+    PrimU: PrimUInt,
+>(
+    ctx: Ctx,
+    mut v: VecU,
+) -> VecU {
+    assert!(PrimU::BITS <= 64.to());
+    assert!(PrimU::BITS >= 8.to());
+    const SPLAT_BYTES_MULTIPLIER: u64 = u64::from_le_bytes([1; 8]);
+    const EVERY_OTHER_BIT_MASK: u64 = 0x55 * SPLAT_BYTES_MULTIPLIER;
+    const TWO_OUT_OF_FOUR_BITS_MASK: u64 = 0x33 * SPLAT_BYTES_MULTIPLIER;
+    const FOUR_OUT_OF_EIGHT_BITS_MASK: u64 = 0x0F * SPLAT_BYTES_MULTIPLIER;
+    // algorithm derived from popcount64c at https://en.wikipedia.org/wiki/Hamming_weight
+    v -= (v >> ctx.make(1.to())) & ctx.make(EVERY_OTHER_BIT_MASK.to());
+    v = (v & ctx.make(TWO_OUT_OF_FOUR_BITS_MASK.to()))
+        + ((v >> ctx.make(2.to())) & ctx.make(TWO_OUT_OF_FOUR_BITS_MASK.to()));
+    v = (v & ctx.make(FOUR_OUT_OF_EIGHT_BITS_MASK.to()))
+        + ((v >> ctx.make(4.to())) & ctx.make(FOUR_OUT_OF_EIGHT_BITS_MASK.to()));
+    if PrimU::BITS > 8.to() {
+        v * ctx.make(SPLAT_BYTES_MULTIPLIER.to()) >> ctx.make(PrimU::BITS - 8.to())
+    } else {
+        v
+    }
+}
+
+pub fn count_ones_sint<
+    Ctx: Context,
+    VecU: UInt + Make<Context = Ctx> + ConvertFrom<VecS>,
+    VecS: SInt<UnsignedType = VecU> + ConvertFrom<VecU>,
+>(
+    ctx: Ctx,
+    v: VecS,
+) -> VecS {
+    count_ones_uint(ctx, VecU::cvt_from(v)).to()
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::scalar::{Scalar, Value};
+
+    #[test]
+    fn test_count_leading_zeros_u16() {
+        for v in 0..=u16::MAX {
+            assert_eq!(
+                v.leading_zeros() as u16,
+                count_leading_zeros_uint(Scalar, Value(v)).0,
+                "v = {:#X}",
+                v,
+            );
+        }
+    }
+
+    #[test]
+    fn test_count_trailing_zeros_u16() {
+        for v in 0..=u16::MAX {
+            assert_eq!(
+                v.trailing_zeros() as u16,
+                count_trailing_zeros_uint(Scalar, Value(v)).0,
+                "v = {:#X}",
+                v,
+            );
+        }
+    }
+
+    #[test]
+    fn test_count_ones_u16() {
+        for v in 0..=u16::MAX {
+            assert_eq!(
+                v.count_ones() as u16,
+                count_ones_uint(Scalar, Value(v)).0,
+                "v = {:#X}",
+                v,
+            );
+        }
+    }
+}
+
+#[cfg(all(feature = "ir", test))]
+mod ir_tests {
+    use super::*;
+    use crate::ir::{IrContext, IrFunction, IrVecI64, IrVecU64, IrVecU8};
+    use std::{format, println};
+
+    #[test]
+    fn test_display_count_leading_zeros_i64() {
+        let ctx = IrContext::new();
+        fn make_it<'ctx>(ctx: &'ctx IrContext<'ctx>) -> IrFunction<'ctx> {
+            let f: fn(&'ctx IrContext<'ctx>, IrVecI64<'ctx>) -> IrVecI64<'ctx> =
+                count_leading_zeros_sint;
+            IrFunction::make(ctx, f)
+        }
+        let text = format!("\n{}", make_it(&ctx));
+        println!("{}", text);
+        assert_eq!(
+            text,
+            r"
+function(in<arg_0>: vec<I64>) -> vec<I64> {
+    op_0: vec<U64> = Cast in<arg_0>
+    op_1: vec<Bool> = CompareGe op_0, splat(0x100000000_u64)
+    op_2: vec<U64> = Select op_1, splat(0x20_u64), splat(0x0_u64)
+    op_3: vec<U64> = Sub splat(0x40_u64), op_2
+    op_4: vec<U64> = Shr op_0, op_2
+    op_5: vec<Bool> = CompareGe op_4, splat(0x10000_u64)
+    op_6: vec<U64> = Select op_5, splat(0x10_u64), splat(0x0_u64)
+    op_7: vec<U64> = Sub op_3, op_6
+    op_8: vec<U64> = Shr op_4, op_6
+    op_9: vec<Bool> = CompareGe op_8, splat(0x100_u64)
+    op_10: vec<U64> = Select op_9, splat(0x8_u64), splat(0x0_u64)
+    op_11: vec<U64> = Sub op_7, op_10
+    op_12: vec<U64> = Shr op_8, op_10
+    op_13: vec<Bool> = CompareGe op_12, splat(0x10_u64)
+    op_14: vec<U64> = Select op_13, splat(0x4_u64), splat(0x0_u64)
+    op_15: vec<U64> = Sub op_11, op_14
+    op_16: vec<U64> = Shr op_12, op_14
+    op_17: vec<Bool> = CompareGe op_16, splat(0x4_u64)
+    op_18: vec<U64> = Select op_17, splat(0x2_u64), splat(0x0_u64)
+    op_19: vec<U64> = Sub op_15, op_18
+    op_20: vec<U64> = Shr op_16, op_18
+    op_21: vec<Bool> = CompareGe op_20, splat(0x2_u64)
+    op_22: vec<U64> = Select op_21, splat(0x1_u64), splat(0x0_u64)
+    op_23: vec<U64> = Sub op_19, op_22
+    op_24: vec<U64> = Shr op_20, op_22
+    op_25: vec<Bool> = CompareNe op_24, splat(0x0_u64)
+    op_26: vec<U64> = Select op_25, splat(0x1_u64), splat(0x0_u64)
+    op_27: vec<U64> = Sub op_23, op_26
+    op_28: vec<I64> = Cast op_27
+    Return op_28
+}
+"
+        );
+    }
+
+    #[test]
+    fn test_display_count_leading_zeros_u8() {
+        let ctx = IrContext::new();
+        fn make_it<'ctx>(ctx: &'ctx IrContext<'ctx>) -> IrFunction<'ctx> {
+            let f: fn(&'ctx IrContext<'ctx>, IrVecU8<'ctx>) -> IrVecU8<'ctx> =
+                count_leading_zeros_uint;
+            IrFunction::make(ctx, f)
+        }
+        let text = format!("\n{}", make_it(&ctx));
+        println!("{}", text);
+        assert_eq!(
+            text,
+            r"
+function(in<arg_0>: vec<U8>) -> vec<U8> {
+    op_0: vec<Bool> = CompareGe in<arg_0>, splat(0x10_u8)
+    op_1: vec<U8> = Select op_0, splat(0x4_u8), splat(0x0_u8)
+    op_2: vec<U8> = Sub splat(0x8_u8), op_1
+    op_3: vec<U8> = Shr in<arg_0>, op_1
+    op_4: vec<Bool> = CompareGe op_3, splat(0x4_u8)
+    op_5: vec<U8> = Select op_4, splat(0x2_u8), splat(0x0_u8)
+    op_6: vec<U8> = Sub op_2, op_5
+    op_7: vec<U8> = Shr op_3, op_5
+    op_8: vec<Bool> = CompareGe op_7, splat(0x2_u8)
+    op_9: vec<U8> = Select op_8, splat(0x1_u8), splat(0x0_u8)
+    op_10: vec<U8> = Sub op_6, op_9
+    op_11: vec<U8> = Shr op_7, op_9
+    op_12: vec<Bool> = CompareNe op_11, splat(0x0_u8)
+    op_13: vec<U8> = Select op_12, splat(0x1_u8), splat(0x0_u8)
+    op_14: vec<U8> = Sub op_10, op_13
+    Return op_14
+}
+"
+        );
+    }
+
+    #[test]
+    fn test_display_count_trailing_zeros_u8() {
+        let ctx = IrContext::new();
+        fn make_it<'ctx>(ctx: &'ctx IrContext<'ctx>) -> IrFunction<'ctx> {
+            let f: fn(&'ctx IrContext<'ctx>, IrVecU8<'ctx>) -> IrVecU8<'ctx> =
+                count_trailing_zeros_uint;
+            IrFunction::make(ctx, f)
+        }
+        let text = format!("\n{}", make_it(&ctx));
+        println!("{}", text);
+        assert_eq!(
+            text,
+            r"
+function(in<arg_0>: vec<U8>) -> vec<U8> {
+    op_0: vec<U8> = And in<arg_0>, splat(0xF_u8)
+    op_1: vec<Bool> = CompareEq op_0, splat(0x0_u8)
+    op_2: vec<U8> = Select op_1, splat(0x4_u8), splat(0x0_u8)
+    op_3: vec<U8> = Add splat(0x0_u8), op_2
+    op_4: vec<U8> = Shr in<arg_0>, op_2
+    op_5: vec<U8> = And op_4, splat(0x3_u8)
+    op_6: vec<Bool> = CompareEq op_5, splat(0x0_u8)
+    op_7: vec<U8> = Select op_6, splat(0x2_u8), splat(0x0_u8)
+    op_8: vec<U8> = Add op_3, op_7
+    op_9: vec<U8> = Shr op_4, op_7
+    op_10: vec<U8> = And op_9, splat(0x1_u8)
+    op_11: vec<Bool> = CompareEq op_10, splat(0x0_u8)
+    op_12: vec<U8> = Select op_11, splat(0x1_u8), splat(0x0_u8)
+    op_13: vec<U8> = Add op_8, op_12
+    op_14: vec<U8> = Shr op_9, op_12
+    op_15: vec<Bool> = CompareEq op_14, splat(0x0_u8)
+    op_16: vec<U8> = Select op_15, splat(0x1_u8), splat(0x0_u8)
+    op_17: vec<U8> = Add op_13, op_16
+    Return op_17
+}
+"
+        );
+    }
+
+    #[test]
+    fn test_display_count_ones_u8() {
+        let ctx = IrContext::new();
+        fn make_it<'ctx>(ctx: &'ctx IrContext<'ctx>) -> IrFunction<'ctx> {
+            let f: fn(&'ctx IrContext<'ctx>, IrVecU8<'ctx>) -> IrVecU8<'ctx> = count_ones_uint;
+            IrFunction::make(ctx, f)
+        }
+        let text = format!("\n{}", make_it(&ctx));
+        println!("{}", text);
+        assert_eq!(
+            text,
+            r"
+function(in<arg_0>: vec<U8>) -> vec<U8> {
+    op_0: vec<U8> = Shr in<arg_0>, splat(0x1_u8)
+    op_1: vec<U8> = And op_0, splat(0x55_u8)
+    op_2: vec<U8> = Sub in<arg_0>, op_1
+    op_3: vec<U8> = And op_2, splat(0x33_u8)
+    op_4: vec<U8> = Shr op_2, splat(0x2_u8)
+    op_5: vec<U8> = And op_4, splat(0x33_u8)
+    op_6: vec<U8> = Add op_3, op_5
+    op_7: vec<U8> = And op_6, splat(0xF_u8)
+    op_8: vec<U8> = Shr op_6, splat(0x4_u8)
+    op_9: vec<U8> = And op_8, splat(0xF_u8)
+    op_10: vec<U8> = Add op_7, op_9
+    Return op_10
+}
+"
+        );
+    }
+
+    #[test]
+    fn test_display_count_ones_u64() {
+        let ctx = IrContext::new();
+        fn make_it<'ctx>(ctx: &'ctx IrContext<'ctx>) -> IrFunction<'ctx> {
+            let f: fn(&'ctx IrContext<'ctx>, IrVecU64<'ctx>) -> IrVecU64<'ctx> = count_ones_uint;
+            IrFunction::make(ctx, f)
+        }
+        let text = format!("\n{}", make_it(&ctx));
+        println!("{}", text);
+        assert_eq!(
+            text,
+            r"
+function(in<arg_0>: vec<U64>) -> vec<U64> {
+    op_0: vec<U64> = Shr in<arg_0>, splat(0x1_u64)
+    op_1: vec<U64> = And op_0, splat(0x5555555555555555_u64)
+    op_2: vec<U64> = Sub in<arg_0>, op_1
+    op_3: vec<U64> = And op_2, splat(0x3333333333333333_u64)
+    op_4: vec<U64> = Shr op_2, splat(0x2_u64)
+    op_5: vec<U64> = And op_4, splat(0x3333333333333333_u64)
+    op_6: vec<U64> = Add op_3, op_5
+    op_7: vec<U64> = And op_6, splat(0xF0F0F0F0F0F0F0F_u64)
+    op_8: vec<U64> = Shr op_6, splat(0x4_u64)
+    op_9: vec<U64> = And op_8, splat(0xF0F0F0F0F0F0F0F_u64)
+    op_10: vec<U64> = Add op_7, op_9
+    op_11: vec<U64> = Mul op_10, splat(0x101010101010101_u64)
+    op_12: vec<U64> = Shr op_11, splat(0x38_u64)
+    Return op_12
+}
+"
+        );
+    }
+}
index ea7b6b94010b7bf78da9a1873dd8b158340d1285..7ba23e5b394ba9fe9c9ea1f18a43ee9fbacd9a3b 100644 (file)
@@ -91,6 +91,7 @@ pub trait PrimInt:
     const ONE: Self;
     const MIN: Self;
     const MAX: Self;
+    const BITS: Self;
 }
 
 pub trait PrimUInt: PrimInt + ConvertFrom<Self::SignedType> {
@@ -110,12 +111,14 @@ macro_rules! impl_int {
             const ONE: Self = 1;
             const MIN: Self = 0;
             const MAX: Self = !0;
+            const BITS: Self = (0 as $uint).count_zeros() as $uint;
         }
         impl PrimInt for $sint {
             const ZERO: Self = 0;
             const ONE: Self = 1;
             const MIN: Self = $sint::MIN;
             const MAX: Self = $sint::MAX;
+            const BITS: Self = (0 as $sint).count_zeros() as $sint;
         }
         impl PrimUInt for $uint {
             type SignedType = $sint;
index 3d757cc21b7a0dee350115ecd9e282f53174d98a..35692c9246c2bd562cd007ddde475f90f8ce4682 100644 (file)
@@ -520,7 +520,7 @@ macro_rules! impl_int_scalar {
 }
 
 macro_rules! impl_int_vector {
-    ($ty:ident) => {
+    ($ty:ident, $count_leading_zeros:ident, $count_trailing_zeros:ident, $count_ones:ident) => {
         impl<const LANES: usize> Int for Wrapper<$ty<LANES>, LANES>
         where
             SimdI8<LANES>: LanesAtMost32,
@@ -539,27 +539,15 @@ macro_rules! impl_int_vector {
             Mask64<LANES>: Mask,
         {
             fn leading_zeros(self) -> Self {
-                todo!()
+                crate::algorithms::integer::$count_leading_zeros(self.ctx(), self)
             }
 
             fn trailing_zeros(self) -> Self {
-                todo!()
+                crate::algorithms::integer::$count_trailing_zeros(self.ctx(), self)
             }
 
             fn count_ones(self) -> Self {
-                todo!()
-            }
-
-            fn leading_ones(self) -> Self {
-                todo!()
-            }
-
-            fn trailing_ones(self) -> Self {
-                todo!()
-            }
-
-            fn count_zeros(self) -> Self {
-                todo!()
+                crate::algorithms::integer::$count_ones(self.ctx(), self)
             }
         }
     };
@@ -567,8 +555,18 @@ macro_rules! impl_int_vector {
 
 macro_rules! impl_uint_sint_vector {
     ($uint:ident, $sint:ident) => {
-        impl_int_vector!($uint);
-        impl_int_vector!($sint);
+        impl_int_vector!(
+            $uint,
+            count_leading_zeros_uint,
+            count_trailing_zeros_uint,
+            count_ones_uint
+        );
+        impl_int_vector!(
+            $sint,
+            count_leading_zeros_sint,
+            count_trailing_zeros_sint,
+            count_ones_sint
+        );
         impl<const LANES: usize> UInt for Wrapper<$uint<LANES>, LANES>
         where
             SimdI8<LANES>: LanesAtMost32,