7ef3258fabb04fa254b65df395b1c5945e4c3d46
1 # Proof of correctness for shift/rotate FU
2 # Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
5 * https://bugs.libre-soc.org/show_bug.cgi?id=340
9 pip install pytest-xdist
10 pytest -n auto src/soc/fu/shift_rot/formal/proof_main_stage.py
11 because that tells pytest to run the tests in parallel, it will take a few
12 minutes instead of an hour.
17 from nmigen
import (Module
, Signal
, Elaboratable
, Mux
, Cat
, Repl
,
18 signed
, Const
, unsigned
)
19 from nmigen
.asserts
import Assert
, AnyConst
, Assume
20 from nmutil
.formaltest
import FHDLTestCase
21 from nmutil
.sim_util
import do_sim
22 from nmigen
.sim
import Delay
24 from soc
.fu
.shift_rot
.main_stage
import ShiftRotMainStage
25 from soc
.fu
.shift_rot
.pipe_data
import ShiftRotPipeSpec
26 from openpower
.decoder
.power_enums
import MicrOp
30 class TstOp(enum
.Enum
):
31 """ops we're testing, the idea is if we run a separate formal proof for
32 each instruction, we end up covering them all and each runs much faster,
33 also the formal proofs can be run in parallel."""
36 RLC32
= MicrOp
.OP_RLC
, 32
37 RLC64
= MicrOp
.OP_RLC
, 64
40 EXTSWSLI
= MicrOp
.OP_EXTSWSLI
41 TERNLOG
= MicrOp
.OP_TERNLOG
42 GREV32
= MicrOp
.OP_GREV
, 32
43 GREV64
= MicrOp
.OP_GREV
, 64
47 if isinstance(self
.value
, tuple):
52 def eq_any_const(sig
: Signal
):
53 return sig
.eq(AnyConst(sig
.shape(), src_loc_at
=1))
56 class Mask(Elaboratable
):
57 # copied from qemu's mask fn:
58 # https://gitlab.com/qemu-project/qemu/-/blob/477c3b934a47adf7de285863f59d6e4503dd1a6d/target/ppc/internal.h#L21
60 self
.start
= Signal(6)
64 def elaborate(self
, platform
):
66 max_val
= Const(~
0, unsigned(64))
68 with m
.If(self
.start
== 0):
69 m
.d
.comb
+= self
.out
.eq(max_val
<< (max_bit
- self
.end
))
70 with m
.Elif(self
.end
== max_bit
):
71 m
.d
.comb
+= self
.out
.eq(max_val
>> self
.start
)
73 ret
= (max_val
>> self
.start
) ^
((max_val
>> self
.end
) >> 1)
74 m
.d
.comb
+= self
.out
.eq(Mux(self
.start
> self
.end
, ~ret
, ret
))
78 class TstMask(unittest
.TestCase
):
82 def case(start
, end
, expected
):
83 with self
.subTest(start
=start
, end
=end
):
84 yield dut
.start
.eq(start
)
88 with self
.subTest(out
=hex(out
), expected
=hex(expected
)):
89 self
.assertEqual(expected
, out
)
92 for start
in range(64):
96 for i
in range(start
, 64):
97 expected |
= 1 << (63 - i
)
98 for i
in range(0, end
+ 1):
99 expected |
= 1 << (63 - i
)
101 for i
in range(start
, end
+ 1):
102 expected |
= 1 << (63 - i
)
103 yield from case(start
, end
, expected
)
104 with
do_sim(self
, dut
, [dut
.start
, dut
.end
, dut
.out
]) as sim
:
105 sim
.add_process(process
)
110 v |
= Const(0, 64) # convert to value at least 64-bits wide
111 amt |
= Const(0, 6) # convert to value at least 6-bits wide
112 return (Cat(v
[:64], v
[:64]) >> (64 - amt
[:6]))[:64]
116 v |
= Const(0, 32) # convert to value at least 32-bits wide
117 return rotl64(Cat(v
[:32], v
[:32]), amt
)
120 # This defines a module to drive the device under test and assert
121 # properties about its outputs
122 class Driver(Elaboratable
):
123 def __init__(self
, which
):
124 assert isinstance(which
, TstOp
) or which
is None
127 def elaborate(self
, platform
):
131 pspec
= ShiftRotPipeSpec(id_wid
=2, parent_pspec
=None)
132 pspec
.draft_bitmanip
= True
133 m
.submodules
.dut
= dut
= ShiftRotMainStage(pspec
)
135 # Set inputs to formal variables
137 eq_any_const(dut
.i
.ctx
.op
.insn_type
),
138 eq_any_const(dut
.i
.ctx
.op
.fn_unit
),
139 eq_any_const(dut
.i
.ctx
.op
.imm_data
.data
),
140 eq_any_const(dut
.i
.ctx
.op
.imm_data
.ok
),
141 eq_any_const(dut
.i
.ctx
.op
.rc
.rc
),
142 eq_any_const(dut
.i
.ctx
.op
.rc
.ok
),
143 eq_any_const(dut
.i
.ctx
.op
.oe
.oe
),
144 eq_any_const(dut
.i
.ctx
.op
.oe
.ok
),
145 eq_any_const(dut
.i
.ctx
.op
.write_cr0
),
146 eq_any_const(dut
.i
.ctx
.op
.input_carry
),
147 eq_any_const(dut
.i
.ctx
.op
.output_carry
),
148 eq_any_const(dut
.i
.ctx
.op
.input_cr
),
149 eq_any_const(dut
.i
.ctx
.op
.is_32bit
),
150 eq_any_const(dut
.i
.ctx
.op
.is_signed
),
151 eq_any_const(dut
.i
.ctx
.op
.insn
),
152 eq_any_const(dut
.i
.xer_ca
),
153 eq_any_const(dut
.i
.ra
),
154 eq_any_const(dut
.i
.rb
),
155 eq_any_const(dut
.i
.rc
),
158 # check that the operation (op) is passed through (and muxid)
159 comb
+= Assert(dut
.o
.ctx
.op
== dut
.i
.ctx
.op
)
160 comb
+= Assert(dut
.o
.ctx
.muxid
== dut
.i
.ctx
.muxid
)
162 if self
.which
is None:
164 comb
+= Assume(dut
.i
.ctx
.op
.insn_type
!= i
.op
)
165 comb
+= Assert(~dut
.o
.o
.ok
)
167 # we're only checking a particular operation:
168 comb
+= Assume(dut
.i
.ctx
.op
.insn_type
== self
.which
.op
)
169 comb
+= Assert(dut
.o
.o
.ok
)
171 # dispatch to check fn for each op
172 getattr(self
, f
"_check_{self.which.name.lower()}")(m
, dut
)
176 def _check_shl(self
, m
, dut
):
177 m
.d
.comb
+= Assume(dut
.i
.ra
== 0)
178 expected
= Signal(64)
179 with m
.If(dut
.i
.ctx
.op
.is_32bit
):
180 m
.d
.comb
+= expected
.eq((dut
.i
.rs
<< dut
.i
.rb
[:6])[:32])
182 m
.d
.comb
+= expected
.eq((dut
.i
.rs
<< dut
.i
.rb
[:7])[:64])
183 m
.d
.comb
+= Assert(dut
.o
.o
.data
== expected
)
184 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
186 def _check_shr(self
, m
, dut
):
187 m
.d
.comb
+= Assume(dut
.i
.ra
== 0)
188 expected
= Signal(64)
190 shift_in_s
= Signal(signed(128))
191 shift_roundtrip
= Signal(signed(128))
192 shift_in_u
= Signal(128)
193 shift_amt
= Signal(7)
194 with m
.If(dut
.i
.ctx
.op
.is_32bit
):
196 shift_amt
.eq(dut
.i
.rb
[:6]),
197 shift_in_s
.eq(dut
.i
.rs
[:32].as_signed()),
198 shift_in_u
.eq(dut
.i
.rs
[:32]),
202 shift_amt
.eq(dut
.i
.rb
[:7]),
203 shift_in_s
.eq(dut
.i
.rs
.as_signed()),
204 shift_in_u
.eq(dut
.i
.rs
),
207 with m
.If(dut
.i
.ctx
.op
.is_signed
):
209 expected
.eq(shift_in_s
>> shift_amt
),
210 shift_roundtrip
.eq((shift_in_s
>> shift_amt
) << shift_amt
),
211 carry
.eq((shift_in_s
< 0) & (shift_roundtrip
!= shift_in_s
)),
215 expected
.eq(shift_in_u
>> shift_amt
),
218 m
.d
.comb
+= Assert(dut
.o
.o
.data
== expected
)
219 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== Repl(carry
, 2))
221 def _check_rlc32(self
, m
, dut
):
222 m
.d
.comb
+= Assume(dut
.i
.ctx
.op
.is_32bit
)
223 # rlwimi, rlwinm, and rlwnm
225 m
.submodules
.mask
= mask
= Mask()
226 expected
= Signal(64)
228 m
.d
.comb
+= rot
.eq(rotl32(dut
.i
.rs
[:32], dut
.i
.rb
[:5]))
229 m
.d
.comb
+= mask
.start
.eq(dut
.fields
.FormM
.MB
[:] + 32)
230 m
.d
.comb
+= mask
.end
.eq(dut
.fields
.FormM
.ME
[:] + 32)
232 # for rlwinm and rlwnm, ra is guaranteed to be 0, so that part of
233 # the expression turns into a no-op
234 m
.d
.comb
+= expected
.eq((rot
& mask
.out
) |
(dut
.i
.ra
& ~mask
.out
))
235 m
.d
.comb
+= Assert(dut
.o
.o
.data
== expected
)
236 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
238 def _check_rlc64(self
, m
, dut
):
239 m
.d
.comb
+= Assume(~dut
.i
.ctx
.op
.is_32bit
)
242 # `rb` is always a 6-bit immediate
243 m
.d
.comb
+= Assume(dut
.i
.rb
[6:] == 0)
245 m
.submodules
.mask
= mask
= Mask()
246 expected
= Signal(64)
248 m
.d
.comb
+= rot
.eq(rotl64(dut
.i
.rs
, dut
.i
.rb
[:6]))
249 mb
= dut
.fields
.FormMD
.mb
[:]
250 m
.d
.comb
+= mask
.start
.eq(Cat(mb
[1:6], mb
[0]))
251 m
.d
.comb
+= mask
.end
.eq(63 - dut
.i
.rb
[:6])
253 # for rldic, ra is guaranteed to be 0, so that part of
254 # the expression turns into a no-op
255 m
.d
.comb
+= expected
.eq((rot
& mask
.out
) |
(dut
.i
.ra
& ~mask
.out
))
256 m
.d
.comb
+= Assert(dut
.o
.o
.data
== expected
)
257 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
259 def _check_rlcl(self
, m
, dut
):
260 m
.d
.comb
+= Assume(~dut
.i
.ctx
.op
.is_32bit
)
263 m
.d
.comb
+= Assume(~dut
.i
.ctx
.op
.is_signed
)
264 m
.d
.comb
+= Assume(dut
.i
.ra
== 0)
266 m
.submodules
.mask
= mask
= Mask()
267 m
.d
.comb
+= mask
.end
.eq(63)
268 mb
= dut
.fields
.FormMD
.mb
[:]
269 m
.d
.comb
+= mask
.start
.eq(Cat(mb
[1:6], mb
[0]))
272 m
.d
.comb
+= rot
.eq(rotl64(dut
.i
.rs
, dut
.i
.rb
[:6]))
274 expected
= Signal(64)
275 m
.d
.comb
+= expected
.eq(rot
& mask
.out
)
277 m
.d
.comb
+= Assert(dut
.o
.o
.data
== expected
)
278 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
280 def _check_rlcr(self
, m
, dut
):
281 m
.d
.comb
+= Assume(~dut
.i
.ctx
.op
.is_32bit
)
284 m
.d
.comb
+= Assume(~dut
.i
.ctx
.op
.is_signed
)
285 m
.d
.comb
+= Assume(dut
.i
.ra
== 0)
287 m
.submodules
.mask
= mask
= Mask()
288 m
.d
.comb
+= mask
.start
.eq(0)
289 me
= dut
.fields
.FormMD
.me
[:]
290 m
.d
.comb
+= mask
.end
.eq(Cat(me
[1:6], me
[0]))
293 m
.d
.comb
+= rot
.eq(rotl64(dut
.i
.rs
, dut
.i
.rb
[:6]))
295 expected
= Signal(64)
296 m
.d
.comb
+= expected
.eq(rot
& mask
.out
)
298 m
.d
.comb
+= Assert(dut
.o
.o
.data
== expected
)
299 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
301 def _check_extswsli(self
, m
, dut
):
302 m
.d
.comb
+= Assume(dut
.i
.ra
== 0)
303 m
.d
.comb
+= Assume(dut
.i
.rb
[6:] == 0)
304 m
.d
.comb
+= Assume(~dut
.i
.ctx
.op
.is_32bit
) # all instrs. are 64-bit
305 expected
= Signal(64)
306 m
.d
.comb
+= expected
.eq((dut
.i
.rs
[0:32].as_signed() << dut
.i
.rb
[:6]))
307 m
.d
.comb
+= Assert(dut
.o
.o
.data
== expected
)
308 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
310 def _check_ternlog(self
, m
, dut
):
311 lut
= dut
.fields
.FormTLI
.TLI
[:]
313 idx
= Cat(dut
.i
.rb
[i
], dut
.i
.ra
[i
], dut
.i
.rc
[i
])
316 m
.d
.comb
+= Assert(dut
.o
.o
.data
[i
] == lut
[j
])
317 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
319 def _check_grev32(self
, m
, dut
):
320 m
.d
.comb
+= Assume(dut
.i
.ctx
.op
.is_32bit
)
321 # assert zero-extended
322 m
.d
.comb
+= Assert(dut
.o
.o
.data
[32:] == 0)
324 m
.d
.comb
+= eq_any_const(i
)
325 idx
= dut
.i
.rb
[0: 5] ^ i
326 m
.d
.comb
+= Assert((dut
.o
.o
.data
>> i
)[0] == (dut
.i
.ra
>> idx
)[0])
327 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
329 def _check_grev64(self
, m
, dut
):
330 m
.d
.comb
+= Assume(~dut
.i
.ctx
.op
.is_32bit
)
332 m
.d
.comb
+= eq_any_const(i
)
333 idx
= dut
.i
.rb
[0: 6] ^ i
334 m
.d
.comb
+= Assert((dut
.o
.o
.data
>> i
)[0] == (dut
.i
.ra
>> idx
)[0])
335 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
338 class ALUTestCase(FHDLTestCase
):
339 def run_it(self
, which
):
340 module
= Driver(which
)
341 self
.assertFormal(module
, mode
="bmc", depth
=2)
342 self
.assertFormal(module
, mode
="cover", depth
=2)
348 self
.run_it(TstOp
.SHL
)
351 self
.run_it(TstOp
.SHR
)
353 def test_rlc32(self
):
354 self
.run_it(TstOp
.RLC32
)
356 def test_rlc64(self
):
357 self
.run_it(TstOp
.RLC64
)
360 self
.run_it(TstOp
.RLCL
)
363 self
.run_it(TstOp
.RLCR
)
365 def test_extswsli(self
):
366 self
.run_it(TstOp
.EXTSWSLI
)
368 def test_ternlog(self
):
369 self
.run_it(TstOp
.TERNLOG
)
371 def test_grev32(self
):
372 self
.run_it(TstOp
.GREV32
)
374 def test_grev64(self
):
375 self
.run_it(TstOp
.GREV64
)
378 # check that all test cases are covered
380 assert callable(getattr(ALUTestCase
, f
"test_{i.name.lower()}"))
383 if __name__
== '__main__':