Add proof that all other opcodes decode to INVALID
[soc.git] / src / soc / decoder / formal / proof_decoder.py
1 from nmigen import Module, Signal, Elaboratable, Cat
2 from nmigen.asserts import Assert, AnyConst, Assume
3 from nmigen.test.utils import FHDLTestCase
4
5 from soc.decoder.power_decoder import create_pdecode, PowerOp
6 from soc.decoder.power_enums import (In1Sel, In2Sel, In3Sel,
7 OutSel, RC, Form, Function,
8 LdstLen, CryIn,
9 InternalOp, SPR, get_csv)
10 from soc.decoder.power_decoder2 import (PowerDecode2,
11 Decode2ToExecute1Type)
12 import unittest
13 import pdb
14
15 class Driver(Elaboratable):
16 def __init__(self):
17 self.instruction = Signal(32, reset_less=True)
18 self.m = None
19 self.comb = None
20
21 def elaborate(self, platform):
22 self.m = Module()
23 self.comb = self.m.d.comb
24 self.instruction = Signal(32)
25
26 self.comb += self.instruction.eq(AnyConst(32))
27
28 pdecode = create_pdecode()
29
30 self.m.submodules.pdecode2 = pdecode2 = PowerDecode2(pdecode)
31 dec1 = pdecode2.dec
32 self.comb += pdecode2.dec.opcode_in.eq(self.instruction)
33
34 # ignore special decoding of nop
35 self.comb += Assume(self.instruction != 0x60000000)
36
37 #self.assert_dec1_decode(dec1, dec1.dec)
38
39 self.assert_form(dec1, pdecode2)
40 return self.m
41
42 def assert_dec1_decode(self, dec1, decoders):
43 if not isinstance(decoders, list):
44 decoders = [decoders]
45 for d in decoders:
46 print(d.pattern)
47 opcode_switch = Signal(d.bitsel[1] - d.bitsel[0])
48 self.comb += opcode_switch.eq(
49 self.instruction[d.bitsel[0]:d.bitsel[1]])
50 with self.m.Switch(opcode_switch):
51 self.handle_subdecoders(dec1, d)
52 for row in d.opcodes:
53 opcode = row['opcode']
54 if d.opint and '-' not in opcode:
55 opcode = int(opcode, 0)
56 if not row['unit']:
57 continue
58 with self.m.Case(opcode):
59 self.comb += self.assert_dec1_signals(dec1, row)
60 with self.m.Default():
61 self.comb += Assert(dec.op.internal_op ==
62 InternalOp.OP_ILLEGAL)
63
64
65 def handle_subdecoders(self, dec1, decoders):
66 for dec in decoders.subdecoders:
67 if isinstance(dec, list):
68 pattern = dec[0].pattern
69 else:
70 pattern = dec.pattern
71 with self.m.Case(pattern):
72 self.assert_dec1_decode(dec1, dec)
73
74 def assert_dec1_signals(self, dec, row):
75 op = dec.op
76 return [Assert(op.function_unit == Function[row['unit']]),
77 Assert(op.internal_op == InternalOp[row['internal op']]),
78 Assert(op.in1_sel == In1Sel[row['in1']]),
79 Assert(op.in2_sel == In2Sel[row['in2']]),
80 Assert(op.in3_sel == In3Sel[row['in3']]),
81 Assert(op.out_sel == OutSel[row['out']]),
82 Assert(op.ldst_len == LdstLen[row['ldst len']]),
83 Assert(op.rc_sel == RC[row['rc']]),
84 Assert(op.cry_in == CryIn[row['cry in']]),
85 Assert(op.form == Form[row['form']]),
86 ]
87
88 # This is to assert that the decoder conforms to the table listed
89 # in PowerISA public spec v3.0B, Section 1.6, page 12
90 def assert_form(self, dec, dec2):
91 with self.m.Switch(dec.op.form):
92 with self.m.Case(Form.A):
93 self.comb += Assert(dec.op.in1_sel.matches(
94 In1Sel.NONE, In1Sel.RA, In1Sel.RA_OR_ZERO))
95 self.comb += Assert(dec.op.in2_sel.matches(
96 In2Sel.RB, In2Sel.NONE))
97 self.comb += Assert(dec.op.in3_sel.matches(
98 In3Sel.RS, In3Sel.NONE))
99 self.comb += Assert(dec.op.out_sel.matches(
100 OutSel.NONE, OutSel.RT))
101 # The table has fields for XO and Rc, but idk what they correspond to
102 with self.m.Case(Form.B):
103 pass
104 with self.m.Case(Form.D):
105 self.comb += Assert(dec.op.in1_sel.matches(
106 In1Sel.NONE, In1Sel.RA, In1Sel.RA_OR_ZERO))
107 self.comb += Assert(dec.op.in2_sel.matches(
108 In2Sel.CONST_UI, In2Sel.CONST_SI, In2Sel.CONST_UI_HI,
109 In2Sel.CONST_SI_HI))
110 self.comb += Assert(dec.op.out_sel.matches(
111 OutSel.NONE, OutSel.RT, OutSel.RA))
112 with self.m.Case(Form.I):
113 self.comb += Assert(dec.op.in2_sel.matches(
114 In2Sel.CONST_LI))
115
116 def instr_bits(self, start, end=None):
117 if not end:
118 end = start
119 return self.instruction[::-1][start:end+1]
120
121 class DecoderTestCase(FHDLTestCase):
122 def test_decoder(self):
123 module = Driver()
124 self.assertFormal(module, mode="bmc", depth=4)
125
126 if __name__ == '__main__':
127 unittest.main()