wip
[bigint-presentation-code.git] / register_allocator / src / loc_set.rs
1 use crate::{
2 error::{Error, Result},
3 interned::{GlobalState, Intern, InternTarget, Interned},
4 loc::{BaseTy, Loc, LocFields, LocKind, Ty, TyFields},
5 };
6 use enum_map::{enum_map, EnumMap};
7 use num_bigint::BigUint;
8 use num_traits::Zero;
9 use serde::{Deserialize, Serialize};
10 use std::{
11 borrow::{Borrow, Cow},
12 cell::Cell,
13 collections::BTreeMap,
14 fmt,
15 hash::Hash,
16 iter::{FusedIterator, Peekable},
17 num::NonZeroU32,
18 ops::{
19 BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, ControlFlow, Range, Sub,
20 SubAssign,
21 },
22 };
23
24 #[derive(Deserialize)]
25 struct LocSetSerialized {
26 reg_len_to_starts_map: BTreeMap<NonZeroU32, EnumMap<LocKind, BigUint>>,
27 }
28
29 impl TryFrom<LocSetSerialized> for LocSet {
30 type Error = Error;
31
32 fn try_from(value: LocSetSerialized) -> Result<Self, Self::Error> {
33 Self::from_reg_len_to_starts_map(value.reg_len_to_starts_map)
34 }
35 }
36
37 #[derive(Clone, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
38 #[serde(try_from = "LocSetSerialized")]
39 pub struct LocSet {
40 reg_len_to_starts_map: BTreeMap<NonZeroU32, EnumMap<LocKind, BigUint>>,
41 }
42
43 /// computes same value as `a & !b`, but more efficiently
44 fn and_not<A: Borrow<BigUint>, B>(a: A, b: B) -> BigUint
45 where
46 BigUint: for<'a> BitXor<A, Output = BigUint>,
47 B: for<'a> BitAnd<&'a BigUint, Output = BigUint>,
48 {
49 // use logical equivalent that avoids needing to use BigInt
50 (b & a.borrow()) ^ a
51 }
52
53 impl From<Loc> for LocSet {
54 fn from(value: Loc) -> Self {
55 Self::from_iter([value])
56 }
57 }
58
59 impl LocSet {
60 pub fn arbitrary_with_ty(
61 ty: Option<Ty>,
62 u: &mut arbitrary::Unstructured<'_>,
63 ) -> arbitrary::Result<Self> {
64 let Some(ty) = ty else {
65 return Ok(Self::new());
66 };
67 let kinds = ty.base_ty.loc_kinds();
68 type Mask = u128;
69 let kinds: Vec<_> = if kinds.len() > Mask::BITS as usize {
70 let chosen_kinds = kinds
71 .iter()
72 .zip(u.arbitrary_iter::<bool>()?)
73 .filter(|(_, cond)| !matches!(cond, Ok(false)))
74 .map(|(&kind, cond)| {
75 cond?;
76 Ok(kind)
77 })
78 .collect::<arbitrary::Result<Vec<_>>>()?;
79 if chosen_kinds.is_empty() {
80 vec![*u.choose(kinds)?]
81 } else {
82 chosen_kinds
83 }
84 } else {
85 let max_mask = Mask::wrapping_shl(1, kinds.len() as u32).wrapping_sub(1);
86 let mask = u.int_in_range(1..=max_mask)?; // non-zero
87 kinds
88 .iter()
89 .enumerate()
90 .filter_map(|(idx, &kind)| {
91 if mask & (1 << idx) != 0 {
92 Some(kind)
93 } else {
94 None
95 }
96 })
97 .collect()
98 };
99 let mut starts = EnumMap::<LocKind, BigUint>::default();
100 let mut all_zero = true;
101 for kind in kinds {
102 let bit_count = Loc::max_start(kind, ty.reg_len)? + 1;
103 let byte_count = (bit_count + u8::BITS - 1) / u8::BITS;
104 let bytes = u.bytes(byte_count as usize)?;
105 starts[kind] = BigUint::from_bytes_le(bytes);
106 all_zero &= starts[kind].is_zero();
107 }
108 if all_zero {
109 Ok(Loc::arbitrary_with_ty(ty, u)?.into())
110 } else {
111 Ok(Self::from_parts(starts, Some(ty))?)
112 }
113 }
114 pub fn starts(&self) -> &EnumMap<LocKind, BigUint> {
115 &self.starts
116 }
117 pub fn stops(&self) -> EnumMap<LocKind, BigUint> {
118 let Some(ty) = self.ty else {
119 return EnumMap::default();
120 };
121 enum_map! {kind => &self.starts[kind] << ty.reg_len.get()}
122 }
123 pub fn ty(&self) -> Option<Ty> {
124 self.ty
125 }
126 pub fn kinds(&self) -> impl Iterator<Item = LocKind> + '_ {
127 self.starts
128 .iter()
129 .filter_map(|(kind, starts)| if starts.is_zero() { None } else { Some(kind) })
130 }
131 pub fn reg_len(&self) -> Option<NonZeroU32> {
132 self.ty.map(|v| v.reg_len)
133 }
134 pub fn base_ty(&self) -> Option<BaseTy> {
135 self.ty.map(|v| v.base_ty)
136 }
137 pub fn new() -> Self {
138 Self::default()
139 }
140 pub fn from_parts(starts: EnumMap<LocKind, BigUint>, ty: Option<Ty>) -> Result<Self> {
141 let mut empty = true;
142 for (kind, starts) in &starts {
143 if !starts.is_zero() {
144 empty = false;
145 let expected_ty = Ty::new_or_scalar(TyFields {
146 base_ty: kind.base_ty(),
147 reg_len: ty.map(|v| v.reg_len).unwrap_or(nzu32_lit!(1)),
148 });
149 if ty != Some(expected_ty) {
150 return Err(Error::TyMismatch {
151 ty,
152 expected_ty: Some(expected_ty),
153 });
154 }
155 // bits() is one past max bit set, so use >= rather than >
156 if starts.bits() >= Loc::max_start(kind, expected_ty.reg_len)? as u64 {
157 return Err(Error::StartNotInValidRange);
158 }
159 }
160 }
161 if empty && ty.is_some() {
162 Err(Error::TyMismatch {
163 ty,
164 expected_ty: None,
165 })
166 } else {
167 Ok(Self { starts, ty })
168 }
169 }
170 pub fn clear(&mut self) {
171 for v in self.starts.values_mut() {
172 v.assign_from_slice(&[]);
173 }
174 }
175 pub fn contains_exact(&self, value: Loc) -> bool {
176 Some(value.ty()) == self.ty && self.starts[value.kind].bit(value.start as _)
177 }
178 pub fn try_insert(&mut self, value: Loc) -> Result<bool> {
179 if self.is_empty() {
180 self.ty = Some(value.ty());
181 self.starts[value.kind].set_bit(value.start as u64, true);
182 return Ok(true);
183 };
184 let ty = Some(value.ty());
185 if ty != self.ty {
186 return Err(Error::TyMismatch {
187 ty,
188 expected_ty: self.ty,
189 });
190 }
191 let retval = !self.starts[value.kind].bit(value.start as u64);
192 self.starts[value.kind].set_bit(value.start as u64, true);
193 Ok(retval)
194 }
195 pub fn insert(&mut self, value: Loc) -> bool {
196 self.try_insert(value).unwrap()
197 }
198 pub fn remove(&mut self, value: Loc) -> bool {
199 if self.contains_exact(value) {
200 self.starts[value.kind].set_bit(value.start as u64, false);
201 if self.starts.values().all(BigUint::is_zero) {
202 self.ty = None;
203 }
204 true
205 } else {
206 false
207 }
208 }
209 pub fn is_empty(&self) -> bool {
210 self.ty.is_none()
211 }
212 pub fn iter(&self) -> Iter<'_> {
213 if let Some(ty) = self.ty {
214 let mut starts = self.starts.iter().peekable();
215 Iter {
216 internals: Some(IterInternals {
217 ty,
218 start_range: get_start_range(starts.peek()),
219 starts,
220 }),
221 }
222 } else {
223 Iter { internals: None }
224 }
225 }
226 pub fn len(&self) -> usize {
227 let retval: u64 = self.starts.values().map(BigUint::count_ones).sum();
228 retval as usize
229 }
230 }
231
232 #[derive(Clone, Debug)]
233 struct IterInternals<I, T>
234 where
235 I: Iterator<Item = (LocKind, T)>,
236 T: Clone + Borrow<BigUint>,
237 {
238 ty: Ty,
239 starts: Peekable<I>,
240 start_range: Range<u32>,
241 }
242
243 impl<I, T> IterInternals<I, T>
244 where
245 I: Iterator<Item = (LocKind, T)>,
246 T: Clone + Borrow<BigUint>,
247 {
248 fn next(&mut self) -> Option<Loc> {
249 let IterInternals {
250 ty,
251 ref mut starts,
252 ref mut start_range,
253 } = *self;
254 loop {
255 let (kind, ref v) = *starts.peek()?;
256 let Some(start) = start_range.next() else {
257 starts.next();
258 *start_range = get_start_range(starts.peek());
259 continue;
260 };
261 if v.borrow().bit(start as u64) {
262 return Some(
263 Loc::new(LocFields {
264 kind,
265 start,
266 reg_len: ty.reg_len,
267 })
268 .expect("known to be valid"),
269 );
270 }
271 }
272 }
273 }
274
275 fn get_start_range(v: Option<&(LocKind, impl Borrow<BigUint>)>) -> Range<u32> {
276 0..v.map(|(_, v)| v.borrow().bits() as u32).unwrap_or(0)
277 }
278
279 #[derive(Clone, Debug)]
280 pub struct Iter<'a> {
281 internals: Option<IterInternals<enum_map::Iter<'a, LocKind, BigUint>, &'a BigUint>>,
282 }
283
284 impl Iterator for Iter<'_> {
285 type Item = Loc;
286
287 fn next(&mut self) -> Option<Self::Item> {
288 self.internals.as_mut()?.next()
289 }
290 }
291
292 impl FusedIterator for Iter<'_> {}
293
294 pub struct IntoIter {
295 internals: Option<IterInternals<enum_map::IntoIter<LocKind, BigUint>, BigUint>>,
296 }
297
298 impl Iterator for IntoIter {
299 type Item = Loc;
300
301 fn next(&mut self) -> Option<Self::Item> {
302 self.internals.as_mut()?.next()
303 }
304 }
305
306 impl FusedIterator for IntoIter {}
307
308 impl IntoIterator for LocSet {
309 type Item = Loc;
310 type IntoIter = IntoIter;
311
312 fn into_iter(self) -> Self::IntoIter {
313 if let Some(ty) = self.ty {
314 let mut starts = self.starts.into_iter().peekable();
315 IntoIter {
316 internals: Some(IterInternals {
317 ty,
318 start_range: get_start_range(starts.peek()),
319 starts,
320 }),
321 }
322 } else {
323 IntoIter { internals: None }
324 }
325 }
326 }
327
328 impl<'a> IntoIterator for &'a LocSet {
329 type Item = Loc;
330 type IntoIter = Iter<'a>;
331
332 fn into_iter(self) -> Self::IntoIter {
333 self.iter()
334 }
335 }
336
337 impl Extend<Loc> for LocSet {
338 fn extend<T: IntoIterator<Item = Loc>>(&mut self, iter: T) {
339 iter.into_iter().for_each(|item| {
340 self.insert(item);
341 });
342 }
343 }
344
345 impl<E: From<Error>> Extend<Loc> for Result<LocSet, E> {
346 fn extend<T: IntoIterator<Item = Loc>>(&mut self, iter: T) {
347 iter.into_iter().try_for_each(|item| {
348 let Ok(loc_set) = self else {
349 return ControlFlow::Break(());
350 };
351 match loc_set.try_insert(item) {
352 Ok(_) => ControlFlow::Continue(()),
353 Err(e) => {
354 *self = Err(e.into());
355 ControlFlow::Break(())
356 }
357 }
358 });
359 }
360 }
361
362 impl FromIterator<Loc> for LocSet {
363 fn from_iter<T: IntoIterator<Item = Loc>>(iter: T) -> Self {
364 let mut retval = LocSet::new();
365 retval.extend(iter);
366 retval
367 }
368 }
369
370 impl<E: From<Error>> FromIterator<Loc> for Result<LocSet, E> {
371 fn from_iter<T: IntoIterator<Item = Loc>>(iter: T) -> Self {
372 let mut retval = Ok(LocSet::new());
373 retval.extend(iter);
374 retval
375 }
376 }
377
378 struct HexBigUint<'a>(&'a BigUint);
379
380 impl fmt::Debug for HexBigUint<'_> {
381 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
382 write!(f, "{:#x}", self.0)
383 }
384 }
385
386 struct LocSetStarts<'a>(&'a EnumMap<LocKind, BigUint>);
387
388 impl fmt::Debug for LocSetStarts<'_> {
389 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
390 f.debug_map()
391 .entries(self.0.iter().map(|(k, v)| (k, HexBigUint(v))))
392 .finish()
393 }
394 }
395
396 impl fmt::Debug for LocSet {
397 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
398 f.debug_struct("LocSet")
399 .field("starts", &self.starts)
400 .finish()
401 }
402 }
403
404 macro_rules! impl_bin_op {
405 (
406 $bin_op:ident::$bin_op_fn:ident(),
407 $bin_assign_op:ident::$bin_assign_op_fn:ident(),
408 $starts_op:expr,
409 $handle_unequal_types:expr,
410 $update_unequal_types:expr,
411 ) => {
412 impl $bin_op<&'_ LocSet> for &'_ LocSet {
413 type Output = LocSet;
414
415 fn $bin_op_fn(self, rhs: &'_ LocSet) -> Self::Output {
416 if self.ty != rhs.ty {
417 $handle_unequal_types(self, Cow::<LocSet>::Borrowed(rhs))
418 } else {
419 LocSet {
420 starts: enum_map! {kind => $starts_op(&self.starts[kind], &rhs.starts[kind])},
421 ty: self.ty,
422 }
423 }
424 }
425 }
426
427 impl $bin_assign_op<&'_ LocSet> for LocSet {
428 fn $bin_assign_op_fn(&mut self, rhs: &'_ LocSet) {
429 if self.ty != rhs.ty {
430 $update_unequal_types(self, rhs);
431 } else {
432 for (kind, starts) in &mut self.starts {
433 let v: BigUint = std::mem::take(starts);
434 *starts = $starts_op(v, &rhs.starts[kind]);
435 }
436 }
437 }
438 }
439
440 impl $bin_assign_op<LocSet> for LocSet {
441 fn $bin_assign_op_fn(&mut self, rhs: LocSet) {
442 self.$bin_assign_op_fn(&rhs);
443 }
444 }
445
446 impl $bin_op<&'_ LocSet> for LocSet {
447 type Output = LocSet;
448
449 fn $bin_op_fn(mut self, rhs: &'_ LocSet) -> Self::Output {
450 self.$bin_assign_op_fn(rhs);
451 self
452 }
453 }
454
455 impl $bin_op<LocSet> for LocSet {
456 type Output = LocSet;
457
458 fn $bin_op_fn(mut self, rhs: LocSet) -> Self::Output {
459 self.$bin_assign_op_fn(rhs);
460 self
461 }
462 }
463
464 impl $bin_op<LocSet> for &'_ LocSet {
465 type Output = LocSet;
466
467 fn $bin_op_fn(self, mut rhs: LocSet) -> Self::Output {
468 if self.ty != rhs.ty {
469 $handle_unequal_types(self, Cow::<LocSet>::Owned(rhs))
470 } else {
471 for (kind, starts) in &mut rhs.starts {
472 *starts = $starts_op(&self.starts[kind], std::mem::take(starts));
473 }
474 rhs
475 }
476 }
477 }
478 };
479 }
480
481 impl_bin_op! {
482 BitAnd::bitand(),
483 BitAndAssign::bitand_assign(),
484 BitAnd::bitand,
485 |_, _| LocSet::new(),
486 |lhs, _| LocSet::clear(lhs),
487 }
488
489 impl_bin_op! {
490 BitOr::bitor(),
491 BitOrAssign::bitor_assign(),
492 BitOr::bitor,
493 |lhs: &LocSet, rhs: Cow<LocSet>| panic!("{}", Error::TyMismatch { ty: rhs.ty, expected_ty: lhs.ty }),
494 |lhs: &mut LocSet, rhs: &LocSet| panic!("{}", Error::TyMismatch { ty: rhs.ty, expected_ty: lhs.ty }),
495 }
496
497 impl_bin_op! {
498 BitXor::bitxor(),
499 BitXorAssign::bitxor_assign(),
500 BitXor::bitxor,
501 |lhs: &LocSet, rhs: Cow<LocSet>| panic!("{}", Error::TyMismatch { ty: rhs.ty, expected_ty: lhs.ty }),
502 |lhs: &mut LocSet, rhs: &LocSet| panic!("{}", Error::TyMismatch { ty: rhs.ty, expected_ty: lhs.ty }),
503 }
504
505 impl_bin_op! {
506 Sub::sub(),
507 SubAssign::sub_assign(),
508 and_not,
509 |lhs: &LocSet, _| lhs.clone(),
510 |_, _| {},
511 }
512
513 /// the largest number of Locs in `lhs` that a single Loc
514 /// from `rhs` can conflict with
515 #[derive(Clone)]
516 pub struct LocSetMaxConflictsWith<Rhs> {
517 lhs: Interned<LocSet>,
518 rhs: Rhs,
519 // result is not included in equality or hash
520 result: Cell<Option<u32>>,
521 }
522
523 impl<Rhs: Hash> Hash for LocSetMaxConflictsWith<Rhs> {
524 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
525 self.lhs.hash(state);
526 self.rhs.hash(state);
527 }
528 }
529
530 impl<Rhs: Eq> Eq for LocSetMaxConflictsWith<Rhs> {}
531
532 impl<Rhs: PartialEq> PartialEq for LocSetMaxConflictsWith<Rhs> {
533 fn eq(&self, other: &Self) -> bool {
534 self.lhs == other.lhs && self.rhs == other.rhs
535 }
536 }
537
538 pub trait LocSetMaxConflictsWithTrait: Clone {
539 fn intern(
540 v: LocSetMaxConflictsWith<Self>,
541 global_state: &GlobalState,
542 ) -> Interned<LocSetMaxConflictsWith<Self>>;
543 fn compute_result(lhs: &Interned<LocSet>, rhs: &Self, global_state: &GlobalState) -> u32;
544 }
545
546 impl LocSetMaxConflictsWithTrait for Loc {
547 fn compute_result(lhs: &Interned<LocSet>, rhs: &Self, _global_state: &GlobalState) -> u32 {
548 // now we do the equivalent of:
549 // return lhs.iter().map(|loc| rhs.conflicts(loc) as u32).sum().unwrap_or(0)
550 let Some(reg_len) = lhs.reg_len() else {
551 return 0;
552 };
553 let starts = &lhs.starts[rhs.kind];
554 if starts.is_zero() {
555 return 0;
556 }
557 // now we do the equivalent of:
558 // return sum(rhs.start < start + reg_len
559 // and start < rhs.start + rhs.reg_len
560 // for start in starts)
561 let stops = starts << reg_len.get();
562
563 // find all the bit indexes `i` where `i < rhs.start + 1`
564 let lt_rhs_start_plus_1 = (BigUint::from(1u32) << (rhs.start + 1)) - 1u32;
565
566 // find all the bit indexes `i` where
567 // `i < rhs.start + rhs.reg_len + reg_len`
568 let lt_rhs_start_plus_rhs_reg_len_plus_reg_len =
569 (BigUint::from(1u32) << (rhs.start + rhs.reg_len.get() + reg_len.get())) - 1u32;
570 let mut included = and_not(&stops, &stops & lt_rhs_start_plus_1);
571 included &= lt_rhs_start_plus_rhs_reg_len_plus_reg_len;
572 included.count_ones() as u32
573 }
574
575 fn intern(
576 v: LocSetMaxConflictsWith<Self>,
577 global_state: &GlobalState,
578 ) -> Interned<LocSetMaxConflictsWith<Self>> {
579 v.into_interned(global_state)
580 }
581 }
582
583 impl LocSetMaxConflictsWithTrait for Interned<LocSet> {
584 fn compute_result(lhs: &Interned<LocSet>, rhs: &Self, global_state: &GlobalState) -> u32 {
585 rhs.iter()
586 .map(|loc| lhs.clone().max_conflicts_with(loc, global_state))
587 .max()
588 .unwrap_or(0)
589 }
590
591 fn intern(
592 v: LocSetMaxConflictsWith<Self>,
593 global_state: &GlobalState,
594 ) -> Interned<LocSetMaxConflictsWith<Self>> {
595 v.into_interned(global_state)
596 }
597 }
598
599 impl<Rhs: LocSetMaxConflictsWithTrait> LocSetMaxConflictsWith<Rhs> {
600 pub fn lhs(&self) -> &Interned<LocSet> {
601 &self.lhs
602 }
603 pub fn rhs(&self) -> &Rhs {
604 &self.rhs
605 }
606 pub fn result(&self, global_state: &GlobalState) -> u32 {
607 match self.result.get() {
608 Some(v) => v,
609 None => {
610 let retval = Rhs::compute_result(&self.lhs, &self.rhs, global_state);
611 self.result.set(Some(retval));
612 retval
613 }
614 }
615 }
616 }
617
618 impl Interned<LocSet> {
619 pub fn max_conflicts_with<Rhs>(self, rhs: Rhs, global_state: &GlobalState) -> u32
620 where
621 Rhs: LocSetMaxConflictsWithTrait,
622 LocSetMaxConflictsWith<Rhs>: InternTarget,
623 {
624 LocSetMaxConflictsWithTrait::intern(
625 LocSetMaxConflictsWith {
626 lhs: self,
627 rhs,
628 result: Cell::default(),
629 },
630 global_state,
631 )
632 .result(global_state)
633 }
634 pub fn conflicts_with<Rhs>(self, rhs: Rhs, global_state: &GlobalState) -> bool
635 where
636 Rhs: LocSetMaxConflictsWithTrait,
637 LocSetMaxConflictsWith<Rhs>: InternTarget,
638 {
639 self.max_conflicts_with(rhs, global_state) != 0
640 }
641 }
642
643 #[cfg(test)]
644 mod tests {
645 use super::*;
646
647 #[test]
648 fn test_and_not() {
649 for a in 0..0x10u32 {
650 for b in 0..0x10 {
651 assert_eq!(
652 and_not(&BigUint::from(a), BigUint::from(b)),
653 (a & !b).into()
654 );
655 }
656 }
657 }
658 }