1 # Proof of correctness for shift/rotate FU
2 # Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
5 * https://bugs.libre-soc.org/show_bug.cgi?id=340
9 from shutil
import which
10 from nmigen
import (Module
, Signal
, Elaboratable
, Mux
, Cat
, Repl
,
11 signed
, Array
, Const
, Value
, unsigned
)
12 from nmigen
.asserts
import Assert
, AnyConst
, Assume
, Cover
13 from nmutil
.formaltest
import FHDLTestCase
14 from nmutil
.sim_util
import do_sim
15 from nmigen
.sim
import Delay
17 from soc
.fu
.shift_rot
.main_stage
import ShiftRotMainStage
18 from soc
.fu
.shift_rot
.rotator
import right_mask
, left_mask
19 from soc
.fu
.shift_rot
.pipe_data
import ShiftRotPipeSpec
20 from soc
.fu
.shift_rot
.sr_input_record
import CompSROpSubset
21 from openpower
.decoder
.power_enums
import MicrOp
22 from openpower
.consts
import field
25 from nmutil
.extend
import exts
29 class TstOp(enum
.Enum
):
30 """ops we're testing, the idea is if we run a separate formal proof for
31 each instruction, we end up covering them all and each runs much faster,
32 also the formal proofs can be run in parallel."""
35 RLC32
= MicrOp
.OP_RLC
, 32
36 RLC64
= MicrOp
.OP_RLC
, 64
39 EXTSWSLI
= MicrOp
.OP_EXTSWSLI
40 TERNLOG
= MicrOp
.OP_TERNLOG
41 GREV32
= MicrOp
.OP_GREV
, 32
42 GREV64
= MicrOp
.OP_GREV
, 64
46 if isinstance(self
.value
, tuple):
51 def eq_any_const(sig
: Signal
):
52 return sig
.eq(AnyConst(sig
.shape(), src_loc_at
=1))
55 class Mask(Elaboratable
):
56 # copied from qemu's mask fn:
57 # https://gitlab.com/qemu-project/qemu/-/blob/477c3b934a47adf7de285863f59d6e4503dd1a6d/target/ppc/internal.h#L21
59 self
.start
= Signal(6)
63 def elaborate(self
, platform
):
65 max_val
= Const(~
0, unsigned(64))
67 with m
.If(self
.start
== 0):
68 m
.d
.comb
+= self
.out
.eq(max_val
<< (max_bit
- self
.end
))
69 with m
.Elif(self
.end
== max_bit
):
70 m
.d
.comb
+= self
.out
.eq(max_val
>> self
.start
)
72 ret
= (max_val
>> self
.start
) ^
((max_val
>> self
.end
) >> 1)
73 m
.d
.comb
+= self
.out
.eq(Mux(self
.start
> self
.end
, ~ret
, ret
))
77 class TstMask(unittest
.TestCase
):
81 def case(start
, end
, expected
):
82 with self
.subTest(start
=start
, end
=end
):
83 yield dut
.start
.eq(start
)
87 with self
.subTest(out
=hex(out
), expected
=hex(expected
)):
88 self
.assertEqual(expected
, out
)
91 for start
in range(64):
95 for i
in range(start
, 64):
96 expected |
= 1 << (63 - i
)
97 for i
in range(0, end
+ 1):
98 expected |
= 1 << (63 - i
)
100 for i
in range(start
, end
+ 1):
101 expected |
= 1 << (63 - i
)
102 yield from case(start
, end
, expected
)
103 with
do_sim(self
, dut
, [dut
.start
, dut
.end
, dut
.out
]) as sim
:
104 sim
.add_process(process
)
109 v |
= Const(0, 64) # convert to value at least 64-bits wide
110 amt |
= Const(0, 6) # convert to value at least 6-bits wide
111 return (Cat(v
[:64], v
[:64]) >> (64 - amt
[:6]))[:64]
115 v |
= Const(0, 32) # convert to value at least 32-bits wide
116 return rotl64(Cat(v
[:32], v
[:32]), amt
)
119 # This defines a module to drive the device under test and assert
120 # properties about its outputs
121 class Driver(Elaboratable
):
122 def __init__(self
, which
):
123 assert isinstance(which
, TstOp
)
126 def elaborate(self
, platform
):
130 pspec
= ShiftRotPipeSpec(id_wid
=2, parent_pspec
=None)
131 pspec
.draft_bitmanip
= True
132 m
.submodules
.dut
= dut
= ShiftRotMainStage(pspec
)
134 # Set inputs to formal variables
136 eq_any_const(dut
.i
.ctx
.op
.insn_type
),
137 eq_any_const(dut
.i
.ctx
.op
.fn_unit
),
138 eq_any_const(dut
.i
.ctx
.op
.imm_data
.data
),
139 eq_any_const(dut
.i
.ctx
.op
.imm_data
.ok
),
140 eq_any_const(dut
.i
.ctx
.op
.rc
.rc
),
141 eq_any_const(dut
.i
.ctx
.op
.rc
.ok
),
142 eq_any_const(dut
.i
.ctx
.op
.oe
.oe
),
143 eq_any_const(dut
.i
.ctx
.op
.oe
.ok
),
144 eq_any_const(dut
.i
.ctx
.op
.write_cr0
),
145 eq_any_const(dut
.i
.ctx
.op
.input_carry
),
146 eq_any_const(dut
.i
.ctx
.op
.output_carry
),
147 eq_any_const(dut
.i
.ctx
.op
.input_cr
),
148 eq_any_const(dut
.i
.ctx
.op
.is_32bit
),
149 eq_any_const(dut
.i
.ctx
.op
.is_signed
),
150 eq_any_const(dut
.i
.ctx
.op
.insn
),
151 eq_any_const(dut
.i
.xer_ca
),
152 eq_any_const(dut
.i
.ra
),
153 eq_any_const(dut
.i
.rb
),
154 eq_any_const(dut
.i
.rc
),
157 # check that the operation (op) is passed through (and muxid)
158 comb
+= Assert(dut
.o
.ctx
.op
== dut
.i
.ctx
.op
)
159 comb
+= Assert(dut
.o
.ctx
.muxid
== dut
.i
.ctx
.muxid
)
161 # we're only checking a particular operation:
162 comb
+= Assume(dut
.i
.ctx
.op
.insn_type
== self
.which
.op
)
164 # dispatch to check fn for each op
165 getattr(self
, f
"_check_{self.which.name.lower()}")(m
, dut
)
169 # all following code in elaborate is kept for ease of reference, to be
170 # deleted once this proof is completed.
172 # convenience variables
173 rs
= dut
.i
.rs
# register to shift
174 b
= dut
.i
.rb
# register containing amount to shift by
175 ra
= dut
.i
.a
# source register if masking is to be done
176 carry_in
= dut
.i
.xer_ca
[0]
177 carry_in32
= dut
.i
.xer_ca
[1]
178 carry_out
= dut
.o
.xer_ca
180 print("fields", rec
.fields
)
181 itype
= rec
.insn_type
184 m_fields
= dut
.fields
.FormM
185 md_fields
= dut
.fields
.FormMD
187 # setup random inputs
188 comb
+= rs
.eq(AnyConst(64))
189 comb
+= ra
.eq(AnyConst(64))
190 comb
+= b
.eq(AnyConst(64))
191 comb
+= carry_in
.eq(AnyConst(1))
192 comb
+= carry_in32
.eq(AnyConst(1))
195 comb
+= dut
.i
.ctx
.op
.eq(rec
)
197 # check that the operation (op) is passed through (and muxid)
198 comb
+= Assert(dut
.o
.ctx
.op
== dut
.i
.ctx
.op
)
199 comb
+= Assert(dut
.o
.ctx
.muxid
== dut
.i
.ctx
.muxid
)
201 # signed and signed/32 versions of input rs
202 a_signed
= Signal(signed(64))
203 a_signed_32
= Signal(signed(32))
204 comb
+= a_signed
.eq(rs
)
205 comb
+= a_signed_32
.eq(rs
[0:32])
208 mb
= Signal(7, reset_less
=True)
209 ml
= Signal(64, reset_less
=True)
212 with m
.If((itype
== MicrOp
.OP_RLC
) |
(itype
== MicrOp
.OP_RLCL
)):
213 with m
.If(rec
.is_32bit
):
214 comb
+= mb
.eq(m_fields
.MB
[:])
216 comb
+= mb
.eq(md_fields
.mb
[:])
218 with m
.If(rec
.is_32bit
):
219 comb
+= mb
.eq(b
[0:6])
222 comb
+= ml
.eq(left_mask(m
, mb
))
225 me
= Signal(7, reset_less
=True)
226 mr
= Signal(64, reset_less
=True)
229 with m
.If((itype
== MicrOp
.OP_RLC
) |
(itype
== MicrOp
.OP_RLCR
)):
230 with m
.If(rec
.is_32bit
):
231 comb
+= me
.eq(m_fields
.ME
[:])
233 comb
+= me
.eq(md_fields
.me
[:])
235 with m
.If(rec
.is_32bit
):
236 comb
+= me
.eq(b
[0:6])
239 comb
+= mr
.eq(right_mask(m
, me
))
245 # main assertion of arithmetic operations
246 with m
.Switch(itype
):
248 # left-shift: 64/32-bit
249 with m
.Case(MicrOp
.OP_SHL
):
250 comb
+= Assume(ra
== 0)
251 with m
.If(rec
.is_32bit
):
252 comb
+= Assert(o
[0:32] == ((rs
<< b
[0:6]) & 0xffffffff))
253 comb
+= Assert(o
[32:64] == 0)
255 comb
+= Assert(o
== ((rs
<< b
[0:7]) & ((1 << 64)-1)))
257 # right-shift: 64/32-bit / signed
258 with m
.Case(MicrOp
.OP_SHR
):
259 comb
+= Assume(ra
== 0)
260 with m
.If(~rec
.is_signed
):
261 with m
.If(rec
.is_32bit
):
262 comb
+= Assert(o
[0:32] == (rs
[0:32] >> b
[0:6]))
263 comb
+= Assert(o
[32:64] == 0)
265 comb
+= Assert(o
== (rs
>> b
[0:7]))
267 with m
.If(rec
.is_32bit
):
268 comb
+= Assert(o
[0:32] == (a_signed_32
>> b
[0:6]))
269 comb
+= Assert(o
[32:64] == Repl(rs
[31], 32))
271 comb
+= Assert(o
== (a_signed
>> b
[0:7]))
273 # extswsli: 32/64-bit moded
274 with m
.Case(MicrOp
.OP_EXTSWSLI
):
275 comb
+= Assume(ra
== 0)
276 with m
.If(rec
.is_32bit
):
277 comb
+= Assert(o
[0:32] == ((rs
<< b
[0:6]) & 0xffffffff))
278 comb
+= Assert(o
[32:64] == 0)
280 # sign-extend to 64 bit
281 a_s
= Signal(64, reset_less
=True)
282 comb
+= a_s
.eq(exts(rs
, 32, 64))
283 comb
+= Assert(o
== ((a_s
<< b
[0:7]) & ((1 << 64)-1)))
285 # rlwinm, rlwnm, rlwimi
286 # *CAN* these even be 64-bit capable? I don't think they are.
287 with m
.Case(MicrOp
.OP_RLC
):
288 comb
+= Assume(ra
== 0)
289 comb
+= Assume(rec
.is_32bit
)
291 # Duplicate some signals so that they're much easier to find
293 # Pro-tip: when debugging, factor out expressions into
295 # signals, and search using a unique grep-tag (RLC in my case).
297 # debugging, resubstitute values to comply with surrounding
300 mrl
= Signal(64, reset_less
=True, name
='MASK_FOR_RLC')
302 comb
+= mrl
.eq(ml | mr
)
304 comb
+= mrl
.eq(ml
& mr
)
306 ainp
= Signal(64, reset_less
=True, name
='A_INP_FOR_RLC')
307 comb
+= ainp
.eq(field(rs
, 32, 63))
309 sh
= Signal(6, reset_less
=True, name
='SH_FOR_RLC')
310 comb
+= sh
.eq(b
[0:6])
312 exp_shl
= Signal(64, reset_less
=True,
313 name
='A_SHIFTED_LEFT_BY_SH_FOR_RLC')
314 comb
+= exp_shl
.eq((ainp
<< sh
) & 0xFFFFFFFF)
316 exp_shr
= Signal(64, reset_less
=True,
317 name
='A_SHIFTED_RIGHT_FOR_RLC')
318 comb
+= exp_shr
.eq((ainp
>> (32 - sh
)) & 0xFFFFFFFF)
320 exp_rot
= Signal(64, reset_less
=True,
321 name
='A_ROTATED_LEFT_FOR_RLC')
322 comb
+= exp_rot
.eq(exp_shl | exp_shr
)
324 exp_ol
= Signal(32, reset_less
=True,
325 name
='EXPECTED_OL_FOR_RLC')
326 comb
+= exp_ol
.eq(field((exp_rot
& mrl
) |
(ainp
& ~mrl
),
329 act_ol
= Signal(32, reset_less
=True, name
='ACTUAL_OL_FOR_RLC')
330 comb
+= act_ol
.eq(field(o
, 32, 63))
332 # If I uncomment the following lines, I can confirm that all
333 # 32-bit rotations work. If I uncomment only one of the
334 # following lines, I can confirm that all 32-bit rotations
335 # work. When I remove/recomment BOTH lines, however, the
336 # assertion fails. Why??
338 # comb += Assume(mr == 0xFFFFFFFF)
339 # comb += Assume(ml == 0xFFFFFFFF)
340 # with m.If(rec.is_32bit):
341 # comb += Assert(act_ol == exp_ol)
342 # comb += Assert(field(o, 0, 31) == 0)
345 with m
.Case(MicrOp
.OP_RLCR
):
347 with m
.Case(MicrOp
.OP_RLCL
):
349 with m
.Case(MicrOp
.OP_TERNLOG
):
350 lut
= dut
.fields
.FormTLI
.TLI
[:]
352 idx
= Cat(dut
.i
.rb
[i
], dut
.i
.ra
[i
], dut
.i
.rc
[i
])
355 comb
+= Assert(dut
.o
.o
.data
[i
] == lut
[j
])
356 with m
.Case(MicrOp
.OP_GREV
):
357 ra_bits
= Array(dut
.i
.ra
[i
] for i
in range(64))
358 with m
.If(dut
.i
.ctx
.op
.is_32bit
):
359 # assert zero-extended
360 comb
+= Assert(dut
.o
.o
.data
[32:] == 0)
362 idx
= dut
.i
.rb
[0:5] ^ i
363 comb
+= Assert(dut
.o
.o
.data
[i
]
367 idx
= dut
.i
.rb
[0:6] ^ i
368 comb
+= Assert(dut
.o
.o
.data
[i
]
374 # check that data ok was only enabled when op actioned
375 comb
+= Assert(dut
.o
.o
.ok
== o_ok
)
379 def _check_shl(self
, m
, dut
):
380 m
.d
.comb
+= Assume(dut
.i
.ra
== 0)
381 expected
= Signal(64)
382 with m
.If(dut
.i
.ctx
.op
.is_32bit
):
383 m
.d
.comb
+= expected
.eq((dut
.i
.rs
<< dut
.i
.rb
[:6])[:32])
385 m
.d
.comb
+= expected
.eq((dut
.i
.rs
<< dut
.i
.rb
[:7])[:64])
386 m
.d
.comb
+= Assert(dut
.o
.o
.data
== expected
)
387 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
389 def _check_shr(self
, m
, dut
):
390 m
.d
.comb
+= Assume(dut
.i
.ra
== 0)
391 expected
= Signal(64)
393 shift_in_s
= Signal(signed(128))
394 shift_roundtrip
= Signal(signed(128))
395 shift_in_u
= Signal(128)
396 shift_amt
= Signal(7)
397 with m
.If(dut
.i
.ctx
.op
.is_32bit
):
399 shift_amt
.eq(dut
.i
.rb
[:6]),
400 shift_in_s
.eq(dut
.i
.rs
[:32].as_signed()),
401 shift_in_u
.eq(dut
.i
.rs
[:32]),
405 shift_amt
.eq(dut
.i
.rb
[:7]),
406 shift_in_s
.eq(dut
.i
.rs
.as_signed()),
407 shift_in_u
.eq(dut
.i
.rs
),
410 with m
.If(dut
.i
.ctx
.op
.is_signed
):
412 expected
.eq(shift_in_s
>> shift_amt
),
413 shift_roundtrip
.eq((shift_in_s
>> shift_amt
) << shift_amt
),
414 carry
.eq((shift_in_s
< 0) & (shift_roundtrip
!= shift_in_s
)),
418 expected
.eq(shift_in_u
>> shift_amt
),
421 m
.d
.comb
+= Assert(dut
.o
.o
.data
== expected
)
422 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== Repl(carry
, 2))
424 def _check_rlc32(self
, m
, dut
):
425 m
.d
.comb
+= Assume(dut
.i
.ctx
.op
.is_32bit
)
426 # rlwimi, rlwinm, and rlwnm
428 m
.submodules
.mask
= mask
= Mask()
429 expected
= Signal(64)
431 m
.d
.comb
+= rot
.eq(rotl32(dut
.i
.rs
[:32], dut
.i
.rb
[:5]))
432 m
.d
.comb
+= mask
.start
.eq(dut
.fields
.FormM
.MB
[:] + 32)
433 m
.d
.comb
+= mask
.end
.eq(dut
.fields
.FormM
.ME
[:] + 32)
435 # for rlwinm and rlwnm, ra is guaranteed to be 0, so that part of
436 # the expression turns into a no-op
437 m
.d
.comb
+= expected
.eq((rot
& mask
.out
) |
(dut
.i
.ra
& ~mask
.out
))
438 m
.d
.comb
+= Assert(dut
.o
.o
.data
== expected
)
439 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
441 def _check_rlc64(self
, m
, dut
):
442 m
.d
.comb
+= Assume(~dut
.i
.ctx
.op
.is_32bit
)
445 # `rb` is always a 6-bit immediate
446 m
.d
.comb
+= Assume(dut
.i
.rb
[6:] == 0)
448 m
.submodules
.mask
= mask
= Mask()
449 expected
= Signal(64)
451 m
.d
.comb
+= rot
.eq(rotl64(dut
.i
.rs
, dut
.i
.rb
[:6]))
452 mb
= dut
.fields
.FormMD
.mb
[:]
453 m
.d
.comb
+= mask
.start
.eq(Cat(mb
[1:6], mb
[0]))
454 m
.d
.comb
+= mask
.end
.eq(63 - dut
.i
.rb
[:6])
456 # for rldic, ra is guaranteed to be 0, so that part of
457 # the expression turns into a no-op
458 m
.d
.comb
+= expected
.eq((rot
& mask
.out
) |
(dut
.i
.ra
& ~mask
.out
))
459 m
.d
.comb
+= Assert(dut
.o
.o
.data
== expected
)
460 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
462 def _check_rlcl(self
, m
, dut
):
463 m
.d
.comb
+= Assume(~dut
.i
.ctx
.op
.is_32bit
)
466 m
.d
.comb
+= Assume(~dut
.i
.ctx
.op
.is_signed
)
467 m
.d
.comb
+= Assume(dut
.i
.ra
== 0)
469 m
.submodules
.mask
= mask
= Mask()
470 m
.d
.comb
+= mask
.end
.eq(63)
471 mb
= dut
.fields
.FormMD
.mb
[:]
472 m
.d
.comb
+= mask
.start
.eq(Cat(mb
[1:6], mb
[0]))
475 m
.d
.comb
+= rot
.eq(rotl64(dut
.i
.rs
, dut
.i
.rb
[:6]))
477 expected
= Signal(64)
478 m
.d
.comb
+= expected
.eq(rot
& mask
.out
)
480 m
.d
.comb
+= Assert(dut
.o
.o
.data
== expected
)
481 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
483 def _check_rlcr(self
, m
, dut
):
484 raise NotImplementedError
486 def _check_extswsli(self
, m
, dut
):
487 m
.d
.comb
+= Assume(dut
.i
.ra
== 0)
488 m
.d
.comb
+= Assume(dut
.i
.rb
[6:] == 0)
489 m
.d
.comb
+= Assume(~dut
.i
.ctx
.op
.is_32bit
) # all instrs. are 64-bit
490 expected
= Signal(64)
491 m
.d
.comb
+= expected
.eq((dut
.i
.rs
[0:32].as_signed() << dut
.i
.rb
[:6]))
492 m
.d
.comb
+= Assert(dut
.o
.o
.data
== expected
)
493 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
495 def _check_ternlog(self
, m
, dut
):
496 lut
= dut
.fields
.FormTLI
.TLI
[:]
498 idx
= Cat(dut
.i
.rb
[i
], dut
.i
.ra
[i
], dut
.i
.rc
[i
])
501 m
.d
.comb
+= Assert(dut
.o
.o
.data
[i
] == lut
[j
])
502 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
504 def _check_grev32(self
, m
, dut
):
505 m
.d
.comb
+= Assume(dut
.i
.ctx
.op
.is_32bit
)
506 # assert zero-extended
507 m
.d
.comb
+= Assert(dut
.o
.o
.data
[32:] == 0)
509 m
.d
.comb
+= eq_any_const(i
)
510 idx
= dut
.i
.rb
[0: 5] ^ i
511 m
.d
.comb
+= Assert((dut
.o
.o
.data
>> i
)[0] == (dut
.i
.ra
>> idx
)[0])
512 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
514 def _check_grev64(self
, m
, dut
):
515 m
.d
.comb
+= Assume(~dut
.i
.ctx
.op
.is_32bit
)
517 m
.d
.comb
+= eq_any_const(i
)
518 idx
= dut
.i
.rb
[0: 6] ^ i
519 m
.d
.comb
+= Assert((dut
.o
.o
.data
>> i
)[0] == (dut
.i
.ra
>> idx
)[0])
520 m
.d
.comb
+= Assert(dut
.o
.xer_ca
.data
== 0)
523 class ALUTestCase(FHDLTestCase
):
524 def run_it(self
, which
):
525 module
= Driver(which
)
526 self
.assertFormal(module
, mode
="bmc", depth
=2)
527 self
.assertFormal(module
, mode
="cover", depth
=2)
530 self
.run_it(TstOp
.SHL
)
533 self
.run_it(TstOp
.SHR
)
535 def test_rlc32(self
):
536 self
.run_it(TstOp
.RLC32
)
538 def test_rlc64(self
):
539 self
.run_it(TstOp
.RLC64
)
542 self
.run_it(TstOp
.RLCL
)
545 self
.run_it(TstOp
.RLCR
)
547 def test_extswsli(self
):
548 self
.run_it(TstOp
.EXTSWSLI
)
550 def test_ternlog(self
):
551 self
.run_it(TstOp
.TERNLOG
)
553 def test_grev32(self
):
554 self
.run_it(TstOp
.GREV32
)
556 def test_grev64(self
):
557 self
.run_it(TstOp
.GREV64
)
560 # check that all test cases are covered
562 assert callable(getattr(ALUTestCase
, f
"test_{i.name.lower()}"))
565 if __name__
== '__main__':