2 error::{Error, Result},
3 interned::{GlobalState, Intern, InternTarget, Interned},
4 loc::{BaseTy, Loc, LocFields, LocKind, Ty, TyFields},
6 use enum_map::{enum_map, EnumMap};
7 use num_bigint::BigUint;
9 use serde::{Deserialize, Serialize};
11 borrow::{Borrow, Cow},
13 collections::BTreeMap,
16 iter::{FusedIterator, Peekable},
19 BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, ControlFlow, Range, Sub,
24 #[derive(Deserialize)]
25 struct LocSetSerialized {
26 reg_len_to_starts_map: BTreeMap<NonZeroU32, EnumMap<LocKind, BigUint>>,
29 impl TryFrom<LocSetSerialized> for LocSet {
32 fn try_from(value: LocSetSerialized) -> Result<Self, Self::Error> {
33 Self::from_reg_len_to_starts_map(value.reg_len_to_starts_map)
37 #[derive(Clone, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
38 #[serde(try_from = "LocSetSerialized")]
40 reg_len_to_starts_map: BTreeMap<NonZeroU32, EnumMap<LocKind, BigUint>>,
43 /// computes same value as `a & !b`, but more efficiently
44 fn and_not<A: Borrow<BigUint>, B>(a: A, b: B) -> BigUint
46 BigUint: for<'a> BitXor<A, Output = BigUint>,
47 B: for<'a> BitAnd<&'a BigUint, Output = BigUint>,
49 // use logical equivalent that avoids needing to use BigInt
53 impl From<Loc> for LocSet {
54 fn from(value: Loc) -> Self {
55 Self::from_iter([value])
60 pub fn arbitrary_with_ty(
62 u: &mut arbitrary::Unstructured<'_>,
63 ) -> arbitrary::Result<Self> {
64 let Some(ty) = ty else {
65 return Ok(Self::new());
67 let kinds = ty.base_ty.loc_kinds();
69 let kinds: Vec<_> = if kinds.len() > Mask::BITS as usize {
70 let chosen_kinds = kinds
72 .zip(u.arbitrary_iter::<bool>()?)
73 .filter(|(_, cond)| !matches!(cond, Ok(false)))
74 .map(|(&kind, cond)| {
78 .collect::<arbitrary::Result<Vec<_>>>()?;
79 if chosen_kinds.is_empty() {
80 vec![*u.choose(kinds)?]
85 let max_mask = Mask::wrapping_shl(1, kinds.len() as u32).wrapping_sub(1);
86 let mask = u.int_in_range(1..=max_mask)?; // non-zero
90 .filter_map(|(idx, &kind)| {
91 if mask & (1 << idx) != 0 {
99 let mut starts = EnumMap::<LocKind, BigUint>::default();
100 let mut all_zero = true;
102 let bit_count = Loc::max_start(kind, ty.reg_len)? + 1;
103 let byte_count = (bit_count + u8::BITS - 1) / u8::BITS;
104 let bytes = u.bytes(byte_count as usize)?;
105 starts[kind] = BigUint::from_bytes_le(bytes);
106 all_zero &= starts[kind].is_zero();
109 Ok(Loc::arbitrary_with_ty(ty, u)?.into())
111 Ok(Self::from_parts(starts, Some(ty))?)
114 pub fn starts(&self) -> &EnumMap<LocKind, BigUint> {
117 pub fn stops(&self) -> EnumMap<LocKind, BigUint> {
118 let Some(ty) = self.ty else {
119 return EnumMap::default();
121 enum_map! {kind => &self.starts[kind] << ty.reg_len.get()}
123 pub fn ty(&self) -> Option<Ty> {
126 pub fn kinds(&self) -> impl Iterator<Item = LocKind> + '_ {
129 .filter_map(|(kind, starts)| if starts.is_zero() { None } else { Some(kind) })
131 pub fn reg_len(&self) -> Option<NonZeroU32> {
132 self.ty.map(|v| v.reg_len)
134 pub fn base_ty(&self) -> Option<BaseTy> {
135 self.ty.map(|v| v.base_ty)
137 pub fn new() -> Self {
140 pub fn from_parts(starts: EnumMap<LocKind, BigUint>, ty: Option<Ty>) -> Result<Self> {
141 let mut empty = true;
142 for (kind, starts) in &starts {
143 if !starts.is_zero() {
145 let expected_ty = Ty::new_or_scalar(TyFields {
146 base_ty: kind.base_ty(),
147 reg_len: ty.map(|v| v.reg_len).unwrap_or(nzu32_lit!(1)),
149 if ty != Some(expected_ty) {
150 return Err(Error::TyMismatch {
152 expected_ty: Some(expected_ty),
155 // bits() is one past max bit set, so use >= rather than >
156 if starts.bits() >= Loc::max_start(kind, expected_ty.reg_len)? as u64 {
157 return Err(Error::StartNotInValidRange);
161 if empty && ty.is_some() {
162 Err(Error::TyMismatch {
167 Ok(Self { starts, ty })
170 pub fn clear(&mut self) {
171 for v in self.starts.values_mut() {
172 v.assign_from_slice(&[]);
175 pub fn contains_exact(&self, value: Loc) -> bool {
176 Some(value.ty()) == self.ty && self.starts[value.kind].bit(value.start as _)
178 pub fn try_insert(&mut self, value: Loc) -> Result<bool> {
180 self.ty = Some(value.ty());
181 self.starts[value.kind].set_bit(value.start as u64, true);
184 let ty = Some(value.ty());
186 return Err(Error::TyMismatch {
188 expected_ty: self.ty,
191 let retval = !self.starts[value.kind].bit(value.start as u64);
192 self.starts[value.kind].set_bit(value.start as u64, true);
195 pub fn insert(&mut self, value: Loc) -> bool {
196 self.try_insert(value).unwrap()
198 pub fn remove(&mut self, value: Loc) -> bool {
199 if self.contains_exact(value) {
200 self.starts[value.kind].set_bit(value.start as u64, false);
201 if self.starts.values().all(BigUint::is_zero) {
209 pub fn is_empty(&self) -> bool {
212 pub fn iter(&self) -> Iter<'_> {
213 if let Some(ty) = self.ty {
214 let mut starts = self.starts.iter().peekable();
216 internals: Some(IterInternals {
218 start_range: get_start_range(starts.peek()),
223 Iter { internals: None }
226 pub fn len(&self) -> usize {
227 let retval: u64 = self.starts.values().map(BigUint::count_ones).sum();
232 #[derive(Clone, Debug)]
233 struct IterInternals<I, T>
235 I: Iterator<Item = (LocKind, T)>,
236 T: Clone + Borrow<BigUint>,
240 start_range: Range<u32>,
243 impl<I, T> IterInternals<I, T>
245 I: Iterator<Item = (LocKind, T)>,
246 T: Clone + Borrow<BigUint>,
248 fn next(&mut self) -> Option<Loc> {
255 let (kind, ref v) = *starts.peek()?;
256 let Some(start) = start_range.next() else {
258 *start_range = get_start_range(starts.peek());
261 if v.borrow().bit(start as u64) {
268 .expect("known to be valid"),
275 fn get_start_range(v: Option<&(LocKind, impl Borrow<BigUint>)>) -> Range<u32> {
276 0..v.map(|(_, v)| v.borrow().bits() as u32).unwrap_or(0)
279 #[derive(Clone, Debug)]
280 pub struct Iter<'a> {
281 internals: Option<IterInternals<enum_map::Iter<'a, LocKind, BigUint>, &'a BigUint>>,
284 impl Iterator for Iter<'_> {
287 fn next(&mut self) -> Option<Self::Item> {
288 self.internals.as_mut()?.next()
292 impl FusedIterator for Iter<'_> {}
294 pub struct IntoIter {
295 internals: Option<IterInternals<enum_map::IntoIter<LocKind, BigUint>, BigUint>>,
298 impl Iterator for IntoIter {
301 fn next(&mut self) -> Option<Self::Item> {
302 self.internals.as_mut()?.next()
306 impl FusedIterator for IntoIter {}
308 impl IntoIterator for LocSet {
310 type IntoIter = IntoIter;
312 fn into_iter(self) -> Self::IntoIter {
313 if let Some(ty) = self.ty {
314 let mut starts = self.starts.into_iter().peekable();
316 internals: Some(IterInternals {
318 start_range: get_start_range(starts.peek()),
323 IntoIter { internals: None }
328 impl<'a> IntoIterator for &'a LocSet {
330 type IntoIter = Iter<'a>;
332 fn into_iter(self) -> Self::IntoIter {
337 impl Extend<Loc> for LocSet {
338 fn extend<T: IntoIterator<Item = Loc>>(&mut self, iter: T) {
339 iter.into_iter().for_each(|item| {
345 impl<E: From<Error>> Extend<Loc> for Result<LocSet, E> {
346 fn extend<T: IntoIterator<Item = Loc>>(&mut self, iter: T) {
347 iter.into_iter().try_for_each(|item| {
348 let Ok(loc_set) = self else {
349 return ControlFlow::Break(());
351 match loc_set.try_insert(item) {
352 Ok(_) => ControlFlow::Continue(()),
354 *self = Err(e.into());
355 ControlFlow::Break(())
362 impl FromIterator<Loc> for LocSet {
363 fn from_iter<T: IntoIterator<Item = Loc>>(iter: T) -> Self {
364 let mut retval = LocSet::new();
370 impl<E: From<Error>> FromIterator<Loc> for Result<LocSet, E> {
371 fn from_iter<T: IntoIterator<Item = Loc>>(iter: T) -> Self {
372 let mut retval = Ok(LocSet::new());
378 struct HexBigUint<'a>(&'a BigUint);
380 impl fmt::Debug for HexBigUint<'_> {
381 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
382 write!(f, "{:#x}", self.0)
386 struct LocSetStarts<'a>(&'a EnumMap<LocKind, BigUint>);
388 impl fmt::Debug for LocSetStarts<'_> {
389 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
391 .entries(self.0.iter().map(|(k, v)| (k, HexBigUint(v))))
396 impl fmt::Debug for LocSet {
397 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
398 f.debug_struct("LocSet")
399 .field("starts", &self.starts)
404 macro_rules! impl_bin_op {
406 $bin_op:ident::$bin_op_fn:ident(),
407 $bin_assign_op:ident::$bin_assign_op_fn:ident(),
409 $handle_unequal_types:expr,
410 $update_unequal_types:expr,
412 impl $bin_op<&'_ LocSet> for &'_ LocSet {
413 type Output = LocSet;
415 fn $bin_op_fn(self, rhs: &'_ LocSet) -> Self::Output {
416 if self.ty != rhs.ty {
417 $handle_unequal_types(self, Cow::<LocSet>::Borrowed(rhs))
420 starts: enum_map! {kind => $starts_op(&self.starts[kind], &rhs.starts[kind])},
427 impl $bin_assign_op<&'_ LocSet> for LocSet {
428 fn $bin_assign_op_fn(&mut self, rhs: &'_ LocSet) {
429 if self.ty != rhs.ty {
430 $update_unequal_types(self, rhs);
432 for (kind, starts) in &mut self.starts {
433 let v: BigUint = std::mem::take(starts);
434 *starts = $starts_op(v, &rhs.starts[kind]);
440 impl $bin_assign_op<LocSet> for LocSet {
441 fn $bin_assign_op_fn(&mut self, rhs: LocSet) {
442 self.$bin_assign_op_fn(&rhs);
446 impl $bin_op<&'_ LocSet> for LocSet {
447 type Output = LocSet;
449 fn $bin_op_fn(mut self, rhs: &'_ LocSet) -> Self::Output {
450 self.$bin_assign_op_fn(rhs);
455 impl $bin_op<LocSet> for LocSet {
456 type Output = LocSet;
458 fn $bin_op_fn(mut self, rhs: LocSet) -> Self::Output {
459 self.$bin_assign_op_fn(rhs);
464 impl $bin_op<LocSet> for &'_ LocSet {
465 type Output = LocSet;
467 fn $bin_op_fn(self, mut rhs: LocSet) -> Self::Output {
468 if self.ty != rhs.ty {
469 $handle_unequal_types(self, Cow::<LocSet>::Owned(rhs))
471 for (kind, starts) in &mut rhs.starts {
472 *starts = $starts_op(&self.starts[kind], std::mem::take(starts));
483 BitAndAssign::bitand_assign(),
485 |_, _| LocSet::new(),
486 |lhs, _| LocSet::clear(lhs),
491 BitOrAssign::bitor_assign(),
493 |lhs: &LocSet, rhs: Cow<LocSet>| panic!("{}", Error::TyMismatch { ty: rhs.ty, expected_ty: lhs.ty }),
494 |lhs: &mut LocSet, rhs: &LocSet| panic!("{}", Error::TyMismatch { ty: rhs.ty, expected_ty: lhs.ty }),
499 BitXorAssign::bitxor_assign(),
501 |lhs: &LocSet, rhs: Cow<LocSet>| panic!("{}", Error::TyMismatch { ty: rhs.ty, expected_ty: lhs.ty }),
502 |lhs: &mut LocSet, rhs: &LocSet| panic!("{}", Error::TyMismatch { ty: rhs.ty, expected_ty: lhs.ty }),
507 SubAssign::sub_assign(),
509 |lhs: &LocSet, _| lhs.clone(),
513 /// the largest number of Locs in `lhs` that a single Loc
514 /// from `rhs` can conflict with
516 pub struct LocSetMaxConflictsWith<Rhs> {
517 lhs: Interned<LocSet>,
519 // result is not included in equality or hash
520 result: Cell<Option<u32>>,
523 impl<Rhs: Hash> Hash for LocSetMaxConflictsWith<Rhs> {
524 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
525 self.lhs.hash(state);
526 self.rhs.hash(state);
530 impl<Rhs: Eq> Eq for LocSetMaxConflictsWith<Rhs> {}
532 impl<Rhs: PartialEq> PartialEq for LocSetMaxConflictsWith<Rhs> {
533 fn eq(&self, other: &Self) -> bool {
534 self.lhs == other.lhs && self.rhs == other.rhs
538 pub trait LocSetMaxConflictsWithTrait: Clone {
540 v: LocSetMaxConflictsWith<Self>,
541 global_state: &GlobalState,
542 ) -> Interned<LocSetMaxConflictsWith<Self>>;
543 fn compute_result(lhs: &Interned<LocSet>, rhs: &Self, global_state: &GlobalState) -> u32;
546 impl LocSetMaxConflictsWithTrait for Loc {
547 fn compute_result(lhs: &Interned<LocSet>, rhs: &Self, _global_state: &GlobalState) -> u32 {
548 // now we do the equivalent of:
549 // return lhs.iter().map(|loc| rhs.conflicts(loc) as u32).sum().unwrap_or(0)
550 let Some(reg_len) = lhs.reg_len() else {
553 let starts = &lhs.starts[rhs.kind];
554 if starts.is_zero() {
557 // now we do the equivalent of:
558 // return sum(rhs.start < start + reg_len
559 // and start < rhs.start + rhs.reg_len
560 // for start in starts)
561 let stops = starts << reg_len.get();
563 // find all the bit indexes `i` where `i < rhs.start + 1`
564 let lt_rhs_start_plus_1 = (BigUint::from(1u32) << (rhs.start + 1)) - 1u32;
566 // find all the bit indexes `i` where
567 // `i < rhs.start + rhs.reg_len + reg_len`
568 let lt_rhs_start_plus_rhs_reg_len_plus_reg_len =
569 (BigUint::from(1u32) << (rhs.start + rhs.reg_len.get() + reg_len.get())) - 1u32;
570 let mut included = and_not(&stops, &stops & lt_rhs_start_plus_1);
571 included &= lt_rhs_start_plus_rhs_reg_len_plus_reg_len;
572 included.count_ones() as u32
576 v: LocSetMaxConflictsWith<Self>,
577 global_state: &GlobalState,
578 ) -> Interned<LocSetMaxConflictsWith<Self>> {
579 v.into_interned(global_state)
583 impl LocSetMaxConflictsWithTrait for Interned<LocSet> {
584 fn compute_result(lhs: &Interned<LocSet>, rhs: &Self, global_state: &GlobalState) -> u32 {
586 .map(|loc| lhs.clone().max_conflicts_with(loc, global_state))
592 v: LocSetMaxConflictsWith<Self>,
593 global_state: &GlobalState,
594 ) -> Interned<LocSetMaxConflictsWith<Self>> {
595 v.into_interned(global_state)
599 impl<Rhs: LocSetMaxConflictsWithTrait> LocSetMaxConflictsWith<Rhs> {
600 pub fn lhs(&self) -> &Interned<LocSet> {
603 pub fn rhs(&self) -> &Rhs {
606 pub fn result(&self, global_state: &GlobalState) -> u32 {
607 match self.result.get() {
610 let retval = Rhs::compute_result(&self.lhs, &self.rhs, global_state);
611 self.result.set(Some(retval));
618 impl Interned<LocSet> {
619 pub fn max_conflicts_with<Rhs>(self, rhs: Rhs, global_state: &GlobalState) -> u32
621 Rhs: LocSetMaxConflictsWithTrait,
622 LocSetMaxConflictsWith<Rhs>: InternTarget,
624 LocSetMaxConflictsWithTrait::intern(
625 LocSetMaxConflictsWith {
628 result: Cell::default(),
632 .result(global_state)
634 pub fn conflicts_with<Rhs>(self, rhs: Rhs, global_state: &GlobalState) -> bool
636 Rhs: LocSetMaxConflictsWithTrait,
637 LocSetMaxConflictsWith<Rhs>: InternTarget,
639 self.max_conflicts_with(rhs, global_state) != 0
649 for a in 0..0x10u32 {
652 and_not(&BigUint::from(a), BigUint::from(b)),