2a1e811b88f0ea377f52743af3aeae5300ac76e1
[bigint-presentation-code.git] / register_allocator / src / function.rs
1 use crate::{
2 error::{Error, Result},
3 index::{BlockIdx, InstIdx, InstRange, SSAValIdx},
4 interned::{GlobalState, Intern, Interned},
5 loc::{BaseTy, Loc, Ty},
6 loc_set::LocSet,
7 };
8 use core::fmt;
9 use hashbrown::HashSet;
10 use petgraph::{
11 algo::dominators,
12 visit::{GraphBase, GraphProp, IntoNeighbors, VisitMap, Visitable},
13 Directed,
14 };
15 use serde::{Deserialize, Serialize};
16 use smallvec::SmallVec;
17 use std::{
18 collections::{btree_map, BTreeMap, BTreeSet},
19 mem,
20 ops::Index,
21 };
22
23 #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
24 pub enum SSAValDef {
25 BlockParam { block: BlockIdx, param_idx: usize },
26 Operand { inst: InstIdx, operand_idx: usize },
27 }
28
29 #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize)]
30 pub struct BranchSuccParamUse {
31 branch_inst: InstIdx,
32 succ: BlockIdx,
33 param_idx: usize,
34 }
35
36 #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize)]
37 pub struct OperandUse {
38 inst: InstIdx,
39 operand_idx: usize,
40 }
41
42 #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
43 pub struct SSAVal {
44 pub ty: Ty,
45 pub def: SSAValDef,
46 pub operand_uses: BTreeSet<OperandUse>,
47 pub branch_succ_param_uses: BTreeSet<BranchSuccParamUse>,
48 }
49
50 impl SSAVal {
51 fn validate(&self, ssa_val_idx: SSAValIdx, func: &FnFields) -> Result<()> {
52 let Self {
53 ty: _,
54 def,
55 operand_uses,
56 branch_succ_param_uses,
57 } = self;
58 match *def {
59 SSAValDef::BlockParam { block, param_idx } => {
60 let block_param = func.try_get_block_param(block, param_idx)?;
61 if ssa_val_idx != block_param {
62 return Err(Error::MismatchedBlockParamDef {
63 ssa_val_idx,
64 block,
65 param_idx,
66 });
67 }
68 }
69 SSAValDef::Operand { inst, operand_idx } => {
70 let operand = func.try_get_operand(inst, operand_idx)?;
71 if ssa_val_idx != operand.ssa_val {
72 return Err(Error::SSAValDefIsNotOperandsSSAVal {
73 ssa_val_idx,
74 inst,
75 operand_idx,
76 });
77 }
78 }
79 }
80 for &OperandUse { inst, operand_idx } in operand_uses {
81 let operand = func.try_get_operand(inst, operand_idx)?;
82 if ssa_val_idx != operand.ssa_val {
83 return Err(Error::SSAValUseIsNotOperandsSSAVal {
84 ssa_val_idx,
85 inst,
86 operand_idx,
87 });
88 }
89 }
90 for &BranchSuccParamUse {
91 branch_inst,
92 succ,
93 param_idx,
94 } in branch_succ_param_uses
95 {
96 if ssa_val_idx != func.try_get_branch_target_param(branch_inst, succ, param_idx)? {
97 return Err(Error::MismatchedBranchTargetBlockParamUse {
98 ssa_val_idx,
99 branch_inst,
100 tgt_block: succ,
101 param_idx,
102 });
103 }
104 }
105 Ok(())
106 }
107 }
108
109 #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize)]
110 #[repr(u8)]
111 pub enum InstStage {
112 Early = 0,
113 Late = 1,
114 }
115
116 #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
117 #[serde(try_from = "SerializedProgPoint", into = "SerializedProgPoint")]
118 pub struct ProgPoint(usize);
119
120 impl ProgPoint {
121 pub const fn new(inst: InstIdx, stage: InstStage) -> Self {
122 const_unwrap_res!(Self::try_new(inst, stage))
123 }
124 pub const fn try_new(inst: InstIdx, stage: InstStage) -> Result<Self> {
125 let Some(inst) = inst.get().checked_shl(1) else {
126 return Err(Error::InstIdxTooBig);
127 };
128 Ok(Self(inst | stage as usize))
129 }
130 pub const fn inst(self) -> InstIdx {
131 InstIdx::new(self.0 >> 1)
132 }
133 pub const fn stage(self) -> InstStage {
134 if self.0 & 1 != 0 {
135 InstStage::Late
136 } else {
137 InstStage::Early
138 }
139 }
140 pub const fn next(self) -> Self {
141 Self(self.0 + 1)
142 }
143 pub const fn prev(self) -> Self {
144 Self(self.0 - 1)
145 }
146 }
147
148 impl fmt::Debug for ProgPoint {
149 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
150 f.debug_struct("ProgPoint")
151 .field("inst", &self.inst())
152 .field("stage", &self.stage())
153 .finish()
154 }
155 }
156
157 #[derive(Serialize, Deserialize)]
158 struct SerializedProgPoint {
159 inst: InstIdx,
160 stage: InstStage,
161 }
162
163 impl From<ProgPoint> for SerializedProgPoint {
164 fn from(value: ProgPoint) -> Self {
165 Self {
166 inst: value.inst(),
167 stage: value.stage(),
168 }
169 }
170 }
171
172 impl TryFrom<SerializedProgPoint> for ProgPoint {
173 type Error = Error;
174
175 fn try_from(value: SerializedProgPoint) -> Result<Self, Self::Error> {
176 ProgPoint::try_new(value.inst, value.stage)
177 }
178 }
179
180 #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize)]
181 #[repr(u8)]
182 pub enum OperandKind {
183 Use = 0,
184 Def = 1,
185 }
186
187 #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize)]
188 pub enum Constraint {
189 /// any register or stack location
190 Any,
191 /// r1-r32
192 BaseGpr,
193 /// r2,r4,r6,r8,...r126
194 SVExtra2VGpr,
195 /// r1-63
196 SVExtra2SGpr,
197 /// r1-127
198 SVExtra3Gpr,
199 /// any stack location
200 Stack,
201 FixedLoc(Loc),
202 }
203
204 impl Constraint {
205 pub fn is_any(&self) -> bool {
206 matches!(self, Self::Any)
207 }
208 }
209
210 impl Default for Constraint {
211 fn default() -> Self {
212 Self::Any
213 }
214 }
215
216 #[derive(
217 Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize, Default,
218 )]
219 #[serde(try_from = "OperandKind", into = "OperandKind")]
220 pub struct OperandKindDefOnly;
221
222 impl TryFrom<OperandKind> for OperandKindDefOnly {
223 type Error = Error;
224
225 fn try_from(value: OperandKind) -> Result<Self, Self::Error> {
226 match value {
227 OperandKind::Use => Err(Error::OperandKindMustBeDef),
228 OperandKind::Def => Ok(Self),
229 }
230 }
231 }
232
233 impl From<OperandKindDefOnly> for OperandKind {
234 fn from(_value: OperandKindDefOnly) -> Self {
235 Self::Def
236 }
237 }
238
239 #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize)]
240 #[serde(untagged)]
241 pub enum KindAndConstraint {
242 Reuse {
243 kind: OperandKindDefOnly,
244 reuse_operand_idx: usize,
245 },
246 Constraint {
247 kind: OperandKind,
248 #[serde(default, skip_serializing_if = "Constraint::is_any")]
249 constraint: Constraint,
250 },
251 }
252
253 impl KindAndConstraint {
254 pub fn kind(self) -> OperandKind {
255 match self {
256 Self::Reuse { .. } => OperandKind::Def,
257 Self::Constraint { kind, .. } => kind,
258 }
259 }
260 pub fn is_reuse(self) -> bool {
261 matches!(self, Self::Reuse { .. })
262 }
263 }
264
265 #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
266 pub struct Operand {
267 pub ssa_val: SSAValIdx,
268 #[serde(flatten)]
269 pub kind_and_constraint: KindAndConstraint,
270 pub stage: InstStage,
271 }
272
273 impl Operand {
274 pub fn try_get_reuse_src<'f>(
275 &self,
276 inst: InstIdx,
277 func: &'f FnFields,
278 ) -> Result<Option<&'f Operand>> {
279 if let KindAndConstraint::Reuse {
280 reuse_operand_idx, ..
281 } = self.kind_and_constraint
282 {
283 Ok(Some(func.try_get_operand(inst, reuse_operand_idx)?))
284 } else {
285 Ok(None)
286 }
287 }
288 pub fn try_constraint(&self, inst: InstIdx, func: &FnFields) -> Result<Constraint> {
289 Ok(match self.kind_and_constraint {
290 KindAndConstraint::Reuse {
291 kind: _,
292 reuse_operand_idx,
293 } => {
294 let operand = func.try_get_operand(inst, reuse_operand_idx)?;
295 match operand.kind_and_constraint {
296 KindAndConstraint::Reuse { .. }
297 | KindAndConstraint::Constraint {
298 kind: OperandKind::Def,
299 ..
300 } => {
301 return Err(Error::ReuseTargetOperandMustBeUse {
302 inst,
303 reuse_target_operand_idx: reuse_operand_idx,
304 })
305 }
306 KindAndConstraint::Constraint {
307 kind: OperandKind::Use,
308 constraint,
309 } => constraint,
310 }
311 }
312 KindAndConstraint::Constraint { constraint, .. } => constraint,
313 })
314 }
315 pub fn constraint(&self, inst: InstIdx, func: &Function) -> Constraint {
316 self.try_constraint(inst, func).unwrap()
317 }
318 fn validate(
319 self,
320 _block: BlockIdx,
321 inst: InstIdx,
322 operand_idx: usize,
323 func: &FnFields,
324 global_state: &GlobalState,
325 ) -> Result<()> {
326 let Self {
327 ssa_val: ssa_val_idx,
328 kind_and_constraint,
329 stage: _,
330 } = self;
331 let ssa_val = func.try_get_ssa_val(ssa_val_idx)?;
332 match kind_and_constraint.kind() {
333 OperandKind::Use => {
334 if !ssa_val
335 .operand_uses
336 .contains(&OperandUse { inst, operand_idx })
337 {
338 return Err(Error::MissingOperandUse {
339 ssa_val_idx,
340 inst,
341 operand_idx,
342 });
343 }
344 }
345 OperandKind::Def => {
346 let def = SSAValDef::Operand { inst, operand_idx };
347 if ssa_val.def != def {
348 return Err(Error::OperandDefIsNotSSAValDef {
349 ssa_val_idx,
350 inst,
351 operand_idx,
352 });
353 }
354 }
355 }
356 let constraint = self.try_constraint(inst, func)?;
357 match constraint {
358 Constraint::Any | Constraint::Stack => {}
359 Constraint::BaseGpr | Constraint::SVExtra2SGpr => {
360 if ssa_val.ty != Ty::scalar(BaseTy::Bits64) {
361 return Err(Error::ConstraintTyMismatch {
362 ssa_val_idx,
363 inst,
364 operand_idx,
365 });
366 }
367 }
368 Constraint::SVExtra2VGpr | Constraint::SVExtra3Gpr => {
369 if ssa_val.ty.base_ty != BaseTy::Bits64 {
370 return Err(Error::ConstraintTyMismatch {
371 ssa_val_idx,
372 inst,
373 operand_idx,
374 });
375 }
376 }
377 Constraint::FixedLoc(loc) => {
378 if func
379 .try_get_inst(inst)?
380 .clobbers
381 .clone()
382 .conflicts_with(loc, global_state)
383 {
384 return Err(Error::FixedLocConflictsWithClobbers { inst, operand_idx });
385 }
386 if ssa_val.ty != loc.ty() {
387 return Err(Error::ConstraintTyMismatch {
388 ssa_val_idx,
389 inst,
390 operand_idx,
391 });
392 }
393 }
394 }
395 Ok(())
396 }
397 }
398
399 /// copy concatenates all `srcs` together and de-concatenates the result into all `dests`.
400 #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
401 pub struct CopyInstKind {
402 pub src_operand_idxs: Vec<usize>,
403 pub dest_operand_idxs: Vec<usize>,
404 pub copy_ty: Ty,
405 }
406
407 impl CopyInstKind {
408 fn calc_copy_ty(operand_idxs: &[usize], inst: InstIdx, func: &FnFields) -> Result<Option<Ty>> {
409 let mut retval: Option<Ty> = None;
410 for &operand_idx in operand_idxs {
411 let operand = func.try_get_operand(inst, operand_idx)?;
412 let ssa_val = func.try_get_ssa_val(operand.ssa_val)?;
413 retval = Some(match retval {
414 Some(retval) => retval.try_concat(ssa_val.ty)?,
415 None => ssa_val.ty,
416 });
417 }
418 Ok(retval)
419 }
420 }
421
422 #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
423 pub struct BlockTermInstKind {
424 pub succs_and_params: BTreeMap<BlockIdx, Vec<SSAValIdx>>,
425 }
426
427 #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
428 pub enum InstKind {
429 Normal,
430 Copy(CopyInstKind),
431 BlockTerm(BlockTermInstKind),
432 }
433
434 impl InstKind {
435 pub fn is_normal(&self) -> bool {
436 matches!(self, Self::Normal)
437 }
438 pub fn is_block_term(&self) -> bool {
439 matches!(self, Self::BlockTerm { .. })
440 }
441 pub fn is_copy(&self) -> bool {
442 matches!(self, Self::Copy { .. })
443 }
444 pub fn block_term(&self) -> Option<&BlockTermInstKind> {
445 match self {
446 InstKind::BlockTerm(v) => Some(v),
447 _ => None,
448 }
449 }
450 pub fn copy(&self) -> Option<&CopyInstKind> {
451 match self {
452 InstKind::Copy(v) => Some(v),
453 _ => None,
454 }
455 }
456 }
457
458 impl Default for InstKind {
459 fn default() -> Self {
460 InstKind::Normal
461 }
462 }
463
464 fn loc_set_is_empty(clobbers: &Interned<LocSet>) -> bool {
465 clobbers.is_empty()
466 }
467
468 fn empty_loc_set() -> Interned<LocSet> {
469 GlobalState::get(|global_state| LocSet::default().into_interned(global_state))
470 }
471
472 #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
473 pub struct Inst {
474 #[serde(default, skip_serializing_if = "InstKind::is_normal")]
475 pub kind: InstKind,
476 pub operands: Vec<Operand>,
477 #[serde(default = "empty_loc_set", skip_serializing_if = "loc_set_is_empty")]
478 pub clobbers: Interned<LocSet>,
479 }
480
481 impl Inst {
482 fn validate(
483 &self,
484 block: BlockIdx,
485 inst: InstIdx,
486 func: &FnFields,
487 global_state: &GlobalState,
488 ) -> Result<()> {
489 let Self {
490 kind,
491 operands,
492 clobbers: _,
493 } = self;
494 let is_at_end_of_block = func.blocks[block].insts.last() == Some(inst);
495 if kind.is_block_term() != is_at_end_of_block {
496 return Err(if is_at_end_of_block {
497 Error::BlocksLastInstMustBeTerm { term_idx: inst }
498 } else {
499 Error::TermInstOnlyAllowedAtBlockEnd { inst_idx: inst }
500 });
501 }
502 for (idx, operand) in operands.iter().enumerate() {
503 operand.validate(block, inst, idx, func, global_state)?;
504 }
505 match kind {
506 InstKind::Normal => {}
507 InstKind::Copy(CopyInstKind {
508 src_operand_idxs,
509 dest_operand_idxs,
510 copy_ty,
511 }) => {
512 let mut seen_dest_operands = SmallVec::<[bool; 16]>::new();
513 seen_dest_operands.resize(operands.len(), false);
514 for &dest_operand_idx in dest_operand_idxs {
515 let seen_dest_operand = seen_dest_operands.get_mut(dest_operand_idx).ok_or(
516 Error::OperandIndexOutOfRange {
517 inst,
518 operand_idx: dest_operand_idx,
519 },
520 )?;
521 if mem::replace(seen_dest_operand, true) {
522 return Err(Error::DupCopyDestOperand {
523 inst,
524 operand_idx: dest_operand_idx,
525 });
526 }
527 }
528 if Some(*copy_ty) != CopyInstKind::calc_copy_ty(&src_operand_idxs, inst, func)? {
529 return Err(Error::CopySrcTyMismatch { inst });
530 }
531 if Some(*copy_ty) != CopyInstKind::calc_copy_ty(&dest_operand_idxs, inst, func)? {
532 return Err(Error::CopyDestTyMismatch { inst });
533 }
534 }
535 InstKind::BlockTerm(BlockTermInstKind { succs_and_params }) => {
536 for (&succ_idx, params) in succs_and_params {
537 let succ = func.try_get_block(succ_idx)?;
538 if !succ.preds.contains(&block) {
539 return Err(Error::SrcBlockMissingFromBranchTgtBlocksPreds {
540 src_block: block,
541 branch_inst: inst,
542 tgt_block: succ_idx,
543 });
544 }
545 if succ.params.len() != params.len() {
546 return Err(Error::BranchSuccParamCountMismatch {
547 inst,
548 succ: succ_idx,
549 block_param_count: succ.params.len(),
550 branch_param_count: params.len(),
551 });
552 }
553 for (param_idx, (&branch_ssa_val_idx, &block_ssa_val_idx)) in
554 params.iter().zip(&succ.params).enumerate()
555 {
556 let branch_ssa_val = func.try_get_ssa_val(branch_ssa_val_idx)?;
557 let block_ssa_val = func.try_get_ssa_val(block_ssa_val_idx)?;
558 if !branch_ssa_val
559 .branch_succ_param_uses
560 .contains(&BranchSuccParamUse {
561 branch_inst: inst,
562 succ: succ_idx,
563 param_idx,
564 })
565 {
566 return Err(Error::MissingBranchSuccParamUse {
567 ssa_val_idx: branch_ssa_val_idx,
568 inst,
569 succ: succ_idx,
570 param_idx,
571 });
572 }
573 if block_ssa_val.ty != branch_ssa_val.ty {
574 return Err(Error::BranchSuccParamTyMismatch {
575 inst,
576 succ: succ_idx,
577 param_idx,
578 block_param_ty: block_ssa_val.ty,
579 branch_param_ty: branch_ssa_val.ty,
580 });
581 }
582 }
583 }
584 }
585 }
586 Ok(())
587 }
588 pub fn try_get_operand(&self, inst: InstIdx, operand_idx: usize) -> Result<&Operand> {
589 self.operands
590 .get(operand_idx)
591 .ok_or(Error::OperandIndexOutOfRange { inst, operand_idx })
592 }
593 }
594
595 #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
596 pub struct Block {
597 pub params: Vec<SSAValIdx>,
598 pub insts: InstRange,
599 pub preds: BTreeSet<BlockIdx>,
600 pub immediate_dominator: Option<BlockIdx>,
601 }
602
603 impl Block {
604 fn validate(&self, block: BlockIdx, func: &FnFields, global_state: &GlobalState) -> Result<()> {
605 let Self {
606 params,
607 insts,
608 preds,
609 immediate_dominator: _,
610 } = self;
611 const _: () = assert!(BlockIdx::ENTRY_BLOCK.get() == 0);
612 let expected_start = if block == BlockIdx::ENTRY_BLOCK {
613 InstIdx::new(0)
614 } else {
615 func.blocks[block.prev()].insts.end
616 };
617 if insts.start != expected_start {
618 return Err(Error::BlockHasInvalidStart {
619 start: insts.start,
620 expected_start,
621 });
622 }
623 let term_inst_idx = insts.last().ok_or(Error::BlockIsEmpty { block })?;
624 func.insts
625 .get(term_inst_idx.get())
626 .ok_or(Error::BlockEndOutOfRange { end: insts.end })?;
627 if block.get() == func.blocks.len() - 1 && insts.end.get() != func.insts.len() {
628 return Err(Error::InstHasNoBlock { inst: insts.end });
629 }
630 if block == BlockIdx::ENTRY_BLOCK {
631 if !params.is_empty() {
632 return Err(Error::EntryBlockCantHaveParams);
633 }
634 if !preds.is_empty() {
635 return Err(Error::EntryBlockCantHavePreds);
636 }
637 }
638 for inst in *insts {
639 func.insts[inst].validate(block, inst, func, global_state)?;
640 }
641 for (param_idx, &ssa_val_idx) in params.iter().enumerate() {
642 let ssa_val = func.try_get_ssa_val(ssa_val_idx)?;
643 let def = SSAValDef::BlockParam { block, param_idx };
644 if ssa_val.def != def {
645 return Err(Error::MismatchedBlockParamDef {
646 ssa_val_idx,
647 block,
648 param_idx,
649 });
650 }
651 }
652 for &pred in preds {
653 let (term_inst, BlockTermInstKind { succs_and_params }) =
654 func.try_get_block_term_inst_and_kind(pred)?;
655 if !succs_and_params.contains_key(&pred) {
656 return Err(Error::PredMissingFromPredsTermBranchsTargets {
657 src_block: pred,
658 branch_inst: term_inst,
659 tgt_block: block,
660 });
661 }
662 }
663 Ok(())
664 }
665 }
666
667 validated_fields! {
668 #[fields_ty = FnFields]
669 #[derive(Clone, PartialEq, Eq, Debug, Hash)]
670 pub struct Function {
671 pub ssa_vals: Vec<SSAVal>,
672 pub insts: Vec<Inst>,
673 pub blocks: Vec<Block>,
674 #[serde(skip)]
675 /// map from blocks' start instruction's index to their block index, doesn't contain the entry block
676 pub start_inst_to_block_map: BTreeMap<InstIdx, BlockIdx>,
677 }
678 }
679
680 impl Function {
681 pub fn new(fields: FnFields) -> Result<Self> {
682 GlobalState::get(|global_state| Self::new_with_global_state(fields, global_state))
683 }
684 pub fn new_with_global_state(mut fields: FnFields, global_state: &GlobalState) -> Result<Self> {
685 fields.fill_start_inst_to_block_map();
686 let FnFields {
687 ssa_vals,
688 insts: _,
689 blocks,
690 start_inst_to_block_map: _,
691 } = &fields;
692 blocks
693 .get(BlockIdx::ENTRY_BLOCK.get())
694 .ok_or(Error::MissingEntryBlock)?;
695 for (idx, block) in blocks.iter().enumerate() {
696 block.validate(BlockIdx::new(idx), &fields, global_state)?;
697 }
698 let dominators = dominators::simple_fast(&fields, BlockIdx::ENTRY_BLOCK);
699 for (idx, block) in blocks.iter().enumerate() {
700 let block_idx = BlockIdx::new(idx);
701 let expected = dominators.immediate_dominator(block_idx);
702 if block.immediate_dominator != expected {
703 return Err(Error::IncorrectImmediateDominator {
704 block_idx,
705 found: block.immediate_dominator,
706 expected,
707 });
708 }
709 }
710 for (idx, ssa_val) in ssa_vals.iter().enumerate() {
711 ssa_val.validate(SSAValIdx::new(idx), &fields)?;
712 }
713 Ok(Self(fields))
714 }
715 pub fn entry_block(&self) -> &Block {
716 &self.blocks[0]
717 }
718 pub fn block_term_kind(&self, block: BlockIdx) -> &BlockTermInstKind {
719 self.insts[self.blocks[block].insts.last().unwrap()]
720 .kind
721 .block_term()
722 .unwrap()
723 }
724 }
725
726 impl FnFields {
727 pub fn fill_start_inst_to_block_map(&mut self) {
728 self.start_inst_to_block_map.clear();
729 for (idx, block) in self.blocks.iter().enumerate() {
730 let block_idx = BlockIdx::new(idx);
731 if block_idx != BlockIdx::ENTRY_BLOCK {
732 self.start_inst_to_block_map
733 .insert(block.insts.start, block_idx);
734 }
735 }
736 }
737 pub fn try_get_ssa_val(&self, idx: SSAValIdx) -> Result<&SSAVal> {
738 self.ssa_vals
739 .get(idx.get())
740 .ok_or(Error::SSAValIdxOutOfRange { idx })
741 }
742 pub fn try_get_inst(&self, idx: InstIdx) -> Result<&Inst> {
743 self.insts
744 .get(idx.get())
745 .ok_or(Error::InstIdxOutOfRange { idx })
746 }
747 pub fn try_get_operand(&self, inst: InstIdx, operand_idx: usize) -> Result<&Operand> {
748 self.try_get_inst(inst)?.try_get_operand(inst, operand_idx)
749 }
750 pub fn try_get_block(&self, idx: BlockIdx) -> Result<&Block> {
751 self.blocks
752 .get(idx.get())
753 .ok_or(Error::BlockIdxOutOfRange { idx })
754 }
755 pub fn try_get_block_param(&self, block: BlockIdx, param_idx: usize) -> Result<SSAValIdx> {
756 self.try_get_block(block)?
757 .params
758 .get(param_idx)
759 .copied()
760 .ok_or(Error::BlockParamIdxOutOfRange { block, param_idx })
761 }
762 pub fn try_get_block_term_inst_idx(&self, block: BlockIdx) -> Result<InstIdx> {
763 self.try_get_block(block)?
764 .insts
765 .last()
766 .ok_or(Error::BlockIsEmpty { block })
767 }
768 pub fn try_get_block_term_inst_and_kind(
769 &self,
770 block: BlockIdx,
771 ) -> Result<(InstIdx, &BlockTermInstKind)> {
772 let term_idx = self.try_get_block_term_inst_idx(block)?;
773 let term_kind = self
774 .try_get_inst(term_idx)?
775 .kind
776 .block_term()
777 .ok_or(Error::BlocksLastInstMustBeTerm { term_idx })?;
778 Ok((term_idx, term_kind))
779 }
780 pub fn try_get_branch_target_params(
781 &self,
782 branch_inst: InstIdx,
783 succ: BlockIdx,
784 ) -> Result<&[SSAValIdx]> {
785 let inst = self.try_get_inst(branch_inst)?;
786 let BlockTermInstKind { succs_and_params } = inst
787 .kind
788 .block_term()
789 .ok_or(Error::InstIsNotBlockTerm { inst: branch_inst })?;
790 Ok(succs_and_params
791 .get(&succ)
792 .ok_or(Error::BranchTargetNotFound {
793 branch_inst,
794 tgt_block: succ,
795 })?)
796 }
797 pub fn try_get_branch_target_param(
798 &self,
799 branch_inst: InstIdx,
800 succ: BlockIdx,
801 param_idx: usize,
802 ) -> Result<SSAValIdx> {
803 Ok(*self
804 .try_get_branch_target_params(branch_inst, succ)?
805 .get(param_idx)
806 .ok_or(Error::BranchTargetParamIdxOutOfRange {
807 branch_inst,
808 tgt_block: succ,
809 param_idx,
810 })?)
811 }
812 pub fn inst_to_block(&self, inst: InstIdx) -> BlockIdx {
813 self.start_inst_to_block_map
814 .range(..=inst)
815 .next_back()
816 .map(|v| *v.1)
817 .unwrap_or(BlockIdx::ENTRY_BLOCK)
818 }
819 }
820
821 impl Index<SSAValIdx> for Vec<SSAVal> {
822 type Output = SSAVal;
823
824 fn index(&self, index: SSAValIdx) -> &Self::Output {
825 &self[index.get()]
826 }
827 }
828
829 impl Index<InstIdx> for Vec<Inst> {
830 type Output = Inst;
831
832 fn index(&self, index: InstIdx) -> &Self::Output {
833 &self[index.get()]
834 }
835 }
836
837 impl Index<BlockIdx> for Vec<Block> {
838 type Output = Block;
839
840 fn index(&self, index: BlockIdx) -> &Self::Output {
841 &self[index.get()]
842 }
843 }
844
845 impl GraphBase for FnFields {
846 type EdgeId = (BlockIdx, BlockIdx);
847 type NodeId = BlockIdx;
848 }
849
850 pub struct Neighbors<'a> {
851 iter: Option<btree_map::Keys<'a, BlockIdx, Vec<SSAValIdx>>>,
852 }
853
854 impl Iterator for Neighbors<'_> {
855 type Item = BlockIdx;
856
857 fn next(&mut self) -> Option<Self::Item> {
858 Some(*self.iter.as_mut()?.next()?)
859 }
860 }
861
862 impl<'a> IntoNeighbors for &'a FnFields {
863 type Neighbors = Neighbors<'a>;
864
865 fn neighbors(self, block_idx: Self::NodeId) -> Self::Neighbors {
866 Neighbors {
867 iter: self
868 .try_get_block_term_inst_and_kind(block_idx)
869 .ok()
870 .map(|(_, BlockTermInstKind { succs_and_params })| succs_and_params.keys()),
871 }
872 }
873 }
874
875 pub struct VisitedMap(HashSet<BlockIdx>);
876
877 impl VisitMap<BlockIdx> for VisitedMap {
878 fn visit(&mut self, block: BlockIdx) -> bool {
879 self.0.insert(block)
880 }
881
882 fn is_visited(&self, block: &BlockIdx) -> bool {
883 self.0.contains(block)
884 }
885 }
886
887 impl Visitable for FnFields {
888 type Map = VisitedMap;
889
890 fn visit_map(&self) -> Self::Map {
891 VisitedMap(HashSet::new())
892 }
893
894 fn reset_map(&self, map: &mut Self::Map) {
895 map.0.clear();
896 }
897 }
898
899 impl GraphProp for FnFields {
900 type EdgeType = Directed;
901 }