simple_mul works with signed/unsigned mul; also made ir repr easier to read
[bigint-presentation-code.git] / src / bigint_presentation_code / toom_cook.py
index de5b0f69776da791bd60db692d9dee011b2fff22..c261d82f776c5d897caa0b3449dc6bec9888a159 100644 (file)
@@ -457,7 +457,7 @@ class EvalOpInput(EvalOp):
         output = cast_to_size(
             fn=state.fn, ssa_val=inp.ssa_val, src_signed=inp.is_signed,
             dest_size=output_value_range.output_size,
-            name="input_{self.part_index}_cast")
+            name=f"input_{self.part_index}_cast")
         return EvalOpGenIrOutput(output=output, value_range=output_value_range)
 
 
@@ -648,10 +648,10 @@ class ToomCookInstance:
 @plain_data(frozen=True, unsafe_hash=True)
 @final
 class PartialProduct:
-    __slots__ = "ssa_val_spread", "shift_in_words", "is_signed"
+    __slots__ = "ssa_val_spread", "shift_in_words", "is_signed", "subtract"
 
-    def __init__(self, ssa_val_spread, shift_in_words, is_signed):
-        # type: (Iterable[SSAVal], int, bool) -> None
+    def __init__(self, ssa_val_spread, shift_in_words, is_signed, subtract):
+        # type: (Iterable[SSAVal], int, bool, bool) -> None
         if shift_in_words < 0:
             raise ValueError("invalid shift_in_words")
         self.ssa_val_spread = tuple(ssa_val_spread)
@@ -660,10 +660,11 @@ class PartialProduct:
                 raise ValueError("invalid ssa_val.ty")
         self.shift_in_words = shift_in_words
         self.is_signed = is_signed
+        self.subtract = subtract
 
 
-def sum_partial_products(fn, partial_products, name):
-    # type: (Fn, Iterable[PartialProduct], str) -> SSAVal
+def sum_partial_products(fn, partial_products, retval_size, name):
+    # type: (Fn, Iterable[PartialProduct], int, str) -> SSAVal
     retval_spread = []  # type: list[SSAVal]
     retval_signed = False
     zero = fn.append_new_op(OpKind.LI, immediates=[0],
@@ -672,7 +673,8 @@ def sum_partial_products(fn, partial_products, name):
     for idx, partial_product in enumerate(partial_products):
         shift_in_words = partial_product.shift_in_words
         spread = list(partial_product.ssa_val_spread)
-        if not retval_signed and shift_in_words >= len(retval_spread):
+        if (not retval_signed and shift_in_words >= len(retval_spread)
+                and not partial_product.subtract):
             retval_spread.extend(
                 [zero] * (shift_in_words - len(retval_spread)))
             retval_spread.extend(spread)
@@ -680,10 +682,21 @@ def sum_partial_products(fn, partial_products, name):
             has_carry_word = False
             continue
         assert len(retval_spread) != 0, "logic error"
-        maxvl = max(len(retval_spread) - shift_in_words, len(spread))
+        retval_hi_len = len(retval_spread) - shift_in_words
+        if retval_hi_len <= len(spread):
+            maxvl = len(spread) + 1
+            has_carry_word = True
+        elif has_carry_word:
+            maxvl = retval_hi_len
+        else:
+            maxvl = retval_hi_len + 1
+            has_carry_word = True
         if not has_carry_word:
             maxvl += 1
             has_carry_word = True
+        if maxvl > retval_size - shift_in_words:
+            maxvl = retval_size - shift_in_words
+            has_carry_word = False
         retval_spread = cast_to_size_spread(
             fn=fn, ssa_vals=retval_spread, src_signed=retval_signed,
             dest_size=maxvl + shift_in_words, name=f"{name}_{idx}_cast_retval")
@@ -699,26 +712,38 @@ def sum_partial_products(fn, partial_products, name):
             name=f"{name}_{idx}_retval_concat", maxvl=maxvl)
         pp_concat = fn.append_new_op(
             kind=OpKind.Concat,
-            input_vals=[*retval_spread[shift_in_words:], setvl.outputs[0]],
+            input_vals=[*spread, setvl.outputs[0]],
             name=f"{name}_{idx}_pp_concat", maxvl=maxvl)
