add sin_pi_f16, cos_pi_f16, and sin_cos_pi_f16
[vector-math.git] / src / algorithms / trig_pi.rs
index 28e67ef0a4f1327eec038b54cfc94d4672c2276e..38104a655ee23870c5c04cfc7c1dc259d6cd32f9 100644 (file)
@@ -1,4 +1,8 @@
-use crate::traits::{Context, ConvertTo, Float};
+use crate::{
+    f16::F16,
+    ieee754::FloatEncoding,
+    traits::{Compare, Context, ConvertFrom, ConvertTo, Float, Select},
+};
 
 mod consts {
     #![allow(clippy::excessive_precision)]
@@ -84,18 +88,64 @@ pub fn cos_pi_kernel_f16<Ctx: Context>(ctx: Ctx, x: Ctx::VecF16) -> Ctx::VecF16
 /// computes `(sin(pi * x), cos(pi * x))`
 /// not guaranteed to give correct sign for zero results
 /// has an error of up to 2ULP
-pub fn sin_cos_pi_f16<Ctx: Context>(_ctx: Ctx, _x: Ctx::VecF16) -> (Ctx::VecF16, Ctx::VecF16) {
-    todo!()
+pub fn sin_cos_pi_f16<Ctx: Context>(ctx: Ctx, x: Ctx::VecF16) -> (Ctx::VecF16, Ctx::VecF16) {
+    let two_f16: Ctx::VecF16 = ctx.make(2.0.to());
+    let one_half: Ctx::VecF16 = ctx.make(0.5.to());
+    let max_contiguous_integer: Ctx::VecF16 =
+        ctx.make((1u16 << (F16::MANTISSA_FIELD_WIDTH + 1)).to());
+    // if `x` is finite and bigger than `max_contiguous_integer`, then x is an even integer
+    let in_range = x.abs().lt(max_contiguous_integer); // use `lt` so nans are counted as out-of-range
+    let is_finite = x.is_finite();
+    let nan: Ctx::VecF16 = ctx.make(f32::NAN.to());
+    let zero_f16: Ctx::VecF16 = ctx.make(0.to());
+    let one_f16: Ctx::VecF16 = ctx.make(1.to());
+    let zero_i16: Ctx::VecI16 = ctx.make(0.to());
+    let one_i16: Ctx::VecI16 = ctx.make(1.to());
+    let two_i16: Ctx::VecI16 = ctx.make(2.to());
+    let out_of_range_sin = is_finite.select(zero_f16, nan);
+    let out_of_range_cos = is_finite.select(one_f16, nan);
+    let xi = (x * two_f16).round();
+    let xk = x - xi * one_half;
+    let sk = sin_pi_kernel_f16(ctx, xk);
+    let ck = cos_pi_kernel_f16(ctx, xk);
+    let xi = Ctx::VecI16::cvt_from(xi);
+    let bit_0_clear = (xi & one_i16).eq(zero_i16);
+    let st = bit_0_clear.select(sk, ck);
+    let ct = bit_0_clear.select(ck, sk);
+    let s = (xi & two_i16).eq(zero_i16).select(st, -st);
+    let c = ((xi + one_i16) & two_i16).eq(zero_i16).select(ct, -ct);
+    (
+        in_range.select(s, out_of_range_sin),
+        in_range.select(c, out_of_range_cos),
+    )
+}
+
+/// computes `sin(pi * x)`
+/// not guaranteed to give correct sign for zero results
+/// has an error of up to 2ULP
+pub fn sin_pi_f16<Ctx: Context>(ctx: Ctx, x: Ctx::VecF16) -> Ctx::VecF16 {
+    sin_cos_pi_f16(ctx, x).0
+}
+
+/// computes `cos(pi * x)`
+/// not guaranteed to give correct sign for zero results
+/// has an error of up to 2ULP
+pub fn cos_pi_f16<Ctx: Context>(ctx: Ctx, x: Ctx::VecF16) -> Ctx::VecF16 {
+    sin_cos_pi_f16(ctx, x).1
 }
 
 #[cfg(test)]
 mod tests {
     use super::*;
-    use crate::{f16::F16, scalar::Scalar};
+    use crate::{
+        f16::F16,
+        scalar::{Scalar, Value},
+    };
     use std::f64;
 
     struct CheckUlpCallbackArg<F, I> {
         distance_in_ulp: I,
+        x: F,
         expected: F,
         result: F,
     }
@@ -103,7 +153,7 @@ mod tests {
     #[track_caller]
     fn check_ulp_f16(
         x: F16,
-        is_ok: impl Fn(CheckUlpCallbackArg<F16, i16>) -> bool,
+        is_ok: impl Fn(CheckUlpCallbackArg<F16, u32>) -> bool,
         fn_f16: impl Fn(F16) -> F16,
         fn_f64: impl Fn(f64) -> f64,
     ) {
@@ -114,25 +164,34 @@ mod tests {
         if result == expected {
             return;
         }
-        let distance_in_ulp = (expected.to_bits() as i16).wrapping_sub(result.to_bits() as i16);
-        if is_ok(CheckUlpCallbackArg {
-            distance_in_ulp,
-            expected,
-            result,
-        }) {
+        if result.is_nan() && expected.is_nan() {
+            return;
+        }
+        let distance_in_ulp = (expected.to_bits() as i32 - result.to_bits() as i32).unsigned_abs();
+        if !result.is_nan()
+            && !expected.is_nan()
+            && is_ok(CheckUlpCallbackArg {
+                distance_in_ulp,
+                x,
+                expected,
+                result,
+            })
+        {
             return;
         }
         panic!(
             "error is too big: \
                 x = {x:?} {x_bits:#X}, \
                 result = {result:?} {result_bits:#X}, \
-                expected = {expected:?} {expected_bits:#X}",
+                expected = {expected:?} {expected_bits:#X}, \
+                distance_in_ulp = {distance_in_ulp}",
             x = x,
             x_bits = x.to_bits(),
             result = result,
             result_bits = result.to_bits(),
             expected = expected,
             expected_bits = expected.to_bits(),
+            distance_in_ulp = distance_in_ulp,
         );
     }
 
@@ -146,7 +205,7 @@ mod tests {
             check_ulp_f16(
                 x,
                 |arg| arg.distance_in_ulp <= if arg.expected == 0.to() { 0 } else { 2 },
-                |x| sin_pi_kernel_f16(Scalar, x),
+                |x| sin_pi_kernel_f16(Scalar, Value(x)).0,
                 |x| (f64::consts::PI * x).sin(),
             )
         };
@@ -167,7 +226,7 @@ mod tests {
             check_ulp_f16(
                 x,
                 |arg| arg.distance_in_ulp <= 2 && arg.result <= 1.to(),
-                |x| cos_pi_kernel_f16(Scalar, x),
+                |x| cos_pi_kernel_f16(Scalar, Value(x)).0,
                 |x| (f64::consts::PI * x).cos(),
             )
         };
@@ -177,4 +236,44 @@ mod tests {
             check(-F16::from_bits(bits));
         }
     }
+
+    fn sin_cos_pi_check_ulp_callback_f16(arg: CheckUlpCallbackArg<F16, u32>) -> bool {
+        if f32::cvt_from(arg.x) % 0.5 == 0.0 {
+            arg.distance_in_ulp == 0
+        } else {
+            arg.distance_in_ulp <= 2 && arg.result.abs() <= 1.to()
+        }
+    }
+
+    #[test]
+    #[cfg_attr(
+        not(feature = "f16"),
+        should_panic(expected = "f16 feature is not enabled")
+    )]
+    fn test_sin_pi_f16() {
+        for bits in 0..=u16::MAX {
+            check_ulp_f16(
+                F16::from_bits(bits),
+                sin_cos_pi_check_ulp_callback_f16,
+                |x| sin_pi_f16(Scalar, Value(x)).0,
+                |x| (f64::consts::PI * x).sin(),
+            );
+        }
+    }
+
+    #[test]
+    #[cfg_attr(
+        not(feature = "f16"),
+        should_panic(expected = "f16 feature is not enabled")
+    )]
+    fn test_cos_pi_f16() {
+        for bits in 0..=u16::MAX {
+            check_ulp_f16(
+                F16::from_bits(bits),
+                sin_cos_pi_check_ulp_callback_f16,
+                |x| cos_pi_f16(Scalar, Value(x)).0,
+                |x| (f64::consts::PI * x).cos(),
+            );
+        }
+    }
 }