TOOM-2 256x256->512-bit [un]signed*[un]signed mul works!
[bigint-presentation-code.git] / src / bigint_presentation_code / toom_cook.py
index 75891f2d5deeacf47a9924368c1b3408392f6865..a7c5450f6f88ec5e3f4d204ad5449d7eccb669e4 100644 (file)
@@ -250,6 +250,19 @@ class EvalOpGenIrOutput:
         # type: () -> int
         return self.value_range.output_size
 
+    @property
+    def current_debugging_value(self):
+        # type: () -> tuple[int, ...]
+        """ get the current value for debugging in pdb or similar.
+
+        This is intended for use with
+        `PreRASimState.set_current_debugging_state`.
+
+        This is only intended for debugging, do not use in unit tests or
+        production code.
+        """
+        return self.output.current_debugging_value
+
 
 @plain_data(frozen=True, unsafe_hash=True)
 @final
@@ -284,6 +297,19 @@ class EvalOpGenIrInput:
         if self.min_value > self.max_value:
             raise ValueError("invalid value range")
 
+    @property
+    def current_debugging_value(self):
+        # type: () -> tuple[int, ...]
+        """ get the current value for debugging in pdb or similar.
+
+        This is intended for use with
+        `PreRASimState.set_current_debugging_state`.
+
+        This is only intended for debugging, do not use in unit tests or
+        production code.
+        """
+        return self.ssa_val.current_debugging_value
+
 
 @plain_data(frozen=True)
 @final
@@ -901,7 +927,7 @@ def cast_to_size_spread(fn, ssa_vals, src_signed, dest_size, name):
 
 
 def split_into_exact_sized_parts(fn, ssa_val, part_count, part_size, name):
-    # type: (Fn, SSAVal, int, int, str) -> list[SSAVal]
+    # type: (Fn, SSAVal, int, int, str) -> tuple[SSAVal, ...]
     """split ssa_val into part_count parts, where all but the last part have
     `part.ty.reg_len == part_size`.
     """
@@ -910,7 +936,7 @@ def split_into_exact_sized_parts(fn, ssa_val, part_count, part_size, name):
     if part_count <= 0:
         raise ValueError("invalid part count, must be positive")
     if part_count == 1:
-        return [ssa_val]
+        return (ssa_val,)
     too_short_reg_len = (part_count - 1) * part_size
     if ssa_val.ty.reg_len <= too_short_reg_len:
         raise ValueError(f"ssa_val is too short to split, must have "
@@ -934,88 +960,127 @@ def split_into_exact_sized_parts(fn, ssa_val, part_count, part_size, name):
             input_vals=[*spread.outputs[start:stop], part_setvl.outputs[0]],
             name=f"{name}_{part}_concat", maxvl=part_maxvl)
         retval.append(concat.outputs[0])
