From cca2522a8bea586559fb11002e246bced61d9d09 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Thu, 19 Jan 2023 23:21:25 -0800 Subject: [PATCH] Function verification should be complete, no tests yet --- Cargo.lock | 35 +- register_allocator/Cargo.toml | 1 + register_allocator/src/error.rs | 189 +++++++- register_allocator/src/function.rs | 754 ++++++++++++++++++++++++++--- register_allocator/src/index.rs | 32 ++ register_allocator/src/loc.rs | 32 ++ register_allocator/src/loc_set.rs | 50 +- 7 files changed, 989 insertions(+), 104 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3e2f1c9..761adf6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -25,9 +25,10 @@ version = "0.1.0" dependencies = [ "enum-map", "eyre", - "hashbrown", + "hashbrown 0.13.2", "num-bigint", "num-traits", + "petgraph", "scoped-tls", "serde", "serde_json", @@ -73,6 +74,18 @@ dependencies = [ "once_cell", ] +[[package]] +name = "fixedbitset" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" + +[[package]] +name = "hashbrown" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" + [[package]] name = "hashbrown" version = "0.13.2" @@ -89,6 +102,16 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ce23b50ad8242c51a442f3ff322d56b02f08852c77e4c0b4d3fd684abc89c683" +[[package]] +name = "indexmap" +version = "1.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1885e79c1fc4b10f0e172c475f458b7f7b93061064d98c3293e98c5ba0c8b399" +dependencies = [ + "autocfg", + "hashbrown 0.12.3", +] + [[package]] name = "itoa" version = "1.0.5" @@ -132,6 +155,16 @@ version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6f61fba1741ea2b3d6a1e3178721804bb716a68a6aeba1149b5d52e3d464ea66" +[[package]] +name = "petgraph" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6d5014253a1331579ce62aa67443b4a658c5e7dd03d4bc6d302b94474888143" +dependencies = [ + "fixedbitset", + "indexmap", +] + [[package]] name = "proc-macro2" version = "1.0.49" diff --git a/register_allocator/Cargo.toml b/register_allocator/Cargo.toml index 8a87392..25e6f0e 100644 --- a/register_allocator/Cargo.toml +++ b/register_allocator/Cargo.toml @@ -17,3 +17,4 @@ smallvec = { version = "1.10.0", features = ["serde", "union", "const_generics", num-bigint = { version = "0.4.3", features = ["serde"] } enum-map = { version = "2.4.2", features = ["serde"] } num-traits = "0.2.15" +petgraph = "0.6.2" diff --git a/register_allocator/src/error.rs b/register_allocator/src/error.rs index 104795e..d8e27d1 100644 --- a/register_allocator/src/error.rs +++ b/register_allocator/src/error.rs @@ -1,5 +1,5 @@ use crate::{ - index::{BlockIdx, InstIdx}, + index::{BlockIdx, DisplayOptionIdx, InstIdx, SSAValIdx}, loc::{BaseTy, Ty}, }; use thiserror::Error; @@ -43,9 +43,194 @@ pub enum Error { #[error("block's last instruction must be a block terminator: {term_idx}")] BlocksLastInstMustBeTerm { term_idx: InstIdx }, #[error( - "block terminator instructions are only allowed as a block's last instruction: {inst_idx}" + "block terminator instructions are only allowed as a block's last \ + instruction: {inst_idx}" )] TermInstOnlyAllowedAtBlockEnd { inst_idx: InstIdx }, + #[error("instruction not in a block: {inst}")] + InstHasNoBlock { inst: InstIdx }, + #[error("operand index {operand_idx} out of range for {inst}")] + OperandIndexOutOfRange { inst: InstIdx, operand_idx: usize }, + #[error("duplicate copy destination operand: operand index {operand_idx} for {inst}")] + DupCopyDestOperand { inst: InstIdx, operand_idx: usize }, + #[error("SSA value index {idx} out of range")] + SSAValIdxOutOfRange { idx: SSAValIdx }, + #[error("instruction index {idx} out of range")] + InstIdxOutOfRange { idx: InstIdx }, + #[error("block index {idx} out of range")] + BlockIdxOutOfRange { idx: BlockIdx }, + #[error("copy instruction's source type doesn't match source operands")] + CopySrcTyMismatch { inst: InstIdx }, + #[error("copy instruction's destination type doesn't match destination operands")] + CopyDestTyMismatch { inst: InstIdx }, + #[error( + "operand index {operand_idx} for {inst} is missing from SSA value \ + {ssa_val_idx}'s uses" + )] + MissingOperandUse { + ssa_val_idx: SSAValIdx, + inst: InstIdx, + operand_idx: usize, + }, + #[error( + "operand index {operand_idx} for {inst} has kind `Def` but isn't \ + SSA value {ssa_val_idx}'s definition" + )] + OperandDefIsNotSSAValDef { + ssa_val_idx: SSAValIdx, + inst: InstIdx, + operand_idx: usize, + }, + #[error( + "SSA value {ssa_val_idx}'s definition isn't the corresponding \ + operand's SSA Value: operand index {operand_idx} for {inst}" + )] + SSAValDefIsNotOperandsSSAVal { + ssa_val_idx: SSAValIdx, + inst: InstIdx, + operand_idx: usize, + }, + #[error( + "SSA value {ssa_val_idx}'s type can't be used with the constraint on \ + operand index {operand_idx} for {inst}" + )] + ConstraintTyMismatch { + ssa_val_idx: SSAValIdx, + inst: InstIdx, + operand_idx: usize, + }, + #[error( + "fixed location constraint on operand index {operand_idx} for \ + {inst} conflicts with clobbers" + )] + FixedLocConflictsWithClobbers { inst: InstIdx, operand_idx: usize }, + #[error("operand kind must be def")] + OperandKindMustBeDef, + #[error( + "reuse target operand (index {reuse_target_operand_idx}) for \ + {inst} must have kind `Use`" + )] + ReuseTargetOperandMustBeUse { + inst: InstIdx, + reuse_target_operand_idx: usize, + }, + #[error( + "source block {src_block} missing from branch {branch_inst}'s \ + target block {tgt_block}'s predecessors" + )] + SrcBlockMissingFromBranchTgtBlocksPreds { + src_block: BlockIdx, + branch_inst: InstIdx, + tgt_block: BlockIdx, + }, + #[error( + "branch {inst}'s parameter (index {param_idx}) for successor {succ} \ + is missing from SSA value {ssa_val_idx}'s uses" + )] + MissingBranchSuccParamUse { + ssa_val_idx: SSAValIdx, + inst: InstIdx, + succ: BlockIdx, + param_idx: usize, + }, + #[error( + "the number of parameters ({branch_param_count}) for branch {inst}'s \ + successor {succ} doesn't match the number of parameters \ + ({block_param_count}) declared in that block" + )] + BranchSuccParamCountMismatch { + inst: InstIdx, + succ: BlockIdx, + block_param_count: usize, + branch_param_count: usize, + }, + #[error( + "the type {branch_param_ty:?} of parameter {param_idx} for branch \ + {inst}'s successor {succ} doesn't match the type {block_param_ty:?} \ + declared in that block" + )] + BranchSuccParamTyMismatch { + inst: InstIdx, + succ: BlockIdx, + param_idx: usize, + block_param_ty: Ty, + branch_param_ty: Ty, + }, + #[error( + "block {block}'s parameter {param_idx} doesn't match SSA value \ + {ssa_val_idx}'s definition" + )] + MismatchedBlockParamDef { + ssa_val_idx: SSAValIdx, + block: BlockIdx, + param_idx: usize, + }, + #[error( + "predecessor {src_block} of target block {tgt_block} is missing from \ + that predecessor block's terminating branch {branch_inst}'s targets" + )] + PredMissingFromPredsTermBranchsTargets { + src_block: BlockIdx, + branch_inst: InstIdx, + tgt_block: BlockIdx, + }, + #[error("block parameter index {param_idx} is out of range for block {block}")] + BlockParamIdxOutOfRange { block: BlockIdx, param_idx: usize }, + #[error( + "SSA value {ssa_val_idx}'s use isn't the corresponding \ + operand's SSA Value: operand index {operand_idx} for {inst}" + )] + SSAValUseIsNotOperandsSSAVal { + ssa_val_idx: SSAValIdx, + inst: InstIdx, + operand_idx: usize, + }, + #[error( + "SSA value {ssa_val_idx} is use as a branch instruction {inst}'s \ + block parameter, but that instruction isn't a branch instruction" + )] + SSAValUseIsNotBranch { + ssa_val_idx: SSAValIdx, + inst: InstIdx, + }, + #[error("expected instruction {inst} to be a `BlockTerm` instruction")] + InstIsNotBlockTerm { inst: InstIdx }, + #[error("target block {tgt_block} not found in branch instruction {branch_inst}")] + BranchTargetNotFound { + branch_inst: InstIdx, + tgt_block: BlockIdx, + }, + #[error( + "branch instruction {branch_inst}'s block parameter index {param_idx} \ + is out of range for target block {tgt_block}" + )] + BranchTargetParamIdxOutOfRange { + branch_inst: InstIdx, + tgt_block: BlockIdx, + param_idx: usize, + }, + #[error( + "SSA value {ssa_val_idx}'s use isn't the corresponding \ + branch {branch_inst}'s target block parameter's SSA Value for \ + target block {tgt_block}'s parameter index {param_idx}" + )] + MismatchedBranchTargetBlockParamUse { + ssa_val_idx: SSAValIdx, + branch_inst: InstIdx, + tgt_block: BlockIdx, + param_idx: usize, + }, + #[error( + "block {block_idx} has incorrect immediate dominator: expected \ + {} found {}", + .expected.display_option_idx(), + .found.display_option_idx(), + )] + IncorrectImmediateDominator { + block_idx: BlockIdx, + found: Option, + expected: Option, + }, } pub type Result = std::result::Result; diff --git a/register_allocator/src/function.rs b/register_allocator/src/function.rs index f9eacec..2a1e811 100644 --- a/register_allocator/src/function.rs +++ b/register_allocator/src/function.rs @@ -1,17 +1,109 @@ use crate::{ error::{Error, Result}, index::{BlockIdx, InstIdx, InstRange, SSAValIdx}, - interned::Interned, - loc::{Loc, Ty}, + interned::{GlobalState, Intern, Interned}, + loc::{BaseTy, Loc, Ty}, loc_set::LocSet, }; use core::fmt; +use hashbrown::HashSet; +use petgraph::{ + algo::dominators, + visit::{GraphBase, GraphProp, IntoNeighbors, VisitMap, Visitable}, + Directed, +}; use serde::{Deserialize, Serialize}; -use std::ops::Index; +use smallvec::SmallVec; +use std::{ + collections::{btree_map, BTreeMap, BTreeSet}, + mem, + ops::Index, +}; #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] +pub enum SSAValDef { + BlockParam { block: BlockIdx, param_idx: usize }, + Operand { inst: InstIdx, operand_idx: usize }, +} + +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize)] +pub struct BranchSuccParamUse { + branch_inst: InstIdx, + succ: BlockIdx, + param_idx: usize, +} + +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize)] +pub struct OperandUse { + inst: InstIdx, + operand_idx: usize, +} + +#[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] pub struct SSAVal { pub ty: Ty, + pub def: SSAValDef, + pub operand_uses: BTreeSet, + pub branch_succ_param_uses: BTreeSet, +} + +impl SSAVal { + fn validate(&self, ssa_val_idx: SSAValIdx, func: &FnFields) -> Result<()> { + let Self { + ty: _, + def, + operand_uses, + branch_succ_param_uses, + } = self; + match *def { + SSAValDef::BlockParam { block, param_idx } => { + let block_param = func.try_get_block_param(block, param_idx)?; + if ssa_val_idx != block_param { + return Err(Error::MismatchedBlockParamDef { + ssa_val_idx, + block, + param_idx, + }); + } + } + SSAValDef::Operand { inst, operand_idx } => { + let operand = func.try_get_operand(inst, operand_idx)?; + if ssa_val_idx != operand.ssa_val { + return Err(Error::SSAValDefIsNotOperandsSSAVal { + ssa_val_idx, + inst, + operand_idx, + }); + } + } + } + for &OperandUse { inst, operand_idx } in operand_uses { + let operand = func.try_get_operand(inst, operand_idx)?; + if ssa_val_idx != operand.ssa_val { + return Err(Error::SSAValUseIsNotOperandsSSAVal { + ssa_val_idx, + inst, + operand_idx, + }); + } + } + for &BranchSuccParamUse { + branch_inst, + succ, + param_idx, + } in branch_succ_param_uses + { + if ssa_val_idx != func.try_get_branch_target_param(branch_inst, succ, param_idx)? { + return Err(Error::MismatchedBranchTargetBlockParamUse { + ssa_val_idx, + branch_inst, + tgt_block: succ, + param_idx, + }); + } + } + Ok(()) + } } #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize)] @@ -107,35 +199,236 @@ pub enum Constraint { /// any stack location Stack, FixedLoc(Loc), - Reuse(usize), +} + +impl Constraint { + pub fn is_any(&self) -> bool { + matches!(self, Self::Any) + } +} + +impl Default for Constraint { + fn default() -> Self { + Self::Any + } +} + +#[derive( + Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize, Default, +)] +#[serde(try_from = "OperandKind", into = "OperandKind")] +pub struct OperandKindDefOnly; + +impl TryFrom for OperandKindDefOnly { + type Error = Error; + + fn try_from(value: OperandKind) -> Result { + match value { + OperandKind::Use => Err(Error::OperandKindMustBeDef), + OperandKind::Def => Ok(Self), + } + } +} + +impl From for OperandKind { + fn from(_value: OperandKindDefOnly) -> Self { + Self::Def + } +} + +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize)] +#[serde(untagged)] +pub enum KindAndConstraint { + Reuse { + kind: OperandKindDefOnly, + reuse_operand_idx: usize, + }, + Constraint { + kind: OperandKind, + #[serde(default, skip_serializing_if = "Constraint::is_any")] + constraint: Constraint, + }, +} + +impl KindAndConstraint { + pub fn kind(self) -> OperandKind { + match self { + Self::Reuse { .. } => OperandKind::Def, + Self::Constraint { kind, .. } => kind, + } + } + pub fn is_reuse(self) -> bool { + matches!(self, Self::Reuse { .. }) + } } #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] pub struct Operand { pub ssa_val: SSAValIdx, - pub constraint: Constraint, - pub kind: OperandKind, + #[serde(flatten)] + pub kind_and_constraint: KindAndConstraint, pub stage: InstStage, } +impl Operand { + pub fn try_get_reuse_src<'f>( + &self, + inst: InstIdx, + func: &'f FnFields, + ) -> Result> { + if let KindAndConstraint::Reuse { + reuse_operand_idx, .. + } = self.kind_and_constraint + { + Ok(Some(func.try_get_operand(inst, reuse_operand_idx)?)) + } else { + Ok(None) + } + } + pub fn try_constraint(&self, inst: InstIdx, func: &FnFields) -> Result { + Ok(match self.kind_and_constraint { + KindAndConstraint::Reuse { + kind: _, + reuse_operand_idx, + } => { + let operand = func.try_get_operand(inst, reuse_operand_idx)?; + match operand.kind_and_constraint { + KindAndConstraint::Reuse { .. } + | KindAndConstraint::Constraint { + kind: OperandKind::Def, + .. + } => { + return Err(Error::ReuseTargetOperandMustBeUse { + inst, + reuse_target_operand_idx: reuse_operand_idx, + }) + } + KindAndConstraint::Constraint { + kind: OperandKind::Use, + constraint, + } => constraint, + } + } + KindAndConstraint::Constraint { constraint, .. } => constraint, + }) + } + pub fn constraint(&self, inst: InstIdx, func: &Function) -> Constraint { + self.try_constraint(inst, func).unwrap() + } + fn validate( + self, + _block: BlockIdx, + inst: InstIdx, + operand_idx: usize, + func: &FnFields, + global_state: &GlobalState, + ) -> Result<()> { + let Self { + ssa_val: ssa_val_idx, + kind_and_constraint, + stage: _, + } = self; + let ssa_val = func.try_get_ssa_val(ssa_val_idx)?; + match kind_and_constraint.kind() { + OperandKind::Use => { + if !ssa_val + .operand_uses + .contains(&OperandUse { inst, operand_idx }) + { + return Err(Error::MissingOperandUse { + ssa_val_idx, + inst, + operand_idx, + }); + } + } + OperandKind::Def => { + let def = SSAValDef::Operand { inst, operand_idx }; + if ssa_val.def != def { + return Err(Error::OperandDefIsNotSSAValDef { + ssa_val_idx, + inst, + operand_idx, + }); + } + } + } + let constraint = self.try_constraint(inst, func)?; + match constraint { + Constraint::Any | Constraint::Stack => {} + Constraint::BaseGpr | Constraint::SVExtra2SGpr => { + if ssa_val.ty != Ty::scalar(BaseTy::Bits64) { + return Err(Error::ConstraintTyMismatch { + ssa_val_idx, + inst, + operand_idx, + }); + } + } + Constraint::SVExtra2VGpr | Constraint::SVExtra3Gpr => { + if ssa_val.ty.base_ty != BaseTy::Bits64 { + return Err(Error::ConstraintTyMismatch { + ssa_val_idx, + inst, + operand_idx, + }); + } + } + Constraint::FixedLoc(loc) => { + if func + .try_get_inst(inst)? + .clobbers + .clone() + .conflicts_with(loc, global_state) + { + return Err(Error::FixedLocConflictsWithClobbers { inst, operand_idx }); + } + if ssa_val.ty != loc.ty() { + return Err(Error::ConstraintTyMismatch { + ssa_val_idx, + inst, + operand_idx, + }); + } + } + } + Ok(()) + } +} + +/// copy concatenates all `srcs` together and de-concatenates the result into all `dests`. #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] -pub struct BranchSucc { - pub block: BlockIdx, - pub params: Vec, +pub struct CopyInstKind { + pub src_operand_idxs: Vec, + pub dest_operand_idxs: Vec, + pub copy_ty: Ty, +} + +impl CopyInstKind { + fn calc_copy_ty(operand_idxs: &[usize], inst: InstIdx, func: &FnFields) -> Result> { + let mut retval: Option = None; + for &operand_idx in operand_idxs { + let operand = func.try_get_operand(inst, operand_idx)?; + let ssa_val = func.try_get_ssa_val(operand.ssa_val)?; + retval = Some(match retval { + Some(retval) => retval.try_concat(ssa_val.ty)?, + None => ssa_val.ty, + }); + } + Ok(retval) + } +} + +#[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] +pub struct BlockTermInstKind { + pub succs_and_params: BTreeMap>, } #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] pub enum InstKind { Normal, - /// copy concatenates all `srcs` together and de-concatenates the result into all `dests`. - Copy { - srcs: Vec, - dests: Vec, - }, - Return, - Branch { - succs: Vec, - }, + Copy(CopyInstKind), + BlockTerm(BlockTermInstKind), } impl InstKind { @@ -143,13 +436,21 @@ impl InstKind { matches!(self, Self::Normal) } pub fn is_block_term(&self) -> bool { - matches!(self, Self::Return | Self::Branch { .. }) + matches!(self, Self::BlockTerm { .. }) + } + pub fn is_copy(&self) -> bool { + matches!(self, Self::Copy { .. }) } - pub fn succs(&self) -> Option<&[BranchSucc]> { + pub fn block_term(&self) -> Option<&BlockTermInstKind> { match self { - InstKind::Normal | InstKind::Copy { .. } => None, - InstKind::Return => Some(&[]), - InstKind::Branch { succs } => Some(succs), + InstKind::BlockTerm(v) => Some(v), + _ => None, + } + } + pub fn copy(&self) -> Option<&CopyInstKind> { + match self { + InstKind::Copy(v) => Some(v), + _ => None, } } } @@ -160,19 +461,207 @@ impl Default for InstKind { } } +fn loc_set_is_empty(clobbers: &Interned) -> bool { + clobbers.is_empty() +} + +fn empty_loc_set() -> Interned { + GlobalState::get(|global_state| LocSet::default().into_interned(global_state)) +} + #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] pub struct Inst { #[serde(default, skip_serializing_if = "InstKind::is_normal")] pub kind: InstKind, pub operands: Vec, + #[serde(default = "empty_loc_set", skip_serializing_if = "loc_set_is_empty")] pub clobbers: Interned, } +impl Inst { + fn validate( + &self, + block: BlockIdx, + inst: InstIdx, + func: &FnFields, + global_state: &GlobalState, + ) -> Result<()> { + let Self { + kind, + operands, + clobbers: _, + } = self; + let is_at_end_of_block = func.blocks[block].insts.last() == Some(inst); + if kind.is_block_term() != is_at_end_of_block { + return Err(if is_at_end_of_block { + Error::BlocksLastInstMustBeTerm { term_idx: inst } + } else { + Error::TermInstOnlyAllowedAtBlockEnd { inst_idx: inst } + }); + } + for (idx, operand) in operands.iter().enumerate() { + operand.validate(block, inst, idx, func, global_state)?; + } + match kind { + InstKind::Normal => {} + InstKind::Copy(CopyInstKind { + src_operand_idxs, + dest_operand_idxs, + copy_ty, + }) => { + let mut seen_dest_operands = SmallVec::<[bool; 16]>::new(); + seen_dest_operands.resize(operands.len(), false); + for &dest_operand_idx in dest_operand_idxs { + let seen_dest_operand = seen_dest_operands.get_mut(dest_operand_idx).ok_or( + Error::OperandIndexOutOfRange { + inst, + operand_idx: dest_operand_idx, + }, + )?; + if mem::replace(seen_dest_operand, true) { + return Err(Error::DupCopyDestOperand { + inst, + operand_idx: dest_operand_idx, + }); + } + } + if Some(*copy_ty) != CopyInstKind::calc_copy_ty(&src_operand_idxs, inst, func)? { + return Err(Error::CopySrcTyMismatch { inst }); + } + if Some(*copy_ty) != CopyInstKind::calc_copy_ty(&dest_operand_idxs, inst, func)? { + return Err(Error::CopyDestTyMismatch { inst }); + } + } + InstKind::BlockTerm(BlockTermInstKind { succs_and_params }) => { + for (&succ_idx, params) in succs_and_params { + let succ = func.try_get_block(succ_idx)?; + if !succ.preds.contains(&block) { + return Err(Error::SrcBlockMissingFromBranchTgtBlocksPreds { + src_block: block, + branch_inst: inst, + tgt_block: succ_idx, + }); + } + if succ.params.len() != params.len() { + return Err(Error::BranchSuccParamCountMismatch { + inst, + succ: succ_idx, + block_param_count: succ.params.len(), + branch_param_count: params.len(), + }); + } + for (param_idx, (&branch_ssa_val_idx, &block_ssa_val_idx)) in + params.iter().zip(&succ.params).enumerate() + { + let branch_ssa_val = func.try_get_ssa_val(branch_ssa_val_idx)?; + let block_ssa_val = func.try_get_ssa_val(block_ssa_val_idx)?; + if !branch_ssa_val + .branch_succ_param_uses + .contains(&BranchSuccParamUse { + branch_inst: inst, + succ: succ_idx, + param_idx, + }) + { + return Err(Error::MissingBranchSuccParamUse { + ssa_val_idx: branch_ssa_val_idx, + inst, + succ: succ_idx, + param_idx, + }); + } + if block_ssa_val.ty != branch_ssa_val.ty { + return Err(Error::BranchSuccParamTyMismatch { + inst, + succ: succ_idx, + param_idx, + block_param_ty: block_ssa_val.ty, + branch_param_ty: branch_ssa_val.ty, + }); + } + } + } + } + } + Ok(()) + } + pub fn try_get_operand(&self, inst: InstIdx, operand_idx: usize) -> Result<&Operand> { + self.operands + .get(operand_idx) + .ok_or(Error::OperandIndexOutOfRange { inst, operand_idx }) + } +} + #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] pub struct Block { pub params: Vec, pub insts: InstRange, - pub preds: Vec, + pub preds: BTreeSet, + pub immediate_dominator: Option, +} + +impl Block { + fn validate(&self, block: BlockIdx, func: &FnFields, global_state: &GlobalState) -> Result<()> { + let Self { + params, + insts, + preds, + immediate_dominator: _, + } = self; + const _: () = assert!(BlockIdx::ENTRY_BLOCK.get() == 0); + let expected_start = if block == BlockIdx::ENTRY_BLOCK { + InstIdx::new(0) + } else { + func.blocks[block.prev()].insts.end + }; + if insts.start != expected_start { + return Err(Error::BlockHasInvalidStart { + start: insts.start, + expected_start, + }); + } + let term_inst_idx = insts.last().ok_or(Error::BlockIsEmpty { block })?; + func.insts + .get(term_inst_idx.get()) + .ok_or(Error::BlockEndOutOfRange { end: insts.end })?; + if block.get() == func.blocks.len() - 1 && insts.end.get() != func.insts.len() { + return Err(Error::InstHasNoBlock { inst: insts.end }); + } + if block == BlockIdx::ENTRY_BLOCK { + if !params.is_empty() { + return Err(Error::EntryBlockCantHaveParams); + } + if !preds.is_empty() { + return Err(Error::EntryBlockCantHavePreds); + } + } + for inst in *insts { + func.insts[inst].validate(block, inst, func, global_state)?; + } + for (param_idx, &ssa_val_idx) in params.iter().enumerate() { + let ssa_val = func.try_get_ssa_val(ssa_val_idx)?; + let def = SSAValDef::BlockParam { block, param_idx }; + if ssa_val.def != def { + return Err(Error::MismatchedBlockParamDef { + ssa_val_idx, + block, + param_idx, + }); + } + } + for &pred in preds { + let (term_inst, BlockTermInstKind { succs_and_params }) = + func.try_get_block_term_inst_and_kind(pred)?; + if !succs_and_params.contains_key(&pred) { + return Err(Error::PredMissingFromPredsTermBranchsTargets { + src_block: pred, + branch_inst: term_inst, + tgt_block: block, + }); + } + } + Ok(()) + } } validated_fields! { @@ -182,68 +671,153 @@ validated_fields! { pub ssa_vals: Vec, pub insts: Vec, pub blocks: Vec, + #[serde(skip)] + /// map from blocks' start instruction's index to their block index, doesn't contain the entry block + pub start_inst_to_block_map: BTreeMap, } } impl Function { pub fn new(fields: FnFields) -> Result { + GlobalState::get(|global_state| Self::new_with_global_state(fields, global_state)) + } + pub fn new_with_global_state(mut fields: FnFields, global_state: &GlobalState) -> Result { + fields.fill_start_inst_to_block_map(); let FnFields { ssa_vals, - insts: insts_vec, + insts: _, blocks, + start_inst_to_block_map: _, } = &fields; - let entry_block = blocks + blocks .get(BlockIdx::ENTRY_BLOCK.get()) .ok_or(Error::MissingEntryBlock)?; - if !entry_block.params.is_empty() { - return Err(Error::EntryBlockCantHaveParams); - } - if !entry_block.preds.is_empty() { - return Err(Error::EntryBlockCantHavePreds); + for (idx, block) in blocks.iter().enumerate() { + block.validate(BlockIdx::new(idx), &fields, global_state)?; } - let mut expected_start = InstIdx::new(0); - for (block_idx, block) in fields.blocks.iter().enumerate() { - let block_idx = BlockIdx::new(block_idx); - let Block { - params, - insts: inst_range, - preds, - } = block; - if inst_range.start != expected_start { - return Err(Error::BlockHasInvalidStart { - start: inst_range.start, - expected_start, + let dominators = dominators::simple_fast(&fields, BlockIdx::ENTRY_BLOCK); + for (idx, block) in blocks.iter().enumerate() { + let block_idx = BlockIdx::new(idx); + let expected = dominators.immediate_dominator(block_idx); + if block.immediate_dominator != expected { + return Err(Error::IncorrectImmediateDominator { + block_idx, + found: block.immediate_dominator, + expected, }); } - let Some((term_idx, non_term_inst_range)) = inst_range.split_last() else { - return Err(Error::BlockIsEmpty { block: block_idx }); - }; - expected_start = inst_range.end; - let Some(Inst { kind: term_kind, .. }) = insts_vec.get(term_idx.get()) else { - return Err(Error::BlockEndOutOfRange { end: inst_range.end }); - }; - if !term_kind.is_block_term() { - return Err(Error::BlocksLastInstMustBeTerm { term_idx }); - } - for inst_idx in non_term_inst_range { - if insts_vec[inst_idx].kind.is_block_term() { - return Err(Error::TermInstOnlyAllowedAtBlockEnd { inst_idx }); - } - } } - todo!() + for (idx, ssa_val) in ssa_vals.iter().enumerate() { + ssa_val.validate(SSAValIdx::new(idx), &fields)?; + } + Ok(Self(fields)) } pub fn entry_block(&self) -> &Block { &self.blocks[0] } - pub fn block_succs(&self, block: BlockIdx) -> &[BranchSucc] { + pub fn block_term_kind(&self, block: BlockIdx) -> &BlockTermInstKind { self.insts[self.blocks[block].insts.last().unwrap()] .kind - .succs() + .block_term() .unwrap() } } +impl FnFields { + pub fn fill_start_inst_to_block_map(&mut self) { + self.start_inst_to_block_map.clear(); + for (idx, block) in self.blocks.iter().enumerate() { + let block_idx = BlockIdx::new(idx); + if block_idx != BlockIdx::ENTRY_BLOCK { + self.start_inst_to_block_map + .insert(block.insts.start, block_idx); + } + } + } + pub fn try_get_ssa_val(&self, idx: SSAValIdx) -> Result<&SSAVal> { + self.ssa_vals + .get(idx.get()) + .ok_or(Error::SSAValIdxOutOfRange { idx }) + } + pub fn try_get_inst(&self, idx: InstIdx) -> Result<&Inst> { + self.insts + .get(idx.get()) + .ok_or(Error::InstIdxOutOfRange { idx }) + } + pub fn try_get_operand(&self, inst: InstIdx, operand_idx: usize) -> Result<&Operand> { + self.try_get_inst(inst)?.try_get_operand(inst, operand_idx) + } + pub fn try_get_block(&self, idx: BlockIdx) -> Result<&Block> { + self.blocks + .get(idx.get()) + .ok_or(Error::BlockIdxOutOfRange { idx }) + } + pub fn try_get_block_param(&self, block: BlockIdx, param_idx: usize) -> Result { + self.try_get_block(block)? + .params + .get(param_idx) + .copied() + .ok_or(Error::BlockParamIdxOutOfRange { block, param_idx }) + } + pub fn try_get_block_term_inst_idx(&self, block: BlockIdx) -> Result { + self.try_get_block(block)? + .insts + .last() + .ok_or(Error::BlockIsEmpty { block }) + } + pub fn try_get_block_term_inst_and_kind( + &self, + block: BlockIdx, + ) -> Result<(InstIdx, &BlockTermInstKind)> { + let term_idx = self.try_get_block_term_inst_idx(block)?; + let term_kind = self + .try_get_inst(term_idx)? + .kind + .block_term() + .ok_or(Error::BlocksLastInstMustBeTerm { term_idx })?; + Ok((term_idx, term_kind)) + } + pub fn try_get_branch_target_params( + &self, + branch_inst: InstIdx, + succ: BlockIdx, + ) -> Result<&[SSAValIdx]> { + let inst = self.try_get_inst(branch_inst)?; + let BlockTermInstKind { succs_and_params } = inst + .kind + .block_term() + .ok_or(Error::InstIsNotBlockTerm { inst: branch_inst })?; + Ok(succs_and_params + .get(&succ) + .ok_or(Error::BranchTargetNotFound { + branch_inst, + tgt_block: succ, + })?) + } + pub fn try_get_branch_target_param( + &self, + branch_inst: InstIdx, + succ: BlockIdx, + param_idx: usize, + ) -> Result { + Ok(*self + .try_get_branch_target_params(branch_inst, succ)? + .get(param_idx) + .ok_or(Error::BranchTargetParamIdxOutOfRange { + branch_inst, + tgt_block: succ, + param_idx, + })?) + } + pub fn inst_to_block(&self, inst: InstIdx) -> BlockIdx { + self.start_inst_to_block_map + .range(..=inst) + .next_back() + .map(|v| *v.1) + .unwrap_or(BlockIdx::ENTRY_BLOCK) + } +} + impl Index for Vec { type Output = SSAVal; @@ -267,3 +841,61 @@ impl Index for Vec { &self[index.get()] } } + +impl GraphBase for FnFields { + type EdgeId = (BlockIdx, BlockIdx); + type NodeId = BlockIdx; +} + +pub struct Neighbors<'a> { + iter: Option>>, +} + +impl Iterator for Neighbors<'_> { + type Item = BlockIdx; + + fn next(&mut self) -> Option { + Some(*self.iter.as_mut()?.next()?) + } +} + +impl<'a> IntoNeighbors for &'a FnFields { + type Neighbors = Neighbors<'a>; + + fn neighbors(self, block_idx: Self::NodeId) -> Self::Neighbors { + Neighbors { + iter: self + .try_get_block_term_inst_and_kind(block_idx) + .ok() + .map(|(_, BlockTermInstKind { succs_and_params })| succs_and_params.keys()), + } + } +} + +pub struct VisitedMap(HashSet); + +impl VisitMap for VisitedMap { + fn visit(&mut self, block: BlockIdx) -> bool { + self.0.insert(block) + } + + fn is_visited(&self, block: &BlockIdx) -> bool { + self.0.contains(block) + } +} + +impl Visitable for FnFields { + type Map = VisitedMap; + + fn visit_map(&self) -> Self::Map { + VisitedMap(HashSet::new()) + } + + fn reset_map(&self, map: &mut Self::Map) { + map.0.clear(); + } +} + +impl GraphProp for FnFields { + type EdgeType = Directed; +} diff --git a/register_allocator/src/index.rs b/register_allocator/src/index.rs index d1423f3..3c99fb6 100644 --- a/register_allocator/src/index.rs +++ b/register_allocator/src/index.rs @@ -1,6 +1,11 @@ use serde::{Deserialize, Serialize}; use std::{fmt, iter::FusedIterator, ops::Range}; +pub trait DisplayOptionIdx { + type Type: fmt::Display; + fn display_option_idx(self) -> Self::Type; +} + macro_rules! define_index { ($name:ident) => { #[derive( @@ -29,6 +34,33 @@ macro_rules! define_index { } } } + + const _: () = { + #[derive(Copy, Clone)] + pub struct DisplayOptionIdxImpl(Option<$name>); + + impl fmt::Debug for DisplayOptionIdxImpl { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(self, f) + } + } + + impl fmt::Display for DisplayOptionIdxImpl { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.0 { + Some(v) => v.fmt(f), + None => write!(f, "none"), + } + } + } + impl DisplayOptionIdx for Option<$name> { + type Type = DisplayOptionIdxImpl; + + fn display_option_idx(self) -> Self::Type { + DisplayOptionIdxImpl(self) + } + } + }; }; } diff --git a/register_allocator/src/loc.rs b/register_allocator/src/loc.rs index 9af021e..eec4290 100644 --- a/register_allocator/src/loc.rs +++ b/register_allocator/src/loc.rs @@ -83,6 +83,38 @@ impl Ty { Ok(Self(fields)) } } + /// returns the `Ty` for `fields` or if there was an error, returns the corresponding scalar type (where `reg_len` is `1`) + pub const fn new_or_scalar(fields: TyFields) -> Self { + match Self::new(fields) { + Ok(v) => v, + Err(_) => Self::scalar(fields.base_ty), + } + } + pub const fn scalar(base_ty: BaseTy) -> Self { + Self(TyFields { + base_ty, + reg_len: nzu32_lit!(1), + }) + } + pub const fn bits64(reg_len: NonZeroU32) -> Self { + Self(TyFields { + base_ty: BaseTy::Bits64, + reg_len, + }) + } + pub const fn try_concat(self, rhs: Self) -> Result { + if !self.get().base_ty.const_eq(rhs.get().base_ty) { + Err(Error::BaseTyMismatch) + } else { + let Some(reg_len) = self.get().reg_len.checked_add(rhs.get().reg_len.get()) else { + return Err(Error::RegLenOutOfRange); + }; + Ty::new(TyFields { + base_ty: self.get().base_ty, + reg_len, + }) + } + } } validated_fields! { diff --git a/register_allocator/src/loc_set.rs b/register_allocator/src/loc_set.rs index b10358b..05344b7 100644 --- a/register_allocator/src/loc_set.rs +++ b/register_allocator/src/loc_set.rs @@ -83,16 +83,9 @@ impl LocSet { for (kind, starts) in &starts { if !starts.is_zero() { empty = false; - let expected_ty = Ty::new(TyFields { + let expected_ty = Ty::new_or_scalar(TyFields { base_ty: kind.base_ty(), reg_len: ty.map(|v| v.reg_len).unwrap_or(nzu32_lit!(1)), - }) - .unwrap_or_else(|_| { - Ty::new(TyFields { - base_ty: kind.base_ty(), - reg_len: nzu32_lit!(1), - }) - .unwrap() }); if ty != Some(expected_ty) { return Err(Error::TyMismatch { @@ -120,7 +113,7 @@ impl LocSet { v.assign_from_slice(&[]); } } - pub fn contains(&self, value: Loc) -> bool { + pub fn contains_exact(&self, value: Loc) -> bool { Some(value.ty()) == self.ty && self.starts[value.kind].bit(value.start as _) } pub fn try_insert(&mut self, value: Loc) -> Result { @@ -144,7 +137,7 @@ impl LocSet { self.try_insert(value).unwrap() } pub fn remove(&mut self, value: Loc) -> bool { - if self.contains(value) { + if self.contains_exact(value) { self.starts[value.kind].set_bit(value.start as u64, false); if self.starts.values().all(BigUint::is_zero) { self.ty = None; @@ -157,36 +150,6 @@ impl LocSet { pub fn is_empty(&self) -> bool { self.ty.is_none() } - pub fn is_disjoint(&self, other: &LocSet) -> bool { - if self.ty != other.ty || self.is_empty() { - return true; - } - for (k, lhs) in self.starts.iter() { - let rhs = &other.starts[k]; - if !(lhs & rhs).is_zero() { - return false; - } - } - true - } - pub fn is_subset(&self, containing_set: &LocSet) -> bool { - if self.is_empty() { - return true; - } - if self.ty != containing_set.ty { - return false; - } - for (k, v) in self.starts.iter() { - let containing_set = &containing_set.starts[k]; - if !and_not(v, containing_set).is_zero() { - return false; - } - } - true - } - pub fn is_superset(&self, contained_set: &LocSet) -> bool { - contained_set.is_subset(self) - } pub fn iter(&self) -> Iter<'_> { if let Some(ty) = self.ty { let mut starts = self.starts.iter().peekable(); @@ -609,6 +572,13 @@ impl Interned { ) .result(global_state) } + pub fn conflicts_with(self, rhs: Rhs, global_state: &GlobalState) -> bool + where + Rhs: LocSetMaxConflictsWithTrait, + LocSetMaxConflictsWith: InternTarget, + { + self.max_conflicts_with(rhs, global_state) != 0 + } } #[cfg(test)] -- 2.30.2