register allocation and simulation works for simple mul 192x192!
[bigint-presentation-code.git] / src / bigint_presentation_code / _tests / test_register_allocator.py
1 import unittest
2
3 from bigint_presentation_code.compiler_ir import (VL, FixedGPRRangeType, Fn,
4 GlobalMem, GPRRange, GPRType,
5 OpBigIntAddSub, OpConcat,
6 OpCopy, OpFuncArg,
7 OpInputMem, OpLI, OpLoad,
8 OpSetCA, OpSetVLImm, OpStore,
9 XERBit)
10 from bigint_presentation_code.register_allocator import (
11 AllocationFailed, MergedRegSet, allocate_registers,
12 try_allocate_registers_without_spilling)
13
14
15 class TestMergedRegSet(unittest.TestCase):
16 maxDiff = None
17
18 def test_from_equality_constraint(self):
19 fn = Fn()
20 li0x1 = OpLI(fn, 0, vl=OpSetVLImm(fn, 1).out)
21 li0x2 = OpLI(fn, 0, vl=OpSetVLImm(fn, 2).out)
22 li0x3 = OpLI(fn, 0, vl=OpSetVLImm(fn, 3).out)
23 self.assertEqual(MergedRegSet.from_equality_constraint([
24 li0x1.out,
25 li0x2.out,
26 li0x3.out,
27 ]), MergedRegSet({
28 li0x1.out: 0,
29 li0x2.out: 1,
30 li0x3.out: 3,
31 }.items()))
32 self.assertEqual(MergedRegSet.from_equality_constraint([
33 li0x2.out,
34 li0x1.out,
35 li0x3.out,
36 ]), MergedRegSet({
37 li0x2.out: 0,
38 li0x1.out: 2,
39 li0x3.out: 3,
40 }.items()))
41
42
43 class TestRegisterAllocator(unittest.TestCase):
44 maxDiff = None
45
46 def test_try_alloc_fail(self):
47 fn = Fn()
48 op0 = OpSetVLImm(fn, 52)
49 op1 = OpLI(fn, 0, vl=op0.out)
50 op2 = OpSetVLImm(fn, 64)
51 op3 = OpLI(fn, 0, vl=op2.out)
52 op4 = OpConcat(fn, [op1.out, op3.out])
53
54 reg_assignments = try_allocate_registers_without_spilling(fn.ops)
55 self.assertEqual(
56 repr(reg_assignments),
57 "AllocationFailed("
58 "node=IGNode(#0, merged_reg_set=MergedRegSet(["
59 "(<#4.dest: <gpr_ty[116]>>, 0), "
60 "(<#1.out: <gpr_ty[52]>>, 0), "
61 "(<#3.out: <gpr_ty[64]>>, 52)]), "
62 "edges={}, reg=None), "
63 "live_intervals=LiveIntervals(live_intervals={"
64 "MergedRegSet([(<#0.out: KnownVLType(length=52)>, 0)]): "
65 "LiveInterval(first_write=0, last_use=1), "
66 "MergedRegSet([(<#4.dest: <gpr_ty[116]>>, 0), "
67 "(<#1.out: <gpr_ty[52]>>, 0), "
68 "(<#3.out: <gpr_ty[64]>>, 52)]): "
69 "LiveInterval(first_write=1, last_use=4), "
70 "MergedRegSet([(<#2.out: KnownVLType(length=64)>, 0)]): "
71 "LiveInterval(first_write=2, last_use=3)}, "
72 "merged_reg_sets=MergedRegSets(data={"
73 "<#0.out: KnownVLType(length=52)>: "
74 "MergedRegSet([(<#0.out: KnownVLType(length=52)>, 0)]), "
75 "<#1.out: <gpr_ty[52]>>: MergedRegSet(["
76 "(<#4.dest: <gpr_ty[116]>>, 0), "
77 "(<#1.out: <gpr_ty[52]>>, 0), "
78 "(<#3.out: <gpr_ty[64]>>, 52)]), "
79 "<#2.out: KnownVLType(length=64)>: "
80 "MergedRegSet([(<#2.out: KnownVLType(length=64)>, 0)]), "
81 "<#3.out: <gpr_ty[64]>>: MergedRegSet(["
82 "(<#4.dest: <gpr_ty[116]>>, 0), "
83 "(<#1.out: <gpr_ty[52]>>, 0), "
84 "(<#3.out: <gpr_ty[64]>>, 52)]), "
85 "<#4.dest: <gpr_ty[116]>>: MergedRegSet(["
86 "(<#4.dest: <gpr_ty[116]>>, 0), "
87 "(<#1.out: <gpr_ty[52]>>, 0), "
88 "(<#3.out: <gpr_ty[64]>>, 52)])}), "
89 "reg_sets_live_after={"
90 "0: OFSet([MergedRegSet(["
91 "(<#0.out: KnownVLType(length=52)>, 0)])]), "
92 "1: OFSet([MergedRegSet(["
93 "(<#4.dest: <gpr_ty[116]>>, 0), "
94 "(<#1.out: <gpr_ty[52]>>, 0), "
95 "(<#3.out: <gpr_ty[64]>>, 52)])]), "
96 "2: OFSet([MergedRegSet(["
97 "(<#4.dest: <gpr_ty[116]>>, 0), "
98 "(<#1.out: <gpr_ty[52]>>, 0), "
99 "(<#3.out: <gpr_ty[64]>>, 52)]), "
100 "MergedRegSet([(<#2.out: KnownVLType(length=64)>, 0)])]), "
101 "3: OFSet([MergedRegSet(["
102 "(<#4.dest: <gpr_ty[116]>>, 0), "
103 "(<#1.out: <gpr_ty[52]>>, 0), "
104 "(<#3.out: <gpr_ty[64]>>, 52)])]), "
105 "4: OFSet()}), "
106 "interference_graph=InterferenceGraph(nodes={"
107 "...: IGNode(#0, merged_reg_set=MergedRegSet(["
108 "(<#0.out: KnownVLType(length=52)>, 0)]), edges={}, reg=None), "
109 "...: IGNode(#1, merged_reg_set=MergedRegSet(["
110 "(<#4.dest: <gpr_ty[116]>>, 0), "
111 "(<#1.out: <gpr_ty[52]>>, 0), "
112 "(<#3.out: <gpr_ty[64]>>, 52)]), edges={}, reg=None), "
113 "...: IGNode(#2, merged_reg_set=MergedRegSet(["
114 "(<#2.out: KnownVLType(length=64)>, 0)]), edges={}, reg=None)}))"
115 )
116
117 def test_try_alloc_bigint_inc(self):
118 fn = Fn()
119 op0 = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3)))
120 op1 = OpCopy(fn, op0.out, GPRType())
121 arg = op1.dest
122 op2 = OpInputMem(fn)
123 mem = op2.out
124 op3 = OpSetVLImm(fn, 32)
125 vl = op3.out
126 op4 = OpLoad(fn, arg, offset=0, mem=mem, vl=vl)
127 a = op4.RT
128 op5 = OpLI(fn, 0, vl=vl)
129 b = op5.out
130 op6 = OpSetCA(fn, True)
131 ca = op6.out
132 op7 = OpBigIntAddSub(fn, a, b, ca, is_sub=False, vl=vl)
133 s = op7.out
134 op8 = OpStore(fn, s, arg, offset=0, mem_in=mem, vl=vl)
135 mem = op8.mem_out
136
137 reg_assignments = try_allocate_registers_without_spilling(fn.ops)
138
139 expected_reg_assignments = {
140 op0.out: GPRRange(start=3, length=1),
141 op1.dest: GPRRange(start=3, length=1),
142 op2.out: GlobalMem.GlobalMem,
143 op3.out: VL.VL_MAXVL,
144 op4.RT: GPRRange(start=78, length=32),
145 op5.out: GPRRange(start=46, length=32),
146 op6.out: XERBit.CA,
147 op7.out: GPRRange(start=14, length=32),
148 op7.CA_out: XERBit.CA,
149 op8.mem_out: GlobalMem.GlobalMem,
150 }
151
152 self.assertEqual(reg_assignments, expected_reg_assignments)
153
154 def tst_try_alloc_concat(self, expected_regs, expected_dest_reg):
155 # type: (list[GPRRange], GPRRange) -> None
156 fn = Fn()
157 inputs = []
158 expected_reg_assignments = {}
159 for i, r in enumerate(expected_regs):
160 vl = OpSetVLImm(fn, r.length).out
161 expected_reg_assignments[vl] = VL.VL_MAXVL
162 inp = OpLI(fn, i, vl=vl).out
163 inputs.append(inp)
164 expected_reg_assignments[inp] = r
165 concat = OpConcat(fn, inputs)
166 expected_reg_assignments[concat.dest] = expected_dest_reg
167
168 reg_assignments = try_allocate_registers_without_spilling(fn.ops)
169
170 for inp, reg in zip(inputs, expected_regs):
171 expected_reg_assignments[inp] = reg
172
173 self.assertEqual(reg_assignments, expected_reg_assignments)
174
175 def test_try_alloc_concat_1(self):
176 self.tst_try_alloc_concat([GPRRange(3)], GPRRange(3))
177
178 def test_try_alloc_concat_3(self):
179 self.tst_try_alloc_concat([GPRRange(3, 3)], GPRRange(3, 3))
180
181 def test_try_alloc_concat_3_5(self):
182 self.tst_try_alloc_concat([GPRRange(3, 3), GPRRange(6, 5)],
183 GPRRange(3, 8))
184
185 def test_try_alloc_concat_5_3(self):
186 self.tst_try_alloc_concat([GPRRange(3, 5), GPRRange(8, 3)],
187 GPRRange(3, 8))
188
189 def test_try_alloc_concat_1_2_3_4_5_6(self):
190 self.tst_try_alloc_concat([
191 GPRRange(14, 1),
192 GPRRange(15, 2),
193 GPRRange(17, 3),
194 GPRRange(20, 4),
195 GPRRange(24, 5),
196 GPRRange(29, 6),
197 ], GPRRange(14, 21))
198
199
200 if __name__ == "__main__":
201 unittest.main()