pass in flatten/processing function into _connect_in/out
[ieee754fpu.git] / src / add / singlepipe.py
index ef3137f3efeb780ee621fd542f41221a424eeec5..8296bf2d4ecd9702cd2d32c20dc345734ed2a641 100644 (file)
 
 from nmigen import Signal, Cat, Const, Mux, Module, Value
 from nmigen.cli import verilog, rtlil
+from nmigen.lib.fifo import SyncFIFO
 from nmigen.hdl.ast import ArrayProxy
 from nmigen.hdl.rec import Record, Layout
 
@@ -173,6 +174,22 @@ from abc import ABCMeta, abstractmethod
 from collections.abc import Sequence
 
 
+class RecordObject(Record):
+    def __init__(self, layout=None, name=None):
+        Record.__init__(self, layout=layout or [], name=None)
+
+    def __setattr__(self, k, v):
+        if k in dir(Record) or "fields" not in self.__dict__:
+            return object.__setattr__(self, k, v)
+        self.fields[k] = v
+        if isinstance(v, Record):
+            newlayout = {k: (k, v.layout)}
+        else:
+            newlayout = {k: (k, v.shape())}
+        self.layout.fields.update(newlayout)
+
+
+
 class PrevControl:
     """ contains signals that come *from* the previous stage (both in and out)
         * i_valid: previous stage indicating all incoming data is valid.
@@ -198,13 +215,15 @@ class PrevControl:
             return self.s_o_ready # set dynamically by stage
         return self._o_ready      # return this when not under dynamic control
 
-    def _connect_in(self, prev):
+    def _connect_in(self, prev, direct=False, fn=None):
         """ internal helper function to connect stage to an input source.
             do not use to connect stage-to-stage!
         """
-        return [self.i_valid.eq(prev.i_valid_test),
+        i_valid = prev.i_valid if direct else prev.i_valid_test
+        i_data = fn(prev.i_data) if fn is not None else prev.i_data
+        return [self.i_valid.eq(i_valid),
                 prev.o_ready.eq(self.o_ready),
-                eq(self.i_data, prev.i_data),
+                eq(self.i_data, i_data),
                ]
 
     @property
@@ -256,25 +275,24 @@ class NextControl:
                 eq(nxt.i_data, self.o_data),
                ]
 
-    def _connect_out(self, nxt):
+    def _connect_out(self, nxt, direct=False, fn=None):
         """ internal helper function to connect stage to an output source.
             do not use to connect stage-to-stage!
         """
+        i_ready = nxt.i_ready if direct else nxt.i_ready_test
+        o_data = fn(self.o_data) if fn is not None else self.o_data
         return [nxt.o_valid.eq(self.o_valid),
-                self.i_ready.eq(nxt.i_ready_test),
-                eq(nxt.o_data, self.o_data),
+                self.i_ready.eq(i_ready),
+                eq(nxt.o_data, o_data),
                ]
 
 
-def eq(o, i):
-    """ makes signals equal: a helper routine which identifies if it is being
-        passed a list (or tuple) of objects, or signals, or Records, and calls
-        the objects' eq function.
+class Visitor:
+    """ a helper routine which identifies if it is being passed a list
+        (or tuple) of objects, or signals, or Records, and calls
+        a visitor function.
 
-        complex objects (classes) can be used: they must follow the
-        convention of having an eq member function, which takes the
-        responsibility of further calling eq and returning a list of
-        eq assignments
+        the visiting fn is called when an object is identified.
 
         Record is a special (unusual, recursive) case, where the input may be
         specified as a dictionary (which may contain further dictionaries,
@@ -291,20 +309,92 @@ def eq(o, i):
         python object, enumerate them, find out the list of Signals that way,
         and assign them.
     """
-    res = []
-    if isinstance(o, dict):
+    def visit(self, o, i, act):
+        if isinstance(o, dict):
+            return self.dict_visit(o, i, act)
+
+        res = act.prepare()
+        if not isinstance(o, Sequence):
+            o, i = [o], [i]
+        for (ao, ai) in zip(o, i):
+            #print ("visit", fn, ao, ai)
+            if isinstance(ao, Record):
+                rres = self.record_visit(ao, ai, act)
+            elif isinstance(ao, ArrayProxy) and not isinstance(ai, Value):
+                rres = self.arrayproxy_visit(ao, ai, act)
+            else:
+                rres = act.fn(ao, ai)
+            res += rres
+        return res
+
+    def dict_visit(self, o, i, act):
+        res = act.prepare()
         for (k, v) in o.items():
             print ("d-eq", v, i[k])
