add ceil and floor
authorJacob Lifshay <programmerjake@gmail.com>
Fri, 14 May 2021 02:21:08 +0000 (19:21 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Fri, 14 May 2021 02:21:08 +0000 (19:21 -0700)
src/algorithms/base.rs
src/f16.rs
src/prim.rs
src/scalar.rs

index 28a641df564d73f46dcd84f20e07374353bea427..b4ec103aac2f8ce2bc7d198753daab9d456d7fed 100644 (file)
@@ -66,14 +66,50 @@ pub fn round_to_nearest_ties_to_even<
 ) -> VecF {
     let big_limit: VecF = ctx.make(PrimF::IMPLICIT_MANTISSA_BIT.to());
     let big = !v.abs().lt(big_limit); // use `lt` so nans are counted as big
-    let small = v.abs().le(ctx.make(PrimF::cvt_from(0.5)));
-    let out_of_range = big | small;
-    let small_value = ctx.make::<VecF>(0.to()).copy_sign(v);
-    let out_of_range_value = small.select(small_value, v);
     let offset = ctx.make((PrimU::cvt_from(1) << PrimF::MANTISSA_FIELD_WIDTH).to());
     let offset_value: VecF = v.abs() + offset;
     let in_range_value = (offset_value - offset).copy_sign(v);
-    out_of_range.select(out_of_range_value, in_range_value)
+    big.select(v, in_range_value)
+}
+
+pub fn floor<
+    Ctx: Context,
+    VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
+    VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
+    PrimF: PrimFloat<BitsType = PrimU>,
+    PrimU: PrimUInt,
+>(
+    ctx: Ctx,
+    v: VecF,
+) -> VecF {
+    let big_limit: VecF = ctx.make(PrimF::IMPLICIT_MANTISSA_BIT.to());
+    let big = !v.abs().lt(big_limit); // use `lt` so nans are counted as big
+    let offset = ctx.make((PrimU::cvt_from(1) << PrimF::MANTISSA_FIELD_WIDTH).to());
+    let offset_value: VecF = v.abs() + offset;
+    let rounded = (offset_value - offset).copy_sign(v);
+    let need_round_down = v.lt(rounded);
+    let in_range_value = need_round_down.select(rounded - ctx.make(1.to()), rounded).copy_sign(v);
+    big.select(v, in_range_value)
+}
+
+pub fn ceil<
+    Ctx: Context,
+    VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
+    VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
+    PrimF: PrimFloat<BitsType = PrimU>,
+    PrimU: PrimUInt,
+>(
+    ctx: Ctx,
+    v: VecF,
+) -> VecF {
+    let big_limit: VecF = ctx.make(PrimF::IMPLICIT_MANTISSA_BIT.to());
+    let big = !v.abs().lt(big_limit); // use `lt` so nans are counted as big
+    let offset = ctx.make((PrimU::cvt_from(1) << PrimF::MANTISSA_FIELD_WIDTH).to());
+    let offset_value: VecF = v.abs() + offset;
+    let rounded = (offset_value - offset).copy_sign(v);
+    let need_round_up = v.gt(rounded);
+    let in_range_value = need_round_up.select(rounded + ctx.make(1.to()), rounded).copy_sign(v);
+    big.select(v, in_range_value)
 }
 
 #[cfg(test)]
