1 // SPDX-License-Identifier: LGPL-2.1-or-later
2 // Copyright 2018 Jacob Lifshay
4 use shader_compiler::backend;
5 use std::cell::RefCell;
6 use std::collections::HashMap;
7 use std::collections::HashSet;
8 use std::ffi::{CStr, CString};
12 use std::mem::ManuallyDrop;
14 use std::os::raw::{c_char, c_uint};
15 use std::ptr::null_mut;
16 use std::ptr::NonNull;
17 use std::sync::{Once, ONCE_INIT};
19 fn to_bool(v: llvm::LLVMBool) -> bool {
24 pub struct LLVM7CompilerConfig {
25 pub variable_vector_length_multiplier: u32,
26 pub optimization_mode: backend::OptimizationMode,
29 impl Default for LLVM7CompilerConfig {
30 fn default() -> Self {
31 backend::CompilerIndependentConfig::default().into()
35 impl From<backend::CompilerIndependentConfig> for LLVM7CompilerConfig {
36 fn from(v: backend::CompilerIndependentConfig) -> Self {
37 let backend::CompilerIndependentConfig { optimization_mode } = v;
39 variable_vector_length_multiplier: 1,
46 struct LLVM7String(NonNull<c_char>);
48 impl Drop for LLVM7String {
51 llvm::LLVMDisposeMessage(self.0.as_ptr());
56 impl Deref for LLVM7String {
58 fn deref(&self) -> &CStr {
59 unsafe { CStr::from_ptr(self.0.as_ptr()) }
63 impl Clone for LLVM7String {
64 fn clone(&self) -> Self {
70 fn new(v: &CStr) -> Self {
71 unsafe { Self::from_ptr(llvm::LLVMCreateMessage(v.as_ptr())).unwrap() }
73 unsafe fn from_nonnull(v: NonNull<c_char>) -> Self {
76 unsafe fn from_ptr(v: *mut c_char) -> Option<Self> {
77 NonNull::new(v).map(|v| Self::from_nonnull(v))
81 impl fmt::Debug for LLVM7String {
82 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
87 #[derive(Clone, Eq, PartialEq, Hash)]
89 pub struct LLVM7Type(llvm::LLVMTypeRef);
91 impl fmt::Debug for LLVM7Type {
92 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
95 LLVM7String::from_ptr(llvm::LLVMPrintTypeToString(self.0)).ok_or(fmt::Error)?;
96 f.write_str(&string.to_string_lossy())
101 impl<'a> backend::types::Type<'a> for LLVM7Type {
102 type Context = LLVM7Context;
105 pub struct LLVM7TypeBuilder {
106 context: llvm::LLVMContextRef,
107 variable_vector_length_multiplier: u32,
110 impl<'a> backend::types::TypeBuilder<'a, LLVM7Type> for LLVM7TypeBuilder {
111 fn build_bool(&self) -> LLVM7Type {
112 unsafe { LLVM7Type(llvm::LLVMInt1TypeInContext(self.context)) }
114 fn build_i8(&self) -> LLVM7Type {
115 unsafe { LLVM7Type(llvm::LLVMInt8TypeInContext(self.context)) }
117 fn build_i16(&self) -> LLVM7Type {
118 unsafe { LLVM7Type(llvm::LLVMInt16TypeInContext(self.context)) }
120 fn build_i32(&self) -> LLVM7Type {
121 unsafe { LLVM7Type(llvm::LLVMInt32TypeInContext(self.context)) }
123 fn build_i64(&self) -> LLVM7Type {
124 unsafe { LLVM7Type(llvm::LLVMInt64TypeInContext(self.context)) }
126 fn build_f32(&self) -> LLVM7Type {
127 unsafe { LLVM7Type(llvm::LLVMFloatTypeInContext(self.context)) }
129 fn build_f64(&self) -> LLVM7Type {
130 unsafe { LLVM7Type(llvm::LLVMDoubleTypeInContext(self.context)) }
132 fn build_pointer(&self, target: LLVM7Type) -> LLVM7Type {
133 unsafe { LLVM7Type(llvm::LLVMPointerType(target.0, 0)) }
135 fn build_array(&self, element: LLVM7Type, count: usize) -> LLVM7Type {
136 assert_eq!(count as u32 as usize, count);
137 unsafe { LLVM7Type(llvm::LLVMArrayType(element.0, count as u32)) }
139 fn build_vector(&self, element: LLVM7Type, length: backend::types::VectorLength) -> LLVM7Type {
140 use self::backend::types::VectorLength::*;
141 let length = match length {
142 Fixed { length } => length,
143 Variable { base_length } => base_length
144 .checked_mul(self.variable_vector_length_multiplier)
147 assert_ne!(length, 0);
148 unsafe { LLVM7Type(llvm::LLVMVectorType(element.0, length)) }
150 fn build_struct(&self, members: &[LLVM7Type]) -> LLVM7Type {
151 assert_eq!(members.len() as c_uint as usize, members.len());
153 LLVM7Type(llvm::LLVMStructTypeInContext(
155 members.as_ptr() as *mut llvm::LLVMTypeRef,
156 members.len() as c_uint,
157 false as llvm::LLVMBool,
161 fn build_function(&self, arguments: &[LLVM7Type], return_type: Option<LLVM7Type>) -> LLVM7Type {
162 assert_eq!(arguments.len() as c_uint as usize, arguments.len());
164 LLVM7Type(llvm::LLVMFunctionType(
166 .unwrap_or_else(|| LLVM7Type(llvm::LLVMVoidTypeInContext(self.context)))
168 arguments.as_ptr() as *mut llvm::LLVMTypeRef,
169 arguments.len() as c_uint,
170 false as llvm::LLVMBool,
178 pub struct LLVM7Value(llvm::LLVMValueRef);
180 impl fmt::Debug for LLVM7Value {
181 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
184 LLVM7String::from_ptr(llvm::LLVMPrintValueToString(self.0)).ok_or(fmt::Error)?;
185 f.write_str(&string.to_string_lossy())
190 impl<'a> backend::Value<'a> for LLVM7Value {
191 type Context = LLVM7Context;
196 pub struct LLVM7BasicBlock(llvm::LLVMBasicBlockRef);
198 impl fmt::Debug for LLVM7BasicBlock {
199 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
200 use self::backend::BasicBlock;
202 let string = LLVM7String::from_ptr(llvm::LLVMPrintValueToString(self.as_value().0))
204 f.write_str(&string.to_string_lossy())
209 impl<'a> backend::BasicBlock<'a> for LLVM7BasicBlock {
210 type Context = LLVM7Context;
211 fn as_value(&self) -> LLVM7Value {
212 unsafe { LLVM7Value(llvm::LLVMBasicBlockAsValue(self.0)) }
216 impl<'a> backend::BuildableBasicBlock<'a> for LLVM7BasicBlock {
217 type Context = LLVM7Context;
218 fn as_basic_block(&self) -> LLVM7BasicBlock {
223 pub struct LLVM7Function {
224 context: llvm::LLVMContextRef,
225 function: llvm::LLVMValueRef,
228 impl fmt::Debug for LLVM7Function {
229 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
231 let string = LLVM7String::from_ptr(llvm::LLVMPrintValueToString(self.function))
233 f.write_str(&string.to_string_lossy())
238 impl<'a> backend::Function<'a> for LLVM7Function {
239 type Context = LLVM7Context;
240 fn as_value(&self) -> LLVM7Value {
241 LLVM7Value(self.function)
243 fn append_new_basic_block(&mut self, name: Option<&str>) -> LLVM7BasicBlock {
244 let name = CString::new(name.unwrap_or("")).unwrap();
246 LLVM7BasicBlock(llvm::LLVMAppendBasicBlockInContext(
255 pub struct LLVM7Context {
256 context: Option<ManuallyDrop<OwnedContext>>,
257 modules: ManuallyDrop<RefCell<Vec<OwnedModule>>>,
258 config: LLVM7CompilerConfig,
261 impl Drop for LLVM7Context {
264 ManuallyDrop::drop(&mut self.modules);
265 if let Some(context) = &mut self.context {
266 ManuallyDrop::drop(context);
272 impl<'a> backend::Context<'a> for LLVM7Context {
273 type Value = LLVM7Value;
274 type BasicBlock = LLVM7BasicBlock;
275 type BuildableBasicBlock = LLVM7BasicBlock;
276 type Function = LLVM7Function;
277 type Type = LLVM7Type;
278 type TypeBuilder = LLVM7TypeBuilder;
279 type Module = LLVM7Module;
280 type VerifiedModule = LLVM7Module;
281 type AttachedBuilder = LLVM7Builder;
282 type DetachedBuilder = LLVM7Builder;
283 fn create_module(&self, name: &str) -> LLVM7Module {
284 let name = CString::new(name).unwrap();
285 let mut modules = self.modules.borrow_mut();
287 let module = OwnedModule(llvm::LLVMModuleCreateWithNameInContext(
289 self.context.as_ref().unwrap().0,
291 let module_ref = module.0;
292 modules.push(module);
294 context: self.context.as_ref().unwrap().0,
296 name_set: HashSet::new(),
300 fn create_builder(&self) -> LLVM7Builder {
302 LLVM7Builder(llvm::LLVMCreateBuilderInContext(
303 self.context.as_ref().unwrap().0,
307 fn create_type_builder(&self) -> LLVM7TypeBuilder {
309 context: self.context.as_ref().unwrap().0,
310 variable_vector_length_multiplier: self.config.variable_vector_length_multiplier,
316 pub struct LLVM7Builder(llvm::LLVMBuilderRef);
318 impl Drop for LLVM7Builder {
321 llvm::LLVMDisposeBuilder(self.0);
326 impl<'a> backend::AttachedBuilder<'a> for LLVM7Builder {
327 type Context = LLVM7Context;
328 fn current_basic_block(&self) -> LLVM7BasicBlock {
329 unsafe { LLVM7BasicBlock(llvm::LLVMGetInsertBlock(self.0)) }
331 fn build_return(self, value: Option<LLVM7Value>) -> LLVM7Builder {
334 Some(value) => llvm::LLVMBuildRet(self.0, value.0),
335 None => llvm::LLVMBuildRetVoid(self.0),
337 llvm::LLVMClearInsertionPosition(self.0);
343 impl<'a> backend::DetachedBuilder<'a> for LLVM7Builder {
344 type Context = LLVM7Context;
345 fn attach(self, basic_block: LLVM7BasicBlock) -> LLVM7Builder {
347 llvm::LLVMPositionBuilderAtEnd(self.0, basic_block.0);
353 struct OwnedModule(llvm::LLVMModuleRef);
355 impl Drop for OwnedModule {
358 llvm::LLVMDisposeModule(self.0);
364 unsafe fn take(mut self) -> llvm::LLVMModuleRef {
371 struct OwnedContext(llvm::LLVMContextRef);
373 impl Drop for OwnedContext {
376 llvm::LLVMContextDispose(self.0);
381 pub struct LLVM7Module {
382 context: llvm::LLVMContextRef,
383 module: llvm::LLVMModuleRef,
384 name_set: HashSet<String>,
387 impl fmt::Debug for LLVM7Module {
388 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
390 let string = LLVM7String::from_ptr(llvm::LLVMPrintModuleToString(self.module))
392 f.write_str(&string.to_string_lossy())
397 impl<'a> backend::Module<'a> for LLVM7Module {
398 type Context = LLVM7Context;
399 fn set_source_file_name(&mut self, source_file_name: &str) {
401 llvm::LLVMSetSourceFileName(
403 source_file_name.as_ptr() as *const c_char,
404 source_file_name.len(),
408 fn add_function(&mut self, name: &str, ty: LLVM7Type) -> LLVM7Function {
409 fn is_start_char(c: char) -> bool {
410 if c.is_ascii_alphabetic() {
414 '_' | '.' | '$' | '-' => true,
419 fn is_continue_char(c: char) -> bool {
420 is_start_char(c) || c.is_ascii_digit()
422 assert!(is_start_char(name.chars().next().unwrap()));
423 assert!(name.chars().all(is_continue_char));
424 assert!(self.name_set.insert(name.into()));
425 let name = CString::new(name).unwrap();
428 context: self.context,
429 function: llvm::LLVMAddFunction(self.module, name.as_ptr(), ty.0),
433 fn verify(self) -> Result<LLVM7Module, backend::VerificationFailure<'a, LLVM7Module>> {
435 let mut message = null_mut();
436 let broken = to_bool(llvm::LLVMVerifyModule(
438 llvm::LLVMReturnStatusAction,
442 let message = LLVM7String::from_ptr(message).unwrap();
443 let message = message.to_string_lossy();
444 Err(backend::VerificationFailure::new(self, message.as_ref()))
450 unsafe fn to_verified_module_unchecked(self) -> LLVM7Module {
455 impl<'a> backend::VerifiedModule<'a> for LLVM7Module {
456 type Context = LLVM7Context;
457 fn into_module(self) -> LLVM7Module {
462 struct LLVM7TargetMachine(llvm::LLVMTargetMachineRef);
464 impl Drop for LLVM7TargetMachine {
467 llvm::LLVMDisposeTargetMachine(self.0);
472 impl LLVM7TargetMachine {
473 fn take(mut self) -> llvm::LLVMTargetMachineRef {
480 struct LLVM7OrcJITStack(llvm::LLVMOrcJITStackRef);
482 impl Drop for LLVM7OrcJITStack {
485 match llvm::LLVMOrcDisposeInstance(self.0) {
486 llvm::LLVMOrcErrSuccess => {}
488 panic!("LLVMOrcDisposeInstance failed");
495 fn initialize_native_target() {
496 static ONCE: Once = ONCE_INIT;
497 ONCE.call_once(|| unsafe {
498 llvm::LLVM_InitializeNativeTarget();
499 llvm::LLVM_InitializeNativeAsmPrinter();
500 llvm::LLVM_InitializeNativeAsmParser();
504 extern "C" fn symbol_resolver_fn<Void>(name: *const c_char, _lookup_context: *mut Void) -> u64 {
505 let name = unsafe { CStr::from_ptr(name) };
506 panic!("symbol_resolver_fn is unimplemented: name = {:?}", name)
509 #[derive(Copy, Clone)]
510 pub struct LLVM7Compiler;
512 impl backend::Compiler for LLVM7Compiler {
513 type Config = LLVM7CompilerConfig;
514 fn name(self) -> &'static str {
517 fn run<U: backend::CompilerUser>(
520 config: LLVM7CompilerConfig,
521 ) -> Result<Box<dyn backend::CompiledCode<U::FunctionKey>>, U::Error> {
523 initialize_native_target();
524 let context = OwnedContext(llvm::LLVMContextCreate());
525 let modules = Vec::new();
526 let mut context = LLVM7Context {
527 context: Some(ManuallyDrop::new(context)),
528 modules: ManuallyDrop::new(RefCell::new(modules)),
529 config: config.clone(),
531 let backend::CompileInputs {
534 } = user.run(&context)?;
535 let callable_functions: Vec<_> = callable_functions
537 .map(|(key, callable_function)| {
539 llvm::LLVMGetGlobalParent(callable_function.function),
543 CStr::from_ptr(llvm::LLVMGetValueName(callable_function.function)).into();
544 assert_ne!(name.to_bytes().len(), 0);
552 .find(|v| v.0 == module.module)
554 let target_triple = LLVM7String::from_ptr(llvm::LLVMGetDefaultTargetTriple()).unwrap();
555 let mut target = null_mut();
556 let mut error = null_mut();
557 let success = !to_bool(llvm::LLVMGetTargetFromTriple(
558 target_triple.as_ptr(),
563 let error = LLVM7String::from_ptr(error).unwrap();
564 return Err(U::create_error(error.to_string_lossy().into()));
566 if !to_bool(llvm::LLVMTargetHasJIT(target)) {
567 return Err(U::create_error(format!(
568 "target {:?} doesn't support JIT",
572 let host_cpu_name = LLVM7String::from_ptr(llvm::LLVMGetHostCPUName()).unwrap();
573 let host_cpu_features = LLVM7String::from_ptr(llvm::LLVMGetHostCPUFeatures()).unwrap();
574 let target_machine = LLVM7TargetMachine(llvm::LLVMCreateTargetMachine(
576 target_triple.as_ptr(),
577 host_cpu_name.as_ptr(),
578 host_cpu_features.as_ptr(),
579 match config.optimization_mode {
580 backend::OptimizationMode::NoOptimizations => llvm::LLVMCodeGenLevelNone,
581 backend::OptimizationMode::Normal => llvm::LLVMCodeGenLevelDefault,
583 llvm::LLVMRelocDefault,
584 llvm::LLVMCodeModelJITDefault,
586 assert!(!target_machine.0.is_null());
588 LLVM7OrcJITStack(llvm::LLVMOrcCreateInstance(target_machine.take()));
589 let mut module_handle = 0;
590 if llvm::LLVMOrcErrSuccess != llvm::LLVMOrcAddEagerlyCompiledIR(
594 Some(symbol_resolver_fn),
597 return Err(U::create_error("compilation failed".into()));
599 let mut functions: HashMap<_, _> = HashMap::new();
600 for (key, name) in callable_functions {
601 let mut address: llvm::LLVMOrcTargetAddress = mem::zeroed();
602 if llvm::LLVMOrcErrSuccess != llvm::LLVMOrcGetSymbolAddressIn(
608 return Err(U::create_error(format!(
609 "function not found in compiled module: {:?}",
613 let address: Option<unsafe extern "C" fn()> = mem::transmute(address as usize);
614 if functions.insert(key, address.unwrap()).is_some() {
615 return Err(U::create_error(format!("duplicate function: {:?}", name)));
618 struct CompiledCode<K: Hash + Eq + Send + Sync + 'static> {
619 functions: HashMap<K, unsafe extern "C" fn()>,
620 orc_jit_stack: ManuallyDrop<LLVM7OrcJITStack>,
621 context: ManuallyDrop<OwnedContext>,
623 unsafe impl<K: Hash + Eq + Send + Sync + 'static> Send for CompiledCode<K> {}
624 unsafe impl<K: Hash + Eq + Send + Sync + 'static> Sync for CompiledCode<K> {}
625 impl<K: Hash + Eq + Send + Sync + 'static> Drop for CompiledCode<K> {
628 ManuallyDrop::drop(&mut self.orc_jit_stack);
629 ManuallyDrop::drop(&mut self.context);
633 impl<K: Hash + Eq + Send + Sync + 'static> backend::CompiledCode<K> for CompiledCode<K> {
634 fn get(&self, key: &K) -> Option<unsafe extern "C" fn()> {
635 Some(*self.functions.get(key)?)
638 Ok(Box::new(CompiledCode {
640 orc_jit_stack: ManuallyDrop::new(orc_jit_stack),
641 context: context.context.take().unwrap(),