8e2952fc6fd015bf1b52ce662c3bb2f816fb9a88
[soc.git] / src / soc / fu / shift_rot / formal / proof_main_stage.py
1 # Proof of correctness for shift/rotate FU
2 # Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
3 """
4 Links:
5 * https://bugs.libre-soc.org/show_bug.cgi?id=340
6 """
7
8 import unittest
9 import enum
10 from nmigen import (Module, Signal, Elaboratable, Mux, Cat, Repl,
11 signed, Const, unsigned)
12 from nmigen.asserts import Assert, AnyConst, Assume
13 from nmutil.formaltest import FHDLTestCase
14 from nmutil.sim_util import do_sim
15 from nmigen.sim import Delay
16
17 from soc.fu.shift_rot.main_stage import ShiftRotMainStage
18 from soc.fu.shift_rot.pipe_data import ShiftRotPipeSpec
19 from openpower.decoder.power_enums import MicrOp
20
21
22 @enum.unique
23 class TstOp(enum.Enum):
24 """ops we're testing, the idea is if we run a separate formal proof for
25 each instruction, we end up covering them all and each runs much faster,
26 also the formal proofs can be run in parallel."""
27 SHL = MicrOp.OP_SHL
28 SHR = MicrOp.OP_SHR
29 RLC32 = MicrOp.OP_RLC, 32
30 RLC64 = MicrOp.OP_RLC, 64
31 RLCL = MicrOp.OP_RLCL
32 RLCR = MicrOp.OP_RLCR
33 EXTSWSLI = MicrOp.OP_EXTSWSLI
34 TERNLOG = MicrOp.OP_TERNLOG
35 GREV32 = MicrOp.OP_GREV, 32
36 GREV64 = MicrOp.OP_GREV, 64
37
38 @property
39 def op(self):
40 if isinstance(self.value, tuple):
41 return self.value[0]
42 return self.value
43
44
45 def eq_any_const(sig: Signal):
46 return sig.eq(AnyConst(sig.shape(), src_loc_at=1))
47
48
49 class Mask(Elaboratable):
50 # copied from qemu's mask fn:
51 # https://gitlab.com/qemu-project/qemu/-/blob/477c3b934a47adf7de285863f59d6e4503dd1a6d/target/ppc/internal.h#L21
52 def __init__(self):
53 self.start = Signal(6)
54 self.end = Signal(6)
55 self.out = Signal(64)
56
57 def elaborate(self, platform):
58 m = Module()
59 max_val = Const(~0, unsigned(64))
60 max_bit = 63
61 with m.If(self.start == 0):
62 m.d.comb += self.out.eq(max_val << (max_bit - self.end))
63 with m.Elif(self.end == max_bit):
64 m.d.comb += self.out.eq(max_val >> self.start)
65 with m.Else():
66 ret = (max_val >> self.start) ^ ((max_val >> self.end) >> 1)
67 m.d.comb += self.out.eq(Mux(self.start > self.end, ~ret, ret))
68 return m
69
70
71 class TstMask(unittest.TestCase):
72 def test_mask(self):
73 dut = Mask()
74
75 def case(start, end, expected):
76 with self.subTest(start=start, end=end):
77 yield dut.start.eq(start)
78 yield dut.end.eq(end)
79 yield Delay(1e-6)
80 out = yield dut.out
81 with self.subTest(out=hex(out), expected=hex(expected)):
82 self.assertEqual(expected, out)
83
84 def process():
85 for start in range(64):
86 for end in range(64):
87 expected = 0
88 if start > end:
89 for i in range(start, 64):
90 expected |= 1 << (63 - i)
91 for i in range(0, end + 1):
92 expected |= 1 << (63 - i)
93 else:
94 for i in range(start, end + 1):
95 expected |= 1 << (63 - i)
96 yield from case(start, end, expected)
97 with do_sim(self, dut, [dut.start, dut.end, dut.out]) as sim:
98 sim.add_process(process)
99 sim.run()
100
101
102 def rotl64(v, amt):
103 v |= Const(0, 64) # convert to value at least 64-bits wide
104 amt |= Const(0, 6) # convert to value at least 6-bits wide
105 return (Cat(v[:64], v[:64]) >> (64 - amt[:6]))[:64]
106
107
108 def rotl32(v, amt):
109 v |= Const(0, 32) # convert to value at least 32-bits wide
110 return rotl64(Cat(v[:32], v[:32]), amt)
111
112
113 # This defines a module to drive the device under test and assert
114 # properties about its outputs
115 class Driver(Elaboratable):
116 def __init__(self, which):
117 assert isinstance(which, TstOp)
118 self.which = which
119
120 def elaborate(self, platform):
121 m = Module()
122 comb = m.d.comb
123
124 pspec = ShiftRotPipeSpec(id_wid=2, parent_pspec=None)
125 pspec.draft_bitmanip = True
126 m.submodules.dut = dut = ShiftRotMainStage(pspec)
127
128 # Set inputs to formal variables
129 comb += [
130 eq_any_const(dut.i.ctx.op.insn_type),
131 eq_any_const(dut.i.ctx.op.fn_unit),
132 eq_any_const(dut.i.ctx.op.imm_data.data),
133 eq_any_const(dut.i.ctx.op.imm_data.ok),
134 eq_any_const(dut.i.ctx.op.rc.rc),
135 eq_any_const(dut.i.ctx.op.rc.ok),
136 eq_any_const(dut.i.ctx.op.oe.oe),
137 eq_any_const(dut.i.ctx.op.oe.ok),
138 eq_any_const(dut.i.ctx.op.write_cr0),
139 eq_any_const(dut.i.ctx.op.input_carry),
140 eq_any_const(dut.i.ctx.op.output_carry),
141 eq_any_const(dut.i.ctx.op.input_cr),
142 eq_any_const(dut.i.ctx.op.is_32bit),
143 eq_any_const(dut.i.ctx.op.is_signed),
144 eq_any_const(dut.i.ctx.op.insn),
145 eq_any_const(dut.i.xer_ca),
146 eq_any_const(dut.i.ra),
147 eq_any_const(dut.i.rb),
148 eq_any_const(dut.i.rc),
149 ]
150
151 # check that the operation (op) is passed through (and muxid)
152 comb += Assert(dut.o.ctx.op == dut.i.ctx.op)
153 comb += Assert(dut.o.ctx.muxid == dut.i.ctx.muxid)
154
155 # we're only checking a particular operation:
156 comb += Assume(dut.i.ctx.op.insn_type == self.which.op)
157
158 # dispatch to check fn for each op
159 getattr(self, f"_check_{self.which.name.lower()}")(m, dut)
160
161 return m
162
163 def _check_shl(self, m, dut):
164 m.d.comb += Assume(dut.i.ra == 0)
165 expected = Signal(64)
166 with m.If(dut.i.ctx.op.is_32bit):
167 m.d.comb += expected.eq((dut.i.rs << dut.i.rb[:6])[:32])
168 with m.Else():
169 m.d.comb += expected.eq((dut.i.rs << dut.i.rb[:7])[:64])
170 m.d.comb += Assert(dut.o.o.data == expected)
171 m.d.comb += Assert(dut.o.xer_ca.data == 0)
172
173 def _check_shr(self, m, dut):
174 m.d.comb += Assume(dut.i.ra == 0)
175 expected = Signal(64)
176 carry = Signal()
177 shift_in_s = Signal(signed(128))
178 shift_roundtrip = Signal(signed(128))
179 shift_in_u = Signal(128)
180 shift_amt = Signal(7)
181 with m.If(dut.i.ctx.op.is_32bit):
182 m.d.comb += [
183 shift_amt.eq(dut.i.rb[:6]),
184 shift_in_s.eq(dut.i.rs[:32].as_signed()),
185 shift_in_u.eq(dut.i.rs[:32]),
186 ]
187 with m.Else():
188 m.d.comb += [
189 shift_amt.eq(dut.i.rb[:7]),
190 shift_in_s.eq(dut.i.rs.as_signed()),
191 shift_in_u.eq(dut.i.rs),
192 ]
193
194 with m.If(dut.i.ctx.op.is_signed):
195 m.d.comb += [
196 expected.eq(shift_in_s >> shift_amt),
197 shift_roundtrip.eq((shift_in_s >> shift_amt) << shift_amt),
198 carry.eq((shift_in_s < 0) & (shift_roundtrip != shift_in_s)),
199 ]
200 with m.Else():
201 m.d.comb += [
202 expected.eq(shift_in_u >> shift_amt),
203 carry.eq(0),
204 ]
205 m.d.comb += Assert(dut.o.o.data == expected)
206 m.d.comb += Assert(dut.o.xer_ca.data == Repl(carry, 2))
207
208 def _check_rlc32(self, m, dut):
209 m.d.comb += Assume(dut.i.ctx.op.is_32bit)
210 # rlwimi, rlwinm, and rlwnm
211
212 m.submodules.mask = mask = Mask()
213 expected = Signal(64)
214 rot = Signal(64)
215 m.d.comb += rot.eq(rotl32(dut.i.rs[:32], dut.i.rb[:5]))
216 m.d.comb += mask.start.eq(dut.fields.FormM.MB[:] + 32)
217 m.d.comb += mask.end.eq(dut.fields.FormM.ME[:] + 32)
218
219 # for rlwinm and rlwnm, ra is guaranteed to be 0, so that part of
220 # the expression turns into a no-op
221 m.d.comb += expected.eq((rot & mask.out) | (dut.i.ra & ~mask.out))
222 m.d.comb += Assert(dut.o.o.data == expected)
223 m.d.comb += Assert(dut.o.xer_ca.data == 0)
224
225 def _check_rlc64(self, m, dut):
226 m.d.comb += Assume(~dut.i.ctx.op.is_32bit)
227 # rldic and rldimi
228
229 # `rb` is always a 6-bit immediate
230 m.d.comb += Assume(dut.i.rb[6:] == 0)
231
232 m.submodules.mask = mask = Mask()
233 expected = Signal(64)
234 rot = Signal(64)
235 m.d.comb += rot.eq(rotl64(dut.i.rs, dut.i.rb[:6]))
236 mb = dut.fields.FormMD.mb[:]
237 m.d.comb += mask.start.eq(Cat(mb[1:6], mb[0]))
238 m.d.comb += mask.end.eq(63 - dut.i.rb[:6])
239
240 # for rldic, ra is guaranteed to be 0, so that part of
241 # the expression turns into a no-op
242 m.d.comb += expected.eq((rot & mask.out) | (dut.i.ra & ~mask.out))
243 m.d.comb += Assert(dut.o.o.data == expected)
244 m.d.comb += Assert(dut.o.xer_ca.data == 0)
245
246 def _check_rlcl(self, m, dut):
247 m.d.comb += Assume(~dut.i.ctx.op.is_32bit)
248 # rldicl and rldcl
249
250 m.d.comb += Assume(~dut.i.ctx.op.is_signed)
251 m.d.comb += Assume(dut.i.ra == 0)
252
253 m.submodules.mask = mask = Mask()
254 m.d.comb += mask.end.eq(63)
255 mb = dut.fields.FormMD.mb[:]
256 m.d.comb += mask.start.eq(Cat(mb[1:6], mb[0]))
257
258 rot = Signal(64)
259 m.d.comb += rot.eq(rotl64(dut.i.rs, dut.i.rb[:6]))
260
261 expected = Signal(64)
262 m.d.comb += expected.eq(rot & mask.out)
263
264 m.d.comb += Assert(dut.o.o.data == expected)
265 m.d.comb += Assert(dut.o.xer_ca.data == 0)
266
267 def _check_rlcr(self, m, dut):
268 m.d.comb += Assume(~dut.i.ctx.op.is_32bit)
269 # rldicr and rldcr
270
271 m.d.comb += Assume(~dut.i.ctx.op.is_signed)
272 m.d.comb += Assume(dut.i.ra == 0)
273
274 m.submodules.mask = mask = Mask()
275 m.d.comb += mask.start.eq(0)
276 me = dut.fields.FormMD.me[:]
277 m.d.comb += mask.end.eq(Cat(me[1:6], me[0]))
278
279 rot = Signal(64)
280 m.d.comb += rot.eq(rotl64(dut.i.rs, dut.i.rb[:6]))
281
282 expected = Signal(64)
283 m.d.comb += expected.eq(rot & mask.out)
284
285 m.d.comb += Assert(dut.o.o.data == expected)
286 m.d.comb += Assert(dut.o.xer_ca.data == 0)
287
288 def _check_extswsli(self, m, dut):
289 m.d.comb += Assume(dut.i.ra == 0)
290 m.d.comb += Assume(dut.i.rb[6:] == 0)
291 m.d.comb += Assume(~dut.i.ctx.op.is_32bit) # all instrs. are 64-bit
292 expected = Signal(64)
293 m.d.comb += expected.eq((dut.i.rs[0:32].as_signed() << dut.i.rb[:6]))
294 m.d.comb += Assert(dut.o.o.data == expected)
295 m.d.comb += Assert(dut.o.xer_ca.data == 0)
296
297 def _check_ternlog(self, m, dut):
298 lut = dut.fields.FormTLI.TLI[:]
299 for i in range(64):
300 idx = Cat(dut.i.rb[i], dut.i.ra[i], dut.i.rc[i])
301 for j in range(8):
302 with m.If(j == idx):
303 m.d.comb += Assert(dut.o.o.data[i] == lut[j])
304 m.d.comb += Assert(dut.o.xer_ca.data == 0)
305
306 def _check_grev32(self, m, dut):
307 m.d.comb += Assume(dut.i.ctx.op.is_32bit)
308 # assert zero-extended
309 m.d.comb += Assert(dut.o.o.data[32:] == 0)
310 i = Signal(5)
311 m.d.comb += eq_any_const(i)
312 idx = dut.i.rb[0: 5] ^ i
313 m.d.comb += Assert((dut.o.o.data >> i)[0] == (dut.i.ra >> idx)[0])
314 m.d.comb += Assert(dut.o.xer_ca.data == 0)
315
316 def _check_grev64(self, m, dut):
317 m.d.comb += Assume(~dut.i.ctx.op.is_32bit)
318 i = Signal(6)
319 m.d.comb += eq_any_const(i)
320 idx = dut.i.rb[0: 6] ^ i
321 m.d.comb += Assert((dut.o.o.data >> i)[0] == (dut.i.ra >> idx)[0])
322 m.d.comb += Assert(dut.o.xer_ca.data == 0)
323
324
325 class ALUTestCase(FHDLTestCase):
326 def run_it(self, which):
327 module = Driver(which)
328 self.assertFormal(module, mode="bmc", depth=2)
329 self.assertFormal(module, mode="cover", depth=2)
330
331 def test_shl(self):
332 self.run_it(TstOp.SHL)
333
334 def test_shr(self):
335 self.run_it(TstOp.SHR)
336
337 def test_rlc32(self):
338 self.run_it(TstOp.RLC32)
339
340 def test_rlc64(self):
341 self.run_it(TstOp.RLC64)
342
343 def test_rlcl(self):
344 self.run_it(TstOp.RLCL)
345
346 def test_rlcr(self):
347 self.run_it(TstOp.RLCR)
348
349 def test_extswsli(self):
350 self.run_it(TstOp.EXTSWSLI)
351
352 def test_ternlog(self):
353 self.run_it(TstOp.TERNLOG)
354
355 def test_grev32(self):
356 self.run_it(TstOp.GREV32)
357
358 def test_grev64(self):
359 self.run_it(TstOp.GREV64)
360
361
362 # check that all test cases are covered
363 for i in TstOp:
364 assert callable(getattr(ALUTestCase, f"test_{i.name.lower()}"))
365
366
367 if __name__ == '__main__':
368 unittest.main()