working on adding Switch CFG
authorJacob Lifshay <programmerjake@gmail.com>
Mon, 26 Nov 2018 09:17:31 +0000 (01:17 -0800)
committerJacob Lifshay <programmerjake@gmail.com>
Mon, 26 Nov 2018 09:17:31 +0000 (01:17 -0800)
shader-compiler/src/cfg.rs

index b25fe7a306df6655cb2f17c0b9d7aa1117134177..f5522f293970db9ee4c24cc52f7650b3cc4aa5f8 100644 (file)
@@ -193,14 +193,19 @@ impl<T: GenericNode> From<Rc<T>> for Node {
     }
 }
 
+#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
+enum SwitchCaseKind {
+    Default,
+    Normal,
+}
+
 #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
 enum BlockKind {
     Unknown,
     ConditionMerge,
     LoopMerge,
     LoopContinue,
-    SwitchCase,
-    SwitchDefault,
+    SwitchCase(SwitchCaseKind),
     SwitchMerge,
 }
 
@@ -331,9 +336,34 @@ impl ParseState {
                     }
                     BlockKind::LoopMerge => unimplemented!(),
                     BlockKind::LoopContinue => unimplemented!(),
-                    BlockKind::SwitchCase => unimplemented!(),
-                    BlockKind::SwitchDefault => unimplemented!(),
-                    BlockKind::SwitchMerge => unimplemented!(),
+                    BlockKind::SwitchCase(kind) => {
+                        let mut switch = self.get_switch();
+                        let expected_target_label = match kind {
+                            SwitchCaseKind::Normal => {
+                                switch.next_case.unwrap_or(switch.default_label)
+                            }
+                            SwitchCaseKind::Default => switch.default_label,
+                        };
+                        assert_eq!(
+                            target_label, expected_target_label,
+                            "invalid branch to next switch case"
+                        );
+                        unimplemented!()
+                    }
+                    BlockKind::SwitchMerge => {
+                        assert_eq!(
+                            target_label,
+                            self.get_switch().merge_label,
+                            "invalid branch to merge block"
+                        );
+                        let retval = Rc::new(SwitchMergeNode {
+                            label: label_id,
+                            instructions: basic_block.get_instructions(),
+                            switch: Default::default(),
+                        });
+                        self.get_switch().merges.push(retval.clone());
+                        retval.into()
+                    }
                 }
             }
             (
@@ -409,11 +439,13 @@ impl ParseState {
                 get_basic_block(basic_blocks, merge_block).set_kind(BlockKind::SwitchMerge);
                 for &(_, target) in targets {
                     if target != merge_block {
-                        get_basic_block(basic_blocks, target).set_kind(BlockKind::SwitchCase);
+                        get_basic_block(basic_blocks, target)
+                            .set_kind(BlockKind::SwitchCase(SwitchCaseKind::Normal));
                     }
                 }
                 if default_label != merge_block {
-                    get_basic_block(basic_blocks, default_label).set_kind(BlockKind::SwitchDefault);
+                    get_basic_block(basic_blocks, default_label)
+                        .set_kind(BlockKind::SwitchCase(SwitchCaseKind::Default));
                 }
                 let old_switch = self.push_switch(ParseStateSwitch {
                     default_label: default_label,
@@ -449,7 +481,13 @@ impl ParseState {
                     } else if let Some(default_fallthrough) = &default_fallthrough {
                         unimplemented!()
                     } else {
-                        unimplemented!()
+                        (
+                            cases,
+                            Some(SwitchDefault {
+                                default_case: default,
+                                after_default_cases: vec![],
+                            }),
+                        )
                     }
                 } else {
                     (cases, None)
@@ -704,6 +742,16 @@ mod tests {
         }
     }
 
+    fn test_cfg(instructions: &[Instruction], expected: &[SerializedCFGElement]) {
+        println!("instructions:");
+        for instruction in instructions {
+            print!("{}", instruction);
+        }
+        println!();
+        let cfg = create_cfg(&instructions);
+        assert_eq!(&*cfg.serialize_cfg_into_vec(), expected);
+    }
+
     #[test]
     fn test_cfg_return() {
         let mut id_factory = IdFactory::new();
@@ -716,11 +764,7 @@ mod tests {
         });
         instructions.push(Instruction::Return);
 
-        let cfg = create_cfg(&instructions);
-        assert_eq!(
-            &cfg.serialize_cfg_into_vec(),
-            &[SerializedCFGElement::Return]
-        );
+        test_cfg(&instructions, &[SerializedCFGElement::Return]);
     }
 
     #[test]
@@ -737,11 +781,7 @@ mod tests {
             value: id_factory.next(),
         });
 
-        let cfg = create_cfg(&instructions);
-        assert_eq!(
-            &cfg.serialize_cfg_into_vec(),
-            &[SerializedCFGElement::Return]
-        );
+        test_cfg(&instructions, &[SerializedCFGElement::Return]);
     }
 
     #[test]
@@ -765,10 +805,9 @@ mod tests {
         });
         instructions.push(Instruction::Kill);
 
-        let cfg = create_cfg(&instructions);
-        assert_eq!(
-            &cfg.serialize_cfg_into_vec(),
-            &[SerializedCFGElement::Simple, SerializedCFGElement::Discard]
+        test_cfg(
+            &instructions,
+            &[SerializedCFGElement::Simple, SerializedCFGElement::Discard],
         );
     }
 
@@ -800,14 +839,13 @@ mod tests {
         });
         instructions.push(Instruction::Return);
 
