minor code-shuffle, comments
[soc.git] / src / soc / fu / mul / formal / proof_main_stage.py
1 # Proof of correctness for partitioned equal signal combiner
2 # Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
3
4 from nmigen import (Module, Signal, Elaboratable, Mux, Cat, Repl,
5 signed)
6 from nmigen.asserts import Assert, AnyConst, Assume, Cover
7 from nmutil.formaltest import FHDLTestCase
8 from nmutil.stageapi import StageChain
9 from nmigen.cli import rtlil
10
11 from soc.fu.mul.pipe_data import CompMULOpSubset, MulPipeSpec
12 from soc.fu.mul.pre_stage import MulMainStage1
13 from soc.fu.mul.main_stage import MulMainStage2
14 from soc.fu.mul.post_stage import MulMainStage3
15
16 from soc.decoder.power_enums import MicrOp
17 import unittest
18
19
20 # This defines a module to drive the device under test and assert
21 # properties about its outputs
22 class Driver(Elaboratable):
23 def __init__(self):
24 # inputs and outputs
25 pass
26
27 def elaborate(self, platform):
28 m = Module()
29 comb = m.d.comb
30
31 rec = CompMULOpSubset()
32
33 # Setup random inputs for dut.op
34 comb += rec.insn_type.eq(AnyConst(rec.insn_type.width))
35 comb += rec.fn_unit.eq(AnyConst(rec.fn_unit.width))
36 comb += rec.is_signed.eq(AnyConst(rec.is_signed.width))
37 comb += rec.is_32bit.eq(AnyConst(rec.is_32bit.width))
38 comb += rec.imm_data.imm.eq(AnyConst(64))
39 comb += rec.imm_data.imm_ok.eq(AnyConst(1))
40 # TODO, the rest of these. (the for-loop hides Assert-failures)
41
42 # set up the mul stages. do not add them to m.submodules, this
43 # is handled by StageChain.setup().
44 pspec = MulPipeSpec(id_wid=2)
45 pipe1 = MulMainStage1(pspec)
46 pipe2 = MulMainStage2(pspec)
47 pipe3 = MulMainStage3(pspec)
48
49 class Dummy: pass
50 dut = Dummy() # make a class into which dut.i and dut.o can be dropped
51 dut.i = pipe1.ispec()
52 chain = [pipe1, pipe2, pipe3] # chain of 3 mul stages
53
54 StageChain(chain).setup(m, dut.i) # input linked here, through chain
55 dut.o = chain[-1].o # output is the last thing in the chain...
56
57 # convenience variables
58 a = dut.i.ra
59 b = dut.i.rb
60 o = dut.o.o.data
61
62 # work out absolute (as 32 bit signed) of a and b
63 abs32_a = Signal(32)
64 abs32_b = Signal(32)
65 a32_s = Signal(1)
66 b32_s = Signal(1)
67 comb += a32_s.eq(a[31])
68 comb += b32_s.eq(b[31])
69 comb += abs32_a.eq(Mux(a32_s, -a[0:32], a[0:32]))
70 comb += abs32_b.eq(Mux(b32_s, -b[0:32], b[0:32]))
71
72 # work out absolute (as 64 bit signed) of a and b
73 abs64_a = Signal(64)
74 abs64_b = Signal(64)
75 a64_s = Signal(1)
76 b64_s = Signal(1)
77 comb += a64_s.eq(a[63])
78 comb += b64_s.eq(b[63])
79 comb += abs64_a.eq(Mux(a64_s, -a[0:64], a[0:64]))
80 comb += abs64_b.eq(Mux(b64_s, -b[0:64], b[0:64]))
81
82 # a same sign as b
83 ab32_seq = Signal()
84 ab64_seq = Signal()
85 comb += ab32_seq.eq(a32_s ^ b32_s)
86 comb += ab64_seq.eq(a64_s ^ b64_s)
87
88 # setup random inputs
89 comb += [a.eq(AnyConst(64)),
90 b.eq(AnyConst(64)),
91 ]
92
93 comb += dut.i.ctx.op.eq(rec)
94
95 # Assert that op gets copied from the input to output
96 comb += Assert(dut.o.ctx.op == dut.i.ctx.op)
97 comb += Assert(dut.o.ctx.muxid == dut.i.ctx.muxid)
98
99 # Assert that XER_SO propagates through as well.
100 # Doesn't mean that the ok signal is always set though.
101 comb += Assert(dut.o.xer_so.data == dut.i.xer_so)
102
103 # main assertion of arithmetic operations
104 with m.Switch(rec.insn_type):
105 with m.Case(MicrOp.OP_MUL_H32):
106 comb += Assume(rec.is_32bit) # OP_MUL_H32 is a 32-bit op
107
108 exp_prod = Signal(64)
109 expected_o = Signal.like(exp_prod)
110
111 # unsigned hi32 - mulhwu
112 with m.If(~rec.is_signed):
113 comb += exp_prod.eq(a[0:32] * b[0:32])
114 comb += expected_o.eq(Repl(exp_prod[32:64], 2))
115 comb += Assert(o[0:64] == expected_o)
116
117 # signed hi32 - mulhw
118 with m.Else():
119 prod = Signal.like(exp_prod) # intermediate product
120 comb += prod.eq(abs32_a * abs32_b)
121 # TODO: comment why a[31]^b[31] is used to invert prod?
122 comb += exp_prod.eq(Mux(ab32_seq, -prod, prod))
123 comb += expected_o.eq(Repl(exp_prod[32:64], 2))
124 comb += Assert(o[0:64] == expected_o)
125
126 with m.Case(MicrOp.OP_MUL_H64):
127 comb += Assume(~rec.is_32bit)
128
129 exp_prod = Signal(128)
130
131 # unsigned hi64 - mulhdu
132 with m.If(~rec.is_signed):
133 comb += exp_prod.eq(a[0:64] * b[0:64])
134 comb += Assert(o[0:64] == exp_prod[64:128])
135
136 # signed hi64 - mulhd
137 with m.Else():
138 prod = Signal.like(exp_prod) # intermediate product
139 comb += prod.eq(abs64_a * abs64_b)
140 comb += exp_prod.eq(Mux(ab64_seq, -prod, prod))
141 comb += Assert(o[0:64] == exp_prod[64:128])
142
143 # mulli, mullw(o)(u), mulld(o)
144 with m.Case(MicrOp.OP_MUL_L64):
145 with m.If(rec.is_32bit):
146 expected_ov = Signal()
147 prod = Signal(64)
148 exp_prod = Signal.like(prod)
149
150 # unsigned lo32 - mullwu
151 with m.If(~rec.is_signed):
152 comb += exp_prod.eq(a[0:32] * b[0:32])
153 comb += Assert(o[0:64] == exp_prod[0:64])
154
155 # signed lo32 - mullw
156 with m.Else():
157 # TODO: comment why a[31]^b[31] is used to invert prod?
158 comb += prod.eq(abs32_a[0:64] * abs32_b[0:64])
159 comb += exp_prod.eq(Mux(ab32_seq, -prod, prod))
160 comb += Assert( o[0:64] == exp_prod[0:64])
161
162 # TODO: how does m31.bool & ~m31.all work?
163 m31 = exp_prod[31:64]
164 comb += expected_ov.eq(m31.bool() & ~m31.all())
165 comb += Assert(dut.o.xer_ov.data == Repl(expected_ov, 2))
166
167 with m.Else(): # is 64-bit; mulld
168 expected_ov = Signal()
169 prod = Signal(128)
170 exp_prod = Signal.like(prod)
171
172 # From my reading of the v3.0B ISA spec,
173 # only signed instructions exist.
174 comb += Assume(rec.is_signed)
175
176 # TODO: comment why a[63]^b[63] is used to invert prod?
177 comb += prod.eq(abs64_a[0:64] * abs64_b[0:64])
178 comb += exp_prod.eq(Mux(ab64_seq, -prod, prod))
179 comb += Assert(o[0:64] == exp_prod[0:64])
180
181 # TODO: how does m63.bool & ~m63.all work?
182 m63 = exp_prod[63:128]
183 comb += expected_ov.eq(m63.bool() & ~m63.all())
184 comb += Assert(dut.o.xer_ov.data == Repl(expected_ov, 2))
185
186 return m
187
188
189 class MulTestCase(FHDLTestCase):
190 def test_formal(self):
191 module = Driver()
192 self.assertFormal(module, mode="bmc", depth=2)
193 self.assertFormal(module, mode="cover", depth=2)
194 def test_ilang(self):
195 dut = Driver()
196 vl = rtlil.convert(dut, ports=[])
197 with open("main_stage.il", "w") as f:
198 f.write(vl)
199
200
201 if __name__ == '__main__':
202 unittest.main()