-        clear_ca = fn.append_new_op(kind=OpKind.ClearCA,
-                                    name=f"{name}_{idx}_clear_ca")
-        add = fn.append_new_op(
-            kind=OpKind.SvAddE, input_vals=[
-                retval_concat.outputs[0], pp_concat.outputs[0],
-                clear_ca.outputs[0], setvl.outputs[0]],
-            maxvl=maxvl, name=f"{name}_{idx}_add")
+        if partial_product.subtract:
+            set_ca = fn.append_new_op(kind=OpKind.SetCA,
+                                      name=f"{name}_{idx}_set_ca")
+            add_sub = fn.append_new_op(
+                kind=OpKind.SvSubFE, input_vals=[
+                    pp_concat.outputs[0], retval_concat.outputs[0],
+                    set_ca.outputs[0], setvl.outputs[0]],
+                maxvl=maxvl, name=f"{name}_{idx}_sub")
+        else:
+            clear_ca = fn.append_new_op(kind=OpKind.ClearCA,
+                                        name=f"{name}_{idx}_clear_ca")
+            add_sub = fn.append_new_op(
+                kind=OpKind.SvAddE, input_vals=[
+                    retval_concat.outputs[0], pp_concat.outputs[0],
+                    clear_ca.outputs[0], setvl.outputs[0]],
+                maxvl=maxvl, name=f"{name}_{idx}_add")
         retval_spread[shift_in_words:] = fn.append_new_op(
             kind=OpKind.Spread,
-            input_vals=[add.outputs[0], setvl.outputs[0]],
+            input_vals=[add_sub.outputs[0], setvl.outputs[0]],
             name=f"{name}_{idx}_sum_spread", maxvl=maxvl).outputs
+    retval_spread = cast_to_size_spread(
+        fn=fn, ssa_vals=retval_spread, src_signed=retval_signed,
+        dest_size=retval_size, name=f"{name}_retval_cast")
     retval_setvl = fn.append_new_op(
-        OpKind.SetVLI, immediates=[len(retval_spread)],
-        maxvl=len(retval_spread), name=f"{name}_setvl")
+        OpKind.SetVLI, immediates=[retval_size],
+        maxvl=retval_size, name=f"{name}_setvl")
     retval_concat = fn.append_new_op(
         kind=OpKind.Concat,
         input_vals=[*retval_spread, retval_setvl.outputs[0]],
-        name=f"{name}_concat", maxvl=len(retval_spread))
+        name=f"{name}_concat", maxvl=retval_size)
     return retval_concat.outputs[0]
 
 
@@ -729,20 +754,20 @@ def simple_mul(fn, lhs, lhs_signed, rhs, rhs_signed, name):
         lhs, rhs = rhs, lhs
         lhs_signed, rhs_signed = rhs_signed, lhs_signed
     # split rhs into elements
-    rhs_setvl = fn.append_new_op(kind=OpKind.SetVLI,
-                                 immediates=[rhs.ty.reg_len], name="rhs_setvl")
+    rhs_setvl = fn.append_new_op(
+        kind=OpKind.SetVLI, immediates=[rhs.ty.reg_len],
+        name=f"{name}_rhs_setvl")
     rhs_spread = fn.append_new_op(
         kind=OpKind.Spread, input_vals=[rhs, rhs_setvl.outputs[0]],
-        maxvl=rhs.ty.reg_len, name="rhs_spread")
+        maxvl=rhs.ty.reg_len, name=f"{name}_rhs_spread")
     rhs_words = rhs_spread.outputs
     zero = fn.append_new_op(
         kind=OpKind.LI, immediates=[0], name=f"{name}_zero").outputs[0]
     maxvl = lhs.ty.reg_len
     lhs_setvl = fn.append_new_op(
-        kind=OpKind.SetVLI, immediates=[maxvl], name="lhs_setvl", maxvl=maxvl)
+        kind=OpKind.SetVLI, immediates=[maxvl], name=f"{name}_lhs_setvl",
+        maxvl=maxvl)
     vl = lhs_setvl.outputs[0]
