implement creating functions and basic blocks
authorJacob Lifshay <programmerjake@gmail.com>
Wed, 21 Nov 2018 05:11:52 +0000 (21:11 -0800)
committerJacob Lifshay <programmerjake@gmail.com>
Wed, 21 Nov 2018 05:11:52 +0000 (21:11 -0800)
shader-compiler-backend-llvm-7/src/backend.rs
shader-compiler-backend/src/lib.rs
shader-compiler/src/lib.rs
shader-compiler/src/parsed_shader_compile.rs
shader-compiler/src/parsed_shader_create.rs

index 95bfa5a258643f2a5991c1afb1cd8aa6c8935dc4..18ea0f958cfa3a6f2507f8d256e907fb14153731 100644 (file)
@@ -20,7 +20,7 @@ fn to_bool(v: llvm::LLVMBool) -> bool {
     v != 0
 }
 
-#[derive(Clone)]
+#[derive(Clone, Debug)]
 pub struct LLVM7CompilerConfig {
     pub variable_vector_length_multiplier: u32,
     pub optimization_mode: backend::OptimizationMode,
@@ -262,6 +262,14 @@ pub struct LLVM7Context {
     config: LLVM7CompilerConfig,
 }
 
+impl fmt::Debug for LLVM7Context {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        f.debug_struct("LLVM7Context")
+            .field("config", &self.config)
+            .finish()
+    }
+}
+
 impl Drop for LLVM7Context {
     fn drop(&mut self) {
         unsafe {
@@ -597,24 +605,26 @@ impl backend::Compiler for LLVM7Compiler {
             let orc_jit_stack =
                 LLVM7OrcJITStack(llvm::LLVMOrcCreateInstance(target_machine.take()));
             let mut module_handle = 0;
-            if llvm::LLVMOrcErrSuccess != llvm::LLVMOrcAddEagerlyCompiledIR(
-                orc_jit_stack.0,
-                &mut module_handle,
-                module.take(),
-                Some(symbol_resolver_fn),
-                null_mut(),
-            ) {
+            if llvm::LLVMOrcErrSuccess
+                != llvm::LLVMOrcAddEagerlyCompiledIR(
+                    orc_jit_stack.0,
+                    &mut module_handle,
+                    module.take(),
+                    Some(symbol_resolver_fn),
+                    null_mut(),
+                ) {
                 return Err(U::create_error("compilation failed".into()));
             }
             let mut functions: HashMap<_, _> = HashMap::new();
             for (key, name) in callable_functions {
                 let mut address: llvm::LLVMOrcTargetAddress = mem::zeroed();
-                if llvm::LLVMOrcErrSuccess != llvm::LLVMOrcGetSymbolAddressIn(
-                    orc_jit_stack.0,
-                    &mut address,
-                    module_handle,
-                    name.as_ptr(),
-                ) {
+                if llvm::LLVMOrcErrSuccess
+                    != llvm::LLVMOrcGetSymbolAddressIn(
+                        orc_jit_stack.0,
+                        &mut address,
+                        module_handle,
+                        name.as_ptr(),
+                    ) {
                     return Err(U::create_error(format!(
                         "function not found in compiled module: {:?}",
                         name
index 0c3e1c3599ee459f3d61e6a0de66ac25d6452c26..72ab36eaa679609b59b3e2d891791604ddb3d3fe 100644 (file)
@@ -154,7 +154,7 @@ pub trait VerifiedModule<'a>: Debug + Sized {
 }
 
 /// instance of a compiler backend; equivalent to LLVM's `LLVMContext`
-pub trait Context<'a>: Sized {
+pub trait Context<'a>: Sized + fmt::Debug {
     /// the `Value` type
     type Value: Value<'a, Context = Self>;
     /// the `BasicBlock` type
index bd668c7781b2b57889afcf847aff16bcafd203ad..dd4c4408b52f54ba376112107dde6e96cbabe2c2 100644 (file)
@@ -517,7 +517,7 @@ struct UniformVariable {
 }
 
 #[derive(Debug)]
-enum IdKind {
+enum IdKind<'a, C: shader_compiler_backend::Context<'a>> {
     Undefined,
     DecorationGroup,
     Type(Rc<Type>),
@@ -531,16 +531,20 @@ enum IdKind {
     Constant(Rc<Constant>),
     UniformVariable(UniformVariable),
     Function(Option<ParsedShaderFunction>),
+    BasicBlock {
+        basic_block: C::BasicBlock,
+        buildable_basic_block: Option<C::BuildableBasicBlock>,
+    },
 }
 
 #[derive(Debug)]
-struct IdProperties {
-    kind: IdKind,
+struct IdProperties<'a, C: shader_compiler_backend::Context<'a>> {
+    kind: IdKind<'a, C>,
     decorations: Vec<Decoration>,
     member_decorations: Vec<MemberDecoration>,
 }
 
-impl IdProperties {
+impl<'a, C: shader_compiler_backend::Context<'a>> IdProperties<'a, C> {
     fn is_empty(&self) -> bool {
         match self.kind {
             IdKind::Undefined => {}
@@ -548,7 +552,7 @@ impl IdProperties {
         }
         self.decorations.is_empty() && self.member_decorations.is_empty()
     }
-    fn set_kind(&mut self, kind: IdKind) {
+    fn set_kind(&mut self, kind: IdKind<'a, C>) {
         match &self.kind {
             IdKind::Undefined => {}
             _ => unreachable!("duplicate id"),
@@ -587,15 +591,15 @@ impl IdProperties {
     }
 }
 
-struct Ids(Vec<IdProperties>);
+struct Ids<'a, C: shader_compiler_backend::Context<'a>>(Vec<IdProperties<'a, C>>);
 
-impl Ids {
-    pub fn iter(&self) -> impl Iterator<Item = (IdRef, &IdProperties)> {
+impl<'a, C: shader_compiler_backend::Context<'a>> Ids<'a, C> {
+    pub fn iter(&self) -> impl Iterator<Item = (IdRef, &IdProperties<'a, C>)> {
         (1..self.0.len()).map(move |index| (IdRef(index as u32), &self.0[index]))
     }
 }
 
-impl fmt::Debug for Ids {
+impl<'a, C: shader_compiler_backend::Context<'a>> fmt::Debug for Ids<'a, C> {
     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
         f.debug_map()
             .entries(
@@ -613,15 +617,15 @@ impl fmt::Debug for Ids {
     }
 }
 
-impl Index<IdRef> for Ids {
-    type Output = IdProperties;
-    fn index(&self, index: IdRef) -> &IdProperties {
+impl<'a, C: shader_compiler_backend::Context<'a>> Index<IdRef> for Ids<'a, C> {
+    type Output = IdProperties<'a, C>;
+    fn index<'b>(&'b self, index: IdRef) -> &'b IdProperties<'a, C> {
         &self.0[index.0 as usize]
     }
 }
 
-impl IndexMut<IdRef> for Ids {
-    fn index_mut(&mut self, index: IdRef) -> &mut IdProperties {
+impl<'a, C: shader_compiler_backend::Context<'a>> IndexMut<IdRef> for Ids<'a, C> {
+    fn index_mut(&mut self, index: IdRef) -> &mut IdProperties<'a, C> {
         &mut self.0[index.0 as usize]
     }
 }
@@ -642,8 +646,8 @@ impl fmt::Debug for ParsedShaderFunction {
 }
 
 #[derive(Debug)]
-struct ParsedShader {
-    ids: Ids,
+struct ParsedShader<'a, C: shader_compiler_backend::Context<'a>> {
+    ids: Ids<'a, C>,
     main_function_id: IdRef,
     interface_variables: Vec<IdRef>,
     execution_modes: Vec<ExecutionMode>,
@@ -655,7 +659,7 @@ struct ShaderEntryPoint {
     interface_variables: Vec<IdRef>,
 }
 
-impl ParsedShader {
+impl<'a, C: shader_compiler_backend::Context<'a>> ParsedShader<'a, C> {
     fn create(
         context: &mut Context,
         stage_info: ShaderStageCreateInfo,
@@ -725,19 +729,13 @@ impl ComputePipeline {
         backend_compiler: C,
     ) -> ComputePipeline {
         let mut frontend_context = Context::default();
-        let parsed_shader = ParsedShader::create(
-            &mut frontend_context,
-            compute_shader_stage,
-            ExecutionModel::GLCompute,
-        );
-        println!("parsed_shader:\n{:#?}", parsed_shader);
-        struct CompilerUser {
+        struct CompilerUser<'a> {
             frontend_context: Context,
-            parsed_shader: ParsedShader,
+            compute_shader_stage: ShaderStageCreateInfo<'a>,
         }
         #[derive(Debug)]
         enum CompileError {}
-        impl shader_compiler_backend::CompilerUser for CompilerUser {
+        impl<'cu> shader_compiler_backend::CompilerUser for CompilerUser<'cu> {
             type FunctionKey = CompiledFunctionKey;
             type Error = CompileError;
             fn create_error(message: String) -> CompileError {
@@ -753,11 +751,20 @@ impl ComputePipeline {
                 let backend_context = context;
                 let CompilerUser {
                     mut frontend_context,
-                    parsed_shader,
+                    compute_shader_stage,
                 } = self;
+                let parsed_shader = ParsedShader::create(
+                    &mut frontend_context,
+                    compute_shader_stage,
+                    ExecutionModel::GLCompute,
+                );
                 let mut module = backend_context.create_module("");
-                let function =
-                    parsed_shader.compile(&mut frontend_context, backend_context, &mut module);
+                let function = parsed_shader.compile(
+                    &mut frontend_context,
+                    backend_context,
+                    &mut module,
+                    "fn_",
+                );
                 Ok(shader_compiler_backend::CompileInputs {
                     module: module.verify().unwrap(),
                     callable_functions: iter::once((
@@ -772,7 +779,7 @@ impl ComputePipeline {
             .run(
                 CompilerUser {
                     frontend_context,
-                    parsed_shader,
+                    compute_shader_stage,
                 },
                 shader_compiler_backend::CompilerIndependentConfig {
                     optimization_mode: options.generic_options.optimization_mode,
index 06af9579049efdf30ed942eef1d7a9e7ea7e1b91..027451a25a41183d1a05e1ce937d00c9a7fa9794 100644 (file)
@@ -1,18 +1,25 @@
 // SPDX-License-Identifier: LGPL-2.1-or-later
 // Copyright 2018 Jacob Lifshay
 
-use super::{Context, IdKind, IdProperties, ParsedShader, ParsedShaderFunction};
+use super::{Context, IdKind, Ids, ParsedShader, ParsedShaderFunction};
+use shader_compiler_backend::{
+    types::TypeBuilder, BuildableBasicBlock, DetachedBuilder, Function, Module,
+};
+use spirv_parser::Decoration;
 use spirv_parser::{FunctionControl, IdRef, IdResult, IdResultType, Instruction};
+use std::cell::Cell;
 use std::collections::hash_map;
 use std::collections::{HashMap, HashSet};
 use std::hash::Hash;
+use std::rc::Rc;
 
-pub(crate) trait ParsedShaderCompile {
-    fn compile<'a, C: shader_compiler_backend::Context<'a>>(
+pub(crate) trait ParsedShaderCompile<'ctx, C: shader_compiler_backend::Context<'ctx>> {
+    fn compile(
         self,
         frontend_context: &mut Context,
-        backend_context: &C,
+        backend_context: &'ctx C,
         module: &mut C::Module,
+        function_name_prefix: &str,
     ) -> C::Function;
 }
 
@@ -44,12 +51,118 @@ impl<T: Eq + Hash + Clone> Default for Worklist<T> {
     }
 }
 
-impl ParsedShaderCompile for ParsedShader {
-    fn compile<'a, C: shader_compiler_backend::Context<'a>>(
+struct FunctionInstruction {
+    id_result_type: IdResultType,
+    id_result: IdResult,
+    function_control: FunctionControl,
+    function_type: IdRef,
+}
+
+struct FunctionState<'ctx, C: shader_compiler_backend::Context<'ctx>> {
+    function_instruction: FunctionInstruction,
+    instructions: Vec<Instruction>,
+    decorations: Vec<Decoration>,
+    backend_function: Cell<Option<C::Function>>,
+    backend_function_value: C::Value,
+}
+
+struct GetOrAddFunctionState<'ctx, 'tb, 'fnp, C: shader_compiler_backend::Context<'ctx>>
+where
+    C::TypeBuilder: 'tb,
+{
+    reachable_functions: HashMap<IdRef, Rc<FunctionState<'ctx, C>>>,
+    type_builder: &'tb C::TypeBuilder,
+    function_name_prefix: &'fnp str,
+}
+
+impl<'ctx, 'tb, 'fnp, C: shader_compiler_backend::Context<'ctx>>
+    GetOrAddFunctionState<'ctx, 'tb, 'fnp, C>
+{
+    fn call(
+        &mut self,
+        reachable_functions_worklist: &mut Vec<IdRef>,
+        ids: &mut Ids<'ctx, C>,
+        module: &mut C::Module,
+        function_id: IdRef,
+    ) -> Rc<FunctionState<'ctx, C>> {
+        match self.reachable_functions.entry(function_id) {
+            hash_map::Entry::Occupied(v) => v.get().clone(),
+            hash_map::Entry::Vacant(v) => {
+                reachable_functions_worklist.push(function_id);
+                let ParsedShaderFunction {
+                    instructions,
+                    decorations,
+                } = match &mut ids[function_id].kind {
+                    IdKind::Function(function) => function.take().unwrap(),
+                    _ => unreachable!("id is not a function"),
+                };
+                let function_instruction = match instructions.get(0) {
+                    Some(&Instruction::Function {
+                        id_result_type,
+                        id_result,
+                        ref function_control,
+                        function_type,
+                    }) => FunctionInstruction {
+                        id_result_type,
+                        id_result,
+                        function_control: function_control.clone(),
+                        function_type,
+                    },
+                    _ => unreachable!("missing OpFunction"),
+                };
+                for decoration in &decorations {
+                    match decoration {
+                        _ => unreachable!(
+                            "unimplemented function decoration: {:?} on {}",
+                            decoration, function_id
+                        ),
+                    }
+                }
+                let function_type = match &ids[function_instruction.function_type].kind {
+                    IdKind::FunctionType {
+                        return_type,
+                        arguments,
+                    } => {
+                        let return_type = match return_type {
+                            None => None,
+                            Some(v) => unimplemented!(),
+                        };
+                        let arguments: Vec<_> = arguments
+                            .iter()
+                            .enumerate()
+                            .map(|(argument_index, argument)| unimplemented!())
+                            .collect();
+                        self.type_builder.build_function(&arguments, return_type)
+                    }
+                    _ => unreachable!("not a function type"),
+                };
+                let backend_function = module.add_function(
+                    &format!("{}{}", self.function_name_prefix, function_id.0),
+                    function_type,
+                );
+                let backend_function_value = backend_function.as_value();
+                v.insert(Rc::new(FunctionState {
+                    function_instruction,
+                    instructions,
+                    decorations,
+                    backend_function: Cell::new(Some(backend_function)),
+                    backend_function_value,
+                }))
+                .clone()
+            }
+        }
+    }
+}
+
+impl<'ctx, C: shader_compiler_backend::Context<'ctx>> ParsedShaderCompile<'ctx, C>
+    for ParsedShader<'ctx, C>
+{
+    fn compile(
         self,
         frontend_context: &mut Context,
-        backend_context: &C,
+        backend_context: &'ctx C,
         module: &mut C::Module,
+        function_name_prefix: &str,
     ) -> C::Function {
         let ParsedShader {
             mut ids,
@@ -58,56 +171,88 @@ impl ParsedShaderCompile for ParsedShader {
             execution_modes,
             workgroup_size,
         } = self;
-        let mut reachable_functions = HashMap::new();
-        let mut reachable_function_worklist = Worklist::default();
-        reachable_function_worklist.add(main_function_id);
-        while let Some(function_id) = reachable_function_worklist.get_next() {
-            let function = match &mut ids[function_id].kind {
-                IdKind::Function(function) => function.take().unwrap(),
-                _ => unreachable!("id is not a function"),
-            };
-            let mut function = match reachable_functions.entry(function_id) {
-                hash_map::Entry::Vacant(entry) => entry.insert(function),
-                _ => unreachable!(),
+        let type_builder = backend_context.create_type_builder();
+        let mut reachable_functions_worklist = Vec::new();
+        let mut get_or_add_function_state = GetOrAddFunctionState {
+            reachable_functions: HashMap::new(),
+            type_builder: &type_builder,
+            function_name_prefix,
+        };
+        let mut get_or_add_function = |reachable_functions_worklist: &mut Vec<IdRef>,
+                                       ids: &mut Ids<'ctx, C>,
+                                       module: &mut C::Module,
+                                       function_id: IdRef| {
+            get_or_add_function_state.call(reachable_functions_worklist, ids, module, function_id)
+        };
+        let get_or_add_basic_block =
+            |ids: &mut Ids<'ctx, C>, label_id: IdRef, backend_function: &mut C::Function| {
+                if let IdKind::BasicBlock { basic_block, .. } = &ids[label_id].kind {
+                    return basic_block.clone();
+                }
+                let buildable_basic_block =
+                    backend_function.append_new_basic_block(Some(&format!("L{}", label_id.0)));
+                let basic_block = buildable_basic_block.as_basic_block();
+                ids[label_id].set_kind(IdKind::BasicBlock {
+                    buildable_basic_block: Some(buildable_basic_block),
+                    basic_block: basic_block.clone(),
+                });
+                basic_block
             };
-            let (function_instruction, instructions) = function
-                .instructions
-                .split_first()
-                .expect("missing OpFunction");
-            struct FunctionInstruction {
-                id_result_type: IdResultType,
-                id_result: IdResult,
-                function_control: FunctionControl,
-                function_type: IdRef,
-            }
-            let function_instruction = match *function_instruction {
-                Instruction::Function {
-                    id_result_type,
-                    id_result,
-                    ref function_control,
-                    function_type,
-                } => FunctionInstruction {
-                    id_result_type,
-                    id_result,
-                    function_control: function_control.clone(),
-                    function_type,
+        get_or_add_function(
+            &mut reachable_functions_worklist,
+            &mut ids,
+            module,
+            main_function_id,
+        );
+        while let Some(function_id) = reachable_functions_worklist.pop() {
+            let function_state = get_or_add_function(
+                &mut reachable_functions_worklist,
+                &mut ids,
+                module,
+                function_id,
+            );
+            let mut backend_function = function_state.backend_function.replace(None).unwrap();
+            enum BasicBlockState<'ctx, C: shader_compiler_backend::Context<'ctx>> {
+                Detached {
+                    builder: C::DetachedBuilder,
+                },
+                Attached {
+                    builder: C::AttachedBuilder,
+                    current_label: IdRef,
                 },
-                _ => unreachable!("missing OpFunction"),
+            }
+            let mut current_basic_block: BasicBlockState<C> = BasicBlockState::Detached {
+                builder: backend_context.create_builder(),
             };
-            let mut current_basic_block: Option<IdRef> = None;
-            for instruction in instructions {
-                if let Some(basic_block) = current_basic_block {
-                    match instruction {
+            for instruction in &function_state.instructions {
+                match current_basic_block {
+                    BasicBlockState::Attached {
+                        builder,
+                        current_label,
+                    } => match instruction {
                         _ => unimplemented!("unimplemented instruction:\n{}", instruction),
-                    }
-                } else {
-                    match instruction {
+                    },
+                    BasicBlockState::Detached { builder } => match instruction {
+                        Instruction::Function { .. } => {
+                            current_basic_block = BasicBlockState::Detached { builder };
+                        }
                         Instruction::Label { id_result } => {
                             ids[id_result.0].assert_no_decorations(id_result.0);
-                            current_basic_block = Some(id_result.0);
+                            get_or_add_basic_block(&mut ids, id_result.0, &mut backend_function);
+                            let buildable_basic_block = match ids[id_result.0].kind {
+                                IdKind::BasicBlock {
+                                    ref mut buildable_basic_block,
+                                    ..
+                                } => buildable_basic_block.take().expect("duplicate OpLabel"),
+                                _ => unreachable!(),
+                            };
+                            current_basic_block = BasicBlockState::Attached {
+                                builder: builder.attach(buildable_basic_block),
+                                current_label: id_result.0,
+                            };
                         }
                         _ => unimplemented!("unimplemented instruction:\n{}", instruction),
-                    }
+                    },
                 }
             }
         }
index c1a226e86f4d9a033cf462de659ceede90edbeef..c85feb6f59eb07b24ea9e25c2792c7e240c1b087 100644 (file)
@@ -12,11 +12,11 @@ use std::mem;
 use std::rc::Rc;
 
 #[cfg_attr(feature = "cargo-clippy", allow(clippy::cyclomatic_complexity))]
-pub(super) fn create(
+pub(super) fn create<'a, C: shader_compiler_backend::Context<'a>>(
     context: &mut Context,
     stage_info: ShaderStageCreateInfo,
     execution_model: ExecutionModel,
-) -> ParsedShader {
+) -> ParsedShader<'a, C> {
     let parser = spirv_parser::Parser::start(stage_info.code).unwrap();
     let header = *parser.header();
     assert_eq!(header.instruction_schema, 0);