576a1658021d2d5fdbd578263d85ff62cf0a3dc5
[soc.git] / src / soc / fu / shift_rot / formal / proof_main_stage.py
1 # Proof of correctness for shift/rotate FU
2 # Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
3 """
4 Links:
5 * https://bugs.libre-soc.org/show_bug.cgi?id=340
6 """
7
8 import enum
9 from shutil import which
10 from nmigen import (Module, Signal, Elaboratable, Mux, Cat, Repl,
11 signed, Array, Const, Value, unsigned)
12 from nmigen.asserts import Assert, AnyConst, Assume, Cover
13 from nmutil.formaltest import FHDLTestCase
14 from nmutil.sim_util import do_sim
15 from nmigen.sim import Delay
16
17 from soc.fu.shift_rot.main_stage import ShiftRotMainStage
18 from soc.fu.shift_rot.rotator import right_mask, left_mask
19 from soc.fu.shift_rot.pipe_data import ShiftRotPipeSpec
20 from soc.fu.shift_rot.sr_input_record import CompSROpSubset
21 from openpower.decoder.power_enums import MicrOp
22 from openpower.consts import field
23
24 import unittest
25 from nmutil.extend import exts
26
27
28 @enum.unique
29 class TstOp(enum.Enum):
30 """ops we're testing, the idea is if we run a separate formal proof for
31 each instruction, we end up covering them all and each runs much faster,
32 also the formal proofs can be run in parallel."""
33 SHL = MicrOp.OP_SHL
34 SHR = MicrOp.OP_SHR
35 RLC32 = MicrOp.OP_RLC, 32
36 RLC64 = MicrOp.OP_RLC, 64
37 RLCL = MicrOp.OP_RLCL
38 RLCR = MicrOp.OP_RLCR
39 EXTSWSLI = MicrOp.OP_EXTSWSLI
40 TERNLOG = MicrOp.OP_TERNLOG
41 GREV32 = MicrOp.OP_GREV, 32
42 GREV64 = MicrOp.OP_GREV, 64
43
44 @property
45 def op(self):
46 if isinstance(self.value, tuple):
47 return self.value[0]
48 return self.value
49
50
51 def eq_any_const(sig: Signal):
52 return sig.eq(AnyConst(sig.shape(), src_loc_at=1))
53
54
55 class Mask(Elaboratable):
56 # copied from qemu's mask fn:
57 # https://gitlab.com/qemu-project/qemu/-/blob/477c3b934a47adf7de285863f59d6e4503dd1a6d/target/ppc/internal.h#L21
58 def __init__(self):
59 self.start = Signal(6)
60 self.end = Signal(6)
61 self.out = Signal(64)
62
63 def elaborate(self, platform):
64 m = Module()
65 max_val = Const(~0, unsigned(64))
66 max_bit = 63
67 with m.If(self.start == 0):
68 m.d.comb += self.out.eq(max_val << (max_bit - self.end))
69 with m.Elif(self.end == max_bit):
70 m.d.comb += self.out.eq(max_val >> self.start)
71 with m.Else():
72 ret = (max_val >> self.start) ^ ((max_val >> self.end) >> 1)
73 m.d.comb += self.out.eq(Mux(self.start > self.end, ~ret, ret))
74 return m
75
76
77 class TstMask(unittest.TestCase):
78 def test_mask(self):
79 dut = Mask()
80
81 def case(start, end, expected):
82 with self.subTest(start=start, end=end):
83 yield dut.start.eq(start)
84 yield dut.end.eq(end)
85 yield Delay(1e-6)
86 out = yield dut.out
87 with self.subTest(out=hex(out), expected=hex(expected)):
88 self.assertEqual(expected, out)
89
90 def process():
91 for start in range(64):
92 for end in range(64):
93 expected = 0
94 if start > end:
95 for i in range(start, 64):
96 expected |= 1 << (63 - i)
97 for i in range(0, end + 1):
98 expected |= 1 << (63 - i)
99 else:
100 for i in range(start, end + 1):
101 expected |= 1 << (63 - i)
102 yield from case(start, end, expected)
103 with do_sim(self, dut, [dut.start, dut.end, dut.out]) as sim:
104 sim.add_process(process)
105 sim.run()
106
107
108 def rotl64(v, amt):
109 v |= Const(0, 64) # convert to value at least 64-bits wide
110 amt |= Const(0, 6) # convert to value at least 6-bits wide
111 return (Cat(v[:64], v[:64]) >> (64 - amt[:6]))[:64]
112
113
114 def rotl32(v, amt):
115 v |= Const(0, 32) # convert to value at least 32-bits wide
116 return rotl64(Cat(v[:32], v[:32]), amt)
117
118
119 # This defines a module to drive the device under test and assert
120 # properties about its outputs
121 class Driver(Elaboratable):
122 def __init__(self, which):
123 assert isinstance(which, TstOp)
124 self.which = which
125
126 def elaborate(self, platform):
127 m = Module()
128 comb = m.d.comb
129
130 pspec = ShiftRotPipeSpec(id_wid=2, parent_pspec=None)
131 pspec.draft_bitmanip = True
132 m.submodules.dut = dut = ShiftRotMainStage(pspec)
133
134 # Set inputs to formal variables
135 comb += [
136 eq_any_const(dut.i.ctx.op.insn_type),
137 eq_any_const(dut.i.ctx.op.fn_unit),
138 eq_any_const(dut.i.ctx.op.imm_data.data),
139 eq_any_const(dut.i.ctx.op.imm_data.ok),
140 eq_any_const(dut.i.ctx.op.rc.rc),
141 eq_any_const(dut.i.ctx.op.rc.ok),
142 eq_any_const(dut.i.ctx.op.oe.oe),
143 eq_any_const(dut.i.ctx.op.oe.ok),
144 eq_any_const(dut.i.ctx.op.write_cr0),
145 eq_any_const(dut.i.ctx.op.input_carry),
146 eq_any_const(dut.i.ctx.op.output_carry),
147 eq_any_const(dut.i.ctx.op.input_cr),
148 eq_any_const(dut.i.ctx.op.is_32bit),
149 eq_any_const(dut.i.ctx.op.is_signed),
150 eq_any_const(dut.i.ctx.op.insn),
151 eq_any_const(dut.i.xer_ca),
152 eq_any_const(dut.i.ra),
153 eq_any_const(dut.i.rb),
154 eq_any_const(dut.i.rc),
155 ]
156
157 # check that the operation (op) is passed through (and muxid)
158 comb += Assert(dut.o.ctx.op == dut.i.ctx.op)
159 comb += Assert(dut.o.ctx.muxid == dut.i.ctx.muxid)
160
161 # we're only checking a particular operation:
162 comb += Assume(dut.i.ctx.op.insn_type == self.which.op)
163
164 # dispatch to check fn for each op
165 getattr(self, f"_check_{self.which.name.lower()}")(m, dut)
166
167 return m
168
169 # all following code in elaborate is kept for ease of reference, to be
170 # deleted once this proof is completed.
171
172 # convenience variables
173 rs = dut.i.rs # register to shift
174 b = dut.i.rb # register containing amount to shift by
175 ra = dut.i.a # source register if masking is to be done
176 carry_in = dut.i.xer_ca[0]
177 carry_in32 = dut.i.xer_ca[1]
178 carry_out = dut.o.xer_ca
179 o = dut.o.o.data
180 print("fields", rec.fields)
181 itype = rec.insn_type
182
183 # instruction fields
184 m_fields = dut.fields.FormM
185 md_fields = dut.fields.FormMD
186
187 # setup random inputs
188 comb += rs.eq(AnyConst(64))
189 comb += ra.eq(AnyConst(64))
190 comb += b.eq(AnyConst(64))
191 comb += carry_in.eq(AnyConst(1))
192 comb += carry_in32.eq(AnyConst(1))
193
194 # copy operation
195 comb += dut.i.ctx.op.eq(rec)
196
197 # check that the operation (op) is passed through (and muxid)
198 comb += Assert(dut.o.ctx.op == dut.i.ctx.op)
199 comb += Assert(dut.o.ctx.muxid == dut.i.ctx.muxid)
200
201 # signed and signed/32 versions of input rs
202 a_signed = Signal(signed(64))
203 a_signed_32 = Signal(signed(32))
204 comb += a_signed.eq(rs)
205 comb += a_signed_32.eq(rs[0:32])
206
207 # masks: start-left
208 mb = Signal(7, reset_less=True)
209 ml = Signal(64, reset_less=True)
210
211 # clear left?
212 with m.If((itype == MicrOp.OP_RLC) | (itype == MicrOp.OP_RLCL)):
213 with m.If(rec.is_32bit):
214 comb += mb.eq(m_fields.MB[:])
215 with m.Else():
216 comb += mb.eq(md_fields.mb[:])
217 with m.Else():
218 with m.If(rec.is_32bit):
219 comb += mb.eq(b[0:6])
220 with m.Else():
221 comb += mb.eq(b+32)
222 comb += ml.eq(left_mask(m, mb))
223
224 # masks: end-right
225 me = Signal(7, reset_less=True)
226 mr = Signal(64, reset_less=True)
227
228 # clear right?
229 with m.If((itype == MicrOp.OP_RLC) | (itype == MicrOp.OP_RLCR)):
230 with m.If(rec.is_32bit):
231 comb += me.eq(m_fields.ME[:])
232 with m.Else():
233 comb += me.eq(md_fields.me[:])
234 with m.Else():
235 with m.If(rec.is_32bit):
236 comb += me.eq(b[0:6])
237 with m.Else():
238 comb += me.eq(63-b)
239 comb += mr.eq(right_mask(m, me))
240
241 # must check Data.ok
242 o_ok = Signal()
243 comb += o_ok.eq(1)
244
245 # main assertion of arithmetic operations
246 with m.Switch(itype):
247
248 # left-shift: 64/32-bit
249 with m.Case(MicrOp.OP_SHL):
250 comb += Assume(ra == 0)
251 with m.If(rec.is_32bit):
252 comb += Assert(o[0:32] == ((rs << b[0:6]) & 0xffffffff))
253 comb += Assert(o[32:64] == 0)
254 with m.Else():
255 comb += Assert(o == ((rs << b[0:7]) & ((1 << 64)-1)))
256
257 # right-shift: 64/32-bit / signed
258 with m.Case(MicrOp.OP_SHR):
259 comb += Assume(ra == 0)
260 with m.If(~rec.is_signed):
261 with m.If(rec.is_32bit):
262 comb += Assert(o[0:32] == (rs[0:32] >> b[0:6]))
263 comb += Assert(o[32:64] == 0)
264 with m.Else():
265 comb += Assert(o == (rs >> b[0:7]))
266 with m.Else():
267 with m.If(rec.is_32bit):
268 comb += Assert(o[0:32] == (a_signed_32 >> b[0:6]))
269 comb += Assert(o[32:64] == Repl(rs[31], 32))
270 with m.Else():
271 comb += Assert(o == (a_signed >> b[0:7]))
272
273 # extswsli: 32/64-bit moded
274 with m.Case(MicrOp.OP_EXTSWSLI):
275 comb += Assume(ra == 0)
276 with m.If(rec.is_32bit):
277 comb += Assert(o[0:32] == ((rs << b[0:6]) & 0xffffffff))
278 comb += Assert(o[32:64] == 0)
279 with m.Else():
280 # sign-extend to 64 bit
281 a_s = Signal(64, reset_less=True)
282 comb += a_s.eq(exts(rs, 32, 64))
283 comb += Assert(o == ((a_s << b[0:7]) & ((1 << 64)-1)))
284
285 # rlwinm, rlwnm, rlwimi
286 # *CAN* these even be 64-bit capable? I don't think they are.
287 with m.Case(MicrOp.OP_RLC):
288 comb += Assume(ra == 0)
289 comb += Assume(rec.is_32bit)
290
291 # Duplicate some signals so that they're much easier to find
292 # in gtkwave.
293 # Pro-tip: when debugging, factor out expressions into
294 # explicitly named
295 # signals, and search using a unique grep-tag (RLC in my case).
296 # After
297 # debugging, resubstitute values to comply with surrounding
298 # code norms.
299
300 mrl = Signal(64, reset_less=True, name='MASK_FOR_RLC')
301 with m.If(mb > me):
302 comb += mrl.eq(ml | mr)
303 with m.Else():
304 comb += mrl.eq(ml & mr)
305
306 ainp = Signal(64, reset_less=True, name='A_INP_FOR_RLC')
307 comb += ainp.eq(field(rs, 32, 63))
308
309 sh = Signal(6, reset_less=True, name='SH_FOR_RLC')
310 comb += sh.eq(b[0:6])
311
312 exp_shl = Signal(64, reset_less=True,
313 name='A_SHIFTED_LEFT_BY_SH_FOR_RLC')
314 comb += exp_shl.eq((ainp << sh) & 0xFFFFFFFF)
315
316 exp_shr = Signal(64, reset_less=True,
317 name='A_SHIFTED_RIGHT_FOR_RLC')
318 comb += exp_shr.eq((ainp >> (32 - sh)) & 0xFFFFFFFF)
319
320 exp_rot = Signal(64, reset_less=True,
321 name='A_ROTATED_LEFT_FOR_RLC')
322 comb += exp_rot.eq(exp_shl | exp_shr)
323
324 exp_ol = Signal(32, reset_less=True,
325 name='EXPECTED_OL_FOR_RLC')
326 comb += exp_ol.eq(field((exp_rot & mrl) | (ainp & ~mrl),
327 32, 63))
328
329 act_ol = Signal(32, reset_less=True, name='ACTUAL_OL_FOR_RLC')
330 comb += act_ol.eq(field(o, 32, 63))
331
332 # If I uncomment the following lines, I can confirm that all
333 # 32-bit rotations work. If I uncomment only one of the
334 # following lines, I can confirm that all 32-bit rotations
335 # work. When I remove/recomment BOTH lines, however, the
336 # assertion fails. Why??
337
338 # comb += Assume(mr == 0xFFFFFFFF)
339 # comb += Assume(ml == 0xFFFFFFFF)
340 # with m.If(rec.is_32bit):
341 # comb += Assert(act_ol == exp_ol)
342 # comb += Assert(field(o, 0, 31) == 0)
343
344 # TODO
345 with m.Case(MicrOp.OP_RLCR):
346 pass
347 with m.Case(MicrOp.OP_RLCL):
348 pass
349 with m.Case(MicrOp.OP_TERNLOG):
350 lut = dut.fields.FormTLI.TLI[:]
351 for i in range(64):
352 idx = Cat(dut.i.rb[i], dut.i.ra[i], dut.i.rc[i])
353 for j in range(8):
354 with m.If(j == idx):
355 comb += Assert(dut.o.o.data[i] == lut[j])
356 with m.Case(MicrOp.OP_GREV):
357 ra_bits = Array(dut.i.ra[i] for i in range(64))
358 with m.If(dut.i.ctx.op.is_32bit):
359 # assert zero-extended
360 comb += Assert(dut.o.o.data[32:] == 0)
361 for i in range(32):
362 idx = dut.i.rb[0:5] ^ i
363 comb += Assert(dut.o.o.data[i]
364 == ra_bits[idx])
365 with m.Else():
366 for i in range(64):
367 idx = dut.i.rb[0:6] ^ i
368 comb += Assert(dut.o.o.data[i]
369 == ra_bits[idx])
370
371 with m.Default():
372 comb += o_ok.eq(0)
373
374 # check that data ok was only enabled when op actioned
375 comb += Assert(dut.o.o.ok == o_ok)
376
377 return m
378
379 def _check_shl(self, m, dut):
380 m.d.comb += Assume(dut.i.ra == 0)
381 expected = Signal(64)
382 with m.If(dut.i.ctx.op.is_32bit):
383 m.d.comb += expected.eq((dut.i.rs << dut.i.rb[:6])[:32])
384 with m.Else():
385 m.d.comb += expected.eq((dut.i.rs << dut.i.rb[:7])[:64])
386 m.d.comb += Assert(dut.o.o.data == expected)
387 m.d.comb += Assert(dut.o.xer_ca.data == 0)
388
389 def _check_shr(self, m, dut):
390 m.d.comb += Assume(dut.i.ra == 0)
391 expected = Signal(64)
392 carry = Signal()
393 shift_in_s = Signal(signed(128))
394 shift_roundtrip = Signal(signed(128))
395 shift_in_u = Signal(128)
396 shift_amt = Signal(7)
397 with m.If(dut.i.ctx.op.is_32bit):
398 m.d.comb += [
399 shift_amt.eq(dut.i.rb[:6]),
400 shift_in_s.eq(dut.i.rs[:32].as_signed()),
401 shift_in_u.eq(dut.i.rs[:32]),
402 ]
403 with m.Else():
404 m.d.comb += [
405 shift_amt.eq(dut.i.rb[:7]),
406 shift_in_s.eq(dut.i.rs.as_signed()),
407 shift_in_u.eq(dut.i.rs),
408 ]
409
410 with m.If(dut.i.ctx.op.is_signed):
411 m.d.comb += [
412 expected.eq(shift_in_s >> shift_amt),
413 shift_roundtrip.eq((shift_in_s >> shift_amt) << shift_amt),
414 carry.eq((shift_in_s < 0) & (shift_roundtrip != shift_in_s)),
415 ]
416 with m.Else():
417 m.d.comb += [
418 expected.eq(shift_in_u >> shift_amt),
419 carry.eq(0),
420 ]
421 m.d.comb += Assert(dut.o.o.data == expected)
422 m.d.comb += Assert(dut.o.xer_ca.data == Repl(carry, 2))
423
424 def _check_rlc32(self, m, dut):
425 m.d.comb += Assume(dut.i.ctx.op.is_32bit)
426 # rlwimi, rlwinm, and rlwnm
427
428 m.submodules.mask = mask = Mask()
429 expected = Signal(64)
430 rot = Signal(64)
431 m.d.comb += rot.eq(rotl32(dut.i.rs[:32], dut.i.rb[:5]))
432 m.d.comb += mask.start.eq(dut.fields.FormM.MB[:] + 32)
433 m.d.comb += mask.end.eq(dut.fields.FormM.ME[:] + 32)
434
435 # for rlwinm and rlwnm, ra is guaranteed to be 0, so that part of
436 # the expression turns into a no-op
437 m.d.comb += expected.eq((rot & mask.out) | (dut.i.ra & ~mask.out))
438 m.d.comb += Assert(dut.o.o.data == expected)
439 m.d.comb += Assert(dut.o.xer_ca.data == 0)
440
441 def _check_rlc64(self, m, dut):
442 m.d.comb += Assume(~dut.i.ctx.op.is_32bit)
443 # rldic and rldimi
444
445 # `rb` is always a 6-bit immediate
446 m.d.comb += Assume(dut.i.rb[6:] == 0)
447
448 m.submodules.mask = mask = Mask()
449 expected = Signal(64)
450 rot = Signal(64)
451 m.d.comb += rot.eq(rotl64(dut.i.rs, dut.i.rb[:6]))
452 mb = dut.fields.FormMD.mb[:]
453 m.d.comb += mask.start.eq(Cat(mb[1:6], mb[0]))
454 m.d.comb += mask.end.eq(63 - dut.i.rb[:6])
455
456 # for rldic, ra is guaranteed to be 0, so that part of
457 # the expression turns into a no-op
458 m.d.comb += expected.eq((rot & mask.out) | (dut.i.ra & ~mask.out))
459 m.d.comb += Assert(dut.o.o.data == expected)
460 m.d.comb += Assert(dut.o.xer_ca.data == 0)
461
462 def _check_rlcl(self, m, dut):
463 raise NotImplementedError
464
465 def _check_rlcr(self, m, dut):
466 raise NotImplementedError
467
468 def _check_extswsli(self, m, dut):
469 m.d.comb += Assume(dut.i.ra == 0)
470 m.d.comb += Assume(dut.i.rb[6:] == 0)
471 m.d.comb += Assume(~dut.i.ctx.op.is_32bit) # all instrs. are 64-bit
472 expected = Signal(64)
473 m.d.comb += expected.eq((dut.i.rs[0:32].as_signed() << dut.i.rb[:6]))
474 m.d.comb += Assert(dut.o.o.data == expected)
475 m.d.comb += Assert(dut.o.xer_ca.data == 0)
476
477 def _check_ternlog(self, m, dut):
478 lut = dut.fields.FormTLI.TLI[:]
479 for i in range(64):
480 idx = Cat(dut.i.rb[i], dut.i.ra[i], dut.i.rc[i])
481 for j in range(8):
482 with m.If(j == idx):
483 m.d.comb += Assert(dut.o.o.data[i] == lut[j])
484 m.d.comb += Assert(dut.o.xer_ca.data == 0)
485
486 def _check_grev32(self, m, dut):
487 m.d.comb += Assume(dut.i.ctx.op.is_32bit)
488 # assert zero-extended
489 m.d.comb += Assert(dut.o.o.data[32:] == 0)
490 i = Signal(5)
491 m.d.comb += eq_any_const(i)
492 idx = dut.i.rb[0: 5] ^ i
493 m.d.comb += Assert((dut.o.o.data >> i)[0] == (dut.i.ra >> idx)[0])
494 m.d.comb += Assert(dut.o.xer_ca.data == 0)
495
496 def _check_grev64(self, m, dut):
497 m.d.comb += Assume(~dut.i.ctx.op.is_32bit)
498 i = Signal(6)
499 m.d.comb += eq_any_const(i)
500 idx = dut.i.rb[0: 6] ^ i
501 m.d.comb += Assert((dut.o.o.data >> i)[0] == (dut.i.ra >> idx)[0])
502 m.d.comb += Assert(dut.o.xer_ca.data == 0)
503
504
505 class ALUTestCase(FHDLTestCase):
506 def run_it(self, which):
507 module = Driver(which)
508 self.assertFormal(module, mode="bmc", depth=2)
509 self.assertFormal(module, mode="cover", depth=2)
510
511 def test_shl(self):
512 self.run_it(TstOp.SHL)
513
514 def test_shr(self):
515 self.run_it(TstOp.SHR)
516
517 def test_rlc32(self):
518 self.run_it(TstOp.RLC32)
519
520 def test_rlc64(self):
521 self.run_it(TstOp.RLC64)
522
523 def test_rlcl(self):
524 self.run_it(TstOp.RLCL)
525
526 def test_rlcr(self):
527 self.run_it(TstOp.RLCR)
528
529 def test_extswsli(self):
530 self.run_it(TstOp.EXTSWSLI)
531
532 def test_ternlog(self):
533 self.run_it(TstOp.TERNLOG)
534
535 def test_grev32(self):
536 self.run_it(TstOp.GREV32)
537
538 def test_grev64(self):
539 self.run_it(TstOp.GREV64)
540
541
542 # check that all test cases are covered
543 for i in TstOp:
544 assert callable(getattr(ALUTestCase, f"test_{i.name.lower()}"))
545
546
547 if __name__ == '__main__':
548 unittest.main()