+from nmigen.hdl.rec import Record, Layout
+
+from collections.abc import Sequence
+
+
+class PrevControl:
+ """ contains signals that come *from* the previous stage (both in and out)
+ * i_valid: previous stage indicating all incoming data is valid.
+ may be a multi-bit signal, where all bits are required
+ to be asserted to indicate "valid".
+ * o_ready: output to next stage indicating readiness to accept data
+ * i_data : an input - added by the user of this class
+ """
+
+ def __init__(self, i_width=1):
+ self.i_valid = Signal(i_width, name="p_i_valid") # prev >>in self
+ self.o_ready = Signal(name="p_o_ready") # prev <<out self
+
+ def connect_in(self, prev):
+ """ helper function to connect stage to an input source. do not
+ use to connect stage-to-stage!
+ """
+ return [self.i_valid.eq(prev.i_valid),
+ prev.o_ready.eq(self.o_ready),
+ eq(self.i_data, prev.i_data),
+ ]
+
+ def i_valid_logic(self):
+ vlen = len(self.i_valid)
+ if vlen > 1: # multi-bit case: valid only when i_valid is all 1s
+ all1s = Const(-1, (len(self.i_valid), False))
+ return self.i_valid == all1s
+ # single-bit i_valid case
+ return self.i_valid
+
+
+class NextControl:
+ """ contains the signals that go *to* the next stage (both in and out)
+ * o_valid: output indicating to next stage that data is valid
+ * i_ready: input from next stage indicating that it can accept data
+ * o_data : an output - added by the user of this class
+ """
+ def __init__(self):
+ self.o_valid = Signal(name="n_o_valid") # self out>> next
+ self.i_ready = Signal(name="n_i_ready") # self <<in next
+
+ def connect_to_next(self, nxt):
+ """ helper function to connect to the next stage data/valid/ready.
+ data/valid is passed *TO* nxt, and ready comes *IN* from nxt.
+ """
+ return [nxt.i_valid.eq(self.o_valid),
+ self.i_ready.eq(nxt.o_ready),
+ eq(nxt.i_data, self.o_data),
+ ]
+
+ def connect_out(self, nxt):
+ """ helper function to connect stage to an output source. do not
+ use to connect stage-to-stage!
+ """
+ return [nxt.o_valid.eq(self.o_valid),
+ self.i_ready.eq(nxt.i_ready),
+ eq(nxt.o_data, self.o_data),
+ ]
+
+
+def eq(o, i):
+ """ makes signals equal: a helper routine which identifies if it is being
+ passed a list (or tuple) of objects, or signals, or Records, and calls
+ the objects' eq function.
+
+ complex objects (classes) can be used: they must follow the
+ convention of having an eq member function, which takes the
+ responsibility of further calling eq and returning a list of
+ eq assignments
+
+ Record is a special (unusual, recursive) case, where the input
+ is specified as a dictionary (which may contain further dictionaries,
+ recursively), where the field names of the dictionary must match
+ the Record's field spec.
+ """
+ if not isinstance(o, Sequence):
+ o, i = [o], [i]
+ res = []
+ for (ao, ai) in zip(o, i):
+ #print ("eq", ao, ai)
+ if isinstance(ao, Record):
+ for idx, (field_name, field_shape, _) in enumerate(ao.layout):
+ if isinstance(field_shape, Layout):
+ rres = eq(ao.fields[field_name], ai.fields[field_name])
+ else:
+ rres = eq(ao.fields[field_name], ai[field_name])
+ res += rres
+ else:
+ rres = ao.eq(ai)
+ if not isinstance(rres, Sequence):
+ rres = [rres]
+ res += rres
+ return res
+
+
+class StageChain:
+ """ pass in a list of stages, and they will automatically be
+ chained together via their input and output specs into a
+ combinatorial chain.