-    if lhs_signed or rhs_signed:
-        raise NotImplementedError  # FIXME: implement signed multiply
 
     def partial_products():
         # type: () -> Iterable[PartialProduct]
@@ -756,9 +781,46 @@ def simple_mul(fn, lhs, lhs_signed, rhs, rhs_signed, name):
             yield PartialProduct(
                 ssa_val_spread=[*mul_rt_spread.outputs, mul.outputs[1]],
                 shift_in_words=shift_in_words,
-                is_signed=False)
-    return sum_partial_products(fn=fn, partial_products=partial_products(),
-                                name=name)
+                is_signed=False, subtract=False)
+        if lhs_signed:
+            lhs_spread = fn.append_new_op(
+                kind=OpKind.Spread, input_vals=[lhs, lhs_setvl.outputs[0]],
+                maxvl=lhs.ty.reg_len, name=f"{name}_lhs_spread")
+            rhs_mask = fn.append_new_op(
+                kind=OpKind.SRADI, input_vals=[lhs_spread.outputs[-1]],
+                immediates=[GPR_SIZE_IN_BITS - 1], name=f"{name}_rhs_mask")
+            lhs_and = fn.append_new_op(
+                kind=OpKind.SvAndVS,
+                input_vals=[rhs, rhs_mask.outputs[0], rhs_setvl.outputs[0]],
+                maxvl=rhs.ty.reg_len, name=f"{name}_rhs_and")
+            rhs_and_spread = fn.append_new_op(
+                kind=OpKind.Spread,
+                input_vals=[lhs_and.outputs[0], rhs_setvl.outputs[0]],
+                name=f"{name}_rhs_and_spread", maxvl=rhs.ty.reg_len)
+            yield PartialProduct(
+                ssa_val_spread=rhs_and_spread.outputs,
+                shift_in_words=lhs.ty.reg_len, is_signed=False, subtract=True)
+        if rhs_signed:
+            rhs_spread = fn.append_new_op(
+                kind=OpKind.Spread, input_vals=[rhs, rhs_setvl.outputs[0]],
+                maxvl=rhs.ty.reg_len, name=f"{name}_rhs_spread")
+            lhs_mask = fn.append_new_op(
+                kind=OpKind.SRADI, input_vals=[rhs_spread.outputs[-1]],
+                immediates=[GPR_SIZE_IN_BITS - 1], name=f"{name}_lhs_mask")
+            rhs_and = fn.append_new_op(
+                kind=OpKind.SvAndVS,
+                input_vals=[lhs, lhs_mask.outputs[0], lhs_setvl.outputs[0]],
+                maxvl=lhs.ty.reg_len, name=f"{name}_lhs_and")
+            lhs_and_spread = fn.append_new_op(
+                kind=OpKind.Spread,
+                input_vals=[rhs_and.outputs[0], lhs_setvl.outputs[0]],
+                name=f"{name}_lhs_and_spread", maxvl=lhs.ty.reg_len)
+            yield PartialProduct(
+                ssa_val_spread=lhs_and_spread.outputs,
+                shift_in_words=rhs.ty.reg_len, is_signed=False, subtract=True)
+    return sum_partial_products(
+        fn=fn, partial_products=partial_products(),
+        retval_size=lhs.ty.reg_len + rhs.ty.reg_len, name=name)
 
 
 def cast_to_size(fn, ssa_val, src_signed, dest_size, name):