28ff917cb4a17c9a16e3e9acd97b26270db70777
[bigint-presentation-code.git] / register_allocator / src / function.rs
1 use crate::{
2 error::{Error, Result},
3 index::{BlockIdx, InstIdx, InstRange, SSAValIdx},
4 interned::{GlobalState, Intern, Interned},
5 loc::{BaseTy, Loc, Ty},
6 loc_set::LocSet,
7 };
8 use core::fmt;
9 use hashbrown::HashSet;
10 use petgraph::{
11 algo::dominators,
12 visit::{GraphBase, GraphProp, IntoNeighbors, VisitMap, Visitable},
13 Directed,
14 };
15 use serde::{Deserialize, Serialize};
16 use smallvec::SmallVec;
17 use std::{
18 collections::{btree_map, BTreeMap, BTreeSet},
19 mem,
20 ops::{Index, IndexMut},
21 };
22
23 #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
24 pub enum SSAValDef {
25 BlockParam { block: BlockIdx, param_idx: usize },
26 Operand { inst: InstIdx, operand_idx: usize },
27 }
28
29 #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize)]
30 pub struct BranchSuccParamUse {
31 branch_inst: InstIdx,
32 succ: BlockIdx,
33 param_idx: usize,
34 }
35
36 #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize)]
37 pub struct OperandUse {
38 inst: InstIdx,
39 operand_idx: usize,
40 }
41
42 #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
43 pub struct SSAVal {
44 pub ty: Ty,
45 pub def: SSAValDef,
46 pub operand_uses: BTreeSet<OperandUse>,
47 pub branch_succ_param_uses: BTreeSet<BranchSuccParamUse>,
48 }
49
50 impl SSAVal {
51 fn validate(&self, ssa_val_idx: SSAValIdx, func: &FnFields) -> Result<()> {
52 let Self {
53 ty: _,
54 def,
55 operand_uses,
56 branch_succ_param_uses,
57 } = self;
58 match *def {
59 SSAValDef::BlockParam { block, param_idx } => {
60 let block_param = func.try_get_block_param(block, param_idx)?;
61 if ssa_val_idx != block_param {
62 return Err(Error::MismatchedBlockParamDef {
63 ssa_val_idx,
64 block,
65 param_idx,
66 });
67 }
68 }
69 SSAValDef::Operand { inst, operand_idx } => {
70 let operand = func.try_get_operand(inst, operand_idx)?;
71 if ssa_val_idx != operand.ssa_val {
72 return Err(Error::SSAValDefIsNotOperandsSSAVal {
73 ssa_val_idx,
74 inst,
75 operand_idx,
76 });
77 }
78 }
79 }
80 for &OperandUse { inst, operand_idx } in operand_uses {
81 let operand = func.try_get_operand(inst, operand_idx)?;
82 if ssa_val_idx != operand.ssa_val {
83 return Err(Error::SSAValUseIsNotOperandsSSAVal {
84 ssa_val_idx,
85 inst,
86 operand_idx,
87 });
88 }
89 }
90 for &BranchSuccParamUse {
91 branch_inst,
92 succ,
93 param_idx,
94 } in branch_succ_param_uses
95 {
96 if ssa_val_idx != func.try_get_branch_target_param(branch_inst, succ, param_idx)? {
97 return Err(Error::MismatchedBranchTargetBlockParamUse {
98 ssa_val_idx,
99 branch_inst,
100 tgt_block: succ,
101 param_idx,
102 });
103 }
104 }
105 Ok(())
106 }
107 }
108
109 #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize)]
110 #[repr(u8)]
111 pub enum InstStage {
112 Early = 0,
113 Late = 1,
114 }
115
116 #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
117 #[serde(try_from = "SerializedProgPoint", into = "SerializedProgPoint")]
118 pub struct ProgPoint(usize);
119
120 impl ProgPoint {
121 pub const fn new(inst: InstIdx, stage: InstStage) -> Self {
122 const_unwrap_res!(Self::try_new(inst, stage))
123 }
124 pub const fn try_new(inst: InstIdx, stage: InstStage) -> Result<Self> {
125 let Some(inst) = inst.get().checked_shl(1) else {
126 return Err(Error::InstIdxTooBig);
127 };
128 Ok(Self(inst | stage as usize))
129 }
130 pub const fn inst(self) -> InstIdx {
131 InstIdx::new(self.0 >> 1)
132 }
133 pub const fn stage(self) -> InstStage {
134 if self.0 & 1 != 0 {
135 InstStage::Late
136 } else {
137 InstStage::Early
138 }
139 }
140 pub const fn next(self) -> Self {
141 Self(self.0 + 1)
142 }
143 pub const fn prev(self) -> Self {
144 Self(self.0 - 1)
145 }
146 }
147
148 impl fmt::Debug for ProgPoint {
149 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
150 f.debug_struct("ProgPoint")
151 .field("inst", &self.inst())
152 .field("stage", &self.stage())
153 .finish()
154 }
155 }
156
157 #[derive(Serialize, Deserialize)]
158 struct SerializedProgPoint {
159 inst: InstIdx,
160 stage: InstStage,
161 }
162
163 impl From<ProgPoint> for SerializedProgPoint {
164 fn from(value: ProgPoint) -> Self {
165 Self {
166 inst: value.inst(),
167 stage: value.stage(),
168 }
169 }
170 }
171
172 impl TryFrom<SerializedProgPoint> for ProgPoint {
173 type Error = Error;
174
175 fn try_from(value: SerializedProgPoint) -> Result<Self, Self::Error> {
176 ProgPoint::try_new(value.inst, value.stage)
177 }
178 }
179
180 #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize)]
181 #[repr(u8)]
182 pub enum OperandKind {
183 Use = 0,
184 Def = 1,
185 }
186
187 #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize)]
188 pub enum Constraint {
189 /// any register or stack location
190 Any,
191 /// r1-r32
192 BaseGpr,
193 /// r2,r4,r6,r8,...r126
194 SVExtra2VGpr,
195 /// r1-63
196 SVExtra2SGpr,
197 /// r1-127
198 SVExtra3Gpr,
199 /// any stack location
200 Stack,
201 FixedLoc(Loc),
202 }
203
204 impl Constraint {
205 pub fn is_any(&self) -> bool {
206 matches!(self, Self::Any)
207 }
208 }
209
210 impl Default for Constraint {
211 fn default() -> Self {
212 Self::Any
213 }
214 }
215
216 #[derive(
217 Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize, Default,
218 )]
219 #[serde(try_from = "OperandKind", into = "OperandKind")]
220 pub struct OperandKindDefOnly;
221
222 impl TryFrom<OperandKind> for OperandKindDefOnly {
223 type Error = Error;
224
225 fn try_from(value: OperandKind) -> Result<Self, Self::Error> {
226 match value {
227 OperandKind::Use => Err(Error::OperandKindMustBeDef),
228 OperandKind::Def => Ok(Self),
229 }
230 }
231 }
232
233 impl From<OperandKindDefOnly> for OperandKind {
234 fn from(_value: OperandKindDefOnly) -> Self {
235 Self::Def
236 }
237 }
238
239 #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize)]
240 #[serde(untagged)]
241 pub enum KindAndConstraint {
242 Reuse {
243 kind: OperandKindDefOnly,
244 reuse_operand_idx: usize,
245 },
246 Constraint {
247 kind: OperandKind,
248 #[serde(default, skip_serializing_if = "Constraint::is_any")]
249 constraint: Constraint,
250 },
251 }
252
253 impl KindAndConstraint {
254 pub fn kind(self) -> OperandKind {
255 match self {
256 Self::Reuse { .. } => OperandKind::Def,
257 Self::Constraint { kind, .. } => kind,
258 }
259 }
260 pub fn is_reuse(self) -> bool {
261 matches!(self, Self::Reuse { .. })
262 }
263 }
264
265 #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
266 pub struct Operand {
267 pub ssa_val: SSAValIdx,
268 #[serde(flatten)]
269 pub kind_and_constraint: KindAndConstraint,
270 pub stage: InstStage,
271 }
272
273 impl Operand {
274 pub fn try_get_reuse_src<'f>(
275 &self,
276 inst: InstIdx,
277 func: &'f FnFields,
278 ) -> Result<Option<&'f Operand>> {
279 if let KindAndConstraint::Reuse {
280 reuse_operand_idx, ..
281 } = self.kind_and_constraint
282 {
283 Ok(Some(func.try_get_operand(inst, reuse_operand_idx)?))
284 } else {
285 Ok(None)
286 }
287 }
288 pub fn try_constraint(&self, inst: InstIdx, func: &FnFields) -> Result<Constraint> {
289 Ok(match self.kind_and_constraint {
290 KindAndConstraint::Reuse {
291 kind: _,
292 reuse_operand_idx,
293 } => {
294 let operand = func.try_get_operand(inst, reuse_operand_idx)?;
295 match operand.kind_and_constraint {
296 KindAndConstraint::Reuse { .. }
297 | KindAndConstraint::Constraint {
298 kind: OperandKind::Def,
299 ..
300 } => {
301 return Err(Error::ReuseTargetOperandMustBeUse {
302 inst,
303 reuse_target_operand_idx: reuse_operand_idx,
304 })
305 }
306 KindAndConstraint::Constraint {
307 kind: OperandKind::Use,
308 constraint,
309 } => constraint,
310 }
311 }
312 KindAndConstraint::Constraint { constraint, .. } => constraint,
313 })
314 }
315 pub fn constraint(&self, inst: InstIdx, func: &Function) -> Constraint {
316 self.try_constraint(inst, func).unwrap()
317 }
318 fn validate(
319 self,
320 _block: BlockIdx,
321 inst: InstIdx,
322 operand_idx: usize,
323 func: &FnFields,
324 global_state: &GlobalState,
325 ) -> Result<()> {
326 let Self {
327 ssa_val: ssa_val_idx,
328 kind_and_constraint,
329 stage: _,
330 } = self;
331 let ssa_val = func.try_get_ssa_val(ssa_val_idx)?;
332 match kind_and_constraint.kind() {
333 OperandKind::Use => {
334 if !ssa_val
335 .operand_uses
336 .contains(&OperandUse { inst, operand_idx })
337 {
338 return Err(Error::MissingOperandUse {
339 ssa_val_idx,
340 inst,
341 operand_idx,
342 });
343 }
344 }
345 OperandKind::Def => {
346 let def = SSAValDef::Operand { inst, operand_idx };
347 if ssa_val.def != def {
348 return Err(Error::OperandDefIsNotSSAValDef {
349 ssa_val_idx,
350 inst,
351 operand_idx,
352 });
353 }
354 }
355 }
356 let constraint = self.try_constraint(inst, func)?;
357 match constraint {
358 Constraint::Any | Constraint::Stack => {}
359 Constraint::BaseGpr | Constraint::SVExtra2SGpr => {
360 if ssa_val.ty != Ty::scalar(BaseTy::Bits64) {
361 return Err(Error::ConstraintTyMismatch {
362 ssa_val_idx,
363 inst,
364 operand_idx,
365 });
366 }
367 }
368 Constraint::SVExtra2VGpr | Constraint::SVExtra3Gpr => {
369 if ssa_val.ty.base_ty != BaseTy::Bits64 {
370 return Err(Error::ConstraintTyMismatch {
371 ssa_val_idx,
372 inst,
373 operand_idx,
374 });
375 }
376 }
377 Constraint::FixedLoc(loc) => {
378 if func
379 .try_get_inst(inst)?
380 .clobbers
381 .clone()
382 .conflicts_with(loc, global_state)
383 {
384 return Err(Error::FixedLocConflictsWithClobbers { inst, operand_idx });
385 }
386 if ssa_val.ty != loc.ty() {
387 return Err(Error::ConstraintTyMismatch {
388 ssa_val_idx,
389 inst,
390 operand_idx,
391 });
392 }
393 }
394 }
395 Ok(())
396 }
397 }
398
399 /// copy concatenates all `srcs` together and de-concatenates the result into all `dests`.
400 #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
401 pub struct CopyInstKind {
402 pub src_operand_idxs: Vec<usize>,
403 pub dest_operand_idxs: Vec<usize>,
404 pub copy_ty: Ty,
405 }
406
407 impl CopyInstKind {
408 fn calc_copy_ty(operand_idxs: &[usize], inst: InstIdx, func: &FnFields) -> Result<Option<Ty>> {
409 let mut retval: Option<Ty> = None;
410 for &operand_idx in operand_idxs {
411 let operand = func.try_get_operand(inst, operand_idx)?;
412 let ssa_val = func.try_get_ssa_val(operand.ssa_val)?;
413 retval = Some(match retval {
414 Some(retval) => retval.try_concat(ssa_val.ty)?,
415 None => ssa_val.ty,
416 });
417 }
418 Ok(retval)
419 }
420 }
421
422 #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
423 pub struct BlockTermInstKind {
424 pub succs_and_params: BTreeMap<BlockIdx, Vec<SSAValIdx>>,
425 }
426
427 #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
428 pub enum InstKind {
429 Normal,
430 Copy(CopyInstKind),
431 BlockTerm(BlockTermInstKind),
432 }
433
434 impl InstKind {
435 pub fn is_normal(&self) -> bool {
436 matches!(self, Self::Normal)
437 }
438 pub fn is_block_term(&self) -> bool {
439 matches!(self, Self::BlockTerm { .. })
440 }
441 pub fn is_copy(&self) -> bool {
442 matches!(self, Self::Copy { .. })
443 }
444 pub fn block_term(&self) -> Option<&BlockTermInstKind> {
445 match self {
446 InstKind::BlockTerm(v) => Some(v),
447 _ => None,
448 }
449 }
450 pub fn block_term_mut(&mut self) -> Option<&mut BlockTermInstKind> {
451 match self {
452 InstKind::BlockTerm(v) => Some(v),
453 _ => None,
454 }
455 }
456 pub fn copy(&self) -> Option<&CopyInstKind> {
457 match self {
458 InstKind::Copy(v) => Some(v),
459 _ => None,
460 }
461 }
462 }
463
464 impl Default for InstKind {
465 fn default() -> Self {
466 InstKind::Normal
467 }
468 }
469
470 fn loc_set_is_empty(clobbers: &Interned<LocSet>) -> bool {
471 clobbers.is_empty()
472 }
473
474 fn empty_loc_set() -> Interned<LocSet> {
475 GlobalState::get(|global_state| LocSet::default().into_interned(global_state))
476 }
477
478 #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
479 pub struct Inst {
480 #[serde(default, skip_serializing_if = "InstKind::is_normal")]
481 pub kind: InstKind,
482 pub operands: Vec<Operand>,
483 #[serde(default = "empty_loc_set", skip_serializing_if = "loc_set_is_empty")]
484 pub clobbers: Interned<LocSet>,
485 }
486
487 impl Inst {
488 fn validate(
489 &self,
490 block: BlockIdx,
491 inst: InstIdx,
492 func: &FnFields,
493 global_state: &GlobalState,
494 ) -> Result<()> {
495 let Self {
496 kind,
497 operands,
498 clobbers: _,
499 } = self;
500 let is_at_end_of_block = func.blocks[block].insts.last() == Some(inst);
501 if kind.is_block_term() != is_at_end_of_block {
502 return Err(if is_at_end_of_block {
503 Error::BlocksLastInstMustBeTerm { term_idx: inst }
504 } else {
505 Error::TermInstOnlyAllowedAtBlockEnd { inst_idx: inst }
506 });
507 }
508 for (idx, operand) in operands.iter().enumerate() {
509 operand.validate(block, inst, idx, func, global_state)?;
510 }
511 match kind {
512 InstKind::Normal => {}
513 InstKind::Copy(CopyInstKind {
514 src_operand_idxs,
515 dest_operand_idxs,
516 copy_ty,
517 }) => {
518 let mut seen_dest_operands = SmallVec::<[bool; 16]>::new();
519 seen_dest_operands.resize(operands.len(), false);
520 for &dest_operand_idx in dest_operand_idxs {
521 let seen_dest_operand = seen_dest_operands.get_mut(dest_operand_idx).ok_or(
522 Error::OperandIndexOutOfRange {
523 inst,
524 operand_idx: dest_operand_idx,
525 },
526 )?;
527 if mem::replace(seen_dest_operand, true) {
528 return Err(Error::DupCopyDestOperand {
529 inst,
530 operand_idx: dest_operand_idx,
531 });
532 }
533 }
534 if Some(*copy_ty) != CopyInstKind::calc_copy_ty(&src_operand_idxs, inst, func)? {
535 return Err(Error::CopySrcTyMismatch { inst });
536 }
537 if Some(*copy_ty) != CopyInstKind::calc_copy_ty(&dest_operand_idxs, inst, func)? {
538 return Err(Error::CopyDestTyMismatch { inst });
539 }
540 }
541 InstKind::BlockTerm(BlockTermInstKind { succs_and_params }) => {
542 for (&succ_idx, params) in succs_and_params {
543 let succ = func.try_get_block(succ_idx)?;
544 if !succ.preds.contains(&block) {
545 return Err(Error::SrcBlockMissingFromBranchTgtBlocksPreds {
546 src_block: block,
547 branch_inst: inst,
548 tgt_block: succ_idx,
549 });
550 }
551 if succ.params.len() != params.len() {
552 return Err(Error::BranchSuccParamCountMismatch {
553 inst,
554 succ: succ_idx,
555 block_param_count: succ.params.len(),
556 branch_param_count: params.len(),
557 });
558 }
559 for (param_idx, (&branch_ssa_val_idx, &block_ssa_val_idx)) in
560 params.iter().zip(&succ.params).enumerate()
561 {
562 let branch_ssa_val = func.try_get_ssa_val(branch_ssa_val_idx)?;
563 let block_ssa_val = func.try_get_ssa_val(block_ssa_val_idx)?;
564 if !branch_ssa_val
565 .branch_succ_param_uses
566 .contains(&BranchSuccParamUse {
567 branch_inst: inst,
568 succ: succ_idx,
569 param_idx,
570 })
571 {
572 return Err(Error::MissingBranchSuccParamUse {
573 ssa_val_idx: branch_ssa_val_idx,
574 inst,
575 succ: succ_idx,
576 param_idx,
577 });
578 }
579 if block_ssa_val.ty != branch_ssa_val.ty {
580 return Err(Error::BranchSuccParamTyMismatch {
581 inst,
582 succ: succ_idx,
583 param_idx,
584 block_param_ty: block_ssa_val.ty,
585 branch_param_ty: branch_ssa_val.ty,
586 });
587 }
588 }
589 }
590 }
591 }
592 Ok(())
593 }
594 pub fn try_get_operand(&self, inst: InstIdx, operand_idx: usize) -> Result<&Operand> {
595 self.operands
596 .get(operand_idx)
597 .ok_or(Error::OperandIndexOutOfRange { inst, operand_idx })
598 }
599 }
600
601 #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
602 pub struct Block {
603 pub params: Vec<SSAValIdx>,
604 pub insts: InstRange,
605 pub preds: BTreeSet<BlockIdx>,
606 pub immediate_dominator: Option<BlockIdx>,
607 }
608
609 impl Block {
610 fn validate(&self, block: BlockIdx, func: &FnFields, global_state: &GlobalState) -> Result<()> {
611 let Self {
612 params,
613 insts,
614 preds,
615 immediate_dominator: _, // validated by Function::new_with_global_state
616 } = self;
617 const _: () = assert!(BlockIdx::ENTRY_BLOCK.get() == 0);
618 let expected_start = if block == BlockIdx::ENTRY_BLOCK {
619 InstIdx::new(0)
620 } else {
621 func.blocks[block.prev()].insts.end
622 };
623 if insts.start != expected_start {
624 return Err(Error::BlockHasInvalidStart {
625 start: insts.start,
626 expected_start,
627 });
628 }
629 let term_inst_idx = insts.last().ok_or(Error::BlockIsEmpty { block })?;
630 func.insts
631 .get(term_inst_idx.get())
632 .ok_or(Error::BlockEndOutOfRange { end: insts.end })?;
633 if block.get() == func.blocks.len() - 1 && insts.end.get() != func.insts.len() {
634 return Err(Error::InstHasNoBlock { inst: insts.end });
635 }
636 if block == BlockIdx::ENTRY_BLOCK {
637 if !params.is_empty() {
638 return Err(Error::EntryBlockCantHaveParams);
639 }
640 if !preds.is_empty() {
641 return Err(Error::EntryBlockCantHavePreds);
642 }
643 }
644 for inst in *insts {
645 func.insts[inst].validate(block, inst, func, global_state)?;
646 }
647 for (param_idx, &ssa_val_idx) in params.iter().enumerate() {
648 let ssa_val = func.try_get_ssa_val(ssa_val_idx)?;
649 let def = SSAValDef::BlockParam { block, param_idx };
650 if ssa_val.def != def {
651 return Err(Error::MismatchedBlockParamDef {
652 ssa_val_idx,
653 block,
654 param_idx,
655 });
656 }
657 }
658 for &pred in preds {
659 let (term_inst, BlockTermInstKind { succs_and_params }) =
660 func.try_get_block_term_inst_and_kind(pred)?;
661 if !succs_and_params.contains_key(&pred) {
662 return Err(Error::PredMissingFromPredsTermBranchsTargets {
663 src_block: pred,
664 branch_inst: term_inst,
665 tgt_block: block,
666 });
667 }
668 if preds.len() > 1 && succs_and_params.len() > 1 {
669 return Err(Error::CriticalEdgeNotAllowed {
670 src_block: pred,
671 branch_inst: term_inst,
672 tgt_block: block,
673 });
674 }
675 }
676 Ok(())
677 }
678 }
679
680 validated_fields! {
681 #[fields_ty = FnFields]
682 #[derive(Clone, PartialEq, Eq, Debug, Hash)]
683 pub struct Function {
684 pub ssa_vals: Vec<SSAVal>,
685 pub insts: Vec<Inst>,
686 pub blocks: Vec<Block>,
687 #[serde(skip)]
688 /// map from blocks' start instruction's index to their block index, doesn't contain the entry block
689 pub start_inst_to_block_map: BTreeMap<InstIdx, BlockIdx>,
690 }
691 }
692
693 impl Function {
694 pub fn new(fields: FnFields) -> Result<Self> {
695 GlobalState::get(|global_state| Self::new_with_global_state(fields, global_state))
696 }
697 pub fn new_with_global_state(mut fields: FnFields, global_state: &GlobalState) -> Result<Self> {
698 fields.fill_start_inst_to_block_map();
699 let FnFields {
700 ssa_vals,
701 insts: _,
702 blocks,
703 start_inst_to_block_map: _,
704 } = &fields;
705 blocks
706 .get(BlockIdx::ENTRY_BLOCK.get())
707 .ok_or(Error::MissingEntryBlock)?;
708 for (idx, block) in blocks.iter().enumerate() {
709 block.validate(BlockIdx::new(idx), &fields, global_state)?;
710 }
711 let dominators = dominators::simple_fast(&fields, BlockIdx::ENTRY_BLOCK);
712 for (idx, block) in blocks.iter().enumerate() {
713 let block_idx = BlockIdx::new(idx);
714 let expected = dominators.immediate_dominator(block_idx);
715 if block.immediate_dominator != expected {
716 return Err(Error::IncorrectImmediateDominator {
717 block_idx,
718 found: block.immediate_dominator,
719 expected,
720 });
721 }
722 }
723 for (idx, ssa_val) in ssa_vals.iter().enumerate() {
724 ssa_val.validate(SSAValIdx::new(idx), &fields)?;
725 }
726 Ok(Self(fields))
727 }
728 pub fn entry_block(&self) -> &Block {
729 &self.blocks[0]
730 }
731 pub fn block_term_kind(&self, block: BlockIdx) -> &BlockTermInstKind {
732 self.insts[self.blocks[block].insts.last().unwrap()]
733 .kind
734 .block_term()
735 .unwrap()
736 }
737 }
738
739 impl FnFields {
740 pub fn fill_start_inst_to_block_map(&mut self) {
741 self.start_inst_to_block_map.clear();
742 for (idx, block) in self.blocks.iter().enumerate() {
743 let block_idx = BlockIdx::new(idx);
744 if block_idx != BlockIdx::ENTRY_BLOCK {
745 self.start_inst_to_block_map
746 .insert(block.insts.start, block_idx);
747 }
748 }
749 }
750 pub fn try_get_ssa_val(&self, idx: SSAValIdx) -> Result<&SSAVal> {
751 self.ssa_vals
752 .get(idx.get())
753 .ok_or(Error::SSAValIdxOutOfRange { idx })
754 }
755 pub fn try_get_inst(&self, idx: InstIdx) -> Result<&Inst> {
756 self.insts
757 .get(idx.get())
758 .ok_or(Error::InstIdxOutOfRange { idx })
759 }
760 pub fn try_get_inst_mut(&mut self, idx: InstIdx) -> Result<&mut Inst> {
761 self.insts
762 .get_mut(idx.get())
763 .ok_or(Error::InstIdxOutOfRange { idx })
764 }
765 pub fn try_get_operand(&self, inst: InstIdx, operand_idx: usize) -> Result<&Operand> {
766 self.try_get_inst(inst)?.try_get_operand(inst, operand_idx)
767 }
768 pub fn try_get_block(&self, idx: BlockIdx) -> Result<&Block> {
769 self.blocks
770 .get(idx.get())
771 .ok_or(Error::BlockIdxOutOfRange { idx })
772 }
773 pub fn try_get_block_param(&self, block: BlockIdx, param_idx: usize) -> Result<SSAValIdx> {
774 self.try_get_block(block)?
775 .params
776 .get(param_idx)
777 .copied()
778 .ok_or(Error::BlockParamIdxOutOfRange { block, param_idx })
779 }
780 pub fn try_get_block_term_inst_idx(&self, block: BlockIdx) -> Result<InstIdx> {
781 self.try_get_block(block)?
782 .insts
783 .last()
784 .ok_or(Error::BlockIsEmpty { block })
785 }
786 pub fn try_get_block_term_inst_and_kind(
787 &self,
788 block: BlockIdx,
789 ) -> Result<(InstIdx, &BlockTermInstKind)> {
790 let term_idx = self.try_get_block_term_inst_idx(block)?;
791 let term_kind = self
792 .try_get_inst(term_idx)?
793 .kind
794 .block_term()
795 .ok_or(Error::BlocksLastInstMustBeTerm { term_idx })?;
796 Ok((term_idx, term_kind))
797 }
798 pub fn try_get_block_term_inst_and_kind_mut(
799 &mut self,
800 block: BlockIdx,
801 ) -> Result<(InstIdx, &mut BlockTermInstKind)> {
802 let term_idx = self.try_get_block_term_inst_idx(block)?;
803 let term_kind = self
804 .try_get_inst_mut(term_idx)?
805 .kind
806 .block_term_mut()
807 .ok_or(Error::BlocksLastInstMustBeTerm { term_idx })?;
808 Ok((term_idx, term_kind))
809 }
810 pub fn try_get_branch_target_params(
811 &self,
812 branch_inst: InstIdx,
813 succ: BlockIdx,
814 ) -> Result<&[SSAValIdx]> {
815 let inst = self.try_get_inst(branch_inst)?;
816 let BlockTermInstKind { succs_and_params } = inst
817 .kind
818 .block_term()
819 .ok_or(Error::InstIsNotBlockTerm { inst: branch_inst })?;
820 Ok(succs_and_params
821 .get(&succ)
822 .ok_or(Error::BranchTargetNotFound {
823 branch_inst,
824 tgt_block: succ,
825 })?)
826 }
827 pub fn try_get_branch_target_param(
828 &self,
829 branch_inst: InstIdx,
830 succ: BlockIdx,
831 param_idx: usize,
832 ) -> Result<SSAValIdx> {
833 Ok(*self
834 .try_get_branch_target_params(branch_inst, succ)?
835 .get(param_idx)
836 .ok_or(Error::BranchTargetParamIdxOutOfRange {
837 branch_inst,
838 tgt_block: succ,
839 param_idx,
840 })?)
841 }
842 pub fn inst_to_block(&self, inst: InstIdx) -> BlockIdx {
843 self.start_inst_to_block_map
844 .range(..=inst)
845 .next_back()
846 .map(|v| *v.1)
847 .unwrap_or(BlockIdx::ENTRY_BLOCK)
848 }
849 }
850
851 impl Index<SSAValIdx> for Vec<SSAVal> {
852 type Output = SSAVal;
853
854 fn index(&self, index: SSAValIdx) -> &Self::Output {
855 &self[index.get()]
856 }
857 }
858
859 impl IndexMut<SSAValIdx> for Vec<SSAVal> {
860 fn index_mut(&mut self, index: SSAValIdx) -> &mut Self::Output {
861 &mut self[index.get()]
862 }
863 }
864
865 impl Index<InstIdx> for Vec<Inst> {
866 type Output = Inst;
867
868 fn index(&self, index: InstIdx) -> &Self::Output {
869 &self[index.get()]
870 }
871 }
872
873 impl IndexMut<InstIdx> for Vec<Inst> {
874 fn index_mut(&mut self, index: InstIdx) -> &mut Self::Output {
875 &mut self[index.get()]
876 }
877 }
878
879 impl Index<BlockIdx> for Vec<Block> {
880 type Output = Block;
881
882 fn index(&self, index: BlockIdx) -> &Self::Output {
883 &self[index.get()]
884 }
885 }
886
887 impl IndexMut<BlockIdx> for Vec<Block> {
888 fn index_mut(&mut self, index: BlockIdx) -> &mut Self::Output {
889 &mut self[index.get()]
890 }
891 }
892
893 impl GraphBase for FnFields {
894 type EdgeId = (BlockIdx, BlockIdx);
895 type NodeId = BlockIdx;
896 }
897
898 pub struct Neighbors<'a> {
899 iter: Option<btree_map::Keys<'a, BlockIdx, Vec<SSAValIdx>>>,
900 }
901
902 impl Iterator for Neighbors<'_> {
903 type Item = BlockIdx;
904
905 fn next(&mut self) -> Option<Self::Item> {
906 Some(*self.iter.as_mut()?.next()?)
907 }
908 }
909
910 impl<'a> IntoNeighbors for &'a FnFields {
911 type Neighbors = Neighbors<'a>;
912
913 fn neighbors(self, block_idx: Self::NodeId) -> Self::Neighbors {
914 Neighbors {
915 iter: self
916 .try_get_block_term_inst_and_kind(block_idx)
917 .ok()
918 .map(|(_, BlockTermInstKind { succs_and_params })| succs_and_params.keys()),
919 }
920 }
921 }
922
923 pub struct VisitedMap(HashSet<BlockIdx>);
924
925 impl VisitMap<BlockIdx> for VisitedMap {
926 fn visit(&mut self, block: BlockIdx) -> bool {
927 self.0.insert(block)
928 }
929
930 fn is_visited(&self, block: &BlockIdx) -> bool {
931 self.0.contains(block)
932 }
933 }
934
935 impl Visitable for FnFields {
936 type Map = VisitedMap;
937
938 fn visit_map(&self) -> Self::Map {
939 VisitedMap(HashSet::new())
940 }
941
942 fn reset_map(&self, map: &mut Self::Map) {
943 map.0.clear();
944 }
945 }
946
947 impl GraphProp for FnFields {
948 type EdgeType = Directed;
949 }