from contextlib import contextmanager
+import sys
import unittest
from typing import Any, Callable, ContextManager, Iterator, Tuple, Iterable
PostRASimState,
PreRASimState, SSAVal)
from bigint_presentation_code.register_allocator import allocate_registers
+from bigint_presentation_code.register_allocator_test_util import GraphDumper
from bigint_presentation_code.toom_cook import (ToomCookInstance, ToomCookMul,
simple_mul)
from bigint_presentation_code.util import OSet
name="store_dest")
-def get_post_ra_state_factory(code):
- # type: (Mul) -> _StateFactory
- ssa_val_to_loc_map = allocate_registers(code.fn)
-
- @contextmanager
- def state_factory():
- yield PostRASimState(
- ssa_val_to_loc_map=ssa_val_to_loc_map,
- memory={}, loc_values={})
- return state_factory
-
-
class TestToomCook(unittest.TestCase):
maxDiff = None
+ def get_post_ra_state_factory(self, code):
+ # type: (Mul) -> _StateFactory
+ ssa_val_to_loc_map = allocate_registers(
+ code.fn, debug_out=sys.stdout, dump_graph=GraphDumper(self))
+
+ @contextmanager
+ def state_factory():
+ yield PostRASimState(
+ ssa_val_to_loc_map=ssa_val_to_loc_map,
+ memory={}, loc_values={})
+ return state_factory
+
def test_toom_2_repr(self):
TOOM_2 = ToomCookInstance.make_toom_2()
# print(repr(repr(TOOM_2)))
for rhs_signed in False, True:
self.tst_simple_mul_192x192_sim(
lhs_signed=lhs_signed, rhs_signed=rhs_signed,
- get_state_factory=get_post_ra_state_factory)
+ get_state_factory=self.get_post_ra_state_factory)
def tst_simple_mul_192x192_sim(
self, lhs_signed, # type: bool
def test_simple_mul_192x192_reg_alloc(self):
code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3)
fn = code.fn
- assigned_registers = allocate_registers(fn)
+ assigned_registers = allocate_registers(
+ fn, debug_out=sys.stdout, dump_graph=GraphDumper(self))
self.assertEqual(
repr(assigned_registers), "{"
"<store_dest.inp2.setvl.outputs[0]: <VL_MAXVL>>: "
def test_simple_mul_192x192_asm(self):
code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3)
fn = code.fn
- assigned_registers = allocate_registers(fn)
+ assigned_registers = allocate_registers(
+ fn, debug_out=sys.stdout, dump_graph=GraphDumper(self))
gen_asm_state = GenAsmState(assigned_registers)
fn.gen_asm(gen_asm_state)
self.assertEqual(gen_asm_state.output, [
- 'or 27, 3, 3',
+ 'or 9, 3, 3',
'setvl 0, 0, 3, 0, 1, 1',
- 'or 6, 27, 27',
'setvl 0, 0, 3, 0, 1, 1',
- 'sv.ld *3, 48(6)',
+ 'sv.ld *20, 48(9)',
'setvl 0, 0, 3, 0, 1, 1',
- 'sv.or *24, *3, *3',
'setvl 0, 0, 3, 0, 1, 1',
- 'or 6, 27, 27',
'setvl 0, 0, 3, 0, 1, 1',
- 'sv.ld *3, 72(6)',
+ 'sv.ld *10, 72(9)',
'setvl 0, 0, 3, 0, 1, 1',
'setvl 0, 0, 3, 0, 1, 1',
'setvl 0, 0, 3, 0, 1, 1',
'setvl 0, 0, 3, 0, 1, 1',
- 'sv.or/mrr *5, *3, *3',
- 'or 4, 5, 5',
- 'or 14, 6, 6',
- 'or 23, 7, 7',
- 'addi 3, 0, 0',
- 'or 22, 3, 3',
+ 'addi 6, 0, 0',
+ 'or 8, 6, 6',
'setvl 0, 0, 3, 0, 1, 1',
'addi 3, 0, 0',
'setvl 0, 0, 3, 0, 1, 1',
- 'sv.or *8, *24, *24',
- 'or 7, 4, 4',
- 'or 6, 22, 22',
+ 'or 6, 8, 8',
'setvl 0, 0, 3, 0, 1, 1',
- 'sv.maddedu *3, *8, 7, 6',
+ 'sv.maddedu *14, *20, 10, 6',
'setvl 0, 0, 3, 0, 1, 1',
- 'or 19, 6, 6',
+ 'or 17, 6, 6',
'setvl 0, 0, 3, 0, 1, 1',
'setvl 0, 0, 3, 0, 1, 1',
- 'or 21, 3, 3',
- 'or 12, 4, 4',
- 'or 11, 5, 5',
'setvl 0, 0, 3, 0, 1, 1',
- 'sv.or *8, *24, *24',
- 'or 7, 14, 14',
- 'or 6, 22, 22',
+ 'or 6, 8, 8',
'setvl 0, 0, 3, 0, 1, 1',
- 'sv.maddedu *3, *8, 7, 6',
+ 'sv.maddedu *3, *20, 11, 6',
'setvl 0, 0, 3, 0, 1, 1',
- 'or 18, 6, 6',
'setvl 0, 0, 3, 0, 1, 1',
'setvl 0, 0, 3, 0, 1, 1',
- 'or 17, 3, 3',
- 'or 16, 4, 4',
- 'or 15, 5, 5',
- 'addi 3, 0, 0',
- 'or 8, 3, 3',
- 'addi 3, 0, 0',
- 'or 14, 3, 3',
+ 'addi 18, 0, 0',
+ 'addi 7, 0, 0',
'setvl 0, 0, 5, 0, 1, 1',
- 'or 3, 12, 12',
- 'or 4, 11, 11',
- 'or 5, 19, 19',
- 'or 6, 8, 8',
- 'or 7, 8, 8',
+ 'or 19, 18, 18',
'setvl 0, 0, 5, 0, 1, 1',
'setvl 0, 0, 5, 0, 1, 1',
- 'sv.or *8, *3, *3',
- 'or 3, 17, 17',
- 'or 4, 16, 16',
- 'or 5, 15, 15',
- 'or 6, 18, 18',
- 'or 7, 14, 14',
'setvl 0, 0, 5, 0, 1, 1',
'setvl 0, 0, 5, 0, 1, 1',
'addic 0, 0, 0',
'setvl 0, 0, 5, 0, 1, 1',
- 'sv.or *14, *8, *8',
'setvl 0, 0, 5, 0, 1, 1',
- 'sv.or *8, *3, *3',
'setvl 0, 0, 5, 0, 1, 1',
- 'sv.adde *3, *14, *8',
+ 'sv.adde *15, *15, *3',
'setvl 0, 0, 5, 0, 1, 1',
'setvl 0, 0, 5, 0, 1, 1',
'setvl 0, 0, 5, 0, 1, 1',
- 'or 20, 3, 3',
- 'or 19, 4, 4',
- 'or 18, 5, 5',
- 'or 17, 6, 6',
- 'or 16, 7, 7',
'setvl 0, 0, 3, 0, 1, 1',
- 'sv.or *8, *24, *24',
- 'or 7, 23, 23',
- 'or 6, 22, 22',
+ 'or 6, 8, 8',
'setvl 0, 0, 3, 0, 1, 1',
- 'sv.maddedu *3, *8, 7, 6',
+ 'sv.maddedu *3, *20, 12, 6',
'setvl 0, 0, 3, 0, 1, 1',
- 'or 15, 6, 6',
'setvl 0, 0, 3, 0, 1, 1',
'setvl 0, 0, 3, 0, 1, 1',
- 'or 14, 3, 3',
- 'or 12, 4, 4',
- 'or 11, 5, 5',
'setvl 0, 0, 4, 0, 1, 1',
- 'or 3, 19, 19',
- 'or 4, 18, 18',
- 'or 5, 17, 17',
- 'or 6, 16, 16',
'setvl 0, 0, 4, 0, 1, 1',
'setvl 0, 0, 4, 0, 1, 1',
- 'sv.or *7, *3, *3',
- 'or 3, 14, 14',
- 'or 4, 12, 12',
- 'or 5, 11, 11',
- 'or 6, 15, 15',
'setvl 0, 0, 4, 0, 1, 1',
'setvl 0, 0, 4, 0, 1, 1',
'addic 0, 0, 0',
'setvl 0, 0, 4, 0, 1, 1',
- 'sv.or *14, *7, *7',
'setvl 0, 0, 4, 0, 1, 1',
- 'sv.or *7, *3, *3',
'setvl 0, 0, 4, 0, 1, 1',
- 'sv.adde *3, *14, *7',
+ 'sv.adde *16, *16, *3',
'setvl 0, 0, 4, 0, 1, 1',
'setvl 0, 0, 4, 0, 1, 1',
'setvl 0, 0, 4, 0, 1, 1',
- 'or 12, 3, 3',
- 'or 11, 4, 4',
- 'or 10, 5, 5',
- 'or 9, 6, 6',
'setvl 0, 0, 6, 0, 1, 1',
- 'or 3, 21, 21',
- 'or 4, 20, 20',
- 'or 5, 12, 12',
- 'or 6, 11, 11',
- 'or 7, 10, 10',
- 'or 8, 9, 9',
'setvl 0, 0, 6, 0, 1, 1',
'setvl 0, 0, 6, 0, 1, 1',
'setvl 0, 0, 6, 0, 1, 1',
'setvl 0, 0, 6, 0, 1, 1',
- 'sv.or/mrr *4, *3, *3',
- 'or 3, 27, 27',
'setvl 0, 0, 6, 0, 1, 1',
- 'sv.std *4, 0(3)'
+ 'sv.std *14, 0(9)',
])
def toom_2_mul_256x256(self, lhs_signed, rhs_signed):
def test_toom_2_mul_256x256_uu_post_ra_sim(self):
self.tst_toom_2_mul_256x256_sim(
lhs_signed=False, rhs_signed=False,
- get_state_factory=get_post_ra_state_factory)
+ get_state_factory=self.get_post_ra_state_factory)
def test_toom_2_mul_256x256_su_post_ra_sim(self):
self.tst_toom_2_mul_256x256_sim(
lhs_signed=True, rhs_signed=False,
- get_state_factory=get_post_ra_state_factory)
+ get_state_factory=self.get_post_ra_state_factory)
def test_toom_2_mul_256x256_us_post_ra_sim(self):
self.tst_toom_2_mul_256x256_sim(
lhs_signed=False, rhs_signed=True,
- get_state_factory=get_post_ra_state_factory)
+ get_state_factory=self.get_post_ra_state_factory)
def test_toom_2_mul_256x256_ss_post_ra_sim(self):
self.tst_toom_2_mul_256x256_sim(
lhs_signed=True, rhs_signed=True,
- get_state_factory=get_post_ra_state_factory)
+ get_state_factory=self.get_post_ra_state_factory)
def test_toom_2_mul_256x256_asm(self):
code = self.toom_2_mul_256x256(lhs_signed=False, rhs_signed=False)
fn = code.fn
- assigned_registers = allocate_registers(fn)
+ assigned_registers = allocate_registers(
+ fn, debug_out=sys.stdout, dump_graph=GraphDumper(self))
gen_asm_state = GenAsmState(assigned_registers)
fn.gen_asm(gen_asm_state)
self.assertEqual(gen_asm_state.output, [