d882742d764384f64ce59f773cb10c541dede64c
[nmutil.git] / src / nmutil / multipipe.py
1 """ Combinatorial Multi-input and Multi-output multiplexer blocks
2 conforming to Pipeline API
3
4 This work is funded through NLnet under Grant 2019-02-012
5
6 License: LGPLv3+
7
8 Multi-input is complex because if any one input is ready, the output
9 can be ready, and the decision comes from a separate module.
10
11 Multi-output is simple (pretty much identical to UnbufferedPipeline),
12 and the selection is just a mux. The only proviso (difference) being:
13 the outputs not being selected have to have their o_ready signals
14 DEASSERTED.
15
16 https://bugs.libre-soc.org/show_bug.cgi?id=538
17 """
18
19 from math import log
20 from nmigen import Signal, Cat, Const, Mux, Module, Array, Elaboratable
21 from nmigen.cli import verilog, rtlil
22 from nmigen.lib.coding import PriorityEncoder
23 from nmigen.hdl.rec import Record, Layout
24 from nmutil.stageapi import _spec
25
26 from collections.abc import Sequence
27
28 from nmutil.nmoperator import eq
29 from nmutil.iocontrol import NextControl, PrevControl
30
31
32 class MultiInControlBase(Elaboratable):
33 """ Common functions for Pipeline API
34 """
35
36 def __init__(self, in_multi=None, p_len=1, maskwid=0, routemask=False):
37 """ Multi-input Control class. Conforms to same API as ControlBase...
38 mostly. has additional indices to the *multiple* input stages
39
40 * p: contains ready/valid to the previous stages PLURAL
41 * n: contains ready/valid to the next stage
42
43 User must also:
44 * add i_data members to PrevControl and
45 * add o_data member to NextControl
46 """
47 self.routemask = routemask
48 # set up input and output IO ACK (prev/next ready/valid)
49 print("multi_in", self, maskwid, p_len, routemask)
50 p = []
51 for i in range(p_len):
52 p.append(PrevControl(in_multi, maskwid=maskwid))
53 self.p = Array(p)
54 if routemask:
55 nmaskwid = maskwid # straight route mask mode
56 else:
57 nmaskwid = maskwid * p_len # fan-in mode
58 self.n = NextControl(maskwid=nmaskwid) # masks fan in (Cat)
59
60 def connect_to_next(self, nxt, p_idx=0):
61 """ helper function to connect to the next stage data/valid/ready.
62 """
63 return self.n.connect_to_next(nxt.p[p_idx])
64
65 def _connect_in(self, prev, idx=0, prev_idx=None):
66 """ helper function to connect stage to an input source. do not
67 use to connect stage-to-stage!
68 """
69 if prev_idx is None:
70 return self.p[idx]._connect_in(prev.p)
71 return self.p[idx]._connect_in(prev.p[prev_idx])
72
73 def _connect_out(self, nxt):
74 """ helper function to connect stage to an output source. do not
75 use to connect stage-to-stage!
76 """
77 if nxt_idx is None:
78 return self.n._connect_out(nxt.n)
79 return self.n._connect_out(nxt.n)
80
81 def set_input(self, i, idx=0):
82 """ helper function to set the input data
83 """
84 return eq(self.p[idx].i_data, i)
85
86 def elaborate(self, platform):
87 m = Module()
88 for i, p in enumerate(self.p):
89 setattr(m.submodules, "p%d" % i, p)
90 m.submodules.n = self.n
91 return m
92
93 def __iter__(self):
94 for p in self.p:
95 yield from p
96 yield from self.n
97
98 def ports(self):
99 return list(self)
100
101
102 class MultiOutControlBase(Elaboratable):
103 """ Common functions for Pipeline API
104 """
105
106 def __init__(self, n_len=1, in_multi=None, maskwid=0, routemask=False):
107 """ Multi-output Control class. Conforms to same API as ControlBase...
108 mostly. has additional indices to the multiple *output* stages
109 [MultiInControlBase has multiple *input* stages]
110
111 * p: contains ready/valid to the previou stage
112 * n: contains ready/valid to the next stages PLURAL
113
114 User must also:
115 * add i_data member to PrevControl and
116 * add o_data members to NextControl
117 """
118
119 if routemask:
120 nmaskwid = maskwid # straight route mask mode
121 else:
122 nmaskwid = maskwid * n_len # fan-out mode
123
124 # set up input and output IO ACK (prev/next ready/valid)
125 self.p = PrevControl(in_multi, maskwid=nmaskwid)
126 n = []
127 for i in range(n_len):
128 n.append(NextControl(maskwid=maskwid))
129 self.n = Array(n)
130
131 def connect_to_next(self, nxt, n_idx=0):
132 """ helper function to connect to the next stage data/valid/ready.
133 """
134 return self.n[n_idx].connect_to_next(nxt.p)
135
136 def _connect_in(self, prev, idx=0):
137 """ helper function to connect stage to an input source. do not
138 use to connect stage-to-stage!
139 """
140 return self.n[idx]._connect_in(prev.p)
141
142 def _connect_out(self, nxt, idx=0, nxt_idx=None):
143 """ helper function to connect stage to an output source. do not
144 use to connect stage-to-stage!
145 """
146 if nxt_idx is None:
147 return self.n[idx]._connect_out(nxt.n)
148 return self.n[idx]._connect_out(nxt.n[nxt_idx])
149
150 def elaborate(self, platform):
151 m = Module()
152 m.submodules.p = self.p
153 for i, n in enumerate(self.n):
154 setattr(m.submodules, "n%d" % i, n)
155 return m
156
157 def set_input(self, i):
158 """ helper function to set the input data
159 """
160 return eq(self.p.i_data, i)
161
162 def __iter__(self):
163 yield from self.p
164 for n in self.n:
165 yield from n
166
167 def ports(self):
168 return list(self)
169
170
171 class CombMultiOutPipeline(MultiOutControlBase):
172 """ A multi-input Combinatorial block conforming to the Pipeline API
173
174 Attributes:
175 -----------
176 p.i_data : stage input data (non-array). shaped according to ispec
177 n.o_data : stage output data array. shaped according to ospec
178 """
179
180 def __init__(self, stage, n_len, n_mux, maskwid=0, routemask=False):
181 MultiOutControlBase.__init__(self, n_len=n_len, maskwid=maskwid,
182 routemask=routemask)
183 self.stage = stage
184 self.maskwid = maskwid
185 self.routemask = routemask
186 self.n_mux = n_mux
187
188 # set up the input and output data
189 self.p.i_data = _spec(stage.ispec, 'i_data') # input type
190 for i in range(n_len):
191 name = 'o_data_%d' % i
192 self.n[i].o_data = _spec(stage.ospec, name) # output type
193
194 def process(self, i):
195 if hasattr(self.stage, "process"):
196 return self.stage.process(i)
197 return i
198
199 def elaborate(self, platform):
200 m = MultiOutControlBase.elaborate(self, platform)
201
202 if hasattr(self.n_mux, "elaborate"): # TODO: identify submodule?
203 m.submodules.n_mux = self.n_mux
204
205 # need buffer register conforming to *input* spec
206 r_data = _spec(self.stage.ispec, 'r_data') # input type
207 if hasattr(self.stage, "setup"):
208 self.stage.setup(m, r_data)
209
210 # multiplexer id taken from n_mux
211 muxid = self.n_mux.m_id
212 print("self.n_mux", self.n_mux)
213 print("self.n_mux.m_id", self.n_mux.m_id)
214
215 self.n_mux.m_id.name = "m_id"
216
217 # temporaries
218 p_i_valid = Signal(reset_less=True)
219 pv = Signal(reset_less=True)
220 m.d.comb += p_i_valid.eq(self.p.i_valid_test)
221 # m.d.comb += pv.eq(self.p.i_valid) #& self.n[muxid].i_ready)
222 m.d.comb += pv.eq(self.p.i_valid & self.p.o_ready)
223
224 # all outputs to next stages first initialised to zero (invalid)
225 # the only output "active" is then selected by the muxid
226 for i in range(len(self.n)):
227 m.d.comb += self.n[i].o_valid.eq(0)
228 if self.routemask:
229 # with m.If(pv):
230 m.d.comb += self.n[muxid].o_valid.eq(pv)
231 m.d.comb += self.p.o_ready.eq(self.n[muxid].i_ready)
232 else:
233 data_valid = self.n[muxid].o_valid
234 m.d.comb += self.p.o_ready.eq(self.n[muxid].i_ready)
235 m.d.comb += data_valid.eq(p_i_valid |
236 (~self.n[muxid].i_ready & data_valid))
237
238 # send data on
239 # with m.If(pv):
240 m.d.comb += eq(r_data, self.p.i_data)
241 #m.d.comb += eq(self.n[muxid].o_data, self.process(r_data))
242 for i in range(len(self.n)):
243 with m.If(muxid == i):
244 m.d.comb += eq(self.n[i].o_data, self.process(r_data))
245
246 if self.maskwid:
247 if self.routemask: # straight "routing" mode - treat like data
248 m.d.comb += self.n[muxid].stop_o.eq(self.p.stop_i)
249 with m.If(pv):
250 m.d.comb += self.n[muxid].mask_o.eq(self.p.mask_i)
251 else:
252 ml = [] # accumulate output masks
253 ms = [] # accumulate output stops
254 # fan-out mode.
255 # conditionally fan-out mask bits, always fan-out stop bits
256 for i in range(len(self.n)):
257 ml.append(self.n[i].mask_o)
258 ms.append(self.n[i].stop_o)
259 m.d.comb += Cat(*ms).eq(self.p.stop_i)
260 with m.If(pv):
261 m.d.comb += Cat(*ml).eq(self.p.mask_i)
262 return m
263
264
265 class CombMultiInPipeline(MultiInControlBase):
266 """ A multi-input Combinatorial block conforming to the Pipeline API
267
268 Attributes:
269 -----------
270 p.i_data : StageInput, shaped according to ispec
271 The pipeline input
272 p.o_data : StageOutput, shaped according to ospec
273 The pipeline output
274 r_data : input_shape according to ispec
275 A temporary (buffered) copy of a prior (valid) input.
276 This is HELD if the output is not ready. It is updated
277 SYNCHRONOUSLY.
278 """
279
280 def __init__(self, stage, p_len, p_mux, maskwid=0, routemask=False):
281 MultiInControlBase.__init__(self, p_len=p_len, maskwid=maskwid,
282 routemask=routemask)
283 self.stage = stage
284 self.maskwid = maskwid
285 self.p_mux = p_mux
286
287 # set up the input and output data
288 for i in range(p_len):
289 name = 'i_data_%d' % i
290 self.p[i].i_data = _spec(stage.ispec, name) # input type
291 self.n.o_data = _spec(stage.ospec, 'o_data')
292
293 def process(self, i):
294 if hasattr(self.stage, "process"):
295 return self.stage.process(i)
296 return i
297
298 def elaborate(self, platform):
299 m = MultiInControlBase.elaborate(self, platform)
300
301 m.submodules.p_mux = self.p_mux
302
303 # need an array of buffer registers conforming to *input* spec
304 r_data = []
305 data_valid = []
306 p_i_valid = []
307 n_i_readyn = []
308 p_len = len(self.p)
309 for i in range(p_len):
310 name = 'r_%d' % i
311 r = _spec(self.stage.ispec, name) # input type
312 r_data.append(r)
313 data_valid.append(Signal(name="data_valid", reset_less=True))
314 p_i_valid.append(Signal(name="p_i_valid", reset_less=True))
315 n_i_readyn.append(Signal(name="n_i_readyn", reset_less=True))
316 if hasattr(self.stage, "setup"):
317 print("setup", self, self.stage, r)
318 self.stage.setup(m, r)
319 if True: # len(r_data) > 1: # hmm always create an Array even of len 1
320 p_i_valid = Array(p_i_valid)
321 n_i_readyn = Array(n_i_readyn)
322 data_valid = Array(data_valid)
323
324 nirn = Signal(reset_less=True)
325 m.d.comb += nirn.eq(~self.n.i_ready)
326 mid = self.p_mux.m_id
327 print("CombMuxIn mid", self, self.stage, self.routemask, mid, p_len)
328 for i in range(p_len):
329 m.d.comb += data_valid[i].eq(0)
330 m.d.comb += n_i_readyn[i].eq(1)
331 m.d.comb += p_i_valid[i].eq(0)
332 #m.d.comb += self.p[i].o_ready.eq(~data_valid[i] | self.n.i_ready)
333 m.d.comb += self.p[i].o_ready.eq(0)
334 p = self.p[mid]
335 maskedout = Signal(reset_less=True)
336 if hasattr(p, "mask_i"):
337 m.d.comb += maskedout.eq(p.mask_i & ~p.stop_i)
338 else:
339 m.d.comb += maskedout.eq(1)
340 m.d.comb += p_i_valid[mid].eq(maskedout & self.p_mux.active &
341 self.p[mid].i_valid)
342 m.d.comb += self.p[mid].o_ready.eq(~data_valid[mid] | self.n.i_ready)
343 m.d.comb += n_i_readyn[mid].eq(nirn & data_valid[mid])
344 anyvalid = Signal(i, reset_less=True)
345 av = []
346 for i in range(p_len):
347 av.append(data_valid[i])
348 anyvalid = Cat(*av)
349 m.d.comb += self.n.o_valid.eq(anyvalid.bool())
350 m.d.comb += data_valid[mid].eq(p_i_valid[mid] |
351 (n_i_readyn[mid]))
352
353 if self.routemask:
354 # XXX hack - fixes loop
355 m.d.comb += eq(self.n.stop_o, self.p[-1].stop_i)
356 for i in range(p_len):
357 p = self.p[i]
358 vr = Signal(name="vr%d" % i, reset_less=True)
359 maskedout = Signal(name="maskedout%d" % i, reset_less=True)
360 if hasattr(p, "mask_i"):
361 m.d.comb += maskedout.eq(p.mask_i & ~p.stop_i)
362 else:
363 m.d.comb += maskedout.eq(1)
364 m.d.comb += vr.eq(maskedout.bool() & p.i_valid & p.o_ready)
365 #m.d.comb += vr.eq(p.i_valid & p.o_ready)
366 with m.If(vr):
367 m.d.comb += eq(self.n.mask_o, self.p[i].mask_i)
368 m.d.comb += eq(r_data[i], self.process(self.p[i].i_data))
369 with m.If(mid == i):
370 m.d.comb += eq(self.n.o_data, r_data[i])
371 else:
372 ml = [] # accumulate output masks
373 ms = [] # accumulate output stops
374 for i in range(p_len):
375 vr = Signal(reset_less=True)
376 p = self.p[i]
377 vr = Signal(reset_less=True)
378 maskedout = Signal(reset_less=True)
379 if hasattr(p, "mask_i"):
380 m.d.comb += maskedout.eq(p.mask_i & ~p.stop_i)
381 else:
382 m.d.comb += maskedout.eq(1)
383 m.d.comb += vr.eq(maskedout.bool() & p.i_valid & p.o_ready)
384 with m.If(vr):
385 m.d.comb += eq(r_data[i], self.process(self.p[i].i_data))
386 with m.If(mid == i):
387 m.d.comb += eq(self.n.o_data, r_data[i])
388 if self.maskwid:
389 mlen = len(self.p[i].mask_i)
390 s = mlen*i
391 e = mlen*(i+1)
392 ml.append(Mux(vr, self.p[i].mask_i, Const(0, mlen)))
393 ms.append(self.p[i].stop_i)
394 if self.maskwid:
395 m.d.comb += self.n.mask_o.eq(Cat(*ml))
396 m.d.comb += self.n.stop_o.eq(Cat(*ms))
397
398 #print ("o_data", self.n.o_data, "r_data[mid]", mid, r_data[mid])
399 #m.d.comb += eq(self.n.o_data, r_data[mid])
400
401 return m
402
403
404 class NonCombMultiInPipeline(MultiInControlBase):
405 """ A multi-input pipeline block conforming to the Pipeline API
406
407 Attributes:
408 -----------
409 p.i_data : StageInput, shaped according to ispec
410 The pipeline input
411 p.o_data : StageOutput, shaped according to ospec
412 The pipeline output
413 r_data : input_shape according to ispec
414 A temporary (buffered) copy of a prior (valid) input.
415 This is HELD if the output is not ready. It is updated
416 SYNCHRONOUSLY.
417 """
418
419 def __init__(self, stage, p_len, p_mux, maskwid=0, routemask=False):
420 MultiInControlBase.__init__(self, p_len=p_len, maskwid=maskwid,
421 routemask=routemask)
422 self.stage = stage
423 self.maskwid = maskwid
424 self.p_mux = p_mux
425
426 # set up the input and output data
427 for i in range(p_len):
428 name = 'i_data_%d' % i
429 self.p[i].i_data = _spec(stage.ispec, name) # input type
430 self.n.o_data = _spec(stage.ospec, 'o_data')
431
432 def process(self, i):
433 if hasattr(self.stage, "process"):
434 return self.stage.process(i)
435 return i
436
437 def elaborate(self, platform):
438 m = MultiInControlBase.elaborate(self, platform)
439
440 m.submodules.p_mux = self.p_mux
441
442 # need an array of buffer registers conforming to *input* spec
443 r_data = []
444 r_busy = []
445 p_i_valid = []
446 p_len = len(self.p)
447 for i in range(p_len):
448 name = 'r_%d' % i
449 r = _spec(self.stage.ispec, name) # input type
450 r_data.append(r)
451 r_busy.append(Signal(name="r_busy%d" % i, reset_less=True))
452 p_i_valid.append(Signal(name="p_i_valid%d" % i, reset_less=True))
453 if hasattr(self.stage, "setup"):
454 print("setup", self, self.stage, r)
455 self.stage.setup(m, r)
456 if len(r_data) > 1:
457 r_data = Array(r_data)
458 p_i_valid = Array(p_i_valid)
459 r_busy = Array(r_busy)
460
461 nirn = Signal(reset_less=True)
462 m.d.comb += nirn.eq(~self.n.i_ready)
463 mid = self.p_mux.m_id
464 print("CombMuxIn mid", self, self.stage, self.routemask, mid, p_len)
465 for i in range(p_len):
466 m.d.comb += r_busy[i].eq(0)
467 m.d.comb += n_i_readyn[i].eq(1)
468 m.d.comb += p_i_valid[i].eq(0)
469 m.d.comb += self.p[i].o_ready.eq(n_i_readyn[i])
470 p = self.p[mid]
471 maskedout = Signal(reset_less=True)
472 if hasattr(p, "mask_i"):
473 m.d.comb += maskedout.eq(p.mask_i & ~p.stop_i)
474 else:
475 m.d.comb += maskedout.eq(1)
476 m.d.comb += p_i_valid[mid].eq(maskedout & self.p_mux.active)
477 m.d.comb += self.p[mid].o_ready.eq(~data_valid[mid] | self.n.i_ready)
478 m.d.comb += n_i_readyn[mid].eq(nirn & data_valid[mid])
479 anyvalid = Signal(i, reset_less=True)
480 av = []
481 for i in range(p_len):
482 av.append(data_valid[i])
483 anyvalid = Cat(*av)
484 m.d.comb += self.n.o_valid.eq(anyvalid.bool())
485 m.d.comb += data_valid[mid].eq(p_i_valid[mid] |
486 (n_i_readyn[mid]))
487
488 if self.routemask:
489 # XXX hack - fixes loop
490 m.d.comb += eq(self.n.stop_o, self.p[-1].stop_i)
491 for i in range(p_len):
492 p = self.p[i]
493 vr = Signal(name="vr%d" % i, reset_less=True)
494 maskedout = Signal(name="maskedout%d" % i, reset_less=True)
495 if hasattr(p, "mask_i"):
496 m.d.comb += maskedout.eq(p.mask_i & ~p.stop_i)
497 else:
498 m.d.comb += maskedout.eq(1)
499 m.d.comb += vr.eq(maskedout.bool() & p.i_valid & p.o_ready)
500 #m.d.comb += vr.eq(p.i_valid & p.o_ready)
501 with m.If(vr):
502 m.d.comb += eq(self.n.mask_o, self.p[i].mask_i)
503 m.d.comb += eq(r_data[i], self.p[i].i_data)
504 else:
505 ml = [] # accumulate output masks
506 ms = [] # accumulate output stops
507 for i in range(p_len):
508 vr = Signal(reset_less=True)
509 p = self.p[i]
510 vr = Signal(reset_less=True)
511 maskedout = Signal(reset_less=True)
512 if hasattr(p, "mask_i"):
513 m.d.comb += maskedout.eq(p.mask_i & ~p.stop_i)
514 else:
515 m.d.comb += maskedout.eq(1)
516 m.d.comb += vr.eq(maskedout.bool() & p.i_valid & p.o_ready)
517 with m.If(vr):
518 m.d.comb += eq(r_data[i], self.p[i].i_data)
519 if self.maskwid:
520 mlen = len(self.p[i].mask_i)
521 s = mlen*i
522 e = mlen*(i+1)
523 ml.append(Mux(vr, self.p[i].mask_i, Const(0, mlen)))
524 ms.append(self.p[i].stop_i)
525 if self.maskwid:
526 m.d.comb += self.n.mask_o.eq(Cat(*ml))
527 m.d.comb += self.n.stop_o.eq(Cat(*ms))
528
529 m.d.comb += eq(self.n.o_data, self.process(r_data[mid]))
530
531 return m
532
533
534 class CombMuxOutPipe(CombMultiOutPipeline):
535 def __init__(self, stage, n_len, maskwid=0, muxidname=None,
536 routemask=False):
537 muxidname = muxidname or "muxid"
538 # HACK: stage is also the n-way multiplexer
539 CombMultiOutPipeline.__init__(self, stage, n_len=n_len,
540 n_mux=stage, maskwid=maskwid,
541 routemask=routemask)
542
543 # HACK: n-mux is also the stage... so set the muxid equal to input muxid
544 muxid = getattr(self.p.i_data, muxidname)
545 print("combmuxout", muxidname, muxid)
546 stage.m_id = muxid
547
548
549 class InputPriorityArbiter(Elaboratable):
550 """ arbitration module for Input-Mux pipe, baed on PriorityEncoder
551 """
552
553 def __init__(self, pipe, num_rows):
554 self.pipe = pipe
555 self.num_rows = num_rows
556 self.mmax = int(log(self.num_rows) / log(2))
557 self.m_id = Signal(self.mmax, reset_less=True) # multiplex id
558 self.active = Signal(reset_less=True)
559
560 def elaborate(self, platform):
561 m = Module()
562
563 assert len(self.pipe.p) == self.num_rows, \
564 "must declare input to be same size"
565 pe = PriorityEncoder(self.num_rows)
566 m.submodules.selector = pe
567
568 # connect priority encoder
569 in_ready = []
570 for i in range(self.num_rows):
571 p_i_valid = Signal(reset_less=True)
572 if self.pipe.maskwid and not self.pipe.routemask:
573 p = self.pipe.p[i]
574 maskedout = Signal(reset_less=True)
575 m.d.comb += maskedout.eq(p.mask_i & ~p.stop_i)
576 m.d.comb += p_i_valid.eq(maskedout.bool() & p.i_valid_test)
577 else:
578 m.d.comb += p_i_valid.eq(self.pipe.p[i].i_valid_test)
579 in_ready.append(p_i_valid)
580 m.d.comb += pe.i.eq(Cat(*in_ready)) # array of input "valids"
581 m.d.comb += self.active.eq(~pe.n) # encoder active (one input valid)
582 m.d.comb += self.m_id.eq(pe.o) # output one active input
583
584 return m
585
586 def ports(self):
587 return [self.m_id, self.active]
588
589
590 class PriorityCombMuxInPipe(CombMultiInPipeline):
591 """ an example of how to use the combinatorial pipeline.
592 """
593
594 def __init__(self, stage, p_len=2, maskwid=0, routemask=False):
595 p_mux = InputPriorityArbiter(self, p_len)
596 CombMultiInPipeline.__init__(self, stage, p_len, p_mux,
597 maskwid=maskwid, routemask=routemask)
598
599
600 if __name__ == '__main__':
601
602 from nmutil.test.example_buf_pipe import ExampleStage
603 dut = PriorityCombMuxInPipe(ExampleStage)
604 vl = rtlil.convert(dut, ports=dut.ports())
605 with open("test_combpipe.il", "w") as f:
606 f.write(vl)