attempt to speed up code
authorJacob Lifshay <programmerjake@gmail.com>
Fri, 16 Dec 2022 08:22:18 +0000 (00:22 -0800)
committerJacob Lifshay <programmerjake@gmail.com>
Fri, 16 Dec 2022 08:22:18 +0000 (00:22 -0800)
src/bigint_presentation_code/register_allocator.py
src/bigint_presentation_code/util.py

index 959e1c39cb978105460037d4c139639025debfbe..6732953f68926123b3f5312a2d3183f7032ac697 100644 (file)
@@ -303,21 +303,45 @@ class MergedSSAVal(Interned):
                             return lhs, rhs
         return None
 
-    @lru_cache(maxsize=None, typed=True)
     def copy_merged(self, lhs_loc, rhs, rhs_loc, copy_relation):
         # type: (Loc | None, MergedSSAVal, Loc | None, _CopyRelation) -> MergedSSAVal
+        retval = self.try_copy_merged(lhs_loc, rhs, rhs_loc, copy_relation)
+        if isinstance(retval, MergedSSAVal):
+            return retval
+        raise retval
+
+    def try_copy_merged(self, lhs_loc,  # type: Loc | None
+                        rhs,  # type: MergedSSAVal
+                        rhs_loc,  # type: Loc | None
+                        copy_relation,  # type: _CopyRelation
+                        ):
+        # type: (...) -> MergedSSAVal | BadMergedSSAVal
         cr_lhs, cr_rhs = copy_relation
         if cr_lhs.ssa_val not in self.ssa_vals:
             cr_lhs, cr_rhs = cr_rhs, cr_lhs
-        lhs_merged = self.with_offset_to_match(
-            cr_lhs.ssa_val, additional_offset=-cr_lhs.reg_idx)
-        if lhs_loc is not None:
-            lhs_merged = lhs_merged.with_loc(lhs_loc)
-        rhs_merged = rhs.with_offset_to_match(
-            cr_rhs.ssa_val, additional_offset=-cr_rhs.reg_idx)
-        if rhs_loc is not None:
-            rhs_merged = rhs_merged.with_loc(rhs_loc)
-        return lhs_merged.merged(rhs_merged).normalized()
+        return self.__try_copy_merged(lhs_loc=lhs_loc, cr_lhs=cr_lhs,
+                                      rhs=rhs, rhs_loc=rhs_loc, cr_rhs=cr_rhs)
+
+    @lru_cache(maxsize=None, typed=True)
+    def __try_copy_merged(self, lhs_loc,  # type: Loc | None
+                          cr_lhs,  # type: SSAValSubReg
+                          rhs,  # type: MergedSSAVal
+                          rhs_loc,  # type: Loc | None
+                          cr_rhs,  # type: SSAValSubReg
+                          ):
+        # type: (...) -> MergedSSAVal | BadMergedSSAVal
+        try:
+            lhs_merged = self.with_offset_to_match(
+                cr_lhs.ssa_val, additional_offset=-cr_lhs.reg_idx)
+            if lhs_loc is not None:
+                lhs_merged = lhs_merged.with_loc(lhs_loc)
+            rhs_merged = rhs.with_offset_to_match(
+                cr_rhs.ssa_val, additional_offset=-cr_rhs.reg_idx)
+            if rhs_loc is not None:
+                rhs_merged = rhs_merged.with_loc(rhs_loc)
+            return lhs_merged.merged(rhs_merged).normalized()
+        except BadMergedSSAVal as e:
+            return e
 
 
 @final
index 9d2b7b1af8fdc4cb099311d696965959305e3b2b..9f1e5ab8c92165476d99594ad1267af53b51a1d4 100644 (file)
@@ -7,6 +7,7 @@ from bigint_presentation_code.type_util import Self, final
 
 _T_co = TypeVar("_T_co", covariant=True)
 _T = TypeVar("_T")
+_T2 = TypeVar("_T2")
 
 __all__ = [
     "BaseBitSet",
@@ -141,6 +142,18 @@ class OSet(MutableSet[_T]):
         # type: (_T) -> None
         self.__items.pop(value, None)
 
+    def remove(self, value):
+        # type: (_T) -> None
+        del self.__items[value]
+
+    def pop(self):
+        # type: () -> _T
+        return self.__items.popitem()[0]
+
+    def clear(self):
+        # type: () -> None
+        self.__items.clear()
+
     def __repr__(self):
         # type: () -> str
         if len(self) == 0:
@@ -201,6 +214,14 @@ class FMap(Mapping[_T, _T_co], Interned):
         # type: () -> str
         return f"FMap({self.__items})"
 
+    def get(self, key, default=None):
+        # type: (_T, _T_co | _T2) -> _T_co | _T2
+        return self.__items.get(key, default)
+
+    def __contains__(self, key):
+        # type: (_T | object) -> bool
+        return key in self.__items
+
 
 def trailing_zero_count(v, default=-1):
     # type: (int, int) -> int