working on code
[bigint-presentation-code.git] / register_allocator / src / function.rs
1 use crate::{
2 error::{Error, Result},
3 index::{BlockIdx, InstIdx, InstRange, SSAValIdx},
4 interned::Interned,
5 loc::{Loc, Ty},
6 loc_set::LocSet,
7 };
8 use core::fmt;
9 use serde::{Deserialize, Serialize};
10 use std::ops::Index;
11
12 #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
13 pub struct SSAVal {
14 pub ty: Ty,
15 }
16
17 #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize)]
18 #[repr(u8)]
19 pub enum InstStage {
20 Early = 0,
21 Late = 1,
22 }
23
24 #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
25 #[serde(try_from = "SerializedProgPoint", into = "SerializedProgPoint")]
26 pub struct ProgPoint(usize);
27
28 impl ProgPoint {
29 pub const fn new(inst: InstIdx, stage: InstStage) -> Self {
30 const_unwrap_res!(Self::try_new(inst, stage))
31 }
32 pub const fn try_new(inst: InstIdx, stage: InstStage) -> Result<Self> {
33 let Some(inst) = inst.get().checked_shl(1) else {
34 return Err(Error::InstIdxTooBig);
35 };
36 Ok(Self(inst | stage as usize))
37 }
38 pub const fn inst(self) -> InstIdx {
39 InstIdx::new(self.0 >> 1)
40 }
41 pub const fn stage(self) -> InstStage {
42 if self.0 & 1 != 0 {
43 InstStage::Late
44 } else {
45 InstStage::Early
46 }
47 }
48 pub const fn next(self) -> Self {
49 Self(self.0 + 1)
50 }
51 pub const fn prev(self) -> Self {
52 Self(self.0 - 1)
53 }
54 }
55
56 impl fmt::Debug for ProgPoint {
57 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58 f.debug_struct("ProgPoint")
59 .field("inst", &self.inst())
60 .field("stage", &self.stage())
61 .finish()
62 }
63 }
64
65 #[derive(Serialize, Deserialize)]
66 struct SerializedProgPoint {
67 inst: InstIdx,
68 stage: InstStage,
69 }
70
71 impl From<ProgPoint> for SerializedProgPoint {
72 fn from(value: ProgPoint) -> Self {
73 Self {
74 inst: value.inst(),
75 stage: value.stage(),
76 }
77 }
78 }
79
80 impl TryFrom<SerializedProgPoint> for ProgPoint {
81 type Error = Error;
82
83 fn try_from(value: SerializedProgPoint) -> Result<Self, Self::Error> {
84 ProgPoint::try_new(value.inst, value.stage)
85 }
86 }
87
88 #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize)]
89 #[repr(u8)]
90 pub enum OperandKind {
91 Use = 0,
92 Def = 1,
93 }
94
95 #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize)]
96 pub enum Constraint {
97 /// any register or stack location
98 Any,
99 /// r1-r32
100 BaseGpr,
101 /// r2,r4,r6,r8,...r126
102 SVExtra2VGpr,
103 /// r1-63
104 SVExtra2SGpr,
105 /// r1-127
106 SVExtra3Gpr,
107 /// any stack location
108 Stack,
109 FixedLoc(Loc),
110 Reuse(usize),
111 }
112
113 #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
114 pub struct Operand {
115 pub ssa_val: SSAValIdx,
116 pub constraint: Constraint,
117 pub kind: OperandKind,
118 pub stage: InstStage,
119 }
120
121 #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
122 pub struct BranchSucc {
123 pub block: BlockIdx,
124 pub params: Vec<SSAValIdx>,
125 }
126
127 #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
128 pub enum InstKind {
129 Normal,
130 /// copy concatenates all `srcs` together and de-concatenates the result into all `dests`.
131 Copy {
132 srcs: Vec<Operand>,
133 dests: Vec<Operand>,
134 },
135 Return,
136 Branch {
137 succs: Vec<BranchSucc>,
138 },
139 }
140
141 impl InstKind {
142 pub fn is_normal(&self) -> bool {
143 matches!(self, Self::Normal)
144 }
145 pub fn is_block_term(&self) -> bool {
146 matches!(self, Self::Return | Self::Branch { .. })
147 }
148 pub fn succs(&self) -> Option<&[BranchSucc]> {
149 match self {
150 InstKind::Normal | InstKind::Copy { .. } => None,
151 InstKind::Return => Some(&[]),
152 InstKind::Branch { succs } => Some(succs),
153 }
154 }
155 }
156
157 impl Default for InstKind {
158 fn default() -> Self {
159 InstKind::Normal
160 }
161 }
162
163 #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
164 pub struct Inst {
165 #[serde(default, skip_serializing_if = "InstKind::is_normal")]
166 pub kind: InstKind,
167 pub operands: Vec<Operand>,
168 pub clobbers: Interned<LocSet>,
169 }
170
171 #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
172 pub struct Block {
173 pub params: Vec<SSAValIdx>,
174 pub insts: InstRange,
175 pub preds: Vec<BlockIdx>,
176 }
177
178 validated_fields! {
179 #[fields_ty = FnFields]
180 #[derive(Clone, PartialEq, Eq, Debug, Hash)]
181 pub struct Function {
182 pub ssa_vals: Vec<SSAVal>,
183 pub insts: Vec<Inst>,
184 pub blocks: Vec<Block>,
185 }
186 }
187
188 impl Function {
189 pub fn new(fields: FnFields) -> Result<Self> {
190 let FnFields {
191 ssa_vals,
192 insts: insts_vec,
193 blocks,
194 } = &fields;
195 let entry_block = blocks
196 .get(BlockIdx::ENTRY_BLOCK.get())
197 .ok_or(Error::MissingEntryBlock)?;
198 if !entry_block.params.is_empty() {
199 return Err(Error::EntryBlockCantHaveParams);
200 }
201 if !entry_block.preds.is_empty() {
202 return Err(Error::EntryBlockCantHavePreds);
203 }
204 let mut expected_start = InstIdx::new(0);
205 for (block_idx, block) in fields.blocks.iter().enumerate() {
206 let block_idx = BlockIdx::new(block_idx);
207 let Block {
208 params,
209 insts: inst_range,
210 preds,
211 } = block;
212 if inst_range.start != expected_start {
213 return Err(Error::BlockHasInvalidStart {
214 start: inst_range.start,
215 expected_start,
216 });
217 }
218 let Some((term_idx, non_term_inst_range)) = inst_range.split_last() else {
219 return Err(Error::BlockIsEmpty { block: block_idx });
220 };
221 expected_start = inst_range.end;
222 let Some(Inst { kind: term_kind, .. }) = insts_vec.get(term_idx.get()) else {
223 return Err(Error::BlockEndOutOfRange { end: inst_range.end });
224 };
225 if !term_kind.is_block_term() {
226 return Err(Error::BlocksLastInstMustBeTerm { term_idx });
227 }
228 for inst_idx in non_term_inst_range {
229 if insts_vec[inst_idx].kind.is_block_term() {
230 return Err(Error::TermInstOnlyAllowedAtBlockEnd { inst_idx });
231 }
232 }
233 }
234 todo!()
235 }
236 pub fn entry_block(&self) -> &Block {
237 &self.blocks[0]
238 }
239 pub fn block_succs(&self, block: BlockIdx) -> &[BranchSucc] {
240 self.insts[self.blocks[block].insts.last().unwrap()]
241 .kind
242 .succs()
243 .unwrap()
244 }
245 }
246
247 impl Index<SSAValIdx> for Vec<SSAVal> {
248 type Output = SSAVal;
249
250 fn index(&self, index: SSAValIdx) -> &Self::Output {
251 &self[index.get()]
252 }
253 }
254
255 impl Index<InstIdx> for Vec<Inst> {
256 type Output = Inst;
257
258 fn index(&self, index: InstIdx) -> &Self::Output {
259 &self[index.get()]
260 }
261 }
262
263 impl Index<BlockIdx> for Vec<Block> {
264 type Output = Block;
265
266 fn index(&self, index: BlockIdx) -> &Self::Output {
267 &self[index.get()]
268 }
269 }