switch to using separate VecBool8/16/32/64
authorJacob Lifshay <programmerjake@gmail.com>
Wed, 5 May 2021 05:55:36 +0000 (22:55 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Wed, 5 May 2021 05:55:36 +0000 (22:55 -0700)
Cargo.toml
src/f16.rs
src/ieee754.rs
src/ir.rs
src/scalar.rs
src/traits.rs
vector-math-proc-macro/Cargo.toml [new file with mode: 0644]
vector-math-proc-macro/src/lib.rs [new file with mode: 0644]

index 5f82460a1e0c49ffa8d95b64b8d795ac0eec75e5..858dbd9fb34bdc8b700d562657ce010cf2109851 100644 (file)
@@ -9,6 +9,7 @@ license = "MIT OR Apache-2.0"
 half = { version = "1.7.1", optional = true }
 typed-arena = { version = "2.0.1", optional = true }
 core_simd = { version = "0.1.0", git = "https://github.com/rust-lang/stdsimd", optional = true }
+vector-math-proc-macro = { version = "=0.1.0", path = "vector-math-proc-macro" }
 
 [features]
 default = ["f16", "fma"]
@@ -17,3 +18,6 @@ fma = ["std"]
 std = []
 ir = ["std", "typed-arena"]
 stdsimd = ["core_simd"]
+
+[workspace]
+members = [".", "vector-math-proc-macro"]
index bc40c782a20f514d5658463c6f5b098d80a72d56..ee13d9992135320be52859df3d5c51f6482ff933 100644 (file)
@@ -161,7 +161,7 @@ impl_bin_op_using_f32! {
     Rem, rem, RemAssign, rem_assign;
 }
 
-impl Float<u32> for F16 {
+impl Float for F16 {
     type FloatEncoding = F16;
     type BitsType = u16;
     type SignedBitsType = i16;
index 0da587de373707c95e51a364ff9fb75281450638..6f9fea7359f868e466a45ea84603a02c98a1c20a 100644 (file)
@@ -14,16 +14,16 @@ mod sealed {
 }
 
 pub trait FloatEncoding:
-    sealed::Sealed + Copy + 'static + Send + Sync + Float<u32> + Make<Context = Scalar>
+    sealed::Sealed + Copy + 'static + Send + Sync + Float + Make<Context = Scalar>
 {
     const EXPONENT_BIAS_UNSIGNED: Self::BitsType;
     const EXPONENT_BIAS_SIGNED: Self::SignedBitsType;
-    const SIGN_FIELD_WIDTH: u32;
-    const EXPONENT_FIELD_WIDTH: u32;
-    const MANTISSA_FIELD_WIDTH: u32;
-    const SIGN_FIELD_SHIFT: u32;
-    const EXPONENT_FIELD_SHIFT: u32;
-    const MANTISSA_FIELD_SHIFT: u32;
+    const SIGN_FIELD_WIDTH: Self::BitsType;
+    const EXPONENT_FIELD_WIDTH: Self::BitsType;
+    const MANTISSA_FIELD_WIDTH: Self::BitsType;
+    const SIGN_FIELD_SHIFT: Self::BitsType;
+    const EXPONENT_FIELD_SHIFT: Self::BitsType;
+    const MANTISSA_FIELD_SHIFT: Self::BitsType;
     const SIGN_FIELD_MASK: Self::BitsType;
     const EXPONENT_FIELD_MASK: Self::BitsType;
     const MANTISSA_FIELD_MASK: Self::BitsType;
@@ -45,12 +45,13 @@ macro_rules! impl_float_encoding {
             const EXPONENT_BIAS_UNSIGNED: Self::BitsType =
                 (1 << (Self::EXPONENT_FIELD_WIDTH - 1)) - 1;
             const EXPONENT_BIAS_SIGNED: Self::SignedBitsType = Self::EXPONENT_BIAS_UNSIGNED as _;
-            const SIGN_FIELD_WIDTH: u32 = 1;
-            const EXPONENT_FIELD_WIDTH: u32 = $exponent_field_width;
-            const MANTISSA_FIELD_WIDTH: u32 = $mantissa_field_width;
-            const SIGN_FIELD_SHIFT: u32 = Self::EXPONENT_FIELD_SHIFT + Self::EXPONENT_FIELD_WIDTH;
-            const EXPONENT_FIELD_SHIFT: u32 = Self::MANTISSA_FIELD_WIDTH;
-            const MANTISSA_FIELD_SHIFT: u32 = 0;
+            const SIGN_FIELD_WIDTH: Self::BitsType = 1;
+            const EXPONENT_FIELD_WIDTH: Self::BitsType = $exponent_field_width;
+            const MANTISSA_FIELD_WIDTH: Self::BitsType = $mantissa_field_width;
+            const SIGN_FIELD_SHIFT: Self::BitsType =
+                Self::EXPONENT_FIELD_SHIFT + Self::EXPONENT_FIELD_WIDTH;
+            const EXPONENT_FIELD_SHIFT: Self::BitsType = Self::MANTISSA_FIELD_WIDTH;
+            const MANTISSA_FIELD_SHIFT: Self::BitsType = 0;
             const SIGN_FIELD_MASK: Self::BitsType = 1 << Self::SIGN_FIELD_SHIFT;
             const EXPONENT_FIELD_MASK: Self::BitsType =
                 ((1 << Self::EXPONENT_FIELD_WIDTH) - 1) << Self::EXPONENT_FIELD_SHIFT;
index f1de05aa7d6bad699471d3afcbd4a5bd86724ac2..a31bece004894b45b290738fd954b96009f070fc 100644 (file)
--- a/src/ir.rs
+++ b/src/ir.rs
@@ -809,6 +809,23 @@ macro_rules! ir_value {
             }
         }
 
+        impl<'ctx> Select<$vec_name<'ctx>> for IrBool<'ctx> {
+            fn select(self, true_v: $vec_name<'ctx>, false_v: $vec_name<'ctx>) -> $vec_name<'ctx> {
+                let value = self
+                    .ctx
+                    .make_operation(
+                        Opcode::Select,
+                        [self.value, true_v.value, false_v.value],
+                        $vec_name::TYPE,
+                    )
+                    .into();
+                $vec_name {
+                    value,
+                    ctx: self.ctx,
+                }
+            }
+        }
+
         impl<'ctx> From<$name<'ctx>> for $vec_name<'ctx> {
             fn from(v: $name<'ctx>) -> Self {
                 let value = v
@@ -1060,12 +1077,41 @@ macro_rules! impl_number_ops {
     };
 }
 
+macro_rules! impl_bool_compare {
+    ($ty:ident) => {
+        impl<'ctx> Compare for $ty<'ctx> {
+            type Bool = Self;
+            fn eq(self, rhs: Self) -> Self::Bool {
+                !(self ^ rhs)
+            }
+            fn ne(self, rhs: Self) -> Self::Bool {
+                self ^ rhs
+            }
+            fn lt(self, rhs: Self) -> Self::Bool {
+                !self & rhs
+            }
+            fn gt(self, rhs: Self) -> Self::Bool {
+                self & !rhs
+            }
+            fn le(self, rhs: Self) -> Self::Bool {
+                !self | rhs
+            }
+            fn ge(self, rhs: Self) -> Self::Bool {
+                self | !rhs
+            }
+        }
+    };
+}
+
+impl_bool_compare!(IrBool);
+impl_bool_compare!(IrVecBool);
+
 macro_rules! impl_shift_ops {
-    ($ty:ident, $rhs:ident) => {
-        impl<'ctx> Shl<$rhs<'ctx>> for $ty<'ctx> {
+    ($ty:ident) => {
+        impl<'ctx> Shl for $ty<'ctx> {
             type Output = Self;
 
-            fn shl(self, rhs: $rhs<'ctx>) -> Self::Output {
+            fn shl(self, rhs: Self) -> Self::Output {
                 let value = self
                     .ctx
                     .make_operation(Opcode::Shl, [self.value, rhs.value], Self::TYPE)
@@ -1076,10 +1122,10 @@ macro_rules! impl_shift_ops {
                 }
             }
         }
-        impl<'ctx> Shr<$rhs<'ctx>> for $ty<'ctx> {
+        impl<'ctx> Shr for $ty<'ctx> {
             type Output = Self;
 
-            fn shr(self, rhs: $rhs<'ctx>) -> Self::Output {
+            fn shr(self, rhs: Self) -> Self::Output {
                 let value = self
                     .ctx
                     .make_operation(Opcode::Shr, [self.value, rhs.value], Self::TYPE)
@@ -1090,13 +1136,13 @@ macro_rules! impl_shift_ops {
                 }
             }
         }
-        impl<'ctx> ShlAssign<$rhs<'ctx>> for $ty<'ctx> {
-            fn shl_assign(&mut self, rhs: $rhs<'ctx>) {
+        impl<'ctx> ShlAssign for $ty<'ctx> {
+            fn shl_assign(&mut self, rhs: Self) {
                 *self = *self << rhs;
             }
         }
-        impl<'ctx> ShrAssign<$rhs<'ctx>> for $ty<'ctx> {
-            fn shr_assign(&mut self, rhs: $rhs<'ctx>) {
+        impl<'ctx> ShrAssign for $ty<'ctx> {
+            fn shr_assign(&mut self, rhs: Self) {
                 *self = *self >> rhs;
             }
         }
@@ -1123,8 +1169,8 @@ macro_rules! impl_neg {
 }
 
 macro_rules! impl_int_trait {
-    ($ty:ident, $u32:ident) => {
-        impl<'ctx> Int<$u32<'ctx>> for $ty<'ctx> {
+    ($ty:ident) => {
+        impl<'ctx> Int for $ty<'ctx> {
             fn leading_zeros(self) -> Self {
                 let value = self
                     .ctx
@@ -1163,12 +1209,12 @@ macro_rules! impl_integer_ops {
     ($scalar:ident, $vec:ident) => {
         impl_bit_ops!($scalar);
         impl_number_ops!($scalar, IrBool);
-        impl_shift_ops!($scalar, IrU32);
+        impl_shift_ops!($scalar);
         impl_bit_ops!($vec);
         impl_number_ops!($vec, IrVecBool);
-        impl_shift_ops!($vec, IrVecU32);
-        impl_int_trait!($scalar, IrU32);
-        impl_int_trait!($vec, IrVecU32);
+        impl_shift_ops!($vec);
+        impl_int_trait!($scalar);
+        impl_int_trait!($vec);
     };
 }
 
@@ -1176,8 +1222,8 @@ macro_rules! impl_uint_ops {
     ($scalar:ident, $vec:ident) => {
         impl_integer_ops!($scalar, $vec);
 
-        impl<'ctx> UInt<IrU32<'ctx>> for $scalar<'ctx> {}
-        impl<'ctx> UInt<IrVecU32<'ctx>> for $vec<'ctx> {}
+        impl<'ctx> UInt for $scalar<'ctx> {}
+        impl<'ctx> UInt for $vec<'ctx> {}
     };
 }
 
@@ -1192,8 +1238,8 @@ macro_rules! impl_sint_ops {
         impl_neg!($scalar);
         impl_neg!($vec);
 
-        impl<'ctx> SInt<IrU32<'ctx>> for $scalar<'ctx> {}
-        impl<'ctx> SInt<IrVecU32<'ctx>> for $vec<'ctx> {}
+        impl<'ctx> SInt for $scalar<'ctx> {}
+        impl<'ctx> SInt for $vec<'ctx> {}
     };
 }
 
@@ -1203,8 +1249,8 @@ impl_sint_ops!(IrI32, IrVecI32);
 impl_sint_ops!(IrI64, IrVecI64);
 
 macro_rules! impl_float {
-    ($float:ident, $bits:ident, $signed_bits:ident, $u32:ident) => {
-        impl<'ctx> Float<$u32<'ctx>> for $float<'ctx> {
+    ($float:ident, $bits:ident, $signed_bits:ident) => {
+        impl<'ctx> Float for $float<'ctx> {
             type FloatEncoding = <$float<'ctx> as Make>::Prim;
             type BitsType = $bits<'ctx>;
             type SignedBitsType = $signed_bits<'ctx>;
@@ -1330,8 +1376,8 @@ macro_rules! impl_float_ops {
         impl_number_ops!($vec, IrVecBool);
         impl_neg!($scalar);
         impl_neg!($vec);
-        impl_float!($scalar, $scalar_bits, $scalar_signed_bits, IrU32);
-        impl_float!($vec, $vec_bits, $vec_signed_bits, IrVecU32);
+        impl_float!($scalar, $scalar_bits, $scalar_signed_bits);
+        impl_float!($vec, $vec_bits, $vec_signed_bits);
     };
 }
 
@@ -1444,47 +1490,40 @@ ir_value!(
 );
 
 macro_rules! impl_convert_to {
-    ($($src:ident -> [$($dest:ident),*];)*) => {
-        $($(
-            impl<'ctx> ConvertTo<$dest<'ctx>> for $src<'ctx> {
-                fn to(self) -> $dest<'ctx> {
-                    let value = if $src::TYPE == $dest::TYPE {
-                        self.value
-                    } else {
-                        self
-                            .ctx
-                            .make_operation(Opcode::Cast, [self.value], $dest::TYPE)
-                            .into()
-                    };
-                    $dest {
-                        value,
-                        ctx: self.ctx,
-                    }
+    ($src:ident -> $dest:ident) => {
+        impl<'ctx> ConvertTo<$dest<'ctx>> for $src<'ctx> {
+            fn to(self) -> $dest<'ctx> {
+                let value = if $src::TYPE == $dest::TYPE {
+                    self.value
+                } else {
+                    self
+                        .ctx
+                        .make_operation(Opcode::Cast, [self.value], $dest::TYPE)
+                        .into()
+                };
+                $dest {
+                    value,
+                    ctx: self.ctx,
                 }
             }
-        )*)*
-    };
-    ([$($src:ident),*] -> $dest:tt;) => {
-        impl_convert_to! {
-            $(
-                $src -> $dest;
-            )*
         }
     };
-    ([$($src:ident),*];) => {
-        impl_convert_to! {
-            [$($src),*] -> [$($src),*];
-        }
+    ($first:ident $(, $ty:ident)*) => {
+        $(
+            impl_convert_to!($first -> $ty);
+            impl_convert_to!($ty -> $first);
+        )*
+        impl_convert_to![$($ty),*];
+    };
+    () => {
     };
 }
+impl_convert_to![IrU8, IrI8, IrU16, IrI16, IrF16, IrU32, IrI32, IrU64, IrI64, IrF32, IrF64];
 
-impl_convert_to! {
-    [IrU8, IrI8, IrU16, IrI16, IrF16, IrU32, IrI32, IrU64, IrI64, IrF32, IrF64];
-}
-
-impl_convert_to! {
-    [IrVecU8, IrVecI8, IrVecU16, IrVecI16, IrVecF16, IrVecU32, IrVecI32, IrVecU64, IrVecI64, IrVecF32, IrVecF64];
-}
+impl_convert_to![
+    IrVecU8, IrVecI8, IrVecU16, IrVecI16, IrVecF16, IrVecU32, IrVecI32, IrVecU64, IrVecI64,
+    IrVecF32, IrVecF64
+];
 
 macro_rules! impl_from {
     ($src:ident => [$($dest:ident),*]) => {
@@ -1564,15 +1603,18 @@ impl<'ctx> Context for &'ctx IrContext<'ctx> {
     type U64 = IrU64<'ctx>;
     type I64 = IrI64<'ctx>;
     type F64 = IrF64<'ctx>;
-    type VecBool = IrVecBool<'ctx>;
+    type VecBool8 = IrVecBool<'ctx>;
     type VecU8 = IrVecU8<'ctx>;
     type VecI8 = IrVecI8<'ctx>;
+    type VecBool16 = IrVecBool<'ctx>;
     type VecU16 = IrVecU16<'ctx>;
     type VecI16 = IrVecI16<'ctx>;
     type VecF16 = IrVecF16<'ctx>;
+    type VecBool32 = IrVecBool<'ctx>;
     type VecU32 = IrVecU32<'ctx>;
     type VecI32 = IrVecI32<'ctx>;
     type VecF32 = IrVecF32<'ctx>;
+    type VecBool64 = IrVecBool<'ctx>;
     type VecU64 = IrVecU64<'ctx>;
     type VecI64 = IrVecI64<'ctx>;
     type VecF64 = IrVecF64<'ctx>;
index d1e137d0a747c8fc4fe586ba960923675ab179c0..fb83af695799114ffdd7c31f37f5d77b0bd84144 100644 (file)
@@ -34,53 +34,32 @@ macro_rules! impl_context {
 impl_context! {
     impl Context for Scalar {
         type Bool = bool;
-
         type U8 = u8;
-
         type I8 = i8;
-
         type U16 = u16;
-
         type I16 = i16;
-
         type F16 = crate::f16::F16;
-
         type U32 = u32;
-
         type I32 = i32;
-
         type F32 = f32;
-
         type U64 = u64;
-
         type I64 = i64;
-
         type F64 = f64;
-
         #[vec]
-
-        type VecBool = bool;
-
+        type VecBool8 = bool;
         type VecU8 = u8;
-
         type VecI8 = i8;
-
+        type VecBool16 = bool;
         type VecU16 = u16;
-
         type VecI16 = i16;
-
         type VecF16 = crate::f16::F16;
-
+        type VecBool32 = bool;
         type VecU32 = u32;
-
         type VecI32 = i32;
-
         type VecF32 = f32;
-
+        type VecBool64 = bool;
         type VecU64 = u64;
-
         type VecI64 = i64;
-
         type VecF64 = f64;
     }
 }
index 942c67f827734139aa610d17c9ef9400e729fdf8..ffc766b81007113f57da2355fcba0330cc25b332 100644 (file)
+use crate::{f16::F16, ieee754::FloatEncoding, scalar::Scalar};
 use core::ops::{
     Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, DivAssign,
     Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign,
 };
 
-use crate::{f16::F16, ieee754::FloatEncoding, scalar::Scalar};
-
-#[rustfmt::skip] // work around for https://github.com/rust-lang/rustfmt/issues/4823
-macro_rules! make_float_type {
-    (
-        #[u32 = $u32:ident]
-        #[bool = $bool:ident]
-        [
-            $({
-                #[uint]
-                $uint_smaller:ident;
-                #[int]
-                $int_smaller:ident;
-                $(
-                    #[float]
-                    $float_smaller:ident;
-                )?
-            },)*
-        ],
-        {
-            #[uint]
-            $uint:ident;
-            #[int]
-            $int:ident;
-            #[float(prim = $float_prim:ident $(, scalar = $float_scalar:ident)?)]
-            $float:ident;
-        },
-        [
-            $({
-                #[uint]
-                $uint_larger:ident;
-                #[int]
-                $int_larger:ident;
-                $(
-                    #[float]
-                    $float_larger:ident;
-                )?
-            },)*
-        ]
-    ) => {
-        type $float: Float<Self::$u32, BitsType = Self::$uint, SignedBitsType = Self::$int, FloatEncoding = $float_prim>
-            $(+ From<Self::$float_scalar>)?
-            + Compare<Bool = Self::$bool>
-            + Make<Context = Self, Prim = $float_prim>
-            $(+ ConvertTo<Self::$uint_smaller>)*
-            $(+ ConvertTo<Self::$int_smaller>)*
-            $($(+ ConvertTo<Self::$float_smaller>)?)*
-            + ConvertTo<Self::$uint>
-            + ConvertTo<Self::$int>
-            $(+ ConvertTo<Self::$uint_larger>)*
-            $(+ ConvertTo<Self::$int_larger>)*
-            $($(+ Into<Self::$float_larger> + ConvertTo<Self::$float_larger>)?)*;
-    };
-    (
-        #[u32 = $u32:ident]
-        #[bool = $bool:ident]
-        [$($smaller:tt,)*],
-        {
-            #[uint]
-            $uint:ident;
-            #[int]
-            $int:ident;
-        },
-        [$($larger:tt,)*]
-    ) => {};
-}
-
-#[rustfmt::skip] // work around for https://github.com/rust-lang/rustfmt/issues/4823
-macro_rules! make_uint_int_float_type {
-    (
-        #[u32 = $u32:ident]
-        #[bool = $bool:ident]
-        [
-            $({
-                #[uint($($uint_smaller_traits:tt)*)]
-                $uint_smaller:ident;
-                #[int($($int_smaller_traits:tt)*)]
-                $int_smaller:ident;
-                $(
-                    #[float($($float_smaller_traits:tt)*)]
-                    $float_smaller:ident;
-                )?
-            },)*
-        ],
-        {
-            #[uint(prim = $uint_prim:ident $(, scalar = $uint_scalar:ident)?)]
-            $uint:ident;
-            #[int(prim = $int_prim:ident $(, scalar = $int_scalar:ident)?)]
-            $int:ident;
-            $(
-                #[float(prim = $float_prim:ident $(, scalar = $float_scalar:ident)?)]
-                $float:ident;
-            )?
-        },
-        [
-            $({
-                #[uint($($uint_larger_traits:tt)*)]
-                $uint_larger:ident;
-                #[int($($int_larger_traits:tt)*)]
-                $int_larger:ident;
-                $(
-                    #[float($($float_larger_traits:tt)*)]
-                    $float_larger:ident;
-                )?
-            },)*
-        ]
-    ) => {
-        type $uint: UInt<Self::$u32>
-            $(+ From<Self::$uint_scalar>)?
-            + Compare<Bool = Self::$bool>
-            + Make<Context = Self, Prim = $uint_prim>
-            $(+ ConvertTo<Self::$uint_smaller>)*
-            $(+ ConvertTo<Self::$int_smaller>)*
-            $($(+ ConvertTo<Self::$float_smaller>)?)*
-            + ConvertTo<Self::$int>
-            $(+ ConvertTo<Self::$float>)?
-            $(+ Into<Self::$uint_larger> + ConvertTo<Self::$uint_larger>)*
-            $(+ Into<Self::$int_larger> + ConvertTo<Self::$int_larger>)*
-            $($(+ Into<Self::$float_larger> + ConvertTo<Self::$float_larger>)?)*;
-        type $int: SInt<Self::$u32>
-            $(+ From<Self::$int_scalar>)?
-            + Compare<Bool = Self::$bool>
-            + Make<Context = Self, Prim = $int_prim>
-            $(+ ConvertTo<Self::$uint_smaller>)*
-            $(+ ConvertTo<Self::$int_smaller>)*
-            $($(+ ConvertTo<Self::$float_smaller>)?)*
-            + ConvertTo<Self::$uint>
-            $(+ ConvertTo<Self::$float>)?
-            $(+ ConvertTo<Self::$uint_larger>)*
-            $(+ Into<Self::$int_larger> + ConvertTo<Self::$int_larger>)*
-            $($(+ Into<Self::$float_larger> + ConvertTo<Self::$float_larger>)?)*;
-        make_float_type! {
-            #[u32 = $u32]
-            #[bool = $bool]
-            [
-                $({
-                    #[uint]
-                    $uint_smaller;
-                    #[int]
-                    $int_smaller;
-                    $(
-                        #[float]
-                        $float_smaller;
-                    )?
-                },)*
-            ],
-            {
-                #[uint]
-                $uint;
-                #[int]
-                $int;
-                $(
-                    #[float(prim = $float_prim $(, scalar = $float_scalar)?)]
-                    $float;
-                )?
-            },
-            [
-                $({
-                    #[uint]
-                    $uint_larger;
-                    #[int]
-                    $int_larger;
-                    $(
-                        #[float]
-                        $float_larger;
-                    )?
-                },)*
-            ]
-        }
-    };
-}
-
-macro_rules! make_uint_int_float_types {
-    (
-        #[u32 = $u32:ident]
-        #[bool = $bool:ident]
-        [$($smaller:tt,)*],
-        $current:tt,
-        [$first_larger:tt, $($larger:tt,)*]
-    ) => {
-        make_uint_int_float_type! {
-            #[u32 = $u32]
-            #[bool = $bool]
-            [$($smaller,)*],
-            $current,
-            [$first_larger, $($larger,)*]
-        }
-        make_uint_int_float_types! {
-            #[u32 = $u32]
-            #[bool = $bool]
-            [$($smaller,)* $current,],
-            $first_larger,
-            [$($larger,)*]
-        }
-    };
-    (
-        #[u32 = $u32:ident]
-        #[bool = $bool:ident]
-        [$($smaller:tt,)*],
-        $current:tt,
-        []
-    ) => {
-        make_uint_int_float_type! {
-            #[u32 = $u32]
-            #[bool = $bool]
-            [$($smaller,)*],
-            $current,
-            []
-        }
-    };
-}
-
-#[rustfmt::skip] // work around for https://github.com/rust-lang/rustfmt/issues/4823
-macro_rules! make_types {
-    (
-        #[bool]
-        $(#[scalar = $ScalarBool:ident])?
-        type $Bool:ident;
-
-        #[u8]
-        $(#[scalar = $ScalarU8:ident])?
-        type $U8:ident;
-
-        #[u16]
-        $(#[scalar = $ScalarU16:ident])?
-        type $U16:ident;
-
-        #[u32]
-        $(#[scalar = $ScalarU32:ident])?
-        type $U32:ident;
-
-        #[u64]
-        $(#[scalar = $ScalarU64:ident])?
-        type $U64:ident;
-
-        #[i8]
-        $(#[scalar = $ScalarI8:ident])?
-        type $I8:ident;
-
-        #[i16]
-        $(#[scalar = $ScalarI16:ident])?
-        type $I16:ident;
-
-        #[i32]
-        $(#[scalar = $ScalarI32:ident])?
-        type $I32:ident;
-
-        #[i64]
-        $(#[scalar = $ScalarI64:ident])?
-        type $I64:ident;
-
-        #[f16]
-        $(#[scalar = $ScalarF16:ident])?
-        type $F16:ident;
-
-        #[f32]
-        $(#[scalar = $ScalarF32:ident])?
-        type $F32:ident;
-
-        #[f64]
-        $(#[scalar = $ScalarF64:ident])?
-        type $F64:ident;
-    ) => {
-        type $Bool: Bool
-            $(+ From<Self::$ScalarBool>)?
-            + Make<Context = Self, Prim = bool>
-            + Select<Self::$Bool>
-            + Select<Self::$U8>
-            + Select<Self::$U16>
-            + Select<Self::$U32>
-            + Select<Self::$U64>
-            + Select<Self::$I8>
-            + Select<Self::$I16>
-            + Select<Self::$I32>
-            + Select<Self::$I64>
-            + Select<Self::$F16>
-            + Select<Self::$F32>
-            + Select<Self::$F64>;
-        make_uint_int_float_types! {
-            #[u32 = $U32]
-            #[bool = $Bool]
-            [],
-            {
-                #[uint(prim = u8 $(, scalar = $ScalarU8)?)]
-                $U8;
-                #[int(prim = i8 $(, scalar = $ScalarI8)?)]
-                $I8;
-            },
-            [
-                {
-                    #[uint(prim = u16 $(, scalar = $ScalarU16)?)]
-                    $U16;
-                    #[int(prim = i16 $(, scalar = $ScalarI16)?)]
-                    $I16;
-                    #[float(prim = F16 $(, scalar = $ScalarF16)?)]
-                    $F16;
-                },
-                {
-                    #[uint(prim = u32 $(, scalar = $ScalarU32)?)]
-                    $U32;
-                    #[int(prim = i32 $(, scalar = $ScalarI32)?)]
-                    $I32;
-                    #[float(prim = f32 $(, scalar = $ScalarF32)?)]
-                    $F32;
-                },
-                {
-                    #[uint(prim = u64 $(, scalar = $ScalarU64)?)]
-                    $U64;
-                    #[int(prim = i64 $(, scalar = $ScalarI64)?)]
-                    $I64;
-                    #[float(prim = f64 $(, scalar = $ScalarF64)?)]
-                    $F64;
-                },
-            ]
-        }
-    };
-}
-
 /// reference used to build IR for Kazan; an empty type for `core::simd`
 pub trait Context: Copy {
-    make_types! {
-        #[bool]
-        type Bool;
-
-        #[u8]
-        type U8;
-
-        #[u16]
-        type U16;
-
-        #[u32]
-        type U32;
-
-        #[u64]
-        type U64;
-
-        #[i8]
-        type I8;
-
-        #[i16]
-        type I16;
-
-        #[i32]
-        type I32;
-
-        #[i64]
-        type I64;
-
-        #[f16]
-        type F16;
-
-        #[f32]
-        type F32;
-
-        #[f64]
-        type F64;
-    }
-    make_types! {
-        #[bool]
-        #[scalar = Bool]
-        type VecBool;
-
-        #[u8]
-        #[scalar = U8]
-        type VecU8;
-
-        #[u16]
-        #[scalar = U16]
-        type VecU16;
-
-        #[u32]
-        #[scalar = U32]
-        type VecU32;
-
-        #[u64]
-        #[scalar = U64]
-        type VecU64;
-
-        #[i8]
-        #[scalar = I8]
-        type VecI8;
-
-        #[i16]
-        #[scalar = I16]
-        type VecI16;
-
-        #[i32]
-        #[scalar = I32]
-        type VecI32;
-
-        #[i64]
-        #[scalar = I64]
-        type VecI64;
-
-        #[f16]
-        #[scalar = F16]
-        type VecF16;
-
-        #[f32]
-        #[scalar = F32]
-        type VecF32;
-
-        #[f64]
-        #[scalar = F64]
-        type VecF64;
-    }
+    vector_math_proc_macro::make_context_types!();
     fn make<T: Make<Context = Self>>(self, v: T::Prim) -> T {
         T::make(self, v)
     }
@@ -425,33 +23,33 @@ pub trait ConvertTo<T> {
     fn to(self) -> T;
 }
 
+impl<T> ConvertTo<T> for T {
+    fn to(self) -> T {
+        self
+    }
+}
+
 macro_rules! impl_convert_to_using_as {
-    ($($src:ident -> [$($dest:ident),*];)*) => {
-        $($(
-            impl ConvertTo<$dest> for $src {
-                fn to(self) -> $dest {
-                    self as $dest
+    ($first:ident $(, $ty:ident)*) => {
+        $(
+            impl ConvertTo<$first> for $ty {
+                fn to(self) -> $first {
+                    self as $first
                 }
             }
-        )*)*
-    };
-    ([$($src:ident),*] -> $dest:tt;) => {
-        impl_convert_to_using_as! {
-            $(
-                $src -> $dest;
-            )*
-        }
+            impl ConvertTo<$ty> for $first {
+                fn to(self) -> $ty {
+                    self as $ty
+                }
+            }
+        )*
+        impl_convert_to_using_as![$($ty),*];
     };
-    ([$($src:ident),*];) => {
-        impl_convert_to_using_as! {
-            [$($src),*] -> [$($src),*];
-        }
+    () => {
     };
 }
 
-impl_convert_to_using_as! {
-    [u8, i8, u16, i16, u32, i32, u64, i64, f32, f64];
-}
+impl_convert_to_using_as![u8, i8, u16, i16, u32, i32, u64, i64, f32, f64];
 
 pub trait Number:
     Compare
@@ -507,13 +105,8 @@ impl<T> BitOps for T where
 {
 }
 
-pub trait Int<ShiftRhs>:
-    Number
-    + BitOps
-    + Shl<ShiftRhs, Output = Self>
-    + Shr<ShiftRhs, Output = Self>
-    + ShlAssign<ShiftRhs>
-    + ShrAssign<ShiftRhs>
+pub trait Int:
+    Number + BitOps + Shl<Output = Self> + Shr<Output = Self> + ShlAssign + ShrAssign
 {
     fn leading_zeros(self) -> Self;
     fn leading_ones(self) -> Self {
@@ -529,13 +122,13 @@ pub trait Int<ShiftRhs>:
     fn count_ones(self) -> Self;
 }
 
-pub trait UInt<ShiftRhs>: Int<ShiftRhs> {}
+pub trait UInt: Int {}
 
-pub trait SInt<ShiftRhs>: Int<ShiftRhs> + Neg<Output = Self> {}
+pub trait SInt: Int + Neg<Output = Self> {}
 
 macro_rules! impl_int {
     ($ty:ident) => {
-        impl Int<u32> for $ty {
+        impl Int for $ty {
             fn leading_zeros(self) -> Self {
                 self.leading_zeros() as Self
             }
@@ -562,7 +155,7 @@ macro_rules! impl_uint {
     ($($ty:ident),*) => {
         $(
             impl_int!($ty);
-            impl UInt<u32> for $ty {}
+            impl UInt for $ty {}
         )*
     };
 }
@@ -573,23 +166,21 @@ macro_rules! impl_sint {
     ($($ty:ident),*) => {
         $(
             impl_int!($ty);
-            impl SInt<u32> for $ty {}
+            impl SInt for $ty {}
         )*
     };
 }
 
 impl_sint![i8, i16, i32, i64];
 
-pub trait Float<BitsShiftRhs: Make<Context = Self::Context, Prim = u32>>:
-    Number + Neg<Output = Self>
-{
+pub trait Float: Number + Neg<Output = Self> {
     type FloatEncoding: FloatEncoding + Make<Context = Scalar, Prim = <Self as Make>::Prim>;
-    type BitsType: UInt<BitsShiftRhs>
-        + Make<Context = Self::Context, Prim = <Self::FloatEncoding as Float<u32>>::BitsType>
+    type BitsType: UInt
+        + Make<Context = Self::Context, Prim = <Self::FloatEncoding as Float>::BitsType>
         + ConvertTo<Self::SignedBitsType>
         + Compare<Bool = Self::Bool>;
-    type SignedBitsType: SInt<BitsShiftRhs>
-        + Make<Context = Self::Context, Prim = <Self::FloatEncoding as Float<u32>>::SignedBitsType>
+    type SignedBitsType: SInt
+        + Make<Context = Self::Context, Prim = <Self::FloatEncoding as Float>::SignedBitsType>
         + ConvertTo<Self::BitsType>
         + Compare<Bool = Self::Bool>;
     fn abs(self) -> Self;
@@ -648,7 +239,7 @@ pub trait Float<BitsShiftRhs: Make<Context = Self::Context, Prim = u32>>:
 
 macro_rules! impl_float {
     ($ty:ty, $bits:ty, $signed_bits:ty) => {
-        impl Float<u32> for $ty {
+        impl Float for $ty {
             type FloatEncoding = $ty;
             type BitsType = $bits;
             type SignedBitsType = $signed_bits;
@@ -763,4 +354,4 @@ macro_rules! impl_compare_using_partial_cmp {
     };
 }
 
-impl_compare_using_partial_cmp![u8, i8, u16, i16, F16, u32, i32, f32, u64, i64, f64];
+impl_compare_using_partial_cmp![bool, u8, i8, u16, i16, F16, u32, i32, f32, u64, i64, f64];
diff --git a/vector-math-proc-macro/Cargo.toml b/vector-math-proc-macro/Cargo.toml
new file mode 100644 (file)
index 0000000..f65c085
--- /dev/null
@@ -0,0 +1,14 @@
+[package]
+name = "vector-math-proc-macro"
+version = "0.1.0"
+authors = ["Jacob Lifshay <programmerjake@gmail.com>"]
+edition = "2018"
+license = "MIT OR Apache-2.0"
+
+[lib]
+proc-macro = true
+
+[dependencies]
+quote = "1.0"
+proc-macro2 = "1.0"
+syn = { version = "1.0", features = [] }
diff --git a/vector-math-proc-macro/src/lib.rs b/vector-math-proc-macro/src/lib.rs
new file mode 100644 (file)
index 0000000..1f5797a
--- /dev/null
@@ -0,0 +1,382 @@
+use std::{
+    cmp::Ordering,
+    collections::{BTreeSet, HashMap},
+    hash::Hash,
+};
+
+use proc_macro2::{Ident, Span, TokenStream};
+use quote::{quote, ToTokens};
+use syn::{
+    parse::{Parse, ParseStream},
+    parse_macro_input,
+};
+
+struct Input {}
+
+impl Parse for Input {
+    fn parse(_input: ParseStream) -> syn::Result<Self> {
+        Ok(Input {})
+    }
+}
+
+macro_rules! make_enum {
+    (
+        $vis:vis enum $ty:ident {
+            $(
+                $field:ident $(= $value:expr)?,
+            )*
+        }
+    ) => {
+        #[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
+        #[repr(u8)]
+        $vis enum $ty {
+            $(
+                $field $(= $value)?,
+            )*
+        }
+
+        impl $ty {
+            #[allow(dead_code)]
+            $vis const VALUES: &'static [Self] = &[
+                $(
+                    Self::$field,
+                )*
+            ];
+        }
+    };
+}
+
+make_enum! {
+    enum TypeKind {
+        Bool,
+        UInt,
+        SInt,
+        Float,
+    }
+}
+
+make_enum! {
+    enum VectorScalar {
+        Scalar,
+        Vector,
+    }
+}
+
+make_enum! {
+    enum TypeBits {
+        Bits8 = 8,
+        Bits16 = 16,
+        Bits32 = 32,
+        Bits64 = 64,
+    }
+}
+
+impl TypeBits {
+    const fn bits(self) -> u32 {
+        self as u8 as u32
+    }
+}
+
+make_enum! {
+    enum Convertibility {
+        Impossible,
+        Lossy,
+        Lossless,
+    }
+}
+
+impl Convertibility {
+    const fn make_possible(lossless: bool) -> Self {
+        if lossless {
+            Self::Lossless
+        } else {
+            Self::Lossy
+        }
+    }
+    const fn make_non_lossy(possible: bool) -> Self {
+        if possible {
+            Self::Lossless
+        } else {
+            Self::Impossible
+        }
+    }
+    const fn possible(self) -> bool {
+        match self {
+            Convertibility::Impossible => false,
+            Convertibility::Lossy | Convertibility::Lossless => true,
+        }
+    }
+}
+
+impl TypeKind {
+    fn is_valid(self, bits: TypeBits, vector_scalar: VectorScalar) -> bool {
+        match self {
+            TypeKind::Float => bits >= TypeBits::Bits16,
+            TypeKind::Bool => bits == TypeBits::Bits8 || vector_scalar == VectorScalar::Vector,
+            TypeKind::UInt | TypeKind::SInt => true,
+        }
+    }
+    fn prim_ty(self, bits: TypeBits) -> Ident {
+        Ident::new(
+            &match self {
+                TypeKind::Bool => "bool".into(),
+                TypeKind::UInt => format!("u{}", bits.bits()),
+                TypeKind::SInt => format!("i{}", bits.bits()),
+                TypeKind::Float if bits == TypeBits::Bits16 => "F16".into(),
+                TypeKind::Float => format!("f{}", bits.bits()),
+            },
+            Span::call_site(),
+        )
+    }
+    fn ty(self, bits: TypeBits, vector_scalar: VectorScalar) -> Ident {
+        let vec_prefix = match vector_scalar {
+            VectorScalar::Scalar => "",
+            VectorScalar::Vector => "Vec",
+        };
+        Ident::new(
+            &match self {
+                TypeKind::Bool => match vector_scalar {
+                    VectorScalar::Scalar => "Bool".into(),
+                    VectorScalar::Vector => format!("VecBool{}", bits.bits()),
+                },
+                TypeKind::UInt => format!("{}U{}", vec_prefix, bits.bits()),
+                TypeKind::SInt => format!("{}I{}", vec_prefix, bits.bits()),
+                TypeKind::Float => format!("{}F{}", vec_prefix, bits.bits()),
+            },
+            Span::call_site(),
+        )
+    }
+    fn convertibility_to(
+        self,
+        src_bits: TypeBits,
+        dest_type_kind: TypeKind,
+        dest_bits: TypeBits,
+    ) -> Convertibility {
+        Convertibility::make_possible(match (self, dest_type_kind) {
+            (TypeKind::Bool, _) | (_, TypeKind::Bool) => {
+                return Convertibility::make_non_lossy(self == dest_type_kind);
+            }
+            (TypeKind::UInt, TypeKind::UInt) => dest_bits >= src_bits,
+            (TypeKind::UInt, TypeKind::SInt) => dest_bits > src_bits,
+            (TypeKind::UInt, TypeKind::Float) => dest_bits > src_bits,
+            (TypeKind::SInt, TypeKind::UInt) => false,
+            (TypeKind::SInt, TypeKind::SInt) => dest_bits >= src_bits,
+            (TypeKind::SInt, TypeKind::Float) => dest_bits > src_bits,
+            (TypeKind::Float, TypeKind::UInt) => false,
+            (TypeKind::Float, TypeKind::SInt) => false,
+            (TypeKind::Float, TypeKind::Float) => dest_bits >= src_bits,
+        })
+    }
+}
+
+#[derive(Default, Debug)]
+struct TokenStreamSetElement {
+    token_stream: TokenStream,
+    text: String,
+}
+
+impl Ord for TokenStreamSetElement {
+    fn cmp(&self, other: &Self) -> Ordering {
+        self.text.cmp(&other.text)
+    }
+}
+
+impl PartialOrd for TokenStreamSetElement {
+    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
+        Some(self.cmp(other))
+    }
+}
+
+impl PartialEq for TokenStreamSetElement {
+    fn eq(&self, other: &Self) -> bool {
+        self.text == other.text
+    }
+}
+
+impl Eq for TokenStreamSetElement {}
+
+impl From<TokenStream> for TokenStreamSetElement {
+    fn from(token_stream: TokenStream) -> Self {
+        let text = token_stream.to_string();
+        Self { token_stream, text }
+    }
+}
+
+impl ToTokens for TokenStreamSetElement {
+    fn to_tokens(&self, tokens: &mut TokenStream) {
+        self.token_stream.to_tokens(tokens)
+    }
+
+    fn to_token_stream(&self) -> TokenStream {
+        self.token_stream.to_token_stream()
+    }
+
+    fn into_token_stream(self) -> TokenStream {
+        self.token_stream
+    }
+}
+
+type TokenStreamSet = BTreeSet<TokenStreamSetElement>;
+
+#[derive(Debug, Default)]
+struct TraitSets {
+    trait_sets_map: HashMap<(TypeKind, TypeBits, VectorScalar), TokenStreamSet>,
+}
+
+impl TraitSets {
+    fn get(
+        &mut self,
+        type_kind: TypeKind,
+        mut bits: TypeBits,
+        vector_scalar: VectorScalar,
+    ) -> &mut TokenStreamSet {
+        if type_kind == TypeKind::Bool && vector_scalar == VectorScalar::Scalar {
+            bits = TypeBits::Bits8;
+        }
+        self.trait_sets_map
+            .entry((type_kind, bits, vector_scalar))
+            .or_default()
+    }
+    fn add_trait(
+        &mut self,
+        type_kind: TypeKind,
+        bits: TypeBits,
+        vector_scalar: VectorScalar,
+        v: impl Into<TokenStreamSetElement>,
+    ) {
+        self.get(type_kind, bits, vector_scalar).insert(v.into());
+    }
+    fn fill(&mut self) {
+        for &bits in TypeBits::VALUES {
+            for &type_kind in TypeKind::VALUES {
+                for &vector_scalar in VectorScalar::VALUES {
+                    if !type_kind.is_valid(bits, vector_scalar) {
+                        continue;
+                    }
+                    let prim_ty = type_kind.prim_ty(bits);
+                    let ty = type_kind.ty(bits, vector_scalar);
+                    if vector_scalar == VectorScalar::Vector {
+                        let scalar_ty = type_kind.ty(bits, VectorScalar::Scalar);
+                        self.add_trait(
+                            type_kind,
+                            bits,
+                            vector_scalar,
+                            quote! { From<Self::#scalar_ty> },
+                        );
+                    }
+                    let bool_ty = TypeKind::Bool.ty(bits, vector_scalar);
+                    let uint_ty = TypeKind::UInt.ty(bits, vector_scalar);
+                    let sint_ty = TypeKind::SInt.ty(bits, vector_scalar);
+                    let type_trait = match type_kind {
+                        TypeKind::Bool => quote! { Bool },
+                        TypeKind::UInt => quote! { UInt },
+                        TypeKind::SInt => quote! { SInt },
+                        TypeKind::Float => quote! { Float<
+                            BitsType = Self::#uint_ty,
+                            SignedBitsType = Self::#sint_ty,
+                            FloatEncoding = #prim_ty,
+                        > },
+                    };
+                    self.add_trait(type_kind, bits, vector_scalar, type_trait);
+                    self.add_trait(
+                        type_kind,
+                        bits,
+                        vector_scalar,
+                        quote! { Compare<Bool = Self::#bool_ty> },
+                    );
+                    self.add_trait(
+                        TypeKind::Bool,
+                        bits,
+                        vector_scalar,
+                        quote! { Select<Self::#ty> },
+                    );
+                    self.add_trait(
+                        TypeKind::Bool,
+                        TypeBits::Bits8,
+                        VectorScalar::Scalar,
+                        quote! { Select<Self::#ty> },
+                    );
+                    for &other_bits in TypeBits::VALUES {
+                        for &other_type_kind in TypeKind::VALUES {
+                            if !other_type_kind.is_valid(other_bits, vector_scalar) {
+                                continue;
+                            }
+                            if other_bits == bits && other_type_kind == type_kind {
+                                continue;
+                            }
+                            let other_ty = other_type_kind.ty(other_bits, vector_scalar);
+                            let convertibility =
+                                other_type_kind.convertibility_to(other_bits, type_kind, bits);
+                            if convertibility == Convertibility::Lossless {
+                                self.add_trait(
+                                    type_kind,
+                                    bits,
+                                    vector_scalar,
+                                    quote! { From<Self::#other_ty> },
+                                );
+                            }
+                            if convertibility.possible() {
+                                self.add_trait(
+                                    other_type_kind,
+                                    other_bits,
+                                    vector_scalar,
+                                    quote! { ConvertTo<Self::#ty> },
+                                );
+                            }
+                        }
+                    }
+                    self.add_trait(
+                        type_kind,
+                        bits,
+                        vector_scalar,
+                        quote! { Make<Context = Self, Prim = #prim_ty> },
+                    );
+                }
+            }
+        }
+    }
+}
+
+impl Input {
+    fn to_tokens(&self) -> syn::Result<TokenStream> {
+        let mut types = Vec::new();
+        let mut trait_sets = TraitSets::default();
+        trait_sets.fill();
+        for &bits in TypeBits::VALUES {
+            for &type_kind in TypeKind::VALUES {
+                for &vector_scalar in VectorScalar::VALUES {
+                    if !type_kind.is_valid(bits, vector_scalar) {
+                        continue;
+                    }
+                    let ty = type_kind.ty(bits, vector_scalar);
+                    let traits = trait_sets.get(type_kind, bits, vector_scalar);
+                    types.push(quote! {
+                        type #ty: #(#traits)+*;
+                    });
+                }
+            }
+        }
+        Ok(quote! {#(#types)*})
+    }
+}
+
+#[proc_macro]
+pub fn make_context_types(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
+    let input = parse_macro_input!(input as Input);
+    match input.to_tokens() {
+        Ok(retval) => retval,
+        Err(err) => err.to_compile_error(),
+    }
+    .into()
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn test() -> syn::Result<()> {
+        Input {}.to_tokens()?;
+        Ok(())
+    }
+}