+from nmigen.hdl.rec import Record, Layout
+
+from collections.abc import Sequence
+
+
+class PrevControl:
+ """ contains signals that come *from* the previous stage (both in and out)
+ * i_valid: previous stage indicating all incoming data is valid.
+ may be a multi-bit signal, where all bits are required
+ to be asserted to indicate "valid".
+ * o_ready: output to next stage indicating readiness to accept data
+ * i_data : an input - added by the user of this class
+ """
+
+ def __init__(self, i_width=1):
+ self.i_valid = Signal(i_width, name="p_i_valid") # prev >>in self
+ self.o_ready = Signal(name="p_o_ready") # prev <<out self
+
+ def connect_in(self, prev):
+ """ helper function to connect stage to an input source. do not
+ use to connect stage-to-stage!
+ """
+ return [self.i_valid.eq(prev.i_valid),
+ prev.o_ready.eq(self.o_ready),
+ eq(self.i_data, prev.i_data),
+ ]
+
+ def i_valid_logic(self):
+ vlen = len(self.i_valid)
+ if vlen > 1: # multi-bit case: valid only when i_valid is all 1s
+ all1s = Const(-1, (len(self.i_valid), False))
+ return self.i_valid == all1s
+ # single-bit i_valid case
+ return self.i_valid
+
+
+class NextControl:
+ """ contains the signals that go *to* the next stage (both in and out)
+ * o_valid: output indicating to next stage that data is valid
+ * i_ready: input from next stage indicating that it can accept data
+ * o_data : an output - added by the user of this class
+ """
+ def __init__(self):
+ self.o_valid = Signal(name="n_o_valid") # self out>> next
+ self.i_ready = Signal(name="n_i_ready") # self <<in next
+
+ def connect_to_next(self, nxt):
+ """ helper function to connect to the next stage data/valid/ready.
+ data/valid is passed *TO* nxt, and ready comes *IN* from nxt.
+ """
+ return [nxt.i_valid.eq(self.o_valid),
+ self.i_ready.eq(nxt.o_ready),
+ eq(nxt.i_data, self.o_data),
+ ]
+
+ def connect_out(self, nxt):
+ """ helper function to connect stage to an output source. do not
+ use to connect stage-to-stage!
+ """
+ return [nxt.o_valid.eq(self.o_valid),
+ self.i_ready.eq(nxt.i_ready),
+ eq(nxt.o_data, self.o_data),
+ ]
+
+
+def eq(o, i):
+ """ makes signals equal: a helper routine which identifies if it is being
+ passed a list (or tuple) of objects, or signals, or Records, and calls
+ the objects' eq function.
+
+ complex objects (classes) can be used: they must follow the
+ convention of having an eq member function, which takes the
+ responsibility of further calling eq and returning a list of
+ eq assignments
+
+ Record is a special (unusual, recursive) case, where the input
+ is specified as a dictionary (which may contain further dictionaries,
+ recursively), where the field names of the dictionary must match
+ the Record's field spec.
+ """
+ if not isinstance(o, Sequence):
+ o, i = [o], [i]
+ res = []
+ for (ao, ai) in zip(o, i):
+ #print ("eq", ao, ai)
+ if isinstance(ao, Record):
+ for idx, (field_name, field_shape, _) in enumerate(ao.layout):
+ if isinstance(field_shape, Layout):
+ rres = eq(ao.fields[field_name], ai.fields[field_name])
+ else:
+ rres = eq(ao.fields[field_name], ai[field_name])
+ res += rres
+ else:
+ rres = ao.eq(ai)
+ if not isinstance(rres, Sequence):
+ rres = [rres]
+ res += rres
+ return res
+
+
+class StageChain:
+ """ pass in a list of stages, and they will automatically be
+ chained together via their input and output specs into a
+ combinatorial chain.
+
+ * input to this class will be the input of the first stage
+ * output of first stage goes into input of second
+ * output of second goes into input into third (etc. etc.)
+ * the output of this class will be the output of the last stage
+ """
+ def __init__(self, chain):
+ self.chain = chain
+
+ def ispec(self):
+ return self.chain[0].ispec()
+
+ def ospec(self):
+ return self.chain[-1].ospec()
+
+ def setup(self, m, i):
+ for (idx, c) in enumerate(self.chain):
+ if hasattr(c, "setup"):
+ c.setup(m, i) # stage may have some module stuff
+ o = self.chain[idx].ospec() # only the last assignment survives
+ m.d.comb += eq(o, c.process(i)) # process input into "o"
+ if idx != len(self.chain)-1:
+ ni = self.chain[idx+1].ispec() # becomes new input on next loop
+ m.d.comb += eq(ni, o) # assign output to next input
+ i = ni
+ self.o = o # last loop is the output
+
+ def process(self, i):
+ return self.o
+
+
+class PipelineBase:
+ """ Common functions for Pipeline API
+ """
+ def __init__(self, stage, in_multi=None):
+ """ pass in a "stage" which may be either a static class or a class
+ instance, which has four functions (one optional):
+ * ispec: returns input signals according to the input specification
+ * ispec: returns output signals to the output specification
+ * process: takes an input instance and returns processed data
+ * setup: performs any module linkage if the stage uses one.
+
+ User must also:
+ * add i_data member to PrevControl and
+ * add o_data member to NextControl
+ """
+ self.stage = stage
+
+ # set up input and output IO ACK (prev/next ready/valid)
+ self.p = PrevControl(in_multi)
+ self.n = NextControl()
+
+ def connect_to_next(self, nxt):
+ """ helper function to connect to the next stage data/valid/ready.
+ """
+ return self.n.connect_to_next(nxt.p)
+
+ def connect_in(self, prev):
+ """ helper function to connect stage to an input source. do not
+ use to connect stage-to-stage!
+ """
+ return self.p.connect_in(prev.p)
+
+ def connect_out(self, nxt):
+ """ helper function to connect stage to an output source. do not
+ use to connect stage-to-stage!
+ """
+ return self.n.connect_out(nxt.n)
+
+ def set_input(self, i):
+ """ helper function to set the input data
+ """
+ return eq(self.p.i_data, i)
+
+ def ports(self):
+ return [self.p.i_valid, self.n.i_ready,
+ self.n.o_valid, self.p.o_ready,
+ self.p.i_data, self.n.o_data # XXX need flattening!
+ ]