2e1a4767b610b5725a7b34bbd1110b19e01a53a8
[kazan.git] / shader-compiler / src / lib.rs
1 // SPDX-License-Identifier: LGPL-2.1-or-later
2 // Copyright 2018 Jacob Lifshay
3
4 extern crate shader_compiler_backend;
5 extern crate spirv_parser;
6
7 use spirv_parser::{
8 BuiltIn, Decoration, ExecutionMode, ExecutionModel, IdRef, Instruction, StorageClass,
9 };
10 use std::cell::RefCell;
11 use std::collections::HashSet;
12 use std::fmt;
13 use std::hash::{Hash, Hasher};
14 use std::mem;
15 use std::ops::{Index, IndexMut};
16 use std::rc::Rc;
17
18 pub struct Context {
19 types: pointer_type::ContextTypes,
20 next_struct_id: usize,
21 }
22
23 impl Default for Context {
24 fn default() -> Context {
25 Context {
26 types: Default::default(),
27 next_struct_id: 0,
28 }
29 }
30 }
31
32 mod pointer_type {
33 use super::{Context, Type};
34 use std::cell::RefCell;
35 use std::fmt;
36 use std::hash::{Hash, Hasher};
37 use std::rc::{Rc, Weak};
38
39 #[derive(Default)]
40 pub struct ContextTypes(Vec<Rc<Type>>);
41
42 #[derive(Clone, Debug)]
43 enum PointerTypeState {
44 Void,
45 Normal(Weak<Type>),
46 Unresolved,
47 }
48
49 #[derive(Clone)]
50 pub struct PointerType {
51 pointee: RefCell<PointerTypeState>,
52 }
53
54 impl fmt::Debug for PointerType {
55 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
56 let mut state = f.debug_struct("PointerType");
57 if let PointerTypeState::Unresolved = *self.pointee.borrow() {
58 state.field("pointee", &PointerTypeState::Unresolved);
59 } else {
60 state.field("pointee", &self.pointee());
61 }
62 state.finish()
63 }
64 }
65
66 impl PointerType {
67 pub fn new(context: &mut Context, pointee: Option<Rc<Type>>) -> Self {
68 Self {
69 pointee: RefCell::new(match pointee {
70 Some(pointee) => {
71 let weak = Rc::downgrade(&pointee);
72 context.types.0.push(pointee);
73 PointerTypeState::Normal(weak)
74 }
75 None => PointerTypeState::Void,
76 }),
77 }
78 }
79 pub fn new_void() -> Self {
80 Self {
81 pointee: RefCell::new(PointerTypeState::Void),
82 }
83 }
84 pub fn unresolved() -> Self {
85 Self {
86 pointee: RefCell::new(PointerTypeState::Unresolved),
87 }
88 }
89 pub fn resolve(&self, context: &mut Context, new_pointee: Option<Rc<Type>>) {
90 let mut pointee = self.pointee.borrow_mut();
91 match &*pointee {
92 PointerTypeState::Unresolved => {}
93 _ => unreachable!("pointer already resolved"),
94 }
95 *pointee = Self::new(context, new_pointee).pointee.into_inner();
96 }
97 pub fn pointee(&self) -> Option<Rc<Type>> {
98 match *self.pointee.borrow() {
99 PointerTypeState::Normal(ref pointee) => Some(
100 pointee
101 .upgrade()
102 .expect("PointerType is not valid after the associated Context is dropped"),
103 ),
104 PointerTypeState::Void => None,
105 PointerTypeState::Unresolved => {
106 unreachable!("pointee() called on unresolved pointer")
107 }
108 }
109 }
110 }
111
112 impl PartialEq for PointerType {
113 fn eq(&self, rhs: &Self) -> bool {
114 self.pointee() == rhs.pointee()
115 }
116 }
117
118 impl Eq for PointerType {}
119
120 impl Hash for PointerType {
121 fn hash<H: Hasher>(&self, hasher: &mut H) {
122 self.pointee().hash(hasher);
123 }
124 }
125 }
126
127 pub use pointer_type::PointerType;
128
129 #[derive(Clone, Eq, PartialEq, Hash, Debug)]
130 pub enum ScalarType {
131 I8,
132 U8,
133 I16,
134 U16,
135 I32,
136 U32,
137 I64,
138 U64,
139 F16,
140 F32,
141 F64,
142 Bool,
143 Pointer(PointerType),
144 }
145
146 #[derive(Clone, Eq, PartialEq, Hash, Debug)]
147 pub struct VectorType {
148 pub element: ScalarType,
149 pub element_count: usize,
150 }
151
152 #[derive(Clone, Eq, PartialEq, Hash, Debug)]
153 pub struct StructMember {
154 pub decorations: Vec<Decoration>,
155 pub member_type: Rc<Type>,
156 }
157
158 #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
159 pub struct StructId(usize);
160
161 impl StructId {
162 pub fn new(context: &mut Context) -> Self {
163 let retval = StructId(context.next_struct_id);
164 context.next_struct_id += 1;
165 retval
166 }
167 }
168
169 #[derive(Clone)]
170 pub struct StructType {
171 pub id: StructId,
172 pub decorations: Vec<Decoration>,
173 pub members: Vec<StructMember>,
174 }
175
176 impl Eq for StructType {}
177
178 impl PartialEq for StructType {
179 fn eq(&self, rhs: &Self) -> bool {
180 self.id == rhs.id
181 }
182 }
183
184 impl Hash for StructType {
185 fn hash<H: Hasher>(&self, h: &mut H) {
186 self.id.hash(h)
187 }
188 }
189
190 impl fmt::Debug for StructType {
191 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
192 thread_local! {
193 static CURRENTLY_FORMATTING: RefCell<HashSet<StructId>> = RefCell::new(HashSet::new());
194 }
195 struct CurrentlyFormatting {
196 id: StructId,
197 was_formatting: bool,
198 }
199 impl CurrentlyFormatting {
200 fn new(id: StructId) -> Self {
201 let was_formatting = CURRENTLY_FORMATTING
202 .with(|currently_formatting| !currently_formatting.borrow_mut().insert(id));
203 Self { id, was_formatting }
204 }
205 }
206 impl Drop for CurrentlyFormatting {
207 fn drop(&mut self) {
208 if !self.was_formatting {
209 CURRENTLY_FORMATTING.with(|currently_formatting| {
210 currently_formatting.borrow_mut().remove(&self.id);
211 });
212 }
213 }
214 }
215 let currently_formatting = CurrentlyFormatting::new(self.id);
216 let mut state = f.debug_struct("StructType");
217 state.field("id", &self.id);
218 if !currently_formatting.was_formatting {
219 state.field("decorations", &self.decorations);
220 state.field("members", &self.members);
221 }
222 state.finish()
223 }
224 }
225
226 #[derive(Clone, Eq, PartialEq, Hash, Debug)]
227 pub struct ArrayType {
228 pub decorations: Vec<Decoration>,
229 pub element: Rc<Type>,
230 pub element_count: Option<usize>,
231 }
232
233 #[derive(Clone, Eq, PartialEq, Hash, Debug)]
234 pub enum Type {
235 Scalar(ScalarType),
236 Vector(VectorType),
237 Struct(StructType),
238 Array(ArrayType),
239 }
240
241 impl Type {
242 pub fn is_pointer(&self) -> bool {
243 if let Type::Scalar(ScalarType::Pointer(_)) = self {
244 true
245 } else {
246 false
247 }
248 }
249 pub fn is_scalar(&self) -> bool {
250 if let Type::Scalar(_) = self {
251 true
252 } else {
253 false
254 }
255 }
256 pub fn is_vector(&self) -> bool {
257 if let Type::Vector(_) = self {
258 true
259 } else {
260 false
261 }
262 }
263 pub fn get_pointee(&self) -> Option<Rc<Type>> {
264 if let Type::Scalar(ScalarType::Pointer(pointer)) = self {
265 pointer.pointee()
266 } else {
267 unreachable!("not a pointer")
268 }
269 }
270 pub fn get_nonvoid_pointee(&self) -> Rc<Type> {
271 self.get_pointee().expect("void is not allowed here")
272 }
273 pub fn get_scalar(&self) -> &ScalarType {
274 if let Type::Scalar(scalar) = self {
275 scalar
276 } else {
277 unreachable!("not a scalar type")
278 }
279 }
280 pub fn get_vector(&self) -> &VectorType {
281 if let Type::Vector(vector) = self {
282 vector
283 } else {
284 unreachable!("not a vector type")
285 }
286 }
287 }
288
289 /// value that can be either defined or undefined
290 #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
291 pub enum Undefable<T> {
292 Undefined,
293 Defined(T),
294 }
295
296 impl<T> Undefable<T> {
297 pub fn unwrap(self) -> T {
298 match self {
299 Undefable::Undefined => panic!("Undefable::unwrap called on Undefined"),
300 Undefable::Defined(v) => v,
301 }
302 }
303 }
304
305 impl<T> From<T> for Undefable<T> {
306 fn from(v: T) -> Undefable<T> {
307 Undefable::Defined(v)
308 }
309 }
310
311 #[derive(Copy, Clone, Debug)]
312 pub enum ScalarConstant {
313 U8(Undefable<u8>),
314 U16(Undefable<u16>),
315 U32(Undefable<u32>),
316 U64(Undefable<u64>),
317 I8(Undefable<i8>),
318 I16(Undefable<i16>),
319 I32(Undefable<i32>),
320 I64(Undefable<i64>),
321 F16(Undefable<u16>),
322 F32(Undefable<f32>),
323 F64(Undefable<f64>),
324 Bool(Undefable<bool>),
325 }
326
327 macro_rules! define_scalar_vector_constant_impl_without_from {
328 ($type:ident, $name:ident, $get_name:ident) => {
329 impl ScalarConstant {
330 pub fn $get_name(self) -> Undefable<$type> {
331 match self {
332 ScalarConstant::$name(v) => v,
333 _ => unreachable!(concat!("expected a constant ", stringify!($type))),
334 }
335 }
336 }
337 impl VectorConstant {
338 pub fn $get_name(&self) -> &Vec<Undefable<$type>> {
339 match self {
340 VectorConstant::$name(v) => v,
341 _ => unreachable!(concat!(
342 "expected a constant vector with ",
343 stringify!($type),
344 " elements"
345 )),
346 }
347 }
348 }
349 };
350 }
351
352 macro_rules! define_scalar_vector_constant_impl {
353 ($type:ident, $name:ident, $get_name:ident) => {
354 define_scalar_vector_constant_impl_without_from!($type, $name, $get_name);
355 impl From<Undefable<$type>> for ScalarConstant {
356 fn from(v: Undefable<$type>) -> ScalarConstant {
357 ScalarConstant::$name(v)
358 }
359 }
360 impl From<Vec<Undefable<$type>>> for VectorConstant {
361 fn from(v: Vec<Undefable<$type>>) -> VectorConstant {
362 VectorConstant::$name(v)
363 }
364 }
365 };
366 }
367
368 define_scalar_vector_constant_impl!(u8, U8, get_u8);
369 define_scalar_vector_constant_impl!(u16, U16, get_u16);
370 define_scalar_vector_constant_impl!(u32, U32, get_u32);
371 define_scalar_vector_constant_impl!(u64, U64, get_u64);
372 define_scalar_vector_constant_impl!(i8, I8, get_i8);
373 define_scalar_vector_constant_impl!(i16, I16, get_i16);
374 define_scalar_vector_constant_impl!(i32, I32, get_i32);
375 define_scalar_vector_constant_impl!(i64, I64, get_i64);
376 define_scalar_vector_constant_impl_without_from!(u16, F16, get_f16);
377 define_scalar_vector_constant_impl!(f32, F32, get_f32);
378 define_scalar_vector_constant_impl!(f64, F64, get_f64);
379 define_scalar_vector_constant_impl!(bool, Bool, get_bool);
380
381 impl ScalarConstant {
382 pub fn get_type(self) -> Type {
383 Type::Scalar(self.get_scalar_type())
384 }
385 pub fn get_scalar_type(self) -> ScalarType {
386 match self {
387 ScalarConstant::U8(_) => ScalarType::U8,
388 ScalarConstant::U16(_) => ScalarType::U16,
389 ScalarConstant::U32(_) => ScalarType::U32,
390 ScalarConstant::U64(_) => ScalarType::U64,
391 ScalarConstant::I8(_) => ScalarType::I8,
392 ScalarConstant::I16(_) => ScalarType::I16,
393 ScalarConstant::I32(_) => ScalarType::I32,
394 ScalarConstant::I64(_) => ScalarType::I64,
395 ScalarConstant::F16(_) => ScalarType::F16,
396 ScalarConstant::F32(_) => ScalarType::F32,
397 ScalarConstant::F64(_) => ScalarType::F64,
398 ScalarConstant::Bool(_) => ScalarType::Bool,
399 }
400 }
401 }
402
403 #[derive(Clone, Debug)]
404 pub enum VectorConstant {
405 U8(Vec<Undefable<u8>>),
406 U16(Vec<Undefable<u16>>),
407 U32(Vec<Undefable<u32>>),
408 U64(Vec<Undefable<u64>>),
409 I8(Vec<Undefable<i8>>),
410 I16(Vec<Undefable<i16>>),
411 I32(Vec<Undefable<i32>>),
412 I64(Vec<Undefable<i64>>),
413 F16(Vec<Undefable<u16>>),
414 F32(Vec<Undefable<f32>>),
415 F64(Vec<Undefable<f64>>),
416 Bool(Vec<Undefable<bool>>),
417 }
418
419 impl VectorConstant {
420 pub fn get_element_type(&self) -> ScalarType {
421 match self {
422 VectorConstant::U8(_) => ScalarType::U8,
423 VectorConstant::U16(_) => ScalarType::U16,
424 VectorConstant::U32(_) => ScalarType::U32,
425 VectorConstant::U64(_) => ScalarType::U64,
426 VectorConstant::I8(_) => ScalarType::I8,
427 VectorConstant::I16(_) => ScalarType::I16,
428 VectorConstant::I32(_) => ScalarType::I32,
429 VectorConstant::I64(_) => ScalarType::I64,
430 VectorConstant::F16(_) => ScalarType::F16,
431 VectorConstant::F32(_) => ScalarType::F32,
432 VectorConstant::F64(_) => ScalarType::F64,
433 VectorConstant::Bool(_) => ScalarType::Bool,
434 }
435 }
436 pub fn get_element_count(&self) -> usize {
437 match self {
438 VectorConstant::U8(v) => v.len(),
439 VectorConstant::U16(v) => v.len(),
440 VectorConstant::U32(v) => v.len(),
441 VectorConstant::U64(v) => v.len(),
442 VectorConstant::I8(v) => v.len(),
443 VectorConstant::I16(v) => v.len(),
444 VectorConstant::I32(v) => v.len(),
445 VectorConstant::I64(v) => v.len(),
446 VectorConstant::F16(v) => v.len(),
447 VectorConstant::F32(v) => v.len(),
448 VectorConstant::F64(v) => v.len(),
449 VectorConstant::Bool(v) => v.len(),
450 }
451 }
452 pub fn get_type(&self) -> Type {
453 Type::Vector(VectorType {
454 element: self.get_element_type(),
455 element_count: self.get_element_count(),
456 })
457 }
458 }
459
460 #[derive(Clone, Debug)]
461 pub enum Constant {
462 Scalar(ScalarConstant),
463 Vector(VectorConstant),
464 }
465
466 impl Constant {
467 pub fn get_type(&self) -> Type {
468 match self {
469 Constant::Scalar(v) => v.get_type(),
470 Constant::Vector(v) => v.get_type(),
471 }
472 }
473 pub fn get_scalar(&self) -> &ScalarConstant {
474 match self {
475 Constant::Scalar(v) => v,
476 _ => unreachable!("not a scalar constant"),
477 }
478 }
479 }
480
481 #[derive(Debug, Clone)]
482 struct MemberDecoration {
483 member: u32,
484 decoration: Decoration,
485 }
486
487 #[derive(Debug, Clone)]
488 struct BuiltInVariable {
489 built_in: BuiltIn,
490 }
491
492 impl BuiltInVariable {
493 fn get_type(&self, _context: &mut Context) -> Rc<Type> {
494 match self.built_in {
495 BuiltIn::GlobalInvocationId => Rc::new(Type::Vector(VectorType {
496 element: ScalarType::U32,
497 element_count: 3,
498 })),
499 _ => unreachable!("unknown built-in"),
500 }
501 }
502 }
503
504 #[derive(Debug, Clone)]
505 struct UniformVariable {
506 binding: u32,
507 descriptor_set: u32,
508 variable_type: Rc<Type>,
509 }
510
511 #[derive(Debug)]
512 enum IdKind {
513 Undefined,
514 DecorationGroup,
515 Type(Rc<Type>),
516 VoidType,
517 FunctionType {
518 return_type: Option<Rc<Type>>,
519 arguments: Vec<Rc<Type>>,
520 },
521 ForwardPointer(Rc<Type>),
522 BuiltInVariable(BuiltInVariable),
523 Constant(Rc<Constant>),
524 UniformVariable(UniformVariable),
525 }
526
527 #[derive(Debug)]
528 struct IdProperties {
529 kind: IdKind,
530 decorations: Vec<Decoration>,
531 member_decorations: Vec<MemberDecoration>,
532 }
533
534 impl IdProperties {
535 fn is_empty(&self) -> bool {
536 match self.kind {
537 IdKind::Undefined => {}
538 _ => return false,
539 }
540 self.decorations.is_empty() && self.member_decorations.is_empty()
541 }
542 fn set_kind(&mut self, kind: IdKind) {
543 match &self.kind {
544 IdKind::Undefined => {}
545 _ => unreachable!("duplicate id"),
546 }
547 self.kind = kind;
548 }
549 fn get_type(&self) -> Option<&Rc<Type>> {
550 match &self.kind {
551 IdKind::Type(t) => Some(t),
552 IdKind::VoidType => None,
553 _ => unreachable!("id is not type"),
554 }
555 }
556 fn get_nonvoid_type(&self) -> &Rc<Type> {
557 self.get_type().expect("void is not allowed here")
558 }
559 fn get_constant(&self) -> &Rc<Constant> {
560 match &self.kind {
561 IdKind::Constant(c) => c,
562 _ => unreachable!("id is not a constant"),
563 }
564 }
565 fn assert_no_member_decorations(&self, id: IdRef) {
566 for member_decoration in &self.member_decorations {
567 unreachable!(
568 "member decoration not allowed on {}: {:?}",
569 id, member_decoration
570 );
571 }
572 }
573 fn assert_no_decorations(&self, id: IdRef) {
574 self.assert_no_member_decorations(id);
575 for decoration in &self.decorations {
576 unreachable!("decoration not allowed on {}: {:?}", id, decoration);
577 }
578 }
579 }
580
581 struct Ids(Vec<IdProperties>);
582
583 impl fmt::Debug for Ids {
584 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
585 f.debug_map()
586 .entries(
587 self.0
588 .iter()
589 .enumerate()
590 .filter_map(|(id_index, id_properties)| {
591 if id_properties.is_empty() {
592 return None;
593 }
594 Some((IdRef(id_index as u32), id_properties))
595 }),
596 )
597 .finish()
598 }
599 }
600
601 impl Index<IdRef> for Ids {
602 type Output = IdProperties;
603 fn index(&self, index: IdRef) -> &IdProperties {
604 &self.0[index.0 as usize]
605 }
606 }
607
608 impl IndexMut<IdRef> for Ids {
609 fn index_mut(&mut self, index: IdRef) -> &mut IdProperties {
610 &mut self.0[index.0 as usize]
611 }
612 }
613
614 struct ParsedShaderFunction {
615 instructions: Vec<Instruction>,
616 }
617
618 impl fmt::Debug for ParsedShaderFunction {
619 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
620 write!(f, "ParsedShaderFunction:\n")?;
621 for instruction in &self.instructions {
622 write!(f, "{}", instruction)?;
623 }
624 Ok(())
625 }
626 }
627
628 #[derive(Debug)]
629 struct ParsedShader {
630 ids: Ids,
631 functions: Vec<ParsedShaderFunction>,
632 main_function_id: IdRef,
633 interface_variables: Vec<IdRef>,
634 execution_modes: Vec<ExecutionMode>,
635 workgroup_size: Option<(u32, u32, u32)>,
636 }
637
638 struct ShaderEntryPoint {
639 main_function_id: IdRef,
640 interface_variables: Vec<IdRef>,
641 }
642
643 impl ParsedShader {
644 #[cfg_attr(feature = "cargo-clippy", allow(clippy::cyclomatic_complexity))]
645 fn create(
646 context: &mut Context,
647 stage_info: ShaderStageCreateInfo,
648 execution_model: ExecutionModel,
649 ) -> Self {
650 let parser = spirv_parser::Parser::start(stage_info.code).unwrap();
651 let header = *parser.header();
652 assert_eq!(header.instruction_schema, 0);
653 assert_eq!(header.version.0, 1);
654 assert!(header.version.1 <= 3);
655 let instructions: Vec<_> = parser.map(Result::unwrap).collect();
656 println!("Parsing Shader:");
657 print!("{}", header);
658 for instruction in instructions.iter() {
659 print!("{}", instruction);
660 }
661 let mut ids = Ids((0..header.bound)
662 .map(|_| IdProperties {
663 kind: IdKind::Undefined,
664 decorations: Vec::new(),
665 member_decorations: Vec::new(),
666 })
667 .collect());
668 let mut entry_point = None;
669 let mut current_function: Option<ParsedShaderFunction> = None;
670 let mut functions = Vec::new();
671 let mut execution_modes = Vec::new();
672 let mut workgroup_size = None;
673 for instruction in instructions {
674 match current_function {
675 Some(mut function) => {
676 current_function = match instruction {
677 instruction @ Instruction::FunctionEnd {} => {
678 function.instructions.push(instruction);
679 functions.push(function);
680 None
681 }
682 instruction => {
683 function.instructions.push(instruction);
684 Some(function)
685 }
686 };
687 continue;
688 }
689 None => current_function = None,
690 }
691 match instruction {
692 instruction @ Instruction::Function { .. } => {
693 current_function = Some(ParsedShaderFunction {
694 instructions: vec![instruction],
695 });
696 }
697 Instruction::EntryPoint {
698 execution_model: current_execution_model,
699 entry_point: main_function_id,
700 name,
701 interface,
702 } => {
703 if execution_model == current_execution_model
704 && name == stage_info.entry_point_name
705 {
706 assert!(entry_point.is_none());
707 entry_point = Some(ShaderEntryPoint {
708 main_function_id,
709 interface_variables: interface.clone(),
710 });
711 }
712 }
713 Instruction::ExecutionMode {
714 entry_point: entry_point_id,
715 mode,
716 }
717 | Instruction::ExecutionModeId {
718 entry_point: entry_point_id,
719 mode,
720 } => {
721 if entry_point_id == entry_point.as_ref().unwrap().main_function_id {
722 execution_modes.push(mode);
723 }
724 }
725 Instruction::Decorate { target, decoration }
726 | Instruction::DecorateId { target, decoration } => {
727 ids[target].decorations.push(decoration);
728 }
729 Instruction::MemberDecorate {
730 structure_type,
731 member,
732 decoration,
733 } => {
734 ids[structure_type]
735 .member_decorations
736 .push(MemberDecoration { member, decoration });
737 }
738 Instruction::DecorationGroup { id_result } => {
739 ids[id_result.0].set_kind(IdKind::DecorationGroup);
740 }
741 Instruction::GroupDecorate {
742 decoration_group,
743 targets,
744 } => {
745 let decorations = ids[decoration_group].decorations.clone();
746 for target in targets {
747 ids[target]
748 .decorations
749 .extend(decorations.iter().map(Clone::clone));
750 }
751 }
752 Instruction::GroupMemberDecorate {
753 decoration_group,
754 targets,
755 } => {
756 let decorations = ids[decoration_group].decorations.clone();
757 for target in targets {
758 ids[target.0]
759 .member_decorations
760 .extend(decorations.iter().map(|decoration| MemberDecoration {
761 member: target.1,
762 decoration: decoration.clone(),
763 }));
764 }
765 }
766 Instruction::TypeFunction {
767 id_result,
768 return_type,
769 parameter_types,
770 } => {
771 ids[id_result.0].assert_no_decorations(id_result.0);
772 let kind = IdKind::FunctionType {
773 return_type: ids[return_type].get_type().map(Clone::clone),
774 arguments: parameter_types
775 .iter()
776 .map(|argument| ids[*argument].get_nonvoid_type().clone())
777 .collect(),
778 };
779 ids[id_result.0].set_kind(kind);
780 }
781 Instruction::TypeVoid { id_result } => {
782 ids[id_result.0].assert_no_decorations(id_result.0);
783 ids[id_result.0].set_kind(IdKind::VoidType);
784 }
785 Instruction::TypeBool { id_result } => {
786 ids[id_result.0].assert_no_decorations(id_result.0);
787 ids[id_result.0]
788 .set_kind(IdKind::Type(Rc::new(Type::Scalar(ScalarType::Bool))));
789 }
790 Instruction::TypeInt {
791 id_result,
792 width,
793 signedness,
794 } => {
795 ids[id_result.0].assert_no_decorations(id_result.0);
796 ids[id_result.0].set_kind(IdKind::Type(Rc::new(Type::Scalar(
797 match (width, signedness != 0) {
798 (8, false) => ScalarType::U8,
799 (8, true) => ScalarType::I8,
800 (16, false) => ScalarType::U16,
801 (16, true) => ScalarType::I16,
802 (32, false) => ScalarType::U32,
803 (32, true) => ScalarType::I32,
804 (64, false) => ScalarType::U64,
805 (64, true) => ScalarType::I64,
806 (width, signedness) => unreachable!(
807 "unsupported int type: {}{}",
808 if signedness { "i" } else { "u" },
809 width
810 ),
811 },
812 ))));
813 }
814 Instruction::TypeFloat { id_result, width } => {
815 ids[id_result.0].assert_no_decorations(id_result.0);
816 ids[id_result.0].set_kind(IdKind::Type(Rc::new(Type::Scalar(match width {
817 16 => ScalarType::F16,
818 32 => ScalarType::F32,
819 64 => ScalarType::F64,
820 _ => unreachable!("unsupported float type: f{}", width),
821 }))));
822 }
823 Instruction::TypeVector {
824 id_result,
825 component_type,
826 component_count,
827 } => {
828 ids[id_result.0].assert_no_decorations(id_result.0);
829 let element = ids[component_type].get_nonvoid_type().get_scalar().clone();
830 ids[id_result.0].set_kind(IdKind::Type(Rc::new(Type::Vector(VectorType {
831 element,
832 element_count: component_count as usize,
833 }))));
834 }
835 Instruction::TypeForwardPointer { pointer_type, .. } => {
836 ids[pointer_type].set_kind(IdKind::ForwardPointer(Rc::new(Type::Scalar(
837 ScalarType::Pointer(PointerType::unresolved()),
838 ))));
839 }
840 Instruction::TypePointer {
841 id_result,
842 type_: pointee,
843 ..
844 } => {
845 ids[id_result.0].assert_no_decorations(id_result.0);
846 let pointee = ids[pointee].get_type().map(Clone::clone);
847 let pointer = match mem::replace(&mut ids[id_result.0].kind, IdKind::Undefined)
848 {
849 IdKind::Undefined => Rc::new(Type::Scalar(ScalarType::Pointer(
850 PointerType::new(context, pointee),
851 ))),
852 IdKind::ForwardPointer(pointer) => {
853 if let Type::Scalar(ScalarType::Pointer(pointer)) = &*pointer {
854 pointer.resolve(context, pointee);
855 } else {
856 unreachable!();
857 }
858 pointer
859 }
860 _ => unreachable!("duplicate id"),
861 };
862 ids[id_result.0].set_kind(IdKind::Type(pointer));
863 }
864 Instruction::TypeStruct {
865 id_result,
866 member_types,
867 } => {
868 let decorations = ids[id_result.0].decorations.clone();
869 let struct_type = {
870 let mut members: Vec<_> = member_types
871 .into_iter()
872 .map(|member_type| StructMember {
873 decorations: Vec::new(),
874 member_type: match ids[member_type].kind {
875 IdKind::Type(ref t) => t.clone(),
876 IdKind::ForwardPointer(ref t) => t.clone(),
877 _ => unreachable!("invalid struct member type"),
878 },
879 })
880 .collect();
881 for member_decoration in &ids[id_result.0].member_decorations {
882 members[member_decoration.member as usize]
883 .decorations
884 .push(member_decoration.decoration.clone());
885 }
886 StructType {
887 id: StructId::new(context),
888 decorations,
889 members,
890 }
891 };
892 ids[id_result.0].set_kind(IdKind::Type(Rc::new(Type::Struct(struct_type))));
893 }
894 Instruction::TypeRuntimeArray {
895 id_result,
896 element_type,
897 } => {
898 ids[id_result.0].assert_no_member_decorations(id_result.0);
899 let decorations = ids[id_result.0].decorations.clone();
900 let element = ids[element_type].get_nonvoid_type().clone();
901 ids[id_result.0].set_kind(IdKind::Type(Rc::new(Type::Array(ArrayType {
902 decorations,
903 element,
904 element_count: None,
905 }))));
906 }
907 Instruction::Variable {
908 id_result_type,
909 id_result,
910 storage_class,
911 initializer,
912 } => {
913 ids[id_result.0].assert_no_member_decorations(id_result.0);
914 if let Some(built_in) =
915 ids[id_result.0]
916 .decorations
917 .iter()
918 .find_map(|decoration| match *decoration {
919 Decoration::BuiltIn { built_in } => Some(built_in),
920 _ => None,
921 }) {
922 let built_in_variable = match built_in {
923 BuiltIn::GlobalInvocationId => {
924 for decoration in &ids[id_result.0].decorations {
925 match decoration {
926 Decoration::BuiltIn { .. } => {}
927 _ => unimplemented!(
928 "unimplemented decoration on {:?}: {:?}",
929 built_in,
930 decoration
931 ),
932 }
933 }
934 assert!(initializer.is_none());
935 BuiltInVariable { built_in }
936 }
937 _ => unimplemented!("unimplemented built-in: {:?}", built_in),
938 };
939 assert_eq!(
940 built_in_variable.get_type(context),
941 ids[id_result_type.0]
942 .get_nonvoid_type()
943 .get_nonvoid_pointee()
944 );
945 ids[id_result.0].set_kind(IdKind::BuiltInVariable(built_in_variable));
946 } else {
947 let variable_type = ids[id_result_type.0].get_nonvoid_type().clone();
948 match storage_class {
949 StorageClass::Uniform => {
950 let mut descriptor_set = None;
951 let mut binding = None;
952 for decoration in &ids[id_result.0].decorations {
953 match *decoration {
954 Decoration::DescriptorSet { descriptor_set: v } => {
955 assert!(
956 descriptor_set.is_none(),
957 "duplicate DescriptorSet decoration"
958 );
959 descriptor_set = Some(v);
960 }
961 Decoration::Binding { binding_point: v } => {
962 assert!(
963 binding.is_none(),
964 "duplicate Binding decoration"
965 );
966 binding = Some(v);
967 }
968 _ => unimplemented!(
969 "unimplemented decoration on uniform variable: {:?}",
970 decoration
971 ),
972 }
973 }
974 let descriptor_set = descriptor_set
975 .expect("uniform variable is missing DescriptorSet decoration");
976 let binding = binding
977 .expect("uniform variable is missing Binding decoration");
978 assert!(initializer.is_none());
979 ids[id_result.0].set_kind(IdKind::UniformVariable(
980 UniformVariable {
981 binding,
982 descriptor_set,
983 variable_type,
984 },
985 ));
986 }
987 StorageClass::Input => unimplemented!(),
988 _ => unimplemented!(
989 "unimplemented OpVariable StorageClass: {:?}",
990 storage_class
991 ),
992 }
993 }
994 }
995 Instruction::Constant32 {
996 id_result_type,
997 id_result,
998 value,
999 } => {
1000 ids[id_result.0].assert_no_decorations(id_result.0);
1001 #[cfg_attr(feature = "cargo-clippy", allow(clippy::cast_lossless))]
1002 let constant = match **ids[id_result_type.0].get_nonvoid_type() {
1003 Type::Scalar(ScalarType::U8) => {
1004 let converted_value = value as u8;
1005 assert_eq!(converted_value as u32, value);
1006 Constant::Scalar(ScalarConstant::U8(Undefable::Defined(
1007 converted_value,
1008 )))
1009 }
1010 Type::Scalar(ScalarType::U16) => {
1011 let converted_value = value as u16;
1012 assert_eq!(converted_value as u32, value);
1013 Constant::Scalar(ScalarConstant::U16(Undefable::Defined(
1014 converted_value,
1015 )))
1016 }
1017 Type::Scalar(ScalarType::U32) => {
1018 Constant::Scalar(ScalarConstant::U32(Undefable::Defined(value)))
1019 }
1020 Type::Scalar(ScalarType::I8) => {
1021 let converted_value = value as i8;
1022 assert_eq!(converted_value as u32, value);
1023 Constant::Scalar(ScalarConstant::I8(Undefable::Defined(
1024 converted_value,
1025 )))
1026 }
1027 Type::Scalar(ScalarType::I16) => {
1028 let converted_value = value as i16;
1029 assert_eq!(converted_value as u32, value);
1030 Constant::Scalar(ScalarConstant::I16(Undefable::Defined(
1031 converted_value,
1032 )))
1033 }
1034 Type::Scalar(ScalarType::I32) => {
1035 Constant::Scalar(ScalarConstant::I32(Undefable::Defined(value as i32)))
1036 }
1037 Type::Scalar(ScalarType::F16) => {
1038 let converted_value = value as u16;
1039 assert_eq!(converted_value as u32, value);
1040 Constant::Scalar(ScalarConstant::F16(Undefable::Defined(
1041 converted_value,
1042 )))
1043 }
1044 Type::Scalar(ScalarType::F32) => Constant::Scalar(ScalarConstant::F32(
1045 Undefable::Defined(f32::from_bits(value)),
1046 )),
1047 _ => unreachable!("invalid type"),
1048 };
1049 ids[id_result.0].set_kind(IdKind::Constant(Rc::new(constant)));
1050 }
1051 Instruction::Constant64 {
1052 id_result_type,
1053 id_result,
1054 value,
1055 } => {
1056 ids[id_result.0].assert_no_decorations(id_result.0);
1057 let constant = match **ids[id_result_type.0].get_nonvoid_type() {
1058 Type::Scalar(ScalarType::U64) => {
1059 Constant::Scalar(ScalarConstant::U64(Undefable::Defined(value)))
1060 }
1061 Type::Scalar(ScalarType::I64) => {
1062 Constant::Scalar(ScalarConstant::I64(Undefable::Defined(value as i64)))
1063 }
1064 Type::Scalar(ScalarType::F64) => Constant::Scalar(ScalarConstant::F64(
1065 Undefable::Defined(f64::from_bits(value)),
1066 )),
1067 _ => unreachable!("invalid type"),
1068 };
1069 ids[id_result.0].set_kind(IdKind::Constant(Rc::new(constant)));
1070 }
1071 Instruction::ConstantFalse {
1072 id_result_type,
1073 id_result,
1074 } => {
1075 ids[id_result.0].assert_no_decorations(id_result.0);
1076 let constant = match **ids[id_result_type.0].get_nonvoid_type() {
1077 Type::Scalar(ScalarType::Bool) => {
1078 Constant::Scalar(ScalarConstant::Bool(Undefable::Defined(false)))
1079 }
1080 _ => unreachable!("invalid type"),
1081 };
1082 ids[id_result.0].set_kind(IdKind::Constant(Rc::new(constant)));
1083 }
1084 Instruction::ConstantTrue {
1085 id_result_type,
1086 id_result,
1087 } => {
1088 ids[id_result.0].assert_no_decorations(id_result.0);
1089 let constant = match **ids[id_result_type.0].get_nonvoid_type() {
1090 Type::Scalar(ScalarType::Bool) => {
1091 Constant::Scalar(ScalarConstant::Bool(Undefable::Defined(true)))
1092 }
1093 _ => unreachable!("invalid type"),
1094 };
1095 ids[id_result.0].set_kind(IdKind::Constant(Rc::new(constant)));
1096 }
1097 Instruction::ConstantComposite {
1098 id_result_type,
1099 id_result,
1100 constituents,
1101 } => {
1102 let constant = match **ids[id_result_type.0].get_nonvoid_type() {
1103 Type::Vector(VectorType {
1104 ref element,
1105 element_count,
1106 }) => {
1107 assert_eq!(element_count, constituents.len());
1108 let constituents = constituents
1109 .iter()
1110 .map(|id| *ids[*id].get_constant().get_scalar());
1111 match *element {
1112 ScalarType::U8 => {
1113 VectorConstant::U8(constituents.map(|v| v.get_u8()).collect())
1114 }
1115 ScalarType::U16 => {
1116 VectorConstant::U16(constituents.map(|v| v.get_u16()).collect())
1117 }
1118 ScalarType::U32 => {
1119 VectorConstant::U32(constituents.map(|v| v.get_u32()).collect())
1120 }
1121 ScalarType::U64 => {
1122 VectorConstant::U64(constituents.map(|v| v.get_u64()).collect())
1123 }
1124 ScalarType::I8 => {
1125 VectorConstant::I8(constituents.map(|v| v.get_i8()).collect())
1126 }
1127 ScalarType::I16 => {
1128 VectorConstant::I16(constituents.map(|v| v.get_i16()).collect())
1129 }
1130 ScalarType::I32 => {
1131 VectorConstant::I32(constituents.map(|v| v.get_i32()).collect())
1132 }
1133 ScalarType::I64 => {
1134 VectorConstant::I64(constituents.map(|v| v.get_i64()).collect())
1135 }
1136 ScalarType::F16 => {
1137 VectorConstant::F16(constituents.map(|v| v.get_f16()).collect())
1138 }
1139 ScalarType::F32 => {
1140 VectorConstant::F32(constituents.map(|v| v.get_f32()).collect())
1141 }
1142 ScalarType::F64 => {
1143 VectorConstant::F64(constituents.map(|v| v.get_f64()).collect())
1144 }
1145 ScalarType::Bool => VectorConstant::Bool(
1146 constituents.map(|v| v.get_bool()).collect(),
1147 ),
1148 ScalarType::Pointer(_) => unimplemented!(),
1149 }
1150 }
1151 _ => unimplemented!(),
1152 };
1153 for decoration in &ids[id_result.0].decorations {
1154 match decoration {
1155 Decoration::BuiltIn {
1156 built_in: BuiltIn::WorkgroupSize,
1157 } => {
1158 assert!(
1159 workgroup_size.is_none(),
1160 "duplicate WorkgroupSize decorations"
1161 );
1162 workgroup_size = match constant {
1163 VectorConstant::U32(ref v) => {
1164 assert_eq!(
1165 v.len(),
1166 3,
1167 "invalid type for WorkgroupSize built-in"
1168 );
1169 Some((v[0].unwrap(), v[1].unwrap(), v[2].unwrap()))
1170 }
1171 _ => unreachable!("invalid type for WorkgroupSize built-in"),
1172 };
1173 }
1174 _ => unimplemented!(
1175 "unimplemented decoration on constant {:?}: {:?}",
1176 Constant::Vector(constant),
1177 decoration
1178 ),
1179 }
1180 }
1181 ids[id_result.0].assert_no_member_decorations(id_result.0);
1182 ids[id_result.0]
1183 .set_kind(IdKind::Constant(Rc::new(Constant::Vector(constant))));
1184 }
1185 Instruction::MemoryModel {
1186 addressing_model,
1187 memory_model,
1188 } => {
1189 assert_eq!(addressing_model, spirv_parser::AddressingModel::Logical);
1190 assert_eq!(memory_model, spirv_parser::MemoryModel::GLSL450);
1191 }
1192 Instruction::Capability { .. }
1193 | Instruction::ExtInstImport { .. }
1194 | Instruction::Source { .. }
1195 | Instruction::SourceExtension { .. }
1196 | Instruction::Name { .. }
1197 | Instruction::MemberName { .. } => {}
1198 Instruction::SpecConstant32 { .. } => unimplemented!(),
1199 Instruction::SpecConstant64 { .. } => unimplemented!(),
1200 Instruction::SpecConstantTrue { .. } => unimplemented!(),
1201 Instruction::SpecConstantFalse { .. } => unimplemented!(),
1202 Instruction::SpecConstantOp { .. } => unimplemented!(),
1203 instruction => unimplemented!("unimplemented instruction:\n{}", instruction),
1204 }
1205 }
1206 assert!(
1207 current_function.is_none(),
1208 "missing terminating OpFunctionEnd"
1209 );
1210 let ShaderEntryPoint {
1211 main_function_id,
1212 interface_variables,
1213 } = entry_point.unwrap();
1214 ParsedShader {
1215 ids,
1216 functions,
1217 main_function_id,
1218 interface_variables,
1219 execution_modes,
1220 workgroup_size,
1221 }
1222 }
1223 }
1224
1225 #[derive(Clone, Debug)]
1226 pub struct GenericPipelineOptions {
1227 pub optimization_mode: shader_compiler_backend::OptimizationMode,
1228 }
1229
1230 #[derive(Debug)]
1231 pub struct PipelineLayout {}
1232
1233 #[derive(Debug)]
1234 pub struct ComputePipeline {}
1235
1236 #[derive(Clone, Debug)]
1237 pub struct ComputePipelineOptions {
1238 pub generic_options: GenericPipelineOptions,
1239 }
1240
1241 #[derive(Copy, Clone, Debug)]
1242 pub struct Specialization<'a> {
1243 pub id: u32,
1244 pub bytes: &'a [u8],
1245 }
1246
1247 #[derive(Copy, Clone, Debug)]
1248 pub struct ShaderStageCreateInfo<'a> {
1249 pub code: &'a [u32],
1250 pub entry_point_name: &'a str,
1251 pub specializations: &'a [Specialization<'a>],
1252 }
1253
1254 impl ComputePipeline {
1255 pub fn new(
1256 _options: &ComputePipelineOptions,
1257 compute_shader_stage: ShaderStageCreateInfo,
1258 ) -> ComputePipeline {
1259 let mut context = Context::default();
1260 let parsed_shader = ParsedShader::create(
1261 &mut context,
1262 compute_shader_stage,
1263 ExecutionModel::GLCompute,
1264 );
1265 println!("parsed_shader:\n{:#?}", parsed_shader);
1266 unimplemented!()
1267 }
1268 }