optimize LocSet.max_conflicts_with
authorJacob Lifshay <programmerjake@gmail.com>
Fri, 16 Dec 2022 10:03:51 +0000 (02:03 -0800)
committerJacob Lifshay <programmerjake@gmail.com>
Fri, 16 Dec 2022 10:03:51 +0000 (02:03 -0800)
src/bigint_presentation_code/compiler_ir.py

index 9e4d13d96fc28f021a16860fcef5717910dd2296..724841caa06af352abea69e839225084b420d216 100644 (file)
@@ -15,7 +15,7 @@ from nmutil import plain_data  # type: ignore
 from bigint_presentation_code.type_util import (Literal, Self, assert_never,
                                                 final)
 from bigint_presentation_code.util import (BitSet, FBitSet, FMap, Interned,
-                                           OFSet, OSet)
+                                           OFSet, OSet, bit_count)
 
 GPR_SIZE_IN_BYTES = 8
 BITS_IN_BYTE = 8
@@ -762,7 +762,7 @@ class Loc(Interned):
         # type: () -> Ty
         return self.make_ty(kind=self.kind, reg_len=self.reg_len)
 
-    @property
+    @cached_property
     def stop(self):
         # type: () -> int
         return self.start + self.reg_len
@@ -912,7 +912,31 @@ class LocSet(OFSet[Loc], Interned):
         if isinstance(other, LocSet):
             return max(self.max_conflicts_with(i) for i in other)
         else:
-            return sum(other.conflicts(i) for i in self)
+            reg_len = self.reg_len
+            if reg_len is None:
+                return 0
+            starts = self.starts.get(other.kind)
+            if starts is None:
+                return 0
+            # now we do the equivalent of:
+            # return sum(other.conflicts(i) for i in self)
+            # which is the equivalent of:
+            # return sum(other.start < start + reg_len
+            #            and start < other.start + other.reg_len
+            #            for start in starts)
+            stops = starts.bits << reg_len
+
+            # find all the bit indexes `i` where `i < other.start + 1`
+            lt_other_start_plus_1 = ~(~0 << (other.start + 1))
+
+            # find all the bit indexes `i` where
+            # `i < other.start + other.reg_len + reg_len`
+            lt_other_start_plus_other_reg_len_plus_reg_len = (
+                ~(~0 << (other.start + other.reg_len + reg_len)))
+            included = ~(stops & lt_other_start_plus_1)
+            included &= stops
+            included &= lt_other_start_plus_other_reg_len_plus_reg_len
+            return bit_count(included)
 
     def __repr__(self):
         return f"LocSet(starts={self.starts!r}, ty={self.ty!r})"