@@ -198,7 +234,7 @@ mod tests {
 
     fn same<F: PrimFloat>(a: F, b: F) -> bool {
         if a.is_finite() && b.is_finite() {
-            a == b
+            a.to_bits() == b.to_bits()
         } else {
             a == b || (a.is_nan() && b.is_nan())
         }
@@ -295,13 +331,8 @@ mod tests {
         #[track_caller]
         fn case(v: f32, expected: f32) {
             let result = reference_round_to_nearest_ties_to_even(v);
-            let same = if expected.is_nan() {
-                result.is_nan()
-            } else {
-                expected.to_bits() == result.to_bits()
-            };
             assert!(
-                same,
+                same(result, expected),
                 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
                 v=v,
                 v_bits=v.to_bits(),
@@ -362,7 +393,7 @@ mod tests {
             let expected = reference_round_to_nearest_ties_to_even(v);
             let result = round_to_nearest_ties_to_even(Scalar, Value(v)).0;
             assert!(
-                same(expected, result),
+                same(result, expected),
                 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
                 v=v,
                 v_bits=v.to_bits(),
@@ -381,7 +412,7 @@ mod tests {
             let expected = reference_round_to_nearest_ties_to_even(v);
             let result = round_to_nearest_ties_to_even(Scalar, Value(v)).0;
             assert!(
-                same(expected, result),
+                same(result, expected),
                 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
                 v=v,
                 v_bits=v.to_bits(),
@@ -399,6 +430,128 @@ mod tests {
             let v = f64::from_bits(bits);
             let expected = reference_round_to_nearest_ties_to_even(v);
             let result = round_to_nearest_ties_to_even(Scalar, Value(v)).0;
+            assert!(
+                same(result, expected),
+                "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
+                v=v,
+                v_bits=v.to_bits(),
+                expected=expected,
+                expected_bits=expected.to_bits(),
+                result=result,
+                result_bits=result.to_bits(),
+            );
+        }
+    }
+
+    #[test]
+    #[cfg_attr(
+        not(feature = "f16"),
+        should_panic(expected = "f16 feature is not enabled")
+    )]
+    fn test_floor_f16() {
+        for bits in 0..=u16::MAX {
+            let v = F16::from_bits(bits);
+            let expected = v.floor();
+            let result = floor(Scalar, Value(v)).0;
+            assert!(
+                same(expected, result),
+                "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
+                v=v,
+                v_bits=v.to_bits(),
+                expected=expected,
+                expected_bits=expected.to_bits(),
+                result=result,
+                result_bits=result.to_bits(),
+            );
+        }
+    }
+
+    #[test]
+    fn test_floor_f32() {
+        for bits in (0..=u32::MAX).step_by(0x10000) {
+            let v = f32::from_bits(bits);
+            let expected = v.floor();
+            let result = floor(Scalar, Value(v)).0;
+            assert!(
+                same(expected, result),
+                "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
+                v=v,
+                v_bits=v.to_bits(),
+                expected=expected,
+                expected_bits=expected.to_bits(),
+                result=result,
+                result_bits=result.to_bits(),
+            );
+        }
+    }
+
+    #[test]
+    fn test_floor_f64() {
+        for bits in (0..=u64::MAX).step_by(1 << 48) {
+            let v = f64::from_bits(bits);
+            let expected = v.floor();
+            let result = floor(Scalar, Value(v)).0;
+            assert!(
+                same(expected, result),
+                "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
+                v=v,
+                v_bits=v.to_bits(),
+                expected=expected,
+                expected_bits=expected.to_bits(),
+                result=result,
+                result_bits=result.to_bits(),
+            );
+        }
+    }
+
+    #[test]
+    #[cfg_attr(
+        not(feature = "f16"),
+        should_panic(expected = "f16 feature is not enabled")
+    )]
+    fn test_ceil_f16() {
+        for bits in 0..=u16::MAX {
+            let v = F16::from_bits(bits);
+            let expected = v.ceil();
+            let result = ceil(Scalar, Value(v)).0;
+            assert!(
+                same(expected, result),
+                "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
+                v=v,
+                v_bits=v.to_bits(),
+                expected=expected,
+                expected_bits=expected.to_bits(),
+                result=result,
+                result_bits=result.to_bits(),
+            );
+        }
+    }
+
+    #[test]
+    fn test_ceil_f32() {
+        for bits in (0..=u32::MAX).step_by(0x10000) {
+            let v = f32::from_bits(bits);
+            let expected = v.ceil();
+            let result = ceil(Scalar, Value(v)).0;
+            assert!(
+                same(expected, result),
+                "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
+                v=v,
+                v_bits=v.to_bits(),
+                expected=expected,
+                expected_bits=expected.to_bits(),
+                result=result,
+                result_bits=result.to_bits(),
+            );
+        }
+    }
+
+    #[test]
+    fn test_ceil_f64() {
+        for bits in (0..=u64::MAX).step_by(1 << 48) {
+            let v = f64::from_bits(bits);
+            let expected = v.ceil();
+            let result = ceil(Scalar, Value(v)).0;
             assert!(
                 same(expected, result),
                 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
index 1ed978d25c59fcddaee617d04a3f179164a5633a..b5d84d5a5020a12f7615e2c644a215593779167c 100644 (file)
@@ -215,16 +215,10 @@ impl F16 {
         return PrimFloat::trunc(f32::from(self)).to();
     }
     pub fn ceil(self) -> Self {
-        #[cfg(feature = "std")]
-        return f32::from(self).ceil().to();
-        #[cfg(not(feature = "std"))]
-        todo!();
+        return PrimFloat::ceil(f32::from(self)).to();
     }
     pub fn floor(self) -> Self {
-        #[cfg(feature = "std")]
-        return f32::from(self).floor().to();
-        #[cfg(not(feature = "std"))]
-        todo!();
+        return PrimFloat::floor(f32::from(self)).to();
     }
     /// round to nearest, ties to unspecified
     pub fn round(self) -> Self {
index 4b3c6ca66f8584adf3e4ed98fb234d0583b2b3e3..ea7b6b94010b7bf78da9a1873dd8b158340d1285 100644 (file)
@@ -163,6 +163,8 @@ pub trait PrimFloat:
     fn trunc(self) -> Self;
     /// round to nearest, ties to unspecified
     fn round(self) -> Self;
+    fn floor(self) -> Self;
+    fn ceil(self) -> Self;
     fn copy_sign(self, sign: Self) -> Self;
 }
 
@@ -232,6 +234,18 @@ macro_rules! impl_float {
                 return crate::algorithms::base::round_to_nearest_ties_to_even(Scalar, Value(self))
                     .0;
             }
+            fn floor(self) -> Self {
+                #[cfg(feature = "std")]
+                return $float::floor(self);
+                #[cfg(not(feature = "std"))]
+                return crate::algorithms::base::floor(Scalar, Value(self)).0;
+            }
+            fn ceil(self) -> Self {
+                #[cfg(feature = "std")]
+                return $float::ceil(self);
+                #[cfg(not(feature = "std"))]
+                return crate::algorithms::base::ceil(Scalar, Value(self)).0;
+            }
             fn copy_sign(self, sign: Self) -> Self {
                 #[cfg(feature = "std")]
                 return $float::copysign(self, sign);
index 3011df707fc97c6e31cdcb430d504652c7cea0d5..b15388b24ef6473abab300da2fcfea8e5a3a83f1 100644 (file)
@@ -368,13 +368,13 @@ macro_rules! impl_float {
                 #[cfg(feature = "std")]
                 return Value(self.0.ceil());
                 #[cfg(not(feature = "std"))]
-                todo!();
+                return crate::algorithms::base::ceil(Scalar, self);
             }
             fn floor(self) -> Self {
                 #[cfg(feature = "std")]
                 return Value(self.0.floor());
                 #[cfg(not(feature = "std"))]
-                todo!();
+                return crate::algorithms::base::floor(Scalar, self);
             }
             fn round(self) -> Self {
                 #[cfg(feature = "std")]