second stage of shader parsing works for vulkan_minimal_compute's shader
authorJacob Lifshay <programmerjake@gmail.com>
Mon, 12 Nov 2018 10:48:58 +0000 (02:48 -0800)
committerJacob Lifshay <programmerjake@gmail.com>
Mon, 12 Nov 2018 10:48:58 +0000 (02:48 -0800)
shader-compiler/src/lib.rs
spirv-parser-generator/src/generate.rs

index d0d1987c4ca60c4af8383b8b31275daa11c6f486..2e1a4767b610b5725a7b34bbd1110b19e01a53a8 100644 (file)
@@ -1,27 +1,38 @@
 // SPDX-License-Identifier: LGPL-2.1-or-later
 // Copyright 2018 Jacob Lifshay
 
-#[macro_use]
 extern crate shader_compiler_backend;
 extern crate spirv_parser;
 
 use spirv_parser::{
     BuiltIn, Decoration, ExecutionMode, ExecutionModel, IdRef, Instruction, StorageClass,
 };
-use std::error;
+use std::cell::RefCell;
+use std::collections::HashSet;
 use std::fmt;
+use std::hash::{Hash, Hasher};
 use std::mem;
 use std::ops::{Index, IndexMut};
 use std::rc::Rc;
 
-#[derive(Default)]
 pub struct Context {
     types: pointer_type::ContextTypes,
+    next_struct_id: usize,
+}
+
+impl Default for Context {
+    fn default() -> Context {
+        Context {
+            types: Default::default(),
+            next_struct_id: 0,
+        }
+    }
 }
 
 mod pointer_type {
     use super::{Context, Type};
     use std::cell::RefCell;
+    use std::fmt;
     use std::hash::{Hash, Hasher};
     use std::rc::{Rc, Weak};
 
@@ -35,11 +46,23 @@ mod pointer_type {
         Unresolved,
     }
 
-    #[derive(Clone, Debug)]
+    #[derive(Clone)]
     pub struct PointerType {
         pointee: RefCell<PointerTypeState>,
     }
 
+    impl fmt::Debug for PointerType {
+        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+            let mut state = f.debug_struct("PointerType");
+            if let PointerTypeState::Unresolved = *self.pointee.borrow() {
+                state.field("pointee", &PointerTypeState::Unresolved);
+            } else {
+                state.field("pointee", &self.pointee());
+            }
+            state.finish()
+        }
+    }
+
     impl PointerType {
         pub fn new(context: &mut Context, pointee: Option<Rc<Type>>) -> Self {
             Self {
@@ -121,24 +144,99 @@ pub enum ScalarType {
 }
 
 #[derive(Clone, Eq, PartialEq, Hash, Debug)]
-pub enum Type {
-    Scalar(ScalarType),
-    Vector {
-        element: ScalarType,
-        element_count: usize,
-    },
+pub struct VectorType {
+    pub element: ScalarType,
+    pub element_count: usize,
 }
 
-#[derive(Debug)]
-pub struct NotAPointer;
+#[derive(Clone, Eq, PartialEq, Hash, Debug)]
+pub struct StructMember {
+    pub decorations: Vec<Decoration>,
+    pub member_type: Rc<Type>,
+}
+
+#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
+pub struct StructId(usize);
+
+impl StructId {
+    pub fn new(context: &mut Context) -> Self {
+        let retval = StructId(context.next_struct_id);
+        context.next_struct_id += 1;
+        retval
+    }
+}
+
+#[derive(Clone)]
+pub struct StructType {
+    pub id: StructId,
+    pub decorations: Vec<Decoration>,
+    pub members: Vec<StructMember>,
+}
 
-impl fmt::Display for NotAPointer {
+impl Eq for StructType {}
+
+impl PartialEq for StructType {
+    fn eq(&self, rhs: &Self) -> bool {
+        self.id == rhs.id
+    }
+}
+
+impl Hash for StructType {
+    fn hash<H: Hasher>(&self, h: &mut H) {
+        self.id.hash(h)
+    }
+}
+
+impl fmt::Debug for StructType {
     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
-        write!(f, "not a pointer")
+        thread_local! {
+            static CURRENTLY_FORMATTING: RefCell<HashSet<StructId>> = RefCell::new(HashSet::new());
+        }
+        struct CurrentlyFormatting {
+            id: StructId,
+            was_formatting: bool,
+        }
+        impl CurrentlyFormatting {
+            fn new(id: StructId) -> Self {
+                let was_formatting = CURRENTLY_FORMATTING
+                    .with(|currently_formatting| !currently_formatting.borrow_mut().insert(id));
+                Self { id, was_formatting }
+            }
+        }
+        impl Drop for CurrentlyFormatting {
+            fn drop(&mut self) {
+                if !self.was_formatting {
+                    CURRENTLY_FORMATTING.with(|currently_formatting| {
+                        currently_formatting.borrow_mut().remove(&self.id);
+                    });
+                }
+            }
+        }
+        let currently_formatting = CurrentlyFormatting::new(self.id);
+        let mut state = f.debug_struct("StructType");
+        state.field("id", &self.id);
+        if !currently_formatting.was_formatting {
+            state.field("decorations", &self.decorations);
+            state.field("members", &self.members);
+        }
+        state.finish()
     }
 }
 
-impl error::Error for NotAPointer {}
+#[derive(Clone, Eq, PartialEq, Hash, Debug)]
+pub struct ArrayType {
+    pub decorations: Vec<Decoration>,
+    pub element: Rc<Type>,
+    pub element_count: Option<usize>,
+}
+
+#[derive(Clone, Eq, PartialEq, Hash, Debug)]
+pub enum Type {
+    Scalar(ScalarType),
+    Vector(VectorType),
+    Struct(StructType),
+    Array(ArrayType),
+}
 
 impl Type {
     pub fn is_pointer(&self) -> bool {
@@ -148,53 +246,234 @@ impl Type {
             false
         }
     }
-    pub fn get_pointee(&self) -> Result<Option<Rc<Type>>, NotAPointer> {
+    pub fn is_scalar(&self) -> bool {
+        if let Type::Scalar(_) = self {
+            true
+        } else {
+            false
+        }
+    }
+    pub fn is_vector(&self) -> bool {
+        if let Type::Vector(_) = self {
+            true
+        } else {
+            false
+        }
+    }
+    pub fn get_pointee(&self) -> Option<Rc<Type>> {
         if let Type::Scalar(ScalarType::Pointer(pointer)) = self {
-            Ok(pointer.pointee())
+            pointer.pointee()
         } else {
-            Err(NotAPointer)
+            unreachable!("not a pointer")
         }
     }
     pub fn get_nonvoid_pointee(&self) -> Rc<Type> {
-        self.get_pointee()
-            .unwrap()
-            .expect("void is not allowed here")
+        self.get_pointee().expect("void is not allowed here")
+    }
+    pub fn get_scalar(&self) -> &ScalarType {
+        if let Type::Scalar(scalar) = self {
+            scalar
+        } else {
+            unreachable!("not a scalar type")
+        }
+    }
+    pub fn get_vector(&self) -> &VectorType {
+        if let Type::Vector(vector) = self {
+            vector
+        } else {
+            unreachable!("not a vector type")
+        }
+    }
+}
+
+/// value that can be either defined or undefined
+#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
+pub enum Undefable<T> {
+    Undefined,
+    Defined(T),
+}
+
+impl<T> Undefable<T> {
+    pub fn unwrap(self) -> T {
+        match self {
+            Undefable::Undefined => panic!("Undefable::unwrap called on Undefined"),
+            Undefable::Defined(v) => v,
+        }
+    }
+}
+
+impl<T> From<T> for Undefable<T> {
+    fn from(v: T) -> Undefable<T> {
+        Undefable::Defined(v)
+    }
+}
+
+#[derive(Copy, Clone, Debug)]
+pub enum ScalarConstant {
+    U8(Undefable<u8>),
+    U16(Undefable<u16>),
+    U32(Undefable<u32>),
+    U64(Undefable<u64>),
+    I8(Undefable<i8>),
+    I16(Undefable<i16>),
+    I32(Undefable<i32>),
+    I64(Undefable<i64>),
+    F16(Undefable<u16>),
+    F32(Undefable<f32>),
+    F64(Undefable<f64>),
+    Bool(Undefable<bool>),
+}
+
+macro_rules! define_scalar_vector_constant_impl_without_from {
+    ($type:ident, $name:ident, $get_name:ident) => {
+        impl ScalarConstant {
+            pub fn $get_name(self) -> Undefable<$type> {
+                match self {
+                    ScalarConstant::$name(v) => v,
+                    _ => unreachable!(concat!("expected a constant ", stringify!($type))),
+                }
+            }
+        }
+        impl VectorConstant {
+            pub fn $get_name(&self) -> &Vec<Undefable<$type>> {
+                match self {
+                    VectorConstant::$name(v) => v,
+                    _ => unreachable!(concat!(
+                        "expected a constant vector with ",
+                        stringify!($type),
+                        " elements"
+                    )),
+                }
+            }
+        }
+    };
+}
+
+macro_rules! define_scalar_vector_constant_impl {
+    ($type:ident, $name:ident, $get_name:ident) => {
+        define_scalar_vector_constant_impl_without_from!($type, $name, $get_name);
+        impl From<Undefable<$type>> for ScalarConstant {
+            fn from(v: Undefable<$type>) -> ScalarConstant {
+                ScalarConstant::$name(v)
+            }
+        }
+        impl From<Vec<Undefable<$type>>> for VectorConstant {
+            fn from(v: Vec<Undefable<$type>>) -> VectorConstant {
+                VectorConstant::$name(v)
+            }
+        }
+    };
+}
+
+define_scalar_vector_constant_impl!(u8, U8, get_u8);
+define_scalar_vector_constant_impl!(u16, U16, get_u16);
+define_scalar_vector_constant_impl!(u32, U32, get_u32);
+define_scalar_vector_constant_impl!(u64, U64, get_u64);
+define_scalar_vector_constant_impl!(i8, I8, get_i8);
+define_scalar_vector_constant_impl!(i16, I16, get_i16);
+define_scalar_vector_constant_impl!(i32, I32, get_i32);
+define_scalar_vector_constant_impl!(i64, I64, get_i64);
+define_scalar_vector_constant_impl_without_from!(u16, F16, get_f16);
+define_scalar_vector_constant_impl!(f32, F32, get_f32);
+define_scalar_vector_constant_impl!(f64, F64, get_f64);
+define_scalar_vector_constant_impl!(bool, Bool, get_bool);
+
+impl ScalarConstant {
+    pub fn get_type(self) -> Type {
+        Type::Scalar(self.get_scalar_type())
+    }
+    pub fn get_scalar_type(self) -> ScalarType {
+        match self {
+            ScalarConstant::U8(_) => ScalarType::U8,
+            ScalarConstant::U16(_) => ScalarType::U16,
+            ScalarConstant::U32(_) => ScalarType::U32,
+            ScalarConstant::U64(_) => ScalarType::U64,
+            ScalarConstant::I8(_) => ScalarType::I8,
+            ScalarConstant::I16(_) => ScalarType::I16,
+            ScalarConstant::I32(_) => ScalarType::I32,
+            ScalarConstant::I64(_) => ScalarType::I64,
+            ScalarConstant::F16(_) => ScalarType::F16,
+            ScalarConstant::F32(_) => ScalarType::F32,
+            ScalarConstant::F64(_) => ScalarType::F64,
+            ScalarConstant::Bool(_) => ScalarType::Bool,
+        }
+    }
+}
+
+#[derive(Clone, Debug)]
+pub enum VectorConstant {
+    U8(Vec<Undefable<u8>>),
+    U16(Vec<Undefable<u16>>),
+    U32(Vec<Undefable<u32>>),
+    U64(Vec<Undefable<u64>>),
+    I8(Vec<Undefable<i8>>),
+    I16(Vec<Undefable<i16>>),
+    I32(Vec<Undefable<i32>>),
+    I64(Vec<Undefable<i64>>),
+    F16(Vec<Undefable<u16>>),
+    F32(Vec<Undefable<f32>>),
+    F64(Vec<Undefable<f64>>),
+    Bool(Vec<Undefable<bool>>),
+}
+
+impl VectorConstant {
+    pub fn get_element_type(&self) -> ScalarType {
+        match self {
+            VectorConstant::U8(_) => ScalarType::U8,
+            VectorConstant::U16(_) => ScalarType::U16,
+            VectorConstant::U32(_) => ScalarType::U32,
+            VectorConstant::U64(_) => ScalarType::U64,
+            VectorConstant::I8(_) => ScalarType::I8,
+            VectorConstant::I16(_) => ScalarType::I16,
+            VectorConstant::I32(_) => ScalarType::I32,
+            VectorConstant::I64(_) => ScalarType::I64,
+            VectorConstant::F16(_) => ScalarType::F16,
+            VectorConstant::F32(_) => ScalarType::F32,
+            VectorConstant::F64(_) => ScalarType::F64,
+            VectorConstant::Bool(_) => ScalarType::Bool,
+        }
+    }
+    pub fn get_element_count(&self) -> usize {
+        match self {
+            VectorConstant::U8(v) => v.len(),
+            VectorConstant::U16(v) => v.len(),
+            VectorConstant::U32(v) => v.len(),
+            VectorConstant::U64(v) => v.len(),
+            VectorConstant::I8(v) => v.len(),
+            VectorConstant::I16(v) => v.len(),
+            VectorConstant::I32(v) => v.len(),
+            VectorConstant::I64(v) => v.len(),
+            VectorConstant::F16(v) => v.len(),
+            VectorConstant::F32(v) => v.len(),
+            VectorConstant::F64(v) => v.len(),
+            VectorConstant::Bool(v) => v.len(),
+        }
+    }
+    pub fn get_type(&self) -> Type {
+        Type::Vector(VectorType {
+            element: self.get_element_type(),
+            element_count: self.get_element_count(),
+        })
     }
 }
 
 #[derive(Clone, Debug)]
 pub enum Constant {
-    Undef(Rc<Type>),
-    U8(u8),
-    U16(u16),
-    U32(u32),
-    U64(u64),
-    I8(i8),
-    I16(i16),
-    I32(i32),
-    I64(i64),
-    F16(u16),
-    F32(f32),
-    F64(f64),
-    Bool(bool),
+    Scalar(ScalarConstant),
+    Vector(VectorConstant),
 }
 
 impl Constant {
-    pub fn get_type(&self) -> &Type {
+    pub fn get_type(&self) -> Type {
+        match self {
+            Constant::Scalar(v) => v.get_type(),
+            Constant::Vector(v) => v.get_type(),
+        }
+    }
+    pub fn get_scalar(&self) -> &ScalarConstant {
         match self {
-            Constant::Undef(t) => &*t,
-            Constant::U8(_) => &Type::Scalar(ScalarType::U8),
-            Constant::U16(_) => &Type::Scalar(ScalarType::U16),
-            Constant::U32(_) => &Type::Scalar(ScalarType::U32),
-            Constant::U64(_) => &Type::Scalar(ScalarType::U64),
-            Constant::I8(_) => &Type::Scalar(ScalarType::I8),
-            Constant::I16(_) => &Type::Scalar(ScalarType::I16),
-            Constant::I32(_) => &Type::Scalar(ScalarType::I32),
-            Constant::I64(_) => &Type::Scalar(ScalarType::I64),
-            Constant::F16(_) => &Type::Scalar(ScalarType::F16),
-            Constant::F32(_) => &Type::Scalar(ScalarType::F32),
-            Constant::F64(_) => &Type::Scalar(ScalarType::F64),
-            Constant::Bool(_) => &Type::Scalar(ScalarType::Bool),
+            Constant::Scalar(v) => v,
+            _ => unreachable!("not a scalar constant"),
         }
     }
 }
@@ -213,15 +492,22 @@ struct BuiltInVariable {
 impl BuiltInVariable {
     fn get_type(&self, _context: &mut Context) -> Rc<Type> {
         match self.built_in {
-            BuiltIn::GlobalInvocationId => Rc::new(Type::Vector {
+            BuiltIn::GlobalInvocationId => Rc::new(Type::Vector(VectorType {
                 element: ScalarType::U32,
                 element_count: 3,
-            }),
+            })),
             _ => unreachable!("unknown built-in"),
         }
     }
 }
 
+#[derive(Debug, Clone)]
+struct UniformVariable {
+    binding: u32,
+    descriptor_set: u32,
+    variable_type: Rc<Type>,
+}
+
 #[derive(Debug)]
 enum IdKind {
     Undefined,
@@ -234,7 +520,8 @@ enum IdKind {
     },
     ForwardPointer(Rc<Type>),
     BuiltInVariable(BuiltInVariable),
-    Constant(Constant),
+    Constant(Rc<Constant>),
+    UniformVariable(UniformVariable),
 }
 
 #[derive(Debug)]
@@ -245,6 +532,13 @@ struct IdProperties {
 }
 
 impl IdProperties {
+    fn is_empty(&self) -> bool {
+        match self.kind {
+            IdKind::Undefined => {}
+            _ => return false,
+        }
+        self.decorations.is_empty() && self.member_decorations.is_empty()
+    }
     fn set_kind(&mut self, kind: IdKind) {
         match &self.kind {
             IdKind::Undefined => {}
@@ -252,16 +546,22 @@ impl IdProperties {
         }
         self.kind = kind;
     }
-    fn get_type(&self) -> Option<Rc<Type>> {
+    fn get_type(&self) -> Option<&Rc<Type>> {
         match &self.kind {
-            IdKind::Type(t) => Some(t.clone()),
+            IdKind::Type(t) => Some(t),
             IdKind::VoidType => None,
             _ => unreachable!("id is not type"),
         }
     }
-    fn get_nonvoid_type(&self) -> Rc<Type> {
+    fn get_nonvoid_type(&self) -> &Rc<Type> {
         self.get_type().expect("void is not allowed here")
     }
+    fn get_constant(&self) -> &Rc<Constant> {
+        match &self.kind {
+            IdKind::Constant(c) => c,
+            _ => unreachable!("id is not a constant"),
+        }
+    }
     fn assert_no_member_decorations(&self, id: IdRef) {
         for member_decoration in &self.member_decorations {
             unreachable!(
@@ -278,9 +578,26 @@ impl IdProperties {
     }
 }
 
-#[derive(Debug)]
 struct Ids(Vec<IdProperties>);
 
+impl fmt::Debug for Ids {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        f.debug_map()
+            .entries(
+                self.0
+                    .iter()
+                    .enumerate()
+                    .filter_map(|(id_index, id_properties)| {
+                        if id_properties.is_empty() {
+                            return None;
+                        }
+                        Some((IdRef(id_index as u32), id_properties))
+                    }),
+            )
+            .finish()
+    }
+}
+
 impl Index<IdRef> for Ids {
     type Output = IdProperties;
     fn index(&self, index: IdRef) -> &IdProperties {
@@ -298,13 +615,24 @@ struct ParsedShaderFunction {
     instructions: Vec<Instruction>,
 }
 
-#[allow(dead_code)]
+impl fmt::Debug for ParsedShaderFunction {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        write!(f, "ParsedShaderFunction:\n")?;
+        for instruction in &self.instructions {
+            write!(f, "{}", instruction)?;
+        }
+        Ok(())
+    }
+}
+
+#[derive(Debug)]
 struct ParsedShader {
     ids: Ids,
     functions: Vec<ParsedShaderFunction>,
     main_function_id: IdRef,
     interface_variables: Vec<IdRef>,
     execution_modes: Vec<ExecutionMode>,
+    workgroup_size: Option<(u32, u32, u32)>,
 }
 
 struct ShaderEntryPoint {
@@ -341,6 +669,7 @@ impl ParsedShader {
         let mut current_function: Option<ParsedShaderFunction> = None;
         let mut functions = Vec::new();
         let mut execution_modes = Vec::new();
+        let mut workgroup_size = None;
         for instruction in instructions {
             match current_function {
                 Some(mut function) => {
@@ -441,14 +770,10 @@ impl ParsedShader {
                 } => {
                     ids[id_result.0].assert_no_decorations(id_result.0);
                     let kind = IdKind::FunctionType {
-                        return_type: ids[return_type].get_type(),
+                        return_type: ids[return_type].get_type().map(Clone::clone),
                         arguments: parameter_types
                             .iter()
-                            .map(|argument| {
-                                ids[*argument]
-                                    .get_type()
-                                    .expect("void is not allowed as a function argument")
-                            })
+                            .map(|argument| ids[*argument].get_nonvoid_type().clone())
                             .collect(),
                     };
                     ids[id_result.0].set_kind(kind);
@@ -501,17 +826,11 @@ impl ParsedShader {
                     component_count,
                 } => {
                     ids[id_result.0].assert_no_decorations(id_result.0);
-                    let element = match &*ids[component_type]
-                        .get_type()
-                        .expect("void is not a valid vector element type")
-                    {
-                        Type::Scalar(v) => v.clone(),
-                        _ => unreachable!("vector element type must be a scalar"),
-                    };
-                    ids[id_result.0].set_kind(IdKind::Type(Rc::new(Type::Vector {
+                    let element = ids[component_type].get_nonvoid_type().get_scalar().clone();
+                    ids[id_result.0].set_kind(IdKind::Type(Rc::new(Type::Vector(VectorType {
                         element,
                         element_count: component_count as usize,
-                    })));
+                    }))));
                 }
                 Instruction::TypeForwardPointer { pointer_type, .. } => {
                     ids[pointer_type].set_kind(IdKind::ForwardPointer(Rc::new(Type::Scalar(
@@ -524,7 +843,7 @@ impl ParsedShader {
                     ..
                 } => {
                     ids[id_result.0].assert_no_decorations(id_result.0);
-                    let pointee = ids[pointee].get_type();
+                    let pointee = ids[pointee].get_type().map(Clone::clone);
                     let pointer = match mem::replace(&mut ids[id_result.0].kind, IdKind::Undefined)
                     {
                         IdKind::Undefined => Rc::new(Type::Scalar(ScalarType::Pointer(
@@ -542,6 +861,49 @@ impl ParsedShader {
                     };
                     ids[id_result.0].set_kind(IdKind::Type(pointer));
                 }
+                Instruction::TypeStruct {
+                    id_result,
+                    member_types,
+                } => {
+                    let decorations = ids[id_result.0].decorations.clone();
+                    let struct_type = {
+                        let mut members: Vec<_> = member_types
+                            .into_iter()
+                            .map(|member_type| StructMember {
+                                decorations: Vec::new(),
+                                member_type: match ids[member_type].kind {
+                                    IdKind::Type(ref t) => t.clone(),
+                                    IdKind::ForwardPointer(ref t) => t.clone(),
+                                    _ => unreachable!("invalid struct member type"),
+                                },
+                            })
+                            .collect();
+                        for member_decoration in &ids[id_result.0].member_decorations {
+                            members[member_decoration.member as usize]
+                                .decorations
+                                .push(member_decoration.decoration.clone());
+                        }
+                        StructType {
+                            id: StructId::new(context),
+                            decorations,
+                            members,
+                        }
+                    };
+                    ids[id_result.0].set_kind(IdKind::Type(Rc::new(Type::Struct(struct_type))));
+                }
+                Instruction::TypeRuntimeArray {
+                    id_result,
+                    element_type,
+                } => {
+                    ids[id_result.0].assert_no_member_decorations(id_result.0);
+                    let decorations = ids[id_result.0].decorations.clone();
+                    let element = ids[element_type].get_nonvoid_type().clone();
+                    ids[id_result.0].set_kind(IdKind::Type(Rc::new(Type::Array(ArrayType {
+                        decorations,
+                        element,
+                        element_count: None,
+                    }))));
+                }
                 Instruction::Variable {
                     id_result_type,
                     id_result,
@@ -582,7 +944,46 @@ impl ParsedShader {
                         );
                         ids[id_result.0].set_kind(IdKind::BuiltInVariable(built_in_variable));
                     } else {
+                        let variable_type = ids[id_result_type.0].get_nonvoid_type().clone();
                         match storage_class {
+                            StorageClass::Uniform => {
+                                let mut descriptor_set = None;
+                                let mut binding = None;
+                                for decoration in &ids[id_result.0].decorations {
+                                    match *decoration {
+                                        Decoration::DescriptorSet { descriptor_set: v } => {
+                                            assert!(
+                                                descriptor_set.is_none(),
+                                                "duplicate DescriptorSet decoration"
+                                            );
+                                            descriptor_set = Some(v);
+                                        }
+                                        Decoration::Binding { binding_point: v } => {
+                                            assert!(
+                                                binding.is_none(),
+                                                "duplicate Binding decoration"
+                                            );
+                                            binding = Some(v);
+                                        }
+                                        _ => unimplemented!(
+                                            "unimplemented decoration on uniform variable: {:?}",
+                                            decoration
+                                        ),
+                                    }
+                                }
+                                let descriptor_set = descriptor_set
+                                    .expect("uniform variable is missing DescriptorSet decoration");
+                                let binding = binding
+                                    .expect("uniform variable is missing Binding decoration");
+                                assert!(initializer.is_none());
+                                ids[id_result.0].set_kind(IdKind::UniformVariable(
+                                    UniformVariable {
+                                        binding,
+                                        descriptor_set,
+                                        variable_type,
+                                    },
+                                ));
+                            }
                             StorageClass::Input => unimplemented!(),
                             _ => unimplemented!(
                                 "unimplemented OpVariable StorageClass: {:?}",
@@ -598,38 +999,54 @@ impl ParsedShader {
                 } => {
                     ids[id_result.0].assert_no_decorations(id_result.0);
                     #[cfg_attr(feature = "cargo-clippy", allow(clippy::cast_lossless))]
-                    let constant = match &*ids[id_result_type.0].get_nonvoid_type() {
+                    let constant = match **ids[id_result_type.0].get_nonvoid_type() {
                         Type::Scalar(ScalarType::U8) => {
                             let converted_value = value as u8;
                             assert_eq!(converted_value as u32, value);
-                            Constant::U8(converted_value)
+                            Constant::Scalar(ScalarConstant::U8(Undefable::Defined(
+                                converted_value,
+                            )))
                         }
                         Type::Scalar(ScalarType::U16) => {
                             let converted_value = value as u16;
                             assert_eq!(converted_value as u32, value);
-                            Constant::U16(converted_value)
+                            Constant::Scalar(ScalarConstant::U16(Undefable::Defined(
+                                converted_value,
+                            )))
+                        }
+                        Type::Scalar(ScalarType::U32) => {
+                            Constant::Scalar(ScalarConstant::U32(Undefable::Defined(value)))
                         }
-                        Type::Scalar(ScalarType::U32) => Constant::U32(value),
                         Type::Scalar(ScalarType::I8) => {
                             let converted_value = value as i8;
                             assert_eq!(converted_value as u32, value);
-                            Constant::I8(converted_value)
+                            Constant::Scalar(ScalarConstant::I8(Undefable::Defined(
+                                converted_value,
+                            )))
                         }
                         Type::Scalar(ScalarType::I16) => {
                             let converted_value = value as i16;
                             assert_eq!(converted_value as u32, value);
-                            Constant::I16(converted_value)
+                            Constant::Scalar(ScalarConstant::I16(Undefable::Defined(
+                                converted_value,
+                            )))
+                        }
+                        Type::Scalar(ScalarType::I32) => {
+                            Constant::Scalar(ScalarConstant::I32(Undefable::Defined(value as i32)))
                         }
-                        Type::Scalar(ScalarType::I32) => Constant::I32(value as i32),
                         Type::Scalar(ScalarType::F16) => {
                             let converted_value = value as u16;
                             assert_eq!(converted_value as u32, value);
-                            Constant::F16(converted_value)
+                            Constant::Scalar(ScalarConstant::F16(Undefable::Defined(
+                                converted_value,
+                            )))
                         }
-                        Type::Scalar(ScalarType::F32) => Constant::F32(f32::from_bits(value)),
+                        Type::Scalar(ScalarType::F32) => Constant::Scalar(ScalarConstant::F32(
+                            Undefable::Defined(f32::from_bits(value)),
+                        )),
                         _ => unreachable!("invalid type"),
                     };
-                    ids[id_result.0].set_kind(IdKind::Constant(constant));
+                    ids[id_result.0].set_kind(IdKind::Constant(Rc::new(constant)));
                 }
                 Instruction::Constant64 {
                     id_result_type,
@@ -637,35 +1054,133 @@ impl ParsedShader {
                     value,
                 } => {
                     ids[id_result.0].assert_no_decorations(id_result.0);
-                    let constant = match &*ids[id_result_type.0].get_nonvoid_type() {
-                        Type::Scalar(ScalarType::U64) => Constant::U64(value),
-                        Type::Scalar(ScalarType::I64) => Constant::I64(value as i64),
-                        Type::Scalar(ScalarType::F64) => Constant::F64(f64::from_bits(value)),
+                    let constant = match **ids[id_result_type.0].get_nonvoid_type() {
+                        Type::Scalar(ScalarType::U64) => {
+                            Constant::Scalar(ScalarConstant::U64(Undefable::Defined(value)))
+                        }
+                        Type::Scalar(ScalarType::I64) => {
+                            Constant::Scalar(ScalarConstant::I64(Undefable::Defined(value as i64)))
+                        }
+                        Type::Scalar(ScalarType::F64) => Constant::Scalar(ScalarConstant::F64(
+                            Undefable::Defined(f64::from_bits(value)),
+                        )),
                         _ => unreachable!("invalid type"),
                     };
-                    ids[id_result.0].set_kind(IdKind::Constant(constant));
+                    ids[id_result.0].set_kind(IdKind::Constant(Rc::new(constant)));
                 }
                 Instruction::ConstantFalse {
                     id_result_type,
                     id_result,
                 } => {
                     ids[id_result.0].assert_no_decorations(id_result.0);
-                    let constant = match &*ids[id_result_type.0].get_nonvoid_type() {
-                        Type::Scalar(ScalarType::Bool) => Constant::Bool(false),
+                    let constant = match **ids[id_result_type.0].get_nonvoid_type() {
+                        Type::Scalar(ScalarType::Bool) => {
+                            Constant::Scalar(ScalarConstant::Bool(Undefable::Defined(false)))
+                        }
                         _ => unreachable!("invalid type"),
                     };
-                    ids[id_result.0].set_kind(IdKind::Constant(constant));
+                    ids[id_result.0].set_kind(IdKind::Constant(Rc::new(constant)));
                 }
                 Instruction::ConstantTrue {
                     id_result_type,
                     id_result,
                 } => {
                     ids[id_result.0].assert_no_decorations(id_result.0);
-                    let constant = match &*ids[id_result_type.0].get_nonvoid_type() {
-                        Type::Scalar(ScalarType::Bool) => Constant::Bool(true),
+                    let constant = match **ids[id_result_type.0].get_nonvoid_type() {
+                        Type::Scalar(ScalarType::Bool) => {
+                            Constant::Scalar(ScalarConstant::Bool(Undefable::Defined(true)))
+                        }
                         _ => unreachable!("invalid type"),
                     };
-                    ids[id_result.0].set_kind(IdKind::Constant(constant));
+                    ids[id_result.0].set_kind(IdKind::Constant(Rc::new(constant)));
+                }
+                Instruction::ConstantComposite {
+                    id_result_type,
+                    id_result,
+                    constituents,
+                } => {
+                    let constant = match **ids[id_result_type.0].get_nonvoid_type() {
+                        Type::Vector(VectorType {
+                            ref element,
+                            element_count,
+                        }) => {
+                            assert_eq!(element_count, constituents.len());
+                            let constituents = constituents
+                                .iter()
+                                .map(|id| *ids[*id].get_constant().get_scalar());
+                            match *element {
+                                ScalarType::U8 => {
+                                    VectorConstant::U8(constituents.map(|v| v.get_u8()).collect())
+                                }
+                                ScalarType::U16 => {
+                                    VectorConstant::U16(constituents.map(|v| v.get_u16()).collect())
+                                }
+                                ScalarType::U32 => {
+                                    VectorConstant::U32(constituents.map(|v| v.get_u32()).collect())
+                                }
+                                ScalarType::U64 => {
+                                    VectorConstant::U64(constituents.map(|v| v.get_u64()).collect())
+                                }
+                                ScalarType::I8 => {
+                                    VectorConstant::I8(constituents.map(|v| v.get_i8()).collect())
+                                }
+                                ScalarType::I16 => {
+                                    VectorConstant::I16(constituents.map(|v| v.get_i16()).collect())
+                                }
+                                ScalarType::I32 => {
+                                    VectorConstant::I32(constituents.map(|v| v.get_i32()).collect())
+                                }
+                                ScalarType::I64 => {
+                                    VectorConstant::I64(constituents.map(|v| v.get_i64()).collect())
+                                }
+                                ScalarType::F16 => {
+                                    VectorConstant::F16(constituents.map(|v| v.get_f16()).collect())
+                                }
+                                ScalarType::F32 => {
+                                    VectorConstant::F32(constituents.map(|v| v.get_f32()).collect())
+                                }
+                                ScalarType::F64 => {
+                                    VectorConstant::F64(constituents.map(|v| v.get_f64()).collect())
+                                }
+                                ScalarType::Bool => VectorConstant::Bool(
+                                    constituents.map(|v| v.get_bool()).collect(),
+                                ),
+                                ScalarType::Pointer(_) => unimplemented!(),
+                            }
+                        }
+                        _ => unimplemented!(),
+                    };
+                    for decoration in &ids[id_result.0].decorations {
+                        match decoration {
+                            Decoration::BuiltIn {
+                                built_in: BuiltIn::WorkgroupSize,
+                            } => {
+                                assert!(
+                                    workgroup_size.is_none(),
+                                    "duplicate WorkgroupSize decorations"
+                                );
+                                workgroup_size = match constant {
+                                    VectorConstant::U32(ref v) => {
+                                        assert_eq!(
+                                            v.len(),
+                                            3,
+                                            "invalid type for WorkgroupSize built-in"
+                                        );
+                                        Some((v[0].unwrap(), v[1].unwrap(), v[2].unwrap()))
+                                    }
+                                    _ => unreachable!("invalid type for WorkgroupSize built-in"),
+                                };
+                            }
+                            _ => unimplemented!(
+                                "unimplemented decoration on constant {:?}: {:?}",
+                                Constant::Vector(constant),
+                                decoration
+                            ),
+                        }
+                    }
+                    ids[id_result.0].assert_no_member_decorations(id_result.0);
+                    ids[id_result.0]
+                        .set_kind(IdKind::Constant(Rc::new(Constant::Vector(constant))));
                 }
                 Instruction::MemoryModel {
                     addressing_model,
@@ -702,6 +1217,7 @@ impl ParsedShader {
             main_function_id,
             interface_variables,
             execution_modes,
+            workgroup_size,
         }
     }
 }
@@ -741,11 +1257,12 @@ impl ComputePipeline {
         compute_shader_stage: ShaderStageCreateInfo,
     ) -> ComputePipeline {
         let mut context = Context::default();
-        let _parsed_shader = ParsedShader::create(
+        let parsed_shader = ParsedShader::create(
             &mut context,
             compute_shader_stage,
             ExecutionModel::GLCompute,
         );
+        println!("parsed_shader:\n{:#?}", parsed_shader);
         unimplemented!()
     }
 }
index a1423b1e4c51ded784b466e60e8bcc18d1bf7a4a..4f709609d33d677ab761386a09ff0e690c7109ab 100644 (file)
@@ -449,7 +449,7 @@ pub(crate) fn generate(
                     let enumerant_parse_operation;
                     if enumerant.parameters.is_empty() {
                         enumerant_items.push(quote!{
-                            #[derive(Clone, Debug, Default)]
+                            #[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Hash)]
                             pub struct #type_name;
                         });
                         enumerant_parse_operation = quote!{(Some(#type_name), words)};
@@ -486,7 +486,7 @@ pub(crate) fn generate(
                             });
                         }
                         enumerant_items.push(quote!{
-                            #[derive(Clone, Debug, Default)]
+                            #[derive(Clone, Debug, Default, Eq, PartialEq, Hash)]
                             pub struct #type_name(#(#enumerant_parameter_declarations)*);
                         });
                         let enumerant_parameter_names = &enumerant_parameter_names;
@@ -527,7 +527,7 @@ pub(crate) fn generate(
                     &mut out,
                     "{}",
                     quote!{
-                        #[derive(Clone, Debug, Default)]
+                        #[derive(Clone, Debug, Default, Eq, PartialEq, Hash)]
                         pub struct #kind_id {
                             #(#enumerant_members),*
                         }
@@ -646,12 +646,15 @@ pub(crate) fn generate(
                         });
                     }
                 }
-                let mut derives = vec![quote!{Clone}, quote!{Debug}];
+                let mut derives = vec![
+                    quote!{Clone},
+                    quote!{Debug},
+                    quote!{Eq},
+                    quote!{PartialEq},
+                    quote!{Hash},
+                ];
                 if !has_any_parameters {
-                    derives.push(quote!{Eq});
-                    derives.push(quote!{PartialEq});
                     derives.push(quote!{Copy});
-                    derives.push(quote!{Hash});
                 }
                 writeln!(
                     &mut out,