Function verification should be complete, no tests yet
[bigint-presentation-code.git] / register_allocator / src / loc.rs
1 use crate::error::{Error, Result};
2 use enum_map::Enum;
3 use serde::{Deserialize, Serialize};
4 use std::num::NonZeroU32;
5
6 #[derive(
7 Serialize, Deserialize, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Enum,
8 )]
9 #[repr(u8)]
10 pub enum LocKind {
11 Gpr,
12 StackBits64,
13 Ca,
14 VlMaxvl,
15 }
16
17 impl LocKind {
18 /// since `==` doesn't work with enums in const context
19 pub const fn const_eq(self, other: Self) -> bool {
20 self as u8 == other as u8
21 }
22 pub const fn base_ty(self) -> BaseTy {
23 match self {
24 Self::Gpr | Self::StackBits64 => BaseTy::Bits64,
25 Self::Ca => BaseTy::Ca,
26 Self::VlMaxvl => BaseTy::VlMaxvl,
27 }
28 }
29
30 pub const fn loc_count(self) -> NonZeroU32 {
31 match self {
32 Self::StackBits64 => nzu32_lit!(512),
33 Self::Gpr | Self::Ca | Self::VlMaxvl => self.base_ty().max_reg_len(),
34 }
35 }
36 }
37
38 #[derive(
39 Serialize, Deserialize, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Enum,
40 )]
41 #[repr(u8)]
42 pub enum BaseTy {
43 Bits64,
44 Ca,
45 VlMaxvl,
46 }
47
48 impl BaseTy {
49 /// since `==` doesn't work with enums in const context
50 pub const fn const_eq(self, other: Self) -> bool {
51 self as u8 == other as u8
52 }
53
54 pub const fn only_scalar(self) -> bool {
55 self.max_reg_len().get() == 1
56 }
57
58 pub const fn max_reg_len(self) -> NonZeroU32 {
59 match self {
60 Self::Bits64 => nzu32_lit!(128),
61 Self::Ca | Self::VlMaxvl => nzu32_lit!(1),
62 }
63 }
64 }
65
66 validated_fields! {
67 #[fields_ty = TyFields]
68 #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)]
69 pub struct Ty {
70 pub base_ty: BaseTy,
71 pub reg_len: NonZeroU32,
72 }
73 }
74
75 impl Ty {
76 pub const fn new(fields: TyFields) -> Result<Ty> {
77 let TyFields { base_ty, reg_len } = fields;
78 if base_ty.only_scalar() && reg_len.get() != 1 {
79 Err(Error::TriedToCreateVectorOfOnlyScalarType { base_ty })
80 } else if reg_len.get() > base_ty.max_reg_len().get() {
81 Err(Error::RegLenOutOfRange)
82 } else {
83 Ok(Self(fields))
84 }
85 }
86 /// returns the `Ty` for `fields` or if there was an error, returns the corresponding scalar type (where `reg_len` is `1`)
87 pub const fn new_or_scalar(fields: TyFields) -> Self {
88 match Self::new(fields) {
89 Ok(v) => v,
90 Err(_) => Self::scalar(fields.base_ty),
91 }
92 }
93 pub const fn scalar(base_ty: BaseTy) -> Self {
94 Self(TyFields {
95 base_ty,
96 reg_len: nzu32_lit!(1),
97 })
98 }
99 pub const fn bits64(reg_len: NonZeroU32) -> Self {
100 Self(TyFields {
101 base_ty: BaseTy::Bits64,
102 reg_len,
103 })
104 }
105 pub const fn try_concat(self, rhs: Self) -> Result<Ty> {
106 if !self.get().base_ty.const_eq(rhs.get().base_ty) {
107 Err(Error::BaseTyMismatch)
108 } else {
109 let Some(reg_len) = self.get().reg_len.checked_add(rhs.get().reg_len.get()) else {
110 return Err(Error::RegLenOutOfRange);
111 };
112 Ty::new(TyFields {
113 base_ty: self.get().base_ty,
114 reg_len,
115 })
116 }
117 }
118 }
119
120 validated_fields! {
121 #[fields_ty = LocFields]
122 #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)]
123 pub struct Loc {
124 pub kind: LocKind,
125 pub start: u32,
126 pub reg_len: NonZeroU32,
127 }
128 }
129
130 impl LocFields {
131 pub const fn ty(self) -> Result<Ty> {
132 Ty::new(TyFields {
133 base_ty: self.kind.base_ty(),
134 reg_len: self.reg_len,
135 })
136 }
137 pub const fn stop(self) -> NonZeroU32 {
138 const_unwrap_opt!(self.reg_len.checked_add(self.start), "overflow")
139 }
140 }
141
142 impl Loc {
143 pub const fn ty(self) -> Ty {
144 const_unwrap_res!(self.0.ty(), "Loc can only be constructed with valid fields")
145 }
146 pub const fn max_start(kind: LocKind, reg_len: NonZeroU32) -> Result<u32, Error> {
147 // validate Ty
148 const_try!(Ty::new(TyFields {
149 base_ty: kind.base_ty(),
150 reg_len
151 }));
152 let loc_count: u32 = kind.loc_count().get();
153 let Some(max_start) = loc_count.checked_sub(reg_len.get()) else {
154 return Err(Error::InvalidRegLen)
155 };
156 Ok(max_start)
157 }
158 pub const fn new(fields: LocFields) -> Result<Loc> {
159 let LocFields {
160 kind,
161 start,
162 reg_len,
163 } = fields;
164
165 if start > const_try!(Self::max_start(kind, reg_len)) {
166 Err(Error::StartNotInValidRange)
167 } else {
168 Ok(Self(fields))
169 }
170 }
171 pub const fn conflicts(self, other: Loc) -> bool {
172 self.0.kind.const_eq(other.0.kind)
173 && self.0.start < other.0.stop().get()
174 && other.0.start < self.0.stop().get()
175 }
176 pub const fn get_sub_loc_at_offset(self, sub_loc_ty: Ty, offset: u32) -> Result<Self> {
177 if !sub_loc_ty.get().base_ty.const_eq(self.get().kind.base_ty()) {
178 return Err(Error::BaseTyMismatch);
179 }
180 let Some(stop) = sub_loc_ty.get().reg_len.checked_add(offset) else {
181 return Err(Error::InvalidSubLocOutOfRange)
182 };
183 if stop.get() > self.get().reg_len.get() {
184 Err(Error::InvalidSubLocOutOfRange)
185 } else {
186 Self::new(LocFields {
187 kind: self.get().kind,
188 start: self.get().start + offset,
189 reg_len: sub_loc_ty.get().reg_len,
190 })
191 }
192 }
193 /// get the Loc containing `self` such that:
194 /// `retval.get_sub_loc_at_offset(self.ty(), offset) == self`
195 /// and `retval.ty() == super_loc_ty`
196 pub const fn get_super_loc_with_self_at_offset(
197 self,
198 super_loc_ty: Ty,
199 offset: u32,
200 ) -> Result<Self> {
201 if !super_loc_ty
202 .get()
203 .base_ty
204 .const_eq(self.get().kind.base_ty())
205 {
206 return Err(Error::BaseTyMismatch);
207 }
208 let Some(stop) = self.get().reg_len.checked_add(offset) else {
209 return Err(Error::InvalidSubLocOutOfRange)
210 };
211 if stop.get() > super_loc_ty.get().reg_len.get() {
212 Err(Error::InvalidSubLocOutOfRange)
213 } else {
214 Self::new(LocFields {
215 kind: self.get().kind,
216 start: self.get().start - offset,
217 reg_len: super_loc_ty.get().reg_len,
218 })
219 }
220 }
221 pub const SPECIAL_GPRS: &[Loc] = &[
222 Loc(LocFields {
223 kind: LocKind::Gpr,
224 start: 0,
225 reg_len: nzu32_lit!(1),
226 }),
227 Loc(LocFields {
228 kind: LocKind::Gpr,
229 start: 1,
230 reg_len: nzu32_lit!(1),
231 }),
232 Loc(LocFields {
233 kind: LocKind::Gpr,
234 start: 2,
235 reg_len: nzu32_lit!(1),
236 }),
237 Loc(LocFields {
238 kind: LocKind::Gpr,
239 start: 13,
240 reg_len: nzu32_lit!(1),
241 }),
242 ];
243 }