c4dd461d0a2596b7134a595287b19faae092a2c2
[soc.git] / src / soc / fu / shift_rot / formal / proof_main_stage.py
1 # Proof of correctness for partitioned equal signal combiner
2 # Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
3 """
4 Links:
5 * https://bugs.libre-soc.org/show_bug.cgi?id=340
6 """
7
8 from nmigen import (Module, Signal, Elaboratable, Mux, Cat, Repl,
9 signed, Array)
10 from nmigen.asserts import Assert, AnyConst, Assume, Cover
11 from nmutil.formaltest import FHDLTestCase
12 from nmigen.cli import rtlil
13
14 from soc.fu.shift_rot.main_stage import ShiftRotMainStage
15 from soc.fu.shift_rot.rotator import right_mask, left_mask
16 from soc.fu.shift_rot.pipe_data import ShiftRotPipeSpec
17 from soc.fu.shift_rot.sr_input_record import CompSROpSubset
18 from openpower.decoder.power_enums import MicrOp
19 from openpower.consts import field
20
21 import unittest
22 from nmutil.extend import exts
23
24
25 # This defines a module to drive the device under test and assert
26 # properties about its outputs
27 class Driver(Elaboratable):
28 def __init__(self):
29 # inputs and outputs
30 pass
31
32 def elaborate(self, platform):
33 m = Module()
34 comb = m.d.comb
35
36 rec = CompSROpSubset()
37 # Setup random inputs for dut.op. do them explicitly so that
38 # we can see which ones cause failures in the debug report
39 # for p in rec.ports():
40 # comb += p.eq(AnyConst(p.width))
41 comb += rec.insn_type.eq(AnyConst(rec.insn_type.width))
42 comb += rec.fn_unit.eq(AnyConst(rec.fn_unit.width))
43 comb += rec.imm_data.data.eq(AnyConst(rec.imm_data.data.width))
44 comb += rec.imm_data.ok.eq(AnyConst(rec.imm_data.ok.width))
45 comb += rec.rc.rc.eq(AnyConst(rec.rc.rc.width))
46 comb += rec.rc.ok.eq(AnyConst(rec.rc.ok.width))
47 comb += rec.oe.oe.eq(AnyConst(rec.oe.oe.width))
48 comb += rec.oe.ok.eq(AnyConst(rec.oe.ok.width))
49 comb += rec.write_cr0.eq(AnyConst(rec.write_cr0.width))
50 comb += rec.input_carry.eq(AnyConst(rec.input_carry.width))
51 comb += rec.output_carry.eq(AnyConst(rec.output_carry.width))
52 comb += rec.input_cr.eq(AnyConst(rec.input_cr.width))
53 comb += rec.is_32bit.eq(AnyConst(rec.is_32bit.width))
54 comb += rec.is_signed.eq(AnyConst(rec.is_signed.width))
55 comb += rec.insn.eq(AnyConst(rec.insn.width))
56
57 pspec = ShiftRotPipeSpec(id_wid=2, parent_pspec=None)
58 pspec.draft_bitmanip = True
59 m.submodules.dut = dut = ShiftRotMainStage(pspec)
60
61 # convenience variables
62 rs = dut.i.rs # register to shift
63 b = dut.i.rb # register containing amount to shift by
64 ra = dut.i.a # source register if masking is to be done
65 carry_in = dut.i.xer_ca[0]
66 carry_in32 = dut.i.xer_ca[1]
67 carry_out = dut.o.xer_ca
68 o = dut.o.o.data
69 print("fields", rec.fields)
70 itype = rec.insn_type
71
72 # instruction fields
73 m_fields = dut.fields.FormM
74 md_fields = dut.fields.FormMD
75
76 # setup random inputs
77 comb += rs.eq(AnyConst(64))
78 comb += ra.eq(AnyConst(64))
79 comb += b.eq(AnyConst(64))
80 comb += carry_in.eq(AnyConst(1))
81 comb += carry_in32.eq(AnyConst(1))
82
83 # copy operation
84 comb += dut.i.ctx.op.eq(rec)
85
86 # check that the operation (op) is passed through (and muxid)
87 comb += Assert(dut.o.ctx.op == dut.i.ctx.op)
88 comb += Assert(dut.o.ctx.muxid == dut.i.ctx.muxid)
89
90 # signed and signed/32 versions of input rs
91 a_signed = Signal(signed(64))
92 a_signed_32 = Signal(signed(32))
93 comb += a_signed.eq(rs)
94 comb += a_signed_32.eq(rs[0:32])
95
96 # masks: start-left
97 mb = Signal(7, reset_less=True)
98 ml = Signal(64, reset_less=True)
99
100 # clear left?
101 with m.If((itype == MicrOp.OP_RLC) | (itype == MicrOp.OP_RLCL)):
102 with m.If(rec.is_32bit):
103 comb += mb.eq(m_fields.MB[:])
104 with m.Else():
105 comb += mb.eq(md_fields.mb[:])
106 with m.Else():
107 with m.If(rec.is_32bit):
108 comb += mb.eq(b[0:6])
109 with m.Else():
110 comb += mb.eq(b+32)
111 comb += ml.eq(left_mask(m, mb))
112
113 # masks: end-right
114 me = Signal(7, reset_less=True)
115 mr = Signal(64, reset_less=True)
116
117 # clear right?
118 with m.If((itype == MicrOp.OP_RLC) | (itype == MicrOp.OP_RLCR)):
119 with m.If(rec.is_32bit):
120 comb += me.eq(m_fields.ME[:])
121 with m.Else():
122 comb += me.eq(md_fields.me[:])
123 with m.Else():
124 with m.If(rec.is_32bit):
125 comb += me.eq(b[0:6])
126 with m.Else():
127 comb += me.eq(63-b)
128 comb += mr.eq(right_mask(m, me))
129
130 # must check Data.ok
131 o_ok = Signal()
132 comb += o_ok.eq(1)
133
134 # main assertion of arithmetic operations
135 with m.Switch(itype):
136
137 # left-shift: 64/32-bit
138 with m.Case(MicrOp.OP_SHL):
139 comb += Assume(ra == 0)
140 with m.If(rec.is_32bit):
141 comb += Assert(o[0:32] == ((rs << b[0:6]) & 0xffffffff))
142 comb += Assert(o[32:64] == 0)
143 with m.Else():
144 comb += Assert(o == ((rs << b[0:7]) & ((1 << 64)-1)))
145
146 # right-shift: 64/32-bit / signed
147 with m.Case(MicrOp.OP_SHR):
148 comb += Assume(ra == 0)
149 with m.If(~rec.is_signed):
150 with m.If(rec.is_32bit):
151 comb += Assert(o[0:32] == (rs[0:32] >> b[0:6]))
152 comb += Assert(o[32:64] == 0)
153 with m.Else():
154 comb += Assert(o == (rs >> b[0:7]))
155 with m.Else():
156 with m.If(rec.is_32bit):
157 comb += Assert(o[0:32] == (a_signed_32 >> b[0:6]))
158 comb += Assert(o[32:64] == Repl(rs[31], 32))
159 with m.Else():
160 comb += Assert(o == (a_signed >> b[0:7]))
161
162 # extswsli: 32/64-bit moded
163 with m.Case(MicrOp.OP_EXTSWSLI):
164 comb += Assume(ra == 0)
165 with m.If(rec.is_32bit):
166 comb += Assert(o[0:32] == ((rs << b[0:6]) & 0xffffffff))
167 comb += Assert(o[32:64] == 0)
168 with m.Else():
169 # sign-extend to 64 bit
170 a_s = Signal(64, reset_less=True)
171 comb += a_s.eq(exts(rs, 32, 64))
172 comb += Assert(o == ((a_s << b[0:7]) & ((1 << 64)-1)))
173
174 # rlwinm, rlwnm, rlwimi
175 # *CAN* these even be 64-bit capable? I don't think they are.
176 with m.Case(MicrOp.OP_RLC):
177 comb += Assume(ra == 0)
178 comb += Assume(rec.is_32bit)
179
180 # Duplicate some signals so that they're much easier to find
181 # in gtkwave.
182 # Pro-tip: when debugging, factor out expressions into
183 # explicitly named
184 # signals, and search using a unique grep-tag (RLC in my case).
185 # After
186 # debugging, resubstitute values to comply with surrounding
187 # code norms.
188
189 mrl = Signal(64, reset_less=True, name='MASK_FOR_RLC')
190 with m.If(mb > me):
191 comb += mrl.eq(ml | mr)
192 with m.Else():
193 comb += mrl.eq(ml & mr)
194
195 ainp = Signal(64, reset_less=True, name='A_INP_FOR_RLC')
196 comb += ainp.eq(field(rs, 32, 63))
197
198 sh = Signal(6, reset_less=True, name='SH_FOR_RLC')
199 comb += sh.eq(b[0:6])
200
201 exp_shl = Signal(64, reset_less=True,
202 name='A_SHIFTED_LEFT_BY_SH_FOR_RLC')
203 comb += exp_shl.eq((ainp << sh) & 0xFFFFFFFF)
204
205 exp_shr = Signal(64, reset_less=True,
206 name='A_SHIFTED_RIGHT_FOR_RLC')
207 comb += exp_shr.eq((ainp >> (32 - sh)) & 0xFFFFFFFF)
208
209 exp_rot = Signal(64, reset_less=True,
210 name='A_ROTATED_LEFT_FOR_RLC')
211 comb += exp_rot.eq(exp_shl | exp_shr)
212
213 exp_ol = Signal(32, reset_less=True,
214 name='EXPECTED_OL_FOR_RLC')
215 comb += exp_ol.eq(field((exp_rot & mrl) | (ainp & ~mrl),
216 32, 63))
217
218 act_ol = Signal(32, reset_less=True, name='ACTUAL_OL_FOR_RLC')
219 comb += act_ol.eq(field(o, 32, 63))
220
221 # If I uncomment the following lines, I can confirm that all
222 # 32-bit rotations work. If I uncomment only one of the
223 # following lines, I can confirm that all 32-bit rotations
224 # work. When I remove/recomment BOTH lines, however, the
225 # assertion fails. Why??
226
227 # comb += Assume(mr == 0xFFFFFFFF)
228 # comb += Assume(ml == 0xFFFFFFFF)
229 # with m.If(rec.is_32bit):
230 # comb += Assert(act_ol == exp_ol)
231 # comb += Assert(field(o, 0, 31) == 0)
232
233 # TODO
234 with m.Case(MicrOp.OP_RLCR):
235 pass
236 with m.Case(MicrOp.OP_RLCL):
237 pass
238 with m.Case(MicrOp.OP_TERNLOG):
239 lut = dut.fields.FormTLI.TLI[:]
240 for i in range(64):
241 idx = Cat(dut.i.rb[i], dut.i.ra[i], dut.i.rc[i])
242 for j in range(8):
243 with m.If(j == idx):
244 comb += Assert(dut.o.o.data[i] == lut[j])
245 with m.Case(MicrOp.OP_GREV):
246 ra_bits = Array(dut.i.ra[i] for i in range(64))
247 with m.If(dut.i.ctx.op.is_32bit):
248 # assert zero-extended
249 comb += Assert(dut.o.o.data[32:] == 0)
250 for i in range(32):
251 idx = dut.i.rb[0:5] ^ i
252 comb += Assert(dut.o.o.data[i]
253 == ra_bits[idx])
254 with m.Else():
255 for i in range(64):
256 idx = dut.i.rb[0:6] ^ i
257 comb += Assert(dut.o.o.data[i]
258 == ra_bits[idx])
259
260 with m.Default():
261 comb += o_ok.eq(0)
262
263 # check that data ok was only enabled when op actioned
264 comb += Assert(dut.o.o.ok == o_ok)
265
266 return m
267
268
269 class ALUTestCase(FHDLTestCase):
270 def test_formal(self):
271 module = Driver()
272 self.assertFormal(module, mode="bmc", depth=2)
273 self.assertFormal(module, mode="cover", depth=2)
274
275 def test_ilang(self):
276 dut = Driver()
277 vl = rtlil.convert(dut, ports=[])
278 with open("main_stage.il", "w") as f:
279 f.write(vl)
280
281
282 if __name__ == '__main__':
283 unittest.main()