be0c4b169fd94a79795c5c7781ba8e66be8b7566
[soc.git] / src / soc / fu / shift_rot / formal / proof_main_stage.py
1 # Proof of correctness for shift/rotate FU
2 # Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
3 """
4 Links:
5 * https://bugs.libre-soc.org/show_bug.cgi?id=340
6 """
7
8 import enum
9 from shutil import which
10 from nmigen import (Module, Signal, Elaboratable, Mux, Cat, Repl,
11 signed, Array, Const, Value)
12 from nmigen.asserts import Assert, AnyConst, Assume, Cover
13 from nmutil.formaltest import FHDLTestCase
14 from nmigen.cli import rtlil
15
16 from soc.fu.shift_rot.main_stage import ShiftRotMainStage
17 from soc.fu.shift_rot.rotator import right_mask, left_mask
18 from soc.fu.shift_rot.pipe_data import ShiftRotPipeSpec
19 from soc.fu.shift_rot.sr_input_record import CompSROpSubset
20 from openpower.decoder.power_enums import MicrOp
21 from openpower.consts import field
22
23 import unittest
24 from nmutil.extend import exts
25
26
27 @enum.unique
28 class TstOp(enum.Enum):
29 """ops we're testing, the idea is if we run a separate formal proof for
30 each instruction, we end up covering them all and each runs much faster,
31 also the formal proofs can be run in parallel."""
32 SHL = MicrOp.OP_SHL
33 SHR = MicrOp.OP_SHR
34 RLC = MicrOp.OP_RLC
35 RLCL = MicrOp.OP_RLCL
36 RLCR = MicrOp.OP_RLCR
37 EXTSWSLI = MicrOp.OP_EXTSWSLI
38 TERNLOG = MicrOp.OP_TERNLOG
39 GREV32 = MicrOp.OP_GREV, 32
40 GREV64 = MicrOp.OP_GREV, 64
41
42 @property
43 def op(self):
44 if isinstance(self.value, tuple):
45 return self.value[0]
46 return self.value
47
48
49 def eq_any_const(sig: Signal):
50 return sig.eq(AnyConst(sig.shape(), src_loc_at=1))
51
52
53 class Mask(Elaboratable):
54 # copied from qemu's mask fn:
55 # https://gitlab.com/qemu-project/qemu/-/blob/477c3b934a47adf7de285863f59d6e4503dd1a6d/target/ppc/internal.h#L21
56 def __init__(self):
57 self.start = Signal(6)
58 self.end = Signal(6)
59 self.out = Signal(64)
60
61 def elaborate(self, platform):
62 m = Module()
63 max_val = Const(~0, 64)
64 max_bit = 63
65 with m.If(self.start == 0):
66 m.d.comb += self.out.eq(max_val << (max_bit - self.end))
67 with m.Elif(self.end == max_bit):
68 m.d.comb += self.out.eq(max_val >> self.start)
69 with m.Else():
70 ret = (max_val >> self.start) ^ ((max_val >> self.end) >> 1)
71 m.d.comb += self.out.eq(Mux(self.start > self.end, ~ret, ret))
72 return m
73
74
75 def rotl64(v, amt):
76 v |= Const(0, 64) # convert to value at least 64-bits wide
77 amt |= Const(0, 6) # convert to value at least 6-bits wide
78 return (Cat(v[:64], v[:64]) >> (64 - amt[:6]))[:64]
79
80
81 def rotl32(v, amt):
82 v |= Const(0, 32) # convert to value at least 32-bits wide
83 return rotl64(Cat(v[:32], v[:32]), amt)
84
85
86 # This defines a module to drive the device under test and assert
87 # properties about its outputs
88 class Driver(Elaboratable):
89 def __init__(self, which):
90 assert isinstance(which, TstOp)
91 self.which = which
92
93 def elaborate(self, platform):
94 m = Module()
95 comb = m.d.comb
96
97 pspec = ShiftRotPipeSpec(id_wid=2, parent_pspec=None)
98 pspec.draft_bitmanip = True
99 m.submodules.dut = dut = ShiftRotMainStage(pspec)
100
101 # Set inputs to formal variables
102 comb += [
103 eq_any_const(dut.i.ctx.op.insn_type),
104 eq_any_const(dut.i.ctx.op.fn_unit),
105 eq_any_const(dut.i.ctx.op.imm_data.data),
106 eq_any_const(dut.i.ctx.op.imm_data.ok),
107 eq_any_const(dut.i.ctx.op.rc.rc),
108 eq_any_const(dut.i.ctx.op.rc.ok),
109 eq_any_const(dut.i.ctx.op.oe.oe),
110 eq_any_const(dut.i.ctx.op.oe.ok),
111 eq_any_const(dut.i.ctx.op.write_cr0),
112 eq_any_const(dut.i.ctx.op.input_carry),
113 eq_any_const(dut.i.ctx.op.output_carry),
114 eq_any_const(dut.i.ctx.op.input_cr),
115 eq_any_const(dut.i.ctx.op.is_32bit),
116 eq_any_const(dut.i.ctx.op.is_signed),
117 eq_any_const(dut.i.ctx.op.insn),
118 eq_any_const(dut.i.xer_ca),
119 eq_any_const(dut.i.ra),
120 eq_any_const(dut.i.rb),
121 eq_any_const(dut.i.rc),
122 ]
123
124 # check that the operation (op) is passed through (and muxid)
125 comb += Assert(dut.o.ctx.op == dut.i.ctx.op)
126 comb += Assert(dut.o.ctx.muxid == dut.i.ctx.muxid)
127
128 # we're only checking a particular operation:
129 comb += Assume(dut.i.ctx.op.insn_type == self.which.op)
130
131 # dispatch to check fn for each op
132 getattr(self, f"_check_{self.which.name.lower()}")(m, dut)
133
134 return m
135
136 # all following code in elaborate is kept for ease of reference, to be
137 # deleted once this proof is completed.
138
139 # convenience variables
140 rs = dut.i.rs # register to shift
141 b = dut.i.rb # register containing amount to shift by
142 ra = dut.i.a # source register if masking is to be done
143 carry_in = dut.i.xer_ca[0]
144 carry_in32 = dut.i.xer_ca[1]
145 carry_out = dut.o.xer_ca
146 o = dut.o.o.data
147 print("fields", rec.fields)
148 itype = rec.insn_type
149
150 # instruction fields
151 m_fields = dut.fields.FormM
152 md_fields = dut.fields.FormMD
153
154 # setup random inputs
155 comb += rs.eq(AnyConst(64))
156 comb += ra.eq(AnyConst(64))
157 comb += b.eq(AnyConst(64))
158 comb += carry_in.eq(AnyConst(1))
159 comb += carry_in32.eq(AnyConst(1))
160
161 # copy operation
162 comb += dut.i.ctx.op.eq(rec)
163
164 # check that the operation (op) is passed through (and muxid)
165 comb += Assert(dut.o.ctx.op == dut.i.ctx.op)
166 comb += Assert(dut.o.ctx.muxid == dut.i.ctx.muxid)
167
168 # signed and signed/32 versions of input rs
169 a_signed = Signal(signed(64))
170 a_signed_32 = Signal(signed(32))
171 comb += a_signed.eq(rs)
172 comb += a_signed_32.eq(rs[0:32])
173
174 # masks: start-left
175 mb = Signal(7, reset_less=True)
176 ml = Signal(64, reset_less=True)
177
178 # clear left?
179 with m.If((itype == MicrOp.OP_RLC) | (itype == MicrOp.OP_RLCL)):
180 with m.If(rec.is_32bit):
181 comb += mb.eq(m_fields.MB[:])
182 with m.Else():
183 comb += mb.eq(md_fields.mb[:])
184 with m.Else():
185 with m.If(rec.is_32bit):
186 comb += mb.eq(b[0:6])
187 with m.Else():
188 comb += mb.eq(b+32)
189 comb += ml.eq(left_mask(m, mb))
190
191 # masks: end-right
192 me = Signal(7, reset_less=True)
193 mr = Signal(64, reset_less=True)
194
195 # clear right?
196 with m.If((itype == MicrOp.OP_RLC) | (itype == MicrOp.OP_RLCR)):
197 with m.If(rec.is_32bit):
198 comb += me.eq(m_fields.ME[:])
199 with m.Else():
200 comb += me.eq(md_fields.me[:])
201 with m.Else():
202 with m.If(rec.is_32bit):
203 comb += me.eq(b[0:6])
204 with m.Else():
205 comb += me.eq(63-b)
206 comb += mr.eq(right_mask(m, me))
207
208 # must check Data.ok
209 o_ok = Signal()
210 comb += o_ok.eq(1)
211
212 # main assertion of arithmetic operations
213 with m.Switch(itype):
214
215 # left-shift: 64/32-bit
216 with m.Case(MicrOp.OP_SHL):
217 comb += Assume(ra == 0)
218 with m.If(rec.is_32bit):
219 comb += Assert(o[0:32] == ((rs << b[0:6]) & 0xffffffff))
220 comb += Assert(o[32:64] == 0)
221 with m.Else():
222 comb += Assert(o == ((rs << b[0:7]) & ((1 << 64)-1)))
223
224 # right-shift: 64/32-bit / signed
225 with m.Case(MicrOp.OP_SHR):
226 comb += Assume(ra == 0)
227 with m.If(~rec.is_signed):
228 with m.If(rec.is_32bit):
229 comb += Assert(o[0:32] == (rs[0:32] >> b[0:6]))
230 comb += Assert(o[32:64] == 0)
231 with m.Else():
232 comb += Assert(o == (rs >> b[0:7]))
233 with m.Else():
234 with m.If(rec.is_32bit):
235 comb += Assert(o[0:32] == (a_signed_32 >> b[0:6]))
236 comb += Assert(o[32:64] == Repl(rs[31], 32))
237 with m.Else():
238 comb += Assert(o == (a_signed >> b[0:7]))
239
240 # extswsli: 32/64-bit moded
241 with m.Case(MicrOp.OP_EXTSWSLI):
242 comb += Assume(ra == 0)
243 with m.If(rec.is_32bit):
244 comb += Assert(o[0:32] == ((rs << b[0:6]) & 0xffffffff))
245 comb += Assert(o[32:64] == 0)
246 with m.Else():
247 # sign-extend to 64 bit
248 a_s = Signal(64, reset_less=True)
249 comb += a_s.eq(exts(rs, 32, 64))
250 comb += Assert(o == ((a_s << b[0:7]) & ((1 << 64)-1)))
251
252 # rlwinm, rlwnm, rlwimi
253 # *CAN* these even be 64-bit capable? I don't think they are.
254 with m.Case(MicrOp.OP_RLC):
255 comb += Assume(ra == 0)
256 comb += Assume(rec.is_32bit)
257
258 # Duplicate some signals so that they're much easier to find
259 # in gtkwave.
260 # Pro-tip: when debugging, factor out expressions into
261 # explicitly named
262 # signals, and search using a unique grep-tag (RLC in my case).
263 # After
264 # debugging, resubstitute values to comply with surrounding
265 # code norms.
266
267 mrl = Signal(64, reset_less=True, name='MASK_FOR_RLC')
268 with m.If(mb > me):
269 comb += mrl.eq(ml | mr)
270 with m.Else():
271 comb += mrl.eq(ml & mr)
272
273 ainp = Signal(64, reset_less=True, name='A_INP_FOR_RLC')
274 comb += ainp.eq(field(rs, 32, 63))
275
276 sh = Signal(6, reset_less=True, name='SH_FOR_RLC')
277 comb += sh.eq(b[0:6])
278
279 exp_shl = Signal(64, reset_less=True,
280 name='A_SHIFTED_LEFT_BY_SH_FOR_RLC')
281 comb += exp_shl.eq((ainp << sh) & 0xFFFFFFFF)
282
283 exp_shr = Signal(64, reset_less=True,
284 name='A_SHIFTED_RIGHT_FOR_RLC')
285 comb += exp_shr.eq((ainp >> (32 - sh)) & 0xFFFFFFFF)
286
287 exp_rot = Signal(64, reset_less=True,
288 name='A_ROTATED_LEFT_FOR_RLC')
289 comb += exp_rot.eq(exp_shl | exp_shr)
290
291 exp_ol = Signal(32, reset_less=True,
292 name='EXPECTED_OL_FOR_RLC')
293 comb += exp_ol.eq(field((exp_rot & mrl) | (ainp & ~mrl),
294 32, 63))
295
296 act_ol = Signal(32, reset_less=True, name='ACTUAL_OL_FOR_RLC')
297 comb += act_ol.eq(field(o, 32, 63))
298
299 # If I uncomment the following lines, I can confirm that all
300 # 32-bit rotations work. If I uncomment only one of the
301 # following lines, I can confirm that all 32-bit rotations
302 # work. When I remove/recomment BOTH lines, however, the
303 # assertion fails. Why??
304
305 # comb += Assume(mr == 0xFFFFFFFF)
306 # comb += Assume(ml == 0xFFFFFFFF)
307 # with m.If(rec.is_32bit):
308 # comb += Assert(act_ol == exp_ol)
309 # comb += Assert(field(o, 0, 31) == 0)
310
311 # TODO
312 with m.Case(MicrOp.OP_RLCR):
313 pass
314 with m.Case(MicrOp.OP_RLCL):
315 pass
316 with m.Case(MicrOp.OP_TERNLOG):
317 lut = dut.fields.FormTLI.TLI[:]
318 for i in range(64):
319 idx = Cat(dut.i.rb[i], dut.i.ra[i], dut.i.rc[i])
320 for j in range(8):
321 with m.If(j == idx):
322 comb += Assert(dut.o.o.data[i] == lut[j])
323 with m.Case(MicrOp.OP_GREV):
324 ra_bits = Array(dut.i.ra[i] for i in range(64))
325 with m.If(dut.i.ctx.op.is_32bit):
326 # assert zero-extended
327 comb += Assert(dut.o.o.data[32:] == 0)
328 for i in range(32):
329 idx = dut.i.rb[0:5] ^ i
330 comb += Assert(dut.o.o.data[i]
331 == ra_bits[idx])
332 with m.Else():
333 for i in range(64):
334 idx = dut.i.rb[0:6] ^ i
335 comb += Assert(dut.o.o.data[i]
336 == ra_bits[idx])
337
338 with m.Default():
339 comb += o_ok.eq(0)
340
341 # check that data ok was only enabled when op actioned
342 comb += Assert(dut.o.o.ok == o_ok)
343
344 return m
345
346 def _check_shl(self, m, dut):
347 m.d.comb += Assume(dut.i.ra == 0)
348 expected = Signal(64)
349 with m.If(dut.i.ctx.op.is_32bit):
350 m.d.comb += expected.eq((dut.i.rs << dut.i.rb[:6])[:32])
351 with m.Else():
352 m.d.comb += expected.eq((dut.i.rs << dut.i.rb[:7])[:64])
353 m.d.comb += Assert(dut.o.o.data == expected)
354 m.d.comb += Assert(dut.o.xer_ca.data == 0)
355
356 def _check_shr(self, m, dut):
357 m.d.comb += Assume(dut.i.ra == 0)
358 expected = Signal(64)
359 carry = Signal()
360 shift_in_s = Signal(signed(128))
361 shift_roundtrip = Signal(signed(128))
362 shift_in_u = Signal(128)
363 shift_amt = Signal(7)
364 with m.If(dut.i.ctx.op.is_32bit):
365 m.d.comb += [
366 shift_amt.eq(dut.i.rb[:6]),
367 shift_in_s.eq(dut.i.rs[:32].as_signed()),
368 shift_in_u.eq(dut.i.rs[:32]),
369 ]
370 with m.Else():
371 m.d.comb += [
372 shift_amt.eq(dut.i.rb[:7]),
373 shift_in_s.eq(dut.i.rs.as_signed()),
374 shift_in_u.eq(dut.i.rs),
375 ]
376
377 with m.If(dut.i.ctx.op.is_signed):
378 m.d.comb += [
379 expected.eq(shift_in_s >> shift_amt),
380 shift_roundtrip.eq((shift_in_s >> shift_amt) << shift_amt),
381 carry.eq((shift_in_s < 0) & (shift_roundtrip != shift_in_s)),
382 ]
383 with m.Else():
384 m.d.comb += [
385 expected.eq(shift_in_u >> shift_amt),
386 carry.eq(0),
387 ]
388 m.d.comb += Assert(dut.o.o.data == expected)
389 m.d.comb += Assert(dut.o.xer_ca.data == Repl(carry, 2))
390
391 def _check_rlc(self, m, dut):
392 raise NotImplementedError
393 m.submodules.mask = mask = Mask()
394 with m.If():
395 pass
396 m.d.comb += Assert(dut.o.xer_ca.data == 0)
397
398 def _check_rlcl(self, m, dut):
399 raise NotImplementedError
400
401 def _check_rlcr(self, m, dut):
402 raise NotImplementedError
403
404 def _check_extswsli(self, m, dut):
405 m.d.comb += Assume(dut.i.ra == 0)
406 m.d.comb += Assume(dut.i.rb[6:] == 0)
407 m.d.comb += Assume(~dut.i.ctx.op.is_32bit) # all instrs. are 64-bit
408 expected = Signal(64)
409 m.d.comb += expected.eq((dut.i.rs[0:32].as_signed() << dut.i.rb[:6]))
410 m.d.comb += Assert(dut.o.o.data == expected)
411 m.d.comb += Assert(dut.o.xer_ca.data == 0)
412
413 def _check_ternlog(self, m, dut):
414 lut = dut.fields.FormTLI.TLI[:]
415 for i in range(64):
416 idx = Cat(dut.i.rb[i], dut.i.ra[i], dut.i.rc[i])
417 for j in range(8):
418 with m.If(j == idx):
419 m.d.comb += Assert(dut.o.o.data[i] == lut[j])
420 m.d.comb += Assert(dut.o.xer_ca.data == 0)
421
422 def _check_grev32(self, m, dut):
423 m.d.comb += Assume(dut.i.ctx.op.is_32bit)
424 # assert zero-extended
425 m.d.comb += Assert(dut.o.o.data[32:] == 0)
426 i = Signal(5)
427 m.d.comb += eq_any_const(i)
428 idx = dut.i.rb[0: 5] ^ i
429 m.d.comb += Assert((dut.o.o.data >> i)[0] == (dut.i.ra >> idx)[0])
430 m.d.comb += Assert(dut.o.xer_ca.data == 0)
431
432 def _check_grev64(self, m, dut):
433 m.d.comb += Assume(~dut.i.ctx.op.is_32bit)
434 i = Signal(6)
435 m.d.comb += eq_any_const(i)
436 idx = dut.i.rb[0: 6] ^ i
437 m.d.comb += Assert((dut.o.o.data >> i)[0] == (dut.i.ra >> idx)[0])
438 m.d.comb += Assert(dut.o.xer_ca.data == 0)
439
440
441 class ALUTestCase(FHDLTestCase):
442 def run_it(self, which):
443 module = Driver(which)
444 self.assertFormal(module, mode="bmc", depth=2)
445 self.assertFormal(module, mode="cover", depth=2)
446
447 def test_shl(self):
448 self.run_it(TstOp.SHL)
449
450 def test_shr(self):
451 self.run_it(TstOp.SHR)
452
453 def test_rlc(self):
454 self.run_it(TstOp.RLC)
455
456 def test_rlcl(self):
457 self.run_it(TstOp.RLCL)
458
459 def test_rlcr(self):
460 self.run_it(TstOp.RLCR)
461
462 def test_extswsli(self):
463 self.run_it(TstOp.EXTSWSLI)
464
465 def test_ternlog(self):
466 self.run_it(TstOp.TERNLOG)
467
468 def test_grev32(self):
469 self.run_it(TstOp.GREV32)
470
471 def test_grev64(self):
472 self.run_it(TstOp.GREV64)
473
474
475 # check that all test cases are covered
476 for i in TstOp:
477 assert callable(getattr(ALUTestCase, f"test_{i.name.lower()}"))
478
479
480 if __name__ == '__main__':
481 unittest.main()