7ef3258fabb04fa254b65df395b1c5945e4c3d46
[soc.git] / src / soc / fu / shift_rot / formal / proof_main_stage.py
1 # Proof of correctness for shift/rotate FU
2 # Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
3 """
4 Links:
5 * https://bugs.libre-soc.org/show_bug.cgi?id=340
6
7 run tests with:
8 pip install pytest
9 pip install pytest-xdist
10 pytest -n auto src/soc/fu/shift_rot/formal/proof_main_stage.py
11 because that tells pytest to run the tests in parallel, it will take a few
12 minutes instead of an hour.
13 """
14
15 import unittest
16 import enum
17 from nmigen import (Module, Signal, Elaboratable, Mux, Cat, Repl,
18 signed, Const, unsigned)
19 from nmigen.asserts import Assert, AnyConst, Assume
20 from nmutil.formaltest import FHDLTestCase
21 from nmutil.sim_util import do_sim
22 from nmigen.sim import Delay
23
24 from soc.fu.shift_rot.main_stage import ShiftRotMainStage
25 from soc.fu.shift_rot.pipe_data import ShiftRotPipeSpec
26 from openpower.decoder.power_enums import MicrOp
27
28
29 @enum.unique
30 class TstOp(enum.Enum):
31 """ops we're testing, the idea is if we run a separate formal proof for
32 each instruction, we end up covering them all and each runs much faster,
33 also the formal proofs can be run in parallel."""
34 SHL = MicrOp.OP_SHL
35 SHR = MicrOp.OP_SHR
36 RLC32 = MicrOp.OP_RLC, 32
37 RLC64 = MicrOp.OP_RLC, 64
38 RLCL = MicrOp.OP_RLCL
39 RLCR = MicrOp.OP_RLCR
40 EXTSWSLI = MicrOp.OP_EXTSWSLI
41 TERNLOG = MicrOp.OP_TERNLOG
42 GREV32 = MicrOp.OP_GREV, 32
43 GREV64 = MicrOp.OP_GREV, 64
44
45 @property
46 def op(self):
47 if isinstance(self.value, tuple):
48 return self.value[0]
49 return self.value
50
51
52 def eq_any_const(sig: Signal):
53 return sig.eq(AnyConst(sig.shape(), src_loc_at=1))
54
55
56 class Mask(Elaboratable):
57 # copied from qemu's mask fn:
58 # https://gitlab.com/qemu-project/qemu/-/blob/477c3b934a47adf7de285863f59d6e4503dd1a6d/target/ppc/internal.h#L21
59 def __init__(self):
60 self.start = Signal(6)
61 self.end = Signal(6)
62 self.out = Signal(64)
63
64 def elaborate(self, platform):
65 m = Module()
66 max_val = Const(~0, unsigned(64))
67 max_bit = 63
68 with m.If(self.start == 0):
69 m.d.comb += self.out.eq(max_val << (max_bit - self.end))
70 with m.Elif(self.end == max_bit):
71 m.d.comb += self.out.eq(max_val >> self.start)
72 with m.Else():
73 ret = (max_val >> self.start) ^ ((max_val >> self.end) >> 1)
74 m.d.comb += self.out.eq(Mux(self.start > self.end, ~ret, ret))
75 return m
76
77
78 class TstMask(unittest.TestCase):
79 def test_mask(self):
80 dut = Mask()
81
82 def case(start, end, expected):
83 with self.subTest(start=start, end=end):
84 yield dut.start.eq(start)
85 yield dut.end.eq(end)
86 yield Delay(1e-6)
87 out = yield dut.out
88 with self.subTest(out=hex(out), expected=hex(expected)):
89 self.assertEqual(expected, out)
90
91 def process():
92 for start in range(64):
93 for end in range(64):
94 expected = 0
95 if start > end:
96 for i in range(start, 64):
97 expected |= 1 << (63 - i)
98 for i in range(0, end + 1):
99 expected |= 1 << (63 - i)
100 else:
101 for i in range(start, end + 1):
102 expected |= 1 << (63 - i)
103 yield from case(start, end, expected)
104 with do_sim(self, dut, [dut.start, dut.end, dut.out]) as sim:
105 sim.add_process(process)
106 sim.run()
107
108
109 def rotl64(v, amt):
110 v |= Const(0, 64) # convert to value at least 64-bits wide
111 amt |= Const(0, 6) # convert to value at least 6-bits wide
112 return (Cat(v[:64], v[:64]) >> (64 - amt[:6]))[:64]
113
114
115 def rotl32(v, amt):
116 v |= Const(0, 32) # convert to value at least 32-bits wide
117 return rotl64(Cat(v[:32], v[:32]), amt)
118
119
120 # This defines a module to drive the device under test and assert
121 # properties about its outputs
122 class Driver(Elaboratable):
123 def __init__(self, which):
124 assert isinstance(which, TstOp) or which is None
125 self.which = which
126
127 def elaborate(self, platform):
128 m = Module()
129 comb = m.d.comb
130
131 pspec = ShiftRotPipeSpec(id_wid=2, parent_pspec=None)
132 pspec.draft_bitmanip = True
133 m.submodules.dut = dut = ShiftRotMainStage(pspec)
134
135 # Set inputs to formal variables
136 comb += [
137 eq_any_const(dut.i.ctx.op.insn_type),
138 eq_any_const(dut.i.ctx.op.fn_unit),
139 eq_any_const(dut.i.ctx.op.imm_data.data),
140 eq_any_const(dut.i.ctx.op.imm_data.ok),
141 eq_any_const(dut.i.ctx.op.rc.rc),
142 eq_any_const(dut.i.ctx.op.rc.ok),
143 eq_any_const(dut.i.ctx.op.oe.oe),
144 eq_any_const(dut.i.ctx.op.oe.ok),
145 eq_any_const(dut.i.ctx.op.write_cr0),
146 eq_any_const(dut.i.ctx.op.input_carry),
147 eq_any_const(dut.i.ctx.op.output_carry),
148 eq_any_const(dut.i.ctx.op.input_cr),
149 eq_any_const(dut.i.ctx.op.is_32bit),
150 eq_any_const(dut.i.ctx.op.is_signed),
151 eq_any_const(dut.i.ctx.op.insn),
152 eq_any_const(dut.i.xer_ca),
153 eq_any_const(dut.i.ra),
154 eq_any_const(dut.i.rb),
155 eq_any_const(dut.i.rc),
156 ]
157
158 # check that the operation (op) is passed through (and muxid)
159 comb += Assert(dut.o.ctx.op == dut.i.ctx.op)
160 comb += Assert(dut.o.ctx.muxid == dut.i.ctx.muxid)
161
162 if self.which is None:
163 for i in TstOp:
164 comb += Assume(dut.i.ctx.op.insn_type != i.op)
165 comb += Assert(~dut.o.o.ok)
166 else:
167 # we're only checking a particular operation:
168 comb += Assume(dut.i.ctx.op.insn_type == self.which.op)
169 comb += Assert(dut.o.o.ok)
170
171 # dispatch to check fn for each op
172 getattr(self, f"_check_{self.which.name.lower()}")(m, dut)
173
174 return m
175
176 def _check_shl(self, m, dut):
177 m.d.comb += Assume(dut.i.ra == 0)
178 expected = Signal(64)
179 with m.If(dut.i.ctx.op.is_32bit):
180 m.d.comb += expected.eq((dut.i.rs << dut.i.rb[:6])[:32])
181 with m.Else():
182 m.d.comb += expected.eq((dut.i.rs << dut.i.rb[:7])[:64])
183 m.d.comb += Assert(dut.o.o.data == expected)
184 m.d.comb += Assert(dut.o.xer_ca.data == 0)
185
186 def _check_shr(self, m, dut):
187 m.d.comb += Assume(dut.i.ra == 0)
188 expected = Signal(64)
189 carry = Signal()
190 shift_in_s = Signal(signed(128))
191 shift_roundtrip = Signal(signed(128))
192 shift_in_u = Signal(128)
193 shift_amt = Signal(7)
194 with m.If(dut.i.ctx.op.is_32bit):
195 m.d.comb += [
196 shift_amt.eq(dut.i.rb[:6]),
197 shift_in_s.eq(dut.i.rs[:32].as_signed()),
198 shift_in_u.eq(dut.i.rs[:32]),
199 ]
200 with m.Else():
201 m.d.comb += [
202 shift_amt.eq(dut.i.rb[:7]),
203 shift_in_s.eq(dut.i.rs.as_signed()),
204 shift_in_u.eq(dut.i.rs),
205 ]
206
207 with m.If(dut.i.ctx.op.is_signed):
208 m.d.comb += [
209 expected.eq(shift_in_s >> shift_amt),
210 shift_roundtrip.eq((shift_in_s >> shift_amt) << shift_amt),
211 carry.eq((shift_in_s < 0) & (shift_roundtrip != shift_in_s)),
212 ]
213 with m.Else():
214 m.d.comb += [
215 expected.eq(shift_in_u >> shift_amt),
216 carry.eq(0),
217 ]
218 m.d.comb += Assert(dut.o.o.data == expected)
219 m.d.comb += Assert(dut.o.xer_ca.data == Repl(carry, 2))
220
221 def _check_rlc32(self, m, dut):
222 m.d.comb += Assume(dut.i.ctx.op.is_32bit)
223 # rlwimi, rlwinm, and rlwnm
224
225 m.submodules.mask = mask = Mask()
226 expected = Signal(64)
227 rot = Signal(64)
228 m.d.comb += rot.eq(rotl32(dut.i.rs[:32], dut.i.rb[:5]))
229 m.d.comb += mask.start.eq(dut.fields.FormM.MB[:] + 32)
230 m.d.comb += mask.end.eq(dut.fields.FormM.ME[:] + 32)
231
232 # for rlwinm and rlwnm, ra is guaranteed to be 0, so that part of
233 # the expression turns into a no-op
234 m.d.comb += expected.eq((rot & mask.out) | (dut.i.ra & ~mask.out))
235 m.d.comb += Assert(dut.o.o.data == expected)
236 m.d.comb += Assert(dut.o.xer_ca.data == 0)
237
238 def _check_rlc64(self, m, dut):
239 m.d.comb += Assume(~dut.i.ctx.op.is_32bit)
240 # rldic and rldimi
241
242 # `rb` is always a 6-bit immediate
243 m.d.comb += Assume(dut.i.rb[6:] == 0)
244
245 m.submodules.mask = mask = Mask()
246 expected = Signal(64)
247 rot = Signal(64)
248 m.d.comb += rot.eq(rotl64(dut.i.rs, dut.i.rb[:6]))
249 mb = dut.fields.FormMD.mb[:]
250 m.d.comb += mask.start.eq(Cat(mb[1:6], mb[0]))
251 m.d.comb += mask.end.eq(63 - dut.i.rb[:6])
252
253 # for rldic, ra is guaranteed to be 0, so that part of
254 # the expression turns into a no-op
255 m.d.comb += expected.eq((rot & mask.out) | (dut.i.ra & ~mask.out))
256 m.d.comb += Assert(dut.o.o.data == expected)
257 m.d.comb += Assert(dut.o.xer_ca.data == 0)
258
259 def _check_rlcl(self, m, dut):
260 m.d.comb += Assume(~dut.i.ctx.op.is_32bit)
261 # rldicl and rldcl
262
263 m.d.comb += Assume(~dut.i.ctx.op.is_signed)
264 m.d.comb += Assume(dut.i.ra == 0)
265
266 m.submodules.mask = mask = Mask()
267 m.d.comb += mask.end.eq(63)
268 mb = dut.fields.FormMD.mb[:]
269 m.d.comb += mask.start.eq(Cat(mb[1:6], mb[0]))
270
271 rot = Signal(64)
272 m.d.comb += rot.eq(rotl64(dut.i.rs, dut.i.rb[:6]))
273
274 expected = Signal(64)
275 m.d.comb += expected.eq(rot & mask.out)
276
277 m.d.comb += Assert(dut.o.o.data == expected)
278 m.d.comb += Assert(dut.o.xer_ca.data == 0)
279
280 def _check_rlcr(self, m, dut):
281 m.d.comb += Assume(~dut.i.ctx.op.is_32bit)
282 # rldicr and rldcr
283
284 m.d.comb += Assume(~dut.i.ctx.op.is_signed)
285 m.d.comb += Assume(dut.i.ra == 0)
286
287 m.submodules.mask = mask = Mask()
288 m.d.comb += mask.start.eq(0)
289 me = dut.fields.FormMD.me[:]
290 m.d.comb += mask.end.eq(Cat(me[1:6], me[0]))
291
292 rot = Signal(64)
293 m.d.comb += rot.eq(rotl64(dut.i.rs, dut.i.rb[:6]))
294
295 expected = Signal(64)
296 m.d.comb += expected.eq(rot & mask.out)
297
298 m.d.comb += Assert(dut.o.o.data == expected)
299 m.d.comb += Assert(dut.o.xer_ca.data == 0)
300
301 def _check_extswsli(self, m, dut):
302 m.d.comb += Assume(dut.i.ra == 0)
303 m.d.comb += Assume(dut.i.rb[6:] == 0)
304 m.d.comb += Assume(~dut.i.ctx.op.is_32bit) # all instrs. are 64-bit
305 expected = Signal(64)
306 m.d.comb += expected.eq((dut.i.rs[0:32].as_signed() << dut.i.rb[:6]))
307 m.d.comb += Assert(dut.o.o.data == expected)
308 m.d.comb += Assert(dut.o.xer_ca.data == 0)
309
310 def _check_ternlog(self, m, dut):
311 lut = dut.fields.FormTLI.TLI[:]
312 for i in range(64):
313 idx = Cat(dut.i.rb[i], dut.i.ra[i], dut.i.rc[i])
314 for j in range(8):
315 with m.If(j == idx):
316 m.d.comb += Assert(dut.o.o.data[i] == lut[j])
317 m.d.comb += Assert(dut.o.xer_ca.data == 0)
318
319 def _check_grev32(self, m, dut):
320 m.d.comb += Assume(dut.i.ctx.op.is_32bit)
321 # assert zero-extended
322 m.d.comb += Assert(dut.o.o.data[32:] == 0)
323 i = Signal(5)
324 m.d.comb += eq_any_const(i)
325 idx = dut.i.rb[0: 5] ^ i
326 m.d.comb += Assert((dut.o.o.data >> i)[0] == (dut.i.ra >> idx)[0])
327 m.d.comb += Assert(dut.o.xer_ca.data == 0)
328
329 def _check_grev64(self, m, dut):
330 m.d.comb += Assume(~dut.i.ctx.op.is_32bit)
331 i = Signal(6)
332 m.d.comb += eq_any_const(i)
333 idx = dut.i.rb[0: 6] ^ i
334 m.d.comb += Assert((dut.o.o.data >> i)[0] == (dut.i.ra >> idx)[0])
335 m.d.comb += Assert(dut.o.xer_ca.data == 0)
336
337
338 class ALUTestCase(FHDLTestCase):
339 def run_it(self, which):
340 module = Driver(which)
341 self.assertFormal(module, mode="bmc", depth=2)
342 self.assertFormal(module, mode="cover", depth=2)
343
344 def test_none(self):
345 self.run_it(None)
346
347 def test_shl(self):
348 self.run_it(TstOp.SHL)
349
350 def test_shr(self):
351 self.run_it(TstOp.SHR)
352
353 def test_rlc32(self):
354 self.run_it(TstOp.RLC32)
355
356 def test_rlc64(self):
357 self.run_it(TstOp.RLC64)
358
359 def test_rlcl(self):
360 self.run_it(TstOp.RLCL)
361
362 def test_rlcr(self):
363 self.run_it(TstOp.RLCR)
364
365 def test_extswsli(self):
366 self.run_it(TstOp.EXTSWSLI)
367
368 def test_ternlog(self):
369 self.run_it(TstOp.TERNLOG)
370
371 def test_grev32(self):
372 self.run_it(TstOp.GREV32)
373
374 def test_grev64(self):
375 self.run_it(TstOp.GREV64)
376
377
378 # check that all test cases are covered
379 for i in TstOp:
380 assert callable(getattr(ALUTestCase, f"test_{i.name.lower()}"))
381
382
383 if __name__ == '__main__':
384 unittest.main()