add function name validation
authorJacob Lifshay <programmerjake@gmail.com>
Wed, 17 Oct 2018 08:23:25 +0000 (01:23 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Wed, 17 Oct 2018 08:23:25 +0000 (01:23 -0700)
shader-compiler-llvm-7/src/backend.rs
shader-compiler-llvm-7/src/tests.rs

index fb466d4030233115850568278d1d277b92e97945..3e02ad799075806b1edbdc0144af6c6a062a85a5 100644 (file)
@@ -4,6 +4,7 @@ use llvm;
 use shader_compiler::backend;
 use std::cell::RefCell;
 use std::collections::HashMap;
+use std::collections::HashSet;
 use std::ffi::{CStr, CString};
 use std::fmt;
 use std::hash::Hash;
@@ -292,6 +293,7 @@ impl<'a> backend::Context<'a> for LLVM7Context {
             LLVM7Module {
                 context: self.context.as_ref().unwrap().0,
                 module: module_ref,
+                name_set: HashSet::new(),
             }
         }
     }
@@ -379,6 +381,7 @@ impl Drop for OwnedContext {
 pub struct LLVM7Module {
     context: llvm::LLVMContextRef,
     module: llvm::LLVMModuleRef,
+    name_set: HashSet<String>,
 }
 
 impl fmt::Debug for LLVM7Module {
@@ -403,6 +406,22 @@ impl<'a> backend::Module<'a> for LLVM7Module {
         }
     }
     fn add_function(&mut self, name: &str, ty: LLVM7Type) -> LLVM7Function {
+        fn is_start_char(c: char) -> bool {
+            if c.is_ascii_alphabetic() {
+                true
+            } else {
+                match c {
+                    '_' | '.' | '$' | '-' => true,
+                    _ => false,
+                }
+            }
+        }
+        fn is_continue_char(c: char) -> bool {
+            is_start_char(c) || c.is_ascii_digit()
+        }
+        assert!(is_start_char(name.chars().next().unwrap()));
+        assert!(name.chars().all(is_continue_char));
+        assert!(self.name_set.insert(name.into()));
         let name = CString::new(name).unwrap();
         unsafe {
             LLVM7Function {
index 4ce50369773bed112870dde61947a884414dd355..ac532cffa7670a32c0f025b822b35df75bfb91b2 100644 (file)
@@ -55,4 +55,46 @@ mod tests {
             function(0);
         }
     }
+
+    #[test]
+    fn test_names() {
+        const NAMES: &[&str] = &["main", "abc123-$._"];
+        type GeneratedFunctionType = unsafe extern "C" fn(u32);
+        #[derive(Copy, Clone, Hash, Eq, PartialEq, Debug)]
+        struct Test;
+        impl CompilerUser for Test {
+            type FunctionKey = String;
+            type Error = String;
+            fn create_error(message: String) -> String {
+                message
+            }
+            fn run<'a, C: Context<'a>>(
+                self,
+                context: &'a C,
+            ) -> Result<CompileInputs<'a, C, String>, String> {
+                let type_builder = context.create_type_builder();
+                let mut module = context.create_module("test_module");
+                let mut functions = Vec::new();
+                let mut detached_builder = context.create_builder();
+                for name in NAMES {
+                    let mut function =
+                        module.add_function(name, type_builder.build::<GeneratedFunctionType>());
+                    let builder = detached_builder.attach(function.append_new_basic_block(None));
+                    detached_builder = builder.build_return(None);
+                    functions.push((name.to_string(), function));
+                }
+                let module = module.verify().unwrap();
+                Ok(CompileInputs {
+                    module,
+                    callable_functions: functions.into_iter().collect(),
+                })
+            }
+        }
+        let compiled_code = make_compiler().run(Test, Default::default()).unwrap();
+        let function = compiled_code.get(&"main".to_string()).unwrap();
+        unsafe {
+            let function: GeneratedFunctionType = mem::transmute(function);
+            function(0);
+        }
+    }
 }