speed up ==, hash, <, >, <=, and >= for plain_data
[nmutil.git] / src / nmutil / plru2.py
1 # based on microwatt plru.vhdl
2 # https://github.com/antonblanchard/microwatt/blob/f67b1431655c291fc1c99857a5c1ef624d5b264c/plru.vhdl
3
4 # new PLRU API, once all users have migrated to new API in plru2.py, then
5 # plru2.py will be renamed to plru.py.
6 # IMPORTANT: since the API will change more, migration should be blocked on:
7 # https://bugs.libre-soc.org/show_bug.cgi?id=913
8
9 from nmigen.hdl.ir import Elaboratable, Display, Signal, Array, Const, Value
10 from nmigen.hdl.dsl import Module
11 from nmigen.cli import rtlil
12 from nmigen.lib.coding import Decoder
13
14
15 class PLRU(Elaboratable):
16 r""" PLRU - Pseudo Least Recently Used Replacement
17
18 IMPORTANT: since the API will change more, migration should be blocked on:
19 https://bugs.libre-soc.org/show_bug.cgi?id=913
20
21 PLRU-tree indexing:
22 lvl0 0
23 / \
24 / \
25 / \
26 lvl1 1 2
27 / \ / \
28 lvl2 3 4 5 6
29 / \ / \ / \ / \
30 ... ... ... ...
31 """
32
33 def __init__(self, log2_num_ways, debug=False):
34 # type: (int, bool) -> None
35 """
36 IMPORTANT: since the API will change more, migration should be blocked on:
37 https://bugs.libre-soc.org/show_bug.cgi?id=913
38
39 Arguments:
40 log2_num_ways: int
41 the log-base-2 of the number of cache ways -- BITS in plru.vhdl
42 debug: bool
43 true if this should print debugging messages at simulation time.
44 """
45 assert log2_num_ways > 0
46 self.log2_num_ways = log2_num_ways
47 self.debug = debug
48 self.acc_i = Signal(log2_num_ways)
49 self.acc_en_i = Signal()
50 self.lru_o = Signal(log2_num_ways)
51
52 def mk_tree(i):
53 return Signal(name=f"tree_{i}", reset=0)
54
55 # original vhdl has array 1 too big, last entry is never used,
56 # subtract 1 to compensate
57 self._tree = Array(mk_tree(i) for i in range(self.num_ways - 1))
58 """ exposed only for testing """
59
60 def mk_node(i, prefix):
61 return Signal(range(self.num_ways), name=f"{prefix}_node_{i}",
62 reset=0)
63
64 nodes_range = range(self.log2_num_ways)
65
66 self._get_lru_nodes = [mk_node(i, "get_lru") for i in nodes_range]
67 """ exposed only for testing """
68
69 self._upd_lru_nodes = [mk_node(i, "upd_lru") for i in nodes_range]
70 """ exposed only for testing """
71
72 @property
73 def num_ways(self):
74 return 1 << self.log2_num_ways
75
76 def _display(self, msg, *args):
77 if not self.debug:
78 return []
79 # work around not yet having
80 # https://gitlab.com/nmigen/nmigen/-/merge_requests/10
81 # by sending through Value.cast()
82 return [Display(msg, *map(Value.cast, args))]
83
84 def _get_lru(self, m):
85 """ get_lru process in plru.vhdl """
86 # XXX Check if we can turn that into a little ROM instead that
87 # takes the tree bit vector and returns the LRU. See if it's better
88 # in term of FPGA resource usage...
89 m.d.comb += self._get_lru_nodes[0].eq(0)
90 for i in range(self.log2_num_ways):
91 node = self._get_lru_nodes[i]
92 val = self._tree[node]
93 m.d.comb += self._display("GET: i:%i node:%#x val:%i",
94 i, node, val)
95 m.d.comb += self.lru_o[self.log2_num_ways - 1 - i].eq(val)
96 if i != self.log2_num_ways - 1:
97 # modified from microwatt version, it uses `node * 2` value
98 # to index into tree, rather than using node like is used
99 # earlier in this loop iteration
100 node <<= 1
101 with m.If(val):
102 m.d.comb += self._get_lru_nodes[i + 1].eq(node + 2)
103 with m.Else():
104 m.d.comb += self._get_lru_nodes[i + 1].eq(node + 1)
105
106 def _update_lru(self, m):
107 """ update_lru process in plru.vhdl """
108 with m.If(self.acc_en_i):
109 m.d.comb += self._upd_lru_nodes[0].eq(0)
110 for i in range(self.log2_num_ways):
111 node = self._upd_lru_nodes[i]
112 abit = self.acc_i[self.log2_num_ways - 1 - i]
113 m.d.sync += [
114 self._tree[node].eq(~abit),
115 self._display("UPD: i:%i node:%#x val:%i",
116 i, node, ~abit),
117 ]
118 if i != self.log2_num_ways - 1:
119 node <<= 1
120 with m.If(abit):
121 m.d.comb += self._upd_lru_nodes[i + 1].eq(node + 2)
122 with m.Else():
123 m.d.comb += self._upd_lru_nodes[i + 1].eq(node + 1)
124
125 def elaborate(self, platform=None):
126 m = Module()
127 self._get_lru(m)
128 self._update_lru(m)
129 return m
130
131 def __iter__(self):
132 yield self.acc_i
133 yield self.acc_en_i
134 yield self.lru_o
135
136 def ports(self):
137 return list(self)
138
139
140 # FIXME: convert PLRUs to new API
141 # class PLRUs(Elaboratable):
142 # def __init__(self, n_plrus, n_bits):
143 # self.n_plrus = n_plrus
144 # self.n_bits = n_bits
145 # self.valid = Signal()
146 # self.way = Signal(n_bits)
147 # self.index = Signal(n_plrus.bit_length())
148 # self.isel = Signal(n_plrus.bit_length())
149 # self.o_index = Signal(n_bits)
150 #
151 # def elaborate(self, platform):
152 # """Generate TLB PLRUs
153 # """
154 # m = Module()
155 # comb = m.d.comb
156 #
157 # if self.n_plrus == 0:
158 # return m
159 #
160 # # Binary-to-Unary one-hot, enabled by valid
161 # m.submodules.te = te = Decoder(self.n_plrus)
162 # comb += te.n.eq(~self.valid)
163 # comb += te.i.eq(self.index)
164 #
165 # out = Array(Signal(self.n_bits, name="plru_out%d" % x)
166 # for x in range(self.n_plrus))
167 #
168 # for i in range(self.n_plrus):
169 # # PLRU interface
170 # m.submodules["plru_%d" % i] = plru = PLRU(self.n_bits)
171 #
172 # comb += plru.acc_en.eq(te.o[i])
173 # comb += plru.acc_i.eq(self.way)
174 # comb += out[i].eq(plru.lru_o)
175 #
176 # # select output based on index
177 # comb += self.o_index.eq(out[self.isel])
178 #
179 # return m
180 #
181 # def ports(self):
182 # return [self.valid, self.way, self.index, self.isel, self.o_index]
183
184
185 if __name__ == '__main__':
186 dut = PLRU(3)
187 vl = rtlil.convert(dut, ports=dut.ports())
188 with open("test_plru.il", "w") as f:
189 f.write(vl)
190
191 # dut = PLRUs(4, 2)
192 # vl = rtlil.convert(dut, ports=dut.ports())
193 # with open("test_plrus.il", "w") as f:
194 # f.write(vl)