From: Jacob Lifshay Date: Mon, 3 May 2021 11:06:34 +0000 (-0700) Subject: IR works! X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=7975aa9639f3a5a702b130a7cf992ffe71c86e2a;p=vector-math.git IR works! --- diff --git a/.gitignore b/.gitignore index 96ef6c0..6bfa6c9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /target Cargo.lock +/.vscode \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 9e011cb..3cf9665 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,9 +7,11 @@ license = "MIT OR Apache-2.0" [dependencies] half = { version = "1.7.1", optional = true } +typed-arena = { version = "2.0.1", optional = true } [features] default = ["f16", "fma"] f16 = ["half"] fma = ["std"] std = [] +ir = ["std", "typed-arena"] diff --git a/src/f16.rs b/src/f16.rs index 5bf2e69..20ed902 100644 --- a/src/f16.rs +++ b/src/f16.rs @@ -8,11 +8,10 @@ use crate::traits::{ConvertTo, Float}; use half::f16 as F16Impl; #[cfg(not(feature = "f16"))] -#[derive(Clone, Copy, PartialEq, PartialOrd, Debug)] -enum F16Impl {} +type F16Impl = u16; #[derive(Clone, Copy, PartialEq, PartialOrd, Debug)] -#[cfg_attr(feature = "f16", repr(transparent))] +#[repr(transparent)] pub struct F16(F16Impl); #[cfg(feature = "f16")] @@ -67,10 +66,7 @@ macro_rules! impl_from_f16 { $( impl From for $ty { fn from(v: F16) -> Self { - #[cfg(feature = "f16")] - return v.0.into(); - #[cfg(not(feature = "f16"))] - match v.0 {} + f16_impl!(v.0.into(), [v]) } } @@ -133,7 +129,7 @@ impl Neg for F16 { type Output = Self; fn neg(self) -> Self::Output { - Self::from_bits(self.to_bits() ^ 0x8000) + f16_impl!(Self::from_bits(self.to_bits() ^ 0x8000), []) } } @@ -169,7 +165,7 @@ impl Float for F16 { type BitsType = u16; fn abs(self) -> Self { - Self::from_bits(self.to_bits() & 0x7FFF) + f16_impl!(Self::from_bits(self.to_bits() & 0x7FFF), []) } fn trunc(self) -> Self { @@ -206,11 +202,17 @@ impl Float for F16 { } fn from_bits(v: Self::BitsType) -> Self { - f16_impl!(F16(F16Impl::from_bits(v)), [v]) + #[cfg(feature = "f16")] + return F16(F16Impl::from_bits(v)); + #[cfg(not(feature = "f16"))] + return F16(v); } fn to_bits(self) -> Self::BitsType { - f16_impl!(self.0.to_bits(), []) + #[cfg(feature = "f16")] + return self.0.to_bits(); + #[cfg(not(feature = "f16"))] + return self.0; } } diff --git a/src/ir.rs b/src/ir.rs new file mode 100644 index 0000000..3eaccb9 --- /dev/null +++ b/src/ir.rs @@ -0,0 +1,1574 @@ +use crate::{ + f16::F16, + traits::{Bool, Compare, Context, ConvertTo, Float, Int, Make, SInt, Select, UInt}, +}; +use std::{ + borrow::Borrow, + cell::{Cell, RefCell}, + collections::HashMap, + fmt::{self, Write as _}, + format, + ops::{ + Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, + DivAssign, Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, + SubAssign, + }, + string::String, + vec::Vec, +}; +use typed_arena::Arena; + +macro_rules! make_enum { + ( + $vis:vis enum $enum:ident { + $( + $(#[$meta:meta])* + $name:ident, + )* + } + ) => { + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] + $vis enum $enum { + $( + $(#[$meta])* + $name, + )* + } + + impl $enum { + $vis const fn as_str(self) -> &'static str { + match self { + $( + Self::$name => stringify!($name), + )* + } + } + } + + impl fmt::Display for $enum { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } + } + }; +} + +make_enum! { + pub enum ScalarType { + Bool, + U8, + I8, + U16, + I16, + F16, + U32, + I32, + F32, + U64, + I64, + F64, + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub struct VectorType { + pub element: ScalarType, +} + +impl fmt::Display for VectorType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "vec<{}>", self.element) + } +} + +impl From for Type { + fn from(v: ScalarType) -> Self { + Type::Scalar(v) + } +} + +impl From for Type { + fn from(v: VectorType) -> Self { + Type::Vector(v) + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum Type { + Scalar(ScalarType), + Vector(VectorType), +} + +impl fmt::Display for Type { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Type::Scalar(v) => v.fmt(f), + Type::Vector(v) => v.fmt(f), + } + } +} + +#[derive(Clone, Copy, PartialEq, Eq, Hash)] +pub enum ScalarConstant { + Bool(bool), + U8(u8), + U16(u16), + U32(u32), + U64(u64), + I8(i8), + I16(i16), + I32(i32), + I64(i64), + F16 { bits: u16 }, + F32 { bits: u32 }, + F64 { bits: u64 }, +} + +macro_rules! make_scalar_constant_get { + ($ty:ident, $enumerant:ident) => { + pub fn $ty(self) -> Option<$ty> { + if let Self::$enumerant(v) = self { + Some(v) + } else { + None + } + } + }; +} + +macro_rules! make_scalar_constant_from { + ($ty:ident, $enumerant:ident) => { + impl From<$ty> for ScalarConstant { + fn from(v: $ty) -> Self { + Self::$enumerant(v) + } + } + impl From<$ty> for Constant { + fn from(v: $ty) -> Self { + Self::Scalar(v.into()) + } + } + impl From<$ty> for Value<'_> { + fn from(v: $ty) -> Self { + Self::Constant(v.into()) + } + } + }; +} + +make_scalar_constant_from!(bool, Bool); +make_scalar_constant_from!(u8, U8); +make_scalar_constant_from!(u16, U16); +make_scalar_constant_from!(u32, U32); +make_scalar_constant_from!(u64, U64); +make_scalar_constant_from!(i8, I8); +make_scalar_constant_from!(i16, I16); +make_scalar_constant_from!(i32, I32); +make_scalar_constant_from!(i64, I64); + +impl ScalarConstant { + pub const fn ty(self) -> ScalarType { + match self { + ScalarConstant::Bool(_) => ScalarType::Bool, + ScalarConstant::U8(_) => ScalarType::U8, + ScalarConstant::U16(_) => ScalarType::U16, + ScalarConstant::U32(_) => ScalarType::U32, + ScalarConstant::U64(_) => ScalarType::U64, + ScalarConstant::I8(_) => ScalarType::I8, + ScalarConstant::I16(_) => ScalarType::I16, + ScalarConstant::I32(_) => ScalarType::I32, + ScalarConstant::I64(_) => ScalarType::I64, + ScalarConstant::F16 { .. } => ScalarType::F16, + ScalarConstant::F32 { .. } => ScalarType::F32, + ScalarConstant::F64 { .. } => ScalarType::F64, + } + } + pub const fn from_f16_bits(bits: u16) -> Self { + Self::F16 { bits } + } + pub const fn from_f32_bits(bits: u32) -> Self { + Self::F32 { bits } + } + pub const fn from_f64_bits(bits: u64) -> Self { + Self::F64 { bits } + } + pub const fn f16_bits(self) -> Option { + if let Self::F16 { bits } = self { + Some(bits) + } else { + None + } + } + pub const fn f32_bits(self) -> Option { + if let Self::F32 { bits } = self { + Some(bits) + } else { + None + } + } + pub const fn f64_bits(self) -> Option { + if let Self::F64 { bits } = self { + Some(bits) + } else { + None + } + } + make_scalar_constant_get!(bool, Bool); + make_scalar_constant_get!(u8, U8); + make_scalar_constant_get!(u16, U16); + make_scalar_constant_get!(u32, U32); + make_scalar_constant_get!(u64, U64); + make_scalar_constant_get!(i8, I8); + make_scalar_constant_get!(i16, I16); + make_scalar_constant_get!(i32, I32); + make_scalar_constant_get!(i64, I64); +} + +impl fmt::Display for ScalarConstant { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ScalarConstant::Bool(false) => write!(f, "false"), + ScalarConstant::Bool(true) => write!(f, "true"), + ScalarConstant::U8(v) => write!(f, "{}_u8", v), + ScalarConstant::U16(v) => write!(f, "{}_u16", v), + ScalarConstant::U32(v) => write!(f, "{}_u32", v), + ScalarConstant::U64(v) => write!(f, "{}_u64", v), + ScalarConstant::I8(v) => write!(f, "{}_i8", v), + ScalarConstant::I16(v) => write!(f, "{}_i16", v), + ScalarConstant::I32(v) => write!(f, "{}_i32", v), + ScalarConstant::I64(v) => write!(f, "{}_i64", v), + ScalarConstant::F16 { bits } => write!(f, "{:#X}_f16", bits), + ScalarConstant::F32 { bits } => write!(f, "{:#X}_f32", bits), + ScalarConstant::F64 { bits } => write!(f, "{:#X}_f64", bits), + } + } +} + +impl fmt::Debug for ScalarConstant { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(self, f) + } +} + +#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] +pub struct VectorSplatConstant { + pub element: ScalarConstant, +} + +impl VectorSplatConstant { + pub const fn ty(self) -> VectorType { + VectorType { + element: self.element.ty(), + } + } +} + +impl fmt::Display for VectorSplatConstant { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "splat({})", self.element) + } +} + +impl From for Constant { + fn from(v: ScalarConstant) -> Self { + Constant::Scalar(v) + } +} + +impl From for Constant { + fn from(v: VectorSplatConstant) -> Self { + Constant::VectorSplat(v) + } +} + +impl From for Value<'_> { + fn from(v: ScalarConstant) -> Self { + Value::Constant(v.into()) + } +} + +impl From for Value<'_> { + fn from(v: VectorSplatConstant) -> Self { + Value::Constant(v.into()) + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum Constant { + Scalar(ScalarConstant), + VectorSplat(VectorSplatConstant), +} + +impl Constant { + pub const fn ty(self) -> Type { + match self { + Constant::Scalar(v) => Type::Scalar(v.ty()), + Constant::VectorSplat(v) => Type::Vector(v.ty()), + } + } +} + +impl fmt::Display for Constant { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Constant::Scalar(v) => v.fmt(f), + Constant::VectorSplat(v) => v.fmt(f), + } + } +} + +#[derive(Debug)] +pub struct Input<'ctx> { + pub name: &'ctx str, + pub ty: Type, +} + +impl fmt::Display for Input<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "in<{}>", self.name) + } +} + +#[derive(Copy, Clone)] +pub enum Value<'ctx> { + Input(&'ctx Input<'ctx>), + Constant(Constant), + OpResult(&'ctx Operation<'ctx>), +} + +impl<'ctx> Value<'ctx> { + pub const fn ty(self) -> Type { + match self { + Value::Input(v) => v.ty, + Value::Constant(v) => v.ty(), + Value::OpResult(v) => v.result_type, + } + } +} + +impl fmt::Debug for Value<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Value::Input(v) => v.fmt(f), + Value::Constant(v) => v.fmt(f), + Value::OpResult(v) => v.result_id.fmt(f), + } + } +} + +impl fmt::Display for Value<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Value::Input(v) => v.fmt(f), + Value::Constant(v) => v.fmt(f), + Value::OpResult(v) => v.result_id.fmt(f), + } + } +} + +impl<'ctx> From<&'ctx Input<'ctx>> for Value<'ctx> { + fn from(v: &'ctx Input<'ctx>) -> Self { + Value::Input(v) + } +} + +impl<'ctx> From<&'ctx Operation<'ctx>> for Value<'ctx> { + fn from(v: &'ctx Operation<'ctx>) -> Self { + Value::OpResult(v) + } +} + +impl<'ctx> From for Value<'ctx> { + fn from(v: Constant) -> Self { + Value::Constant(v) + } +} + +make_enum! { + pub enum Opcode { + Add, + Sub, + Mul, + Div, + Rem, + Fma, + Cast, + And, + Or, + Xor, + Not, + Shl, + Shr, + Neg, + Abs, + Trunc, + Ceil, + Floor, + Round, + IsInfinite, + IsFinite, + ToBits, + FromBits, + Splat, + CompareEq, + CompareNe, + CompareLt, + CompareLe, + CompareGt, + CompareGe, + Select, + } +} + +#[derive(Debug)] +pub struct Operation<'ctx> { + pub opcode: Opcode, + pub arguments: Vec>, + pub result_type: Type, + pub result_id: OperationId, +} + +impl fmt::Display for Operation<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{}: {} = {}", + self.result_id, self.result_type, self.opcode + )?; + let mut separator = " "; + for i in &self.arguments { + write!(f, "{}{}", separator, i)?; + separator = ", "; + } + Ok(()) + } +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] +pub struct OperationId(pub u64); + +impl fmt::Display for OperationId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "op_{}", self.0) + } +} + +#[derive(Default)] +pub struct IrContext<'ctx> { + bytes_arena: Arena, + inputs_arena: Arena>, + inputs: RefCell>>, + operations_arena: Arena>, + operations: RefCell>>, + next_operation_result_id: Cell, +} + +impl fmt::Debug for IrContext<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("IrContext { .. }") + } +} + +impl<'ctx> IrContext<'ctx> { + pub fn new() -> Self { + Self::default() + } + pub fn make_input + Into, T: Into>( + &'ctx self, + name: N, + ty: T, + ) -> &'ctx Input<'ctx> { + let mut inputs = self.inputs.borrow_mut(); + let name_str = name.borrow(); + let ty = ty.into(); + if !name_str.is_empty() && !inputs.contains_key(name_str) { + let name = self.bytes_arena.alloc_str(name_str); + let input = self.inputs_arena.alloc(Input { name, ty }); + inputs.insert(name, input); + return input; + } + let mut name: String = name.into(); + if name.is_empty() { + name = "in".into(); + } + let name_len = name.len(); + let mut tag = 2usize; + loop { + name.truncate(name_len); + write!(name, "_{}", tag).unwrap(); + if !inputs.contains_key(&*name) { + let name = self.bytes_arena.alloc_str(&name); + let input = self.inputs_arena.alloc(Input { name, ty }); + inputs.insert(name, input); + return input; + } + tag += 1; + } + } + pub fn make_operation>>, T: Into>( + &'ctx self, + opcode: Opcode, + arguments: A, + result_type: T, + ) -> &'ctx Operation<'ctx> { + let arguments = arguments.into(); + let result_type = result_type.into(); + let result_id = OperationId(self.next_operation_result_id.get()); + self.next_operation_result_id.set(result_id.0 + 1); + let operation = self.operations_arena.alloc(Operation { + opcode, + arguments, + result_type, + result_id, + }); + self.operations.borrow_mut().push(operation); + operation + } + pub fn replace_operations( + &'ctx self, + new_operations: Vec<&'ctx Operation<'ctx>>, + ) -> Vec<&'ctx Operation<'ctx>> { + self.operations.replace(new_operations) + } +} + +#[derive(Debug)] +pub struct IrFunction<'ctx> { + pub inputs: Vec<&'ctx Input<'ctx>>, + pub operations: Vec<&'ctx Operation<'ctx>>, + pub outputs: Vec>, +} + +impl fmt::Display for IrFunction<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "function(")?; + let mut first = true; + for input in &self.inputs { + if first { + first = false + } else { + write!(f, ", ")?; + } + write!(f, "{}: {}", input, input.ty)?; + } + match self.outputs.len() { + 0 => writeln!(f, ") {{")?, + 1 => writeln!(f, ") -> {} {{", self.outputs[0].ty())?, + _ => { + write!(f, ") -> ({}", self.outputs[0].ty())?; + for output in self.outputs.iter().skip(1) { + write!(f, ", {}", output.ty())?; + } + writeln!(f, ") {{")?; + } + } + for operation in &self.operations { + writeln!(f, " {}", operation)?; + } + match self.outputs.len() { + 0 => writeln!(f, "}}")?, + 1 => writeln!(f, " Return {}\n}}", self.outputs[0])?, + _ => { + write!(f, " Return {}", self.outputs[0])?; + for output in self.outputs.iter().skip(1) { + write!(f, ", {}", output)?; + } + writeln!(f, "\n}}")?; + } + } + Ok(()) + } +} + +impl<'ctx> IrFunction<'ctx> { + pub fn make>(ctx: &'ctx IrContext<'ctx>, f: F) -> Self { + let old_operations = ctx.replace_operations(Vec::new()); + let (v, inputs) = F::make_inputs(ctx); + let outputs = f.call(ctx, v).outputs_to_vec(); + let operations = ctx.replace_operations(old_operations); + Self { + inputs, + operations, + outputs, + } + } +} + +pub trait IrFunctionMaker<'ctx>: Sized { + type Inputs; + type Outputs: IrFunctionMakerOutputs<'ctx>; + fn call(self, ctx: &'ctx IrContext<'ctx>, inputs: Self::Inputs) -> Self::Outputs; + fn make_inputs(ctx: &'ctx IrContext<'ctx>) -> (Self::Inputs, Vec<&'ctx Input<'ctx>>); +} + +pub trait IrFunctionMakerOutputs<'ctx> { + fn outputs_to_vec(self) -> Vec>; +} + +impl<'ctx, T: IrValue<'ctx>> IrFunctionMakerOutputs<'ctx> for T { + fn outputs_to_vec(self) -> Vec> { + [self.value()].into() + } +} + +impl<'ctx> IrFunctionMakerOutputs<'ctx> for () { + fn outputs_to_vec(self) -> Vec> { + Vec::new() + } +} + +impl<'ctx, R: IrFunctionMakerOutputs<'ctx>> IrFunctionMaker<'ctx> + for fn(&'ctx IrContext<'ctx>) -> R +{ + type Inputs = (); + type Outputs = R; + fn call(self, ctx: &'ctx IrContext<'ctx>, _inputs: Self::Inputs) -> Self::Outputs { + self(ctx) + } + fn make_inputs(_ctx: &'ctx IrContext<'ctx>) -> (Self::Inputs, Vec<&'ctx Input<'ctx>>) { + ((), Vec::new()) + } +} + +macro_rules! impl_ir_function_maker_io { + () => {}; + ($first_arg:ident: $first_arg_ty:ident, $($arg:ident: $arg_ty:ident,)*) => { + impl<'ctx, $first_arg_ty, $($arg_ty,)* R> IrFunctionMaker<'ctx> for fn(&'ctx IrContext<'ctx>, $first_arg_ty $(, $arg_ty)*) -> R + where + $first_arg_ty: IrValue<'ctx>, + $($arg_ty: IrValue<'ctx>,)* + R: IrFunctionMakerOutputs<'ctx>, + { + type Inputs = ($first_arg_ty, $($arg_ty,)*); + type Outputs = R; + fn call(self, ctx: &'ctx IrContext<'ctx>, inputs: Self::Inputs) -> Self::Outputs { + let ($first_arg, $($arg,)*) = inputs; + self(ctx, $first_arg$(, $arg)*) + } + fn make_inputs(ctx: &'ctx IrContext<'ctx>) -> (Self::Inputs, Vec<&'ctx Input<'ctx>>) { + let mut $first_arg = String::new(); + $(let mut $arg = String::new();)* + for (index, arg) in [&mut $first_arg $(, &mut $arg)*].iter_mut().enumerate() { + **arg = format!("arg_{}", index); + } + let $first_arg = $first_arg_ty::make_input(ctx, $first_arg); + $(let $arg = $arg_ty::make_input(ctx, $arg);)* + (($first_arg.0, $($arg.0,)*), [$first_arg.1 $(, $arg.1)*].into()) + } + } + impl<'ctx, $first_arg_ty, $($arg_ty),*> IrFunctionMakerOutputs<'ctx> for ($first_arg_ty, $($arg_ty,)*) + where + $first_arg_ty: IrValue<'ctx>, + $($arg_ty: IrValue<'ctx>,)* + { + fn outputs_to_vec(self) -> Vec> { + let ($first_arg, $($arg,)*) = self; + [$first_arg.value() $(, $arg.value())*].into() + } + } + impl_ir_function_maker_io!($($arg: $arg_ty,)*); + }; +} + +impl_ir_function_maker_io!( + in0: In0, + in1: In1, + in2: In2, + in3: In3, + in4: In4, + in5: In5, + in6: In6, + in7: In7, + in8: In8, + in9: In9, + in10: In10, + in11: In11, +); + +pub trait IrValue<'ctx>: Copy { + const TYPE: Type; + fn new(ctx: &'ctx IrContext<'ctx>, value: Value<'ctx>) -> Self; + fn make_input + Into>( + ctx: &'ctx IrContext<'ctx>, + name: N, + ) -> (Self, &'ctx Input<'ctx>) { + let input = ctx.make_input(name, Self::TYPE); + (Self::new(ctx, input.into()), input) + } + fn ctx(self) -> &'ctx IrContext<'ctx>; + fn value(self) -> Value<'ctx>; +} + +macro_rules! ir_value { + ($name:ident, $vec_name:ident, TYPE = $scalar_type:ident, fn make($make_var:ident: $prim:ident) {$make:expr}) => { + #[derive(Clone, Copy, Debug)] + pub struct $name<'ctx> { + pub value: Value<'ctx>, + pub ctx: &'ctx IrContext<'ctx>, + } + + impl<'ctx> IrValue<'ctx> for $name<'ctx> { + const TYPE: Type = Type::Scalar(Self::SCALAR_TYPE); + fn new(ctx: &'ctx IrContext<'ctx>, value: Value<'ctx>) -> Self { + assert_eq!(value.ty(), Self::TYPE); + Self { ctx, value } + } + fn ctx(self) -> &'ctx IrContext<'ctx> { + self.ctx + } + fn value(self) -> Value<'ctx> { + self.value + } + } + + impl<'ctx> $name<'ctx> { + pub const SCALAR_TYPE: ScalarType = ScalarType::$scalar_type; + } + + impl<'ctx> Make<&'ctx IrContext<'ctx>> for $name<'ctx> { + type Prim = $prim; + + fn make(ctx: &'ctx IrContext<'ctx>, $make_var: Self::Prim) -> Self { + let value: ScalarConstant = $make; + let value = value.into(); + Self { value, ctx } + } + } + + #[derive(Clone, Copy, Debug)] + pub struct $vec_name<'ctx> { + pub value: Value<'ctx>, + pub ctx: &'ctx IrContext<'ctx>, + } + + impl<'ctx> IrValue<'ctx> for $vec_name<'ctx> { + const TYPE: Type = Type::Vector(Self::VECTOR_TYPE); + fn new(ctx: &'ctx IrContext<'ctx>, value: Value<'ctx>) -> Self { + assert_eq!(value.ty(), Self::TYPE); + Self { ctx, value } + } + fn ctx(self) -> &'ctx IrContext<'ctx> { + self.ctx + } + fn value(self) -> Value<'ctx> { + self.value + } + } + + impl<'ctx> $vec_name<'ctx> { + pub const VECTOR_TYPE: VectorType = VectorType { + element: ScalarType::$scalar_type, + }; + } + + impl<'ctx> Make<&'ctx IrContext<'ctx>> for $vec_name<'ctx> { + type Prim = $prim; + + fn make(ctx: &'ctx IrContext<'ctx>, $make_var: Self::Prim) -> Self { + let element = $make; + Self { + value: VectorSplatConstant { element }.into(), + ctx, + } + } + } + + impl<'ctx> Select<$name<'ctx>> for IrBool<'ctx> { + fn select(self, true_v: $name<'ctx>, false_v: $name<'ctx>) -> $name<'ctx> { + let value = self + .ctx + .make_operation( + Opcode::Select, + [self.value, true_v.value, false_v.value], + $name::TYPE, + ) + .into(); + $name { + value, + ctx: self.ctx, + } + } + } + + impl<'ctx> Select<$vec_name<'ctx>> for IrVecBool<'ctx> { + fn select(self, true_v: $vec_name<'ctx>, false_v: $vec_name<'ctx>) -> $vec_name<'ctx> { + let value = self + .ctx + .make_operation( + Opcode::Select, + [self.value, true_v.value, false_v.value], + $vec_name::TYPE, + ) + .into(); + $vec_name { + value, + ctx: self.ctx, + } + } + } + + impl<'ctx> From<$name<'ctx>> for $vec_name<'ctx> { + fn from(v: $name<'ctx>) -> Self { + let value = v + .ctx + .make_operation(Opcode::Splat, [v.value], $vec_name::TYPE) + .into(); + Self { value, ctx: v.ctx } + } + } + }; +} + +macro_rules! impl_bit_ops { + ($ty:ident) => { + impl<'ctx> BitAnd for $ty<'ctx> { + type Output = Self; + + fn bitand(self, rhs: Self) -> Self::Output { + let value = self + .ctx + .make_operation(Opcode::And, [self.value, rhs.value], Self::TYPE) + .into(); + Self { + value, + ctx: self.ctx, + } + } + } + impl<'ctx> BitOr for $ty<'ctx> { + type Output = Self; + + fn bitor(self, rhs: Self) -> Self::Output { + let value = self + .ctx + .make_operation(Opcode::Or, [self.value, rhs.value], Self::TYPE) + .into(); + Self { + value, + ctx: self.ctx, + } + } + } + impl<'ctx> BitXor for $ty<'ctx> { + type Output = Self; + + fn bitxor(self, rhs: Self) -> Self::Output { + let value = self + .ctx + .make_operation(Opcode::Xor, [self.value, rhs.value], Self::TYPE) + .into(); + Self { + value, + ctx: self.ctx, + } + } + } + impl<'ctx> Not for $ty<'ctx> { + type Output = Self; + + fn not(self) -> Self::Output { + let value = self + .ctx + .make_operation(Opcode::Not, [self.value], Self::TYPE) + .into(); + Self { + value, + ctx: self.ctx, + } + } + } + impl<'ctx> BitAndAssign for $ty<'ctx> { + fn bitand_assign(&mut self, rhs: Self) { + *self = *self & rhs; + } + } + impl<'ctx> BitOrAssign for $ty<'ctx> { + fn bitor_assign(&mut self, rhs: Self) { + *self = *self | rhs; + } + } + impl<'ctx> BitXorAssign for $ty<'ctx> { + fn bitxor_assign(&mut self, rhs: Self) { + *self = *self ^ rhs; + } + } + }; +} + +macro_rules! impl_number_ops { + ($ty:ident, $bool:ident) => { + impl<'ctx> Add for $ty<'ctx> { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + let value = self + .ctx + .make_operation(Opcode::Add, [self.value, rhs.value], Self::TYPE) + .into(); + Self { + value, + ctx: self.ctx, + } + } + } + impl<'ctx> Sub for $ty<'ctx> { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + let value = self + .ctx + .make_operation(Opcode::Sub, [self.value, rhs.value], Self::TYPE) + .into(); + Self { + value, + ctx: self.ctx, + } + } + } + impl<'ctx> Mul for $ty<'ctx> { + type Output = Self; + + fn mul(self, rhs: Self) -> Self::Output { + let value = self + .ctx + .make_operation(Opcode::Mul, [self.value, rhs.value], Self::TYPE) + .into(); + Self { + value, + ctx: self.ctx, + } + } + } + impl<'ctx> Div for $ty<'ctx> { + type Output = Self; + + fn div(self, rhs: Self) -> Self::Output { + let value = self + .ctx + .make_operation(Opcode::Div, [self.value, rhs.value], Self::TYPE) + .into(); + Self { + value, + ctx: self.ctx, + } + } + } + impl<'ctx> Rem for $ty<'ctx> { + type Output = Self; + + fn rem(self, rhs: Self) -> Self::Output { + let value = self + .ctx + .make_operation(Opcode::Rem, [self.value, rhs.value], Self::TYPE) + .into(); + Self { + value, + ctx: self.ctx, + } + } + } + impl<'ctx> AddAssign for $ty<'ctx> { + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } + } + impl<'ctx> SubAssign for $ty<'ctx> { + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } + } + impl<'ctx> MulAssign for $ty<'ctx> { + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } + } + impl<'ctx> DivAssign for $ty<'ctx> { + fn div_assign(&mut self, rhs: Self) { + *self = *self / rhs; + } + } + impl<'ctx> RemAssign for $ty<'ctx> { + fn rem_assign(&mut self, rhs: Self) { + *self = *self % rhs; + } + } + impl<'ctx> Compare for $ty<'ctx> { + type Bool = $bool<'ctx>; + fn eq(self, rhs: Self) -> Self::Bool { + let value = self + .ctx + .make_operation(Opcode::CompareEq, [self.value, rhs.value], $bool::TYPE) + .into(); + $bool { + value, + ctx: self.ctx, + } + } + fn ne(self, rhs: Self) -> Self::Bool { + let value = self + .ctx + .make_operation(Opcode::CompareNe, [self.value, rhs.value], $bool::TYPE) + .into(); + $bool { + value, + ctx: self.ctx, + } + } + fn lt(self, rhs: Self) -> Self::Bool { + let value = self + .ctx + .make_operation(Opcode::CompareLt, [self.value, rhs.value], $bool::TYPE) + .into(); + $bool { + value, + ctx: self.ctx, + } + } + fn gt(self, rhs: Self) -> Self::Bool { + let value = self + .ctx + .make_operation(Opcode::CompareGt, [self.value, rhs.value], $bool::TYPE) + .into(); + $bool { + value, + ctx: self.ctx, + } + } + fn le(self, rhs: Self) -> Self::Bool { + let value = self + .ctx + .make_operation(Opcode::CompareLe, [self.value, rhs.value], $bool::TYPE) + .into(); + $bool { + value, + ctx: self.ctx, + } + } + fn ge(self, rhs: Self) -> Self::Bool { + let value = self + .ctx + .make_operation(Opcode::CompareGe, [self.value, rhs.value], $bool::TYPE) + .into(); + $bool { + value, + ctx: self.ctx, + } + } + } + }; +} + +macro_rules! impl_shift_ops { + ($ty:ident, $rhs:ident) => { + impl<'ctx> Shl<$rhs<'ctx>> for $ty<'ctx> { + type Output = Self; + + fn shl(self, rhs: $rhs<'ctx>) -> Self::Output { + let value = self + .ctx + .make_operation(Opcode::Shl, [self.value, rhs.value], Self::TYPE) + .into(); + Self { + value, + ctx: self.ctx, + } + } + } + impl<'ctx> Shr<$rhs<'ctx>> for $ty<'ctx> { + type Output = Self; + + fn shr(self, rhs: $rhs<'ctx>) -> Self::Output { + let value = self + .ctx + .make_operation(Opcode::Shr, [self.value, rhs.value], Self::TYPE) + .into(); + Self { + value, + ctx: self.ctx, + } + } + } + impl<'ctx> ShlAssign<$rhs<'ctx>> for $ty<'ctx> { + fn shl_assign(&mut self, rhs: $rhs<'ctx>) { + *self = *self << rhs; + } + } + impl<'ctx> ShrAssign<$rhs<'ctx>> for $ty<'ctx> { + fn shr_assign(&mut self, rhs: $rhs<'ctx>) { + *self = *self >> rhs; + } + } + }; +} + +macro_rules! impl_neg { + ($ty:ident) => { + impl<'ctx> Neg for $ty<'ctx> { + type Output = Self; + + fn neg(self) -> Self::Output { + let value = self + .ctx + .make_operation(Opcode::Neg, [self.value], Self::TYPE) + .into(); + Self { + value, + ctx: self.ctx, + } + } + } + }; +} + +macro_rules! impl_integer_ops { + ($scalar:ident, $vec:ident) => { + impl_bit_ops!($scalar); + impl_number_ops!($scalar, IrBool); + impl_shift_ops!($scalar, IrU32); + impl_bit_ops!($vec); + impl_number_ops!($vec, IrVecBool); + impl_shift_ops!($vec, IrVecU32); + + impl<'ctx> Int> for $scalar<'ctx> {} + impl<'ctx> Int> for $vec<'ctx> {} + }; +} + +macro_rules! impl_uint_ops { + ($scalar:ident, $vec:ident) => { + impl_integer_ops!($scalar, $vec); + + impl<'ctx> UInt> for $scalar<'ctx> {} + impl<'ctx> UInt> for $vec<'ctx> {} + }; +} + +impl_uint_ops!(IrU8, IrVecU8); +impl_uint_ops!(IrU16, IrVecU16); +impl_uint_ops!(IrU32, IrVecU32); +impl_uint_ops!(IrU64, IrVecU64); + +macro_rules! impl_sint_ops { + ($scalar:ident, $vec:ident) => { + impl_integer_ops!($scalar, $vec); + impl_neg!($scalar); + impl_neg!($vec); + + impl<'ctx> SInt> for $scalar<'ctx> {} + impl<'ctx> SInt> for $vec<'ctx> {} + }; +} + +impl_sint_ops!(IrI8, IrVecI8); +impl_sint_ops!(IrI16, IrVecI16); +impl_sint_ops!(IrI32, IrVecI32); +impl_sint_ops!(IrI64, IrVecI64); + +macro_rules! impl_float { + ($float:ident, $bits:ident, $u32:ident) => { + impl<'ctx> Float<$u32<'ctx>> for $float<'ctx> { + type BitsType = $bits<'ctx>; + fn abs(self) -> Self { + let value = self + .ctx + .make_operation(Opcode::Abs, [self.value], Self::TYPE) + .into(); + Self { + value, + ctx: self.ctx, + } + } + fn trunc(self) -> Self { + let value = self + .ctx + .make_operation(Opcode::Trunc, [self.value], Self::TYPE) + .into(); + Self { + value, + ctx: self.ctx, + } + } + fn ceil(self) -> Self { + let value = self + .ctx + .make_operation(Opcode::Ceil, [self.value], Self::TYPE) + .into(); + Self { + value, + ctx: self.ctx, + } + } + fn floor(self) -> Self { + let value = self + .ctx + .make_operation(Opcode::Floor, [self.value], Self::TYPE) + .into(); + Self { + value, + ctx: self.ctx, + } + } + fn round(self) -> Self { + let value = self + .ctx + .make_operation(Opcode::Round, [self.value], Self::TYPE) + .into(); + Self { + value, + ctx: self.ctx, + } + } + #[cfg(feature = "fma")] + fn fma(self, a: Self, b: Self) -> Self { + let value = self + .ctx + .make_operation(Opcode::Fma, [self.value, a.value, b.value], Self::TYPE) + .into(); + Self { + value, + ctx: self.ctx, + } + } + fn is_nan(self) -> Self::Bool { + let value = self + .ctx + .make_operation( + Opcode::CompareNe, + [self.value, self.value], + Self::Bool::TYPE, + ) + .into(); + Self::Bool { + value, + ctx: self.ctx, + } + } + fn is_infinite(self) -> Self::Bool { + let value = self + .ctx + .make_operation(Opcode::IsInfinite, [self.value], Self::Bool::TYPE) + .into(); + Self::Bool { + value, + ctx: self.ctx, + } + } + fn is_finite(self) -> Self::Bool { + let value = self + .ctx + .make_operation(Opcode::IsFinite, [self.value], Self::Bool::TYPE) + .into(); + Self::Bool { + value, + ctx: self.ctx, + } + } + fn from_bits(v: Self::BitsType) -> Self { + let value = v + .ctx + .make_operation(Opcode::FromBits, [v.value], Self::TYPE) + .into(); + Self { value, ctx: v.ctx } + } + fn to_bits(self) -> Self::BitsType { + let value = self + .ctx + .make_operation(Opcode::ToBits, [self.value], Self::BitsType::TYPE) + .into(); + Self::BitsType { + value, + ctx: self.ctx, + } + } + } + }; +} + +macro_rules! impl_float_ops { + ($scalar:ident, $scalar_bits:ident, $vec:ident, $vec_bits:ident) => { + impl_number_ops!($scalar, IrBool); + impl_number_ops!($vec, IrVecBool); + impl_neg!($scalar); + impl_neg!($vec); + impl_float!($scalar, $scalar_bits, IrU32); + impl_float!($vec, $vec_bits, IrVecU32); + }; +} + +impl_float_ops!(IrF16, IrU16, IrVecF16, IrVecU16); +impl_float_ops!(IrF32, IrU32, IrVecF32, IrVecU32); +impl_float_ops!(IrF64, IrU64, IrVecF64, IrVecU64); + +ir_value!( + IrBool, + IrVecBool, + TYPE = Bool, + fn make(v: bool) { + v.into() + } +); + +impl<'ctx> Bool for IrBool<'ctx> {} +impl<'ctx> Bool for IrVecBool<'ctx> {} + +impl_bit_ops!(IrBool); +impl_bit_ops!(IrVecBool); + +ir_value!( + IrU8, + IrVecU8, + TYPE = U8, + fn make(v: u8) { + v.into() + } +); +ir_value!( + IrU16, + IrVecU16, + TYPE = U16, + fn make(v: u16) { + v.into() + } +); +ir_value!( + IrU32, + IrVecU32, + TYPE = U32, + fn make(v: u32) { + v.into() + } +); +ir_value!( + IrU64, + IrVecU64, + TYPE = U64, + fn make(v: u64) { + v.into() + } +); +ir_value!( + IrI8, + IrVecI8, + TYPE = I8, + fn make(v: i8) { + v.into() + } +); +ir_value!( + IrI16, + IrVecI16, + TYPE = I16, + fn make(v: i16) { + v.into() + } +); +ir_value!( + IrI32, + IrVecI32, + TYPE = I32, + fn make(v: i32) { + v.into() + } +); +ir_value!( + IrI64, + IrVecI64, + TYPE = I64, + fn make(v: i64) { + v.into() + } +); +ir_value!( + IrF16, + IrVecF16, + TYPE = F16, + fn make(v: F16) { + ScalarConstant::from_f16_bits(v.to_bits()) + } +); +ir_value!( + IrF32, + IrVecF32, + TYPE = F32, + fn make(v: f32) { + ScalarConstant::from_f32_bits(v.to_bits()) + } +); +ir_value!( + IrF64, + IrVecF64, + TYPE = F64, + fn make(v: f64) { + ScalarConstant::from_f64_bits(v.to_bits()) + } +); + +macro_rules! impl_convert_to { + ($($src:ident -> [$($dest:ident),*];)*) => { + $($( + impl<'ctx> ConvertTo<$dest<'ctx>> for $src<'ctx> { + fn to(self) -> $dest<'ctx> { + let value = if $src::TYPE == $dest::TYPE { + self.value + } else { + self + .ctx + .make_operation(Opcode::Cast, [self.value], $dest::TYPE) + .into() + }; + $dest { + value, + ctx: self.ctx, + } + } + } + )*)* + }; + ([$($src:ident),*] -> $dest:tt;) => { + impl_convert_to! { + $( + $src -> $dest; + )* + } + }; + ([$($src:ident),*];) => { + impl_convert_to! { + [$($src),*] -> [$($src),*]; + } + }; +} + +impl_convert_to! { + [IrU8, IrI8, IrU16, IrI16, IrF16, IrU32, IrI32, IrU64, IrI64, IrF32, IrF64]; +} + +impl_convert_to! { + [IrVecU8, IrVecI8, IrVecU16, IrVecI16, IrVecF16, IrVecU32, IrVecI32, IrVecU64, IrVecI64, IrVecF32, IrVecF64]; +} + +macro_rules! impl_from { + ($src:ident => [$($dest:ident),*]) => { + $( + impl<'ctx> From<$src<'ctx>> for $dest<'ctx> { + fn from(v: $src<'ctx>) -> Self { + v.to() + } + } + )* + }; +} + +macro_rules! impl_froms { + ( + #[u8] $u8:ident; + #[i8] $i8:ident; + #[u16] $u16:ident; + #[i16] $i16:ident; + #[f16] $f16:ident; + #[u32] $u32:ident; + #[i32] $i32:ident; + #[f32] $f32:ident; + #[u64] $u64:ident; + #[i64] $i64:ident; + #[f64] $f64:ident; + ) => { + impl_from!($u8 => [$u16, $i16, $f16, $u32, $i32, $f32, $u64, $i64, $f64]); + impl_from!($u16 => [$u32, $i32, $f32, $u64, $i64, $f64]); + impl_from!($u32 => [$u64, $i64, $f64]); + impl_from!($i8 => [$i16, $f16, $i32, $f32, $i64, $f64]); + impl_from!($i16 => [$i32, $f32, $i64, $f64]); + impl_from!($i32 => [$i64, $f64]); + impl_from!($f16 => [$f32, $f64]); + impl_from!($f32 => [$f64]); + }; +} + +impl_froms! { + #[u8] IrU8; + #[i8] IrI8; + #[u16] IrU16; + #[i16] IrI16; + #[f16] IrF16; + #[u32] IrU32; + #[i32] IrI32; + #[f32] IrF32; + #[u64] IrU64; + #[i64] IrI64; + #[f64] IrF64; +} + +impl_froms! { + #[u8] IrVecU8; + #[i8] IrVecI8; + #[u16] IrVecU16; + #[i16] IrVecI16; + #[f16] IrVecF16; + #[u32] IrVecU32; + #[i32] IrVecI32; + #[f32] IrVecF32; + #[u64] IrVecU64; + #[i64] IrVecI64; + #[f64] IrVecF64; +} + +impl<'ctx> Context for &'ctx IrContext<'ctx> { + type Bool = IrBool<'ctx>; + type U8 = IrU8<'ctx>; + type I8 = IrI8<'ctx>; + type U16 = IrU16<'ctx>; + type I16 = IrI16<'ctx>; + type F16 = IrF16<'ctx>; + type U32 = IrU32<'ctx>; + type I32 = IrI32<'ctx>; + type F32 = IrF32<'ctx>; + type U64 = IrU64<'ctx>; + type I64 = IrI64<'ctx>; + type F64 = IrF64<'ctx>; + type VecBool = IrVecBool<'ctx>; + type VecU8 = IrVecU8<'ctx>; + type VecI8 = IrVecI8<'ctx>; + type VecU16 = IrVecU16<'ctx>; + type VecI16 = IrVecI16<'ctx>; + type VecF16 = IrVecF16<'ctx>; + type VecU32 = IrVecU32<'ctx>; + type VecI32 = IrVecI32<'ctx>; + type VecF32 = IrVecF32<'ctx>; + type VecU64 = IrVecU64<'ctx>; + type VecI64 = IrVecI64<'ctx>; + type VecF64 = IrVecF64<'ctx>; +} + +#[cfg(test)] +mod tests { + use super::*; + use std::println; + + #[test] + fn test_display() { + fn f(ctx: Ctx, a: Ctx::VecU8, b: Ctx::VecF32) -> Ctx::VecF64 { + let a: Ctx::VecF32 = a.into(); + (a - (a + b - ctx.make(5f32)).floor()).to() + } + let ctx = IrContext::new(); + fn make_it<'ctx>(ctx: &'ctx IrContext<'ctx>) -> IrFunction<'ctx> { + let f: fn(&'ctx IrContext<'ctx>, IrVecU8<'ctx>, IrVecF32<'ctx>) -> IrVecF64<'ctx> = f; + IrFunction::make(ctx, f) + } + let text = format!("\n{}", make_it(&ctx)); + println!("{}", text); + assert_eq!( + text, + r" +function(in: vec, in: vec) -> vec { + op_0: vec = Cast in + op_1: vec = Add op_0, in + op_2: vec = Sub op_1, splat(0x40A00000_f32) + op_3: vec = Floor op_2 + op_4: vec = Sub op_0, op_3 + op_5: vec = Cast op_4 + Return op_5 +} +" + ); + } +} diff --git a/src/lib.rs b/src/lib.rs index 551ef7d..06bfb80 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,5 +5,7 @@ extern crate std; pub mod f16; +#[cfg(feature = "ir")] +pub mod ir; pub mod scalar; pub mod traits; diff --git a/src/traits.rs b/src/traits.rs index fa46654..7837bfe 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -54,7 +54,7 @@ macro_rules! make_float_type { + ConvertTo $(+ ConvertTo)* $(+ ConvertTo)* - $($(+ Into)?)*; + $($(+ Into + ConvertTo)?)*; }; ( #[u32 = $u32:ident] @@ -119,9 +119,9 @@ macro_rules! make_uint_int_float_type { $($(+ ConvertTo)?)* + ConvertTo $(+ ConvertTo)? - $(+ Into)* - $(+ Into)* - $($(+ Into)?)*; + $(+ Into + ConvertTo)* + $(+ Into + ConvertTo)* + $($(+ Into + ConvertTo)?)*; type $int: SInt $(+ From)? + Compare @@ -132,8 +132,8 @@ macro_rules! make_uint_int_float_type { + ConvertTo $(+ ConvertTo)? $(+ ConvertTo)* - $(+ Into)* - $($(+ Into)?)*; + $(+ Into + ConvertTo)* + $($(+ Into + ConvertTo)?)*; make_float_type! { #[u32 = $u32] #[bool = $bool]