1 # Proof of correctness for shift/rotate FU
2 # Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
5 * https://bugs.libre-soc.org/show_bug.cgi?id=340
9 pip install pytest-xdist
10 pytest -n auto src/soc/fu/shift_rot/formal/proof_main_stage.py
11 because that tells pytest to run the tests in parallel, it will take a few
12 minutes instead of an hour.
17 from nmigen
import (Module
, Signal
, Elaboratable
, Mux
, Cat
, Repl
,
18 signed
, Const
, unsigned
)
19 from nmigen
.asserts
import Assert
, AnyConst
, Assume
20 from nmutil
.formaltest
import FHDLTestCase
21 from nmutil
.sim_util
import do_sim
22 from nmigen
.sim
import Delay
24 from soc
.fu
.shift_rot
.main_stage
import ShiftRotMainStage
25 from soc
.fu
.shift_rot
.pipe_data
import ShiftRotPipeSpec
26 from openpower
.decoder
.power_enums
import MicrOp
30 class TstOp(enum
.Enum
):
31 """ops we're testing, the idea is if we run a separate formal proof for
32 each instruction, we end up covering them all and each runs much faster,
33 also the formal proofs can be run in parallel."""
36 RLC32
= MicrOp
.OP_RLC
, 32
37 RLC64
= MicrOp
.OP_RLC
, 64
40 EXTSWSLI
= MicrOp
.OP_EXTSWSLI
41 TERNLOG
= MicrOp
.OP_TERNLOG
42 # grev removed -- leaving code for later use in grevlut
43 # GREV32 = MicrOp.OP_GREV, 32
44 # GREV64 = MicrOp.OP_GREV, 64
48 if isinstance(self
.value
, tuple):
53 def eq_any_const(sig
: Signal
):
54 return sig
.eq(AnyConst(sig
.shape(), src_loc_at
=1))
57 class Mask(Elaboratable
):
58 # copied from qemu's mask fn:
59 # https://gitlab.com/qemu-project/qemu/-/blob/477c3b934a47adf7de285863f59d6e4503dd1a6d/target/ppc/internal.h#L21
61 self
.start
= Signal(6)
65 def elaborate(self
, platform
):
67 max_val
= Const(~
0, unsigned(64))
69 with m
.If(self
.start
== 0):
70 m
.d
.comb
+= self
.out
.eq(max_val
<< (max_bit
- self
.end
))
71 with m
.Elif(self
.end
== max_bit
):
72 m
.d
.comb
+= self
.out
.eq(max_val
>> self
.start
)
74 ret
= (max_val
>> self
.start
) ^
((max_val
>> self
.end
) >> 1)
75 m
.d
.comb
+= self
.out
.eq(Mux(self
.start
> self
.end
, ~ret
, ret
))
79 class TstMask(unittest
.TestCase
):
83 def case(start
, end
, expected
):
84 with self
.subTest(start
=start
, end
=end
):
85 yield dut
.start
.eq(start
)
89 with self
.subTest(out
=hex(out
), expected
=hex(expected
)):
90 self
.assertEqual(expected
, out
)
93 for start
in range(64):
97 for i
in range(start
, 64):
98 expected |
= 1 << (63 - i
)
99 for i
in range(0, end
+ 1):
100 expected |
= 1 << (63 - i
)
102 for i
in range(start
, end
+ 1):
103 expected |
= 1 << (63 - i
)
104 yield from case(start
, end
, expected
)
105 with
do_sim(self
, dut
, [dut
.start
, dut
.end
, dut
.out
]) as sim
:
106 sim
.add_process(process
)
111 v |
= Const(0, 64) # convert to value at least 64-bits wide
112 amt |
= Const(0, 6) # convert to value at least 6-bits wide
113 return (Cat(v
[:64], v
[:64]) >> (64 - amt
[:6]))[:64]
117 v |
= Const(0, 32) # convert to value at least 32-bits wide
118 return rotl64(Cat(v
[:32], v
[:32]), amt
)
121 # This defines a module to drive the device under test and assert
122 # properties about its outputs
123 class Driver(Elaboratable
):
124 def __init__(self
, which
):
125 assert isinstance(which
, TstOp
) or which
is None
128 def elaborate(self
, platform
):
132 pspec
= ShiftRotPipeSpec(id_wid
=2, parent_pspec
=None)
133 pspec
.draft_bitmanip
= True
134 m
.submodules
.dut
= dut
= ShiftRotMainStage(pspec
)
136 # Set inputs to formal variables
138 eq_any_const(dut
.i
.ctx
.op
.insn_type
),
139 eq_any_const(dut
.i
.ctx
.op
.fn_unit
),
140 eq_any_const(dut
.i
.ctx
.op
.imm_data
.data
),
141 eq_any_const(dut
.i
.ctx
.op
.imm_data
.ok
),
142 eq_any_const(dut
.i
.ctx
.op
.rc
.rc
),
143 eq_any_const(dut
.i
.ctx
.op
.rc
.ok
),
144 eq_any_const(dut
.i
.ctx
.op
.oe
.oe
),
145 eq_any_const(dut
.i
.ctx
.op
.oe
.ok
),
146 eq_any_const(dut
.i
.ctx
.op
.write_cr0
),
147 eq_any_const(dut
.i
.ctx
.op
.input_carry
),
148 eq_any_const(dut
.i
.ctx
.op
.output_carry
),
149 eq_any_const(dut
.i
.ctx
.op
.input_cr
),
150 eq_any_const(dut
.i
.ctx
.op
.is_32bit
),
151 eq_any_const(dut
.i
.ctx
.op
.is_signed
),
152 eq_any_const(dut
.i
.ctx
.op
.insn
),
153 eq_any_const(dut
.i
.xer_ca
),
154 eq_any_const(dut
.i
.ra
),
155 eq_any_const(dut
.i
.rb
),
156 eq_any_const(dut
.i
.rc
),
159 # check that the operation (op) is passed through (and muxid)
160 comb
+= Assert(dut
.o
.ctx
.op
== dut
.i
.ctx
.op
)
161 comb
+= Assert(dut
.o
.ctx
.muxid
== dut
.i
.ctx
.muxid
)
163 if self
.which
is None:
165 comb
+= Assume(dut
.i
.ctx
.op
.insn_type
!= i
.op
)
166 comb
+= Assert(~dut
.o
.o
.ok
)
168 # we're only checking a particular operation:
169 comb
+= Assume(dut
.i
.ctx
.op
.insn_type
== self
.which
.op
)
170 comb
+= Assert(dut
.o
.o
.ok
)
172 # dispatch to check fn for each op
173 getattr(self
, f
"_check_{self.which.name.lower()}")(m
, dut
)
177 def _check_shl(self
, m
, dut
):
178 m
.d
.comb
+= Assume(dut
.i
.ra
== 0)
179 expected
= Signal(64)
180 with m
.If(dut
.i
.ctx
.op
.is_32bit
):
181 m
.d
.comb
+= expected
.eq((dut
.i
.rs
<< dut
.i
.rb
[:6])[:32])
183 m
.d
.comb
+= expected
.eq((dut
.i
.rs
<< dut
.i
.rb
[:7])[:64])
184 m
.d
.comb
+= Assert(dut
.o
.o
.data
== expected
)
185 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
187 def _check_shr(self
, m
, dut
):
188 m
.d
.comb
+= Assume(dut
.i
.ra
== 0)
189 expected
= Signal(64)
191 shift_in_s
= Signal(signed(128))
192 shift_roundtrip
= Signal(signed(128))
193 shift_in_u
= Signal(128)
194 shift_amt
= Signal(7)
195 with m
.If(dut
.i
.ctx
.op
.is_32bit
):
197 shift_amt
.eq(dut
.i
.rb
[:6]),
198 shift_in_s
.eq(dut
.i
.rs
[:32].as_signed()),
199 shift_in_u
.eq(dut
.i
.rs
[:32]),
203 shift_amt
.eq(dut
.i
.rb
[:7]),
204 shift_in_s
.eq(dut
.i
.rs
.as_signed()),
205 shift_in_u
.eq(dut
.i
.rs
),
208 with m
.If(dut
.i
.ctx
.op
.is_signed
):
210 expected
.eq(shift_in_s
>> shift_amt
),
211 shift_roundtrip
.eq((shift_in_s
>> shift_amt
) << shift_amt
),
212 carry
.eq((shift_in_s
< 0) & (shift_roundtrip
!= shift_in_s
)),
216 expected
.eq(shift_in_u
>> shift_amt
),
219 m
.d
.comb
+= Assert(dut
.o
.o
.data
== expected
)
220 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== Repl(carry
, 2))
222 def _check_rlc32(self
, m
, dut
):
223 m
.d
.comb
+= Assume(dut
.i
.ctx
.op
.is_32bit
)
224 # rlwimi, rlwinm, and rlwnm
226 m
.submodules
.mask
= mask
= Mask()
227 expected
= Signal(64)
229 m
.d
.comb
+= rot
.eq(rotl32(dut
.i
.rs
[:32], dut
.i
.rb
[:5]))
230 m
.d
.comb
+= mask
.start
.eq(dut
.fields
.FormM
.MB
[:] + 32)
231 m
.d
.comb
+= mask
.end
.eq(dut
.fields
.FormM
.ME
[:] + 32)
233 # for rlwinm and rlwnm, ra is guaranteed to be 0, so that part of
234 # the expression turns into a no-op
235 m
.d
.comb
+= expected
.eq((rot
& mask
.out
) |
(dut
.i
.ra
& ~mask
.out
))
236 m
.d
.comb
+= Assert(dut
.o
.o
.data
== expected
)
237 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
239 def _check_rlc64(self
, m
, dut
):
240 m
.d
.comb
+= Assume(~dut
.i
.ctx
.op
.is_32bit
)
243 # `rb` is always a 6-bit immediate
244 m
.d
.comb
+= Assume(dut
.i
.rb
[6:] == 0)
246 m
.submodules
.mask
= mask
= Mask()
247 expected
= Signal(64)
249 m
.d
.comb
+= rot
.eq(rotl64(dut
.i
.rs
, dut
.i
.rb
[:6]))
250 mb
= dut
.fields
.FormMD
.mb
[:]
251 m
.d
.comb
+= mask
.start
.eq(Cat(mb
[1:6], mb
[0]))
252 m
.d
.comb
+= mask
.end
.eq(63 - dut
.i
.rb
[:6])
254 # for rldic, ra is guaranteed to be 0, so that part of
255 # the expression turns into a no-op
256 m
.d
.comb
+= expected
.eq((rot
& mask
.out
) |
(dut
.i
.ra
& ~mask
.out
))
257 m
.d
.comb
+= Assert(dut
.o
.o
.data
== expected
)
258 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
260 def _check_rlcl(self
, m
, dut
):
261 m
.d
.comb
+= Assume(~dut
.i
.ctx
.op
.is_32bit
)
264 m
.d
.comb
+= Assume(~dut
.i
.ctx
.op
.is_signed
)
265 m
.d
.comb
+= Assume(dut
.i
.ra
== 0)
267 m
.submodules
.mask
= mask
= Mask()
268 m
.d
.comb
+= mask
.end
.eq(63)
269 mb
= dut
.fields
.FormMD
.mb
[:]
270 m
.d
.comb
+= mask
.start
.eq(Cat(mb
[1:6], mb
[0]))
273 m
.d
.comb
+= rot
.eq(rotl64(dut
.i
.rs
, dut
.i
.rb
[:6]))
275 expected
= Signal(64)
276 m
.d
.comb
+= expected
.eq(rot
& mask
.out
)
278 m
.d
.comb
+= Assert(dut
.o
.o
.data
== expected
)
279 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
281 def _check_rlcr(self
, m
, dut
):
282 m
.d
.comb
+= Assume(~dut
.i
.ctx
.op
.is_32bit
)
285 m
.d
.comb
+= Assume(~dut
.i
.ctx
.op
.is_signed
)
286 m
.d
.comb
+= Assume(dut
.i
.ra
== 0)
288 m
.submodules
.mask
= mask
= Mask()
289 m
.d
.comb
+= mask
.start
.eq(0)
290 me
= dut
.fields
.FormMD
.me
[:]
291 m
.d
.comb
+= mask
.end
.eq(Cat(me
[1:6], me
[0]))
294 m
.d
.comb
+= rot
.eq(rotl64(dut
.i
.rs
, dut
.i
.rb
[:6]))
296 expected
= Signal(64)
297 m
.d
.comb
+= expected
.eq(rot
& mask
.out
)
299 m
.d
.comb
+= Assert(dut
.o
.o
.data
== expected
)
300 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
302 def _check_extswsli(self
, m
, dut
):
303 m
.d
.comb
+= Assume(dut
.i
.ra
== 0)
304 m
.d
.comb
+= Assume(dut
.i
.rb
[6:] == 0)
305 m
.d
.comb
+= Assume(~dut
.i
.ctx
.op
.is_32bit
) # all instrs. are 64-bit
306 expected
= Signal(64)
307 m
.d
.comb
+= expected
.eq((dut
.i
.rs
[0:32].as_signed() << dut
.i
.rb
[:6]))
308 m
.d
.comb
+= Assert(dut
.o
.o
.data
== expected
)
309 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
311 def _check_ternlog(self
, m
, dut
):
312 lut
= dut
.fields
.FormTLI
.TLI
[:]
314 idx
= Cat(dut
.i
.rb
[i
], dut
.i
.ra
[i
], dut
.i
.rc
[i
])
317 m
.d
.comb
+= Assert(dut
.o
.o
.data
[i
] == lut
[j
])
318 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
320 # grev removed -- leaving code for later use in grevlut
321 def _check_grev32(self
, m
, dut
):
322 m
.d
.comb
+= Assume(dut
.i
.ctx
.op
.is_32bit
)
323 # assert zero-extended
324 m
.d
.comb
+= Assert(dut
.o
.o
.data
[32:] == 0)
326 m
.d
.comb
+= eq_any_const(i
)
327 idx
= dut
.i
.rb
[0: 5] ^ i
328 m
.d
.comb
+= Assert((dut
.o
.o
.data
>> i
)[0] == (dut
.i
.ra
>> idx
)[0])
329 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
331 # grev removed -- leaving code for later use in grevlut
332 def _check_grev64(self
, m
, dut
):
333 m
.d
.comb
+= Assume(~dut
.i
.ctx
.op
.is_32bit
)
335 m
.d
.comb
+= eq_any_const(i
)
336 idx
= dut
.i
.rb
[0: 6] ^ i
337 m
.d
.comb
+= Assert((dut
.o
.o
.data
>> i
)[0] == (dut
.i
.ra
>> idx
)[0])
338 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
341 class ALUTestCase(FHDLTestCase
):
342 def run_it(self
, which
):
343 module
= Driver(which
)
344 self
.assertFormal(module
, mode
="bmc", depth
=2)
345 self
.assertFormal(module
, mode
="cover", depth
=2)
351 self
.run_it(TstOp
.SHL
)
354 self
.run_it(TstOp
.SHR
)
356 def test_rlc32(self
):
357 self
.run_it(TstOp
.RLC32
)
359 def test_rlc64(self
):
360 self
.run_it(TstOp
.RLC64
)
363 self
.run_it(TstOp
.RLCL
)
366 self
.run_it(TstOp
.RLCR
)
368 def test_extswsli(self
):
369 self
.run_it(TstOp
.EXTSWSLI
)
371 def test_ternlog(self
):
372 self
.run_it(TstOp
.TERNLOG
)
374 @unittest.skip("grev removed -- leaving code for later use in grevlut")
375 def test_grev32(self
):
376 self
.run_it(TstOp
.GREV32
)
378 @unittest.skip("grev removed -- leaving code for later use in grevlut")
379 def test_grev64(self
):
380 self
.run_it(TstOp
.GREV64
)
383 # check that all test cases are covered
385 assert callable(getattr(ALUTestCase
, f
"test_{i.name.lower()}"))
388 if __name__
== '__main__':