add function parameters
[kazan.git] / shader-compiler-llvm-7 / src / backend.rs
1 // SPDX-License-Identifier: LGPL-2.1-or-later
2 // Copyright 2018 Jacob Lifshay
3 use llvm;
4 use shader_compiler::backend;
5 use std::cell::RefCell;
6 use std::collections::HashMap;
7 use std::collections::HashSet;
8 use std::ffi::{CStr, CString};
9 use std::fmt;
10 use std::hash::Hash;
11 use std::mem;
12 use std::mem::ManuallyDrop;
13 use std::ops::Deref;
14 use std::os::raw::{c_char, c_uint};
15 use std::ptr::null_mut;
16 use std::ptr::NonNull;
17 use std::sync::{Once, ONCE_INIT};
18
19 fn to_bool(v: llvm::LLVMBool) -> bool {
20 v != 0
21 }
22
23 #[derive(Clone)]
24 pub struct LLVM7CompilerConfig {
25 pub variable_vector_length_multiplier: u32,
26 pub optimization_mode: backend::OptimizationMode,
27 }
28
29 impl Default for LLVM7CompilerConfig {
30 fn default() -> Self {
31 backend::CompilerIndependentConfig::default().into()
32 }
33 }
34
35 impl From<backend::CompilerIndependentConfig> for LLVM7CompilerConfig {
36 fn from(v: backend::CompilerIndependentConfig) -> Self {
37 let backend::CompilerIndependentConfig { optimization_mode } = v;
38 Self {
39 variable_vector_length_multiplier: 1,
40 optimization_mode,
41 }
42 }
43 }
44
45 #[repr(transparent)]
46 struct LLVM7String(NonNull<c_char>);
47
48 impl Drop for LLVM7String {
49 fn drop(&mut self) {
50 unsafe {
51 llvm::LLVMDisposeMessage(self.0.as_ptr());
52 }
53 }
54 }
55
56 impl Deref for LLVM7String {
57 type Target = CStr;
58 fn deref(&self) -> &CStr {
59 unsafe { CStr::from_ptr(self.0.as_ptr()) }
60 }
61 }
62
63 impl Clone for LLVM7String {
64 fn clone(&self) -> Self {
65 Self::new(self)
66 }
67 }
68
69 impl LLVM7String {
70 fn new(v: &CStr) -> Self {
71 unsafe { Self::from_ptr(llvm::LLVMCreateMessage(v.as_ptr())).unwrap() }
72 }
73 unsafe fn from_nonnull(v: NonNull<c_char>) -> Self {
74 LLVM7String(v)
75 }
76 unsafe fn from_ptr(v: *mut c_char) -> Option<Self> {
77 NonNull::new(v).map(|v| Self::from_nonnull(v))
78 }
79 }
80
81 impl fmt::Debug for LLVM7String {
82 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
83 (**self).fmt(f)
84 }
85 }
86
87 #[derive(Clone, Eq, PartialEq, Hash)]
88 #[repr(transparent)]
89 pub struct LLVM7Type(llvm::LLVMTypeRef);
90
91 impl fmt::Debug for LLVM7Type {
92 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
93 unsafe {
94 let string =
95 LLVM7String::from_ptr(llvm::LLVMPrintTypeToString(self.0)).ok_or(fmt::Error)?;
96 f.write_str(&string.to_string_lossy())
97 }
98 }
99 }
100
101 impl<'a> backend::types::Type<'a> for LLVM7Type {
102 type Context = LLVM7Context;
103 }
104
105 pub struct LLVM7TypeBuilder {
106 context: llvm::LLVMContextRef,
107 variable_vector_length_multiplier: u32,
108 }
109
110 impl<'a> backend::types::TypeBuilder<'a, LLVM7Type> for LLVM7TypeBuilder {
111 fn build_bool(&self) -> LLVM7Type {
112 unsafe { LLVM7Type(llvm::LLVMInt1TypeInContext(self.context)) }
113 }
114 fn build_i8(&self) -> LLVM7Type {
115 unsafe { LLVM7Type(llvm::LLVMInt8TypeInContext(self.context)) }
116 }
117 fn build_i16(&self) -> LLVM7Type {
118 unsafe { LLVM7Type(llvm::LLVMInt16TypeInContext(self.context)) }
119 }
120 fn build_i32(&self) -> LLVM7Type {
121 unsafe { LLVM7Type(llvm::LLVMInt32TypeInContext(self.context)) }
122 }
123 fn build_i64(&self) -> LLVM7Type {
124 unsafe { LLVM7Type(llvm::LLVMInt64TypeInContext(self.context)) }
125 }
126 fn build_f32(&self) -> LLVM7Type {
127 unsafe { LLVM7Type(llvm::LLVMFloatTypeInContext(self.context)) }
128 }
129 fn build_f64(&self) -> LLVM7Type {
130 unsafe { LLVM7Type(llvm::LLVMDoubleTypeInContext(self.context)) }
131 }
132 fn build_pointer(&self, target: LLVM7Type) -> LLVM7Type {
133 unsafe { LLVM7Type(llvm::LLVMPointerType(target.0, 0)) }
134 }
135 fn build_array(&self, element: LLVM7Type, count: usize) -> LLVM7Type {
136 assert_eq!(count as u32 as usize, count);
137 unsafe { LLVM7Type(llvm::LLVMArrayType(element.0, count as u32)) }
138 }
139 fn build_vector(&self, element: LLVM7Type, length: backend::types::VectorLength) -> LLVM7Type {
140 use self::backend::types::VectorLength::*;
141 let length = match length {
142 Fixed { length } => length,
143 Variable { base_length } => base_length
144 .checked_mul(self.variable_vector_length_multiplier)
145 .unwrap(),
146 };
147 assert_ne!(length, 0);
148 unsafe { LLVM7Type(llvm::LLVMVectorType(element.0, length)) }
149 }
150 fn build_struct(&self, members: &[LLVM7Type]) -> LLVM7Type {
151 assert_eq!(members.len() as c_uint as usize, members.len());
152 unsafe {
153 LLVM7Type(llvm::LLVMStructTypeInContext(
154 self.context,
155 members.as_ptr() as *mut llvm::LLVMTypeRef,
156 members.len() as c_uint,
157 false as llvm::LLVMBool,
158 ))
159 }
160 }
161 fn build_function(&self, arguments: &[LLVM7Type], return_type: Option<LLVM7Type>) -> LLVM7Type {
162 assert_eq!(arguments.len() as c_uint as usize, arguments.len());
163 unsafe {
164 LLVM7Type(llvm::LLVMFunctionType(
165 return_type
166 .unwrap_or_else(|| LLVM7Type(llvm::LLVMVoidTypeInContext(self.context)))
167 .0,
168 arguments.as_ptr() as *mut llvm::LLVMTypeRef,
169 arguments.len() as c_uint,
170 false as llvm::LLVMBool,
171 ))
172 }
173 }
174 }
175
176 #[derive(Clone)]
177 #[repr(transparent)]
178 pub struct LLVM7Value(llvm::LLVMValueRef);
179
180 impl fmt::Debug for LLVM7Value {
181 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
182 unsafe {
183 let string =
184 LLVM7String::from_ptr(llvm::LLVMPrintValueToString(self.0)).ok_or(fmt::Error)?;
185 f.write_str(&string.to_string_lossy())
186 }
187 }
188 }
189
190 impl<'a> backend::Value<'a> for LLVM7Value {
191 type Context = LLVM7Context;
192 }
193
194 #[derive(Clone)]
195 #[repr(transparent)]
196 pub struct LLVM7BasicBlock(llvm::LLVMBasicBlockRef);
197
198 impl fmt::Debug for LLVM7BasicBlock {
199 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
200 use self::backend::BasicBlock;
201 unsafe {
202 let string = LLVM7String::from_ptr(llvm::LLVMPrintValueToString(self.as_value().0))
203 .ok_or(fmt::Error)?;
204 f.write_str(&string.to_string_lossy())
205 }
206 }
207 }
208
209 impl<'a> backend::BasicBlock<'a> for LLVM7BasicBlock {
210 type Context = LLVM7Context;
211 fn as_value(&self) -> LLVM7Value {
212 unsafe { LLVM7Value(llvm::LLVMBasicBlockAsValue(self.0)) }
213 }
214 }
215
216 impl<'a> backend::BuildableBasicBlock<'a> for LLVM7BasicBlock {
217 type Context = LLVM7Context;
218 fn as_basic_block(&self) -> LLVM7BasicBlock {
219 self.clone()
220 }
221 }
222
223 pub struct LLVM7Function {
224 context: llvm::LLVMContextRef,
225 function: llvm::LLVMValueRef,
226 parameters: Box<[LLVM7Value]>,
227 }
228
229 impl fmt::Debug for LLVM7Function {
230 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
231 unsafe {
232 let string = LLVM7String::from_ptr(llvm::LLVMPrintValueToString(self.function))
233 .ok_or(fmt::Error)?;
234 f.write_str(&string.to_string_lossy())
235 }
236 }
237 }
238
239 impl<'a> backend::Function<'a> for LLVM7Function {
240 type Context = LLVM7Context;
241 fn as_value(&self) -> LLVM7Value {
242 LLVM7Value(self.function)
243 }
244 fn append_new_basic_block(&mut self, name: Option<&str>) -> LLVM7BasicBlock {
245 let name = CString::new(name.unwrap_or("")).unwrap();
246 unsafe {
247 LLVM7BasicBlock(llvm::LLVMAppendBasicBlockInContext(
248 self.context,
249 self.function,
250 name.as_ptr(),
251 ))
252 }
253 }
254 fn parameters(&self) -> &[LLVM7Value] {
255 &self.parameters
256 }
257 }
258
259 pub struct LLVM7Context {
260 context: Option<ManuallyDrop<OwnedContext>>,
261 modules: ManuallyDrop<RefCell<Vec<OwnedModule>>>,
262 config: LLVM7CompilerConfig,
263 }
264
265 impl Drop for LLVM7Context {
266 fn drop(&mut self) {
267 unsafe {
268 ManuallyDrop::drop(&mut self.modules);
269 if let Some(context) = &mut self.context {
270 ManuallyDrop::drop(context);
271 }
272 }
273 }
274 }
275
276 impl<'a> backend::Context<'a> for LLVM7Context {
277 type Value = LLVM7Value;
278 type BasicBlock = LLVM7BasicBlock;
279 type BuildableBasicBlock = LLVM7BasicBlock;
280 type Function = LLVM7Function;
281 type Type = LLVM7Type;
282 type TypeBuilder = LLVM7TypeBuilder;
283 type Module = LLVM7Module;
284 type VerifiedModule = LLVM7Module;
285 type AttachedBuilder = LLVM7Builder;
286 type DetachedBuilder = LLVM7Builder;
287 fn create_module(&self, name: &str) -> LLVM7Module {
288 let name = CString::new(name).unwrap();
289 let mut modules = self.modules.borrow_mut();
290 unsafe {
291 let module = OwnedModule(llvm::LLVMModuleCreateWithNameInContext(
292 name.as_ptr(),
293 self.context.as_ref().unwrap().0,
294 ));
295 let module_ref = module.0;
296 modules.push(module);
297 LLVM7Module {
298 context: self.context.as_ref().unwrap().0,
299 module: module_ref,
300 name_set: HashSet::new(),
301 }
302 }
303 }
304 fn create_builder(&self) -> LLVM7Builder {
305 unsafe {
306 LLVM7Builder(llvm::LLVMCreateBuilderInContext(
307 self.context.as_ref().unwrap().0,
308 ))
309 }
310 }
311 fn create_type_builder(&self) -> LLVM7TypeBuilder {
312 LLVM7TypeBuilder {
313 context: self.context.as_ref().unwrap().0,
314 variable_vector_length_multiplier: self.config.variable_vector_length_multiplier,
315 }
316 }
317 }
318
319 #[repr(transparent)]
320 pub struct LLVM7Builder(llvm::LLVMBuilderRef);
321
322 impl Drop for LLVM7Builder {
323 fn drop(&mut self) {
324 unsafe {
325 llvm::LLVMDisposeBuilder(self.0);
326 }
327 }
328 }
329
330 impl<'a> backend::AttachedBuilder<'a> for LLVM7Builder {
331 type Context = LLVM7Context;
332 fn current_basic_block(&self) -> LLVM7BasicBlock {
333 unsafe { LLVM7BasicBlock(llvm::LLVMGetInsertBlock(self.0)) }
334 }
335 fn build_return(self, value: Option<LLVM7Value>) -> LLVM7Builder {
336 unsafe {
337 match value {
338 Some(value) => llvm::LLVMBuildRet(self.0, value.0),
339 None => llvm::LLVMBuildRetVoid(self.0),
340 };
341 llvm::LLVMClearInsertionPosition(self.0);
342 }
343 self
344 }
345 }
346
347 impl<'a> backend::DetachedBuilder<'a> for LLVM7Builder {
348 type Context = LLVM7Context;
349 fn attach(self, basic_block: LLVM7BasicBlock) -> LLVM7Builder {
350 unsafe {
351 llvm::LLVMPositionBuilderAtEnd(self.0, basic_block.0);
352 }
353 self
354 }
355 }
356
357 struct OwnedModule(llvm::LLVMModuleRef);
358
359 impl Drop for OwnedModule {
360 fn drop(&mut self) {
361 unsafe {
362 llvm::LLVMDisposeModule(self.0);
363 }
364 }
365 }
366
367 impl OwnedModule {
368 unsafe fn take(mut self) -> llvm::LLVMModuleRef {
369 let retval = self.0;
370 self.0 = null_mut();
371 retval
372 }
373 }
374
375 struct OwnedContext(llvm::LLVMContextRef);
376
377 impl Drop for OwnedContext {
378 fn drop(&mut self) {
379 unsafe {
380 llvm::LLVMContextDispose(self.0);
381 }
382 }
383 }
384
385 pub struct LLVM7Module {
386 context: llvm::LLVMContextRef,
387 module: llvm::LLVMModuleRef,
388 name_set: HashSet<String>,
389 }
390
391 impl fmt::Debug for LLVM7Module {
392 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
393 unsafe {
394 let string = LLVM7String::from_ptr(llvm::LLVMPrintModuleToString(self.module))
395 .ok_or(fmt::Error)?;
396 f.write_str(&string.to_string_lossy())
397 }
398 }
399 }
400
401 impl<'a> backend::Module<'a> for LLVM7Module {
402 type Context = LLVM7Context;
403 fn set_source_file_name(&mut self, source_file_name: &str) {
404 unsafe {
405 llvm::LLVMSetSourceFileName(
406 self.module,
407 source_file_name.as_ptr() as *const c_char,
408 source_file_name.len(),
409 )
410 }
411 }
412 fn add_function(&mut self, name: &str, ty: LLVM7Type) -> LLVM7Function {
413 fn is_start_char(c: char) -> bool {
414 if c.is_ascii_alphabetic() {
415 true
416 } else {
417 match c {
418 '_' | '.' | '$' | '-' => true,
419 _ => false,
420 }
421 }
422 }
423 fn is_continue_char(c: char) -> bool {
424 is_start_char(c) || c.is_ascii_digit()
425 }
426 assert!(is_start_char(name.chars().next().unwrap()));
427 assert!(name.chars().all(is_continue_char));
428 assert!(self.name_set.insert(name.into()));
429 let name = CString::new(name).unwrap();
430 unsafe {
431 let function = llvm::LLVMAddFunction(self.module, name.as_ptr(), ty.0);
432 let mut parameters = Vec::new();
433 parameters.resize(llvm::LLVMCountParams(function) as usize, null_mut());
434 llvm::LLVMGetParams(function, parameters.as_mut_ptr());
435 let parameters: Vec<_> = parameters.into_iter().map(LLVM7Value).collect();
436 LLVM7Function {
437 context: self.context,
438 function: llvm::LLVMAddFunction(self.module, name.as_ptr(), ty.0),
439 parameters: parameters.into_boxed_slice(),
440 }
441 }
442 }
443 fn verify(self) -> Result<LLVM7Module, backend::VerificationFailure<'a, LLVM7Module>> {
444 unsafe {
445 let mut message = null_mut();
446 let broken = to_bool(llvm::LLVMVerifyModule(
447 self.module,
448 llvm::LLVMReturnStatusAction,
449 &mut message,
450 ));
451 if broken {
452 let message = LLVM7String::from_ptr(message).unwrap();
453 let message = message.to_string_lossy();
454 Err(backend::VerificationFailure::new(self, message.as_ref()))
455 } else {
456 Ok(self)
457 }
458 }
459 }
460 unsafe fn to_verified_module_unchecked(self) -> LLVM7Module {
461 self
462 }
463 }
464
465 impl<'a> backend::VerifiedModule<'a> for LLVM7Module {
466 type Context = LLVM7Context;
467 fn into_module(self) -> LLVM7Module {
468 self
469 }
470 }
471
472 struct LLVM7TargetMachine(llvm::LLVMTargetMachineRef);
473
474 impl Drop for LLVM7TargetMachine {
475 fn drop(&mut self) {
476 unsafe {
477 llvm::LLVMDisposeTargetMachine(self.0);
478 }
479 }
480 }
481
482 impl LLVM7TargetMachine {
483 fn take(mut self) -> llvm::LLVMTargetMachineRef {
484 let retval = self.0;
485 self.0 = null_mut();
486 retval
487 }
488 }
489
490 struct LLVM7OrcJITStack(llvm::LLVMOrcJITStackRef);
491
492 impl Drop for LLVM7OrcJITStack {
493 fn drop(&mut self) {
494 unsafe {
495 match llvm::LLVMOrcDisposeInstance(self.0) {
496 llvm::LLVMOrcErrSuccess => {}
497 _ => {
498 panic!("LLVMOrcDisposeInstance failed");
499 }
500 }
501 }
502 }
503 }
504
505 fn initialize_native_target() {
506 static ONCE: Once = ONCE_INIT;
507 ONCE.call_once(|| unsafe {
508 llvm::LLVM_InitializeNativeTarget();
509 llvm::LLVM_InitializeNativeAsmPrinter();
510 llvm::LLVM_InitializeNativeAsmParser();
511 });
512 }
513
514 extern "C" fn symbol_resolver_fn<Void>(name: *const c_char, _lookup_context: *mut Void) -> u64 {
515 let name = unsafe { CStr::from_ptr(name) };
516 panic!("symbol_resolver_fn is unimplemented: name = {:?}", name)
517 }
518
519 #[derive(Copy, Clone)]
520 pub struct LLVM7Compiler;
521
522 impl backend::Compiler for LLVM7Compiler {
523 type Config = LLVM7CompilerConfig;
524 fn name(self) -> &'static str {
525 "LLVM 7"
526 }
527 fn run<U: backend::CompilerUser>(
528 self,
529 user: U,
530 config: LLVM7CompilerConfig,
531 ) -> Result<Box<dyn backend::CompiledCode<U::FunctionKey>>, U::Error> {
532 unsafe {
533 initialize_native_target();
534 let context = OwnedContext(llvm::LLVMContextCreate());
535 let modules = Vec::new();
536 let mut context = LLVM7Context {
537 context: Some(ManuallyDrop::new(context)),
538 modules: ManuallyDrop::new(RefCell::new(modules)),
539 config: config.clone(),
540 };
541 let backend::CompileInputs {
542 module,
543 callable_functions,
544 } = user.run(&context)?;
545 let callable_functions: Vec<_> = callable_functions
546 .into_iter()
547 .map(|(key, callable_function)| {
548 assert_eq!(
549 llvm::LLVMGetGlobalParent(callable_function.function),
550 module.module
551 );
552 let name: CString =
553 CStr::from_ptr(llvm::LLVMGetValueName(callable_function.function)).into();
554 assert_ne!(name.to_bytes().len(), 0);
555 (key, name)
556 })
557 .collect();
558 let module = context
559 .modules
560 .get_mut()
561 .drain(..)
562 .find(|v| v.0 == module.module)
563 .unwrap();
564 let target_triple = LLVM7String::from_ptr(llvm::LLVMGetDefaultTargetTriple()).unwrap();
565 let mut target = null_mut();
566 let mut error = null_mut();
567 let success = !to_bool(llvm::LLVMGetTargetFromTriple(
568 target_triple.as_ptr(),
569 &mut target,
570 &mut error,
571 ));
572 if !success {
573 let error = LLVM7String::from_ptr(error).unwrap();
574 return Err(U::create_error(error.to_string_lossy().into()));
575 }
576 if !to_bool(llvm::LLVMTargetHasJIT(target)) {
577 return Err(U::create_error(format!(
578 "target {:?} doesn't support JIT",
579 target_triple
580 )));
581 }
582 let host_cpu_name = LLVM7String::from_ptr(llvm::LLVMGetHostCPUName()).unwrap();
583 let host_cpu_features = LLVM7String::from_ptr(llvm::LLVMGetHostCPUFeatures()).unwrap();
584 let target_machine = LLVM7TargetMachine(llvm::LLVMCreateTargetMachine(
585 target,
586 target_triple.as_ptr(),
587 host_cpu_name.as_ptr(),
588 host_cpu_features.as_ptr(),
589 match config.optimization_mode {
590 backend::OptimizationMode::NoOptimizations => llvm::LLVMCodeGenLevelNone,
591 backend::OptimizationMode::Normal => llvm::LLVMCodeGenLevelDefault,
592 },
593 llvm::LLVMRelocDefault,
594 llvm::LLVMCodeModelJITDefault,
595 ));
596 assert!(!target_machine.0.is_null());
597 let orc_jit_stack =
598 LLVM7OrcJITStack(llvm::LLVMOrcCreateInstance(target_machine.take()));
599 let mut module_handle = 0;
600 if llvm::LLVMOrcErrSuccess != llvm::LLVMOrcAddEagerlyCompiledIR(
601 orc_jit_stack.0,
602 &mut module_handle,
603 module.take(),
604 Some(symbol_resolver_fn),
605 null_mut(),
606 ) {
607 return Err(U::create_error("compilation failed".into()));
608 }
609 let mut functions: HashMap<_, _> = HashMap::new();
610 for (key, name) in callable_functions {
611 let mut address: llvm::LLVMOrcTargetAddress = mem::zeroed();
612 if llvm::LLVMOrcErrSuccess != llvm::LLVMOrcGetSymbolAddressIn(
613 orc_jit_stack.0,
614 &mut address,
615 module_handle,
616 name.as_ptr(),
617 ) {
618 return Err(U::create_error(format!(
619 "function not found in compiled module: {:?}",
620 name
621 )));
622 }
623 let address: Option<unsafe extern "C" fn()> = mem::transmute(address as usize);
624 if functions.insert(key, address.unwrap()).is_some() {
625 return Err(U::create_error(format!("duplicate function: {:?}", name)));
626 }
627 }
628 struct CompiledCode<K: Hash + Eq + Send + Sync + 'static> {
629 functions: HashMap<K, unsafe extern "C" fn()>,
630 orc_jit_stack: ManuallyDrop<LLVM7OrcJITStack>,
631 context: ManuallyDrop<OwnedContext>,
632 }
633 unsafe impl<K: Hash + Eq + Send + Sync + 'static> Send for CompiledCode<K> {}
634 unsafe impl<K: Hash + Eq + Send + Sync + 'static> Sync for CompiledCode<K> {}
635 impl<K: Hash + Eq + Send + Sync + 'static> Drop for CompiledCode<K> {
636 fn drop(&mut self) {
637 unsafe {
638 ManuallyDrop::drop(&mut self.orc_jit_stack);
639 ManuallyDrop::drop(&mut self.context);
640 }
641 }
642 }
643 impl<K: Hash + Eq + Send + Sync + 'static> backend::CompiledCode<K> for CompiledCode<K> {
644 fn get(&self, key: &K) -> Option<unsafe extern "C" fn()> {
645 Some(*self.functions.get(key)?)
646 }
647 }
648 Ok(Box::new(CompiledCode {
649 functions,
650 orc_jit_stack: ManuallyDrop::new(orc_jit_stack),
651 context: context.context.take().unwrap(),
652 }))
653 }
654 }
655 }