49b8599b4ce9dee56747a00faf7e747971519f36
[nmutil.git] / src / nmutil / formal / test_plru.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay
3
4 import unittest
5 from nmigen.hdl.ast import (AnySeq, Assert, Signal, Assume, Const,
6 unsigned, AnyConst, Value)
7 from nmigen.hdl.dsl import Module
8 from nmutil.formaltest import FHDLTestCase
9 from nmutil.plru import PLRU, PLRUs
10 from nmutil.sim_util import write_il
11 from nmutil.plain_data import plain_data
12
13
14 @plain_data()
15 class PrettyPrintState:
16 __slots__ = "indent", "file", "at_line_start"
17
18 def __init__(self, indent=0, file=None, at_line_start=True):
19 self.indent = indent
20 self.file = file
21 self.at_line_start = at_line_start
22
23 def write(self, text):
24 # type: (str) -> None
25 for ch in text:
26 if ch == "\n":
27 self.at_line_start = True
28 elif self.at_line_start:
29 self.at_line_start = False
30 print(" " * self.indent, file=self.file, end='')
31 print(ch, file=self.file, end='')
32
33
34 @plain_data()
35 class PLRUNode:
36 __slots__ = "state", "left_child", "right_child"
37
38 def __init__(self, state, left_child=None, right_child=None):
39 # type: (Signal, PLRUNode | None, PLRUNode | None) -> None
40 self.state = state
41 self.left_child = left_child
42 self.right_child = right_child
43
44 def __pretty_print(self, state):
45 # type: (PrettyPrintState) -> None
46 state.write("PLRUNode(")
47 state.indent += 1
48 state.write(f"state={self.state!r},\n")
49 state.write("left_child=")
50 if self.left_child is None:
51 state.write("None")
52 else:
53 self.left_child.__pretty_print(state)
54 state.write(",\nright_child=")
55 if self.right_child is None:
56 state.write("None")
57 else:
58 self.right_child.__pretty_print(state)
59 state.indent -= 1
60 state.write("\n)")
61
62 def pretty_print(self, file=None):
63 self.__pretty_print(PrettyPrintState(file=file))
64
65 def set_states_from_index(self, m, index):
66 # type: (Module, Value) -> None
67 m.d.sync += self.state.eq(index[-1])
68 with m.If(index[-1]):
69 if self.left_child is not None:
70 self.left_child.set_states_from_index(m, index[:-1])
71 with m.Else():
72 if self.right_child is not None:
73 self.right_child.set_states_from_index(m, index[:-1])
74
75
76 class TestPLRU(FHDLTestCase):
77 @unittest.skip("not finished yet")
78 def tst(self, BITS):
79 # type: (int) -> None
80
81 # FIXME: figure out what BITS is supposed to mean -- I would have
82 # expected it to be the number of cache ways, or the number of state
83 # bits in PLRU, but it's neither of those, making me think whoever
84 # converted the code botched their math.
85 #
86 # Until that's figured out, this test is broken.
87
88 dut = PLRU(BITS)
89 write_il(self, dut, ports=dut.ports())
90 m = Module()
91 nodes = [PLRUNode(Signal(name=f"state_{i}")) for i in range(dut.TLBSZ)]
92 self.assertEqual(len(dut._plru_tree), len(nodes))
93 for i in range(1, dut.TLBSZ):
94 parent = (i + 1) // 2 - 1
95 if i % 2:
96 nodes[parent].left_child = nodes[i]
97 else:
98 nodes[parent].right_child = nodes[i]
99 m.d.comb += Assert(nodes[i].state == dut._plru_tree[i])
100
101 in_index = Signal(range(BITS))
102
103 m.d.comb += [
104 in_index.eq(AnySeq(range(BITS))),
105 Assume(in_index < BITS),
106 dut.acc_i.eq(1 << in_index),
107 dut.acc_en.eq(AnySeq(1)),
108 ]
109
110 with m.If(dut.acc_en):
111 nodes[0].set_states_from_index(m, in_index)
112
113 nodes[0].pretty_print()
114
115 m.submodules.dut = dut
116 self.assertFormal(m, mode="prove")
117
118 def test_bits_1(self):
119 self.tst(1)
120
121 def test_bits_2(self):
122 self.tst(2)
123
124 def test_bits_3(self):
125 self.tst(3)
126
127 def test_bits_4(self):
128 self.tst(4)
129
130 def test_bits_5(self):
131 self.tst(5)
132
133 def test_bits_6(self):
134 self.tst(6)
135
136 def test_bits_7(self):
137 self.tst(7)
138
139 def test_bits_8(self):
140 self.tst(8)
141
142 def test_bits_9(self):
143 self.tst(9)
144
145 def test_bits_10(self):
146 self.tst(10)
147
148 def test_bits_11(self):
149 self.tst(11)
150
151 def test_bits_12(self):
152 self.tst(12)
153
154 def test_bits_13(self):
155 self.tst(13)
156
157 def test_bits_14(self):
158 self.tst(14)
159
160 def test_bits_15(self):
161 self.tst(15)
162
163 def test_bits_16(self):
164 self.tst(16)
165
166
167 if __name__ == "__main__":
168 unittest.main()