speed up ==, hash, <, >, <=, and >= for plain_data
[nmutil.git] / src / nmutil / multipipe.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 """ Combinatorial Multi-input and Multi-output multiplexer blocks
3 conforming to Pipeline API
4
5 This work is funded through NLnet under Grant 2019-02-012
6
7 License: LGPLv3+
8
9 Multi-input is complex because if any one input is ready, the output
10 can be ready, and the decision comes from a separate module.
11
12 Multi-output is simple (pretty much identical to UnbufferedPipeline),
13 and the selection is just a mux. The only proviso (difference) being:
14 the outputs not being selected have to have their o_ready signals
15 DEASSERTED.
16
17 https://bugs.libre-soc.org/show_bug.cgi?id=538
18 """
19
20 from math import log
21 from nmigen import Signal, Cat, Const, Mux, Module, Array, Elaboratable
22 from nmigen.cli import verilog, rtlil
23 from nmigen.lib.coding import PriorityEncoder
24 from nmigen.hdl.rec import Record, Layout
25 from nmutil.stageapi import _spec
26
27 from collections.abc import Sequence
28
29 from nmutil.nmoperator import eq
30 from nmutil.iocontrol import NextControl, PrevControl
31
32
33 class MultiInControlBase(Elaboratable):
34 """ Common functions for Pipeline API
35 """
36
37 def __init__(self, in_multi=None, p_len=1, maskwid=0, routemask=False):
38 """ Multi-input Control class. Conforms to same API as ControlBase...
39 mostly. has additional indices to the *multiple* input stages
40
41 * p: contains ready/valid to the previous stages PLURAL
42 * n: contains ready/valid to the next stage
43
44 User must also:
45 * add i_data members to PrevControl and
46 * add o_data member to NextControl
47 """
48 self.routemask = routemask
49 # set up input and output IO ACK (prev/next ready/valid)
50 print("multi_in", self, maskwid, p_len, routemask)
51 p = []
52 for i in range(p_len):
53 p.append(PrevControl(in_multi, maskwid=maskwid))
54 self.p = Array(p)
55 if routemask:
56 nmaskwid = maskwid # straight route mask mode
57 else:
58 nmaskwid = maskwid * p_len # fan-in mode
59 self.n = NextControl(maskwid=nmaskwid) # masks fan in (Cat)
60
61 def connect_to_next(self, nxt, p_idx=0):
62 """ helper function to connect to the next stage data/valid/ready.
63 """
64 return self.n.connect_to_next(nxt.p[p_idx])
65
66 def _connect_in(self, prev, idx=0, prev_idx=None):
67 """ helper function to connect stage to an input source. do not
68 use to connect stage-to-stage!
69 """
70 if prev_idx is None:
71 return self.p[idx]._connect_in(prev.p)
72 return self.p[idx]._connect_in(prev.p[prev_idx])
73
74 def _connect_out(self, nxt):
75 """ helper function to connect stage to an output source. do not
76 use to connect stage-to-stage!
77 """
78 if nxt_idx is None:
79 return self.n._connect_out(nxt.n)
80 return self.n._connect_out(nxt.n)
81
82 def set_input(self, i, idx=0):
83 """ helper function to set the input data
84 """
85 return eq(self.p[idx].i_data, i)
86
87 def elaborate(self, platform):
88 m = Module()
89 for i, p in enumerate(self.p):
90 setattr(m.submodules, "p%d" % i, p)
91 m.submodules.n = self.n
92 return m
93
94 def __iter__(self):
95 for p in self.p:
96 yield from p
97 yield from self.n
98
99 def ports(self):
100 return list(self)
101
102
103 class MultiOutControlBase(Elaboratable):
104 """ Common functions for Pipeline API
105 """
106
107 def __init__(self, n_len=1, in_multi=None, maskwid=0, routemask=False):
108 """ Multi-output Control class. Conforms to same API as ControlBase...
109 mostly. has additional indices to the multiple *output* stages
110 [MultiInControlBase has multiple *input* stages]
111
112 * p: contains ready/valid to the previou stage
113 * n: contains ready/valid to the next stages PLURAL
114
115 User must also:
116 * add i_data member to PrevControl and
117 * add o_data members to NextControl
118 """
119
120 if routemask:
121 nmaskwid = maskwid # straight route mask mode
122 else:
123 nmaskwid = maskwid * n_len # fan-out mode
124
125 # set up input and output IO ACK (prev/next ready/valid)
126 self.p = PrevControl(in_multi, maskwid=nmaskwid)
127 n = []
128 for i in range(n_len):
129 n.append(NextControl(maskwid=maskwid))
130 self.n = Array(n)
131
132 def connect_to_next(self, nxt, n_idx=0):
133 """ helper function to connect to the next stage data/valid/ready.
134 """
135 return self.n[n_idx].connect_to_next(nxt.p)
136
137 def _connect_in(self, prev, idx=0):
138 """ helper function to connect stage to an input source. do not
139 use to connect stage-to-stage!
140 """
141 return self.n[idx]._connect_in(prev.p)
142
143 def _connect_out(self, nxt, idx=0, nxt_idx=None):
144 """ helper function to connect stage to an output source. do not
145 use to connect stage-to-stage!
146 """
147 if nxt_idx is None:
148 return self.n[idx]._connect_out(nxt.n)
149 return self.n[idx]._connect_out(nxt.n[nxt_idx])
150
151 def elaborate(self, platform):
152 m = Module()
153 m.submodules.p = self.p
154 for i, n in enumerate(self.n):
155 setattr(m.submodules, "n%d" % i, n)
156 return m
157
158 def set_input(self, i):
159 """ helper function to set the input data
160 """
161 return eq(self.p.i_data, i)
162
163 def __iter__(self):
164 yield from self.p
165 for n in self.n:
166 yield from n
167
168 def ports(self):
169 return list(self)
170
171
172 class CombMultiOutPipeline(MultiOutControlBase):
173 """ A multi-input Combinatorial block conforming to the Pipeline API
174
175 Attributes:
176 -----------
177 p.i_data : stage input data (non-array). shaped according to ispec
178 n.o_data : stage output data array. shaped according to ospec
179 """
180
181 def __init__(self, stage, n_len, n_mux, maskwid=0, routemask=False):
182 MultiOutControlBase.__init__(self, n_len=n_len, maskwid=maskwid,
183 routemask=routemask)
184 self.stage = stage
185 self.maskwid = maskwid
186 self.routemask = routemask
187 self.n_mux = n_mux
188
189 # set up the input and output data
190 self.p.i_data = _spec(stage.ispec, 'i_data') # input type
191 for i in range(n_len):
192 name = 'o_data_%d' % i
193 self.n[i].o_data = _spec(stage.ospec, name) # output type
194
195 def process(self, i):
196 if hasattr(self.stage, "process"):
197 return self.stage.process(i)
198 return i
199
200 def elaborate(self, platform):
201 m = MultiOutControlBase.elaborate(self, platform)
202
203 if hasattr(self.n_mux, "elaborate"): # TODO: identify submodule?
204 m.submodules.n_mux = self.n_mux
205
206 # need buffer register conforming to *input* spec
207 r_data = _spec(self.stage.ispec, 'r_data') # input type
208 if hasattr(self.stage, "setup"):
209 self.stage.setup(m, r_data)
210
211 # multiplexer id taken from n_mux
212 muxid = self.n_mux.m_id
213 print("self.n_mux", self.n_mux)
214 print("self.n_mux.m_id", self.n_mux.m_id)
215
216 self.n_mux.m_id.name = "m_id"
217
218 # temporaries
219 p_i_valid = Signal(reset_less=True)
220 pv = Signal(reset_less=True)
221 m.d.comb += p_i_valid.eq(self.p.i_valid_test)
222 # m.d.comb += pv.eq(self.p.i_valid) #& self.n[muxid].i_ready)
223 m.d.comb += pv.eq(self.p.i_valid & self.p.o_ready)
224
225 # all outputs to next stages first initialised to zero (invalid)
226 # the only output "active" is then selected by the muxid
227 for i in range(len(self.n)):
228 m.d.comb += self.n[i].o_valid.eq(0)
229 if self.routemask:
230 # with m.If(pv):
231 m.d.comb += self.n[muxid].o_valid.eq(pv)
232 m.d.comb += self.p.o_ready.eq(self.n[muxid].i_ready)
233 else:
234 data_valid = self.n[muxid].o_valid
235 m.d.comb += self.p.o_ready.eq(self.n[muxid].i_ready)
236 m.d.comb += data_valid.eq(p_i_valid |
237 (~self.n[muxid].i_ready & data_valid))
238
239 # send data on
240 # with m.If(pv):
241 m.d.comb += eq(r_data, self.p.i_data)
242 #m.d.comb += eq(self.n[muxid].o_data, self.process(r_data))
243 for i in range(len(self.n)):
244 with m.If(muxid == i):
245 m.d.comb += eq(self.n[i].o_data, self.process(r_data))
246
247 if self.maskwid:
248 if self.routemask: # straight "routing" mode - treat like data
249 m.d.comb += self.n[muxid].stop_o.eq(self.p.stop_i)
250 with m.If(pv):
251 m.d.comb += self.n[muxid].mask_o.eq(self.p.mask_i)
252 else:
253 ml = [] # accumulate output masks
254 ms = [] # accumulate output stops
255 # fan-out mode.
256 # conditionally fan-out mask bits, always fan-out stop bits
257 for i in range(len(self.n)):
258 ml.append(self.n[i].mask_o)
259 ms.append(self.n[i].stop_o)
260 m.d.comb += Cat(*ms).eq(self.p.stop_i)
261 with m.If(pv):
262 m.d.comb += Cat(*ml).eq(self.p.mask_i)
263 return m
264
265
266 class CombMultiInPipeline(MultiInControlBase):
267 """ A multi-input Combinatorial block conforming to the Pipeline API
268
269 Attributes:
270 -----------
271 p.i_data : StageInput, shaped according to ispec
272 The pipeline input
273 p.o_data : StageOutput, shaped according to ospec
274 The pipeline output
275 r_data : input_shape according to ispec
276 A temporary (buffered) copy of a prior (valid) input.
277 This is HELD if the output is not ready. It is updated
278 SYNCHRONOUSLY.
279 """
280
281 def __init__(self, stage, p_len, p_mux, maskwid=0, routemask=False):
282 MultiInControlBase.__init__(self, p_len=p_len, maskwid=maskwid,
283 routemask=routemask)
284 self.stage = stage
285 self.maskwid = maskwid
286 self.p_mux = p_mux
287
288 # set up the input and output data
289 for i in range(p_len):
290 name = 'i_data_%d' % i
291 self.p[i].i_data = _spec(stage.ispec, name) # input type
292 self.n.o_data = _spec(stage.ospec, 'o_data')
293
294 def process(self, i):
295 if hasattr(self.stage, "process"):
296 return self.stage.process(i)
297 return i
298
299 def elaborate(self, platform):
300 m = MultiInControlBase.elaborate(self, platform)
301
302 m.submodules.p_mux = self.p_mux
303
304 # need an array of buffer registers conforming to *input* spec
305 r_data = []
306 data_valid = []
307 p_i_valid = []
308 n_i_readyn = []
309 p_len = len(self.p)
310 for i in range(p_len):
311 name = 'r_%d' % i
312 r = _spec(self.stage.ispec, name) # input type
313 r_data.append(r)
314 data_valid.append(Signal(name="data_valid", reset_less=True))
315 p_i_valid.append(Signal(name="p_i_valid", reset_less=True))
316 n_i_readyn.append(Signal(name="n_i_readyn", reset_less=True))
317 if hasattr(self.stage, "setup"):
318 print("setup", self, self.stage, r)
319 self.stage.setup(m, r)
320 if True: # len(r_data) > 1: # hmm always create an Array even of len 1
321 p_i_valid = Array(p_i_valid)
322 n_i_readyn = Array(n_i_readyn)
323 data_valid = Array(data_valid)
324
325 nirn = Signal(reset_less=True)
326 m.d.comb += nirn.eq(~self.n.i_ready)
327 mid = self.p_mux.m_id
328 print("CombMuxIn mid", self, self.stage, self.routemask, mid, p_len)
329 for i in range(p_len):
330 m.d.comb += data_valid[i].eq(0)
331 m.d.comb += n_i_readyn[i].eq(1)
332 m.d.comb += p_i_valid[i].eq(0)
333 #m.d.comb += self.p[i].o_ready.eq(~data_valid[i] | self.n.i_ready)
334 m.d.comb += self.p[i].o_ready.eq(0)
335 p = self.p[mid]
336 maskedout = Signal(reset_less=True)
337 if hasattr(p, "mask_i"):
338 m.d.comb += maskedout.eq(p.mask_i & ~p.stop_i)
339 else:
340 m.d.comb += maskedout.eq(1)
341 m.d.comb += p_i_valid[mid].eq(maskedout & self.p_mux.active &
342 self.p[mid].i_valid)
343 m.d.comb += self.p[mid].o_ready.eq(~data_valid[mid] | self.n.i_ready)
344 m.d.comb += n_i_readyn[mid].eq(nirn & data_valid[mid])
345 anyvalid = Signal(i, reset_less=True)
346 av = []
347 for i in range(p_len):
348 av.append(data_valid[i])
349 anyvalid = Cat(*av)
350 m.d.comb += self.n.o_valid.eq(anyvalid.bool())
351 m.d.comb += data_valid[mid].eq(p_i_valid[mid] |
352 (n_i_readyn[mid]))
353
354 if self.routemask:
355 # XXX hack - fixes loop
356 m.d.comb += eq(self.n.stop_o, self.p[-1].stop_i)
357 for i in range(p_len):
358 p = self.p[i]
359 vr = Signal(name="vr%d" % i, reset_less=True)
360 maskedout = Signal(name="maskedout%d" % i, reset_less=True)
361 if hasattr(p, "mask_i"):
362 m.d.comb += maskedout.eq(p.mask_i & ~p.stop_i)
363 else:
364 m.d.comb += maskedout.eq(1)
365 m.d.comb += vr.eq(maskedout.bool() & p.i_valid & p.o_ready)
366 #m.d.comb += vr.eq(p.i_valid & p.o_ready)
367 with m.If(vr):
368 m.d.comb += eq(self.n.mask_o, self.p[i].mask_i)
369 m.d.comb += eq(r_data[i], self.process(self.p[i].i_data))
370 with m.If(mid == i):
371 m.d.comb += eq(self.n.o_data, r_data[i])
372 else:
373 ml = [] # accumulate output masks
374 ms = [] # accumulate output stops
375 for i in range(p_len):
376 vr = Signal(reset_less=True)
377 p = self.p[i]
378 vr = Signal(reset_less=True)
379 maskedout = Signal(reset_less=True)
380 if hasattr(p, "mask_i"):
381 m.d.comb += maskedout.eq(p.mask_i & ~p.stop_i)
382 else:
383 m.d.comb += maskedout.eq(1)
384 m.d.comb += vr.eq(maskedout.bool() & p.i_valid & p.o_ready)
385 with m.If(vr):
386 m.d.comb += eq(r_data[i], self.process(self.p[i].i_data))
387 with m.If(mid == i):
388 m.d.comb += eq(self.n.o_data, r_data[i])
389 if self.maskwid:
390 mlen = len(self.p[i].mask_i)
391 s = mlen*i
392 e = mlen*(i+1)
393 ml.append(Mux(vr, self.p[i].mask_i, Const(0, mlen)))
394 ms.append(self.p[i].stop_i)
395 if self.maskwid:
396 m.d.comb += self.n.mask_o.eq(Cat(*ml))
397 m.d.comb += self.n.stop_o.eq(Cat(*ms))
398
399 #print ("o_data", self.n.o_data, "r_data[mid]", mid, r_data[mid])
400 #m.d.comb += eq(self.n.o_data, r_data[mid])
401
402 return m
403
404
405 class NonCombMultiInPipeline(MultiInControlBase):
406 """ A multi-input pipeline block conforming to the Pipeline API
407
408 Attributes:
409 -----------
410 p.i_data : StageInput, shaped according to ispec
411 The pipeline input
412 p.o_data : StageOutput, shaped according to ospec
413 The pipeline output
414 r_data : input_shape according to ispec
415 A temporary (buffered) copy of a prior (valid) input.
416 This is HELD if the output is not ready. It is updated
417 SYNCHRONOUSLY.
418 """
419
420 def __init__(self, stage, p_len, p_mux, maskwid=0, routemask=False):
421 MultiInControlBase.__init__(self, p_len=p_len, maskwid=maskwid,
422 routemask=routemask)
423 self.stage = stage
424 self.maskwid = maskwid
425 self.p_mux = p_mux
426
427 # set up the input and output data
428 for i in range(p_len):
429 name = 'i_data_%d' % i
430 self.p[i].i_data = _spec(stage.ispec, name) # input type
431 self.n.o_data = _spec(stage.ospec, 'o_data')
432
433 def process(self, i):
434 if hasattr(self.stage, "process"):
435 return self.stage.process(i)
436 return i
437
438 def elaborate(self, platform):
439 m = MultiInControlBase.elaborate(self, platform)
440
441 m.submodules.p_mux = self.p_mux
442
443 # need an array of buffer registers conforming to *input* spec
444 r_data = []
445 r_busy = []
446 p_i_valid = []
447 p_len = len(self.p)
448 for i in range(p_len):
449 name = 'r_%d' % i
450 r = _spec(self.stage.ispec, name) # input type
451 r_data.append(r)
452 r_busy.append(Signal(name="r_busy%d" % i, reset_less=True))
453 p_i_valid.append(Signal(name="p_i_valid%d" % i, reset_less=True))
454 if hasattr(self.stage, "setup"):
455 print("setup", self, self.stage, r)
456 self.stage.setup(m, r)
457 if len(r_data) > 1:
458 r_data = Array(r_data)
459 p_i_valid = Array(p_i_valid)
460 r_busy = Array(r_busy)
461
462 nirn = Signal(reset_less=True)
463 m.d.comb += nirn.eq(~self.n.i_ready)
464 mid = self.p_mux.m_id
465 print("CombMuxIn mid", self, self.stage, self.routemask, mid, p_len)
466 for i in range(p_len):
467 m.d.comb += r_busy[i].eq(0)
468 m.d.comb += n_i_readyn[i].eq(1)
469 m.d.comb += p_i_valid[i].eq(0)
470 m.d.comb += self.p[i].o_ready.eq(n_i_readyn[i])
471 p = self.p[mid]
472 maskedout = Signal(reset_less=True)
473 if hasattr(p, "mask_i"):
474 m.d.comb += maskedout.eq(p.mask_i & ~p.stop_i)
475 else:
476 m.d.comb += maskedout.eq(1)
477 m.d.comb += p_i_valid[mid].eq(maskedout & self.p_mux.active)
478 m.d.comb += self.p[mid].o_ready.eq(~data_valid[mid] | self.n.i_ready)
479 m.d.comb += n_i_readyn[mid].eq(nirn & data_valid[mid])
480 anyvalid = Signal(i, reset_less=True)
481 av = []
482 for i in range(p_len):
483 av.append(data_valid[i])
484 anyvalid = Cat(*av)
485 m.d.comb += self.n.o_valid.eq(anyvalid.bool())
486 m.d.comb += data_valid[mid].eq(p_i_valid[mid] |
487 (n_i_readyn[mid]))
488
489 if self.routemask:
490 # XXX hack - fixes loop
491 m.d.comb += eq(self.n.stop_o, self.p[-1].stop_i)
492 for i in range(p_len):
493 p = self.p[i]
494 vr = Signal(name="vr%d" % i, reset_less=True)
495 maskedout = Signal(name="maskedout%d" % i, reset_less=True)
496 if hasattr(p, "mask_i"):
497 m.d.comb += maskedout.eq(p.mask_i & ~p.stop_i)
498 else:
499 m.d.comb += maskedout.eq(1)
500 m.d.comb += vr.eq(maskedout.bool() & p.i_valid & p.o_ready)
501 #m.d.comb += vr.eq(p.i_valid & p.o_ready)
502 with m.If(vr):
503 m.d.comb += eq(self.n.mask_o, self.p[i].mask_i)
504 m.d.comb += eq(r_data[i], self.p[i].i_data)
505 else:
506 ml = [] # accumulate output masks
507 ms = [] # accumulate output stops
508 for i in range(p_len):
509 vr = Signal(reset_less=True)
510 p = self.p[i]
511 vr = Signal(reset_less=True)
512 maskedout = Signal(reset_less=True)
513 if hasattr(p, "mask_i"):
514 m.d.comb += maskedout.eq(p.mask_i & ~p.stop_i)
515 else:
516 m.d.comb += maskedout.eq(1)
517 m.d.comb += vr.eq(maskedout.bool() & p.i_valid & p.o_ready)
518 with m.If(vr):
519 m.d.comb += eq(r_data[i], self.p[i].i_data)
520 if self.maskwid:
521 mlen = len(self.p[i].mask_i)
522 s = mlen*i
523 e = mlen*(i+1)
524 ml.append(Mux(vr, self.p[i].mask_i, Const(0, mlen)))
525 ms.append(self.p[i].stop_i)
526 if self.maskwid:
527 m.d.comb += self.n.mask_o.eq(Cat(*ml))
528 m.d.comb += self.n.stop_o.eq(Cat(*ms))
529
530 m.d.comb += eq(self.n.o_data, self.process(r_data[mid]))
531
532 return m
533
534
535 class CombMuxOutPipe(CombMultiOutPipeline):
536 def __init__(self, stage, n_len, maskwid=0, muxidname=None,
537 routemask=False):
538 muxidname = muxidname or "muxid"
539 # HACK: stage is also the n-way multiplexer
540 CombMultiOutPipeline.__init__(self, stage, n_len=n_len,
541 n_mux=stage, maskwid=maskwid,
542 routemask=routemask)
543
544 # HACK: n-mux is also the stage... so set the muxid equal to input muxid
545 muxid = getattr(self.p.i_data, muxidname)
546 print("combmuxout", muxidname, muxid)
547 stage.m_id = muxid
548
549
550 class InputPriorityArbiter(Elaboratable):
551 """ arbitration module for Input-Mux pipe, baed on PriorityEncoder
552 """
553
554 def __init__(self, pipe, num_rows):
555 self.pipe = pipe
556 self.num_rows = num_rows
557 self.mmax = int(log(self.num_rows) / log(2))
558 self.m_id = Signal(self.mmax, reset_less=True) # multiplex id
559 self.active = Signal(reset_less=True)
560
561 def elaborate(self, platform):
562 m = Module()
563
564 assert len(self.pipe.p) == self.num_rows, \
565 "must declare input to be same size"
566 pe = PriorityEncoder(self.num_rows)
567 m.submodules.selector = pe
568
569 # connect priority encoder
570 in_ready = []
571 for i in range(self.num_rows):
572 p_i_valid = Signal(reset_less=True)
573 if self.pipe.maskwid and not self.pipe.routemask:
574 p = self.pipe.p[i]
575 maskedout = Signal(reset_less=True)
576 m.d.comb += maskedout.eq(p.mask_i & ~p.stop_i)
577 m.d.comb += p_i_valid.eq(maskedout.bool() & p.i_valid_test)
578 else:
579 m.d.comb += p_i_valid.eq(self.pipe.p[i].i_valid_test)
580 in_ready.append(p_i_valid)
581 m.d.comb += pe.i.eq(Cat(*in_ready)) # array of input "valids"
582 m.d.comb += self.active.eq(~pe.n) # encoder active (one input valid)
583 m.d.comb += self.m_id.eq(pe.o) # output one active input
584
585 return m
586
587 def ports(self):
588 return [self.m_id, self.active]
589
590
591 class PriorityCombMuxInPipe(CombMultiInPipeline):
592 """ an example of how to use the combinatorial pipeline.
593 """
594
595 def __init__(self, stage, p_len=2, maskwid=0, routemask=False):
596 p_mux = InputPriorityArbiter(self, p_len)
597 CombMultiInPipeline.__init__(self, stage, p_len, p_mux,
598 maskwid=maskwid, routemask=routemask)
599
600
601 if __name__ == '__main__':
602
603 from nmutil.test.example_buf_pipe import ExampleStage
604 dut = PriorityCombMuxInPipe(ExampleStage)
605 vl = rtlil.convert(dut, ports=dut.ports())
606 with open("test_combpipe.il", "w") as f:
607 f.write(vl)