add count_leading_zeros, count_trailing_zeros, and count_ones implementations
[vector-math.git] / src / algorithms / integer.rs
1 use crate::{
2 prim::PrimUInt,
3 traits::{Context, ConvertFrom, ConvertTo, Make, SInt, Select, UInt},
4 };
5
6 pub fn count_leading_zeros_uint<
7 Ctx: Context,
8 VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
9 PrimU: PrimUInt,
10 >(
11 ctx: Ctx,
12 mut v: VecU,
13 ) -> VecU {
14 let mut retval: VecU = ctx.make(PrimU::BITS);
15 let mut bits = PrimU::BITS;
16 while bits > 1.to() {
17 bits /= 2.to();
18 let limit = PrimU::ONE << bits;
19 let found = v.ge(ctx.make(limit));
20 let shift: VecU = found.select(ctx.make(bits), ctx.make(0.to()));
21 retval -= shift;
22 v >>= shift;
23 }
24 let nonzero = v.ne(ctx.make(0.to()));
25 retval - nonzero.select(ctx.make(1.to()), ctx.make(0.to()))
26 }
27
28 pub fn count_leading_zeros_sint<
29 Ctx: Context,
30 VecU: UInt + Make<Context = Ctx> + ConvertFrom<VecS>,
31 VecS: SInt<UnsignedType = VecU> + ConvertFrom<VecU>,
32 >(
33 ctx: Ctx,
34 v: VecS,
35 ) -> VecS {
36 count_leading_zeros_uint(ctx, VecU::cvt_from(v)).to()
37 }
38
39 pub fn count_trailing_zeros_uint<
40 Ctx: Context,
41 VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
42 PrimU: PrimUInt,
43 >(
44 ctx: Ctx,
45 mut v: VecU,
46 ) -> VecU {
47 let mut retval: VecU = ctx.make(PrimU::ZERO);
48 let mut bits = PrimU::BITS;
49 while bits > 1.to() {
50 bits /= 2.to();
51 let mask = (PrimU::ONE << bits) - 1.to();
52 let zero = (v & ctx.make(mask)).eq(ctx.make(0.to()));
53 let shift: VecU = zero.select(ctx.make(bits), ctx.make(0.to()));
54 retval += shift;
55 v >>= shift;
56 }
57 let zero = v.eq(ctx.make(0.to()));
58 retval + zero.select(ctx.make(1.to()), ctx.make(0.to()))
59 }
60
61 pub fn count_trailing_zeros_sint<
62 Ctx: Context,
63 VecU: UInt + Make<Context = Ctx> + ConvertFrom<VecS>,
64 VecS: SInt<UnsignedType = VecU> + ConvertFrom<VecU>,
65 >(
66 ctx: Ctx,
67 v: VecS,
68 ) -> VecS {
69 count_trailing_zeros_uint(ctx, VecU::cvt_from(v)).to()
70 }
71
72 pub fn count_ones_uint<
73 Ctx: Context,
74 VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
75 PrimU: PrimUInt,
76 >(
77 ctx: Ctx,
78 mut v: VecU,
79 ) -> VecU {
80 assert!(PrimU::BITS <= 64.to());
81 assert!(PrimU::BITS >= 8.to());
82 const SPLAT_BYTES_MULTIPLIER: u64 = u64::from_le_bytes([1; 8]);
83 const EVERY_OTHER_BIT_MASK: u64 = 0x55 * SPLAT_BYTES_MULTIPLIER;
84 const TWO_OUT_OF_FOUR_BITS_MASK: u64 = 0x33 * SPLAT_BYTES_MULTIPLIER;
85 const FOUR_OUT_OF_EIGHT_BITS_MASK: u64 = 0x0F * SPLAT_BYTES_MULTIPLIER;
86 // algorithm derived from popcount64c at https://en.wikipedia.org/wiki/Hamming_weight
87 v -= (v >> ctx.make(1.to())) & ctx.make(EVERY_OTHER_BIT_MASK.to());
88 v = (v & ctx.make(TWO_OUT_OF_FOUR_BITS_MASK.to()))
89 + ((v >> ctx.make(2.to())) & ctx.make(TWO_OUT_OF_FOUR_BITS_MASK.to()));
90 v = (v & ctx.make(FOUR_OUT_OF_EIGHT_BITS_MASK.to()))
91 + ((v >> ctx.make(4.to())) & ctx.make(FOUR_OUT_OF_EIGHT_BITS_MASK.to()));
92 if PrimU::BITS > 8.to() {
93 v * ctx.make(SPLAT_BYTES_MULTIPLIER.to()) >> ctx.make(PrimU::BITS - 8.to())
94 } else {
95 v
96 }
97 }
98
99 pub fn count_ones_sint<
100 Ctx: Context,
101 VecU: UInt + Make<Context = Ctx> + ConvertFrom<VecS>,
102 VecS: SInt<UnsignedType = VecU> + ConvertFrom<VecU>,
103 >(
104 ctx: Ctx,
105 v: VecS,
106 ) -> VecS {
107 count_ones_uint(ctx, VecU::cvt_from(v)).to()
108 }
109
110 #[cfg(test)]
111 mod tests {
112 use super::*;
113 use crate::scalar::{Scalar, Value};
114
115 #[test]
116 fn test_count_leading_zeros_u16() {
117 for v in 0..=u16::MAX {
118 assert_eq!(
119 v.leading_zeros() as u16,
120 count_leading_zeros_uint(Scalar, Value(v)).0,
121 "v = {:#X}",
122 v,
123 );
124 }
125 }
126
127 #[test]
128 fn test_count_trailing_zeros_u16() {
129 for v in 0..=u16::MAX {
130 assert_eq!(
131 v.trailing_zeros() as u16,
132 count_trailing_zeros_uint(Scalar, Value(v)).0,
133 "v = {:#X}",
134 v,
135 );
136 }
137 }
138
139 #[test]
140 fn test_count_ones_u16() {
141 for v in 0..=u16::MAX {
142 assert_eq!(
143 v.count_ones() as u16,
144 count_ones_uint(Scalar, Value(v)).0,
145 "v = {:#X}",
146 v,
147 );
148 }
149 }
150 }
151
152 #[cfg(all(feature = "ir", test))]
153 mod ir_tests {
154 use super::*;
155 use crate::ir::{IrContext, IrFunction, IrVecI64, IrVecU64, IrVecU8};
156 use std::{format, println};
157
158 #[test]
159 fn test_display_count_leading_zeros_i64() {
160 let ctx = IrContext::new();
161 fn make_it<'ctx>(ctx: &'ctx IrContext<'ctx>) -> IrFunction<'ctx> {
162 let f: fn(&'ctx IrContext<'ctx>, IrVecI64<'ctx>) -> IrVecI64<'ctx> =
163 count_leading_zeros_sint;
164 IrFunction::make(ctx, f)
165 }
166 let text = format!("\n{}", make_it(&ctx));
167 println!("{}", text);
168 assert_eq!(
169 text,
170 r"
171 function(in<arg_0>: vec<I64>) -> vec<I64> {
172 op_0: vec<U64> = Cast in<arg_0>
173 op_1: vec<Bool> = CompareGe op_0, splat(0x100000000_u64)
174 op_2: vec<U64> = Select op_1, splat(0x20_u64), splat(0x0_u64)
175 op_3: vec<U64> = Sub splat(0x40_u64), op_2
176 op_4: vec<U64> = Shr op_0, op_2
177 op_5: vec<Bool> = CompareGe op_4, splat(0x10000_u64)
178 op_6: vec<U64> = Select op_5, splat(0x10_u64), splat(0x0_u64)
179 op_7: vec<U64> = Sub op_3, op_6
180 op_8: vec<U64> = Shr op_4, op_6
181 op_9: vec<Bool> = CompareGe op_8, splat(0x100_u64)
182 op_10: vec<U64> = Select op_9, splat(0x8_u64), splat(0x0_u64)
183 op_11: vec<U64> = Sub op_7, op_10
184 op_12: vec<U64> = Shr op_8, op_10
185 op_13: vec<Bool> = CompareGe op_12, splat(0x10_u64)
186 op_14: vec<U64> = Select op_13, splat(0x4_u64), splat(0x0_u64)
187 op_15: vec<U64> = Sub op_11, op_14
188 op_16: vec<U64> = Shr op_12, op_14
189 op_17: vec<Bool> = CompareGe op_16, splat(0x4_u64)
190 op_18: vec<U64> = Select op_17, splat(0x2_u64), splat(0x0_u64)
191 op_19: vec<U64> = Sub op_15, op_18
192 op_20: vec<U64> = Shr op_16, op_18
193 op_21: vec<Bool> = CompareGe op_20, splat(0x2_u64)
194 op_22: vec<U64> = Select op_21, splat(0x1_u64), splat(0x0_u64)
195 op_23: vec<U64> = Sub op_19, op_22
196 op_24: vec<U64> = Shr op_20, op_22
197 op_25: vec<Bool> = CompareNe op_24, splat(0x0_u64)
198 op_26: vec<U64> = Select op_25, splat(0x1_u64), splat(0x0_u64)
199 op_27: vec<U64> = Sub op_23, op_26
200 op_28: vec<I64> = Cast op_27
201 Return op_28
202 }
203 "
204 );
205 }
206
207 #[test]
208 fn test_display_count_leading_zeros_u8() {
209 let ctx = IrContext::new();
210 fn make_it<'ctx>(ctx: &'ctx IrContext<'ctx>) -> IrFunction<'ctx> {
211 let f: fn(&'ctx IrContext<'ctx>, IrVecU8<'ctx>) -> IrVecU8<'ctx> =
212 count_leading_zeros_uint;
213 IrFunction::make(ctx, f)
214 }
215 let text = format!("\n{}", make_it(&ctx));
216 println!("{}", text);
217 assert_eq!(
218 text,
219 r"
220 function(in<arg_0>: vec<U8>) -> vec<U8> {
221 op_0: vec<Bool> = CompareGe in<arg_0>, splat(0x10_u8)
222 op_1: vec<U8> = Select op_0, splat(0x4_u8), splat(0x0_u8)
223 op_2: vec<U8> = Sub splat(0x8_u8), op_1
224 op_3: vec<U8> = Shr in<arg_0>, op_1
225 op_4: vec<Bool> = CompareGe op_3, splat(0x4_u8)
226 op_5: vec<U8> = Select op_4, splat(0x2_u8), splat(0x0_u8)
227 op_6: vec<U8> = Sub op_2, op_5
228 op_7: vec<U8> = Shr op_3, op_5
229 op_8: vec<Bool> = CompareGe op_7, splat(0x2_u8)
230 op_9: vec<U8> = Select op_8, splat(0x1_u8), splat(0x0_u8)
231 op_10: vec<U8> = Sub op_6, op_9
232 op_11: vec<U8> = Shr op_7, op_9
233 op_12: vec<Bool> = CompareNe op_11, splat(0x0_u8)
234 op_13: vec<U8> = Select op_12, splat(0x1_u8), splat(0x0_u8)
235 op_14: vec<U8> = Sub op_10, op_13
236 Return op_14
237 }
238 "
239 );
240 }
241
242 #[test]
243 fn test_display_count_trailing_zeros_u8() {
244 let ctx = IrContext::new();
245 fn make_it<'ctx>(ctx: &'ctx IrContext<'ctx>) -> IrFunction<'ctx> {
246 let f: fn(&'ctx IrContext<'ctx>, IrVecU8<'ctx>) -> IrVecU8<'ctx> =
247 count_trailing_zeros_uint;
248 IrFunction::make(ctx, f)
249 }
250 let text = format!("\n{}", make_it(&ctx));
251 println!("{}", text);
252 assert_eq!(
253 text,
254 r"
255 function(in<arg_0>: vec<U8>) -> vec<U8> {
256 op_0: vec<U8> = And in<arg_0>, splat(0xF_u8)
257 op_1: vec<Bool> = CompareEq op_0, splat(0x0_u8)
258 op_2: vec<U8> = Select op_1, splat(0x4_u8), splat(0x0_u8)
259 op_3: vec<U8> = Add splat(0x0_u8), op_2
260 op_4: vec<U8> = Shr in<arg_0>, op_2
261 op_5: vec<U8> = And op_4, splat(0x3_u8)
262 op_6: vec<Bool> = CompareEq op_5, splat(0x0_u8)
263 op_7: vec<U8> = Select op_6, splat(0x2_u8), splat(0x0_u8)
264 op_8: vec<U8> = Add op_3, op_7
265 op_9: vec<U8> = Shr op_4, op_7
266 op_10: vec<U8> = And op_9, splat(0x1_u8)
267 op_11: vec<Bool> = CompareEq op_10, splat(0x0_u8)
268 op_12: vec<U8> = Select op_11, splat(0x1_u8), splat(0x0_u8)
269 op_13: vec<U8> = Add op_8, op_12
270 op_14: vec<U8> = Shr op_9, op_12
271 op_15: vec<Bool> = CompareEq op_14, splat(0x0_u8)
272 op_16: vec<U8> = Select op_15, splat(0x1_u8), splat(0x0_u8)
273 op_17: vec<U8> = Add op_13, op_16
274 Return op_17
275 }
276 "
277 );
278 }
279
280 #[test]
281 fn test_display_count_ones_u8() {
282 let ctx = IrContext::new();
283 fn make_it<'ctx>(ctx: &'ctx IrContext<'ctx>) -> IrFunction<'ctx> {
284 let f: fn(&'ctx IrContext<'ctx>, IrVecU8<'ctx>) -> IrVecU8<'ctx> = count_ones_uint;
285 IrFunction::make(ctx, f)
286 }
287 let text = format!("\n{}", make_it(&ctx));
288 println!("{}", text);
289 assert_eq!(
290 text,
291 r"
292 function(in<arg_0>: vec<U8>) -> vec<U8> {
293 op_0: vec<U8> = Shr in<arg_0>, splat(0x1_u8)
294 op_1: vec<U8> = And op_0, splat(0x55_u8)
295 op_2: vec<U8> = Sub in<arg_0>, op_1
296 op_3: vec<U8> = And op_2, splat(0x33_u8)
297 op_4: vec<U8> = Shr op_2, splat(0x2_u8)
298 op_5: vec<U8> = And op_4, splat(0x33_u8)
299 op_6: vec<U8> = Add op_3, op_5
300 op_7: vec<U8> = And op_6, splat(0xF_u8)
301 op_8: vec<U8> = Shr op_6, splat(0x4_u8)
302 op_9: vec<U8> = And op_8, splat(0xF_u8)
303 op_10: vec<U8> = Add op_7, op_9
304 Return op_10
305 }
306 "
307 );
308 }
309
310 #[test]
311 fn test_display_count_ones_u64() {
312 let ctx = IrContext::new();
313 fn make_it<'ctx>(ctx: &'ctx IrContext<'ctx>) -> IrFunction<'ctx> {
314 let f: fn(&'ctx IrContext<'ctx>, IrVecU64<'ctx>) -> IrVecU64<'ctx> = count_ones_uint;
315 IrFunction::make(ctx, f)
316 }
317 let text = format!("\n{}", make_it(&ctx));
318 println!("{}", text);
319 assert_eq!(
320 text,
321 r"
322 function(in<arg_0>: vec<U64>) -> vec<U64> {
323 op_0: vec<U64> = Shr in<arg_0>, splat(0x1_u64)
324 op_1: vec<U64> = And op_0, splat(0x5555555555555555_u64)
325 op_2: vec<U64> = Sub in<arg_0>, op_1
326 op_3: vec<U64> = And op_2, splat(0x3333333333333333_u64)
327 op_4: vec<U64> = Shr op_2, splat(0x2_u64)
328 op_5: vec<U64> = And op_4, splat(0x3333333333333333_u64)
329 op_6: vec<U64> = Add op_3, op_5
330 op_7: vec<U64> = And op_6, splat(0xF0F0F0F0F0F0F0F_u64)
331 op_8: vec<U64> = Shr op_6, splat(0x4_u64)
332 op_9: vec<U64> = And op_8, splat(0xF0F0F0F0F0F0F0F_u64)
333 op_10: vec<U64> = Add op_7, op_9
334 op_11: vec<U64> = Mul op_10, splat(0x101010101010101_u64)
335 op_12: vec<U64> = Shr op_11, splat(0x38_u64)
336 Return op_12
337 }
338 "
339 );
340 }
341 }