get LocSet hash working correctly
authorJacob Lifshay <programmerjake@gmail.com>
Wed, 9 Nov 2022 02:18:00 +0000 (18:18 -0800)
committerJacob Lifshay <programmerjake@gmail.com>
Wed, 9 Nov 2022 02:18:00 +0000 (18:18 -0800)
src/bigint_presentation_code/_tests/test_compiler_ir.py
src/bigint_presentation_code/compiler_ir.py
src/bigint_presentation_code/util.py

index ba29ee098e31eb86cfd39d8885053fe38431ccce..9763a07d1c639c82d4dc484867b04161a0bb4f79 100644 (file)
@@ -2,9 +2,10 @@ import unittest
 
 from bigint_presentation_code.compiler_ir import (GPR_SIZE_IN_BYTES, BaseTy,
                                                   Fn, FnAnalysis, GenAsmState,
-                                                  Loc, LocKind, OpKind,
+                                                  Loc, LocKind, LocSet, OpKind,
                                                   OpStage, PreRASimState,
                                                   ProgramPoint, SSAVal, Ty)
+from bigint_presentation_code.util import OFSet
 
 
 class TestCompilerIR(unittest.TestCase):
@@ -23,6 +24,28 @@ class TestCompilerIR(unittest.TestCase):
 
         self.assertEqual(sorted(expected), expected)
 
+    def test_loc_set_hash_intern(self):
+        # type: () -> None
+        # hashes should match all other collections.abc.Set types, which are
+        # supposed to match frozenset but don't until Python 3.11 because of a
+        # bug fixed in:
+        # https://github.com/python/cpython/commit/c878f5d81772dc6f718d6608c78baa4be9a4f176
+        a = LocSet([])
+        self.assertEqual(hash(a), hash(OFSet()))
+        starts = 0, 1, 0, 1, 2
+        GPR = LocKind.GPR
+        expected = OFSet(Loc(kind=GPR, start=i, reg_len=1) for i in starts)
+        b = LocSet(expected)
+        c = LocSet(Loc(kind=GPR, start=i, reg_len=1) for i in starts)
+        d = LocSet(Loc(kind=GPR, start=i, reg_len=1) for i in starts)
+        # hashes should be equal to OFSet's hash
+        self.assertEqual(hash(b), hash(expected))
+        self.assertEqual(hash(c), hash(expected))
+        self.assertEqual(hash(d), hash(expected))
+        # they should intern to the same object
+        self.assertIs(b, d)
+        self.assertIs(c, d)
+
     def make_add_fn(self):
         # type: () -> tuple[Fn, SSAVal]
         fn = Fn()
