3 from bigint_presentation_code
.compiler_ir
import (VL
, FixedGPRRangeType
, Fn
,
4 GlobalMem
, GPRRange
, GPRType
,
5 OpBigIntAddSub
, OpConcat
,
7 OpInputMem
, OpLI
, OpLoad
,
8 OpSetCA
, OpSetVLImm
, OpStore
,
10 from bigint_presentation_code
.register_allocator
import (
11 AllocationFailed
, MergedRegSet
, allocate_registers
,
12 try_allocate_registers_without_spilling
)
15 class TestMergedRegSet(unittest
.TestCase
):
18 def test_from_equality_constraint(self
):
20 li0x1
= OpLI(fn
, 0, vl
=OpSetVLImm(fn
, 1).out
)
21 li0x2
= OpLI(fn
, 0, vl
=OpSetVLImm(fn
, 2).out
)
22 li0x3
= OpLI(fn
, 0, vl
=OpSetVLImm(fn
, 3).out
)
23 self
.assertEqual(MergedRegSet
.from_equality_constraint([
32 self
.assertEqual(MergedRegSet
.from_equality_constraint([
43 class TestRegisterAllocator(unittest
.TestCase
):
46 def test_try_alloc_fail(self
):
48 op0
= OpSetVLImm(fn
, 52)
49 op1
= OpLI(fn
, 0, vl
=op0
.out
)
50 op2
= OpSetVLImm(fn
, 64)
51 op3
= OpLI(fn
, 0, vl
=op2
.out
)
52 op4
= OpConcat(fn
, [op1
.out
, op3
.out
])
54 reg_assignments
= try_allocate_registers_without_spilling(fn
.ops
)
56 repr(reg_assignments
),
58 "node=IGNode(#0, merged_reg_set=MergedRegSet(["
59 "(<#4.dest: <gpr_ty[116]>>, 0), "
60 "(<#1.out: <gpr_ty[52]>>, 0), "
61 "(<#3.out: <gpr_ty[64]>>, 52)]), "
62 "edges={}, reg=None), "
63 "live_intervals=LiveIntervals(live_intervals={"
64 "MergedRegSet([(<#0.out: KnownVLType(length=52)>, 0)]): "
65 "LiveInterval(first_write=0, last_use=1), "
66 "MergedRegSet([(<#4.dest: <gpr_ty[116]>>, 0), "
67 "(<#1.out: <gpr_ty[52]>>, 0), "
68 "(<#3.out: <gpr_ty[64]>>, 52)]): "
69 "LiveInterval(first_write=1, last_use=4), "
70 "MergedRegSet([(<#2.out: KnownVLType(length=64)>, 0)]): "
71 "LiveInterval(first_write=2, last_use=3)}, "
72 "merged_reg_sets=MergedRegSets(data={"
73 "<#0.out: KnownVLType(length=52)>: "
74 "MergedRegSet([(<#0.out: KnownVLType(length=52)>, 0)]), "
75 "<#1.out: <gpr_ty[52]>>: MergedRegSet(["
76 "(<#4.dest: <gpr_ty[116]>>, 0), "
77 "(<#1.out: <gpr_ty[52]>>, 0), "
78 "(<#3.out: <gpr_ty[64]>>, 52)]), "
79 "<#2.out: KnownVLType(length=64)>: "
80 "MergedRegSet([(<#2.out: KnownVLType(length=64)>, 0)]), "
81 "<#3.out: <gpr_ty[64]>>: MergedRegSet(["
82 "(<#4.dest: <gpr_ty[116]>>, 0), "
83 "(<#1.out: <gpr_ty[52]>>, 0), "
84 "(<#3.out: <gpr_ty[64]>>, 52)]), "
85 "<#4.dest: <gpr_ty[116]>>: MergedRegSet(["
86 "(<#4.dest: <gpr_ty[116]>>, 0), "
87 "(<#1.out: <gpr_ty[52]>>, 0), "
88 "(<#3.out: <gpr_ty[64]>>, 52)])}), "
89 "reg_sets_live_after={"
90 "0: OFSet([MergedRegSet(["
91 "(<#0.out: KnownVLType(length=52)>, 0)])]), "
92 "1: OFSet([MergedRegSet(["
93 "(<#4.dest: <gpr_ty[116]>>, 0), "
94 "(<#1.out: <gpr_ty[52]>>, 0), "
95 "(<#3.out: <gpr_ty[64]>>, 52)])]), "
96 "2: OFSet([MergedRegSet(["
97 "(<#4.dest: <gpr_ty[116]>>, 0), "
98 "(<#1.out: <gpr_ty[52]>>, 0), "
99 "(<#3.out: <gpr_ty[64]>>, 52)]), "
100 "MergedRegSet([(<#2.out: KnownVLType(length=64)>, 0)])]), "
101 "3: OFSet([MergedRegSet(["
102 "(<#4.dest: <gpr_ty[116]>>, 0), "
103 "(<#1.out: <gpr_ty[52]>>, 0), "
104 "(<#3.out: <gpr_ty[64]>>, 52)])]), "
106 "interference_graph=InterferenceGraph(nodes={"
107 "...: IGNode(#0, merged_reg_set=MergedRegSet(["
108 "(<#0.out: KnownVLType(length=52)>, 0)]), edges={}, reg=None), "
109 "...: IGNode(#1, merged_reg_set=MergedRegSet(["
110 "(<#4.dest: <gpr_ty[116]>>, 0), "
111 "(<#1.out: <gpr_ty[52]>>, 0), "
112 "(<#3.out: <gpr_ty[64]>>, 52)]), edges={}, reg=None), "
113 "...: IGNode(#2, merged_reg_set=MergedRegSet(["
114 "(<#2.out: KnownVLType(length=64)>, 0)]), edges={}, reg=None)}))"
117 def test_try_alloc_bigint_inc(self
):
119 op0
= OpFuncArg(fn
, FixedGPRRangeType(GPRRange(3)))
120 op1
= OpCopy(fn
, op0
.out
, GPRType())
124 op3
= OpSetVLImm(fn
, 32)
126 op4
= OpLoad(fn
, arg
, offset
=0, mem
=mem
, vl
=vl
)
128 op5
= OpLI(fn
, 0, vl
=vl
)
130 op6
= OpSetCA(fn
, True)
132 op7
= OpBigIntAddSub(fn
, a
, b
, ca
, is_sub
=False, vl
=vl
)
134 op8
= OpStore(fn
, s
, arg
, offset
=0, mem_in
=mem
, vl
=vl
)
137 reg_assignments
= try_allocate_registers_without_spilling(fn
.ops
)
139 expected_reg_assignments
= {
140 op0
.out
: GPRRange(start
=3, length
=1),
141 op1
.dest
: GPRRange(start
=3, length
=1),
142 op2
.out
: GlobalMem
.GlobalMem
,
143 op3
.out
: VL
.VL_MAXVL
,
144 op4
.RT
: GPRRange(start
=78, length
=32),
145 op5
.out
: GPRRange(start
=46, length
=32),
147 op7
.out
: GPRRange(start
=14, length
=32),
148 op7
.CA_out
: XERBit
.CA
,
149 op8
.mem_out
: GlobalMem
.GlobalMem
,
152 self
.assertEqual(reg_assignments
, expected_reg_assignments
)
154 def tst_try_alloc_concat(self
, expected_regs
, expected_dest_reg
):
155 # type: (list[GPRRange], GPRRange) -> None
158 expected_reg_assignments
= {}
159 for i
, r
in enumerate(expected_regs
):
160 vl
= OpSetVLImm(fn
, r
.length
).out
161 expected_reg_assignments
[vl
] = VL
.VL_MAXVL
162 inp
= OpLI(fn
, i
, vl
=vl
).out
164 expected_reg_assignments
[inp
] = r
165 concat
= OpConcat(fn
, inputs
)
166 expected_reg_assignments
[concat
.dest
] = expected_dest_reg
168 reg_assignments
= try_allocate_registers_without_spilling(fn
.ops
)
170 for inp
, reg
in zip(inputs
, expected_regs
):
171 expected_reg_assignments
[inp
] = reg
173 self
.assertEqual(reg_assignments
, expected_reg_assignments
)
175 def test_try_alloc_concat_1(self
):
176 self
.tst_try_alloc_concat([GPRRange(3)], GPRRange(3))
178 def test_try_alloc_concat_3(self
):
179 self
.tst_try_alloc_concat([GPRRange(3, 3)], GPRRange(3, 3))
181 def test_try_alloc_concat_3_5(self
):
182 self
.tst_try_alloc_concat([GPRRange(3, 3), GPRRange(6, 5)],
185 def test_try_alloc_concat_5_3(self
):
186 self
.tst_try_alloc_concat([GPRRange(3, 5), GPRRange(8, 3)],
189 def test_try_alloc_concat_1_2_3_4_5_6(self
):
190 self
.tst_try_alloc_concat([
200 if __name__
== "__main__":