wip
[bigint-presentation-code.git] / register_allocator / src / loc.rs
1 use crate::error::{Error, Result};
2 use arbitrary::{size_hint, Arbitrary};
3 use enum_map::Enum;
4 use serde::{Deserialize, Serialize};
5 use std::{iter::FusedIterator, num::NonZeroU32};
6
7 #[derive(
8 Serialize,
9 Deserialize,
10 Copy,
11 Clone,
12 PartialEq,
13 Eq,
14 PartialOrd,
15 Ord,
16 Debug,
17 Hash,
18 Enum,
19 Arbitrary,
20 )]
21 #[repr(u8)]
22 pub enum LocKind {
23 Gpr,
24 StackBits64,
25 Ca,
26 VlMaxvl,
27 }
28
29 impl LocKind {
30 /// since `==` doesn't work with enums in const context
31 pub const fn const_eq(self, other: Self) -> bool {
32 self as u8 == other as u8
33 }
34 pub const fn base_ty(self) -> BaseTy {
35 match self {
36 Self::Gpr | Self::StackBits64 => BaseTy::Bits64,
37 Self::Ca => BaseTy::Ca,
38 Self::VlMaxvl => BaseTy::VlMaxvl,
39 }
40 }
41
42 pub const fn loc_count(self) -> NonZeroU32 {
43 match self {
44 Self::StackBits64 => nzu32_lit!(512),
45 Self::Gpr | Self::Ca | Self::VlMaxvl => self.base_ty().max_reg_len(),
46 }
47 }
48 }
49
50 #[derive(
51 Serialize,
52 Deserialize,
53 Copy,
54 Clone,
55 PartialEq,
56 Eq,
57 PartialOrd,
58 Ord,
59 Debug,
60 Hash,
61 Enum,
62 Arbitrary,
63 )]
64 #[repr(u8)]
65 pub enum BaseTy {
66 Bits64,
67 Ca,
68 VlMaxvl,
69 }
70
71 impl BaseTy {
72 /// since `==` doesn't work with enums in const context
73 pub const fn const_eq(self, other: Self) -> bool {
74 self as u8 == other as u8
75 }
76
77 pub const fn only_scalar(self) -> bool {
78 self.max_reg_len().get() == 1
79 }
80
81 pub const fn max_reg_len(self) -> NonZeroU32 {
82 match self {
83 Self::Bits64 => nzu32_lit!(128),
84 Self::Ca | Self::VlMaxvl => nzu32_lit!(1),
85 }
86 }
87
88 pub const fn loc_kinds(self) -> &'static [LocKind] {
89 match self {
90 BaseTy::Bits64 => &[LocKind::Gpr, LocKind::StackBits64],
91 BaseTy::Ca => &[LocKind::Ca],
92 BaseTy::VlMaxvl => &[LocKind::VlMaxvl],
93 }
94 }
95
96 pub fn arbitrary_reg_len(
97 self,
98 u: &mut arbitrary::Unstructured<'_>,
99 ) -> arbitrary::Result<NonZeroU32> {
100 Ok(NonZeroU32::new(u.int_in_range(1..=self.max_reg_len().get())?).unwrap())
101 }
102
103 pub fn arbitrary_reg_len_size_hint(depth: usize) -> (usize, Option<usize>) {
104 (0, NonZeroU32::size_hint(depth).1)
105 }
106 }
107
108 validated_fields! {
109 #[fields_ty = TyFields]
110 #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)]
111 pub struct Ty {
112 pub base_ty: BaseTy,
113 pub reg_len: NonZeroU32,
114 }
115 }
116
117 impl<'a> Arbitrary<'a> for TyFields {
118 fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
119 let base_ty: BaseTy = u.arbitrary()?;
120 let reg_len = base_ty.arbitrary_reg_len(u)?;
121 Ok(Self { base_ty, reg_len })
122 }
123 fn size_hint(depth: usize) -> (usize, Option<usize>) {
124 let base_ty = BaseTy::size_hint(depth);
125 let reg_len = BaseTy::arbitrary_reg_len_size_hint(depth);
126 size_hint::and(base_ty, reg_len)
127 }
128 }
129
130 impl<'a> Arbitrary<'a> for Ty {
131 fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
132 Ok(Ty::new(u.arbitrary()?)?)
133 }
134 fn size_hint(depth: usize) -> (usize, Option<usize>) {
135 TyFields::size_hint(depth)
136 }
137 }
138
139 impl Ty {
140 pub const fn new(fields: TyFields) -> Result<Ty> {
141 let TyFields { base_ty, reg_len } = fields;
142 if base_ty.only_scalar() && reg_len.get() != 1 {
143 Err(Error::TriedToCreateVectorOfOnlyScalarType { base_ty })
144 } else if reg_len.get() > base_ty.max_reg_len().get() {
145 Err(Error::RegLenOutOfRange)
146 } else {
147 Ok(Self(fields))
148 }
149 }
150 /// returns the `Ty` for `fields` or if there was an error, returns the corresponding scalar type (where `reg_len` is `1`)
151 pub const fn new_or_scalar(fields: TyFields) -> Self {
152 match Self::new(fields) {
153 Ok(v) => v,
154 Err(_) => Self::scalar(fields.base_ty),
155 }
156 }
157 pub const fn scalar(base_ty: BaseTy) -> Self {
158 Self(TyFields {
159 base_ty,
160 reg_len: nzu32_lit!(1),
161 })
162 }
163 pub const fn bits64(reg_len: NonZeroU32) -> Self {
164 Self(TyFields {
165 base_ty: BaseTy::Bits64,
166 reg_len,
167 })
168 }
169 pub const fn try_concat(self, rhs: Self) -> Result<Ty> {
170 if !self.get().base_ty.const_eq(rhs.get().base_ty) {
171 Err(Error::BaseTyMismatch)
172 } else {
173 let Some(reg_len) = self.get().reg_len.checked_add(rhs.get().reg_len.get()) else {
174 return Err(Error::RegLenOutOfRange);
175 };
176 Ty::new(TyFields {
177 base_ty: self.get().base_ty,
178 reg_len,
179 })
180 }
181 }
182 }
183
184 validated_fields! {
185 #[fields_ty = SubLocFields]
186 #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)]
187 pub struct SubLoc {
188 pub kind: LocKind,
189 pub start: u32,
190 }
191 }
192
193 impl SubLoc {
194 pub const fn new(fields: SubLocFields) -> Result<SubLoc> {
195 const_try!(Loc::new(LocFields {
196 reg_len: nzu32_lit!(1),
197 kind: fields.kind,
198 start: fields.start
199 }));
200 Ok(SubLoc(fields))
201 }
202 }
203
204 validated_fields! {
205 #[fields_ty = LocFields]
206 #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)]
207 pub struct Loc {
208 pub reg_len: NonZeroU32,
209 pub kind: LocKind,
210 pub start: u32,
211 }
212 }
213
214 impl<'a> Arbitrary<'a> for LocFields {
215 fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
216 let kind: LocKind = u.arbitrary()?;
217 let reg_len = kind.base_ty().arbitrary_reg_len(u)?;
218 let start = Loc::arbitrary_start(kind, reg_len, u)?;
219 Ok(Self {
220 kind,
221 start,
222 reg_len,
223 })
224 }
225
226 fn size_hint(depth: usize) -> (usize, Option<usize>) {
227 let kind = LocKind::size_hint(depth);
228 let reg_len = BaseTy::arbitrary_reg_len_size_hint(depth);
229 let start = Loc::arbitrary_start_size_hint(depth);
230 size_hint::and(size_hint::and(kind, reg_len), start)
231 }
232 }
233
234 impl<'a> Arbitrary<'a> for Loc {
235 fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
236 Ok(Loc::new(u.arbitrary()?)?)
237 }
238 fn size_hint(depth: usize) -> (usize, Option<usize>) {
239 LocFields::size_hint(depth)
240 }
241 }
242
243 impl LocFields {
244 pub const fn ty(self) -> Result<Ty> {
245 Ty::new(TyFields {
246 base_ty: self.kind.base_ty(),
247 reg_len: self.reg_len,
248 })
249 }
250 pub const fn stop(self) -> NonZeroU32 {
251 const_unwrap_opt!(self.reg_len.checked_add(self.start), "overflow")
252 }
253 pub const fn first_subloc(self) -> SubLocFields {
254 SubLocFields {
255 kind: self.kind,
256 start: self.start,
257 }
258 }
259 }
260
261 impl Loc {
262 pub const fn first_subloc(self) -> SubLoc {
263 SubLoc(self.get().first_subloc())
264 }
265 pub fn arbitrary_with_ty(
266 ty: Ty,
267 u: &mut arbitrary::Unstructured<'_>,
268 ) -> arbitrary::Result<Self> {
269 let kind = *u.choose(ty.base_ty.loc_kinds())?;
270 let start = Self::arbitrary_start(kind, ty.reg_len, u)?;
271 Ok(Self::new(LocFields {
272 kind,
273 start,
274 reg_len: ty.reg_len,
275 })?)
276 }
277 pub fn arbitrary_start(
278 kind: LocKind,
279 reg_len: NonZeroU32,
280 u: &mut arbitrary::Unstructured<'_>,
281 ) -> arbitrary::Result<u32> {
282 u.int_in_range(0..=Loc::max_start(kind, reg_len)?)
283 }
284 pub fn arbitrary_start_size_hint(depth: usize) -> (usize, Option<usize>) {
285 (0, u32::size_hint(depth).1)
286 }
287 pub const fn ty(self) -> Ty {
288 const_unwrap_res!(self.0.ty(), "Loc can only be constructed with valid fields")
289 }
290 /// does all `Loc` validation except checking `start`, returns the maximum
291 /// value `start` can have, so a `Loc` is valid if
292 /// `start < Loc::max_start(kind, reg_len)?`
293 pub const fn max_start(kind: LocKind, reg_len: NonZeroU32) -> Result<u32, Error> {
294 // validate Ty
295 const_try!(Ty::new(TyFields {
296 base_ty: kind.base_ty(),
297 reg_len
298 }));
299 let loc_count: u32 = kind.loc_count().get();
300 let Some(max_start) = loc_count.checked_sub(reg_len.get()) else {
301 return Err(Error::InvalidRegLen)
302 };
303 Ok(max_start)
304 }
305 pub const fn new(fields: LocFields) -> Result<Loc> {
306 let LocFields {
307 kind,
308 start,
309 reg_len,
310 } = fields;
311
312 if start > const_try!(Self::max_start(kind, reg_len)) {
313 Err(Error::StartNotInValidRange)
314 } else {
315 Ok(Self(fields))
316 }
317 }
318 pub const fn conflicts(self, other: Loc) -> bool {
319 self.0.kind.const_eq(other.0.kind)
320 && self.0.start < other.0.stop().get()
321 && other.0.start < self.0.stop().get()
322 }
323 pub const fn get_sub_loc_at_offset(self, sub_loc_ty: Ty, offset: u32) -> Result<Self> {
324 if !sub_loc_ty.get().base_ty.const_eq(self.get().kind.base_ty()) {
325 return Err(Error::BaseTyMismatch);
326 }
327 let Some(stop) = sub_loc_ty.get().reg_len.checked_add(offset) else {
328 return Err(Error::InvalidSubLocOutOfRange)
329 };
330 if stop.get() > self.get().reg_len.get() {
331 Err(Error::InvalidSubLocOutOfRange)
332 } else {
333 Self::new(LocFields {
334 kind: self.get().kind,
335 start: self.get().start + offset,
336 reg_len: sub_loc_ty.get().reg_len,
337 })
338 }
339 }
340 /// get the Loc containing `self` such that:
341 /// `retval.get_sub_loc_at_offset(self.ty(), offset) == self`
342 /// and `retval.ty() == super_loc_ty`
343 pub const fn get_super_loc_with_self_at_offset(
344 self,
345 super_loc_ty: Ty,
346 offset: u32,
347 ) -> Result<Self> {
348 if !super_loc_ty
349 .get()
350 .base_ty
351 .const_eq(self.get().kind.base_ty())
352 {
353 return Err(Error::BaseTyMismatch);
354 }
355 let Some(stop) = self.get().reg_len.checked_add(offset) else {
356 return Err(Error::InvalidSubLocOutOfRange)
357 };
358 if stop.get() > super_loc_ty.get().reg_len.get() {
359 Err(Error::InvalidSubLocOutOfRange)
360 } else {
361 Self::new(LocFields {
362 kind: self.get().kind,
363 start: self.get().start - offset,
364 reg_len: super_loc_ty.get().reg_len,
365 })
366 }
367 }
368 pub const SPECIAL_GPRS: &[Loc] = &[
369 Loc(LocFields {
370 kind: LocKind::Gpr,
371 start: 0,
372 reg_len: nzu32_lit!(1),
373 }),
374 Loc(LocFields {
375 kind: LocKind::Gpr,
376 start: 1,
377 reg_len: nzu32_lit!(1),
378 }),
379 Loc(LocFields {
380 kind: LocKind::Gpr,
381 start: 2,
382 reg_len: nzu32_lit!(1),
383 }),
384 Loc(LocFields {
385 kind: LocKind::Gpr,
386 start: 13,
387 reg_len: nzu32_lit!(1),
388 }),
389 ];
390 pub fn sub_locs(
391 self,
392 ) -> impl Iterator<Item = SubLoc> + FusedIterator + ExactSizeIterator + DoubleEndedIterator
393 {
394 let LocFields {
395 reg_len: _,
396 kind,
397 start,
398 } = *self;
399 (start..self.stop().get()).map(move |start| SubLoc(SubLocFields { kind, start }))
400 }
401 }
402
403 #[cfg(test)]
404 mod tests {
405 use super::*;
406
407 #[test]
408 fn test_base_ty_loc_kinds() {
409 for loc_kind in 0..LocKind::LENGTH {
410 let loc_kind = LocKind::from_usize(loc_kind);
411 let base_ty = loc_kind.base_ty();
412 let loc_kinds = base_ty.loc_kinds();
413 assert!(
414 loc_kinds.contains(&loc_kind),
415 "loc_kind:{loc_kind:?} base_ty:{base_ty:?} loc_kinds:{loc_kinds:?}"
416 );
417 }
418 for base_ty in 0..BaseTy::LENGTH {
419 let base_ty = BaseTy::from_usize(base_ty);
420 let loc_kinds = base_ty.loc_kinds();
421 for &loc_kind in loc_kinds {
422 assert_eq!(
423 loc_kind.base_ty(),
424 base_ty,
425 "loc_kind:{loc_kind:?} base_ty:{base_ty:?} loc_kinds:{loc_kinds:?}"
426 );
427 }
428 }
429 }
430 }