-    return retval
+    return tuple(retval)
 
 
-__TCIs = Tuple[ToomCookInstance, ...]
+_TCIs = Tuple[ToomCookInstance, ...]
+
+
+@plain_data(frozen=True)
+@final
+class ToomCookMul:
+    __slots__ = (
+        "fn", "lhs", "lhs_signed", "rhs", "rhs_signed", "instances",
+        "retval_size", "start_instance_index", "instance", "part_size",
+        "lhs_parts", "lhs_inputs", "lhs_eval_state", "lhs_outputs",
+        "rhs_parts", "rhs_inputs", "rhs_eval_state", "rhs_outputs",
+        "prod_inputs", "prod_eval_state", "prod_parts",
+        "partial_products", "retval",
+    )
+
+    def __init__(self, fn, lhs, lhs_signed, rhs, rhs_signed, instances,
+                 retval_size=None, start_instance_index=0):
+        # type: (Fn, SSAVal, bool, SSAVal, bool, _TCIs, None | int, int) -> None
+        self.fn = fn
+        self.lhs = lhs
+        self.lhs_signed = lhs_signed
+        self.rhs = rhs
+        self.rhs_signed = rhs_signed
+        self.instances = instances
+        if retval_size is None:
+            retval_size = lhs.ty.reg_len + rhs.ty.reg_len
+        self.retval_size = retval_size
+        if start_instance_index < 0:
+            raise ValueError("start_instance_index must be non-negative")
+        self.start_instance_index = start_instance_index
+        self.instance = None
+        self.part_size = 0  # type: int
+        while start_instance_index < len(instances):
+            self.instance = instances[start_instance_index]
+            self.part_size = max(
+                lhs.ty.reg_len // self.instance.lhs_part_count,
+                rhs.ty.reg_len // self.instance.rhs_part_count)
+            if self.part_size <= 0:
+                self.instance = None
+                start_instance_index += 1
+            else:
+                break
+        if self.instance is None:
+            self.retval = simple_mul(fn=fn,
+                                     lhs=lhs, lhs_signed=lhs_signed,
+                                     rhs=rhs, rhs_signed=rhs_signed,
+                                     name="toom_cook_base_case")
+            return
+        self.lhs_parts = split_into_exact_sized_parts(
+            fn=fn, ssa_val=lhs, part_count=self.instance.lhs_part_count,
+            part_size=self.part_size, name="lhs")
+        self.lhs_inputs = []  # type: list[EvalOpGenIrInput]
+        for part, ssa_val in enumerate(self.lhs_parts):
+            self.lhs_inputs.append(EvalOpGenIrInput(
+                ssa_val=ssa_val,
+                is_signed=lhs_signed and part == len(self.lhs_parts) - 1))
+        self.lhs_eval_state = EvalOpGenIrState(fn=fn, inputs=self.lhs_inputs)
+        lhs_eval_ops = self.instance.lhs_eval_ops
+        self.lhs_outputs = [
+            self.lhs_eval_state.get_output(i) for i in lhs_eval_ops]
+        self.rhs_parts = split_into_exact_sized_parts(
+            fn=fn, ssa_val=rhs, part_count=self.instance.rhs_part_count,
+            part_size=self.part_size, name="rhs")
+        self.rhs_inputs = []  # type: list[EvalOpGenIrInput]
+        for part, ssa_val in enumerate(self.rhs_parts):
+            self.rhs_inputs.append(EvalOpGenIrInput(
+                ssa_val=ssa_val,
+                is_signed=rhs_signed and part == len(self.rhs_parts) - 1))
+        self.rhs_eval_state = EvalOpGenIrState(fn=fn, inputs=self.rhs_inputs)
+        rhs_eval_ops = self.instance.rhs_eval_ops
+        self.rhs_outputs = [
+            self.rhs_eval_state.get_output(i) for i in rhs_eval_ops]
+        self.prod_inputs = []  # type: list[EvalOpGenIrInput]
+        for lhs_output, rhs_output in zip(self.lhs_outputs, self.rhs_outputs):
+            ssa_val = toom_cook_mul(
+                fn=fn,
+                lhs=lhs_output.output, lhs_signed=lhs_output.is_signed,
+                rhs=rhs_output.output, rhs_signed=rhs_output.is_signed,
+                instances=instances,
+                start_instance_index=start_instance_index + 1)
+            products = (lhs_output.min_value * rhs_output.min_value,
+                        lhs_output.min_value * rhs_output.max_value,
+                        lhs_output.max_value * rhs_output.min_value,
+                        lhs_output.max_value * rhs_output.max_value)
+            self.prod_inputs.append(EvalOpGenIrInput(
+                ssa_val=ssa_val,
+                is_signed=None,
+                min_value=min(products),
+                max_value=max(products)))
+        self.prod_eval_state = EvalOpGenIrState(fn=fn, inputs=self.prod_inputs)
+        prod_eval_ops = self.instance.prod_eval_ops
+        self.prod_parts = [
+            self.prod_eval_state.get_output(i) for i in prod_eval_ops]
+
+        def partial_products():
+            # type: () -> Iterable[PartialProduct]
+            for part, prod_part in enumerate(self.prod_parts):
+                part_maxvl = prod_part.output.ty.reg_len
+                part_setvl = fn.append_new_op(
+                    OpKind.SetVLI, immediates=[part_maxvl],
+                    name=f"prod_{part}_setvl", maxvl=part_maxvl)
+                spread_part = fn.append_new_op(
+                    OpKind.Spread,
+                    input_vals=[prod_part.output, part_setvl.outputs[0]],
+                    name=f"prod_{part}_spread", maxvl=part_maxvl)
+                yield PartialProduct(
+                    spread_part.outputs, shift_in_words=part * self.part_size,
+                    is_signed=prod_part.is_signed, subtract=False)
+        self.partial_products = tuple(partial_products())
+        self.retval = sum_partial_products(
+            fn=fn, partial_products=self.partial_products,
+            retval_size=retval_size, name="prod")
 
 
 def toom_cook_mul(fn, lhs, lhs_signed, rhs, rhs_signed, instances,
                   retval_size=None, start_instance_index=0):
-    # type: (Fn, SSAVal, bool, SSAVal, bool, __TCIs, None | int, int) -> SSAVal
-    if retval_size is None:
-        retval_size = lhs.ty.reg_len + rhs.ty.reg_len
-    if start_instance_index < 0:
-        raise ValueError("start_instance_index must be non-negative")
-    instance = None
-    part_size = 0
-    while start_instance_index < len(instances):
-        instance = instances[start_instance_index]
-        part_size = max(lhs.ty.reg_len // instance.lhs_part_count,
-                        rhs.ty.reg_len // instance.rhs_part_count)
-        if part_size <= 0:
-            instance = None
-            start_instance_index += 1
-        else:
-            break
-    if instance is None:
-        return simple_mul(fn=fn,
-                          lhs=lhs, lhs_signed=lhs_signed,
-                          rhs=rhs, rhs_signed=rhs_signed,
-                          name="toom_cook_base_case")
-    lhs_parts = split_into_exact_sized_parts(
-        fn=fn, ssa_val=lhs, part_count=instance.lhs_part_count,
-        part_size=part_size, name="lhs")
-    lhs_inputs = []  # type: list[EvalOpGenIrInput]
-    for part, ssa_val in enumerate(lhs_parts):
-        lhs_inputs.append(EvalOpGenIrInput(
-            ssa_val=ssa_val,
-            is_signed=lhs_signed and part == len(lhs_parts) - 1))
-    lhs_eval_state = EvalOpGenIrState(fn=fn, inputs=lhs_inputs)
-    lhs_outputs = [lhs_eval_state.get_output(i) for i in instance.lhs_eval_ops]
-    rhs_parts = split_into_exact_sized_parts(
-        fn=fn, ssa_val=rhs, part_count=instance.rhs_part_count,
-        part_size=part_size, name="rhs")
-    rhs_inputs = []  # type: list[EvalOpGenIrInput]
-    for part, ssa_val in enumerate(rhs_parts):
-        rhs_inputs.append(EvalOpGenIrInput(
-            ssa_val=ssa_val,
-            is_signed=rhs_signed and part == len(rhs_parts) - 1))
-    rhs_eval_state = EvalOpGenIrState(fn=fn, inputs=rhs_inputs)
-    rhs_outputs = [rhs_eval_state.get_output(i) for i in instance.rhs_eval_ops]
-    prod_inputs = []  # type: list[EvalOpGenIrInput]
-    for lhs_output, rhs_output in zip(lhs_outputs, rhs_outputs):
-        ssa_val = toom_cook_mul(
-            fn=fn,
-            lhs=lhs_output.output, lhs_signed=lhs_output.is_signed,
-            rhs=rhs_output.output, rhs_signed=rhs_output.is_signed,
-            instances=instances, start_instance_index=start_instance_index + 1)
-        products = (lhs_output.min_value * rhs_output.min_value,
-                    lhs_output.min_value * rhs_output.max_value,
-                    lhs_output.max_value * rhs_output.min_value,
-                    lhs_output.max_value * rhs_output.max_value)
-        prod_inputs.append(EvalOpGenIrInput(
-            ssa_val=ssa_val,
-            is_signed=None,
-            min_value=min(products),
-            max_value=max(products)))
-    prod_eval_state = EvalOpGenIrState(fn=fn, inputs=prod_inputs)
-    prod_parts = [
-        prod_eval_state.get_output(i) for i in instance.prod_eval_ops]
-
-    def partial_products():
-        # type: () -> Iterable[PartialProduct]
-        for part, prod_part in enumerate(prod_parts):
-            part_maxvl = prod_part.output.ty.reg_len
-            part_setvl = fn.append_new_op(
-                OpKind.SetVLI, immediates=[part_maxvl],
-                name=f"prod_{part}_setvl", maxvl=part_maxvl)
-            spread_part = fn.append_new_op(
-                OpKind.Spread,
-                input_vals=[prod_part.output, part_setvl.outputs[0]],
-                name=f"prod_{part}_spread", maxvl=part_maxvl)
-            yield PartialProduct(
-                spread_part.outputs, shift_in_words=part * part_size,
-                is_signed=prod_part.is_signed, subtract=False)
-    return sum_partial_products(fn=fn, partial_products=partial_products(),
-                                retval_size=retval_size, name="prod")
+    # type: (Fn, SSAVal, bool, SSAVal, bool, _TCIs, None | int, int) -> SSAVal
+    return ToomCookMul(
+        fn=fn, lhs=lhs, lhs_signed=lhs_signed, rhs=rhs, rhs_signed=rhs_signed,
+        instances=instances, retval_size=retval_size,
+        start_instance_index=start_instance_index).retval