new ra wip
[bigint-presentation-code.git] / register_allocator / src / loc_set.rs
1 use crate::{
2 error::{Error, Result},
3 interned::{GlobalState, Intern, InternTarget, Interned},
4 loc::{BaseTy, Loc, LocFields, LocKind, Ty, TyFields},
5 };
6 use enum_map::{enum_map, EnumMap};
7 use num_bigint::BigUint;
8 use num_traits::Zero;
9 use serde::{Deserialize, Serialize};
10 use std::{
11 borrow::{Borrow, Cow},
12 cell::Cell,
13 fmt,
14 hash::Hash,
15 iter::{FusedIterator, Peekable},
16 num::NonZeroU32,
17 ops::{
18 BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, ControlFlow, Range, Sub,
19 SubAssign,
20 },
21 };
22
23 #[derive(Deserialize)]
24 struct LocSetSerialized {
25 starts: EnumMap<LocKind, BigUint>,
26 ty: Option<Ty>,
27 }
28
29 impl TryFrom<LocSetSerialized> for LocSet {
30 type Error = Error;
31
32 fn try_from(value: LocSetSerialized) -> Result<Self, Self::Error> {
33 Self::from_parts(value.starts, value.ty)
34 }
35 }
36
37 #[derive(Clone, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
38 #[serde(try_from = "LocSetSerialized")]
39 pub struct LocSet {
40 starts: EnumMap<LocKind, BigUint>,
41 ty: Option<Ty>,
42 }
43
44 /// computes same value as `a & !b`, but more efficiently
45 fn and_not<A: Borrow<BigUint>, B>(a: A, b: B) -> BigUint
46 where
47 BigUint: for<'a> BitXor<A, Output = BigUint>,
48 B: for<'a> BitAnd<&'a BigUint, Output = BigUint>,
49 {
50 // use logical equivalent that avoids needing to use BigInt
51 (b & a.borrow()) ^ a
52 }
53
54 impl LocSet {
55 pub fn starts(&self) -> &EnumMap<LocKind, BigUint> {
56 &self.starts
57 }
58 pub fn stops(&self) -> EnumMap<LocKind, BigUint> {
59 let Some(ty) = self.ty else {
60 return EnumMap::default();
61 };
62 enum_map! {kind => &self.starts[kind] << ty.reg_len.get()}
63 }
64 pub fn ty(&self) -> Option<Ty> {
65 self.ty
66 }
67 pub fn kinds(&self) -> impl Iterator<Item = LocKind> + '_ {
68 self.starts
69 .iter()
70 .filter_map(|(kind, starts)| if starts.is_zero() { None } else { Some(kind) })
71 }
72 pub fn reg_len(&self) -> Option<NonZeroU32> {
73 self.ty.map(|v| v.reg_len)
74 }
75 pub fn base_ty(&self) -> Option<BaseTy> {
76 self.ty.map(|v| v.base_ty)
77 }
78 pub fn new() -> Self {
79 Self::default()
80 }
81 pub fn from_parts(starts: EnumMap<LocKind, BigUint>, ty: Option<Ty>) -> Result<Self> {
82 let mut empty = true;
83 for (kind, starts) in &starts {
84 if !starts.is_zero() {
85 empty = false;
86 let expected_ty = Ty::new(TyFields {
87 base_ty: kind.base_ty(),
88 reg_len: ty.map(|v| v.reg_len).unwrap_or(nzu32_lit!(1)),
89 })
90 .unwrap_or_else(|_| {
91 Ty::new(TyFields {
92 base_ty: kind.base_ty(),
93 reg_len: nzu32_lit!(1),
94 })
95 .unwrap()
96 });
97 if ty != Some(expected_ty) {
98 return Err(Error::TyMismatch {
99 ty,
100 expected_ty: Some(expected_ty),
101 });
102 }
103 // bits() is one past max bit set, so use >= rather than >
104 if starts.bits() >= Loc::max_start(kind, expected_ty.reg_len)? as u64 {
105 return Err(Error::StartNotInValidRange);
106 }
107 }
108 }
109 if empty && ty.is_some() {
110 Err(Error::TyMismatch {
111 ty,
112 expected_ty: None,
113 })
114 } else {
115 Ok(Self { starts, ty })
116 }
117 }
118 pub fn clear(&mut self) {
119 for v in self.starts.values_mut() {
120 v.assign_from_slice(&[]);
121 }
122 }
123 pub fn contains(&self, value: Loc) -> bool {
124 Some(value.ty()) == self.ty && self.starts[value.kind].bit(value.start as _)
125 }
126 pub fn try_insert(&mut self, value: Loc) -> Result<bool> {
127 if self.is_empty() {
128 self.ty = Some(value.ty());
129 self.starts[value.kind].set_bit(value.start as u64, true);
130 return Ok(true);
131 };
132 let ty = Some(value.ty());
133 if ty != self.ty {
134 return Err(Error::TyMismatch {
135 ty,
136 expected_ty: self.ty,
137 });
138 }
139 let retval = !self.starts[value.kind].bit(value.start as u64);
140 self.starts[value.kind].set_bit(value.start as u64, true);
141 Ok(retval)
142 }
143 pub fn insert(&mut self, value: Loc) -> bool {
144 self.try_insert(value).unwrap()
145 }
146 pub fn remove(&mut self, value: Loc) -> bool {
147 if self.contains(value) {
148 self.starts[value.kind].set_bit(value.start as u64, false);
149 if self.starts.values().all(BigUint::is_zero) {
150 self.ty = None;
151 }
152 true
153 } else {
154 false
155 }
156 }
157 pub fn is_empty(&self) -> bool {
158 self.ty.is_none()
159 }
160 pub fn is_disjoint(&self, other: &LocSet) -> bool {
161 if self.ty != other.ty || self.is_empty() {
162 return true;
163 }
164 for (k, lhs) in self.starts.iter() {
165 let rhs = &other.starts[k];
166 if !(lhs & rhs).is_zero() {
167 return false;
168 }
169 }
170 true
171 }
172 pub fn is_subset(&self, containing_set: &LocSet) -> bool {
173 if self.is_empty() {
174 return true;
175 }
176 if self.ty != containing_set.ty {
177 return false;
178 }
179 for (k, v) in self.starts.iter() {
180 let containing_set = &containing_set.starts[k];
181 if !and_not(v, containing_set).is_zero() {
182 return false;
183 }
184 }
185 true
186 }
187 pub fn is_superset(&self, contained_set: &LocSet) -> bool {
188 contained_set.is_subset(self)
189 }
190 pub fn iter(&self) -> Iter<'_> {
191 if let Some(ty) = self.ty {
192 let mut starts = self.starts.iter().peekable();
193 Iter {
194 internals: Some(IterInternals {
195 ty,
196 start_range: get_start_range(starts.peek()),
197 starts,
198 }),
199 }
200 } else {
201 Iter { internals: None }
202 }
203 }
204 pub fn len(&self) -> usize {
205 let retval: u64 = self.starts.values().map(BigUint::count_ones).sum();
206 retval as usize
207 }
208 }
209
210 #[derive(Clone, Debug)]
211 struct IterInternals<I, T>
212 where
213 I: Iterator<Item = (LocKind, T)>,
214 T: Clone + Borrow<BigUint>,
215 {
216 ty: Ty,
217 starts: Peekable<I>,
218 start_range: Range<u32>,
219 }
220
221 impl<I, T> IterInternals<I, T>
222 where
223 I: Iterator<Item = (LocKind, T)>,
224 T: Clone + Borrow<BigUint>,
225 {
226 fn next(&mut self) -> Option<Loc> {
227 let IterInternals {
228 ty,
229 ref mut starts,
230 ref mut start_range,
231 } = *self;
232 loop {
233 let (kind, ref v) = *starts.peek()?;
234 let Some(start) = start_range.next() else {
235 starts.next();
236 *start_range = get_start_range(starts.peek());
237 continue;
238 };
239 if v.borrow().bit(start as u64) {
240 return Some(
241 Loc::new(LocFields {
242 kind,
243 start,
244 reg_len: ty.reg_len,
245 })
246 .expect("known to be valid"),
247 );
248 }
249 }
250 }
251 }
252
253 fn get_start_range(v: Option<&(LocKind, impl Borrow<BigUint>)>) -> Range<u32> {
254 0..v.map(|(_, v)| v.borrow().bits() as u32).unwrap_or(0)
255 }
256
257 #[derive(Clone, Debug)]
258 pub struct Iter<'a> {
259 internals: Option<IterInternals<enum_map::Iter<'a, LocKind, BigUint>, &'a BigUint>>,
260 }
261
262 impl Iterator for Iter<'_> {
263 type Item = Loc;
264
265 fn next(&mut self) -> Option<Self::Item> {
266 self.internals.as_mut()?.next()
267 }
268 }
269
270 impl FusedIterator for Iter<'_> {}
271
272 pub struct IntoIter {
273 internals: Option<IterInternals<enum_map::IntoIter<LocKind, BigUint>, BigUint>>,
274 }
275
276 impl Iterator for IntoIter {
277 type Item = Loc;
278
279 fn next(&mut self) -> Option<Self::Item> {
280 self.internals.as_mut()?.next()
281 }
282 }
283
284 impl FusedIterator for IntoIter {}
285
286 impl IntoIterator for LocSet {
287 type Item = Loc;
288 type IntoIter = IntoIter;
289
290 fn into_iter(self) -> Self::IntoIter {
291 if let Some(ty) = self.ty {
292 let mut starts = self.starts.into_iter().peekable();
293 IntoIter {
294 internals: Some(IterInternals {
295 ty,
296 start_range: get_start_range(starts.peek()),
297 starts,
298 }),
299 }
300 } else {
301 IntoIter { internals: None }
302 }
303 }
304 }
305
306 impl<'a> IntoIterator for &'a LocSet {
307 type Item = Loc;
308 type IntoIter = Iter<'a>;
309
310 fn into_iter(self) -> Self::IntoIter {
311 self.iter()
312 }
313 }
314
315 impl Extend<Loc> for LocSet {
316 fn extend<T: IntoIterator<Item = Loc>>(&mut self, iter: T) {
317 iter.into_iter().for_each(|item| {
318 self.insert(item);
319 });
320 }
321 }
322
323 impl<E: From<Error>> Extend<Loc> for Result<LocSet, E> {
324 fn extend<T: IntoIterator<Item = Loc>>(&mut self, iter: T) {
325 iter.into_iter().try_for_each(|item| {
326 let Ok(loc_set) = self else {
327 return ControlFlow::Break(());
328 };
329 match loc_set.try_insert(item) {
330 Ok(_) => ControlFlow::Continue(()),
331 Err(e) => {
332 *self = Err(e.into());
333 ControlFlow::Break(())
334 }
335 }
336 });
337 }
338 }
339
340 impl FromIterator<Loc> for LocSet {
341 fn from_iter<T: IntoIterator<Item = Loc>>(iter: T) -> Self {
342 let mut retval = LocSet::new();
343 retval.extend(iter);
344 retval
345 }
346 }
347
348 impl<E: From<Error>> FromIterator<Loc> for Result<LocSet, E> {
349 fn from_iter<T: IntoIterator<Item = Loc>>(iter: T) -> Self {
350 let mut retval = Ok(LocSet::new());
351 retval.extend(iter);
352 retval
353 }
354 }
355
356 struct HexBigUint<'a>(&'a BigUint);
357
358 impl fmt::Debug for HexBigUint<'_> {
359 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
360 write!(f, "{:#x}", self.0)
361 }
362 }
363
364 struct LocSetStarts<'a>(&'a EnumMap<LocKind, BigUint>);
365
366 impl fmt::Debug for LocSetStarts<'_> {
367 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
368 f.debug_map()
369 .entries(self.0.iter().map(|(k, v)| (k, HexBigUint(v))))
370 .finish()
371 }
372 }
373
374 impl fmt::Debug for LocSet {
375 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
376 f.debug_struct("LocSet")
377 .field("starts", &self.starts)
378 .finish()
379 }
380 }
381
382 macro_rules! impl_bin_op {
383 (
384 $bin_op:ident::$bin_op_fn:ident(),
385 $bin_assign_op:ident::$bin_assign_op_fn:ident(),
386 $starts_op:expr,
387 $handle_unequal_types:expr,
388 $update_unequal_types:expr,
389 ) => {
390 impl $bin_op<&'_ LocSet> for &'_ LocSet {
391 type Output = LocSet;
392
393 fn $bin_op_fn(self, rhs: &'_ LocSet) -> Self::Output {
394 if self.ty != rhs.ty {
395 $handle_unequal_types(self, Cow::<LocSet>::Borrowed(rhs))
396 } else {
397 LocSet {
398 starts: enum_map! {kind => $starts_op(&self.starts[kind], &rhs.starts[kind])},
399 ty: self.ty,
400 }
401 }
402 }
403 }
404
405 impl $bin_assign_op<&'_ LocSet> for LocSet {
406 fn $bin_assign_op_fn(&mut self, rhs: &'_ LocSet) {
407 if self.ty != rhs.ty {
408 $update_unequal_types(self, rhs);
409 } else {
410 for (kind, starts) in &mut self.starts {
411 let v: BigUint = std::mem::take(starts);
412 *starts = $starts_op(v, &rhs.starts[kind]);
413 }
414 }
415 }
416 }
417
418 impl $bin_assign_op<LocSet> for LocSet {
419 fn $bin_assign_op_fn(&mut self, rhs: LocSet) {
420 self.$bin_assign_op_fn(&rhs);
421 }
422 }
423
424 impl $bin_op<&'_ LocSet> for LocSet {
425 type Output = LocSet;
426
427 fn $bin_op_fn(mut self, rhs: &'_ LocSet) -> Self::Output {
428 self.$bin_assign_op_fn(rhs);
429 self
430 }
431 }
432
433 impl $bin_op<LocSet> for LocSet {
434 type Output = LocSet;
435
436 fn $bin_op_fn(mut self, rhs: LocSet) -> Self::Output {
437 self.$bin_assign_op_fn(rhs);
438 self
439 }
440 }
441
442 impl $bin_op<LocSet> for &'_ LocSet {
443 type Output = LocSet;
444
445 fn $bin_op_fn(self, mut rhs: LocSet) -> Self::Output {
446 if self.ty != rhs.ty {
447 $handle_unequal_types(self, Cow::<LocSet>::Owned(rhs))
448 } else {
449 for (kind, starts) in &mut rhs.starts {
450 *starts = $starts_op(&self.starts[kind], std::mem::take(starts));
451 }
452 rhs
453 }
454 }
455 }
456 };
457 }
458
459 impl_bin_op! {
460 BitAnd::bitand(),
461 BitAndAssign::bitand_assign(),
462 BitAnd::bitand,
463 |_, _| LocSet::new(),
464 |lhs, _| LocSet::clear(lhs),
465 }
466
467 impl_bin_op! {
468 BitOr::bitor(),
469 BitOrAssign::bitor_assign(),
470 BitOr::bitor,
471 |lhs: &LocSet, rhs: Cow<LocSet>| panic!("{}", Error::TyMismatch { ty: rhs.ty, expected_ty: lhs.ty }),
472 |lhs: &mut LocSet, rhs: &LocSet| panic!("{}", Error::TyMismatch { ty: rhs.ty, expected_ty: lhs.ty }),
473 }
474
475 impl_bin_op! {
476 BitXor::bitxor(),
477 BitXorAssign::bitxor_assign(),
478 BitXor::bitxor,
479 |lhs: &LocSet, rhs: Cow<LocSet>| panic!("{}", Error::TyMismatch { ty: rhs.ty, expected_ty: lhs.ty }),
480 |lhs: &mut LocSet, rhs: &LocSet| panic!("{}", Error::TyMismatch { ty: rhs.ty, expected_ty: lhs.ty }),
481 }
482
483 impl_bin_op! {
484 Sub::sub(),
485 SubAssign::sub_assign(),
486 and_not,
487 |lhs: &LocSet, _| lhs.clone(),
488 |_, _| {},
489 }
490
491 /// the largest number of Locs in `lhs` that a single Loc
492 /// from `rhs` can conflict with
493 #[derive(Clone)]
494 pub struct LocSetMaxConflictsWith<Rhs> {
495 lhs: Interned<LocSet>,
496 rhs: Rhs,
497 // result is not included in equality or hash
498 result: Cell<Option<u32>>,
499 }
500
501 impl<Rhs: Hash> Hash for LocSetMaxConflictsWith<Rhs> {
502 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
503 self.lhs.hash(state);
504 self.rhs.hash(state);
505 }
506 }
507
508 impl<Rhs: Eq> Eq for LocSetMaxConflictsWith<Rhs> {}
509
510 impl<Rhs: PartialEq> PartialEq for LocSetMaxConflictsWith<Rhs> {
511 fn eq(&self, other: &Self) -> bool {
512 self.lhs == other.lhs && self.rhs == other.rhs
513 }
514 }
515
516 pub trait LocSetMaxConflictsWithTrait: Clone {
517 fn intern(
518 v: LocSetMaxConflictsWith<Self>,
519 global_state: &GlobalState,
520 ) -> Interned<LocSetMaxConflictsWith<Self>>;
521 fn compute_result(lhs: &Interned<LocSet>, rhs: &Self, global_state: &GlobalState) -> u32;
522 }
523
524 impl LocSetMaxConflictsWithTrait for Loc {
525 fn compute_result(lhs: &Interned<LocSet>, rhs: &Self, _global_state: &GlobalState) -> u32 {
526 // now we do the equivalent of:
527 // return lhs.iter().map(|loc| rhs.conflicts(loc) as u32).sum().unwrap_or(0)
528 let Some(reg_len) = lhs.reg_len() else {
529 return 0;
530 };
531 let starts = &lhs.starts[rhs.kind];
532 if starts.is_zero() {
533 return 0;
534 }
535 // now we do the equivalent of:
536 // return sum(rhs.start < start + reg_len
537 // and start < rhs.start + rhs.reg_len
538 // for start in starts)
539 let stops = starts << reg_len.get();
540
541 // find all the bit indexes `i` where `i < rhs.start + 1`
542 let lt_rhs_start_plus_1 = (BigUint::from(1u32) << (rhs.start + 1)) - 1u32;
543
544 // find all the bit indexes `i` where
545 // `i < rhs.start + rhs.reg_len + reg_len`
546 let lt_rhs_start_plus_rhs_reg_len_plus_reg_len =
547 (BigUint::from(1u32) << (rhs.start + rhs.reg_len.get() + reg_len.get())) - 1u32;
548 let mut included = and_not(&stops, &stops & lt_rhs_start_plus_1);
549 included &= lt_rhs_start_plus_rhs_reg_len_plus_reg_len;
550 included.count_ones() as u32
551 }
552
553 fn intern(
554 v: LocSetMaxConflictsWith<Self>,
555 global_state: &GlobalState,
556 ) -> Interned<LocSetMaxConflictsWith<Self>> {
557 v.into_interned(global_state)
558 }
559 }
560
561 impl LocSetMaxConflictsWithTrait for Interned<LocSet> {
562 fn compute_result(lhs: &Interned<LocSet>, rhs: &Self, global_state: &GlobalState) -> u32 {
563 rhs.iter()
564 .map(|loc| lhs.clone().max_conflicts_with(loc, global_state))
565 .max()
566 .unwrap_or(0)
567 }
568
569 fn intern(
570 v: LocSetMaxConflictsWith<Self>,
571 global_state: &GlobalState,
572 ) -> Interned<LocSetMaxConflictsWith<Self>> {
573 v.into_interned(global_state)
574 }
575 }
576
577 impl<Rhs: LocSetMaxConflictsWithTrait> LocSetMaxConflictsWith<Rhs> {
578 pub fn lhs(&self) -> &Interned<LocSet> {
579 &self.lhs
580 }
581 pub fn rhs(&self) -> &Rhs {
582 &self.rhs
583 }
584 pub fn result(&self, global_state: &GlobalState) -> u32 {
585 match self.result.get() {
586 Some(v) => v,
587 None => {
588 let retval = Rhs::compute_result(&self.lhs, &self.rhs, global_state);
589 self.result.set(Some(retval));
590 retval
591 }
592 }
593 }
594 }
595
596 impl Interned<LocSet> {
597 pub fn max_conflicts_with<Rhs>(self, rhs: Rhs, global_state: &GlobalState) -> u32
598 where
599 Rhs: LocSetMaxConflictsWithTrait,
600 LocSetMaxConflictsWith<Rhs>: InternTarget,
601 {
602 LocSetMaxConflictsWithTrait::intern(
603 LocSetMaxConflictsWith {
604 lhs: self,
605 rhs,
606 result: Cell::default(),
607 },
608 global_state,
609 )
610 .result(global_state)
611 }
612 }
613
614 #[cfg(test)]
615 mod tests {
616 use super::*;
617
618 #[test]
619 fn test_and_not() {
620 for a in 0..0x10u32 {
621 for b in 0..0x10 {
622 assert_eq!(
623 and_not(&BigUint::from(a), BigUint::from(b)),
624 (a & !b).into()
625 );
626 }
627 }
628 }
629 }