-        let cfg = create_cfg(&instructions);
-        assert_eq!(
-            &cfg.serialize_cfg_into_vec(),
+        test_cfg(
+            &instructions,
             &[
                 SerializedCFGElement::Condition,
                 SerializedCFGElement::ConditionEnd,
-                SerializedCFGElement::Return
-            ]
+                SerializedCFGElement::Return,
+            ],
         );
     }
 
@@ -847,16 +885,15 @@ mod tests {
         });
         instructions.push(Instruction::Return);
 
-        let cfg = create_cfg(&instructions);
-        assert_eq!(
-            &cfg.serialize_cfg_into_vec(),
+        test_cfg(
+            &instructions,
             &[
                 SerializedCFGElement::Condition,
                 SerializedCFGElement::ConditionTrue,
                 SerializedCFGElement::ConditionMerge,
                 SerializedCFGElement::ConditionEnd,
-                SerializedCFGElement::Return
-            ]
+                SerializedCFGElement::Return,
+            ],
         );
     }
 
@@ -902,9 +939,8 @@ mod tests {
         });
         instructions.push(Instruction::Return);
 
-        let cfg = create_cfg(&instructions);
-        assert_eq!(
-            &cfg.serialize_cfg_into_vec(),
+        test_cfg(
+            &instructions,
             &[
                 SerializedCFGElement::Condition,
                 SerializedCFGElement::ConditionTrue,
@@ -912,8 +948,8 @@ mod tests {
                 SerializedCFGElement::ConditionFalse,
                 SerializedCFGElement::ConditionMerge,
                 SerializedCFGElement::ConditionEnd,
-                SerializedCFGElement::Return
-            ]
+                SerializedCFGElement::Return,
+            ],
         );
     }
 
@@ -952,14 +988,127 @@ mod tests {
         });
         instructions.push(Instruction::Return);
 
-        let cfg = create_cfg(&instructions);
-        assert_eq!(
-            &cfg.serialize_cfg_into_vec(),
+        test_cfg(
+            &instructions,
+            &[
+                SerializedCFGElement::Switch,
+                SerializedCFGElement::SwitchDefaultCase,
+                SerializedCFGElement::SwitchMerge,
+                SerializedCFGElement::SwitchEnd,
+                SerializedCFGElement::Return,
+            ],
+        );
+    }
+
+    #[test]
+    fn test_cfg_switch_return_default_break() {
+        let mut id_factory = IdFactory::new();
+        let mut instructions = Vec::new();
+
+        let label_start = id_factory.next();
+        let label_case1 = id_factory.next();
+        let label_default = id_factory.next();
+        let label_merge = id_factory.next();
+
+        instructions.push(Instruction::NoLine);
+        instructions.push(Instruction::Label {
+            id_result: IdResult(label_start),
+        });
+        instructions.push(Instruction::SelectionMerge {
+            merge_block: label_merge,
+            selection_control: spirv_parser::SelectionControl::default(),
+        });
+        instructions.push(Instruction::Switch64 {
+            selector: id_factory.next(),
+            default: label_default,
+            target: vec![(0, label_case1)],
+        });
+
+        instructions.push(Instruction::Label {
+            id_result: IdResult(label_case1),
+        });
+        instructions.push(Instruction::Return);
+
+        instructions.push(Instruction::Label {
+            id_result: IdResult(label_default),
+        });
+        instructions.push(Instruction::Branch {
+            target_label: label_merge,
+        });
+
+        instructions.push(Instruction::Label {
+            id_result: IdResult(label_merge),
+        });
+        instructions.push(Instruction::Return);
+
+        test_cfg(
+            &instructions,
+            &[
+                SerializedCFGElement::Switch,
+                SerializedCFGElement::SwitchCase,
+                SerializedCFGElement::Return,
+                SerializedCFGElement::SwitchDefaultCase,
+                SerializedCFGElement::SwitchMerge,
+                SerializedCFGElement::SwitchEnd,
+                SerializedCFGElement::Return,
+            ],
+        );
+    }
+
+    #[test]
+    fn test_cfg_switch_fallthrough_default_break() {
+        let mut id_factory = IdFactory::new();
+        let mut instructions = Vec::new();
+
+        let label_start = id_factory.next();
+        let label_case1 = id_factory.next();
+        let label_default = id_factory.next();
+        let label_merge = id_factory.next();
+
+        instructions.push(Instruction::NoLine);
+        instructions.push(Instruction::Label {
+            id_result: IdResult(label_start),
+        });
+        instructions.push(Instruction::SelectionMerge {
+            merge_block: label_merge,
+            selection_control: spirv_parser::SelectionControl::default(),
+        });
+        instructions.push(Instruction::Switch64 {
+            selector: id_factory.next(),
+            default: label_default,
+            target: vec![(0, label_case1)],
+        });
+
+        instructions.push(Instruction::Label {
+            id_result: IdResult(label_case1),
+        });
+        instructions.push(Instruction::Branch {
+            target_label: label_default,
+        });
+
+        instructions.push(Instruction::Label {
+            id_result: IdResult(label_default),
+        });
+        instructions.push(Instruction::Branch {
+            target_label: label_merge,
+        });
+
+        instructions.push(Instruction::Label {
+            id_result: IdResult(label_merge),
+        });
+        instructions.push(Instruction::Return);
+
+        test_cfg(
+            &instructions,
             &[
                 SerializedCFGElement::Switch,
+                SerializedCFGElement::SwitchCase,
+                SerializedCFGElement::Return,
+                SerializedCFGElement::SwitchDefaultCase,
+                SerializedCFGElement::SwitchMerge,
                 SerializedCFGElement::SwitchEnd,
-                SerializedCFGElement::Return
-            ]
+                SerializedCFGElement::Return,
+            ],
         );
     }
 }