speed up ==, hash, <, >, <=, and >= for plain_data
[nmutil.git] / src / nmutil / picker.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 """ Priority Picker: optimized back-to-back PriorityEncoder and Decoder
3 and MultiPriorityPicker: cascading mutually-exclusive pickers
4
5 This work is funded through NLnet under Grant 2019-02-012
6
7 License: LGPLv3+
8
9
10 PriorityPicker: the input is N bits, the output is N bits wide and
11 only one is enabled.
12
13 MultiPriorityPicker: likewise except that there are M pickers and
14 each output is guaranteed mutually exclusive. Optionally:
15 an "index" (and enable line) is also outputted.
16
17 MultiPriorityPicker is designed for port-selection, when there are
18 multiple "things" (of width N) contending for access to M "ports".
19 When the M=0 "thing" requests a port, it gets allocated port 0
20 (always). However if the M=0 "thing" does *not* request a port,
21 this gives the M=1 "thing" the opportunity to gain access to port 0.
22
23 Given that N may potentially be much greater than M (16 bits wide
24 where M may be e.g. only 4) we can't just ok, "ok so M=N therefore
25 M=0 gets access to port 0, M=1 gets access to port 1" etc.
26 """
27
28 from nmigen import Module, Signal, Cat, Elaboratable, Array, Const, Mux
29 from nmigen.utils import bits_for
30 from nmigen.cli import rtlil
31 import math
32 from nmutil.prefix_sum import prefix_sum
33
34
35 class PriorityPicker(Elaboratable):
36 """ implements a priority-picker. input: N bits, output: N bits
37
38 * msb_mode is for a MSB-priority picker
39 * reverse_i=True is for convenient reversal of the input bits
40 * reverse_o=True is for convenient reversal of the output bits
41 * `msb_mode=True` is redundant with `reverse_i=True, reverse_o=True`
42 but is allowed for backwards compatibility.
43 """
44
45 def __init__(self, wid, msb_mode=False, reverse_i=False, reverse_o=False):
46 self.wid = wid
47 # inputs
48 self.msb_mode = msb_mode
49 self.reverse_i = reverse_i
50 self.reverse_o = reverse_o
51 self.i = Signal(wid, reset_less=True)
52 self.o = Signal(wid, reset_less=True)
53
54 self.en_o = Signal(reset_less=True)
55 "true if any output is true"
56
57 def elaborate(self, platform):
58 m = Module()
59
60 # works by saying, "if all previous bits were zero, we get a chance"
61 res = []
62 ni = Signal(self.wid, reset_less=True)
63 i = list(self.i)
64 if self.reverse_i:
65 i.reverse()
66 if self.msb_mode:
67 i.reverse()
68 m.d.comb += ni.eq(~Cat(*i))
69 prange = list(range(0, self.wid))
70 if self.msb_mode:
71 prange.reverse()
72 for n in prange:
73 t = Signal(name="t%d" % n, reset_less=True)
74 res.append(t)
75 if n == 0:
76 m.d.comb += t.eq(i[n])
77 else:
78 m.d.comb += t.eq(~Cat(ni[n], *i[:n]).bool())
79 if self.reverse_o:
80 res.reverse()
81 # we like Cat(*xxx). turn lists into concatenated bits
82 m.d.comb += self.o.eq(Cat(*res))
83 # useful "is any output enabled" signal
84 m.d.comb += self.en_o.eq(self.o.bool()) # true if 1 input is true
85
86 return m
87
88 def __iter__(self):
89 yield self.i
90 yield self.o
91 yield self.en_o
92
93 def ports(self):
94 return list(self)
95
96
97 class MultiPriorityPicker(Elaboratable):
98 """ implements a multi-input priority picker
99 Mx inputs of N bits, Mx outputs of N bits, only one is set
100
101 Each picker masks out the one below it, such that the first
102 gets top priority, the second cannot have the same bit that
103 the first has set, and so on. To do this, a "mask" accumulates
104 the output from the chain, masking the input to the next chain.
105
106 Also outputted (optional): an index for each picked "thing".
107 """
108
109 def __init__(self, wid, levels, indices=False, multi_in=False):
110 self.levels = levels
111 self.wid = wid
112 self.indices = indices
113 self.multi_in = multi_in
114
115 if multi_in:
116 # multiple inputs, multiple outputs.
117 i_l = [] # array of picker outputs
118 for j in range(self.levels):
119 i = Signal(self.wid, name="i_%d" % j, reset_less=True)
120 i_l.append(i)
121 self.i = Array(i_l)
122 else:
123 # only the one input, but multiple (single) bit outputs
124 self.i = Signal(self.wid, reset_less=True)
125
126 # create array of (single-bit) outputs (unary)
127 o_l = [] # array of picker outputs
128 for j in range(self.levels):
129 o = Signal(self.wid, name="o_%d" % j, reset_less=True)
130 o_l.append(o)
131 self.o = Array(o_l)
132
133 # add an array of "enables"
134 self.en_o = Signal(self.levels, name="en_o", reset_less=True)
135
136 if not self.indices:
137 return
138
139 # add an array of indices
140 lidx = math.ceil(math.log2(self.levels))
141 idx_o = [] # store the array of indices
142 for j in range(self.levels):
143 i = Signal(lidx, name="idxo_%d" % j, reset_less=True)
144 idx_o.append(i)
145 self.idx_o = Array(idx_o)
146
147 def elaborate(self, platform):
148 m = Module()
149 comb = m.d.comb
150
151 # create Priority Pickers, accumulate their outputs and prevent
152 # the next one in the chain from selecting that output bit.
153 # the input from the current picker will be "masked" and connected
154 # to the *next* picker on the next loop
155 prev_pp = None
156 p_mask = None
157 pp_l = []
158 for j in range(self.levels):
159 if self.multi_in:
160 i = self.i[j]
161 else:
162 i = self.i
163 o = self.o[j]
164 pp = PriorityPicker(self.wid)
165 pp_l.append(pp)
166 setattr(m.submodules, "pp%d" % j, pp)
167 comb += o.eq(pp.o)
168 if prev_pp is None:
169 comb += pp.i.eq(i)
170 p_mask = Const(0, self.wid)
171 else:
172 mask = Signal(self.wid, name="m_%d" % j, reset_less=True)
173 comb += mask.eq(prev_pp.o | p_mask) # accumulate output bits
174 comb += pp.i.eq(i & ~mask) # mask out input
175 p_mask = mask
176 i = pp.i # for input to next round
177 prev_pp = pp
178
179 # accumulate the enables
180 en_l = []
181 for j in range(self.levels):
182 en_l.append(pp_l[j].en_o)
183 # concat accumulated enable bits
184 comb += self.en_o.eq(Cat(*en_l))
185
186 if not self.indices:
187 return m
188
189 # for each picker enabled, pass that out and set a cascading index
190 lidx = math.ceil(math.log2(self.levels))
191 prev_count = 0
192 for j in range(self.levels):
193 en_o = pp_l[j].en_o
194 count1 = Signal(lidx, name="count_%d" % j, reset_less=True)
195 comb += count1.eq(prev_count + Const(1, lidx))
196 comb += self.idx_o[j].eq(prev_count)
197 prev_count = Mux(en_o, count1, prev_count)
198
199 return m
200
201 def __iter__(self):
202 if self.multi_in:
203 yield from self.i
204 else:
205 yield self.i
206 yield from self.o
207 yield self.en_o
208 if not self.indices:
209 return
210 yield from self.idx_o
211
212 def ports(self):
213 return list(self)
214
215
216 class BetterMultiPriorityPicker(Elaboratable):
217 """A better replacement for MultiPriorityPicker that has O(log levels)
218 latency, rather than > O(levels) latency.
219 """
220
221 def __init__(self, width, levels, *, work_efficient=False):
222 assert isinstance(width, int) and width >= 1
223 assert isinstance(levels, int) and 1 <= levels <= width
224 assert isinstance(work_efficient, bool)
225 self.width = width
226 self.levels = levels
227 self.work_efficient = work_efficient
228 assert self.__index_sat > self.levels - 1
229 self.i = Signal(width)
230 self.o = [Signal(width, name=f"o_{i}") for i in range(levels)]
231 self.en_o = Signal(levels)
232
233 @property
234 def __index_width(self):
235 return bits_for(self.levels)
236
237 @property
238 def __index_sat(self):
239 return (1 << self.__index_width) - 1
240
241 def elaborate(self, platform):
242 m = Module()
243
244 def sat_add(a, b):
245 sum = Signal(self.__index_width + 1)
246 m.d.comb += sum.eq(a + b)
247 retval = Signal(self.__index_width)
248 m.d.comb += retval.eq(Mux(sum[-1], self.__index_sat, sum))
249 return retval
250 indexes = prefix_sum((self.i[i] for i in range(self.width - 1)),
251 sat_add, work_efficient=self.work_efficient)
252 indexes.insert(0, 0)
253 for i in range(self.width):
254 sig = Signal(self.__index_width, name=f"index_{i}")
255 m.d.comb += sig.eq(indexes[i])
256 indexes[i] = sig
257 for level in range(self.levels):
258 m.d.comb += self.en_o[level].eq(self.o[level].bool())
259 for i in range(self.width):
260 index_matches = indexes[i] == level
261 m.d.comb += self.o[level][i].eq(index_matches & self.i[i])
262
263 return m
264
265 def __iter__(self):
266 yield self.i
267 yield from self.o
268 yield self.en_o
269
270 def ports(self):
271 return list(self)
272
273
274 if __name__ == '__main__':
275 dut = PriorityPicker(16)
276 vl = rtlil.convert(dut, ports=dut.ports())
277 with open("test_picker.il", "w") as f:
278 f.write(vl)
279 dut = MultiPriorityPicker(5, 4, True)
280 vl = rtlil.convert(dut, ports=dut.ports())
281 with open("test_multi_picker.il", "w") as f:
282 f.write(vl)
283 dut = MultiPriorityPicker(5, 4, False, True)
284 vl = rtlil.convert(dut, ports=dut.ports())
285 with open("test_multi_picker_noidx.il", "w") as f:
286 f.write(vl)