implement sin_cos_pi_f32
authorJacob Lifshay <programmerjake@gmail.com>
Mon, 10 May 2021 07:28:02 +0000 (00:28 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Mon, 10 May 2021 07:28:02 +0000 (00:28 -0700)
src/algorithms/trig_pi.rs
src/prim.rs

index a9b275c5b6c94f6ee8e961a7b6e510c730409108..fca45d9d2998ea173e2d9eb8870d2ba822f01b7a 100644 (file)
@@ -85,6 +85,30 @@ pub fn cos_pi_kernel_f16<Ctx: Context>(ctx: Ctx, x: Ctx::VecF16) -> Ctx::VecF16
     v.mul_add_fast(x_sq, ctx.make(consts::COSPI_KERNEL_TAYLOR_0.to()))
 }
 
+/// computes `sin(pi * x)` for `-0.25 <= x <= 0.25`
+/// not guaranteed to give correct sign for zero result
+/// has an error of up to 2ULP
+pub fn sin_pi_kernel_f32<Ctx: Context>(ctx: Ctx, x: Ctx::VecF32) -> Ctx::VecF32 {
+    let x_sq = x * x;
+    let mut v: Ctx::VecF32 = ctx.make(consts::SINPI_KERNEL_TAYLOR_9.to());
+    v = v.mul_add_fast(x_sq, ctx.make(consts::SINPI_KERNEL_TAYLOR_7.to()));
+    v = v.mul_add_fast(x_sq, ctx.make(consts::SINPI_KERNEL_TAYLOR_5.to()));
+    v = v.mul_add_fast(x_sq, ctx.make(consts::SINPI_KERNEL_TAYLOR_3.to()));
+    v = v.mul_add_fast(x_sq, ctx.make(consts::SINPI_KERNEL_TAYLOR_1.to()));
+    v * x
+}
+
+/// computes `cos(pi * x)` for `-0.25 <= x <= 0.25`
+/// has an error of up to 2ULP
+pub fn cos_pi_kernel_f32<Ctx: Context>(ctx: Ctx, x: Ctx::VecF32) -> Ctx::VecF32 {
+    let x_sq = x * x;
+    let mut v: Ctx::VecF32 = ctx.make(consts::COSPI_KERNEL_TAYLOR_8.to());
+    v = v.mul_add_fast(x_sq, ctx.make(consts::COSPI_KERNEL_TAYLOR_6.to()));
+    v = v.mul_add_fast(x_sq, ctx.make(consts::COSPI_KERNEL_TAYLOR_4.to()));
+    v = v.mul_add_fast(x_sq, ctx.make(consts::COSPI_KERNEL_TAYLOR_2.to()));
+    v.mul_add_fast(x_sq, ctx.make(consts::COSPI_KERNEL_TAYLOR_0.to()))
+}
+
 /// computes `(sin(pi * x), cos(pi * x))`
 /// not guaranteed to give correct sign for zero results
 /// inherits error from `sin_pi_kernel` and `cos_pi_kernel`
@@ -153,6 +177,27 @@ pub fn cos_pi_f16<Ctx: Context>(ctx: Ctx, x: Ctx::VecF16) -> Ctx::VecF16 {
     sin_cos_pi_f16(ctx, x).1
 }
 
+/// computes `(sin(pi * x), cos(pi * x))`
+/// not guaranteed to give correct sign for zero results
+/// has an error of up to 2ULP
+pub fn sin_cos_pi_f32<Ctx: Context>(ctx: Ctx, x: Ctx::VecF32) -> (Ctx::VecF32, Ctx::VecF32) {
+    sin_cos_pi_impl(ctx, x, sin_pi_kernel_f32, cos_pi_kernel_f32)
+}
+
+/// computes `sin(pi * x)`
+/// not guaranteed to give correct sign for zero results
+/// has an error of up to 2ULP
+pub fn sin_pi_f32<Ctx: Context>(ctx: Ctx, x: Ctx::VecF32) -> Ctx::VecF32 {
+    sin_cos_pi_f32(ctx, x).0
+}
+
+/// computes `cos(pi * x)`
+/// not guaranteed to give correct sign for zero results
+/// has an error of up to 2ULP
+pub fn cos_pi_f32<Ctx: Context>(ctx: Ctx, x: Ctx::VecF32) -> Ctx::VecF32 {
+    sin_cos_pi_f32(ctx, x).1
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -258,8 +303,44 @@ mod tests {
         }
     }
 
