From 48b69fb9efb324daa90159836acf98a6f312930b Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Mon, 26 Nov 2018 22:57:58 -0800 Subject: [PATCH] implemented switch CFG parsing --- shader-compiler/src/cfg.rs | 413 ++++++++++++++++++++++++++++--------- 1 file changed, 313 insertions(+), 100 deletions(-) diff --git a/shader-compiler/src/cfg.rs b/shader-compiler/src/cfg.rs index f5522f2..393df4a 100644 --- a/shader-compiler/src/cfg.rs +++ b/shader-compiler/src/cfg.rs @@ -65,7 +65,6 @@ pub(crate) struct SwitchFallthroughNode { pub(crate) label: IdRef, pub(crate) instructions: Vec, pub(crate) switch: RefCell>, - pub(crate) target_label: IdRef, } impl GenericNode for SwitchFallthroughNode { @@ -187,6 +186,33 @@ pub(crate) enum Node { ConditionMerge(Rc), } +impl Node { + pub(crate) fn instructions(&self) -> &Vec { + match self { + Node::Simple(v) => v.instructions(), + Node::Return(v) => v.instructions(), + Node::Discard(v) => v.instructions(), + Node::Switch(v) => v.instructions(), + Node::SwitchFallthrough(v) => v.instructions(), + Node::SwitchMerge(v) => v.instructions(), + Node::Condition(v) => v.instructions(), + Node::ConditionMerge(v) => v.instructions(), + } + } + pub(crate) fn label(&self) -> IdRef { + match self { + Node::Simple(v) => v.label(), + Node::Return(v) => v.label(), + Node::Discard(v) => v.label(), + Node::Switch(v) => v.label(), + Node::SwitchFallthrough(v) => v.label(), + Node::SwitchMerge(v) => v.label(), + Node::Condition(v) => v.label(), + Node::ConditionMerge(v) => v.label(), + } + } +} + impl From> for Node { fn from(v: Rc) -> Node { GenericNode::to_node(v) @@ -251,12 +277,11 @@ struct ParseStateCondition { } struct ParseStateSwitch { - fallthrough_to_default: Option>, - fallthroughs: Vec>, + fallthrough: Option>, default_label: IdRef, - next_case: Option, merges: Vec>, merge_label: IdRef, + fallthrough_target: Option, } struct ParseState { @@ -289,6 +314,129 @@ impl ParseState { fn get_switch(&mut self) -> &mut ParseStateSwitch { self.switch.as_mut().unwrap() } + fn parse_switch( + &mut self, + basic_blocks: &HashMap, + label_id: IdRef, + basic_block: &BasicBlock, + targets: &[(T, IdRef)], + default_label: IdRef, + merge_block: IdRef, + ) -> Node { + get_basic_block(basic_blocks, merge_block).set_kind(BlockKind::SwitchMerge); + let mut last_target = None; + for &(_, target) in targets { + if Some(target) == last_target { + continue; + } + last_target = Some(target); + if target != merge_block { + get_basic_block(basic_blocks, target) + .set_kind(BlockKind::SwitchCase(SwitchCaseKind::Normal)); + } + } + if default_label != merge_block { + get_basic_block(basic_blocks, default_label) + .set_kind(BlockKind::SwitchCase(SwitchCaseKind::Default)); + } + let old_switch = self.push_switch(ParseStateSwitch { + default_label: default_label, + fallthrough: None, + merge_label: merge_block, + merges: vec![], + fallthrough_target: None, + }); + let default_node = if default_label != merge_block { + Some(self.parse(basic_blocks, default_label)) + } else { + None + }; + let mut default_fallthrough = self.get_switch().fallthrough.take(); + let mut default_fallthrough_target = self.get_switch().fallthrough_target.take(); + let mut cases = Vec::with_capacity(targets.len()); + struct Case { + node: Node, + fallthrough: Option>, + fallthrough_target: Option, + } + let mut last_target = None; + for (index, &(_, target)) in targets.iter().enumerate() { + if Some(target) == last_target { + continue; + } + last_target = Some(target); + let node = self.parse(basic_blocks, target); + let fallthrough_target = self.get_switch().fallthrough_target.take(); + if let Some(fallthrough_target) = fallthrough_target { + if default_label != fallthrough_target { + assert_eq!( + Some(fallthrough_target), + targets.get(index + 1).map(|v| v.1), + "invalid fallthrough branch" + ); + } + } + cases.push(Case { + node, + fallthrough: self.get_switch().fallthrough.take(), + fallthrough_target, + }); + } + let switch = self.pop_switch(old_switch); + let mut before_default_cases = None; + let mut output_cases = vec![]; + let mut fallthroughs = vec![]; + fallthroughs.extend(default_fallthrough); + for ( + index, + Case { + node, + fallthrough, + fallthrough_target, + }, + ) in cases.into_iter().enumerate() + { + if Some(node.label()) == default_fallthrough_target { + if before_default_cases.is_none() { + before_default_cases = Some(mem::replace(&mut output_cases, vec![])); + } else { + assert!(output_cases.is_empty(), "invalid fallthrough branch"); + } + } + output_cases.push(node); + fallthroughs.extend(fallthrough); + if Some(default_label) == fallthrough_target { + assert!(before_default_cases.is_none()); + before_default_cases = Some(mem::replace(&mut output_cases, vec![])); + } + } + let before_default_cases = + before_default_cases.unwrap_or_else(|| mem::replace(&mut output_cases, vec![])); + let default = if let Some(default_node) = default_node { + Some(SwitchDefault { + default_case: default_node, + after_default_cases: output_cases, + }) + } else { + None + }; + let next = self.parse(basic_blocks, merge_block); + let retval = Rc::new(SwitchNode { + label: label_id, + instructions: basic_block.get_instructions(), + before_default_cases, + default, + next, + }); + for fallthrough in fallthroughs { + fallthrough.switch.replace(Rc::downgrade(&retval)); + } + for merge in switch.merges { + merge.switch.replace(Rc::downgrade(&retval)); + } + retval.into() + } + #[cfg_attr(feature = "cargo-clippy", allow(clippy::cyclomatic_complexity))] fn parse(&mut self, basic_blocks: &HashMap, label_id: IdRef) -> Node { let basic_block = get_basic_block(basic_blocks, label_id); let (terminating_instruction, instructions_without_terminator) = basic_block @@ -338,17 +486,16 @@ impl ParseState { BlockKind::LoopContinue => unimplemented!(), BlockKind::SwitchCase(kind) => { let mut switch = self.get_switch(); - let expected_target_label = match kind { - SwitchCaseKind::Normal => { - switch.next_case.unwrap_or(switch.default_label) - } - SwitchCaseKind::Default => switch.default_label, - }; - assert_eq!( - target_label, expected_target_label, - "invalid branch to next switch case" - ); - unimplemented!() + let retval = Rc::new(SwitchFallthroughNode { + label: label_id, + instructions: basic_block.get_instructions(), + switch: Default::default(), + }); + assert!(switch.fallthrough_target.is_none()); + assert!(switch.fallthrough.is_none()); + switch.fallthrough_target = Some(target_label); + switch.fallthrough = Some(retval.clone()); + retval.into() } BlockKind::SwitchMerge => { assert_eq!( @@ -420,14 +567,19 @@ impl ParseState { } ( &Instruction::Switch32 { - default, + default: default_label, target: ref targets, .. }, Some(&Instruction::SelectionMerge { merge_block, .. }), - ) => { - unimplemented!(); - } + ) => self.parse_switch( + basic_blocks, + label_id, + basic_block, + targets, + default_label, + merge_block, + ), ( &Instruction::Switch64 { default: default_label, @@ -435,87 +587,14 @@ impl ParseState { .. }, Some(&Instruction::SelectionMerge { merge_block, .. }), - ) => { - get_basic_block(basic_blocks, merge_block).set_kind(BlockKind::SwitchMerge); - for &(_, target) in targets { - if target != merge_block { - get_basic_block(basic_blocks, target) - .set_kind(BlockKind::SwitchCase(SwitchCaseKind::Normal)); - } - } - if default_label != merge_block { - get_basic_block(basic_blocks, default_label) - .set_kind(BlockKind::SwitchCase(SwitchCaseKind::Default)); - } - let old_switch = self.push_switch(ParseStateSwitch { - default_label: default_label, - fallthrough_to_default: None, - merge_label: merge_block, - fallthroughs: vec![], - merges: vec![], - next_case: None, - }); - let default = if default_label != merge_block { - Some(self.parse(basic_blocks, default_label)) - } else { - None - }; - let mut default_fallthrough = None; - for i in self.get_switch().fallthroughs.drain(..) { - assert!( - default_fallthrough.is_none(), - "multiple fallthroughs from default case" - ); - default_fallthrough = Some(i); - } - let mut cases = Vec::with_capacity(targets.len()); - for (index, &(_, target)) in targets.iter().enumerate() { - self.get_switch().next_case = targets.get(index + 1).map(|v| v.1); - cases.push(self.parse(basic_blocks, target)); - } - let switch = self.pop_switch(old_switch); - let (before_default_cases, default) = if let Some(default) = default { - if let Some(fallthrough_to_default) = &switch.fallthrough_to_default { - // FIXME: handle default_fallthrough - unimplemented!() - } else if let Some(default_fallthrough) = &default_fallthrough { - unimplemented!() - } else { - ( - cases, - Some(SwitchDefault { - default_case: default, - after_default_cases: vec![], - }), - ) - } - } else { - (cases, None) - }; - let next = self.parse(basic_blocks, merge_block); - let retval = Rc::new(SwitchNode { - label: label_id, - instructions: basic_block.get_instructions(), - before_default_cases, - default, - next, - }); - if let Some(default_fallthrough) = default_fallthrough { - default_fallthrough.switch.replace(Rc::downgrade(&retval)); - } - if let Some(fallthrough_to_default) = switch.fallthrough_to_default { - fallthrough_to_default - .switch - .replace(Rc::downgrade(&retval)); - } - for fallthrough in switch.fallthroughs { - fallthrough.switch.replace(Rc::downgrade(&retval)); - } - for merge in switch.merges { - merge.switch.replace(Rc::downgrade(&retval)); - } - retval.into() - } + ) => self.parse_switch( + basic_blocks, + label_id, + basic_block, + targets, + default_label, + merge_block, + ), (&Instruction::Switch32 { .. }, _) => unreachable!("missing merge instruction"), (&Instruction::Switch64 { .. }, _) => unreachable!("missing merge instruction"), (&Instruction::Kill {}, _) => Rc::new(DiscardNode { @@ -1103,8 +1182,142 @@ mod tests { &[ SerializedCFGElement::Switch, SerializedCFGElement::SwitchCase, + SerializedCFGElement::SwitchFallthrough, + SerializedCFGElement::SwitchDefaultCase, + SerializedCFGElement::SwitchMerge, + SerializedCFGElement::SwitchEnd, + SerializedCFGElement::Return, + ], + ); + } + + #[test] + fn test_cfg_switch_fallthrough_default_fallthrough_break() { + let mut id_factory = IdFactory::new(); + let mut instructions = Vec::new(); + + let label_start = id_factory.next(); + let label_case1 = id_factory.next(); + let label_default = id_factory.next(); + let label_case2 = id_factory.next(); + let label_merge = id_factory.next(); + + instructions.push(Instruction::NoLine); + instructions.push(Instruction::Label { + id_result: IdResult(label_start), + }); + instructions.push(Instruction::SelectionMerge { + merge_block: label_merge, + selection_control: spirv_parser::SelectionControl::default(), + }); + instructions.push(Instruction::Switch64 { + selector: id_factory.next(), + default: label_default, + target: vec![(0, label_case1), (1, label_case1), (2, label_case2)], + }); + + instructions.push(Instruction::Label { + id_result: IdResult(label_case1), + }); + instructions.push(Instruction::Branch { + target_label: label_default, + }); + + instructions.push(Instruction::Label { + id_result: IdResult(label_default), + }); + instructions.push(Instruction::Branch { + target_label: label_case2, + }); + + instructions.push(Instruction::Label { + id_result: IdResult(label_case2), + }); + instructions.push(Instruction::Branch { + target_label: label_merge, + }); + + instructions.push(Instruction::Label { + id_result: IdResult(label_merge), + }); + instructions.push(Instruction::Return); + + test_cfg( + &instructions, + &[ + SerializedCFGElement::Switch, + SerializedCFGElement::SwitchCase, + SerializedCFGElement::SwitchFallthrough, + SerializedCFGElement::SwitchDefaultCase, + SerializedCFGElement::SwitchFallthrough, + SerializedCFGElement::SwitchCase, + SerializedCFGElement::SwitchMerge, + SerializedCFGElement::SwitchEnd, SerializedCFGElement::Return, + ], + ); + } + + #[test] + fn test_cfg_switch_break_default_fallthrough_break() { + let mut id_factory = IdFactory::new(); + let mut instructions = Vec::new(); + + let label_start = id_factory.next(); + let label_case1 = id_factory.next(); + let label_default = id_factory.next(); + let label_case2 = id_factory.next(); + let label_merge = id_factory.next(); + + instructions.push(Instruction::NoLine); + instructions.push(Instruction::Label { + id_result: IdResult(label_start), + }); + instructions.push(Instruction::SelectionMerge { + merge_block: label_merge, + selection_control: spirv_parser::SelectionControl::default(), + }); + instructions.push(Instruction::Switch32 { + selector: id_factory.next(), + default: label_default, + target: vec![(0, label_case1), (1, label_case1), (2, label_case2)], + }); + + instructions.push(Instruction::Label { + id_result: IdResult(label_case1), + }); + instructions.push(Instruction::Branch { + target_label: label_merge, + }); + + instructions.push(Instruction::Label { + id_result: IdResult(label_default), + }); + instructions.push(Instruction::Branch { + target_label: label_case2, + }); + + instructions.push(Instruction::Label { + id_result: IdResult(label_case2), + }); + instructions.push(Instruction::Branch { + target_label: label_merge, + }); + + instructions.push(Instruction::Label { + id_result: IdResult(label_merge), + }); + instructions.push(Instruction::Return); + + test_cfg( + &instructions, + &[ + SerializedCFGElement::Switch, + SerializedCFGElement::SwitchCase, + SerializedCFGElement::SwitchMerge, SerializedCFGElement::SwitchDefaultCase, + SerializedCFGElement::SwitchFallthrough, + SerializedCFGElement::SwitchCase, SerializedCFGElement::SwitchMerge, SerializedCFGElement::SwitchEnd, SerializedCFGElement::Return, -- 2.30.2