2 error::{Error, Result},
3 interned::{GlobalState, Intern, InternTarget, Interned},
4 loc::{BaseTy, Loc, LocFields, LocKind, Ty, TyFields},
6 use enum_map::{enum_map, EnumMap};
7 use num_bigint::BigUint;
9 use serde::{Deserialize, Serialize};
11 borrow::{Borrow, Cow},
15 iter::{FusedIterator, Peekable},
18 BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, ControlFlow, Range, Sub,
23 #[derive(Deserialize)]
24 struct LocSetSerialized {
25 starts: EnumMap<LocKind, BigUint>,
29 impl TryFrom<LocSetSerialized> for LocSet {
32 fn try_from(value: LocSetSerialized) -> Result<Self, Self::Error> {
33 Self::from_parts(value.starts, value.ty)
37 #[derive(Clone, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
38 #[serde(try_from = "LocSetSerialized")]
40 starts: EnumMap<LocKind, BigUint>,
44 /// computes same value as `a & !b`, but more efficiently
45 fn and_not<A: Borrow<BigUint>, B>(a: A, b: B) -> BigUint
47 BigUint: for<'a> BitXor<A, Output = BigUint>,
48 B: for<'a> BitAnd<&'a BigUint, Output = BigUint>,
50 // use logical equivalent that avoids needing to use BigInt
55 pub fn starts(&self) -> &EnumMap<LocKind, BigUint> {
58 pub fn stops(&self) -> EnumMap<LocKind, BigUint> {
59 let Some(ty) = self.ty else {
60 return EnumMap::default();
62 enum_map! {kind => &self.starts[kind] << ty.reg_len.get()}
64 pub fn ty(&self) -> Option<Ty> {
67 pub fn kinds(&self) -> impl Iterator<Item = LocKind> + '_ {
70 .filter_map(|(kind, starts)| if starts.is_zero() { None } else { Some(kind) })
72 pub fn reg_len(&self) -> Option<NonZeroU32> {
73 self.ty.map(|v| v.reg_len)
75 pub fn base_ty(&self) -> Option<BaseTy> {
76 self.ty.map(|v| v.base_ty)
78 pub fn new() -> Self {
81 pub fn from_parts(starts: EnumMap<LocKind, BigUint>, ty: Option<Ty>) -> Result<Self> {
83 for (kind, starts) in &starts {
84 if !starts.is_zero() {
86 let expected_ty = Ty::new(TyFields {
87 base_ty: kind.base_ty(),
88 reg_len: ty.map(|v| v.reg_len).unwrap_or(nzu32_lit!(1)),
92 base_ty: kind.base_ty(),
93 reg_len: nzu32_lit!(1),
97 if ty != Some(expected_ty) {
98 return Err(Error::TyMismatch {
100 expected_ty: Some(expected_ty),
103 // bits() is one past max bit set, so use >= rather than >
104 if starts.bits() >= Loc::max_start(kind, expected_ty.reg_len)? as u64 {
105 return Err(Error::StartNotInValidRange);
109 if empty && ty.is_some() {
110 Err(Error::TyMismatch {
115 Ok(Self { starts, ty })
118 pub fn clear(&mut self) {
119 for v in self.starts.values_mut() {
120 v.assign_from_slice(&[]);
123 pub fn contains(&self, value: Loc) -> bool {
124 Some(value.ty()) == self.ty && self.starts[value.kind].bit(value.start as _)
126 pub fn try_insert(&mut self, value: Loc) -> Result<bool> {
128 self.ty = Some(value.ty());
129 self.starts[value.kind].set_bit(value.start as u64, true);
132 let ty = Some(value.ty());
134 return Err(Error::TyMismatch {
136 expected_ty: self.ty,
139 let retval = !self.starts[value.kind].bit(value.start as u64);
140 self.starts[value.kind].set_bit(value.start as u64, true);
143 pub fn insert(&mut self, value: Loc) -> bool {
144 self.try_insert(value).unwrap()
146 pub fn remove(&mut self, value: Loc) -> bool {
147 if self.contains(value) {
148 self.starts[value.kind].set_bit(value.start as u64, false);
149 if self.starts.values().all(BigUint::is_zero) {
157 pub fn is_empty(&self) -> bool {
160 pub fn is_disjoint(&self, other: &LocSet) -> bool {
161 if self.ty != other.ty || self.is_empty() {
164 for (k, lhs) in self.starts.iter() {
165 let rhs = &other.starts[k];
166 if !(lhs & rhs).is_zero() {
172 pub fn is_subset(&self, containing_set: &LocSet) -> bool {
176 if self.ty != containing_set.ty {
179 for (k, v) in self.starts.iter() {
180 let containing_set = &containing_set.starts[k];
181 if !and_not(v, containing_set).is_zero() {
187 pub fn is_superset(&self, contained_set: &LocSet) -> bool {
188 contained_set.is_subset(self)
190 pub fn iter(&self) -> Iter<'_> {
191 if let Some(ty) = self.ty {
192 let mut starts = self.starts.iter().peekable();
194 internals: Some(IterInternals {
196 start_range: get_start_range(starts.peek()),
201 Iter { internals: None }
204 pub fn len(&self) -> usize {
205 let retval: u64 = self.starts.values().map(BigUint::count_ones).sum();
210 #[derive(Clone, Debug)]
211 struct IterInternals<I, T>
213 I: Iterator<Item = (LocKind, T)>,
214 T: Clone + Borrow<BigUint>,
218 start_range: Range<u32>,
221 impl<I, T> IterInternals<I, T>
223 I: Iterator<Item = (LocKind, T)>,
224 T: Clone + Borrow<BigUint>,
226 fn next(&mut self) -> Option<Loc> {
233 let (kind, ref v) = *starts.peek()?;
234 let Some(start) = start_range.next() else {
236 *start_range = get_start_range(starts.peek());
239 if v.borrow().bit(start as u64) {
246 .expect("known to be valid"),
253 fn get_start_range(v: Option<&(LocKind, impl Borrow<BigUint>)>) -> Range<u32> {
254 0..v.map(|(_, v)| v.borrow().bits() as u32).unwrap_or(0)
257 #[derive(Clone, Debug)]
258 pub struct Iter<'a> {
259 internals: Option<IterInternals<enum_map::Iter<'a, LocKind, BigUint>, &'a BigUint>>,
262 impl Iterator for Iter<'_> {
265 fn next(&mut self) -> Option<Self::Item> {
266 self.internals.as_mut()?.next()
270 impl FusedIterator for Iter<'_> {}
272 pub struct IntoIter {
273 internals: Option<IterInternals<enum_map::IntoIter<LocKind, BigUint>, BigUint>>,
276 impl Iterator for IntoIter {
279 fn next(&mut self) -> Option<Self::Item> {
280 self.internals.as_mut()?.next()
284 impl FusedIterator for IntoIter {}
286 impl IntoIterator for LocSet {
288 type IntoIter = IntoIter;
290 fn into_iter(self) -> Self::IntoIter {
291 if let Some(ty) = self.ty {
292 let mut starts = self.starts.into_iter().peekable();
294 internals: Some(IterInternals {
296 start_range: get_start_range(starts.peek()),
301 IntoIter { internals: None }
306 impl<'a> IntoIterator for &'a LocSet {
308 type IntoIter = Iter<'a>;
310 fn into_iter(self) -> Self::IntoIter {
315 impl Extend<Loc> for LocSet {
316 fn extend<T: IntoIterator<Item = Loc>>(&mut self, iter: T) {
317 iter.into_iter().for_each(|item| {
323 impl<E: From<Error>> Extend<Loc> for Result<LocSet, E> {
324 fn extend<T: IntoIterator<Item = Loc>>(&mut self, iter: T) {
325 iter.into_iter().try_for_each(|item| {
326 let Ok(loc_set) = self else {
327 return ControlFlow::Break(());
329 match loc_set.try_insert(item) {
330 Ok(_) => ControlFlow::Continue(()),
332 *self = Err(e.into());
333 ControlFlow::Break(())
340 impl FromIterator<Loc> for LocSet {
341 fn from_iter<T: IntoIterator<Item = Loc>>(iter: T) -> Self {
342 let mut retval = LocSet::new();
348 impl<E: From<Error>> FromIterator<Loc> for Result<LocSet, E> {
349 fn from_iter<T: IntoIterator<Item = Loc>>(iter: T) -> Self {
350 let mut retval = Ok(LocSet::new());
356 struct HexBigUint<'a>(&'a BigUint);
358 impl fmt::Debug for HexBigUint<'_> {
359 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
360 write!(f, "{:#x}", self.0)
364 struct LocSetStarts<'a>(&'a EnumMap<LocKind, BigUint>);
366 impl fmt::Debug for LocSetStarts<'_> {
367 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
369 .entries(self.0.iter().map(|(k, v)| (k, HexBigUint(v))))
374 impl fmt::Debug for LocSet {
375 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
376 f.debug_struct("LocSet")
377 .field("starts", &self.starts)
382 macro_rules! impl_bin_op {
384 $bin_op:ident::$bin_op_fn:ident(),
385 $bin_assign_op:ident::$bin_assign_op_fn:ident(),
387 $handle_unequal_types:expr,
388 $update_unequal_types:expr,
390 impl $bin_op<&'_ LocSet> for &'_ LocSet {
391 type Output = LocSet;
393 fn $bin_op_fn(self, rhs: &'_ LocSet) -> Self::Output {
394 if self.ty != rhs.ty {
395 $handle_unequal_types(self, Cow::<LocSet>::Borrowed(rhs))
398 starts: enum_map! {kind => $starts_op(&self.starts[kind], &rhs.starts[kind])},
405 impl $bin_assign_op<&'_ LocSet> for LocSet {
406 fn $bin_assign_op_fn(&mut self, rhs: &'_ LocSet) {
407 if self.ty != rhs.ty {
408 $update_unequal_types(self, rhs);
410 for (kind, starts) in &mut self.starts {
411 let v: BigUint = std::mem::take(starts);
412 *starts = $starts_op(v, &rhs.starts[kind]);
418 impl $bin_assign_op<LocSet> for LocSet {
419 fn $bin_assign_op_fn(&mut self, rhs: LocSet) {
420 self.$bin_assign_op_fn(&rhs);
424 impl $bin_op<&'_ LocSet> for LocSet {
425 type Output = LocSet;
427 fn $bin_op_fn(mut self, rhs: &'_ LocSet) -> Self::Output {
428 self.$bin_assign_op_fn(rhs);
433 impl $bin_op<LocSet> for LocSet {
434 type Output = LocSet;
436 fn $bin_op_fn(mut self, rhs: LocSet) -> Self::Output {
437 self.$bin_assign_op_fn(rhs);
442 impl $bin_op<LocSet> for &'_ LocSet {
443 type Output = LocSet;
445 fn $bin_op_fn(self, mut rhs: LocSet) -> Self::Output {
446 if self.ty != rhs.ty {
447 $handle_unequal_types(self, Cow::<LocSet>::Owned(rhs))
449 for (kind, starts) in &mut rhs.starts {
450 *starts = $starts_op(&self.starts[kind], std::mem::take(starts));
461 BitAndAssign::bitand_assign(),
463 |_, _| LocSet::new(),
464 |lhs, _| LocSet::clear(lhs),
469 BitOrAssign::bitor_assign(),
471 |lhs: &LocSet, rhs: Cow<LocSet>| panic!("{}", Error::TyMismatch { ty: rhs.ty, expected_ty: lhs.ty }),
472 |lhs: &mut LocSet, rhs: &LocSet| panic!("{}", Error::TyMismatch { ty: rhs.ty, expected_ty: lhs.ty }),
477 BitXorAssign::bitxor_assign(),
479 |lhs: &LocSet, rhs: Cow<LocSet>| panic!("{}", Error::TyMismatch { ty: rhs.ty, expected_ty: lhs.ty }),
480 |lhs: &mut LocSet, rhs: &LocSet| panic!("{}", Error::TyMismatch { ty: rhs.ty, expected_ty: lhs.ty }),
485 SubAssign::sub_assign(),
487 |lhs: &LocSet, _| lhs.clone(),
491 /// the largest number of Locs in `lhs` that a single Loc
492 /// from `rhs` can conflict with
494 pub struct LocSetMaxConflictsWith<Rhs> {
495 lhs: Interned<LocSet>,
497 // result is not included in equality or hash
498 result: Cell<Option<u32>>,
501 impl<Rhs: Hash> Hash for LocSetMaxConflictsWith<Rhs> {
502 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
503 self.lhs.hash(state);
504 self.rhs.hash(state);
508 impl<Rhs: Eq> Eq for LocSetMaxConflictsWith<Rhs> {}
510 impl<Rhs: PartialEq> PartialEq for LocSetMaxConflictsWith<Rhs> {
511 fn eq(&self, other: &Self) -> bool {
512 self.lhs == other.lhs && self.rhs == other.rhs
516 pub trait LocSetMaxConflictsWithTrait: Clone {
518 v: LocSetMaxConflictsWith<Self>,
519 global_state: &GlobalState,
520 ) -> Interned<LocSetMaxConflictsWith<Self>>;
521 fn compute_result(lhs: &Interned<LocSet>, rhs: &Self, global_state: &GlobalState) -> u32;
524 impl LocSetMaxConflictsWithTrait for Loc {
525 fn compute_result(lhs: &Interned<LocSet>, rhs: &Self, _global_state: &GlobalState) -> u32 {
526 // now we do the equivalent of:
527 // return lhs.iter().map(|loc| rhs.conflicts(loc) as u32).sum().unwrap_or(0)
528 let Some(reg_len) = lhs.reg_len() else {
531 let starts = &lhs.starts[rhs.kind];
532 if starts.is_zero() {
535 // now we do the equivalent of:
536 // return sum(rhs.start < start + reg_len
537 // and start < rhs.start + rhs.reg_len
538 // for start in starts)
539 let stops = starts << reg_len.get();
541 // find all the bit indexes `i` where `i < rhs.start + 1`
542 let lt_rhs_start_plus_1 = (BigUint::from(1u32) << (rhs.start + 1)) - 1u32;
544 // find all the bit indexes `i` where
545 // `i < rhs.start + rhs.reg_len + reg_len`
546 let lt_rhs_start_plus_rhs_reg_len_plus_reg_len =
547 (BigUint::from(1u32) << (rhs.start + rhs.reg_len.get() + reg_len.get())) - 1u32;
548 let mut included = and_not(&stops, &stops & lt_rhs_start_plus_1);
549 included &= lt_rhs_start_plus_rhs_reg_len_plus_reg_len;
550 included.count_ones() as u32
554 v: LocSetMaxConflictsWith<Self>,
555 global_state: &GlobalState,
556 ) -> Interned<LocSetMaxConflictsWith<Self>> {
557 v.into_interned(global_state)
561 impl LocSetMaxConflictsWithTrait for Interned<LocSet> {
562 fn compute_result(lhs: &Interned<LocSet>, rhs: &Self, global_state: &GlobalState) -> u32 {
564 .map(|loc| lhs.clone().max_conflicts_with(loc, global_state))
570 v: LocSetMaxConflictsWith<Self>,
571 global_state: &GlobalState,
572 ) -> Interned<LocSetMaxConflictsWith<Self>> {
573 v.into_interned(global_state)
577 impl<Rhs: LocSetMaxConflictsWithTrait> LocSetMaxConflictsWith<Rhs> {
578 pub fn lhs(&self) -> &Interned<LocSet> {
581 pub fn rhs(&self) -> &Rhs {
584 pub fn result(&self, global_state: &GlobalState) -> u32 {
585 match self.result.get() {
588 let retval = Rhs::compute_result(&self.lhs, &self.rhs, global_state);
589 self.result.set(Some(retval));
596 impl Interned<LocSet> {
597 pub fn max_conflicts_with<Rhs>(self, rhs: Rhs, global_state: &GlobalState) -> u32
599 Rhs: LocSetMaxConflictsWithTrait,
600 LocSetMaxConflictsWith<Rhs>: InternTarget,
602 LocSetMaxConflictsWithTrait::intern(
603 LocSetMaxConflictsWith {
606 result: Cell::default(),
610 .result(global_state)
620 for a in 0..0x10u32 {
623 and_not(&BigUint::from(a), BigUint::from(b)),