-    fn sin_cos_pi_check_ulp_callback_f16(arg: CheckUlpCallbackArg<F16, u64>) -> bool {
-        if f32::cvt_from(arg.x) % 0.5 == 0.0 {
+    #[test]
+    #[cfg(feature = "full_tests")]
+    fn test_sin_pi_kernel_f32() {
+        let check = |x| {
+            check_ulp(
+                x,
+                |arg| arg.distance_in_ulp <= if arg.expected == 0.to() { 0 } else { 2 },
+                |x| sin_pi_kernel_f32(Scalar, Value(x)).0,
+                |x| (f64::consts::PI * x).sin(),
+            )
+        };
+        let quarter = 0.25f32.to_bits();
+        for bits in (0..=quarter).rev() {
+            check(f32::from_bits(bits));
+            check(-f32::from_bits(bits));
+        }
+    }
+
+    #[test]
+    #[cfg(feature = "full_tests")]
+    fn test_cos_pi_kernel_f32() {
+        let check = |x| {
+            check_ulp(
+                x,
+                |arg| arg.distance_in_ulp <= 2 && arg.result <= 1.to(),
+                |x| cos_pi_kernel_f32(Scalar, Value(x)).0,
+                |x| (f64::consts::PI * x).cos(),
+            )
+        };
+        let quarter = 0.25f32.to_bits();
+        for bits in (0..=quarter).rev() {
+            check(f32::from_bits(bits));
+            check(-f32::from_bits(bits));
+        }
+    }
+
+    fn sin_cos_pi_check_ulp_callback<F: PrimFloat>(arg: CheckUlpCallbackArg<F, u64>) -> bool {
+        if arg.x % 0.5.to() == 0.0.to() {
             arg.distance_in_ulp == 0
         } else {
             arg.distance_in_ulp <= 2 && arg.result.abs() <= 1.to()
@@ -275,7 +356,7 @@ mod tests {
         for bits in 0..=u16::MAX {
             check_ulp(
                 F16::from_bits(bits),
-                sin_cos_pi_check_ulp_callback_f16,
+                sin_cos_pi_check_ulp_callback,
                 |x| sin_pi_f16(Scalar, Value(x)).0,
                 |x| (f64::consts::PI * x).sin(),
             );
@@ -291,10 +372,154 @@ mod tests {
         for bits in 0..=u16::MAX {
             check_ulp(
                 F16::from_bits(bits),
-                sin_cos_pi_check_ulp_callback_f16,
+                sin_cos_pi_check_ulp_callback,
                 |x| cos_pi_f16(Scalar, Value(x)).0,
                 |x| (f64::consts::PI * x).cos(),
             );
         }
     }
+
+    fn reference_sin_cos_pi_f32(mut v: f64) -> (f64, f64) {
+        if !v.is_finite() {
+            return (f64::NAN, f64::NAN);
+        }
+        v %= 2.0;
+        if v >= 1.0 {
+            v -= 2.0;
+        } else if v <= -1.0 {
+            v += 2.0;
+        }
+        v *= 2.0;
+        let part = v.round() as i32;
+        v -= part as f64;
+        v *= f64::consts::PI / 2.0;
+        let (sin, cos) = v.sin_cos();
+        match part {
+            0 => (sin, cos),
+            1 => (cos, -sin),
+            2 => (-sin, -cos),
+            -2 => (-sin, -cos),
+            -1 => (-cos, sin),
+            _ => panic!("not implemented: part={}", part),
+        }
+    }
+
+    #[test]
+    fn test_reference_sin_cos_pi_f32() {
+        fn approx_same(a: f32, b: f32) -> bool {
+            if a.is_finite() && b.is_finite() {
+                (a - b).abs() < 1e-6
+            } else {
+                a == b || (a.is_nan() && b.is_nan())
+            }
+        }
+        #[track_caller]
+        fn case(x: f32, expected_sin: f32, expected_cos: f32) {
+            let (ref_sin, ref_cos) = reference_sin_cos_pi_f32(x as f64);
+            assert!(
+                approx_same(ref_sin as f32, expected_sin)
+                    && approx_same(ref_cos as f32, expected_cos),
+                "case failed: x={x}, expected_sin={expected_sin}, expected_cos={expected_cos}, ref_sin={ref_sin}, ref_cos={ref_cos}",
+                x=x,
+                expected_sin=expected_sin,
+                expected_cos=expected_cos,
+                ref_sin=ref_sin,
+                ref_cos=ref_cos,
+            );
+        }
+        case(f32::NAN, f32::NAN, f32::NAN);
+        case(f32::INFINITY, f32::NAN, f32::NAN);
+        case(-f32::INFINITY, f32::NAN, f32::NAN);
+        case(-4.0, 0.0, 1.0);
+        case(-3.875, 0.3826834323650906, 0.9238795325112864);
+        case(-3.75, 0.7071067811865475, 0.7071067811865475);
+        case(-3.625, 0.9238795325112867, 0.3826834323650898);
+        case(-3.5, 1.0, 0.0);
+        case(-3.375, 0.9238795325112864, -0.3826834323650905);
+        case(-3.25, 0.7071067811865475, -0.7071067811865475);
+        case(-3.125, 0.3826834323650898, -0.9238795325112867);
+        case(-3.0, 0.0, -1.0);
+        case(-2.875, -0.3826834323650905, -0.9238795325112864);
+        case(-2.75, -0.7071067811865475, -0.7071067811865475);
+        case(-2.625, -0.9238795325112867, -0.3826834323650899);
+        case(-2.5, -1.0, 0.0);
+        case(-2.375, -0.9238795325112865, 0.3826834323650904);
+        case(-2.25, -0.7071067811865475, 0.7071067811865475);
+        case(-2.125, -0.3826834323650899, 0.9238795325112867);
+        case(-2.0, 0.0, 1.0);
+        case(-1.875, 0.3826834323650904, 0.9238795325112865);
+        case(-1.75, 0.7071067811865475, 0.7071067811865475);
+        case(-1.625, 0.9238795325112866, 0.38268343236509);
+        case(-1.5, 1.0, 0.0);
+        case(-1.375, 0.9238795325112865, -0.3826834323650903);
+        case(-1.25, 0.7071067811865475, -0.7071067811865475);
+        case(-1.125, 0.3826834323650896, -0.9238795325112869);
+        case(-1.0, 0.0, -1.0);
+        case(-0.875, -0.3826834323650899, -0.9238795325112867);
+        case(-0.75, -0.7071067811865475, -0.7071067811865475);
+        case(-0.625, -0.9238795325112867, -0.3826834323650897);
+        case(-0.5, -1.0, 0.0);
+        case(-0.375, -0.9238795325112867, 0.3826834323650898);
+        case(-0.25, -0.7071067811865475, 0.7071067811865475);
+        case(-0.125, -0.3826834323650898, 0.9238795325112867);
+        case(0.0, 0.0, 1.0);
+        case(0.125, 0.3826834323650898, 0.9238795325112867);
+        case(0.25, 0.7071067811865475, 0.7071067811865475);
+        case(0.375, 0.9238795325112867, 0.3826834323650898);
+        case(0.5, 1.0, 0.0);
+        case(0.625, 0.9238795325112867, -0.3826834323650897);
+        case(0.75, 0.7071067811865475, -0.7071067811865475);
+        case(0.875, 0.3826834323650899, -0.9238795325112867);
+        case(1.0, 0.0, -1.0);
+        case(1.125, -0.3826834323650896, -0.9238795325112869);
+        case(1.25, -0.7071067811865475, -0.7071067811865475);
+        case(1.375, -0.9238795325112865, -0.3826834323650903);
+        case(1.5, -1.0, 0.0);
+        case(1.625, -0.9238795325112866, 0.38268343236509);
+        case(1.75, -0.7071067811865475, 0.7071067811865475);
+        case(1.875, -0.3826834323650904, 0.9238795325112865);
+        case(2.0, 0.0, 1.0);
+        case(2.125, 0.3826834323650899, 0.9238795325112867);
+        case(2.25, 0.7071067811865475, 0.7071067811865475);
+        case(2.375, 0.9238795325112865, 0.3826834323650904);
+        case(2.5, 1.0, 0.0);
+        case(2.625, 0.9238795325112867, -0.3826834323650899);
+        case(2.75, 0.7071067811865475, -0.7071067811865475);
+        case(2.875, 0.3826834323650905, -0.9238795325112864);
+        case(3.0, 0.0, -1.0);
+        case(3.125, -0.3826834323650898, -0.9238795325112867);
+        case(3.25, -0.7071067811865475, -0.7071067811865475);
+        case(3.375, -0.9238795325112864, -0.3826834323650905);
+        case(3.5, -1.0, 0.0);
+        case(3.625, -0.9238795325112867, 0.3826834323650898);
+        case(3.75, -0.7071067811865475, 0.7071067811865475);
+        case(3.875, -0.3826834323650906, 0.9238795325112864);
+        case(4.0, 0.0, 1.0);
+    }
+
+    #[test]
+    #[cfg(feature = "full_tests")]
+    fn test_sin_pi_f32() {
+        for bits in 0..=u32::MAX {
+            check_ulp(
+                f32::from_bits(bits),
+                sin_cos_pi_check_ulp_callback,
+                |x| sin_pi_f32(Scalar, Value(x)).0,
+                |x| reference_sin_cos_pi_f32(x).0,
+            );
+        }
+    }
+
+    #[test]
+    #[cfg(feature = "full_tests")]
+    fn test_cos_pi_f32() {
+        for bits in 0..=u32::MAX {
+            check_ulp(
+                f32::from_bits(bits),
+                sin_cos_pi_check_ulp_callback,
+                |x| cos_pi_f32(Scalar, Value(x)).0,
+                |x| reference_sin_cos_pi_f32(x).1,
+            );
+        }
+    }
 }
index 08ede9e643616a73a8fe15fb189cfdf4a32e741b..184e5fcec9d5ac6d6a9a6037a14a5606b5a175e0 100644 (file)
@@ -139,6 +139,7 @@ pub trait PrimFloat:
     fn is_nan(self) -> bool;
     fn from_bits(bits: Self::BitsType) -> Self;
     fn to_bits(self) -> Self::BitsType;
+    fn abs(self) -> Self;
 }
 
 macro_rules! impl_float {
@@ -185,6 +186,9 @@ macro_rules! impl_float {
             fn to_bits(self) -> Self::BitsType {
                 self.to_bits()
             }
+            fn abs(self) -> Self {
+                $float::abs(self)
+            }
         }
     };
 }