9af021e118d54e3c74bb1fcbfcfe62702975799e
[bigint-presentation-code.git] / register_allocator / src / loc.rs
1 use crate::error::{Error, Result};
2 use enum_map::Enum;
3 use serde::{Deserialize, Serialize};
4 use std::num::NonZeroU32;
5
6 #[derive(
7 Serialize, Deserialize, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Enum,
8 )]
9 #[repr(u8)]
10 pub enum LocKind {
11 Gpr,
12 StackBits64,
13 Ca,
14 VlMaxvl,
15 }
16
17 impl LocKind {
18 /// since `==` doesn't work with enums in const context
19 pub const fn const_eq(self, other: Self) -> bool {
20 self as u8 == other as u8
21 }
22 pub const fn base_ty(self) -> BaseTy {
23 match self {
24 Self::Gpr | Self::StackBits64 => BaseTy::Bits64,
25 Self::Ca => BaseTy::Ca,
26 Self::VlMaxvl => BaseTy::VlMaxvl,
27 }
28 }
29
30 pub const fn loc_count(self) -> NonZeroU32 {
31 match self {
32 Self::StackBits64 => nzu32_lit!(512),
33 Self::Gpr | Self::Ca | Self::VlMaxvl => self.base_ty().max_reg_len(),
34 }
35 }
36 }
37
38 #[derive(
39 Serialize, Deserialize, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Enum,
40 )]
41 #[repr(u8)]
42 pub enum BaseTy {
43 Bits64,
44 Ca,
45 VlMaxvl,
46 }
47
48 impl BaseTy {
49 /// since `==` doesn't work with enums in const context
50 pub const fn const_eq(self, other: Self) -> bool {
51 self as u8 == other as u8
52 }
53
54 pub const fn only_scalar(self) -> bool {
55 self.max_reg_len().get() == 1
56 }
57
58 pub const fn max_reg_len(self) -> NonZeroU32 {
59 match self {
60 Self::Bits64 => nzu32_lit!(128),
61 Self::Ca | Self::VlMaxvl => nzu32_lit!(1),
62 }
63 }
64 }
65
66 validated_fields! {
67 #[fields_ty = TyFields]
68 #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)]
69 pub struct Ty {
70 pub base_ty: BaseTy,
71 pub reg_len: NonZeroU32,
72 }
73 }
74
75 impl Ty {
76 pub const fn new(fields: TyFields) -> Result<Ty> {
77 let TyFields { base_ty, reg_len } = fields;
78 if base_ty.only_scalar() && reg_len.get() != 1 {
79 Err(Error::TriedToCreateVectorOfOnlyScalarType { base_ty })
80 } else if reg_len.get() > base_ty.max_reg_len().get() {
81 Err(Error::RegLenOutOfRange)
82 } else {
83 Ok(Self(fields))
84 }
85 }
86 }
87
88 validated_fields! {
89 #[fields_ty = LocFields]
90 #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)]
91 pub struct Loc {
92 pub kind: LocKind,
93 pub start: u32,
94 pub reg_len: NonZeroU32,
95 }
96 }
97
98 impl LocFields {
99 pub const fn ty(self) -> Result<Ty> {
100 Ty::new(TyFields {
101 base_ty: self.kind.base_ty(),
102 reg_len: self.reg_len,
103 })
104 }
105 pub const fn stop(self) -> NonZeroU32 {
106 const_unwrap_opt!(self.reg_len.checked_add(self.start), "overflow")
107 }
108 }
109
110 impl Loc {
111 pub const fn ty(self) -> Ty {
112 const_unwrap_res!(self.0.ty(), "Loc can only be constructed with valid fields")
113 }
114 pub const fn max_start(kind: LocKind, reg_len: NonZeroU32) -> Result<u32, Error> {
115 // validate Ty
116 const_try!(Ty::new(TyFields {
117 base_ty: kind.base_ty(),
118 reg_len
119 }));
120 let loc_count: u32 = kind.loc_count().get();
121 let Some(max_start) = loc_count.checked_sub(reg_len.get()) else {
122 return Err(Error::InvalidRegLen)
123 };
124 Ok(max_start)
125 }
126 pub const fn new(fields: LocFields) -> Result<Loc> {
127 let LocFields {
128 kind,
129 start,
130 reg_len,
131 } = fields;
132
133 if start > const_try!(Self::max_start(kind, reg_len)) {
134 Err(Error::StartNotInValidRange)
135 } else {
136 Ok(Self(fields))
137 }
138 }
139 pub const fn conflicts(self, other: Loc) -> bool {
140 self.0.kind.const_eq(other.0.kind)
141 && self.0.start < other.0.stop().get()
142 && other.0.start < self.0.stop().get()
143 }
144 pub const fn get_sub_loc_at_offset(self, sub_loc_ty: Ty, offset: u32) -> Result<Self> {
145 if !sub_loc_ty.get().base_ty.const_eq(self.get().kind.base_ty()) {
146 return Err(Error::BaseTyMismatch);
147 }
148 let Some(stop) = sub_loc_ty.get().reg_len.checked_add(offset) else {
149 return Err(Error::InvalidSubLocOutOfRange)
150 };
151 if stop.get() > self.get().reg_len.get() {
152 Err(Error::InvalidSubLocOutOfRange)
153 } else {
154 Self::new(LocFields {
155 kind: self.get().kind,
156 start: self.get().start + offset,
157 reg_len: sub_loc_ty.get().reg_len,
158 })
159 }
160 }
161 /// get the Loc containing `self` such that:
162 /// `retval.get_sub_loc_at_offset(self.ty(), offset) == self`
163 /// and `retval.ty() == super_loc_ty`
164 pub const fn get_super_loc_with_self_at_offset(
165 self,
166 super_loc_ty: Ty,
167 offset: u32,
168 ) -> Result<Self> {
169 if !super_loc_ty
170 .get()
171 .base_ty
172 .const_eq(self.get().kind.base_ty())
173 {
174 return Err(Error::BaseTyMismatch);
175 }
176 let Some(stop) = self.get().reg_len.checked_add(offset) else {
177 return Err(Error::InvalidSubLocOutOfRange)
178 };
179 if stop.get() > super_loc_ty.get().reg_len.get() {
180 Err(Error::InvalidSubLocOutOfRange)
181 } else {
182 Self::new(LocFields {
183 kind: self.get().kind,
184 start: self.get().start - offset,
185 reg_len: super_loc_ty.get().reg_len,
186 })
187 }
188 }
189 pub const SPECIAL_GPRS: &[Loc] = &[
190 Loc(LocFields {
191 kind: LocKind::Gpr,
192 start: 0,
193 reg_len: nzu32_lit!(1),
194 }),
195 Loc(LocFields {
196 kind: LocKind::Gpr,
197 start: 1,
198 reg_len: nzu32_lit!(1),
199 }),
200 Loc(LocFields {
201 kind: LocKind::Gpr,
202 start: 2,
203 reg_len: nzu32_lit!(1),
204 }),
205 Loc(LocFields {
206 kind: LocKind::Gpr,
207 start: 13,
208 reg_len: nzu32_lit!(1),
209 }),
210 ];
211 }