820c305a70d995ecd2912883d6798e364630bbcf
[bigint-presentation-code.git] / src / bigint_presentation_code / _tests / test_compiler_ir.py
1 import unittest
2
3 from bigint_presentation_code.compiler_ir import (VL, FixedGPRRangeType, Fn,
4 GlobalMem, GPRRange, GPRType,
5 OpBigIntAddSub, OpConcat,
6 OpCopy, OpFuncArg,
7 OpInputMem, OpLI, OpLoad,
8 OpSetCA, OpSetVLImm, OpStore,
9 RegLoc, SSAVal, XERBit,
10 generate_assembly,
11 op_set_to_list)
12
13
14 class TestCompilerIR(unittest.TestCase):
15 maxDiff = None
16
17 def test_op_set_to_list(self):
18 fn = Fn()
19 op0 = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3)))
20 op1 = OpCopy(fn, op0.out, GPRType())
21 arg = op1.dest
22 op2 = OpInputMem(fn)
23 mem = op2.out
24 op3 = OpSetVLImm(fn, 32)
25 vl = op3.out
26 op4 = OpLoad(fn, arg, offset=0, mem=mem, vl=vl)
27 a = op4.RT
28 op5 = OpLI(fn, 1)
29 b_0 = op5.out
30 op6 = OpSetVLImm(fn, 31)
31 vl = op6.out
32 op7 = OpLI(fn, 0, vl=vl)
33 b_rest = op7.out
34 op8 = OpConcat(fn, [b_0, b_rest])
35 b = op8.dest
36 op9 = OpSetVLImm(fn, 32)
37 vl = op9.out
38 op10 = OpSetCA(fn, False)
39 ca = op10.out
40 op11 = OpBigIntAddSub(fn, a, b, ca, is_sub=False, vl=vl)
41 s = op11.out
42 op12 = OpStore(fn, s, arg, offset=0, mem_in=mem, vl=vl)
43 mem = op12.mem_out
44
45 expected_ops = [
46 op10, # OpSetCA(fn, False)
47 op9, # OpSetVLImm(fn, 32)
48 op6, # OpSetVLImm(fn, 31)
49 op5, # OpLI(fn, 1)
50 op3, # OpSetVLImm(fn, 32)
51 op2, # OpInputMem(fn)
52 op0, # OpFuncArg(fn, FixedGPRRangeType(GPRRange(3)))
53 op7, # OpLI(fn, 0, vl=vl)
54 op1, # OpCopy(fn, op0.out, GPRType())
55 op8, # OpConcat(fn, [b_0, b_rest])
56 op4, # OpLoad(fn, arg, offset=0, mem=mem, vl=vl)
57 op11, # OpBigIntAddSub(fn, a, b, ca, is_sub=False, vl=vl)
58 op12, # OpStore(fn, s, arg, offset=0, mem_in=mem, vl=vl)
59 ]
60
61 ops = op_set_to_list(fn.ops[::-1])
62 if ops != expected_ops:
63 self.assertEqual(repr(ops), repr(expected_ops))
64
65 def tst_generate_assembly(self, use_reg_alloc=False):
66 fn = Fn()
67 op0 = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3)))
68 op1 = OpCopy(fn, op0.out, GPRType())
69 arg = op1.dest
70 op2 = OpInputMem(fn)
71 mem = op2.out
72 op3 = OpSetVLImm(fn, 32)
73 vl = op3.out
74 op4 = OpLoad(fn, arg, offset=0, mem=mem, vl=vl)
75 a = op4.RT
76 op5 = OpLI(fn, 0, vl=vl)
77 b = op5.out
78 op6 = OpSetCA(fn, True)
79 ca = op6.out
80 op7 = OpBigIntAddSub(fn, a, b, ca, is_sub=False, vl=vl)
81 s = op7.out
82 op8 = OpStore(fn, s, arg, offset=0, mem_in=mem, vl=vl)
83 mem = op8.mem_out
84
85 assigned_registers = {
86 op0.out: GPRRange(start=3, length=1),
87 op1.dest: GPRRange(start=3, length=1),
88 op2.out: GlobalMem.GlobalMem,
89 op3.out: VL.VL_MAXVL,
90 op4.RT: GPRRange(start=78, length=32),
91 op5.out: GPRRange(start=46, length=32),
92 op6.out: XERBit.CA,
93 op7.out: GPRRange(start=14, length=32),
94 op7.CA_out: XERBit.CA,
95 op8.mem_out: GlobalMem.GlobalMem,
96 } # type: dict[SSAVal, RegLoc] | None
97
98 if use_reg_alloc:
99 assigned_registers = None
100
101 asm = generate_assembly(fn.ops, assigned_registers)
102 self.assertEqual(asm, [
103 "setvl 0, 0, 32, 0, 1, 1",
104 "sv.ld *78, 0(3)",
105 "sv.addi *46, 0, 0",
106 "subfic 0, 0, -1",
107 "sv.adde *14, *78, *46",
108 "sv.std *14, 0(3)",
109 "bclr 20, 0, 0",
110 ])
111
112 def test_generate_assembly(self):
113 self.tst_generate_assembly()
114
115 def test_generate_assembly_with_register_allocator(self):
116 self.tst_generate_assembly(use_reg_alloc=True)
117
118
119 if __name__ == "__main__":
120 unittest.main()