From df8cf71d060258ff24339caba08c3d2556c7bb72 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Tue, 15 Nov 2022 22:58:14 -0800 Subject: [PATCH] add general TOOM-2 test --- .../_tests/test_toom_cook.py | 98 ++++++++++++++++++- 1 file changed, 97 insertions(+), 1 deletion(-) diff --git a/src/bigint_presentation_code/_tests/test_toom_cook.py b/src/bigint_presentation_code/_tests/test_toom_cook.py index 76b9a5e..41a4e22 100644 --- a/src/bigint_presentation_code/_tests/test_toom_cook.py +++ b/src/bigint_presentation_code/_tests/test_toom_cook.py @@ -1,6 +1,6 @@ from contextlib import contextmanager import unittest -from typing import Any, Callable, ContextManager, Iterator, Tuple +from typing import Any, Callable, ContextManager, Iterator, Tuple, Iterable from bigint_presentation_code.compiler_ir import (GPR_SIZE_IN_BITS, GPR_SIZE_IN_BYTES, @@ -1847,6 +1847,102 @@ class TestToomCook(unittest.TestCase): 'sv.std *4, 0(3)' ]) + def tst_toom_mul_sim( + self, code, # type: Mul + lhs_signed, # type: bool + rhs_signed, # type: bool + get_state_factory, # type: Callable[[Mul], _StateFactory] + test_cases, # type: Iterable[tuple[int, int]] + ): + print(code.retval[1]) + print(code.fn.ops_to_str()) + state_factory = get_state_factory(code) + ptr_in = 0x100 + dest_ptr = ptr_in + code.dest_offset + lhs_ptr = ptr_in + code.lhs_offset + rhs_ptr = ptr_in + code.rhs_offset + lhs_size_in_bits = code.lhs_size_in_words * GPR_SIZE_IN_BITS + rhs_size_in_bits = code.rhs_size_in_words * GPR_SIZE_IN_BITS + for lhs_value, rhs_value in test_cases: + lhs_value %= 1 << lhs_size_in_bits + rhs_value %= 1 << rhs_size_in_bits + if lhs_signed and lhs_value >> (lhs_size_in_bits - 1): + lhs_value -= 1 << lhs_size_in_bits + if rhs_signed and rhs_value >> (rhs_size_in_bits - 1): + rhs_value -= 1 << rhs_size_in_bits + prod_value = lhs_value * rhs_value + lhs_value %= 1 << lhs_size_in_bits + rhs_value %= 1 << rhs_size_in_bits + prod_value %= 1 << (lhs_size_in_bits + rhs_size_in_bits) + with self.subTest(lhs_signed=lhs_signed, rhs_signed=rhs_signed, + lhs_value=hex(lhs_value), + rhs_value=hex(rhs_value), + prod_value=hex(prod_value)): + with state_factory() as state: + state[code.ptr_in] = ptr_in, + for i in range(code.lhs_size_in_words): + v = lhs_value >> GPR_SIZE_IN_BITS * i + v &= GPR_VALUE_MASK + state.store(lhs_ptr + i * GPR_SIZE_IN_BYTES, v) + for i in range(code.rhs_size_in_words): + v = rhs_value >> GPR_SIZE_IN_BITS * i + v &= GPR_VALUE_MASK + state.store(rhs_ptr + i * GPR_SIZE_IN_BYTES, v) + code.fn.sim(state) + prod = 0 + for i in range(code.dest_size_in_words): + v = state.load(dest_ptr + GPR_SIZE_IN_BYTES * i) + prod += v << (GPR_SIZE_IN_BITS * i) + self.assertEqual(hex(prod), hex(prod_value), + f"failed: state={state}") + + def tst_toom_mul_all_sizes_pre_ra_sim(self, instances): + # type: (tuple[ToomCookInstance, ...]) -> None + for lhs_signed in False, True: + for rhs_signed in False, True: + def mul(fn, lhs, rhs): + # type: (Fn, SSAVal, SSAVal) -> tuple[SSAVal, ToomCookMul] + v = ToomCookMul( + fn=fn, lhs=lhs, lhs_signed=lhs_signed, rhs=rhs, + rhs_signed=rhs_signed, instances=instances) + return v.retval, v + for lhs_size_in_words in range(1, 32): + for rhs_size_in_words in range(1, 32): + lhs_size_in_bits = GPR_SIZE_IN_BITS * lhs_size_in_words + rhs_size_in_bits = GPR_SIZE_IN_BITS * rhs_size_in_words + with self.subTest(lhs_size_in_words=lhs_size_in_words, + rhs_size_in_words=rhs_size_in_words, + lhs_signed=lhs_signed, + rhs_signed=rhs_signed): + test_cases = [] # type: list[tuple[int, int]] + test_cases.append((-1, -1)) + test_cases.append(((0x80 << 2048) // 0xFF, + (0x80 << 2048) // 0xFF)) + test_cases.append(((0x40 << 2048) // 0xFF, + (0x80 << 2048) // 0xFF)) + test_cases.append(((0x80 << 2048) // 0xFF, + (0x40 << 2048) // 0xFF)) + test_cases.append(((0x40 << 2048) // 0xFF, + (0x40 << 2048) // 0xFF)) + test_cases.append((1 << (lhs_size_in_bits - 1), + 1 << (rhs_size_in_bits - 1))) + test_cases.append((1, 1 << (rhs_size_in_bits - 1))) + test_cases.append((1 << (lhs_size_in_bits - 1), 1)) + test_cases.append((1, 1)) + self.tst_toom_mul_sim( + code=Mul(mul=mul, + lhs_size_in_words=lhs_size_in_words, + rhs_size_in_words=rhs_size_in_words), + lhs_signed=lhs_signed, rhs_signed=rhs_signed, + get_state_factory=get_pre_ra_state_factory, + test_cases=test_cases) + + def test_toom_2_mul_all_sizes_pre_ra_sim(self): + self.skipTest("broken") # FIXME: fix + TOOM_2 = ToomCookInstance.make_toom_2() + self.tst_toom_mul_all_sizes_pre_ra_sim( + (TOOM_2, TOOM_2, TOOM_2, TOOM_2)) + if __name__ == "__main__": unittest.main() -- 2.30.2