-            res.append(v.eq(i[k]))
+            res.append(act.fn(v, i[k]))
+        return res
+
+    def record_visit(self, ao, ai, act):
+        res = act.prepare()
+        for idx, (field_name, field_shape, _) in enumerate(ao.layout):
+            if isinstance(field_shape, Layout):
+                val = ai.fields
+            else:
+                val = ai
+            if hasattr(val, field_name): # check for attribute
+                val = getattr(val, field_name)
+            else:
+                val = val[field_name] # dictionary-style specification
+            val = self.visit(ao.fields[field_name], val, act)
+            if isinstance(val, Sequence):
+                res += val
+            else:
+                res.append(val)
         return res
 
-    if not isinstance(o, Sequence):
-        o, i = [o], [i]
-    for (ao, ai) in zip(o, i):
-        #print ("eq", ao, ai)
-        if isinstance(ao, Record):
+    def arrayproxy_visit(self, ao, ai, act):
+        res = act.prepare()
+        for p in ai.ports():
+            op = getattr(ao, p.name)
+            #print (op, p, p.name)
+            res.append(fn(op, p))
+        return res
+
+
+class Eq(Visitor):
+    def __init__(self):
+        self.res = []
+    def prepare(self):
+        return []
+    def fn(self, o, i):
+        rres = o.eq(i)
+        if not isinstance(rres, Sequence):
+            rres = [rres]
+        return rres
+    def __call__(self, o, i):
+        return self.visit(o, i, self)
+
+
+def eq(o, i):
+    """ makes signals equal: a helper routine which identifies if it is being
+        passed a list (or tuple) of objects, or signals, or Records, and calls
+        the objects' eq function.
+    """
+    return Eq()(o, i)
+
+
+def flatten(i):
+    """ flattens a compound structure recursively using Cat
+    """
+    if not isinstance(i, Sequence):
+        i = [i]
+    res = []
+    for ai in i:
+        print ("flatten", ai)
+        if isinstance(ai, Record):
+            print ("record", list(ai.layout))
             rres = []
-            for idx, (field_name, field_shape, _) in enumerate(ao.layout):
+            for idx, (field_name, field_shape, _) in enumerate(ai.layout):
                 if isinstance(field_shape, Layout):
                     val = ai.fields
                 else:
@@ -313,19 +403,28 @@ def eq(o, i):
                     val = getattr(val, field_name)
                 else:
                     val = val[field_name] # dictionary-style specification
-                rres += eq(ao.fields[field_name], val)
-        elif isinstance(ao, ArrayProxy) and not isinstance(ai, Value):
+                print ("recidx", idx, field_name, field_shape, val)
+                val = flatten(val)
+                print ("recidx flat", idx, val)
+                if isinstance(val, Sequence):
+                    rres += val
+                else:
+                    rres.append(val)
+
+        elif isinstance(ai, ArrayProxy) and not isinstance(ai, Value):
             rres = []
             for p in ai.ports():
-                op = getattr(ao, p.name)
+                op = getattr(ai, p.name)
                 #print (op, p, p.name)
-                rres.append(op.eq(p))
+                rres.append(flatten(p))
         else:
-            rres = ao.eq(ai)
+            rres = ai
         if not isinstance(rres, Sequence):
             rres = [rres]
         res += rres
-    return res
+        print ("flatten res", res)
+    return Cat(*res)
+
 
 
 class StageCls(metaclass=ABCMeta):
@@ -638,7 +737,7 @@ class BufferedHandshake(ControlBase):
             self.m.d.sync += [self.n.o_valid.eq(p_i_valid), # valid if p_valid
                               eq(self.n.o_data, result),    # update output
                              ]
