add fixed PLRU
[nmutil.git] / src / nmutil / plru2.py
1 # based on microwatt plru.vhdl
2 # https://github.com/antonblanchard/microwatt/blob/f67b1431655c291fc1c99857a5c1ef624d5b264c/plru.vhdl
3
4 # new PLRU API, once all users have migrated to new API in plru2.py, then
5 # plru2.py will be renamed to plru.py.
6
7 from nmigen.hdl.ir import Elaboratable, Display, Signal, Array, Const, Value
8 from nmigen.hdl.dsl import Module
9 from nmigen.cli import rtlil
10 from nmigen.lib.coding import Decoder
11
12
13 class PLRU(Elaboratable):
14 r""" PLRU - Pseudo Least Recently Used Replacement
15
16 PLRU-tree indexing:
17 lvl0 0
18 / \
19 / \
20 / \
21 lvl1 1 2
22 / \ / \
23 lvl2 3 4 5 6
24 / \ / \ / \ / \
25 ... ... ... ...
26 """
27
28 def __init__(self, log2_num_ways, debug=False):
29 # type: (int, bool) -> None
30 """
31 Arguments:
32 log2_num_ways: int
33 the log-base-2 of the number of cache ways -- BITS in plru.vhdl
34 debug: bool
35 true if this should print debugging messages at simulation time.
36 """
37 assert log2_num_ways > 0
38 self.log2_num_ways = log2_num_ways
39 self.debug = debug
40 self.acc_i = Signal(log2_num_ways)
41 self.acc_en_i = Signal()
42 self.lru_o = Signal(log2_num_ways)
43
44 def mk_tree(i):
45 return Signal(name=f"tree_{i}", reset=0)
46
47 # original vhdl has array 1 too big, last entry is never used,
48 # subtract 1 to compensate
49 self._tree = Array(mk_tree(i) for i in range(self.num_ways - 1))
50 """ exposed only for testing """
51
52 def mk_node(i, prefix):
53 return Signal(range(self.num_ways), name=f"{prefix}_node_{i}",
54 reset=0)
55
56 nodes_range = range(self.log2_num_ways)
57
58 self._get_lru_nodes = [mk_node(i, "get_lru") for i in nodes_range]
59 """ exposed only for testing """
60
61 self._upd_lru_nodes = [mk_node(i, "upd_lru") for i in nodes_range]
62 """ exposed only for testing """
63
64 @property
65 def num_ways(self):
66 return 1 << self.log2_num_ways
67
68 def _display(self, msg, *args):
69 if not self.debug:
70 return []
71 # work around not yet having
72 # https://gitlab.com/nmigen/nmigen/-/merge_requests/10
73 # by sending through Value.cast()
74 return [Display(msg, *map(Value.cast, args))]
75
76 def _get_lru(self, m):
77 """ get_lru process in plru.vhdl """
78 # XXX Check if we can turn that into a little ROM instead that
79 # takes the tree bit vector and returns the LRU. See if it's better
80 # in term of FPGA resource usage...
81 m.d.comb += self._get_lru_nodes[0].eq(0)
82 for i in range(self.log2_num_ways):
83 node = self._get_lru_nodes[i]
84 val = self._tree[node]
85 m.d.comb += self._display("GET: i:%i node:%#x val:%i",
86 i, node, val)
87 m.d.comb += self.lru_o[self.log2_num_ways - 1 - i].eq(val)
88 if i != self.log2_num_ways - 1:
89 # modified from microwatt version, it uses `node * 2` value
90 # to index into tree, rather than using node like is used
91 # earlier in this loop iteration
92 node <<= 1
93 with m.If(val):
94 m.d.comb += self._get_lru_nodes[i + 1].eq(node + 2)
95 with m.Else():
96 m.d.comb += self._get_lru_nodes[i + 1].eq(node + 1)
97
98 def _update_lru(self, m):
99 """ update_lru process in plru.vhdl """
100 with m.If(self.acc_en_i):
101 m.d.comb += self._upd_lru_nodes[0].eq(0)
102 for i in range(self.log2_num_ways):
103 node = self._upd_lru_nodes[i]
104 abit = self.acc_i[self.log2_num_ways - 1 - i]
105 m.d.sync += [
106 self._tree[node].eq(~abit),
107 self._display("UPD: i:%i node:%#x val:%i",
108 i, node, ~abit),
109 ]
110 if i != self.log2_num_ways - 1:
111 node <<= 1
112 with m.If(abit):
113 m.d.comb += self._upd_lru_nodes[i + 1].eq(node + 2)
114 with m.Else():
115 m.d.comb += self._upd_lru_nodes[i + 1].eq(node + 1)
116
117 def elaborate(self, platform=None):
118 m = Module()
119 self._get_lru(m)
120 self._update_lru(m)
121 return m
122
123 def __iter__(self):
124 yield self.acc_i
125 yield self.acc_en_i
126 yield self.lru_o
127
128 def ports(self):
129 return list(self)
130
131
132 # FIXME: convert PLRUs to new API
133 # class PLRUs(Elaboratable):
134 # def __init__(self, n_plrus, n_bits):
135 # self.n_plrus = n_plrus
136 # self.n_bits = n_bits
137 # self.valid = Signal()
138 # self.way = Signal(n_bits)
139 # self.index = Signal(n_plrus.bit_length())
140 # self.isel = Signal(n_plrus.bit_length())
141 # self.o_index = Signal(n_bits)
142 #
143 # def elaborate(self, platform):
144 # """Generate TLB PLRUs
145 # """
146 # m = Module()
147 # comb = m.d.comb
148 #
149 # if self.n_plrus == 0:
150 # return m
151 #
152 # # Binary-to-Unary one-hot, enabled by valid
153 # m.submodules.te = te = Decoder(self.n_plrus)
154 # comb += te.n.eq(~self.valid)
155 # comb += te.i.eq(self.index)
156 #
157 # out = Array(Signal(self.n_bits, name="plru_out%d" % x)
158 # for x in range(self.n_plrus))
159 #
160 # for i in range(self.n_plrus):
161 # # PLRU interface
162 # m.submodules["plru_%d" % i] = plru = PLRU(self.n_bits)
163 #
164 # comb += plru.acc_en.eq(te.o[i])
165 # comb += plru.acc_i.eq(self.way)
166 # comb += out[i].eq(plru.lru_o)
167 #
168 # # select output based on index
169 # comb += self.o_index.eq(out[self.isel])
170 #
171 # return m
172 #
173 # def ports(self):
174 # return [self.valid, self.way, self.index, self.isel, self.o_index]
175
176
177 if __name__ == '__main__':
178 dut = PLRU(3)
179 vl = rtlil.convert(dut, ports=dut.ports())
180 with open("test_plru.il", "w") as f:
181 f.write(vl)
182
183 # dut = PLRUs(4, 2)
184 # vl = rtlil.convert(dut, ports=dut.ports())
185 # with open("test_plrus.il", "w") as f:
186 # f.write(vl)