add equality constraints
authorJacob Lifshay <programmerjake@gmail.com>
Sat, 8 Oct 2022 00:34:41 +0000 (17:34 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Sat, 8 Oct 2022 00:34:41 +0000 (17:34 -0700)
src/bigint_presentation_code/toom_cook.py

index 58c9a449ac95551dddf7dfcd48bfce2b2357821b..2209d8b780f3189588dbe7b1317157615d8d836f 100644 (file)
@@ -273,6 +273,17 @@ class VecArg:
             yield GPR(r[index])
 
 
+@final
+@plain_data(unsafe_hash=True, frozen=True)
+class EqualityConstraint:
+    __slots__ = "lhs", "rhs"
+
+    def __init__(self, lhs, rhs):
+        # type: (SSAVal, SSAVal) -> None
+        self.lhs = lhs
+        self.rhs = rhs
+
+
 @plain_data(unsafe_hash=True, frozen=True)
 class Op(metaclass=ABCMeta):
     __slots__ = ()
@@ -308,6 +319,11 @@ class Op(metaclass=ABCMeta):
         # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
         ...
 
+    def get_equality_constraints(self):
+        # type: () -> Iterable[EqualityConstraint]
+        if False:
+            yield ...
+
     def __init__(self):
         pass
 
@@ -426,6 +442,10 @@ class OpAddSubE(Op):
         else:
             yield from self.RB.possible_reg_assignments(val, value_assignments)
 
+    def get_equality_constraints(self):
+        # type: () -> Iterable[EqualityConstraint]
+        yield EqualityConstraint(self.CY_in, self.CY_out)
+
 
 def to_reg_set(v):
     # type: (None | GPR | range) -> set[GPR]
@@ -500,6 +520,10 @@ class OpBigIntMulDiv(Op):
                 val, value_assignments,
                 conflicting_regs=to_reg_set(RT_range) | to_reg_set(RC_RS_reg))
 
+    def get_equality_constraints(self):
+        # type: () -> Iterable[EqualityConstraint]
+        yield EqualityConstraint(self.RC, self.RS)
+
 
 @final
 @unique
@@ -715,6 +739,10 @@ class OpStore(Op):
         else:
             yield from self.RS.possible_reg_assignments(value_assignments)
 
+    def get_equality_constraints(self):
+        # type: () -> Iterable[EqualityConstraint]
+        yield EqualityConstraint(self.mem_in, self.mem_out)
+
 
 @plain_data(unsafe_hash=True, frozen=True)
 @final