add parent_pspec everywhere
[soc.git] / src / soc / fu / mul / formal / proof_main_stage.py
1 # Proof of correctness for multiplier
2 # Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
3 # Copyright (C) 2020 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
4 # Copyright (C) 2020 Samuel A. Falvo II <kc5tja@arrl.net>
5
6 """Formal Correctness Proof for POWER9 multiplier
7
8 notes for ov/32. similar logic applies for 64-bit quantities (m63)
9
10 m31 = exp_prod[31:64]
11 comb += expected_ov.eq(m31.bool() & ~m31.all())
12
13 If the instruction enables the OV and OV32 flags to be
14 set, then we must set them both to 1 if and only if
15 the resulting product *cannot* be contained within a
16 32-bit quantity.
17
18 This is detected by checking to see if the resulting
19 upper bits are either all 1s or all 0s. If even *one*
20 bit in this set differs from its peers, then we know
21 the signed value cannot be contained in the destination's
22 field width.
23
24 m31.bool() is true if *any* high bit is set.
25 m31.all() is true if *all* high bits are set.
26
27 m31.bool() m31.all() Meaning
28 0 x All upper bits are 0, so product
29 is positive. Thus, it fits.
30 1 0 At least *one* high bit is clear.
31 Implying, not all high bits are
32 clones of the output sign bit.
33 Thus, product extends beyond
34 destination register size.
35 1 1 All high bits are set *and* they
36 match the sign bit. The number
37 is properly negative, and fits
38 in the destination register width.
39
40 Note that OV/OV32 are set to the *inverse* of m31.all(),
41 hence the expression m31.bool() & ~m31.all().
42 """
43
44
45 from nmigen import (Module, Signal, Elaboratable, Mux, Cat, Repl,
46 signed)
47 from nmigen.asserts import Assert, AnyConst, Assume, Cover
48 from nmutil.formaltest import FHDLTestCase
49 from nmutil.stageapi import StageChain
50 from nmigen.cli import rtlil
51
52 from openpower.decoder.power_fields import DecodeFields
53 from openpower.decoder.power_fieldsn import SignalBitRange
54
55 from soc.fu.mul.pipe_data import CompMULOpSubset, MulPipeSpec
56 from soc.fu.mul.pre_stage import MulMainStage1
57 from soc.fu.mul.main_stage import MulMainStage2
58 from soc.fu.mul.post_stage import MulMainStage3
59
60 from openpower.decoder.power_enums import MicrOp
61 import unittest
62
63
64 # This defines a module to drive the device under test and assert
65 # properties about its outputs
66 class Driver(Elaboratable):
67 def __init__(self):
68 # inputs and outputs
69 pass
70
71 def elaborate(self, platform):
72 m = Module()
73 comb = m.d.comb
74
75 rec = CompMULOpSubset()
76
77 # Setup random inputs for dut.op
78 comb += rec.insn_type.eq(AnyConst(rec.insn_type.width))
79 comb += rec.fn_unit.eq(AnyConst(rec.fn_unit.width))
80 comb += rec.is_signed.eq(AnyConst(rec.is_signed.width))
81 comb += rec.is_32bit.eq(AnyConst(rec.is_32bit.width))
82 comb += rec.imm_data.imm.eq(AnyConst(64))
83 comb += rec.imm_data.imm_ok.eq(AnyConst(1))
84
85 # set up the mul stages. do not add them to m.submodules, this
86 # is handled by StageChain.setup().
87 pspec = MulPipeSpec(id_wid=2, parent_pspec=None)
88 pipe1 = MulMainStage1(pspec)
89 pipe2 = MulMainStage2(pspec)
90 pipe3 = MulMainStage3(pspec)
91
92 class Dummy:
93 pass
94 dut = Dummy() # make a class into which dut.i and dut.o can be dropped
95 dut.i = pipe1.ispec()
96 chain = [pipe1, pipe2, pipe3] # chain of 3 mul stages
97
98 StageChain(chain).setup(m, dut.i) # input linked here, through chain
99 dut.o = chain[-1].o # output is the last thing in the chain...
100
101 # convenience variables
102 a = dut.i.ra
103 b = dut.i.rb
104 o = dut.o.o.data
105 xer_ov_o = dut.o.xer_ov.data
106 xer_ov_ok = dut.o.xer_ov.ok
107
108 # For 32- and 64-bit parameters, work out the absolute values of the
109 # input parameters for signed multiplies. Needed for signed
110 # multiplication.
111
112 abs32_a = Signal(32)
113 abs32_b = Signal(32)
114 abs64_a = Signal(64)
115 abs64_b = Signal(64)
116 a32_s = Signal(1)
117 b32_s = Signal(1)
118 a64_s = Signal(1)
119 b64_s = Signal(1)
120
121 comb += a32_s.eq(a[31])
122 comb += b32_s.eq(b[31])
123 comb += a64_s.eq(a[63])
124 comb += b64_s.eq(b[63])
125
126 comb += abs32_a.eq(Mux(a32_s, -a[0:32], a[0:32]))
127 comb += abs32_b.eq(Mux(b32_s, -b[0:32], b[0:32]))
128 comb += abs64_a.eq(Mux(a64_s, -a[0:64], a[0:64]))
129 comb += abs64_b.eq(Mux(b64_s, -b[0:64], b[0:64]))
130
131 # For 32- and 64-bit quantities, break out whether signs differ.
132 # (the _sne suffix is read as "signs not equal").
133 #
134 # This is required because of the rules of signed multiplication:
135 #
136 # a*b = +(abs(a)*abs(b)) for two positive numbers a and b.
137 # a*b = -(abs(a)*abs(b)) for any one positive number and negative
138 # number.
139 # a*b = +(abs(a)*abs(b)) for two negative numbers a and b.
140
141 ab32_sne = Signal()
142 ab64_sne = Signal()
143 comb += ab32_sne.eq(a32_s ^ b32_s)
144 comb += ab64_sne.eq(a64_s ^ b64_s)
145
146 # setup random inputs
147 comb += [a.eq(AnyConst(64)),
148 b.eq(AnyConst(64)),
149 ]
150
151 comb += dut.i.ctx.op.eq(rec)
152
153 # check overflow and result flags
154 result_ok = Signal()
155 enable_overflow = Signal()
156
157 # default to 1, disabled if default case is hit
158 comb += result_ok.eq(1)
159
160 # Assert that op gets copied from the input to output
161 comb += Assert(dut.o.ctx.op == dut.i.ctx.op)
162 comb += Assert(dut.o.ctx.muxid == dut.i.ctx.muxid)
163
164 # Assert that XER_SO propagates through as well.
165 comb += Assert(dut.o.xer_so == dut.i.xer_so)
166
167 # main assertion of arithmetic operations
168 with m.Switch(rec.insn_type):
169
170 ###### HI-32 #####
171
172 with m.Case(MicrOp.OP_MUL_H32):
173 comb += Assume(rec.is_32bit) # OP_MUL_H32 is a 32-bit op
174
175 exp_prod = Signal(64)
176 expected_o = Signal.like(exp_prod)
177
178 # unsigned hi32 - mulhwu
179 with m.If(~rec.is_signed):
180 comb += exp_prod.eq(a[0:32] * b[0:32])
181 comb += expected_o.eq(Repl(exp_prod[32:64], 2))
182 comb += Assert(o[0:64] == expected_o)
183
184 # signed hi32 - mulhw
185 with m.Else():
186 # Per rules of signed multiplication, if input signs
187 # differ, we negate the product. This implies that
188 # the product is calculated from the absolute values
189 # of the inputs.
190 prod = Signal.like(exp_prod) # intermediate product
191 comb += prod.eq(abs32_a * abs32_b)
192 comb += exp_prod.eq(Mux(ab32_sne, -prod, prod))
193 comb += expected_o.eq(Repl(exp_prod[32:64], 2))
194 comb += Assert(o[0:64] == expected_o)
195
196 ###### HI-64 #####
197
198 with m.Case(MicrOp.OP_MUL_H64):
199 comb += Assume(~rec.is_32bit)
200
201 exp_prod = Signal(128)
202
203 # unsigned hi64 - mulhdu
204 with m.If(~rec.is_signed):
205 comb += exp_prod.eq(a[0:64] * b[0:64])
206 comb += Assert(o[0:64] == exp_prod[64:128])
207
208 # signed hi64 - mulhd
209 with m.Else():
210 # Per rules of signed multiplication, if input signs
211 # differ, we negate the product. This implies that
212 # the product is calculated from the absolute values
213 # of the inputs.
214 prod = Signal.like(exp_prod) # intermediate product
215 comb += prod.eq(abs64_a * abs64_b)
216 comb += exp_prod.eq(Mux(ab64_sne, -prod, prod))
217 comb += Assert(o[0:64] == exp_prod[64:128])
218
219 ###### LO-64 #####
220 # mulli, mullw(o)(u), mulld(o)
221
222 with m.Case(MicrOp.OP_MUL_L64):
223
224 with m.If(rec.is_32bit): # 32-bit mode
225 expected_ov = Signal()
226 prod = Signal(64)
227 exp_prod = Signal.like(prod)
228
229 # unsigned lo32 - mullwu
230 with m.If(~rec.is_signed):
231 comb += exp_prod.eq(a[0:32] * b[0:32])
232 comb += Assert(o[0:64] == exp_prod[0:64])
233
234 # signed lo32 - mullw
235 with m.Else():
236 # Per rules of signed multiplication, if input signs
237 # differ, we negate the product. This implies that
238 # the product is calculated from the absolute values
239 # of the inputs.
240 comb += prod.eq(abs32_a[0:64] * abs32_b[0:64])
241 comb += exp_prod.eq(Mux(ab32_sne, -prod, prod))
242 comb += Assert(o[0:64] == exp_prod[0:64])
243
244 # see notes on overflow detection, above
245 m31 = exp_prod[31:64]
246 comb += expected_ov.eq(m31.bool() & ~m31.all())
247 comb += enable_overflow.eq(1)
248 comb += Assert(xer_ov_o == Repl(expected_ov, 2))
249
250 with m.Else(): # is 64-bit; mulld
251 expected_ov = Signal()
252 prod = Signal(128)
253 exp_prod = Signal.like(prod)
254
255 # From my reading of the v3.0B ISA spec,
256 # only signed instructions exist.
257 #
258 # Per rules of signed multiplication, if input signs
259 # differ, we negate the product. This implies that
260 # the product is calculated from the absolute values
261 # of the inputs.
262 comb += Assume(rec.is_signed)
263 comb += prod.eq(abs64_a[0:64] * abs64_b[0:64])
264 comb += exp_prod.eq(Mux(ab64_sne, -prod, prod))
265 comb += Assert(o[0:64] == exp_prod[0:64])
266
267 # see notes on overflow detection, above
268 m63 = exp_prod[63:128]
269 comb += expected_ov.eq(m63.bool() & ~m63.all())
270 comb += enable_overflow.eq(1)
271 comb += Assert(xer_ov_o == Repl(expected_ov, 2))
272
273 # not any of the cases above, disable result checking
274 with m.Default():
275 comb += result_ok.eq(0)
276
277 # check result "write" is correctly requested
278 comb += Assert(dut.o.o.ok == result_ok)
279 comb += Assert(xer_ov_ok == enable_overflow)
280
281 return m
282
283
284 class MulTestCase(FHDLTestCase):
285 def test_formal(self):
286 module = Driver()
287 self.assertFormal(module, mode="bmc", depth=2)
288 self.assertFormal(module, mode="cover", depth=2)
289
290 def test_ilang(self):
291 dut = Driver()
292 vl = rtlil.convert(dut, ports=[])
293 with open("main_stage.il", "w") as f:
294 f.write(vl)
295
296
297 if __name__ == '__main__':
298 unittest.main()