working on code
[bigint-presentation-code.git] / src / bigint_presentation_code / toom_cook.py
index 246f65454a861be15fbfb00954f9f9a37dac4899..aa18967e3443353851439002f2ed7458ab7a37c9 100644 (file)
@@ -8,10 +8,7 @@ from typing import Any, Generic, Iterable, Mapping, TypeVar, Union
 
 from nmutil.plain_data import plain_data
 
-from bigint_presentation_code.compiler_ir import (Fn, OpBigIntAddSub,
-                                                  OpBigIntMulDiv, OpConcat,
-                                                  OpLI, OpSetCA, OpSetVLImm,
-                                                  OpSplit, SSAGPRRange)
+from bigint_presentation_code.compiler_ir2 import (Fn, OpKind, SSAVal)
 from bigint_presentation_code.matrix import Matrix
 from bigint_presentation_code.type_util import Literal, final
 
@@ -190,6 +187,7 @@ class EvalOp(Generic[_EvalOpLHS, _EvalOpRHS]):
 
     def __init__(self, lhs, rhs):
         # type: (_EvalOpLHS, _EvalOpRHS) -> None
+        super().__init__()
         self.lhs = lhs
         self.rhs = rhs
         self.poly = self._make_poly()
@@ -442,32 +440,65 @@ class ToomCookInstance:
 
 
 def simple_mul(fn, lhs, rhs):
-    # type: (Fn, SSAGPRRange, SSAGPRRange) -> SSAGPRRange
+    # type: (Fn, SSAVal, SSAVal) -> SSAVal
     """ simple O(n^2) big-int unsigned multiply """
-    if lhs.ty.length < rhs.ty.length:
+    if lhs.ty.reg_len < rhs.ty.reg_len:
         lhs, rhs = rhs, lhs
     # split rhs into elements
-    rhs_words = OpSplit(fn, rhs, range(1, rhs.ty.length)).results
-    retval = None
-    vl = OpSetVLImm(fn, lhs.ty.length).out
-    zero = OpLI(fn, 0).out
+    rhs_setvl = fn.append_new_op(kind=OpKind.SetVLI,
+                                 immediates=[rhs.ty.reg_len], name="rhs_setvl")
+    rhs_spread = fn.append_new_op(
+        kind=OpKind.Spread, input_vals=[rhs, rhs_setvl.outputs[0]],
+        maxvl=rhs.ty.reg_len, name="rhs_spread")
+    rhs_words = rhs_spread.outputs
+    spread_retval = None  # type: tuple[SSAVal, ...] | None
+    maxvl = lhs.ty.reg_len
+    lhs_setvl = fn.append_new_op(kind=OpKind.SetVLI,
+                                 immediates=[lhs.ty.reg_len], name="lhs_setvl")
+    vl = lhs_setvl.outputs[0]
+    zero_op = fn.append_new_op(kind=OpKind.LI, immediates=[0], name="zero")
+    zero = zero_op.outputs[0]
     for shift, rhs_word in enumerate(rhs_words):
-        mul = OpBigIntMulDiv(fn, RA=lhs, RB=rhs_word, RC=zero,
-                             is_div=False, vl=vl)
-        if retval is None:
-            retval = OpConcat(fn, [mul.RT, mul.RS]).dest
+        mul = fn.append_new_op(kind=OpKind.SvMAddEDU,
+                               input_vals=[lhs, rhs_word, zero, vl],
+                               maxvl=maxvl, name=f"mul{shift}")
+        if spread_retval is None:
+            mul_rt_spread = fn.append_new_op(
+                kind=OpKind.Spread, input_vals=[mul.outputs[0], vl],
+                name=f"mul{shift}_rt_spread", maxvl=maxvl)
+            spread_retval = (*mul_rt_spread.outputs, mul.outputs[1])
         else:
-            first_part, last_part = OpSplit(fn, retval, [shift]).results
-            add = OpBigIntAddSub(
-                fn, lhs=mul.RT, rhs=last_part, CA_in=OpSetCA(fn, False).out,
-                is_sub=False, vl=vl)
-            add_hi = OpBigIntAddSub(fn, lhs=mul.RS, rhs=zero, CA_in=add.CA_out,
-                                    is_sub=False)
-            retval = OpConcat(fn, [first_part, add.out, add_hi.out]).dest
-    assert retval is not None
-    return retval
+            first_part = spread_retval[:shift]  # type: tuple[SSAVal, ...]
+            last_part = spread_retval[shift:]
+
+            add_rb_concat = fn.append_new_op(
+                kind=OpKind.Concat, input_vals=[*last_part, vl],
+                name=f"add{shift}_rb_concat", maxvl=maxvl)
+            clear_ca = fn.append_new_op(kind=OpKind.ClearCA,
+                                        name=f"clear_ca{shift}")
+            add = fn.append_new_op(
+                kind=OpKind.SvAddE, input_vals=[
+                    mul.outputs[0], add_rb_concat.outputs[0],
+                    clear_ca.outputs[0], vl],
+                maxvl=maxvl, name=f"add{shift}")
+            add_rt_spread = fn.append_new_op(
+                kind=OpKind.Spread, input_vals=[add.outputs[0], vl],
+                name=f"add{shift}_rt_spread", maxvl=maxvl)
+            add_hi = fn.append_new_op(
+                kind=OpKind.AddZE, input_vals=[mul.outputs[1], add.outputs[1]],
+                name=f"add_hi{shift}")
+            spread_retval = (
+                *first_part, *add_rt_spread.outputs, add_hi.outputs[0])
+    assert spread_retval is not None
+    lhs_setvl = fn.append_new_op(
+        kind=OpKind.SetVLI, immediates=[len(spread_retval)],
+        name="retval_setvl")
+    concat_retval = fn.append_new_op(
+        kind=OpKind.Concat, input_vals=[*spread_retval, lhs_setvl.outputs[0]],
+        name="concat_retval", maxvl=len(spread_retval))
+    return concat_retval.outputs[0]
 
 
 def toom_cook_mul(fn, lhs, rhs, instances):
-    # type: (Fn, SSAGPRRange, SSAGPRRange, list[ToomCookInstance]) -> SSAGPRRange
+    # type: (Fn, SSAVal, SSAVal, list[ToomCookInstance]) -> SSAVal
     raise NotImplementedError