1431a0386d595a1252c19c1254d0c050295d748e
[soc.git] / src / soc / fu / spr / formal / proof_main_stage.py
1 # Proof of correctness for SPR pipeline, main stage
2
3
4 """
5 Links:
6 * https://bugs.libre-soc.org/show_bug.cgi?id=418
7 """
8
9 import unittest
10
11 from nmigen import (Elaboratable, Module, Signal, Cat)
12 from nmigen.asserts import Assert, AnyConst, Assume
13 from nmigen.cli import rtlil
14
15 from nmutil.formaltest import FHDLTestCase
16
17 from soc.fu.spr.main_stage import SPRMainStage
18 from soc.fu.spr.pipe_data import SPRPipeSpec
19 from soc.fu.spr.spr_input_record import CompSPROpSubset
20
21 from openpower.decoder.power_decoder2 import decode_spr_num
22 from openpower.decoder.power_enums import MicrOp, SPR, XER_bits
23 from openpower.decoder.power_fields import DecodeFields
24 from openpower.decoder.power_fieldsn import SignalBitRange
25
26 # use POWER numbering. sigh.
27 def xer_bit(name):
28 return 63-XER_bits[name]
29
30
31 class Driver(Elaboratable):
32 """
33 Defines a module to drive the device under test and assert properties
34 about its outputs.
35 """
36
37 def elaborate(self, platform):
38 m = Module()
39 comb = m.d.comb
40
41 # cookie-cutting most of this from alu formal proof_main_stage.py
42
43 rec = CompSPROpSubset()
44 # Setup random inputs for dut.op
45 for p in rec.ports():
46 width = p.width
47 comb += p.eq(AnyConst(width))
48
49 pspec = SPRPipeSpec(id_wid=2)
50 m.submodules.dut = dut = SPRMainStage(pspec)
51
52 # frequently used aliases
53 a = dut.i.a
54 ca_in = dut.i.xer_ca[0] # CA carry in
55 ca32_in = dut.i.xer_ca[1] # CA32 carry in 32
56 so_in = dut.i.xer_so # SO sticky overflow
57 ov_in = dut.i.xer_ov[0] # XER OV in
58 ov32_in = dut.i.xer_ov[1] # XER OV32 in
59 o = dut.o.o
60
61 # setup random inputs
62 comb += [a.eq(AnyConst(64)),
63 ca_in.eq(AnyConst(0b11)),
64 so_in.eq(AnyConst(1))]
65
66 # and for the context muxid
67 width = dut.i.ctx.muxid.width
68 comb += dut.i.ctx.muxid.eq(AnyConst(width))
69
70 # assign the PowerDecode2 operation subset
71 comb += dut.i.ctx.op.eq(rec)
72
73 # check that the operation (op) is passed through (and muxid)
74 comb += Assert(dut.o.ctx.op == dut.i.ctx.op )
75 comb += Assert(dut.o.ctx.muxid == dut.i.ctx.muxid )
76
77 # MTSPR
78 fields = DecodeFields(SignalBitRange, [dut.i.ctx.op.insn])
79 fields.create_specs()
80 xfx = fields.FormXFX
81 spr = Signal(len(xfx.SPR))
82 comb += spr.eq(decode_spr_num(xfx.SPR))
83
84 with m.Switch(dut.i.ctx.op.insn_type):
85
86 # OP_MTSPR
87 with m.Case(MicrOp.OP_MTSPR):
88 with m.Switch(spr):
89 with m.Case(SPR.CTR, SPR.LR, SPR.TAR, SPR.SRR0, SPR.SRR1):
90 comb += [
91 Assert(dut.o.fast1.data == a),
92 Assert(dut.o.fast1.ok),
93
94 # If a fast-path SPR is referenced, no other OKs
95 # can fire.
96 Assert(~dut.o.xer_so.ok),
97 Assert(~dut.o.xer_ov.ok),
98 Assert(~dut.o.xer_ca.ok),
99 ]
100 with m.Case(SPR.XER):
101 comb += [
102 Assert(dut.o.xer_so.data == a[xer_bit('SO')]),
103 Assert(dut.o.xer_so.ok),
104 Assert(dut.o.xer_ov.data == Cat(
105 a[xer_bit('OV')], a[xer_bit('OV32')]
106 )),
107 Assert(dut.o.xer_ov.ok),
108 Assert(dut.o.xer_ca.data == Cat(
109 a[xer_bit('CA')], a[xer_bit('CA32')]
110 )),
111 Assert(dut.o.xer_ca.ok),
112
113 # XER is not a fast-path SPR.
114 Assert(~dut.o.fast1.ok),
115 ]
116 # slow SPRs TODO
117
118 # OP_MFSPR
119 with m.Case(MicrOp.OP_MFSPR):
120 comb += Assert(o.ok)
121 with m.Switch(spr):
122 with m.Case(SPR.CTR, SPR.LR, SPR.TAR, SPR.SRR0, SPR.SRR1):
123 comb += Assert(o.data == dut.i.fast1)
124 with m.Case(SPR.XER):
125 bits = {
126 'SO': so_in,
127 'OV': ov_in,
128 'OV32': ov32_in,
129 'CA': ca_in,
130 'CA32': ca32_in,
131 }
132 comb += [
133 Assert(o[xer_bit(b)] == bits[b])
134 for b in bits
135 ]
136 # slow SPRs TODO
137
138 return m
139
140
141 class SPRMainStageTestCase(FHDLTestCase):
142 def test_formal(self):
143 self.assertFormal(Driver(), mode="bmc", depth=100)
144 self.assertFormal(Driver(), mode="cover", depth=100)
145
146 def test_ilang(self):
147 vl = rtlil.convert(Driver(), ports=[])
148 with open("spr_main_stage.il", "w") as f:
149 f.write(vl)
150
151
152 if __name__ == '__main__':
153 unittest.main()