index d3b52e85c4a5cbb654bf7bedd39c25f42924da2f..756615df8e4ce02a4fcff7125c7fa5318e7005c3 100644 (file)
@@ -643,59 +643,35 @@ SPECIAL_GPRS = (
 
 
 @final
-class _LocSetHashHelper(AbstractSet[Loc]):
-    """helper to more quickly compute LocSet's hash"""
-
-    def __init__(self, locs):
-        # type: (Iterable[Loc]) -> None
-        super().__init__()
-        self.locs = list(locs)
-
-    def __hash__(self):
-        # type: () -> int
-        return super()._hash()
-
-    def __contains__(self, x):
-        # type: (Loc | Any) -> bool
-        return x in self.locs
-
-    def __iter__(self):
-        # type: () -> Iterator[Loc]
-        return iter(self.locs)
-
-    def __len__(self):
-        return len(self.locs)
-
-
-@plain_data(frozen=True, eq=False, repr=False)
-@final
-class LocSet(AbstractSet[Loc], metaclass=InternedMeta):
-    __slots__ = "starts", "ty", "_LocSet__hash"
-
+class LocSet(OFSet[Loc], metaclass=InternedMeta):
     def __init__(self, __locs=()):
         # type: (Iterable[Loc]) -> None
+        super().__init__(__locs)
         if isinstance(__locs, LocSet):
-            self.starts = __locs.starts  # type: FMap[LocKind, FBitSet]
-            self.ty = __locs.ty  # type: Ty | None
-            self._LocSet__hash = __locs._LocSet__hash  # type: int
+            self.__starts = __locs.starts
+            self.__ty = __locs.ty
             return
         starts = {i: BitSet() for i in LocKind}
         ty = None  # type: None | Ty
-
-        def locs():
-            # type: () -> Iterable[Loc]
-            nonlocal ty
-            for loc in __locs:
-                if ty is None:
-                    ty = loc.ty
-                if ty != loc.ty:
-                    raise ValueError(f"conflicting types: {ty} != {loc.ty}")
-                starts[loc.kind].add(loc.start)
-                yield loc
-        self._LocSet__hash = _LocSetHashHelper(locs()).__hash__()
-        self.starts = FMap(
+        for loc in self:
+            if ty is None:
+                ty = loc.ty
+            if ty != loc.ty:
+                raise ValueError(f"conflicting types: {ty} != {loc.ty}")
+            starts[loc.kind].add(loc.start)
+        self.__starts = FMap(
             (k, FBitSet(v)) for k, v in starts.items() if len(v) != 0)
-        self.ty = ty
+        self.__ty = ty
+
+    @property
+    def starts(self):
+        # type: () -> FMap[LocKind, FBitSet]
+        return self.__starts
+
+    @property
+    def ty(self):
+        # type: () -> Ty | None
+        return self.__ty
 
     @cached_property
     def stops(self):
@@ -756,38 +732,6 @@ class LocSet(AbstractSet[Loc], metaclass=InternedMeta):
                         yield loc
         return LocSet(locs())
 
-    def __contains__(self, loc):
-        # type: (Loc | Any) -> bool
-        if not isinstance(loc, Loc) or loc.ty != self.ty:
-            return False
-        if loc.kind not in self.starts:
-            return False
-        return loc.start in self.starts[loc.kind]
-
-    def __iter__(self):
-        # type: () -> Iterator[Loc]
-        if self.ty is None:
-            return
-        for kind, starts in self.starts.items():
-            for start in starts:
-                yield Loc(kind=kind, start=start, reg_len=self.ty.reg_len)
-
-    @cached_property
-    def __len(self):
-        return sum((len(v) for v in self.starts.values()), 0)
-
-    def __len__(self):
-        return self.__len
-
-    def __hash__(self):
-        return self._LocSet__hash
-
-    def __eq__(self, __other):
-        # type: (LocSet | Any) -> bool
-        if isinstance(__other, LocSet):
-            return self.ty == __other.ty and self.starts == __other.starts
-        return super().__eq__(__other)
-
     @lru_cache(maxsize=None, typed=True)
     def max_conflicts_with(self, other):
         # type: (LocSet | Loc) -> int
@@ -800,12 +744,7 @@ class LocSet(AbstractSet[Loc], metaclass=InternedMeta):
             return sum(other.conflicts(i) for i in self)
 
     def __repr__(self):
-        items = []  # type: list[str]
-        for name in fields(self):
-            if name.startswith("_"):
-                continue
-            items.append(f"{name}={getattr(self, name)!r}")
-        return f"LocSet({', '.join(items)})"
+        return f"LocSet(starts={self.starts!r}, ty={self.ty!r})"
 
 
 @plain_data(frozen=True, unsafe_hash=True)
index 87c6975425dbfd4fdec64df750a0110924c648ee..03eaeff9e91ca2b0ae6214031652f83f9ccc3863 100644 (file)
@@ -57,7 +57,10 @@ class OFSet(AbstractSet[_T_co], metaclass=InternedMeta):
     def __init__(self, items=()):
         # type: (Iterable[_T_co]) -> None
         super().__init__()
-        self.__items = {v: None for v in items}
+        if isinstance(items, OFSet):
+            self.__items = items.__items
+        else:
+            self.__items = {v: None for v in items}
 
     def __contains__(self, x):
         # type: (Any) -> bool