From: Jacob Lifshay Date: Mon, 10 May 2021 00:41:54 +0000 (-0700) Subject: refactor to easily allow algorithms generic over f16/32/64 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=6953218b789361857e2c2dcc45ed64a75cf8f4fd;p=vector-math.git refactor to easily allow algorithms generic over f16/32/64 --- diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index c7adb6b..4d9c541 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -11,6 +11,26 @@ rust-latest: matrix: - FEATURES: ["", "fma,ir", "f16,ir", "fma,f16,ir"] +rust-latest-release: + stage: build + image: rust:latest + script: + - cargo build --verbose --release --no-default-features --features="$FEATURES" + - cargo test --verbose --release --no-default-features --features="$FEATURES" + parallel: + matrix: + - FEATURES: + [ + "", + "fma,ir", + "f16,ir", + "fma,f16,ir", + "full_tests", + "full_tests,fma", + "full_tests,fma,f16", + "full_tests,f16", + ] + rust-nightly: stage: build image: rustlang/rust:nightly diff --git a/Cargo.toml b/Cargo.toml index d83abb8..e1b82f3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ fma = ["std"] std = [] ir = ["std", "typed-arena"] stdsimd = ["core_simd"] +# enable slow tests full_tests = [] [workspace] diff --git a/src/algorithms/ilogb.rs b/src/algorithms/ilogb.rs index 7da2733..36d9d54 100644 --- a/src/algorithms/ilogb.rs +++ b/src/algorithms/ilogb.rs @@ -1,6 +1,6 @@ use crate::{ f16::F16, - ieee754::FloatEncoding, + prim::PrimFloat, traits::{Compare, Context, ConvertTo, Float, Select}, }; diff --git a/src/algorithms/trig_pi.rs b/src/algorithms/trig_pi.rs index 38104a6..5b07c2a 100644 --- a/src/algorithms/trig_pi.rs +++ b/src/algorithms/trig_pi.rs @@ -1,7 +1,7 @@ use crate::{ f16::F16, - ieee754::FloatEncoding, - traits::{Compare, Context, ConvertFrom, ConvertTo, Float, Select}, + prim::PrimFloat,prim::PrimSInt,prim::PrimUInt, + traits::{Compare, Context, ConvertFrom, ConvertTo, Float, Make, Select}, }; mod consts { @@ -87,39 +87,58 @@ pub fn cos_pi_kernel_f16(ctx: Ctx, x: Ctx::VecF16) -> Ctx::VecF16 /// computes `(sin(pi * x), cos(pi * x))` /// not guaranteed to give correct sign for zero results -/// has an error of up to 2ULP -pub fn sin_cos_pi_f16(ctx: Ctx, x: Ctx::VecF16) -> (Ctx::VecF16, Ctx::VecF16) { - let two_f16: Ctx::VecF16 = ctx.make(2.0.to()); - let one_half: Ctx::VecF16 = ctx.make(0.5.to()); - let max_contiguous_integer: Ctx::VecF16 = - ctx.make((1u16 << (F16::MANTISSA_FIELD_WIDTH + 1)).to()); +/// inherits error from `sin_pi_kernel` and `cos_pi_kernel` +pub fn sin_cos_pi_impl< + Ctx: Context, + VecF: Float + Make, + PrimF: PrimFloat, + PrimU: PrimUInt, + SinPiKernel: FnOnce(Ctx, VecF) -> VecF, + CosPiKernel: FnOnce(Ctx, VecF) -> VecF, +>( + ctx: Ctx, + x: VecF, + sin_pi_kernel: SinPiKernel, + cos_pi_kernel: CosPiKernel, +) -> (VecF, VecF) { + let two_f: VecF = ctx.make(2.0.to()); + let one_half: VecF = ctx.make(0.5.to()); + let max_contiguous_integer: VecF = + ctx.make((PrimU::cvt_from(1) << (PrimF::MANTISSA_FIELD_WIDTH + 1.to())).to()); // if `x` is finite and bigger than `max_contiguous_integer`, then x is an even integer let in_range = x.abs().lt(max_contiguous_integer); // use `lt` so nans are counted as out-of-range let is_finite = x.is_finite(); - let nan: Ctx::VecF16 = ctx.make(f32::NAN.to()); - let zero_f16: Ctx::VecF16 = ctx.make(0.to()); - let one_f16: Ctx::VecF16 = ctx.make(1.to()); - let zero_i16: Ctx::VecI16 = ctx.make(0.to()); - let one_i16: Ctx::VecI16 = ctx.make(1.to()); - let two_i16: Ctx::VecI16 = ctx.make(2.to()); - let out_of_range_sin = is_finite.select(zero_f16, nan); - let out_of_range_cos = is_finite.select(one_f16, nan); - let xi = (x * two_f16).round(); + let nan: VecF = ctx.make(f32::NAN.to()); + let zero_f: VecF = ctx.make(0.to()); + let one_f: VecF = ctx.make(1.to()); + let zero_i: VecF::SignedBitsType = ctx.make(0.to()); + let one_i: VecF::SignedBitsType = ctx.make(1.to()); + let two_i: VecF::SignedBitsType = ctx.make(2.to()); + let out_of_range_sin = is_finite.select(zero_f, nan); + let out_of_range_cos = is_finite.select(one_f, nan); + let xi = (x * two_f).round(); let xk = x - xi * one_half; - let sk = sin_pi_kernel_f16(ctx, xk); - let ck = cos_pi_kernel_f16(ctx, xk); - let xi = Ctx::VecI16::cvt_from(xi); - let bit_0_clear = (xi & one_i16).eq(zero_i16); + let sk = sin_pi_kernel(ctx, xk); + let ck = cos_pi_kernel(ctx, xk); + let xi = VecF::SignedBitsType::cvt_from(xi); + let bit_0_clear = (xi & one_i).eq(zero_i); let st = bit_0_clear.select(sk, ck); let ct = bit_0_clear.select(ck, sk); - let s = (xi & two_i16).eq(zero_i16).select(st, -st); - let c = ((xi + one_i16) & two_i16).eq(zero_i16).select(ct, -ct); + let s = (xi & two_i).eq(zero_i).select(st, -st); + let c = ((xi + one_i) & two_i).eq(zero_i).select(ct, -ct); ( in_range.select(s, out_of_range_sin), in_range.select(c, out_of_range_cos), ) } +/// computes `(sin(pi * x), cos(pi * x))` +/// not guaranteed to give correct sign for zero results +/// has an error of up to 2ULP +pub fn sin_cos_pi_f16(ctx: Ctx, x: Ctx::VecF16) -> (Ctx::VecF16, Ctx::VecF16) { + sin_cos_pi_impl(ctx, x, sin_pi_kernel_f16, cos_pi_kernel_f16) +} + /// computes `sin(pi * x)` /// not guaranteed to give correct sign for zero results /// has an error of up to 2ULP diff --git a/src/f16.rs b/src/f16.rs index e9541b4..5253fef 100644 --- a/src/f16.rs +++ b/src/f16.rs @@ -1,11 +1,11 @@ -use core::ops::{ - Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign, -}; - use crate::{ scalar::Value, traits::{ConvertFrom, ConvertTo, Float}, }; +use core::{ + fmt, + ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign}, +}; #[cfg(feature = "f16")] use half::f16 as F16Impl; @@ -13,7 +13,7 @@ use half::f16 as F16Impl; #[cfg(not(feature = "f16"))] type F16Impl = u16; -#[derive(Clone, Copy, PartialEq, PartialOrd, Debug)] +#[derive(Clone, Copy, PartialEq, PartialOrd)] #[repr(transparent)] pub struct F16(F16Impl); @@ -40,6 +40,18 @@ macro_rules! f16_impl { }; } +impl fmt::Display for F16 { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f16_impl!(self.0.fmt(f), [f]) + } +} + +impl fmt::Debug for F16 { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f16_impl!(self.0.fmt(f), [f]) + } +} + impl Default for F16 { fn default() -> Self { f16_impl!(F16(F16Impl::default()), []) @@ -193,21 +205,29 @@ impl F16 { f16_impl!(Self::from_bits(self.to_bits() & 0x7FFF), []) } pub fn trunc(self) -> Self { - f32::from(self).trunc().to() + #[cfg(feature = "std")] + return f32::from(self).trunc().to(); + #[cfg(not(feature = "std"))] + todo!(); } - pub fn ceil(self) -> Self { - f32::from(self).ceil().to() + #[cfg(feature = "std")] + return f32::from(self).ceil().to(); + #[cfg(not(feature = "std"))] + todo!(); } - pub fn floor(self) -> Self { - f32::from(self).floor().to() + #[cfg(feature = "std")] + return f32::from(self).floor().to(); + #[cfg(not(feature = "std"))] + todo!(); } - pub fn round(self) -> Self { - f32::from(self).round().to() + #[cfg(feature = "std")] + return f32::from(self).round().to(); + #[cfg(not(feature = "std"))] + todo!(); } - #[cfg(feature = "fma")] pub fn fma(self, a: Self, b: Self) -> Self { (f64::from(self) * f64::from(a) + f64::from(b)).to() @@ -227,7 +247,7 @@ impl F16 { } impl Float for Value { - type FloatEncoding = F16; + type PrimFloat = F16; type BitsType = Value; type SignedBitsType = Value; diff --git a/src/ieee754.rs b/src/ieee754.rs deleted file mode 100644 index 3d70468..0000000 --- a/src/ieee754.rs +++ /dev/null @@ -1,95 +0,0 @@ -use crate::f16::F16; - -mod sealed { - use crate::f16::F16; - - pub trait Sealed {} - impl Sealed for F16 {} - impl Sealed for f32 {} - impl Sealed for f64 {} -} - -pub trait FloatEncoding: sealed::Sealed + Copy + 'static + Send + Sync { - type BitsType; - type SignedBitsType; - const EXPONENT_BIAS_UNSIGNED: Self::BitsType; - const EXPONENT_BIAS_SIGNED: Self::SignedBitsType; - const SIGN_FIELD_WIDTH: Self::BitsType; - const EXPONENT_FIELD_WIDTH: Self::BitsType; - const MANTISSA_FIELD_WIDTH: Self::BitsType; - const SIGN_FIELD_SHIFT: Self::BitsType; - const EXPONENT_FIELD_SHIFT: Self::BitsType; - const MANTISSA_FIELD_SHIFT: Self::BitsType; - const SIGN_FIELD_MASK: Self::BitsType; - const EXPONENT_FIELD_MASK: Self::BitsType; - const MANTISSA_FIELD_MASK: Self::BitsType; - const IMPLICIT_MANTISSA_BIT: Self::BitsType; - const ZERO_SUBNORMAL_EXPONENT: Self::BitsType; - const NAN_INFINITY_EXPONENT: Self::BitsType; - const INFINITY_BITS: Self::BitsType; - const NAN_BITS: Self::BitsType; -} - -macro_rules! impl_float_encoding { - ( - impl FloatEncoding for $float:ident { - type BitsType = $bits_type:ident; - type SignedBitsType = $signed_bits_type:ident; - const EXPONENT_FIELD_WIDTH: u32 = $exponent_field_width:literal; - const MANTISSA_FIELD_WIDTH: u32 = $mantissa_field_width:literal; - } - ) => { - impl FloatEncoding for $float { - type BitsType = $bits_type; - type SignedBitsType = $signed_bits_type; - const EXPONENT_BIAS_UNSIGNED: Self::BitsType = - (1 << (Self::EXPONENT_FIELD_WIDTH - 1)) - 1; - const EXPONENT_BIAS_SIGNED: Self::SignedBitsType = Self::EXPONENT_BIAS_UNSIGNED as _; - const SIGN_FIELD_WIDTH: Self::BitsType = 1; - const EXPONENT_FIELD_WIDTH: Self::BitsType = $exponent_field_width; - const MANTISSA_FIELD_WIDTH: Self::BitsType = $mantissa_field_width; - const SIGN_FIELD_SHIFT: Self::BitsType = - Self::EXPONENT_FIELD_SHIFT + Self::EXPONENT_FIELD_WIDTH; - const EXPONENT_FIELD_SHIFT: Self::BitsType = Self::MANTISSA_FIELD_WIDTH; - const MANTISSA_FIELD_SHIFT: Self::BitsType = 0; - const SIGN_FIELD_MASK: Self::BitsType = 1 << Self::SIGN_FIELD_SHIFT; - const EXPONENT_FIELD_MASK: Self::BitsType = - ((1 << Self::EXPONENT_FIELD_WIDTH) - 1) << Self::EXPONENT_FIELD_SHIFT; - const MANTISSA_FIELD_MASK: Self::BitsType = (1 << Self::MANTISSA_FIELD_WIDTH) - 1; - const IMPLICIT_MANTISSA_BIT: Self::BitsType = 1 << Self::MANTISSA_FIELD_WIDTH; - const ZERO_SUBNORMAL_EXPONENT: Self::BitsType = 0; - const NAN_INFINITY_EXPONENT: Self::BitsType = (1 << Self::EXPONENT_FIELD_WIDTH) - 1; - const INFINITY_BITS: Self::BitsType = - Self::NAN_INFINITY_EXPONENT << Self::EXPONENT_FIELD_SHIFT; - const NAN_BITS: Self::BitsType = - Self::INFINITY_BITS | (1 << (Self::MANTISSA_FIELD_WIDTH - 1)); - } - }; -} - -impl_float_encoding! { - impl FloatEncoding for F16 { - type BitsType = u16; - type SignedBitsType = i16; - const EXPONENT_FIELD_WIDTH: u32 = 5; - const MANTISSA_FIELD_WIDTH: u32 = 10; - } -} - -impl_float_encoding! { - impl FloatEncoding for f32 { - type BitsType = u32; - type SignedBitsType = i32; - const EXPONENT_FIELD_WIDTH: u32 = 8; - const MANTISSA_FIELD_WIDTH: u32 = 23; - } -} - -impl_float_encoding! { - impl FloatEncoding for f64 { - type BitsType = u64; - type SignedBitsType = i64; - const EXPONENT_FIELD_WIDTH: u32 = 11; - const MANTISSA_FIELD_WIDTH: u32 = 52; - } -} diff --git a/src/ir.rs b/src/ir.rs index e2b4a0e..2799020 100644 --- a/src/ir.rs +++ b/src/ir.rs @@ -1220,40 +1220,41 @@ macro_rules! impl_integer_ops { }; } -macro_rules! impl_uint_ops { - ($scalar:ident, $vec:ident) => { - impl_integer_ops!($scalar, $vec); - - impl<'ctx> UInt for $scalar<'ctx> {} - impl<'ctx> UInt for $vec<'ctx> {} - }; -} +macro_rules! impl_uint_sint_ops { + ($uint_scalar:ident, $uint_vec:ident, $sint_scalar:ident, $sint_vec:ident) => { + impl_integer_ops!($uint_scalar, $uint_vec); + impl_integer_ops!($sint_scalar, $sint_vec); + impl_neg!($sint_scalar); + impl_neg!($sint_vec); -impl_uint_ops!(IrU8, IrVecU8); -impl_uint_ops!(IrU16, IrVecU16); -impl_uint_ops!(IrU32, IrVecU32); -impl_uint_ops!(IrU64, IrVecU64); - -macro_rules! impl_sint_ops { - ($scalar:ident, $vec:ident) => { - impl_integer_ops!($scalar, $vec); - impl_neg!($scalar); - impl_neg!($vec); - - impl<'ctx> SInt for $scalar<'ctx> {} - impl<'ctx> SInt for $vec<'ctx> {} + impl<'ctx> UInt for $uint_scalar<'ctx> { + type PrimUInt = Self::Prim; + type SignedType = $sint_scalar<'ctx>; + } + impl<'ctx> UInt for $uint_vec<'ctx> { + type PrimUInt = Self::Prim; + type SignedType = $sint_vec<'ctx>; + } + impl<'ctx> SInt for $sint_scalar<'ctx> { + type PrimSInt = Self::Prim; + type UnsignedType = $uint_scalar<'ctx>; + } + impl<'ctx> SInt for $sint_vec<'ctx> { + type PrimSInt = Self::Prim; + type UnsignedType = $uint_vec<'ctx>; + } }; } -impl_sint_ops!(IrI8, IrVecI8); -impl_sint_ops!(IrI16, IrVecI16); -impl_sint_ops!(IrI32, IrVecI32); -impl_sint_ops!(IrI64, IrVecI64); +impl_uint_sint_ops!(IrU8, IrVecU8, IrI8, IrVecI8); +impl_uint_sint_ops!(IrU16, IrVecU16, IrI16, IrVecI16); +impl_uint_sint_ops!(IrU32, IrVecU32, IrI32, IrVecI32); +impl_uint_sint_ops!(IrU64, IrVecU64, IrI64, IrVecI64); macro_rules! impl_float { ($float:ident, $bits:ident, $signed_bits:ident) => { impl<'ctx> Float for $float<'ctx> { - type FloatEncoding = <$float<'ctx> as Make>::Prim; + type PrimFloat = <$float<'ctx> as Make>::Prim; type BitsType = $bits<'ctx>; type SignedBitsType = $signed_bits<'ctx>; fn abs(self) -> Self { diff --git a/src/lib.rs b/src/lib.rs index cecc2e4..60aa082 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,9 +6,9 @@ extern crate std; pub mod algorithms; pub mod f16; -pub mod ieee754; #[cfg(feature = "ir")] pub mod ir; +pub mod prim; pub mod scalar; #[cfg(feature = "stdsimd")] pub mod stdsimd; diff --git a/src/prim.rs b/src/prim.rs new file mode 100644 index 0000000..b2d2ebb --- /dev/null +++ b/src/prim.rs @@ -0,0 +1,205 @@ +use crate::{ + f16::F16, + traits::{ConvertFrom, ConvertTo}, +}; +use core::{fmt, hash, ops}; + +mod sealed { + use crate::f16::F16; + + pub trait Sealed {} + impl Sealed for F16 {} + impl Sealed for f32 {} + impl Sealed for f64 {} + impl Sealed for u8 {} + impl Sealed for u16 {} + impl Sealed for u32 {} + impl Sealed for u64 {} + impl Sealed for i8 {} + impl Sealed for i16 {} + impl Sealed for i32 {} + impl Sealed for i64 {} +} + +pub trait PrimBase: + sealed::Sealed + + Copy + + 'static + + Send + + Sync + + PartialOrd + + fmt::Debug + + fmt::Display + + ops::Add + + ops::Sub + + ops::Mul + + ops::Div + + ops::Rem + + ops::AddAssign + + ops::SubAssign + + ops::MulAssign + + ops::DivAssign + + ops::RemAssign + + ConvertFrom + + ConvertFrom + + ConvertFrom + + ConvertFrom + + ConvertFrom + + ConvertFrom + + ConvertFrom + + ConvertFrom + + ConvertFrom + + ConvertFrom + + ConvertFrom + + ConvertTo + + ConvertTo + + ConvertTo + + ConvertTo + + ConvertTo + + ConvertTo + + ConvertTo + + ConvertTo + + ConvertTo + + ConvertTo + + ConvertTo +{ +} + +pub trait PrimInt: + PrimBase + + Ord + + hash::Hash + + fmt::Binary + + fmt::LowerHex + + fmt::Octal + + fmt::UpperHex + + ops::BitAnd + + ops::BitOr + + ops::BitXor + + ops::Shl + + ops::Shr + + ops::Not + + ops::BitAndAssign + + ops::BitOrAssign + + ops::BitXorAssign + + ops::ShlAssign + + ops::ShrAssign +{ +} + +pub trait PrimUInt: PrimInt + ConvertFrom { + type SignedType: PrimSInt + ConvertFrom; +} + +pub trait PrimSInt: PrimInt + ops::Neg + ConvertFrom { + type UnsignedType: PrimUInt + ConvertFrom; +} + +macro_rules! impl_int { + ($uint:ident, $sint:ident) => { + impl PrimBase for $uint {} + impl PrimBase for $sint {} + impl PrimInt for $uint {} + impl PrimInt for $sint {} + impl PrimUInt for $uint { + type SignedType = $sint; + } + impl PrimSInt for $sint { + type UnsignedType = $uint; + } + }; +} + +impl_int!(u8, i8); +impl_int!(u16, i16); +impl_int!(u32, i32); +impl_int!(u64, i64); + +pub trait PrimFloat: + PrimBase + ops::Neg + ConvertFrom + ConvertFrom +{ + type BitsType: PrimUInt + ConvertFrom; + type SignedBitsType: PrimSInt + ConvertFrom; + const EXPONENT_BIAS_UNSIGNED: Self::BitsType; + const EXPONENT_BIAS_SIGNED: Self::SignedBitsType; + const SIGN_FIELD_WIDTH: Self::BitsType; + const EXPONENT_FIELD_WIDTH: Self::BitsType; + const MANTISSA_FIELD_WIDTH: Self::BitsType; + const SIGN_FIELD_SHIFT: Self::BitsType; + const EXPONENT_FIELD_SHIFT: Self::BitsType; + const MANTISSA_FIELD_SHIFT: Self::BitsType; + const SIGN_FIELD_MASK: Self::BitsType; + const EXPONENT_FIELD_MASK: Self::BitsType; + const MANTISSA_FIELD_MASK: Self::BitsType; + const IMPLICIT_MANTISSA_BIT: Self::BitsType; + const ZERO_SUBNORMAL_EXPONENT: Self::BitsType; + const NAN_INFINITY_EXPONENT: Self::BitsType; + const INFINITY_BITS: Self::BitsType; + const NAN_BITS: Self::BitsType; +} + +macro_rules! impl_float { + ( + impl PrimFloat for $float:ident { + type BitsType = $bits_type:ident; + type SignedBitsType = $signed_bits_type:ident; + const EXPONENT_FIELD_WIDTH: u32 = $exponent_field_width:literal; + const MANTISSA_FIELD_WIDTH: u32 = $mantissa_field_width:literal; + } + ) => { + impl PrimBase for $float {} + + impl PrimFloat for $float { + type BitsType = $bits_type; + type SignedBitsType = $signed_bits_type; + const EXPONENT_BIAS_UNSIGNED: Self::BitsType = + (1 << (Self::EXPONENT_FIELD_WIDTH - 1)) - 1; + const EXPONENT_BIAS_SIGNED: Self::SignedBitsType = Self::EXPONENT_BIAS_UNSIGNED as _; + const SIGN_FIELD_WIDTH: Self::BitsType = 1; + const EXPONENT_FIELD_WIDTH: Self::BitsType = $exponent_field_width; + const MANTISSA_FIELD_WIDTH: Self::BitsType = $mantissa_field_width; + const SIGN_FIELD_SHIFT: Self::BitsType = + Self::EXPONENT_FIELD_SHIFT + Self::EXPONENT_FIELD_WIDTH; + const EXPONENT_FIELD_SHIFT: Self::BitsType = Self::MANTISSA_FIELD_WIDTH; + const MANTISSA_FIELD_SHIFT: Self::BitsType = 0; + const SIGN_FIELD_MASK: Self::BitsType = 1 << Self::SIGN_FIELD_SHIFT; + const EXPONENT_FIELD_MASK: Self::BitsType = + ((1 << Self::EXPONENT_FIELD_WIDTH) - 1) << Self::EXPONENT_FIELD_SHIFT; + const MANTISSA_FIELD_MASK: Self::BitsType = (1 << Self::MANTISSA_FIELD_WIDTH) - 1; + const IMPLICIT_MANTISSA_BIT: Self::BitsType = 1 << Self::MANTISSA_FIELD_WIDTH; + const ZERO_SUBNORMAL_EXPONENT: Self::BitsType = 0; + const NAN_INFINITY_EXPONENT: Self::BitsType = (1 << Self::EXPONENT_FIELD_WIDTH) - 1; + const INFINITY_BITS: Self::BitsType = + Self::NAN_INFINITY_EXPONENT << Self::EXPONENT_FIELD_SHIFT; + const NAN_BITS: Self::BitsType = + Self::INFINITY_BITS | (1 << (Self::MANTISSA_FIELD_WIDTH - 1)); + } + }; +} + +impl_float! { + impl PrimFloat for F16 { + type BitsType = u16; + type SignedBitsType = i16; + const EXPONENT_FIELD_WIDTH: u32 = 5; + const MANTISSA_FIELD_WIDTH: u32 = 10; + } +} + +impl_float! { + impl PrimFloat for f32 { + type BitsType = u32; + type SignedBitsType = i32; + const EXPONENT_FIELD_WIDTH: u32 = 8; + const MANTISSA_FIELD_WIDTH: u32 = 23; + } +} + +impl_float! { + impl PrimFloat for f64 { + type BitsType = u64; + type SignedBitsType = i64; + const EXPONENT_FIELD_WIDTH: u32 = 11; + const MANTISSA_FIELD_WIDTH: u32 = 52; + } +} diff --git a/src/scalar.rs b/src/scalar.rs index 4eb5b98..30aaa9e 100644 --- a/src/scalar.rs +++ b/src/scalar.rs @@ -1,5 +1,6 @@ use crate::{ f16::F16, + prim::{PrimSInt, PrimUInt}, traits::{Bool, Compare, Context, ConvertFrom, Float, Int, Make, SInt, Select, UInt}, }; use core::ops::{ @@ -230,7 +231,10 @@ macro_rules! impl_uint { ($($ty:ident),*) => { $( impl_int!($ty); - impl UInt for Value<$ty> {} + impl UInt for Value<$ty> { + type PrimUInt = $ty; + type SignedType = Value<<$ty as PrimUInt>::SignedType>; + } )* }; } @@ -241,7 +245,10 @@ macro_rules! impl_sint { ($($ty:ident),*) => { $( impl_int!($ty); - impl SInt for Value<$ty> {} + impl SInt for Value<$ty> { + type PrimSInt = $ty; + type UnsignedType = Value<<$ty as PrimSInt>::UnsignedType>; + } )* }; } @@ -336,7 +343,7 @@ macro_rules! impl_float { ($ty:ident, $bits:ty, $signed_bits:ty) => { impl_float_ops!($ty); impl Float for Value<$ty> { - type FloatEncoding = $ty; + type PrimFloat = $ty; type BitsType = Value<$bits>; type SignedBitsType = Value<$signed_bits>; fn abs(self) -> Self { diff --git a/src/traits.rs b/src/traits.rs index 1877c21..2ec4815 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -1,4 +1,7 @@ -use crate::{f16::F16, ieee754::FloatEncoding}; +use crate::{ + f16::F16, + prim::{PrimFloat, PrimSInt, PrimUInt}, +}; use core::ops::{ Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, DivAssign, Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign, @@ -132,20 +135,42 @@ pub trait Int: fn count_ones(self) -> Self; } -pub trait UInt: Int {} - -pub trait SInt: Int + Neg {} - -pub trait Float: Number + Neg { - type FloatEncoding: FloatEncoding + From<::Prim> + Into<::Prim>; - type BitsType: UInt - + Make::BitsType> - + ConvertTo +pub trait UInt: Int + Make + ConvertFrom { + type PrimUInt: PrimUInt::PrimSInt>; + type SignedType: SInt + + ConvertFrom + + Make + Compare; - type SignedBitsType: SInt - + Make::SignedBitsType> - + ConvertTo +} + +pub trait SInt: + Int + Neg + Make + ConvertFrom +{ + type PrimSInt: PrimSInt::PrimUInt>; + type UnsignedType: UInt + + ConvertFrom + + Make + Compare; +} + +pub trait Float: + Number + + Neg + + Make + + ConvertFrom + + ConvertFrom +{ + type PrimFloat: PrimFloat; + type BitsType: UInt::BitsType, SignedType = Self::SignedBitsType> + + Make::BitsType> + + Compare + + ConvertFrom; + type SignedBitsType: SInt< + PrimSInt = ::SignedBitsType, + UnsignedType = Self::BitsType, + > + Make::SignedBitsType> + + Compare + + ConvertFrom; fn abs(self) -> Self; fn trunc(self) -> Self; fn ceil(self) -> Self; @@ -169,43 +194,38 @@ pub trait Float: Number + Neg { self.abs().eq(Self::infinity(self.ctx())) } fn infinity(ctx: Self::Context) -> Self { - Self::from_bits(ctx.make(Self::FloatEncoding::INFINITY_BITS)) + Self::from_bits(ctx.make(Self::PrimFloat::INFINITY_BITS)) } fn nan(ctx: Self::Context) -> Self { - Self::from_bits(ctx.make(Self::FloatEncoding::NAN_BITS)) + Self::from_bits(ctx.make(Self::PrimFloat::NAN_BITS)) } fn is_finite(self) -> Self::Bool; fn is_zero_or_subnormal(self) -> Self::Bool { - self.extract_exponent_field().eq(self - .ctx() - .make(Self::FloatEncoding::ZERO_SUBNORMAL_EXPONENT)) + self.extract_exponent_field() + .eq(self.ctx().make(Self::PrimFloat::ZERO_SUBNORMAL_EXPONENT)) } fn from_bits(v: Self::BitsType) -> Self; fn to_bits(self) -> Self::BitsType; fn extract_exponent_field(self) -> Self::BitsType { - let mask = self.ctx().make(Self::FloatEncoding::EXPONENT_FIELD_MASK); - let shift = self.ctx().make(Self::FloatEncoding::EXPONENT_FIELD_SHIFT); + let mask = self.ctx().make(Self::PrimFloat::EXPONENT_FIELD_MASK); + let shift = self.ctx().make(Self::PrimFloat::EXPONENT_FIELD_SHIFT); (self.to_bits() & mask) >> shift } fn extract_exponent_unbiased(self) -> Self::SignedBitsType { Self::sub_exponent_bias(self.extract_exponent_field()) } fn extract_mantissa_field(self) -> Self::BitsType { - let mask = self.ctx().make(Self::FloatEncoding::MANTISSA_FIELD_MASK); + let mask = self.ctx().make(Self::PrimFloat::MANTISSA_FIELD_MASK); self.to_bits() & mask } fn sub_exponent_bias(exponent_field: Self::BitsType) -> Self::SignedBitsType { - exponent_field.to() + Self::SignedBitsType::cvt_from(exponent_field) - exponent_field .ctx() - .make(Self::FloatEncoding::EXPONENT_BIAS_SIGNED) + .make(Self::PrimFloat::EXPONENT_BIAS_SIGNED) } fn add_exponent_bias(exponent: Self::SignedBitsType) -> Self::BitsType { - (exponent - + exponent - .ctx() - .make(Self::FloatEncoding::EXPONENT_BIAS_SIGNED)) - .to() + (exponent + exponent.ctx().make(Self::PrimFloat::EXPONENT_BIAS_SIGNED)).to() } } diff --git a/vector-math-proc-macro/src/lib.rs b/vector-math-proc-macro/src/lib.rs index 5c4de02..89d2bc4 100644 --- a/vector-math-proc-macro/src/lib.rs +++ b/vector-math-proc-macro/src/lib.rs @@ -269,12 +269,16 @@ impl TraitSets { let sint_ty = TypeKind::SInt.ty(bits, vector_scalar); let type_trait = match type_kind { TypeKind::Bool => quote! { Bool }, - TypeKind::UInt => quote! { UInt }, - TypeKind::SInt => quote! { SInt }, + TypeKind::UInt => { + quote! { UInt } + } + TypeKind::SInt => { + quote! { SInt } + } TypeKind::Float => quote! { Float< BitsType = Self::#uint_ty, SignedBitsType = Self::#sint_ty, - FloatEncoding = #prim_ty, + PrimFloat = #prim_ty, > }, }; self.add_trait(type_kind, bits, vector_scalar, type_trait);