import sys
import unittest
-import shutil
from bigint_presentation_code.compiler_ir import (Fn, GenAsmState, OpKind,
SSAVal)
from bigint_presentation_code.register_allocator import allocate_registers
-from nmutil.get_test_path import get_test_path
-
-
-def dump_graphs(test_case, graphs):
- # type: (unittest.TestCase, dict[str, str]) -> None
- base_path = get_test_path(test_case, "dumped_graphs")
- shutil.rmtree(base_path, ignore_errors=True)
- base_path.mkdir(parents=True, exist_ok=True)
- for name, dot in graphs.items():
- path = base_path / name
- dot_path = path.with_suffix(".dot")
- dot_path.write_text(dot)
+from bigint_presentation_code.register_allocator_test_util import GraphDumper
class TestRegisterAllocator(unittest.TestCase):
def test_register_allocate(self):
fn, _arg = self.make_add_fn()
- reg_assignments = allocate_registers(fn, debug_out=sys.stdout)
+ reg_assignments = allocate_registers(
+ fn, debug_out=sys.stdout, dump_graph=GraphDumper(self))
self.assertEqual(
repr(reg_assignments),
def test_gen_asm(self):
fn, _arg = self.make_add_fn()
- reg_assignments = allocate_registers(fn)
+ reg_assignments = allocate_registers(
+ fn, debug_out=sys.stdout, dump_graph=GraphDumper(self))
self.assertEqual(
repr(reg_assignments),
"Loc(kind=LocKind.GPR, start=14, reg_len=32), "
"<ld.outputs[0]: <I64*32>>: "
"Loc(kind=LocKind.GPR, start=14, reg_len=32), "
- "<ld.inp0.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
"<arg.outputs[0]: <I64>>: "
"Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+ "<arg.out0.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+ "<ld.inp0.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
"<st.inp1.copy.outputs[0]: <I64>>: "
"Loc(kind=LocKind.GPR, start=3, reg_len=1), "
- "<arg.out0.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
"<st.inp2.setvl.outputs[0]: <VL_MAXVL>>: "
"Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
"<st.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
state = GenAsmState(reg_assignments)
fn.gen_asm(state)
self.assertEqual(state.output, [
- 'or 4, 3, 3',
'setvl 0, 0, 32, 0, 1, 1',
- 'or 3, 4, 4',
'setvl 0, 0, 32, 0, 1, 1',
'sv.ld *14, 0(3)',
'setvl 0, 0, 32, 0, 1, 1',
'sv.adde *14, *78, *46',
'setvl 0, 0, 32, 0, 1, 1',
'setvl 0, 0, 32, 0, 1, 1',
- 'or 3, 4, 4',
'setvl 0, 0, 32, 0, 1, 1',
- 'sv.std *14, 0(3)',
+ 'sv.std *14, 0(3)'
])
def test_register_allocate_graphs(self):
fn, _arg = self.make_add_fn()
graphs = {} # type: dict[str, str]
+ graph_dumper = GraphDumper(self)
+
def dump_graph(name, dot):
# type: (str, str) -> None
self.assertNotIn(name, graphs, "duplicate graph name")
graphs[name] = dot
- allocated = allocate_registers(fn, dump_graph=dump_graph,
- debug_out=sys.stdout)
- dump_graphs(self, graphs)
+ graph_dumper(name, dot)
+ allocated = allocate_registers(
+ fn, debug_out=sys.stdout, dump_graph=dump_graph)
self.assertEqual(
repr(allocated),
"{"
"Loc(kind=LocKind.GPR, start=14, reg_len=32), "
"<ld.outputs[0]: <I64*32>>: "
"Loc(kind=LocKind.GPR, start=14, reg_len=32), "
- "<ld.inp0.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
"<arg.outputs[0]: <I64>>: "
"Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+ "<arg.out0.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+ "<ld.inp0.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
"<st.inp1.copy.outputs[0]: <I64>>: "
"Loc(kind=LocKind.GPR, start=3, reg_len=1), "
- "<arg.out0.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
"<st.inp2.setvl.outputs[0]: <VL_MAXVL>>: "
"Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
"<st.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
_concat = fn.append_new_op(
OpKind.Concat, input_vals=[*spread[::-1], vl],
name="concat", maxvl=maxvl)
- reg_assignments = allocate_registers(fn, debug_out=sys.stdout)
+ reg_assignments = allocate_registers(
+ fn, debug_out=sys.stdout, dump_graph=GraphDumper(self))
self.assertEqual(
repr(reg_assignments),