--- /dev/null
+use crate::{
+ prim::PrimUInt,
+ traits::{Context, ConvertFrom, ConvertTo, Make, SInt, Select, UInt},
+};
+
+pub fn count_leading_zeros_uint<
+ Ctx: Context,
+ VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
+ PrimU: PrimUInt,
+>(
+ ctx: Ctx,
+ mut v: VecU,
+) -> VecU {
+ let mut retval: VecU = ctx.make(PrimU::BITS);
+ let mut bits = PrimU::BITS;
+ while bits > 1.to() {
+ bits /= 2.to();
+ let limit = PrimU::ONE << bits;
+ let found = v.ge(ctx.make(limit));
+ let shift: VecU = found.select(ctx.make(bits), ctx.make(0.to()));
+ retval -= shift;
+ v >>= shift;
+ }
+ let nonzero = v.ne(ctx.make(0.to()));
+ retval - nonzero.select(ctx.make(1.to()), ctx.make(0.to()))
+}
+
+pub fn count_leading_zeros_sint<
+ Ctx: Context,
+ VecU: UInt + Make<Context = Ctx> + ConvertFrom<VecS>,
+ VecS: SInt<UnsignedType = VecU> + ConvertFrom<VecU>,
+>(
+ ctx: Ctx,
+ v: VecS,
+) -> VecS {
+ count_leading_zeros_uint(ctx, VecU::cvt_from(v)).to()
+}
+
+pub fn count_trailing_zeros_uint<
+ Ctx: Context,
+ VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
+ PrimU: PrimUInt,
+>(
+ ctx: Ctx,
+ mut v: VecU,
+) -> VecU {
+ let mut retval: VecU = ctx.make(PrimU::ZERO);
+ let mut bits = PrimU::BITS;
+ while bits > 1.to() {
+ bits /= 2.to();
+ let mask = (PrimU::ONE << bits) - 1.to();
+ let zero = (v & ctx.make(mask)).eq(ctx.make(0.to()));
+ let shift: VecU = zero.select(ctx.make(bits), ctx.make(0.to()));
+ retval += shift;
+ v >>= shift;
+ }
+ let zero = v.eq(ctx.make(0.to()));
+ retval + zero.select(ctx.make(1.to()), ctx.make(0.to()))
+}
+
+pub fn count_trailing_zeros_sint<
+ Ctx: Context,
+ VecU: UInt + Make<Context = Ctx> + ConvertFrom<VecS>,
+ VecS: SInt<UnsignedType = VecU> + ConvertFrom<VecU>,
+>(
+ ctx: Ctx,
+ v: VecS,
+) -> VecS {
+ count_trailing_zeros_uint(ctx, VecU::cvt_from(v)).to()
+}
+
+pub fn count_ones_uint<
+ Ctx: Context,
+ VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
+ PrimU: PrimUInt,
+>(
+ ctx: Ctx,
+ mut v: VecU,
+) -> VecU {
+ assert!(PrimU::BITS <= 64.to());
+ assert!(PrimU::BITS >= 8.to());
+ const SPLAT_BYTES_MULTIPLIER: u64 = u64::from_le_bytes([1; 8]);
+ const EVERY_OTHER_BIT_MASK: u64 = 0x55 * SPLAT_BYTES_MULTIPLIER;
+ const TWO_OUT_OF_FOUR_BITS_MASK: u64 = 0x33 * SPLAT_BYTES_MULTIPLIER;
+ const FOUR_OUT_OF_EIGHT_BITS_MASK: u64 = 0x0F * SPLAT_BYTES_MULTIPLIER;
+ // algorithm derived from popcount64c at https://en.wikipedia.org/wiki/Hamming_weight
+ v -= (v >> ctx.make(1.to())) & ctx.make(EVERY_OTHER_BIT_MASK.to());
+ v = (v & ctx.make(TWO_OUT_OF_FOUR_BITS_MASK.to()))
+ + ((v >> ctx.make(2.to())) & ctx.make(TWO_OUT_OF_FOUR_BITS_MASK.to()));
+ v = (v & ctx.make(FOUR_OUT_OF_EIGHT_BITS_MASK.to()))
+ + ((v >> ctx.make(4.to())) & ctx.make(FOUR_OUT_OF_EIGHT_BITS_MASK.to()));
+ if PrimU::BITS > 8.to() {
+ v * ctx.make(SPLAT_BYTES_MULTIPLIER.to()) >> ctx.make(PrimU::BITS - 8.to())
+ } else {
+ v
+ }
+}
+
+pub fn count_ones_sint<
+ Ctx: Context,
+ VecU: UInt + Make<Context = Ctx> + ConvertFrom<VecS>,
+ VecS: SInt<UnsignedType = VecU> + ConvertFrom<VecU>,
+>(
+ ctx: Ctx,
+ v: VecS,
+) -> VecS {
+ count_ones_uint(ctx, VecU::cvt_from(v)).to()
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::scalar::{Scalar, Value};
+
+ #[test]
+ fn test_count_leading_zeros_u16() {
+ for v in 0..=u16::MAX {
+ assert_eq!(
+ v.leading_zeros() as u16,
+ count_leading_zeros_uint(Scalar, Value(v)).0,
+ "v = {:#X}",
+ v,
+ );
+ }
+ }
+
+ #[test]
+ fn test_count_trailing_zeros_u16() {
+ for v in 0..=u16::MAX {
+ assert_eq!(
+ v.trailing_zeros() as u16,
+ count_trailing_zeros_uint(Scalar, Value(v)).0,
+ "v = {:#X}",
+ v,
+ );
+ }
+ }
+
+ #[test]
+ fn test_count_ones_u16() {
+ for v in 0..=u16::MAX {
+ assert_eq!(
+ v.count_ones() as u16,
+ count_ones_uint(Scalar, Value(v)).0,
+ "v = {:#X}",
+ v,
+ );
+ }
+ }
+}
+
+#[cfg(all(feature = "ir", test))]
+mod ir_tests {
+ use super::*;
+ use crate::ir::{IrContext, IrFunction, IrVecI64, IrVecU64, IrVecU8};
+ use std::{format, println};
+
+ #[test]
+ fn test_display_count_leading_zeros_i64() {
+ let ctx = IrContext::new();
+ fn make_it<'ctx>(ctx: &'ctx IrContext<'ctx>) -> IrFunction<'ctx> {
+ let f: fn(&'ctx IrContext<'ctx>, IrVecI64<'ctx>) -> IrVecI64<'ctx> =
+ count_leading_zeros_sint;
+ IrFunction::make(ctx, f)
+ }
+ let text = format!("\n{}", make_it(&ctx));
+ println!("{}", text);
+ assert_eq!(
+ text,
+ r"
+function(in<arg_0>: vec<I64>) -> vec<I64> {
+ op_0: vec<U64> = Cast in<arg_0>
+ op_1: vec<Bool> = CompareGe op_0, splat(0x100000000_u64)
+ op_2: vec<U64> = Select op_1, splat(0x20_u64), splat(0x0_u64)
+ op_3: vec<U64> = Sub splat(0x40_u64), op_2
+ op_4: vec<U64> = Shr op_0, op_2
+ op_5: vec<Bool> = CompareGe op_4, splat(0x10000_u64)
+ op_6: vec<U64> = Select op_5, splat(0x10_u64), splat(0x0_u64)
+ op_7: vec<U64> = Sub op_3, op_6
+ op_8: vec<U64> = Shr op_4, op_6
+ op_9: vec<Bool> = CompareGe op_8, splat(0x100_u64)
+ op_10: vec<U64> = Select op_9, splat(0x8_u64), splat(0x0_u64)
+ op_11: vec<U64> = Sub op_7, op_10
+ op_12: vec<U64> = Shr op_8, op_10
+ op_13: vec<Bool> = CompareGe op_12, splat(0x10_u64)
+ op_14: vec<U64> = Select op_13, splat(0x4_u64), splat(0x0_u64)
+ op_15: vec<U64> = Sub op_11, op_14
+ op_16: vec<U64> = Shr op_12, op_14
+ op_17: vec<Bool> = CompareGe op_16, splat(0x4_u64)
+ op_18: vec<U64> = Select op_17, splat(0x2_u64), splat(0x0_u64)
+ op_19: vec<U64> = Sub op_15, op_18
+ op_20: vec<U64> = Shr op_16, op_18
+ op_21: vec<Bool> = CompareGe op_20, splat(0x2_u64)
+ op_22: vec<U64> = Select op_21, splat(0x1_u64), splat(0x0_u64)
+ op_23: vec<U64> = Sub op_19, op_22
+ op_24: vec<U64> = Shr op_20, op_22
+ op_25: vec<Bool> = CompareNe op_24, splat(0x0_u64)
+ op_26: vec<U64> = Select op_25, splat(0x1_u64), splat(0x0_u64)
+ op_27: vec<U64> = Sub op_23, op_26
+ op_28: vec<I64> = Cast op_27
+ Return op_28
+}
+"
+ );
+ }
+
+ #[test]
+ fn test_display_count_leading_zeros_u8() {
+ let ctx = IrContext::new();
+ fn make_it<'ctx>(ctx: &'ctx IrContext<'ctx>) -> IrFunction<'ctx> {
+ let f: fn(&'ctx IrContext<'ctx>, IrVecU8<'ctx>) -> IrVecU8<'ctx> =
+ count_leading_zeros_uint;
+ IrFunction::make(ctx, f)
+ }
+ let text = format!("\n{}", make_it(&ctx));
+ println!("{}", text);
+ assert_eq!(
+ text,
+ r"
+function(in<arg_0>: vec<U8>) -> vec<U8> {
+ op_0: vec<Bool> = CompareGe in<arg_0>, splat(0x10_u8)
+ op_1: vec<U8> = Select op_0, splat(0x4_u8), splat(0x0_u8)
+ op_2: vec<U8> = Sub splat(0x8_u8), op_1
+ op_3: vec<U8> = Shr in<arg_0>, op_1
+ op_4: vec<Bool> = CompareGe op_3, splat(0x4_u8)
+ op_5: vec<U8> = Select op_4, splat(0x2_u8), splat(0x0_u8)
+ op_6: vec<U8> = Sub op_2, op_5
+ op_7: vec<U8> = Shr op_3, op_5
+ op_8: vec<Bool> = CompareGe op_7, splat(0x2_u8)
+ op_9: vec<U8> = Select op_8, splat(0x1_u8), splat(0x0_u8)
+ op_10: vec<U8> = Sub op_6, op_9
+ op_11: vec<U8> = Shr op_7, op_9
+ op_12: vec<Bool> = CompareNe op_11, splat(0x0_u8)
+ op_13: vec<U8> = Select op_12, splat(0x1_u8), splat(0x0_u8)
+ op_14: vec<U8> = Sub op_10, op_13
+ Return op_14
+}
+"
+ );
+ }
+
+ #[test]
+ fn test_display_count_trailing_zeros_u8() {
+ let ctx = IrContext::new();
+ fn make_it<'ctx>(ctx: &'ctx IrContext<'ctx>) -> IrFunction<'ctx> {
+ let f: fn(&'ctx IrContext<'ctx>, IrVecU8<'ctx>) -> IrVecU8<'ctx> =
+ count_trailing_zeros_uint;
+ IrFunction::make(ctx, f)
+ }
+ let text = format!("\n{}", make_it(&ctx));
+ println!("{}", text);
+ assert_eq!(
+ text,
+ r"
+function(in<arg_0>: vec<U8>) -> vec<U8> {
+ op_0: vec<U8> = And in<arg_0>, splat(0xF_u8)
+ op_1: vec<Bool> = CompareEq op_0, splat(0x0_u8)
+ op_2: vec<U8> = Select op_1, splat(0x4_u8), splat(0x0_u8)
+ op_3: vec<U8> = Add splat(0x0_u8), op_2
+ op_4: vec<U8> = Shr in<arg_0>, op_2
+ op_5: vec<U8> = And op_4, splat(0x3_u8)
+ op_6: vec<Bool> = CompareEq op_5, splat(0x0_u8)
+ op_7: vec<U8> = Select op_6, splat(0x2_u8), splat(0x0_u8)
+ op_8: vec<U8> = Add op_3, op_7
+ op_9: vec<U8> = Shr op_4, op_7
+ op_10: vec<U8> = And op_9, splat(0x1_u8)
+ op_11: vec<Bool> = CompareEq op_10, splat(0x0_u8)
+ op_12: vec<U8> = Select op_11, splat(0x1_u8), splat(0x0_u8)
+ op_13: vec<U8> = Add op_8, op_12
+ op_14: vec<U8> = Shr op_9, op_12
+ op_15: vec<Bool> = CompareEq op_14, splat(0x0_u8)
+ op_16: vec<U8> = Select op_15, splat(0x1_u8), splat(0x0_u8)
+ op_17: vec<U8> = Add op_13, op_16
+ Return op_17
+}
+"
+ );
+ }
+
+ #[test]
+ fn test_display_count_ones_u8() {
+ let ctx = IrContext::new();
+ fn make_it<'ctx>(ctx: &'ctx IrContext<'ctx>) -> IrFunction<'ctx> {
+ let f: fn(&'ctx IrContext<'ctx>, IrVecU8<'ctx>) -> IrVecU8<'ctx> = count_ones_uint;
+ IrFunction::make(ctx, f)
+ }
+ let text = format!("\n{}", make_it(&ctx));
+ println!("{}", text);
+ assert_eq!(
+ text,
+ r"
+function(in<arg_0>: vec<U8>) -> vec<U8> {
+ op_0: vec<U8> = Shr in<arg_0>, splat(0x1_u8)
+ op_1: vec<U8> = And op_0, splat(0x55_u8)
+ op_2: vec<U8> = Sub in<arg_0>, op_1
+ op_3: vec<U8> = And op_2, splat(0x33_u8)
+ op_4: vec<U8> = Shr op_2, splat(0x2_u8)
+ op_5: vec<U8> = And op_4, splat(0x33_u8)
+ op_6: vec<U8> = Add op_3, op_5
+ op_7: vec<U8> = And op_6, splat(0xF_u8)
+ op_8: vec<U8> = Shr op_6, splat(0x4_u8)
+ op_9: vec<U8> = And op_8, splat(0xF_u8)
+ op_10: vec<U8> = Add op_7, op_9
+ Return op_10
+}
+"
+ );
+ }
+
+ #[test]
+ fn test_display_count_ones_u64() {
+ let ctx = IrContext::new();
+ fn make_it<'ctx>(ctx: &'ctx IrContext<'ctx>) -> IrFunction<'ctx> {
+ let f: fn(&'ctx IrContext<'ctx>, IrVecU64<'ctx>) -> IrVecU64<'ctx> = count_ones_uint;
+ IrFunction::make(ctx, f)
+ }
+ let text = format!("\n{}", make_it(&ctx));
+ println!("{}", text);
+ assert_eq!(
+ text,
+ r"
+function(in<arg_0>: vec<U64>) -> vec<U64> {
+ op_0: vec<U64> = Shr in<arg_0>, splat(0x1_u64)
+ op_1: vec<U64> = And op_0, splat(0x5555555555555555_u64)
+ op_2: vec<U64> = Sub in<arg_0>, op_1
+ op_3: vec<U64> = And op_2, splat(0x3333333333333333_u64)
+ op_4: vec<U64> = Shr op_2, splat(0x2_u64)
+ op_5: vec<U64> = And op_4, splat(0x3333333333333333_u64)
+ op_6: vec<U64> = Add op_3, op_5
+ op_7: vec<U64> = And op_6, splat(0xF0F0F0F0F0F0F0F_u64)
+ op_8: vec<U64> = Shr op_6, splat(0x4_u64)
+ op_9: vec<U64> = And op_8, splat(0xF0F0F0F0F0F0F0F_u64)
+ op_10: vec<U64> = Add op_7, op_9
+ op_11: vec<U64> = Mul op_10, splat(0x101010101010101_u64)
+ op_12: vec<U64> = Shr op_11, splat(0x38_u64)
+ Return op_12
+}
+"
+ );
+ }
+}