add trunc implementation
[vector-math.git] / src / algorithms / base.rs
index 0b6dcb60993100b966d73f308ceb9084826cbfd1..d38734091a90b41ebabcb6d24c4ebbe4e49212d6 100644 (file)
@@ -1,6 +1,6 @@
 use crate::{
     prim::{PrimFloat, PrimUInt},
-    traits::{Context, Float, Make},
+    traits::{Context, ConvertTo, Float, Make, Select, UInt},
 };
 
 pub fn abs<
@@ -30,6 +30,30 @@ pub fn copy_sign<
     VecF::from_bits(mag_bits | sign_bit)
 }
 
+pub fn trunc<
+    Ctx: Context,
+    VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
+    VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
+    PrimF: PrimFloat<BitsType = PrimU>,
+    PrimU: PrimUInt,
+>(
+    ctx: Ctx,
+    v: VecF,
+) -> VecF {
+    let big_limit: VecF = ctx.make(PrimF::IMPLICIT_MANTISSA_BIT.to());
+    let big = !v.abs().lt(big_limit); // use `lt` so nans are counted as big
+    let small = v.abs().lt(ctx.make(PrimF::cvt_from(1)));
+    let out_of_range = big | small;
+    let small_value = ctx.make::<VecF>(0.to()).copy_sign(v);
+    let out_of_range_value = small.select(small_value, v);
+    let exponent_field = v.extract_exponent_field();
+    let right_shift_amount: VecU = exponent_field - ctx.make(PrimF::EXPONENT_BIAS_UNSIGNED);
+    let mut mask: VecU = ctx.make(PrimF::MANTISSA_FIELD_MASK);
+    mask >>= right_shift_amount;
+    let in_range_value = VecF::from_bits(v.to_bits() & !mask);
+    out_of_range.select(out_of_range_value, in_range_value)
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -147,4 +171,73 @@ mod tests {
             }
         }
     }
+
+    fn same<F: PrimFloat>(a: F, b: F) -> bool {
+        if a.is_finite() && b.is_finite() {
+            a == b
+        } else {
+            a == b || (a.is_nan() && b.is_nan())
+        }
+    }
+
+    #[test]
+    #[cfg_attr(
+        not(feature = "f16"),
+        should_panic(expected = "f16 feature is not enabled")
+    )]
+    fn test_trunc_f16() {
+        for bits in 0..=u16::MAX {
+            let v = F16::from_bits(bits);
+            let expected = v.trunc();
+            let result = trunc(Scalar, Value(v)).0;
+            assert!(
+                same(expected, result),
+                "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
+                v=v,
+                v_bits=v.to_bits(),
+                expected=expected,
+                expected_bits=expected.to_bits(),
+                result=result,
+                result_bits=result.to_bits(),
+            );
+        }
+    }
+
+    #[test]
+    fn test_trunc_f32() {
+        for bits in (0..=u32::MAX).step_by(0x10000) {
+            let v = f32::from_bits(bits);
+            let expected = v.trunc();
+            let result = trunc(Scalar, Value(v)).0;
+            assert!(
+                same(expected, result),
+                "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
+                v=v,
+                v_bits=v.to_bits(),
+                expected=expected,
+                expected_bits=expected.to_bits(),
+                result=result,
+                result_bits=result.to_bits(),
+            );
+        }
+    }
+
+    #[test]
+    fn test_trunc_f64() {
+        for bits in (0..=u64::MAX).step_by(1 << 48) {
+            let v = f64::from_bits(bits);
+            let expected = v.trunc();
+            let result = trunc(Scalar, Value(v)).0;
+            assert!(
+                same(expected, result),
+                "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
+                v=v,
+                v_bits=v.to_bits(),
+                expected=expected,
+                expected_bits=expected.to_bits(),
+                result=result,
+                result_bits=result.to_bits(),
+            );
+        }
+    }
 }