02419e538731560f8aea5a82465f7a0a2471d6e2
[bigint-presentation-code.git] / register_allocator / src / function.rs
1 use crate::{
2 error::{
3 BlockIdxOutOfRange, BlockParamIdxOutOfRange, Error, InstIdxOutOfRange,
4 OperandIdxOutOfRange, Result, SSAValIdxOutOfRange,
5 },
6 index::{
7 BlockIdx, BlockParamIdx, IndexTy, InstIdx, InstRange, OperandIdx, RangeIter, SSAValIdx,
8 },
9 interned::{GlobalState, Intern, Interned},
10 loc::{BaseTy, Loc, Ty},
11 loc_set::LocSet,
12 };
13 use arbitrary::Arbitrary;
14 use core::fmt;
15 use enum_map::Enum;
16 use hashbrown::HashSet;
17 use petgraph::{
18 algo::dominators,
19 visit::{GraphBase, GraphProp, IntoNeighbors, VisitMap, Visitable},
20 Directed,
21 };
22 use serde::{Deserialize, Serialize};
23 use smallvec::SmallVec;
24 use std::{
25 collections::{btree_map, BTreeMap, BTreeSet},
26 iter::FusedIterator,
27 mem,
28 ops::{Index, IndexMut},
29 };
30
31 #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
32 pub enum SSAValDef {
33 BlockParam {
34 block: BlockIdx,
35 param_idx: BlockParamIdx,
36 },
37 Operand {
38 inst: InstIdx,
39 operand_idx: OperandIdx,
40 },
41 }
42
43 impl SSAValDef {
44 pub const fn invalid() -> Self {
45 SSAValDef::BlockParam {
46 block: BlockIdx::ENTRY_BLOCK,
47 param_idx: BlockParamIdx::new(!0),
48 }
49 }
50 }
51
52 #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize)]
53 pub struct BranchSuccParamUse {
54 pub branch_inst: InstIdx,
55 pub succ: BlockIdx,
56 pub param_idx: BlockParamIdx,
57 }
58
59 #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize)]
60 pub struct OperandUse {
61 pub inst: InstIdx,
62 pub operand_idx: OperandIdx,
63 }
64
65 #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
66 pub struct SSAVal {
67 pub ty: Ty,
68 #[serde(skip, default = "SSAValDef::invalid")]
69 pub def: SSAValDef,
70 #[serde(skip)]
71 pub operand_uses: BTreeSet<OperandUse>,
72 #[serde(skip)]
73 pub branch_succ_param_uses: BTreeSet<BranchSuccParamUse>,
74 }
75
76 impl SSAVal {
77 fn validate(&self, ssa_val_idx: SSAValIdx, func: &FnFields) -> Result<()> {
78 let Self {
79 ty: _,
80 def,
81 operand_uses,
82 branch_succ_param_uses,
83 } = self;
84 match *def {
85 SSAValDef::BlockParam { block, param_idx } => {
86 let block_param = func.try_get_block_param(block, param_idx)?;
87 if ssa_val_idx != block_param {
88 return Err(Error::MismatchedBlockParamDef {
89 ssa_val_idx,
90 block,
91 param_idx,
92 });
93 }
94 }
95 SSAValDef::Operand { inst, operand_idx } => {
96 let operand = func.try_get_operand(inst, operand_idx)?;
97 if ssa_val_idx != operand.ssa_val {
98 return Err(Error::SSAValDefIsNotOperandsSSAVal {
99 ssa_val_idx,
100 inst,
101 operand_idx,
102 });
103 }
104 }
105 }
106 for &OperandUse { inst, operand_idx } in operand_uses {
107 let operand = func.try_get_operand(inst, operand_idx)?;
108 if ssa_val_idx != operand.ssa_val {
109 return Err(Error::SSAValUseIsNotOperandsSSAVal {
110 ssa_val_idx,
111 inst,
112 operand_idx,
113 });
114 }
115 }
116 for &BranchSuccParamUse {
117 branch_inst,
118 succ,
119 param_idx,
120 } in branch_succ_param_uses
121 {
122 if ssa_val_idx != func.try_get_branch_target_param(branch_inst, succ, param_idx)? {
123 return Err(Error::MismatchedBranchTargetBlockParamUse {
124 ssa_val_idx,
125 branch_inst,
126 tgt_block: succ,
127 param_idx,
128 });
129 }
130 }
131 Ok(())
132 }
133 }
134
135 #[derive(
136 Copy,
137 Clone,
138 PartialEq,
139 Eq,
140 PartialOrd,
141 Ord,
142 Debug,
143 Hash,
144 Serialize,
145 Deserialize,
146 Arbitrary,
147 Enum,
148 )]
149 #[repr(u8)]
150 pub enum InstStage {
151 Early = 0,
152 Late = 1,
153 }
154
155 #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
156 #[serde(try_from = "SerializedProgPoint", into = "SerializedProgPoint")]
157 pub struct ProgPoint(usize);
158
159 impl ProgPoint {
160 pub const fn new(inst: InstIdx, stage: InstStage) -> Self {
161 const_unwrap_res!(Self::try_new(inst, stage))
162 }
163 pub const fn try_new(inst: InstIdx, stage: InstStage) -> Result<Self> {
164 let Some(inst) = inst.get().checked_shl(1) else {
165 return Err(Error::InstIdxTooBig);
166 };
167 Ok(Self(inst | stage as usize))
168 }
169 pub const fn inst(self) -> InstIdx {
170 InstIdx::new(self.0 >> 1)
171 }
172 pub const fn stage(self) -> InstStage {
173 if self.0 & 1 != 0 {
174 InstStage::Late
175 } else {
176 InstStage::Early
177 }
178 }
179 pub const fn next(self) -> Self {
180 Self(self.0 + 1)
181 }
182 pub const fn prev(self) -> Self {
183 Self(self.0 - 1)
184 }
185 }
186
187 impl fmt::Debug for ProgPoint {
188 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
189 f.debug_struct("ProgPoint")
190 .field("inst", &self.inst())
191 .field("stage", &self.stage())
192 .finish()
193 }
194 }
195
196 #[derive(Serialize, Deserialize)]
197 struct SerializedProgPoint {
198 inst: InstIdx,
199 stage: InstStage,
200 }
201
202 impl From<ProgPoint> for SerializedProgPoint {
203 fn from(value: ProgPoint) -> Self {
204 Self {
205 inst: value.inst(),
206 stage: value.stage(),
207 }
208 }
209 }
210
211 impl TryFrom<SerializedProgPoint> for ProgPoint {
212 type Error = Error;
213
214 fn try_from(value: SerializedProgPoint) -> Result<Self, Self::Error> {
215 ProgPoint::try_new(value.inst, value.stage)
216 }
217 }
218
219 #[derive(
220 Copy,
221 Clone,
222 PartialEq,
223 Eq,
224 PartialOrd,
225 Ord,
226 Debug,
227 Hash,
228 Serialize,
229 Deserialize,
230 Arbitrary,
231 Enum,
232 )]
233 #[repr(u8)]
234 pub enum OperandKind {
235 Use = 0,
236 Def = 1,
237 }
238
239 #[derive(
240 Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize, Arbitrary,
241 )]
242 pub enum Constraint {
243 /// any register or stack location
244 Any,
245 /// r1-r32
246 BaseGpr,
247 /// r2,r4,r6,r8,...r126
248 SVExtra2VGpr,
249 /// r1-63
250 SVExtra2SGpr,
251 /// r1-127
252 SVExtra3Gpr,
253 /// any stack location
254 Stack,
255 FixedLoc(Loc),
256 }
257
258 impl Constraint {
259 pub fn is_any(&self) -> bool {
260 matches!(self, Self::Any)
261 }
262 pub fn fixed_loc(&self) -> Option<Loc> {
263 match *self {
264 Constraint::Any
265 | Constraint::BaseGpr
266 | Constraint::SVExtra2VGpr
267 | Constraint::SVExtra2SGpr
268 | Constraint::SVExtra3Gpr
269 | Constraint::Stack => None,
270 Constraint::FixedLoc(v) => Some(v),
271 }
272 }
273 pub fn non_fixed_choices_for_ty(ty: Ty) -> &'static [Constraint] {
274 match (ty.base_ty, ty.reg_len.get()) {
275 (BaseTy::Bits64, 1) => &[
276 Constraint::Any,
277 Constraint::BaseGpr,
278 Constraint::SVExtra2SGpr,
279 Constraint::SVExtra2VGpr,
280 Constraint::SVExtra3Gpr,
281 Constraint::Stack,
282 ],
283 (BaseTy::Bits64, _) => &[
284 Constraint::Any,
285 Constraint::SVExtra2VGpr,
286 Constraint::SVExtra3Gpr,
287 Constraint::Stack,
288 ],
289 (BaseTy::Ca, _) | (BaseTy::VlMaxvl, _) => &[Constraint::Any, Constraint::Stack],
290 }
291 }
292 pub fn arbitrary_with_ty(
293 ty: Ty,
294 u: &mut arbitrary::Unstructured<'_>,
295 ) -> arbitrary::Result<Self> {
296 let non_fixed_choices = Self::non_fixed_choices_for_ty(ty);
297 if let Some(&retval) = non_fixed_choices.get(u.choose_index(non_fixed_choices.len() + 1)?) {
298 Ok(retval)
299 } else {
300 Ok(Constraint::FixedLoc(Loc::arbitrary_with_ty(ty, u)?))
301 }
302 }
303 pub fn check_for_ty_mismatch(&self, ty: Ty) -> Result<(), ()> {
304 match self {
305 Constraint::Any | Constraint::Stack => {}
306 Constraint::BaseGpr | Constraint::SVExtra2SGpr => {
307 if ty != Ty::scalar(BaseTy::Bits64) {
308 return Err(());
309 }
310 }
311 Constraint::SVExtra2VGpr | Constraint::SVExtra3Gpr => {
312 if ty.base_ty != BaseTy::Bits64 {
313 return Err(());
314 }
315 }
316 Constraint::FixedLoc(loc) => {
317 if ty != loc.ty() {
318 return Err(());
319 }
320 }
321 }
322 Ok(())
323 }
324 }
325
326 impl Default for Constraint {
327 fn default() -> Self {
328 Self::Any
329 }
330 }
331
332 #[derive(
333 Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize, Default,
334 )]
335 #[serde(try_from = "OperandKind", into = "OperandKind")]
336 pub struct OperandKindDefOnly;
337
338 impl TryFrom<OperandKind> for OperandKindDefOnly {
339 type Error = Error;
340
341 fn try_from(value: OperandKind) -> Result<Self, Self::Error> {
342 match value {
343 OperandKind::Use => Err(Error::OperandKindMustBeDef),
344 OperandKind::Def => Ok(Self),
345 }
346 }
347 }
348
349 impl From<OperandKindDefOnly> for OperandKind {
350 fn from(_value: OperandKindDefOnly) -> Self {
351 Self::Def
352 }
353 }
354
355 #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize)]
356 #[serde(untagged)]
357 pub enum KindAndConstraint {
358 Reuse {
359 kind: OperandKindDefOnly,
360 reuse_operand_idx: OperandIdx,
361 },
362 Constraint {
363 kind: OperandKind,
364 #[serde(default, skip_serializing_if = "Constraint::is_any")]
365 constraint: Constraint,
366 },
367 }
368
369 impl KindAndConstraint {
370 pub fn kind(self) -> OperandKind {
371 match self {
372 Self::Reuse { .. } => OperandKind::Def,
373 Self::Constraint { kind, .. } => kind,
374 }
375 }
376 pub fn is_reuse(self) -> bool {
377 matches!(self, Self::Reuse { .. })
378 }
379 }
380
381 #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
382 pub struct Operand {
383 pub ssa_val: SSAValIdx,
384 #[serde(flatten)]
385 pub kind_and_constraint: KindAndConstraint,
386 pub stage: InstStage,
387 }
388
389 impl Operand {
390 pub fn try_get_reuse_src<'f>(
391 &self,
392 inst: InstIdx,
393 func: &'f FnFields,
394 ) -> Result<Option<&'f Operand>> {
395 if let KindAndConstraint::Reuse {
396 reuse_operand_idx, ..
397 } = self.kind_and_constraint
398 {
399 Ok(Some(func.try_get_operand(inst, reuse_operand_idx)?))
400 } else {
401 Ok(None)
402 }
403 }
404 pub fn try_constraint(&self, inst: InstIdx, func: &FnFields) -> Result<Constraint> {
405 Ok(match self.kind_and_constraint {
406 KindAndConstraint::Reuse {
407 kind: _,
408 reuse_operand_idx,
409 } => {
410 let operand = func.try_get_operand(inst, reuse_operand_idx)?;
411 match operand.kind_and_constraint {
412 KindAndConstraint::Reuse { .. }
413 | KindAndConstraint::Constraint {
414 kind: OperandKind::Def,
415 ..
416 } => {
417 return Err(Error::ReuseTargetOperandMustBeUse {
418 inst,
419 reuse_target_operand_idx: reuse_operand_idx,
420 })
421 }
422 KindAndConstraint::Constraint {
423 kind: OperandKind::Use,
424 constraint,
425 } => constraint,
426 }
427 }
428 KindAndConstraint::Constraint { constraint, .. } => constraint,
429 })
430 }
431 pub fn constraint(&self, inst: InstIdx, func: &Function) -> Constraint {
432 self.try_constraint(inst, func).unwrap()
433 }
434 fn validate(
435 self,
436 block: BlockIdx,
437 inst: InstIdx,
438 operand_idx: OperandIdx,
439 func: &FnFields,
440 global_state: &GlobalState,
441 ) -> Result<()> {
442 let Self {
443 ssa_val: ssa_val_idx,
444 kind_and_constraint,
445 stage: _,
446 } = self;
447 let ssa_val = func
448 .ssa_vals
449 .try_index(ssa_val_idx)
450 .map_err(|e| e.with_block_inst_and_operand(block, inst, operand_idx))?;
451 match kind_and_constraint.kind() {
452 OperandKind::Use => {
453 if !ssa_val
454 .operand_uses
455 .contains(&OperandUse { inst, operand_idx })
456 {
457 return Err(Error::MissingOperandUse {
458 ssa_val_idx,
459 inst,
460 operand_idx,
461 });
462 }
463 }
464 OperandKind::Def => {
465 let def = SSAValDef::Operand { inst, operand_idx };
466 if ssa_val.def != def {
467 return Err(Error::OperandDefIsNotSSAValDef {
468 ssa_val_idx,
469 inst,
470 operand_idx,
471 });
472 }
473 }
474 }
475 if let KindAndConstraint::Reuse {
476 kind: _,
477 reuse_operand_idx,
478 } = self.kind_and_constraint
479 {
480 let reuse_src = func.try_get_operand(inst, reuse_operand_idx)?;
481 let reuse_src_ssa_val = func
482 .ssa_vals
483 .try_index(reuse_src.ssa_val)
484 .map_err(|e| e.with_block_inst_and_operand(block, inst, reuse_operand_idx))?;
485 if ssa_val.ty != reuse_src_ssa_val.ty {
486 return Err(Error::ReuseOperandTyMismatch {
487 inst,
488 tgt_operand_idx: operand_idx,
489 src_operand_idx: reuse_operand_idx,
490 src_ty: reuse_src_ssa_val.ty,
491 tgt_ty: ssa_val.ty,
492 });
493 }
494 }
495 let constraint = self.try_constraint(inst, func)?;
496 constraint
497 .check_for_ty_mismatch(ssa_val.ty)
498 .map_err(|()| Error::ConstraintTyMismatch {
499 ssa_val_idx,
500 inst,
501 operand_idx,
502 })?;
503 if let Some(fixed_loc) = constraint.fixed_loc() {
504 if func
505 .insts
506 .try_index(inst)?
507 .clobbers
508 .clone()
509 .conflicts_with(fixed_loc, global_state)
510 {
511 return Err(Error::FixedLocConflictsWithClobbers { inst, operand_idx });
512 }
513 }
514 Ok(())
515 }
516 }
517
518 /// copy concatenates all `srcs` together and de-concatenates the result into all `dests`.
519 #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
520 pub struct CopyInstKind {
521 pub src_operand_idxs: Vec<OperandIdx>,
522 pub dest_operand_idxs: Vec<OperandIdx>,
523 pub copy_ty: Ty,
524 }
525
526 impl CopyInstKind {
527 fn calc_copy_ty(
528 operand_idxs: &[OperandIdx],
529 inst: InstIdx,
530 func: &FnFields,
531 ) -> Result<Option<Ty>> {
532 let mut retval: Option<Ty> = None;
533 for &operand_idx in operand_idxs {
534 let operand = func.try_get_operand(inst, operand_idx)?;
535 let ssa_val = func
536 .ssa_vals
537 .try_index(operand.ssa_val)
538 .map_err(|e| e.with_inst_and_operand(inst, operand_idx))?;
539 retval = Some(match retval {
540 Some(retval) => retval.try_concat(ssa_val.ty)?,
541 None => ssa_val.ty,
542 });
543 }
544 Ok(retval)
545 }
546 }
547
548 #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
549 pub struct BlockTermInstKind {
550 pub succs_and_params: BTreeMap<BlockIdx, Vec<SSAValIdx>>,
551 }
552
553 #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
554 pub enum InstKind {
555 Normal,
556 Copy(CopyInstKind),
557 BlockTerm(BlockTermInstKind),
558 }
559
560 impl InstKind {
561 pub fn is_normal(&self) -> bool {
562 matches!(self, Self::Normal)
563 }
564 pub fn is_block_term(&self) -> bool {
565 matches!(self, Self::BlockTerm { .. })
566 }
567 pub fn is_copy(&self) -> bool {
568 matches!(self, Self::Copy { .. })
569 }
570 pub fn block_term(&self) -> Option<&BlockTermInstKind> {
571 match self {
572 InstKind::BlockTerm(v) => Some(v),
573 _ => None,
574 }
575 }
576 pub fn block_term_mut(&mut self) -> Option<&mut BlockTermInstKind> {
577 match self {
578 InstKind::BlockTerm(v) => Some(v),
579 _ => None,
580 }
581 }
582 pub fn copy(&self) -> Option<&CopyInstKind> {
583 match self {
584 InstKind::Copy(v) => Some(v),
585 _ => None,
586 }
587 }
588 }
589
590 impl Default for InstKind {
591 fn default() -> Self {
592 InstKind::Normal
593 }
594 }
595
596 fn loc_set_is_empty(clobbers: &Interned<LocSet>) -> bool {
597 clobbers.is_empty()
598 }
599
600 fn empty_loc_set() -> Interned<LocSet> {
601 GlobalState::get(|global_state| LocSet::default().into_interned(global_state))
602 }
603
604 #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
605 pub struct Inst {
606 #[serde(default, skip_serializing_if = "InstKind::is_normal")]
607 pub kind: InstKind,
608 pub operands: Vec<Operand>,
609 #[serde(default = "empty_loc_set", skip_serializing_if = "loc_set_is_empty")]
610 pub clobbers: Interned<LocSet>,
611 }
612
613 impl Inst {
614 fn validate(
615 &self,
616 block: BlockIdx,
617 inst: InstIdx,
618 func: &FnFields,
619 global_state: &GlobalState,
620 ) -> Result<()> {
621 let Self {
622 kind,
623 operands,
624 clobbers: _,
625 } = self;
626 let is_at_end_of_block = func.blocks[block].insts.last() == Some(inst);
627 if kind.is_block_term() != is_at_end_of_block {
628 return Err(if is_at_end_of_block {
629 Error::BlocksLastInstMustBeTerm { term_idx: inst }
630 } else {
631 Error::TermInstOnlyAllowedAtBlockEnd { inst_idx: inst }
632 });
633 }
634 for (operand_idx, operand) in operands.entries() {
635 operand.validate(block, inst, operand_idx, func, global_state)?;
636 }
637 match kind {
638 InstKind::Normal => {}
639 InstKind::Copy(CopyInstKind {
640 src_operand_idxs,
641 dest_operand_idxs,
642 copy_ty,
643 }) => {
644 let mut seen_dest_operands = SmallVec::<[bool; 16]>::new();
645 seen_dest_operands.resize(operands.len(), false);
646 for &dest_operand_idx in dest_operand_idxs {
647 let seen_dest_operand = seen_dest_operands
648 .get_mut(dest_operand_idx.get())
649 .ok_or_else(|| {
650 OperandIdxOutOfRange {
651 idx: dest_operand_idx,
652 }
653 .with_inst(inst)
654 })?;
655 if mem::replace(seen_dest_operand, true) {
656 return Err(Error::DupCopyDestOperand {
657 inst,
658 operand_idx: dest_operand_idx,
659 });
660 }
661 }
662 if Some(*copy_ty) != CopyInstKind::calc_copy_ty(&src_operand_idxs, inst, func)? {
663 return Err(Error::CopySrcTyMismatch { inst });
664 }
665 if Some(*copy_ty) != CopyInstKind::calc_copy_ty(&dest_operand_idxs, inst, func)? {
666 return Err(Error::CopyDestTyMismatch { inst });
667 }
668 }
669 InstKind::BlockTerm(BlockTermInstKind { succs_and_params }) => {
670 for (&succ_idx, params) in succs_and_params {
671 let succ = func.blocks.try_index(succ_idx)?;
672 if !succ.preds.contains(&block) {
673 return Err(Error::SrcBlockMissingFromBranchTgtBlocksPreds {
674 src_block: block,
675 branch_inst: inst,
676 tgt_block: succ_idx,
677 });
678 }
679 if succ.params.len() != params.len() {
680 return Err(Error::BranchSuccParamCountMismatch {
681 inst,
682 succ: succ_idx,
683 block_param_count: succ.params.len(),
684 branch_param_count: params.len(),
685 });
686 }
687 for ((param_idx, &branch_ssa_val_idx), &block_ssa_val_idx) in
688 params.entries().zip(&succ.params)
689 {
690 let branch_ssa_val = func
691 .ssa_vals
692 .try_index(branch_ssa_val_idx)
693 .map_err(|e| e.with_inst_succ_and_param(inst, succ_idx, param_idx))?;
694 let block_ssa_val = func
695 .ssa_vals
696 .try_index(block_ssa_val_idx)
697 .map_err(|e| e.with_block_and_param(succ_idx, param_idx))?;
698 if !branch_ssa_val
699 .branch_succ_param_uses
700 .contains(&BranchSuccParamUse {
701 branch_inst: inst,
702 succ: succ_idx,
703 param_idx,
704 })
705 {
706 return Err(Error::MissingBranchSuccParamUse {
707 ssa_val_idx: branch_ssa_val_idx,
708 inst,
709 succ: succ_idx,
710 param_idx,
711 });
712 }
713 if block_ssa_val.ty != branch_ssa_val.ty {
714 return Err(Error::BranchSuccParamTyMismatch {
715 inst,
716 succ: succ_idx,
717 param_idx,
718 block_param_ty: block_ssa_val.ty,
719 branch_param_ty: branch_ssa_val.ty,
720 });
721 }
722 }
723 }
724 }
725 }
726 Ok(())
727 }
728 pub fn try_get_operand(&self, inst: InstIdx, operand_idx: OperandIdx) -> Result<&Operand> {
729 self.operands
730 .try_index(operand_idx)
731 .map_err(|e| e.with_inst(inst).into())
732 }
733 }
734
735 #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
736 pub struct Block {
737 pub params: Vec<SSAValIdx>,
738 pub insts: InstRange,
739 pub preds: BTreeSet<BlockIdx>,
740 pub immediate_dominator: Option<BlockIdx>,
741 }
742
743 impl Block {
744 fn validate(&self, block: BlockIdx, func: &FnFields, global_state: &GlobalState) -> Result<()> {
745 let Self {
746 params,
747 insts,
748 preds,
749 immediate_dominator: _, // validated by Function::new_with_global_state
750 } = self;
751 const _: () = assert!(BlockIdx::ENTRY_BLOCK.get() == 0);
752 let expected_start = if block == BlockIdx::ENTRY_BLOCK {
753 InstIdx::new(0)
754 } else {
755 func.blocks[block.prev()].insts.end
756 };
757 if insts.start != expected_start {
758 return Err(Error::BlockHasInvalidStart {
759 start: insts.start,
760 expected_start,
761 });
762 }
763 let term_inst_idx = insts.last().ok_or(Error::BlockIsEmpty { block })?;
764 func.insts
765 .get(term_inst_idx.get())
766 .ok_or(Error::BlockEndOutOfRange { end: insts.end })?;
767 if block.get() == func.blocks.len() - 1 && insts.end.get() != func.insts.len() {
768 return Err(Error::InstHasNoBlock { inst: insts.end });
769 }
770 if block == BlockIdx::ENTRY_BLOCK {
771 if !params.is_empty() {
772 return Err(Error::EntryBlockCantHaveParams);
773 }
774 if !preds.is_empty() {
775 return Err(Error::EntryBlockCantHavePreds);
776 }
777 }
778 for inst in *insts {
779 func.insts[inst].validate(block, inst, func, global_state)?;
780 }
781 for (param_idx, &ssa_val_idx) in params.entries() {
782 let ssa_val = func
783 .ssa_vals
784 .try_index(ssa_val_idx)
785 .map_err(|e| e.with_block_and_param(block, param_idx))?;
786 let def = SSAValDef::BlockParam { block, param_idx };
787 if ssa_val.def != def {
788 return Err(Error::MismatchedBlockParamDef {
789 ssa_val_idx,
790 block,
791 param_idx,
792 });
793 }
794 }
795 for &pred in preds {
796 let (term_inst, BlockTermInstKind { succs_and_params }) =
797 func.try_get_block_term_inst_and_kind(pred)?;
798 if !succs_and_params.contains_key(&block) {
799 return Err(Error::PredMissingFromPredsTermBranchsTargets {
800 src_block: pred,
801 branch_inst: term_inst,
802 tgt_block: block,
803 });
804 }
805 if preds.len() > 1 && succs_and_params.len() > 1 {
806 return Err(Error::CriticalEdgeNotAllowed {
807 src_block: pred,
808 branch_inst: term_inst,
809 tgt_block: block,
810 });
811 }
812 }
813 Ok(())
814 }
815 }
816
817 validated_fields! {
818 #[fields_ty = FnFields]
819 #[derive(Clone, PartialEq, Eq, Debug, Hash)]
820 pub struct Function {
821 pub ssa_vals: Vec<SSAVal>,
822 pub insts: Vec<Inst>,
823 pub blocks: Vec<Block>,
824 #[serde(skip)]
825 /// map from blocks' start instruction's index to their block index, doesn't contain the entry block
826 pub start_inst_to_block_map: BTreeMap<InstIdx, BlockIdx>,
827 }
828 }
829
830 impl Function {
831 pub fn new(fields: FnFields) -> Result<Self> {
832 GlobalState::get(|global_state| Self::new_with_global_state(fields, global_state))
833 }
834 pub fn new_with_global_state(mut fields: FnFields, global_state: &GlobalState) -> Result<Self> {
835 fields.fill_start_inst_to_block_map();
836 fields.fill_ssa_defs_uses()?;
837 let FnFields {
838 ssa_vals,
839 insts: _,
840 blocks,
841 start_inst_to_block_map: _,
842 } = &fields;
843 blocks
844 .get(BlockIdx::ENTRY_BLOCK.get())
845 .ok_or(Error::MissingEntryBlock)?;
846 for (block_idx, block) in blocks.entries() {
847 block.validate(block_idx, &fields, global_state)?;
848 }
849 let dominators = dominators::simple_fast(&fields, BlockIdx::ENTRY_BLOCK);
850 for (block_idx, block) in blocks.entries() {
851 let expected = dominators.immediate_dominator(block_idx);
852 if block.immediate_dominator != expected {
853 return Err(Error::IncorrectImmediateDominator {
854 block_idx,
855 found: block.immediate_dominator,
856 expected,
857 });
858 }
859 }
860 for (ssa_val_idx, ssa_val) in ssa_vals.entries() {
861 ssa_val.validate(ssa_val_idx, &fields)?;
862 }
863 Ok(Self(fields))
864 }
865 pub fn entry_block(&self) -> &Block {
866 &self.blocks[0]
867 }
868 pub fn block_term_kind(&self, block: BlockIdx) -> &BlockTermInstKind {
869 self.insts[self.blocks[block].insts.last().unwrap()]
870 .kind
871 .block_term()
872 .unwrap()
873 }
874 }
875
876 impl FnFields {
877 pub fn fill_start_inst_to_block_map(&mut self) {
878 self.start_inst_to_block_map.clear();
879 for (block_idx, block) in self.blocks.entries() {
880 if block_idx != BlockIdx::ENTRY_BLOCK {
881 self.start_inst_to_block_map
882 .insert(block.insts.start, block_idx);
883 }
884 }
885 }
886 pub fn fill_ssa_defs_uses(&mut self) -> Result<()> {
887 for ssa_val in &mut self.ssa_vals {
888 ssa_val.branch_succ_param_uses.clear();
889 ssa_val.operand_uses.clear();
890 ssa_val.def = SSAValDef::invalid();
891 }
892 for (block_idx, block) in self.blocks.entries() {
893 for (param_idx, &param) in block.params.entries() {
894 self.ssa_vals
895 .try_index_mut(param)
896 .map_err(|e| e.with_block_and_param(block_idx, param_idx))?
897 .def = SSAValDef::BlockParam {
898 block: block_idx,
899 param_idx,
900 };
901 }
902 }
903 for (inst_idx, inst) in self.insts.entries() {
904 for (operand_idx, operand) in inst.operands.entries() {
905 let ssa_val = self
906 .ssa_vals
907 .try_index_mut(operand.ssa_val)
908 .map_err(|e| e.with_inst_and_operand(inst_idx, operand_idx))?;
909 match operand.kind_and_constraint.kind() {
910 OperandKind::Use => {
911 ssa_val.operand_uses.insert(OperandUse {
912 inst: inst_idx,
913 operand_idx,
914 });
915 }
916 OperandKind::Def => {
917 ssa_val.def = SSAValDef::Operand {
918 inst: inst_idx,
919 operand_idx,
920 };
921 }
922 }
923 }
924 match &inst.kind {
925 InstKind::Normal | InstKind::Copy(_) => {}
926 InstKind::BlockTerm(BlockTermInstKind { succs_and_params }) => {
927 for (&succ, params) in succs_and_params {
928 for (param_idx, &param) in params.entries() {
929 let ssa_val = self.ssa_vals.try_index_mut(param).map_err(|e| {
930 e.with_inst_succ_and_param(inst_idx, succ, param_idx)
931 })?;
932 ssa_val.branch_succ_param_uses.insert(BranchSuccParamUse {
933 branch_inst: inst_idx,
934 succ,
935 param_idx,
936 });
937 }
938 }
939 }
940 }
941 }
942 Ok(())
943 }
944 pub fn try_get_operand(&self, inst: InstIdx, operand_idx: OperandIdx) -> Result<&Operand> {
945 Ok(self
946 .insts
947 .try_index(inst)?
948 .operands
949 .try_index(operand_idx)
950 .map_err(|e| e.with_inst(inst))?)
951 }
952 pub fn try_get_block_param(
953 &self,
954 block: BlockIdx,
955 param_idx: BlockParamIdx,
956 ) -> Result<SSAValIdx> {
957 Ok(*self
958 .blocks
959 .try_index(block)?
960 .params
961 .try_index(param_idx)
962 .map_err(|e| e.with_block(block))?)
963 }
964 pub fn try_get_block_term_inst_idx(&self, block: BlockIdx) -> Result<InstIdx> {
965 self.blocks
966 .try_index(block)?
967 .insts
968 .last()
969 .ok_or(Error::BlockIsEmpty { block })
970 }
971 pub fn try_get_block_term_inst_and_kind(
972 &self,
973 block: BlockIdx,
974 ) -> Result<(InstIdx, &BlockTermInstKind)> {
975 let term_idx = self.try_get_block_term_inst_idx(block)?;
976 let term_kind = self
977 .insts
978 .try_index(term_idx)?
979 .kind
980 .block_term()
981 .ok_or(Error::BlocksLastInstMustBeTerm { term_idx })?;
982 Ok((term_idx, term_kind))
983 }
984 pub fn try_get_block_term_inst_and_kind_mut(
985 &mut self,
986 block: BlockIdx,
987 ) -> Result<(InstIdx, &mut BlockTermInstKind)> {
988 let term_idx = self.try_get_block_term_inst_idx(block)?;
989 let term_kind = self
990 .insts
991 .try_index_mut(term_idx)?
992 .kind
993 .block_term_mut()
994 .ok_or(Error::BlocksLastInstMustBeTerm { term_idx })?;
995 Ok((term_idx, term_kind))
996 }
997 pub fn try_get_branch_target_params(
998 &self,
999 branch_inst: InstIdx,
1000 succ: BlockIdx,
1001 ) -> Result<&[SSAValIdx]> {
1002 let inst = self.insts.try_index(branch_inst)?;
1003 let BlockTermInstKind { succs_and_params } = inst
1004 .kind
1005 .block_term()
1006 .ok_or(Error::InstIsNotBlockTerm { inst: branch_inst })?;
1007 Ok(succs_and_params
1008 .get(&succ)
1009 .ok_or(Error::BranchTargetNotFound {
1010 branch_inst,
1011 tgt_block: succ,
1012 })?)
1013 }
1014 pub fn try_get_branch_target_param(
1015 &self,
1016 branch_inst: InstIdx,
1017 succ: BlockIdx,
1018 param_idx: BlockParamIdx,
1019 ) -> Result<SSAValIdx> {
1020 Ok(*self
1021 .try_get_branch_target_params(branch_inst, succ)?
1022 .try_index(param_idx)
1023 .map_err(|e: BlockParamIdxOutOfRange| e.with_inst_and_succ(branch_inst, succ))?)
1024 }
1025 pub fn inst_to_block(&self, inst: InstIdx) -> BlockIdx {
1026 self.start_inst_to_block_map
1027 .range(..=inst)
1028 .next_back()
1029 .map(|v| *v.1)
1030 .unwrap_or(BlockIdx::ENTRY_BLOCK)
1031 }
1032 }
1033
1034 pub trait Entries<'a, I>: Index<I>
1035 where
1036 I: IndexTy,
1037 Self::Output: 'a,
1038 {
1039 type Iter: Iterator<Item = (I, &'a Self::Output)>
1040 + DoubleEndedIterator
1041 + ExactSizeIterator
1042 + FusedIterator;
1043 fn entries(&'a self) -> Self::Iter;
1044 fn keys(&'a self) -> RangeIter<I>;
1045 }
1046
1047 pub trait EntriesMut<'a, I>: Entries<'a, I> + IndexMut<I>
1048 where
1049 I: IndexTy,
1050 Self::Output: 'a,
1051 {
1052 type IterMut: Iterator<Item = (I, &'a mut Self::Output)>
1053 + DoubleEndedIterator
1054 + ExactSizeIterator
1055 + FusedIterator;
1056 fn entries_mut(&'a mut self) -> Self::IterMut;
1057 }
1058
1059 pub trait TryIndex<I>: for<'a> Entries<'a, I>
1060 where
1061 I: IndexTy,
1062 {
1063 type Error;
1064 fn try_index(&self, idx: I) -> Result<&Self::Output, Self::Error>;
1065 }
1066
1067 pub trait TryIndexMut<I>: TryIndex<I> + for<'a> EntriesMut<'a, I>
1068 where
1069 I: IndexTy,
1070 {
1071 fn try_index_mut(&mut self, idx: I) -> Result<&mut Self::Output, Self::Error>;
1072 }
1073
1074 macro_rules! impl_index {
1075 (
1076 #[error = $Error:ident, iter = $Iter:ident, iter_mut = $IterMut:ident]
1077 impl Index<$I:ty> for Vec<$T:ty> {}
1078 ) => {
1079 #[derive(Clone, Debug)]
1080 pub struct $Iter<'a> {
1081 iter: std::iter::Enumerate<std::slice::Iter<'a, $T>>,
1082 }
1083
1084 impl<'a> Iterator for $Iter<'a> {
1085 type Item = ($I, &'a $T);
1086
1087 fn next(&mut self) -> Option<Self::Item> {
1088 self.iter.next().map(|(i, v)| (<$I>::new(i), v))
1089 }
1090
1091 fn size_hint(&self) -> (usize, Option<usize>) {
1092 self.iter.size_hint()
1093 }
1094
1095 fn fold<B, F>(self, init: B, mut f: F) -> B
1096 where
1097 F: FnMut(B, Self::Item) -> B,
1098 {
1099 self.iter
1100 .fold(init, move |a, (i, v)| f(a, (<$I>::new(i), v)))
1101 }
1102 }
1103
1104 impl DoubleEndedIterator for $Iter<'_> {
1105 fn next_back(&mut self) -> Option<Self::Item> {
1106 self.iter.next_back().map(|(i, v)| (<$I>::new(i), v))
1107 }
1108
1109 fn rfold<B, F>(self, init: B, mut f: F) -> B
1110 where
1111 F: FnMut(B, Self::Item) -> B,
1112 {
1113 self.iter
1114 .rfold(init, move |a, (i, v)| f(a, (<$I>::new(i), v)))
1115 }
1116 }
1117
1118 impl ExactSizeIterator for $Iter<'_> {
1119 fn len(&self) -> usize {
1120 self.iter.len()
1121 }
1122 }
1123
1124 impl FusedIterator for $Iter<'_> {}
1125
1126 #[derive(Debug)]
1127 pub struct $IterMut<'a> {
1128 iter: std::iter::Enumerate<std::slice::IterMut<'a, $T>>,
1129 }
1130
1131 impl<'a> Iterator for $IterMut<'a> {
1132 type Item = ($I, &'a mut $T);
1133
1134 fn next(&mut self) -> Option<Self::Item> {
1135 self.iter.next().map(|(i, v)| (<$I>::new(i), v))
1136 }
1137
1138 fn size_hint(&self) -> (usize, Option<usize>) {
1139 self.iter.size_hint()
1140 }
1141
1142 fn fold<B, F>(self, init: B, mut f: F) -> B
1143 where
1144 F: FnMut(B, Self::Item) -> B,
1145 {
1146 self.iter
1147 .fold(init, move |a, (i, v)| f(a, (<$I>::new(i), v)))
1148 }
1149 }
1150
1151 impl DoubleEndedIterator for $IterMut<'_> {
1152 fn next_back(&mut self) -> Option<Self::Item> {
1153 self.iter.next_back().map(|(i, v)| (<$I>::new(i), v))
1154 }
1155
1156 fn rfold<B, F>(self, init: B, mut f: F) -> B
1157 where
1158 F: FnMut(B, Self::Item) -> B,
1159 {
1160 self.iter
1161 .rfold(init, move |a, (i, v)| f(a, (<$I>::new(i), v)))
1162 }
1163 }
1164
1165 impl ExactSizeIterator for $IterMut<'_> {
1166 fn len(&self) -> usize {
1167 self.iter.len()
1168 }
1169 }
1170
1171 impl FusedIterator for $IterMut<'_> {}
1172
1173 impl Index<$I> for Vec<$T> {
1174 type Output = $T;
1175
1176 fn index(&self, index: $I) -> &Self::Output {
1177 &self[index.get()]
1178 }
1179 }
1180
1181 impl IndexMut<$I> for Vec<$T> {
1182 fn index_mut(&mut self, index: $I) -> &mut Self::Output {
1183 &mut self[index.get()]
1184 }
1185 }
1186
1187 impl<'a> Entries<'a, $I> for Vec<$T> {
1188 type Iter = $Iter<'a>;
1189 fn entries(&'a self) -> Self::Iter {
1190 $Iter {
1191 iter: (**self).iter().enumerate(),
1192 }
1193 }
1194 fn keys(&'a self) -> RangeIter<$I> {
1195 RangeIter::from_usize_range(0..self.len())
1196 }
1197 }
1198
1199 impl<'a> EntriesMut<'a, $I> for Vec<$T> {
1200 type IterMut = $IterMut<'a>;
1201 fn entries_mut(&'a mut self) -> Self::IterMut {
1202 $IterMut {
1203 iter: (**self).iter_mut().enumerate(),
1204 }
1205 }
1206 }
1207
1208 impl TryIndex<$I> for Vec<$T> {
1209 type Error = $Error;
1210
1211 fn try_index(&self, idx: $I) -> Result<&Self::Output, Self::Error> {
1212 self.get(idx.get()).ok_or($Error { idx })
1213 }
1214 }
1215
1216 impl TryIndexMut<$I> for Vec<$T> {
1217 fn try_index_mut(&mut self, idx: $I) -> Result<&mut Self::Output, Self::Error> {
1218 self.get_mut(idx.get()).ok_or($Error { idx })
1219 }
1220 }
1221
1222 impl Index<$I> for [$T] {
1223 type Output = $T;
1224
1225 fn index(&self, index: $I) -> &Self::Output {
1226 &self[index.get()]
1227 }
1228 }
1229
1230 impl IndexMut<$I> for [$T] {
1231 fn index_mut(&mut self, index: $I) -> &mut Self::Output {
1232 &mut self[index.get()]
1233 }
1234 }
1235
1236 impl<'a> Entries<'a, $I> for [$T] {
1237 type Iter = $Iter<'a>;
1238 fn entries(&'a self) -> Self::Iter {
1239 $Iter {
1240 iter: self.iter().enumerate(),
1241 }
1242 }
1243 fn keys(&'a self) -> RangeIter<$I> {
1244 RangeIter::from_usize_range(0..self.len())
1245 }
1246 }
1247
1248 impl<'a> EntriesMut<'a, $I> for [$T] {
1249 type IterMut = $IterMut<'a>;
1250 fn entries_mut(&'a mut self) -> Self::IterMut {
1251 $IterMut {
1252 iter: self.iter_mut().enumerate(),
1253 }
1254 }
1255 }
1256
1257 impl TryIndex<$I> for [$T] {
1258 type Error = $Error;
1259
1260 fn try_index(&self, idx: $I) -> Result<&Self::Output, Self::Error> {
1261 self.get(idx.get()).ok_or($Error { idx })
1262 }
1263 }
1264
1265 impl TryIndexMut<$I> for [$T] {
1266 fn try_index_mut(&mut self, idx: $I) -> Result<&mut Self::Output, Self::Error> {
1267 self.get_mut(idx.get()).ok_or($Error { idx })
1268 }
1269 }
1270 };
1271 }
1272
1273 impl_index! {
1274 #[error = SSAValIdxOutOfRange, iter = SSAValEntriesIter, iter_mut = SSAValEntriesIterMut]
1275 impl Index<SSAValIdx> for Vec<SSAVal> {}
1276 }
1277
1278 impl_index! {
1279 #[error = BlockParamIdxOutOfRange, iter = BlockParamEntriesIter, iter_mut = BlockParamEntriesIterMut]
1280 impl Index<BlockParamIdx> for Vec<SSAValIdx> {}
1281 }
1282
1283 impl_index! {
1284 #[error = OperandIdxOutOfRange, iter = OperandEntriesIter, iter_mut = OperandEntriesIterMut]
1285 impl Index<OperandIdx> for Vec<Operand> {}
1286 }
1287
1288 impl_index! {
1289 #[error = InstIdxOutOfRange, iter = InstEntriesIter, iter_mut = InstEntriesIterMut]
1290 impl Index<InstIdx> for Vec<Inst> {}
1291 }
1292
1293 impl_index! {
1294 #[error = BlockIdxOutOfRange, iter = BlockEntriesIter, iter_mut = BlockEntriesIterMut]
1295 impl Index<BlockIdx> for Vec<Block> {}
1296 }
1297
1298 impl GraphBase for FnFields {
1299 type EdgeId = (BlockIdx, BlockIdx);
1300 type NodeId = BlockIdx;
1301 }
1302
1303 pub struct Neighbors<'a> {
1304 iter: Option<btree_map::Keys<'a, BlockIdx, Vec<SSAValIdx>>>,
1305 }
1306
1307 impl Iterator for Neighbors<'_> {
1308 type Item = BlockIdx;
1309
1310 fn next(&mut self) -> Option<Self::Item> {
1311 Some(*self.iter.as_mut()?.next()?)
1312 }
1313 }
1314
1315 impl<'a> IntoNeighbors for &'a FnFields {
1316 type Neighbors = Neighbors<'a>;
1317
1318 fn neighbors(self, block_idx: Self::NodeId) -> Self::Neighbors {
1319 Neighbors {
1320 iter: self
1321 .try_get_block_term_inst_and_kind(block_idx)
1322 .ok()
1323 .map(|(_, BlockTermInstKind { succs_and_params })| succs_and_params.keys()),
1324 }
1325 }
1326 }
1327
1328 pub struct VisitedMap(HashSet<BlockIdx>);
1329
1330 impl VisitMap<BlockIdx> for VisitedMap {
1331 fn visit(&mut self, block: BlockIdx) -> bool {
1332 self.0.insert(block)
1333 }
1334
1335 fn is_visited(&self, block: &BlockIdx) -> bool {
1336 self.0.contains(block)
1337 }
1338 }
1339
1340 impl Visitable for FnFields {
1341 type Map = VisitedMap;
1342
1343 fn visit_map(&self) -> Self::Map {
1344 VisitedMap(HashSet::new())
1345 }
1346
1347 fn reset_map(&self, map: &mut Self::Map) {
1348 map.0.clear();
1349 }
1350 }
1351
1352 impl GraphProp for FnFields {
1353 type EdgeType = Directed;
1354 }
1355
1356 #[cfg(test)]
1357 mod tests {
1358 use super::*;
1359 use crate::loc::TyFields;
1360 use std::num::NonZeroU32;
1361
1362 #[test]
1363 fn test_constraint_non_fixed_choices_for_ty() {
1364 macro_rules! seen {
1365 (
1366 enum ConstraintWithoutFixedLoc {
1367 $($field:ident,)*
1368 }
1369 ) => {
1370 #[derive(Default)]
1371 #[allow(non_snake_case)]
1372 struct Seen {
1373 $($field: bool,)*
1374 }
1375
1376 impl Seen {
1377 fn add(&mut self, constraint: &Constraint) {
1378 match constraint {
1379 Constraint::FixedLoc(_) => {}
1380 $(Constraint::$field => self.$field = true,)*
1381 }
1382 }
1383 fn check(self) {
1384 $(assert!(self.$field, "never seen field: {}", stringify!($field));)*
1385 }
1386 }
1387 };
1388 }
1389 seen! {
1390 enum ConstraintWithoutFixedLoc {
1391 Any,
1392 BaseGpr,
1393 SVExtra2VGpr,
1394 SVExtra2SGpr,
1395 SVExtra3Gpr,
1396 Stack,
1397 }
1398 }
1399 let mut seen = Seen::default();
1400 for base_ty in 0..BaseTy::LENGTH {
1401 let base_ty = BaseTy::from_usize(base_ty);
1402 for reg_len in [1, 2, 100] {
1403 let reg_len = NonZeroU32::new(reg_len).unwrap();
1404 let ty = Ty::new_or_scalar(TyFields { base_ty, reg_len });
1405 let non_fixed_choices = Constraint::non_fixed_choices_for_ty(ty);
1406 assert_eq!(non_fixed_choices.first(), Some(&Constraint::Any));
1407 assert_eq!(non_fixed_choices.last(), Some(&Constraint::Stack));
1408 for constraint in non_fixed_choices {
1409 assert_eq!(constraint.fixed_loc(), None);
1410 seen.add(constraint);
1411 if constraint.check_for_ty_mismatch(ty).is_err() {
1412 panic!("constraint ty mismatch: constraint={constraint:?} ty={ty:?}");
1413 }
1414 }
1415 }
1416 }
1417 seen.check();
1418 }
1419 }