From: Jacob Lifshay Date: Mon, 28 Nov 2022 07:41:18 +0000 (-0800) Subject: TOOM-2 multiplication works for all sizes X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=df51119062561a5980a2b0ac587db8f53fc5f231;p=bigint-presentation-code.git TOOM-2 multiplication works for all sizes --- diff --git a/src/bigint_presentation_code/_tests/test_toom_cook.py b/src/bigint_presentation_code/_tests/test_toom_cook.py index 41a4e22..42ab769 100644 --- a/src/bigint_presentation_code/_tests/test_toom_cook.py +++ b/src/bigint_presentation_code/_tests/test_toom_cook.py @@ -11,6 +11,7 @@ from bigint_presentation_code.compiler_ir import (GPR_SIZE_IN_BITS, from bigint_presentation_code.register_allocator import allocate_registers from bigint_presentation_code.toom_cook import (ToomCookInstance, ToomCookMul, simple_mul) +from bigint_presentation_code.util import OSet _StateFactory = Callable[[], ContextManager[BaseSimState]] @@ -1896,52 +1897,94 @@ class TestToomCook(unittest.TestCase): self.assertEqual(hex(prod), hex(prod_value), f"failed: state={state}") - def tst_toom_mul_all_sizes_pre_ra_sim(self, instances): - # type: (tuple[ToomCookInstance, ...]) -> None - for lhs_signed in False, True: - for rhs_signed in False, True: - def mul(fn, lhs, rhs): - # type: (Fn, SSAVal, SSAVal) -> tuple[SSAVal, ToomCookMul] - v = ToomCookMul( - fn=fn, lhs=lhs, lhs_signed=lhs_signed, rhs=rhs, - rhs_signed=rhs_signed, instances=instances) - return v.retval, v - for lhs_size_in_words in range(1, 32): - for rhs_size_in_words in range(1, 32): - lhs_size_in_bits = GPR_SIZE_IN_BITS * lhs_size_in_words - rhs_size_in_bits = GPR_SIZE_IN_BITS * rhs_size_in_words - with self.subTest(lhs_size_in_words=lhs_size_in_words, - rhs_size_in_words=rhs_size_in_words, - lhs_signed=lhs_signed, - rhs_signed=rhs_signed): - test_cases = [] # type: list[tuple[int, int]] - test_cases.append((-1, -1)) - test_cases.append(((0x80 << 2048) // 0xFF, - (0x80 << 2048) // 0xFF)) - test_cases.append(((0x40 << 2048) // 0xFF, - (0x80 << 2048) // 0xFF)) - test_cases.append(((0x80 << 2048) // 0xFF, - (0x40 << 2048) // 0xFF)) - test_cases.append(((0x40 << 2048) // 0xFF, - (0x40 << 2048) // 0xFF)) - test_cases.append((1 << (lhs_size_in_bits - 1), - 1 << (rhs_size_in_bits - 1))) - test_cases.append((1, 1 << (rhs_size_in_bits - 1))) - test_cases.append((1 << (lhs_size_in_bits - 1), 1)) - test_cases.append((1, 1)) - self.tst_toom_mul_sim( - code=Mul(mul=mul, - lhs_size_in_words=lhs_size_in_words, - rhs_size_in_words=rhs_size_in_words), - lhs_signed=lhs_signed, rhs_signed=rhs_signed, - get_state_factory=get_pre_ra_state_factory, - test_cases=test_cases) + def tst_toom_mul_all_sizes_pre_ra_sim(self, instances, lhs_signed, rhs_signed): + # type: (tuple[ToomCookInstance, ...], bool, bool) -> None + def mul(fn, lhs, rhs): + # type: (Fn, SSAVal, SSAVal) -> tuple[SSAVal, ToomCookMul] + v = ToomCookMul( + fn=fn, lhs=lhs, lhs_signed=lhs_signed, rhs=rhs, + rhs_signed=rhs_signed, instances=instances) + return v.retval, v + sizes_in_words = OSet() # type: OSet[int] + for i in range(6): + sizes_in_words.add(1 << i) + sizes_in_words.add(3 << i) + sizes_in_words = OSet( + i for i in sorted(sizes_in_words) if 1 <= i <= 16) + for lhs_size_in_words in sizes_in_words: + for rhs_size_in_words in sizes_in_words: + lhs_size_in_bits = GPR_SIZE_IN_BITS * lhs_size_in_words + rhs_size_in_bits = GPR_SIZE_IN_BITS * rhs_size_in_words + with self.subTest(lhs_size_in_words=lhs_size_in_words, + rhs_size_in_words=rhs_size_in_words, + lhs_signed=lhs_signed, + rhs_signed=rhs_signed): + test_cases = [] # type: list[tuple[int, int]] + test_cases.append((-1, -1)) + test_cases.append(((0x80 << 2048) // 0xFF, + (0x80 << 2048) // 0xFF)) + test_cases.append(((0x40 << 2048) // 0xFF, + (0x80 << 2048) // 0xFF)) + test_cases.append(((0x80 << 2048) // 0xFF, + (0x40 << 2048) // 0xFF)) + test_cases.append(((0x40 << 2048) // 0xFF, + (0x40 << 2048) // 0xFF)) + test_cases.append((1 << (lhs_size_in_bits - 1), + 1 << (rhs_size_in_bits - 1))) + test_cases.append((1, 1 << (rhs_size_in_bits - 1))) + test_cases.append((1 << (lhs_size_in_bits - 1), 1)) + test_cases.append((1, 1)) + self.tst_toom_mul_sim( + code=Mul(mul=mul, + lhs_size_in_words=lhs_size_in_words, + rhs_size_in_words=rhs_size_in_words), + lhs_signed=lhs_signed, rhs_signed=rhs_signed, + get_state_factory=get_pre_ra_state_factory, + test_cases=test_cases) + + def test_toom_2_once_mul_uu_all_sizes_pre_ra_sim(self): + TOOM_2 = ToomCookInstance.make_toom_2() + self.tst_toom_mul_all_sizes_pre_ra_sim( + (TOOM_2,), lhs_signed=False, rhs_signed=False) + + def test_toom_2_once_mul_us_all_sizes_pre_ra_sim(self): + TOOM_2 = ToomCookInstance.make_toom_2() + self.tst_toom_mul_all_sizes_pre_ra_sim( + (TOOM_2,), lhs_signed=False, rhs_signed=True) + + def test_toom_2_once_mul_su_all_sizes_pre_ra_sim(self): + TOOM_2 = ToomCookInstance.make_toom_2() + self.tst_toom_mul_all_sizes_pre_ra_sim( + (TOOM_2,), lhs_signed=True, rhs_signed=False) + + def test_toom_2_once_mul_ss_all_sizes_pre_ra_sim(self): + TOOM_2 = ToomCookInstance.make_toom_2() + self.tst_toom_mul_all_sizes_pre_ra_sim( + (TOOM_2,), lhs_signed=True, rhs_signed=True) + + def test_toom_2_mul_uu_all_sizes_pre_ra_sim(self): + TOOM_2 = ToomCookInstance.make_toom_2() + instances = TOOM_2, TOOM_2, TOOM_2, TOOM_2 + self.tst_toom_mul_all_sizes_pre_ra_sim( + instances, lhs_signed=False, rhs_signed=False) + + def test_toom_2_mul_us_all_sizes_pre_ra_sim(self): + TOOM_2 = ToomCookInstance.make_toom_2() + instances = TOOM_2, TOOM_2, TOOM_2, TOOM_2 + self.tst_toom_mul_all_sizes_pre_ra_sim( + instances, lhs_signed=False, rhs_signed=True) + + def test_toom_2_mul_su_all_sizes_pre_ra_sim(self): + TOOM_2 = ToomCookInstance.make_toom_2() + instances = TOOM_2, TOOM_2, TOOM_2, TOOM_2 + self.tst_toom_mul_all_sizes_pre_ra_sim( + instances, lhs_signed=True, rhs_signed=False) - def test_toom_2_mul_all_sizes_pre_ra_sim(self): - self.skipTest("broken") # FIXME: fix + def test_toom_2_mul_ss_all_sizes_pre_ra_sim(self): TOOM_2 = ToomCookInstance.make_toom_2() + instances = TOOM_2, TOOM_2, TOOM_2, TOOM_2 self.tst_toom_mul_all_sizes_pre_ra_sim( - (TOOM_2, TOOM_2, TOOM_2, TOOM_2)) + instances, lhs_signed=True, rhs_signed=True) if __name__ == "__main__": diff --git a/src/bigint_presentation_code/toom_cook.py b/src/bigint_presentation_code/toom_cook.py index a7c5450..4ceb3d7 100644 --- a/src/bigint_presentation_code/toom_cook.py +++ b/src/bigint_presentation_code/toom_cook.py @@ -165,7 +165,7 @@ class EvalOpPoly: @final class EvalOpValueRange: __slots__ = ("eval_op", "inputs", "min_value", "max_value", - "is_signed", "output_size") + "is_signed", "output_size", "name_part") def __init__(self, eval_op, inputs): # type: (EvalOp | int, tuple[EvalOpGenIrInput, ...]) -> None @@ -199,6 +199,10 @@ class EvalOpValueRange: min_v <<= GPR_SIZE_IN_BITS max_v <<= GPR_SIZE_IN_BITS self.output_size = output_size + if isinstance(eval_op, int): + self.name_part = f"const_{eval_op}" + else: + self.name_part = eval_op.name_part @cached_property def poly(self): @@ -314,13 +318,14 @@ class EvalOpGenIrInput: @plain_data(frozen=True) @final class EvalOpGenIrState: - __slots__ = "fn", "inputs", "outputs_map" + __slots__ = "fn", "inputs", "outputs_map", "name" - def __init__(self, fn, inputs): - # type: (Fn, Iterable[EvalOpGenIrInput]) -> None + def __init__(self, fn, inputs, name): + # type: (Fn, Iterable[EvalOpGenIrInput], str) -> None super().__init__() self.fn = fn self.inputs = tuple(inputs) + self.name = name self.outputs_map = {} # type: dict[EvalOp | int, EvalOpGenIrOutput] def get_output(self, eval_op): @@ -330,12 +335,13 @@ class EvalOpGenIrState: return retval value_range = EvalOpValueRange(eval_op=eval_op, inputs=self.inputs) if isinstance(eval_op, int): + name = f"{self.name}_{EvalOp.get_name_part(eval_op)}" li = self.fn.append_new_op(OpKind.LI, immediates=[eval_op], - name=f"li_{eval_op}") + name=f"{name}_li") output = cast_to_size( fn=self.fn, ssa_val=li.outputs[0], dest_size=value_range.output_size, - src_signed=value_range.is_signed, name=f"cast_{eval_op}") + src_signed=value_range.is_signed, name=f"{name}_case") retval = EvalOpGenIrOutput(output=output, value_range=value_range) else: retval = eval_op.make_output(state=self, @@ -373,6 +379,31 @@ class EvalOp(metaclass=InternedMeta): # type: (EvalOpGenIrState, EvalOpValueRange) -> EvalOpGenIrOutput ... + @cached_property + @abstractmethod + def name_part(self): + # type: () -> str + ... + + @staticmethod + def get_name_part(eval_op): + # type: (EvalOp | int) -> str + if isinstance(eval_op, int): + return f"const_{eval_op}" + return eval_op.name_part + + @property + @final + def lhs_name_part(self): + # type: () -> str + return EvalOp.get_name_part(self.lhs) + + @property + @final + def rhs_name_part(self): + # type: () -> str + return EvalOp.get_name_part(self.rhs) + def __init__(self, lhs, rhs): # type: (EvalOp | int, EvalOp | int) -> None super().__init__() @@ -390,26 +421,34 @@ class EvalOpAdd(EvalOp): # type: () -> EvalOpPoly return self.lhs_poly + self.rhs_poly + @cached_property + def name_part(self): + # type: () -> str + return f"({self.lhs_name_part}+{self.rhs_name_part})" + def make_output(self, state, output_value_range): # type: (EvalOpGenIrState, EvalOpValueRange) -> EvalOpGenIrOutput lhs = state.get_output(self.lhs) lhs_output = cast_to_size( fn=state.fn, ssa_val=lhs.output, dest_size=output_value_range.output_size, src_signed=lhs.is_signed, - name="add_lhs_cast") + name=f"{state.name}_{self.name_part}_lhs_cast") rhs = state.get_output(self.rhs) rhs_output = cast_to_size( fn=state.fn, ssa_val=rhs.output, dest_size=output_value_range.output_size, src_signed=rhs.is_signed, - name="add_rhs_cast") + name=f"{state.name}_{self.name_part}_rhs_cast") setvl = state.fn.append_new_op( OpKind.SetVLI, immediates=[output_value_range.output_size], - name="setvl", maxvl=output_value_range.output_size) - clear_ca = state.fn.append_new_op(OpKind.ClearCA, name="clear_ca") + name=f"{state.name}_{self.name_part}_setvl", + maxvl=output_value_range.output_size) + clear_ca = state.fn.append_new_op( + OpKind.ClearCA, name=f"{state.name}_{self.name_part}_clear_ca") add = state.fn.append_new_op( OpKind.SvAddE, input_vals=[ lhs_output, rhs_output, clear_ca.outputs[0], setvl.outputs[0]], - maxvl=output_value_range.output_size, name="add") + maxvl=output_value_range.output_size, + name=f"{state.name}_{self.name_part}_add") return EvalOpGenIrOutput( output=add.outputs[0], value_range=output_value_range) @@ -423,26 +462,34 @@ class EvalOpSub(EvalOp): # type: () -> EvalOpPoly return self.lhs_poly - self.rhs_poly + @cached_property + def name_part(self): + # type: () -> str + return f"({self.lhs_name_part}-{self.rhs_name_part})" + def make_output(self, state, output_value_range): # type: (EvalOpGenIrState, EvalOpValueRange) -> EvalOpGenIrOutput lhs = state.get_output(self.lhs) lhs_output = cast_to_size( fn=state.fn, ssa_val=lhs.output, dest_size=output_value_range.output_size, src_signed=lhs.is_signed, - name="add_lhs_cast") + name=f"{state.name}_{self.name_part}_lhs_cast") rhs = state.get_output(self.rhs) rhs_output = cast_to_size( fn=state.fn, ssa_val=rhs.output, dest_size=output_value_range.output_size, src_signed=rhs.is_signed, - name="add_rhs_cast") + name=f"{state.name}_{self.name_part}_rhs_cast") setvl = state.fn.append_new_op( OpKind.SetVLI, immediates=[output_value_range.output_size], - name="setvl", maxvl=output_value_range.output_size) - set_ca = state.fn.append_new_op(OpKind.SetCA, name="set_ca") + name=f"{state.name}_{self.name_part}_setvl", + maxvl=output_value_range.output_size) + set_ca = state.fn.append_new_op( + OpKind.SetCA, name=f"{state.name}_{self.name_part}_set_ca") sub = state.fn.append_new_op( OpKind.SvSubFE, input_vals=[ rhs_output, lhs_output, set_ca.outputs[0], setvl.outputs[0]], - maxvl=output_value_range.output_size, name="sub") + maxvl=output_value_range.output_size, + name=f"{state.name}_{self.name_part}_sub") return EvalOpGenIrOutput( output=sub.outputs[0], value_range=output_value_range) @@ -459,6 +506,11 @@ class EvalOpMul(EvalOp): raise TypeError("invalid rhs type") return self.lhs_poly * self.rhs + @cached_property + def name_part(self): + # type: () -> str + return f"({self.lhs_name_part}*{self.rhs})" + def make_output(self, state, output_value_range): # type: (EvalOpGenIrState, EvalOpValueRange) -> EvalOpGenIrOutput raise NotImplementedError # FIXME: finish @@ -476,6 +528,11 @@ class EvalOpExactDiv(EvalOp): raise TypeError("invalid rhs type") return self.lhs_poly / self.rhs + @cached_property + def name_part(self): + # type: () -> str + return f"({self.lhs_name_part}/{self.rhs})" + def make_output(self, state, output_value_range): # type: (EvalOpGenIrState, EvalOpValueRange) -> EvalOpGenIrOutput raise NotImplementedError # FIXME: finish @@ -500,6 +557,11 @@ class EvalOpInput(EvalOp): def part_index(self): return self.lhs + @cached_property + def name_part(self): + # type: () -> str + return f"part[{self.part_index}]" + def _make_poly(self): # type: () -> EvalOpPoly return EvalOpPoly({self.part_index: 1}) @@ -510,7 +572,7 @@ class EvalOpInput(EvalOp): output = cast_to_size( fn=state.fn, ssa_val=inp.ssa_val, src_signed=inp.is_signed, dest_size=output_value_range.output_size, - name=f"input_{self.part_index}_cast") + name=f"{state.name}_{self.name_part}_cast") return EvalOpGenIrOutput(output=output, value_range=output_value_range) @@ -951,6 +1013,8 @@ def split_into_exact_sized_parts(fn, ssa_val, part_count, part_size, name): for part in range(part_count): start = part * part_size stop = min(maxvl, start + part_size) + if part == part_count - 1: + stop = maxvl part_maxvl = stop - start part_setvl = fn.append_new_op( OpKind.SetVLI, immediates=[part_size], maxvl=part_size, @@ -971,16 +1035,25 @@ _TCIs = Tuple[ToomCookInstance, ...] class ToomCookMul: __slots__ = ( "fn", "lhs", "lhs_signed", "rhs", "rhs_signed", "instances", - "retval_size", "start_instance_index", "instance", "part_size", + "retval_size", "start_instance_index", "name", "instance", "part_size", "lhs_parts", "lhs_inputs", "lhs_eval_state", "lhs_outputs", "rhs_parts", "rhs_inputs", "rhs_eval_state", "rhs_outputs", "prod_inputs", "prod_eval_state", "prod_parts", "partial_products", "retval", ) - def __init__(self, fn, lhs, lhs_signed, rhs, rhs_signed, instances, - retval_size=None, start_instance_index=0): - # type: (Fn, SSAVal, bool, SSAVal, bool, _TCIs, None | int, int) -> None + def __init__( + self, fn, # type: Fn + lhs, # type: SSAVal + lhs_signed, # type: bool + rhs, # type: SSAVal + rhs_signed, # type: bool + instances, # type: _TCIs + retval_size=None, # type: None | int + name=None, # type: None | str + start_instance_index=0, # type: int + ): + # type: (...) -> None self.fn = fn self.lhs = lhs self.lhs_signed = lhs_signed @@ -990,6 +1063,9 @@ class ToomCookMul: if retval_size is None: retval_size = lhs.ty.reg_len + rhs.ty.reg_len self.retval_size = retval_size + if name is None: + name = "mul" + self.name = name if start_instance_index < 0: raise ValueError("start_instance_index must be non-negative") self.start_instance_index = start_instance_index @@ -997,9 +1073,16 @@ class ToomCookMul: self.part_size = 0 # type: int while start_instance_index < len(instances): self.instance = instances[start_instance_index] - self.part_size = max( - lhs.ty.reg_len // self.instance.lhs_part_count, - rhs.ty.reg_len // self.instance.rhs_part_count) + self.part_size = 0 + # FIXME: this loop is some kind of integer division, + # figure out the correct formula + for shift in reversed(range(6)): + next_part_size = self.part_size + (1 << shift) + if (lhs.ty.reg_len > ( + self.instance.lhs_part_count - 1) * next_part_size + and rhs.ty.reg_len > ( + self.instance.rhs_part_count - 1) * next_part_size): + self.part_size = next_part_size if self.part_size <= 0: self.instance = None start_instance_index += 1 @@ -1009,40 +1092,45 @@ class ToomCookMul: self.retval = simple_mul(fn=fn, lhs=lhs, lhs_signed=lhs_signed, rhs=rhs, rhs_signed=rhs_signed, - name="toom_cook_base_case") + name=f"{name}_base_case") return self.lhs_parts = split_into_exact_sized_parts( fn=fn, ssa_val=lhs, part_count=self.instance.lhs_part_count, - part_size=self.part_size, name="lhs") + part_size=self.part_size, name=f"{name}_lhs") self.lhs_inputs = [] # type: list[EvalOpGenIrInput] for part, ssa_val in enumerate(self.lhs_parts): self.lhs_inputs.append(EvalOpGenIrInput( ssa_val=ssa_val, is_signed=lhs_signed and part == len(self.lhs_parts) - 1)) - self.lhs_eval_state = EvalOpGenIrState(fn=fn, inputs=self.lhs_inputs) + self.lhs_eval_state = EvalOpGenIrState( + fn=fn, inputs=self.lhs_inputs, name=f"{name}_lhs_eval") lhs_eval_ops = self.instance.lhs_eval_ops self.lhs_outputs = [ self.lhs_eval_state.get_output(i) for i in lhs_eval_ops] self.rhs_parts = split_into_exact_sized_parts( fn=fn, ssa_val=rhs, part_count=self.instance.rhs_part_count, - part_size=self.part_size, name="rhs") + part_size=self.part_size, name=f"{name}_rhs") self.rhs_inputs = [] # type: list[EvalOpGenIrInput] for part, ssa_val in enumerate(self.rhs_parts): self.rhs_inputs.append(EvalOpGenIrInput( ssa_val=ssa_val, is_signed=rhs_signed and part == len(self.rhs_parts) - 1)) - self.rhs_eval_state = EvalOpGenIrState(fn=fn, inputs=self.rhs_inputs) + self.rhs_eval_state = EvalOpGenIrState( + fn=fn, inputs=self.rhs_inputs, name=f"{name}_rhs_eval") rhs_eval_ops = self.instance.rhs_eval_ops self.rhs_outputs = [ self.rhs_eval_state.get_output(i) for i in rhs_eval_ops] self.prod_inputs = [] # type: list[EvalOpGenIrInput] - for lhs_output, rhs_output in zip(self.lhs_outputs, self.rhs_outputs): - ssa_val = toom_cook_mul( + for point_index, (lhs_output, rhs_output) in enumerate( + zip(self.lhs_outputs, self.rhs_outputs)): + ssa_val = ToomCookMul( fn=fn, lhs=lhs_output.output, lhs_signed=lhs_output.is_signed, rhs=rhs_output.output, rhs_signed=rhs_output.is_signed, instances=instances, - start_instance_index=start_instance_index + 1) + start_instance_index=start_instance_index + 1, + retval_size=None, + name=f"{name}_pt{point_index}").retval products = (lhs_output.min_value * rhs_output.min_value, lhs_output.min_value * rhs_output.max_value, lhs_output.max_value * rhs_output.min_value, @@ -1052,7 +1140,8 @@ class ToomCookMul: is_signed=None, min_value=min(products), max_value=max(products))) - self.prod_eval_state = EvalOpGenIrState(fn=fn, inputs=self.prod_inputs) + self.prod_eval_state = EvalOpGenIrState( + fn=fn, inputs=self.prod_inputs, name=f"{name}_prod_eval") prod_eval_ops = self.instance.prod_eval_ops self.prod_parts = [ self.prod_eval_state.get_output(i) for i in prod_eval_ops] @@ -1063,24 +1152,31 @@ class ToomCookMul: part_maxvl = prod_part.output.ty.reg_len part_setvl = fn.append_new_op( OpKind.SetVLI, immediates=[part_maxvl], - name=f"prod_{part}_setvl", maxvl=part_maxvl) + name=f"{name}_prod_{part}_setvl", maxvl=part_maxvl) spread_part = fn.append_new_op( OpKind.Spread, input_vals=[prod_part.output, part_setvl.outputs[0]], - name=f"prod_{part}_spread", maxvl=part_maxvl) + name=f"{name}_prod_{part}_spread", maxvl=part_maxvl) yield PartialProduct( spread_part.outputs, shift_in_words=part * self.part_size, is_signed=prod_part.is_signed, subtract=False) self.partial_products = tuple(partial_products()) self.retval = sum_partial_products( fn=fn, partial_products=self.partial_products, - retval_size=retval_size, name="prod") - - -def toom_cook_mul(fn, lhs, lhs_signed, rhs, rhs_signed, instances, - retval_size=None, start_instance_index=0): - # type: (Fn, SSAVal, bool, SSAVal, bool, _TCIs, None | int, int) -> SSAVal + retval_size=retval_size, name=f"{name}_sum_p_prods") + + +def toom_cook_mul( + fn, # type: Fn + lhs, # type: SSAVal + lhs_signed, # type: bool + rhs, # type: SSAVal + rhs_signed, # type: bool + instances, # type: _TCIs + retval_size=None, # type: None | int + name=None, # type: None | str +): + # type: (...) -> SSAVal return ToomCookMul( fn=fn, lhs=lhs, lhs_signed=lhs_signed, rhs=rhs, rhs_signed=rhs_signed, - instances=instances, retval_size=retval_size, - start_instance_index=start_instance_index).retval + instances=instances, retval_size=retval_size, name=name).retval