-        # buffer flush conditions (NOTE: n.o_valid override data passthru)
+        # buffer flush conditions (NOTE: can override data passthru conditions)
         with self.m.If(nir_por_n): # not stalled
             # Flush the [already processed] buffer to the output port.
             self.m.d.sync += [self.n.o_valid.eq(1),  # reg empty
@@ -661,6 +760,35 @@ class SimpleHandshake(ControlBase):
         stage-1   p.i_data  >>in   stage   n.o_data  out>>   stage+1
                               |             |
                               +--process->--^
+        Truth Table
+
+        Inputs   Temporary  Output
+        -------  ---------- -----
+        P P N N  PiV& ~NiV&  N P
+        i o i o  PoR  NoV    o o
+        V R R V              V R
+
+        -------   -    -     - -
+        0 0 0 0   0    0    >0 0
+        0 0 0 1   0    1    >1 0
+        0 0 1 0   0    0     0 1
+        0 0 1 1   0    0     0 1
+        -------   -    -     - -
+        0 1 0 0   0    0    >0 0
+        0 1 0 1   0    1    >1 0
+        0 1 1 0   0    0     0 1
+        0 1 1 1   0    0     0 1
+        -------   -    -     - -
+        1 0 0 0   0    0    >0 0
+        1 0 0 1   0    1    >1 0
+        1 0 1 0   0    0     0 1
+        1 0 1 1   0    0     0 1
+        -------   -    -     - -
+        1 1 0 0   1    0     1 0
+        1 1 0 1   1    1     1 0
+        1 1 1 0   1    0     1 1
+        1 1 1 1   1    0     1 1
+        -------   -    -     - -
     """
 
     def elaborate(self, platform):
@@ -735,6 +863,38 @@ class UnbufferedPipeline(ControlBase):
         result: output_shape according to ospec
             The output of the combinatorial logic.  it is updated
             COMBINATORIALLY (no clock dependence).
+
+        Truth Table
+
+        Inputs  Temp  Output
+        -------   -   -----
+        P P N N ~NiR&  N P
+        i o i o  NoV   o o
+        V R R V        V R
+
+        -------   -    - -
+        0 0 0 0   0    0 1
+        0 0 0 1   1    1 0
+        0 0 1 0   0    0 1
+        0 0 1 1   0    0 1
+        -------   -    - -
+        0 1 0 0   0    0 1
+        0 1 0 1   1    1 0
+        0 1 1 0   0    0 1
+        0 1 1 1   0    0 1
+        -------   -    - -
+        1 0 0 0   0    1 1
+        1 0 0 1   1    1 0
+        1 0 1 0   0    1 1
+        1 0 1 1   0    1 1
+        -------   -    - -
+        1 1 0 0   0    1 1
+        1 1 0 1   1    1 0
+        1 1 1 0   0    1 1
+        1 1 1 1   0    1 1
+        -------   -    - -
+
+        Note: PoR is *NOT* involved in the above decision-making.
     """
 
     def elaborate(self, platform):
@@ -853,3 +1013,59 @@ class RegisterPipeline(UnbufferedPipeline):
     def __init__(self, iospecfn):
         UnbufferedPipeline.__init__(self, PassThroughStage(iospecfn))
 
+
+class FIFOtest(ControlBase):
+    """ A test of using a SyncFIFO to see if it will work.
+        Note: the only things it will accept is a Signal of width "width".
+    """
+
+    def __init__(self, iospecfn, width, depth):
+
+        self.iospecfn = iospecfn
+        self.fwidth = width # XXX temporary
+        self.fdepth = depth
+        #stage = PassThroughStage(iospecfn)
+        ControlBase.__init__(self, stage=self)
+
+    def ispec(self): return self.iospecfn()
+    def ospec(self): return Signal(self.fwidth, name="dout")
+    def process(self, i): return i
+
+    def elaborate(self, platform):
+        self.m = m = ControlBase._elaborate(self, platform)
+
+        (fwidth, _) = self.p.i_data.shape()
+        fifo = SyncFIFO(fwidth, self.fdepth)
+        m.submodules.fifo = fifo
+
+        # XXX TODO: would be nice to do these...
+        ## prev: make the FIFO "look" like a PrevControl...
+        fp = PrevControl()
+        fp.i_valid = fifo.we
+        fp._o_ready = fifo.writable
+        fp.i_data = fifo.din
+        m.d.comb += fp._connect_in(self.p, True, fn=flatten)
+
+        # next: make the FIFO "look" like a NextControl...
+        fn = NextControl()
+        fn.o_valid = fifo.readable
+        fn.i_ready = fifo.re
+        fn.o_data = fifo.dout
+        # ... so we can do this!
+        m.d.comb += fn._connect_out(self.n, fn=flatten)
+
+        # connect previous rdy/valid/data - do flatten on i_data
+        #m.d.comb += [fifo.we.eq(self.p.i_valid_test),
+        #             self.p.o_ready.eq(fifo.writable),
+        #             eq(fifo.din, flatten(self.p.i_data)),
+        #           ]
+
+        # connect next rdy/valid/data - do flatten on o_data
+        #m.d.comb += [self.n.o_valid.eq(fifo.readable),
+        #             fifo.re.eq(self.n.i_ready_test),
+        #             flatten(self.n.o_data).eq(fifo.dout),
+        #           ]
+
+        # err... that should be all!
+        return m
+