pass in flatten/processing function into _connect_in/out
[ieee754fpu.git] / src / add / singlepipe.py
index d7aee14d2dd3f4442def4c8449ee47dbc2539293..8296bf2d4ecd9702cd2d32c20dc345734ed2a641 100644 (file)
@@ -174,6 +174,22 @@ from abc import ABCMeta, abstractmethod
 from collections.abc import Sequence
 
 
+class RecordObject(Record):
+    def __init__(self, layout=None, name=None):
+        Record.__init__(self, layout=layout or [], name=None)
+
+    def __setattr__(self, k, v):
+        if k in dir(Record) or "fields" not in self.__dict__:
+            return object.__setattr__(self, k, v)
+        self.fields[k] = v
+        if isinstance(v, Record):
+            newlayout = {k: (k, v.layout)}
+        else:
+            newlayout = {k: (k, v.shape())}
+        self.layout.fields.update(newlayout)
+
+
+
 class PrevControl:
     """ contains signals that come *from* the previous stage (both in and out)
         * i_valid: previous stage indicating all incoming data is valid.
@@ -199,13 +215,15 @@ class PrevControl:
             return self.s_o_ready # set dynamically by stage
         return self._o_ready      # return this when not under dynamic control
 
-    def _connect_in(self, prev):
+    def _connect_in(self, prev, direct=False, fn=None):
         """ internal helper function to connect stage to an input source.
             do not use to connect stage-to-stage!
         """
-        return [self.i_valid.eq(prev.i_valid_test),
+        i_valid = prev.i_valid if direct else prev.i_valid_test
+        i_data = fn(prev.i_data) if fn is not None else prev.i_data
+        return [self.i_valid.eq(i_valid),
                 prev.o_ready.eq(self.o_ready),
-                eq(self.i_data, prev.i_data),
+                eq(self.i_data, i_data),
                ]
 
     @property
@@ -257,25 +275,24 @@ class NextControl:
                 eq(nxt.i_data, self.o_data),
                ]
 
-    def _connect_out(self, nxt):
+    def _connect_out(self, nxt, direct=False, fn=None):
         """ internal helper function to connect stage to an output source.
             do not use to connect stage-to-stage!
         """
+        i_ready = nxt.i_ready if direct else nxt.i_ready_test
+        o_data = fn(self.o_data) if fn is not None else self.o_data
         return [nxt.o_valid.eq(self.o_valid),
-                self.i_ready.eq(nxt.i_ready_test),
-                eq(nxt.o_data, self.o_data),
+                self.i_ready.eq(i_ready),
+                eq(nxt.o_data, o_data),
                ]
 
 
-def eq(o, i):
-    """ makes signals equal: a helper routine which identifies if it is being
-        passed a list (or tuple) of objects, or signals, or Records, and calls
-        the objects' eq function.
+class Visitor:
+    """ a helper routine which identifies if it is being passed a list
+        (or tuple) of objects, or signals, or Records, and calls
+        a visitor function.
 
-        complex objects (classes) can be used: they must follow the
-        convention of having an eq member function, which takes the
-        responsibility of further calling eq and returning a list of
-        eq assignments
+        the visiting fn is called when an object is identified.
 
         Record is a special (unusual, recursive) case, where the input may be
         specified as a dictionary (which may contain further dictionaries,
@@ -292,20 +309,92 @@ def eq(o, i):
         python object, enumerate them, find out the list of Signals that way,
         and assign them.
     """
-    res = []
-    if isinstance(o, dict):
+    def visit(self, o, i, act):
+        if isinstance(o, dict):
+            return self.dict_visit(o, i, act)
+
+        res = act.prepare()
+        if not isinstance(o, Sequence):
+            o, i = [o], [i]
+        for (ao, ai) in zip(o, i):
+            #print ("visit", fn, ao, ai)
+            if isinstance(ao, Record):
+                rres = self.record_visit(ao, ai, act)
+            elif isinstance(ao, ArrayProxy) and not isinstance(ai, Value):
+                rres = self.arrayproxy_visit(ao, ai, act)
+            else:
+                rres = act.fn(ao, ai)
+            res += rres
+        return res
+
+    def dict_visit(self, o, i, act):
+        res = act.prepare()
         for (k, v) in o.items():
             print ("d-eq", v, i[k])
-            res.append(v.eq(i[k]))
+            res.append(act.fn(v, i[k]))
+        return res
+
+    def record_visit(self, ao, ai, act):
+        res = act.prepare()
+        for idx, (field_name, field_shape, _) in enumerate(ao.layout):
+            if isinstance(field_shape, Layout):
+                val = ai.fields
+            else:
+                val = ai
+            if hasattr(val, field_name): # check for attribute
+                val = getattr(val, field_name)
+            else:
+                val = val[field_name] # dictionary-style specification
+            val = self.visit(ao.fields[field_name], val, act)
+            if isinstance(val, Sequence):
+                res += val
+            else:
+                res.append(val)
         return res
 
-    if not isinstance(o, Sequence):
-        o, i = [o], [i]
-    for (ao, ai) in zip(o, i):
-        #print ("eq", ao, ai)
-        if isinstance(ao, Record):
+    def arrayproxy_visit(self, ao, ai, act):
+        res = act.prepare()
+        for p in ai.ports():
+            op = getattr(ao, p.name)
+            #print (op, p, p.name)
+            res.append(fn(op, p))
+        return res
+
+
+class Eq(Visitor):
+    def __init__(self):
+        self.res = []
+    def prepare(self):
+        return []
+    def fn(self, o, i):
+        rres = o.eq(i)
+        if not isinstance(rres, Sequence):
+            rres = [rres]
+        return rres
+    def __call__(self, o, i):
+        return self.visit(o, i, self)
+
+
+def eq(o, i):
+    """ makes signals equal: a helper routine which identifies if it is being
+        passed a list (or tuple) of objects, or signals, or Records, and calls
+        the objects' eq function.
+    """
+    return Eq()(o, i)
+
+
+def flatten(i):
+    """ flattens a compound structure recursively using Cat
+    """
+    if not isinstance(i, Sequence):
+        i = [i]
+    res = []
+    for ai in i:
+        print ("flatten", ai)
+        if isinstance(ai, Record):
+            print ("record", list(ai.layout))
             rres = []
-            for idx, (field_name, field_shape, _) in enumerate(ao.layout):
+            for idx, (field_name, field_shape, _) in enumerate(ai.layout):
                 if isinstance(field_shape, Layout):
                     val = ai.fields
                 else:
@@ -314,19 +403,28 @@ def eq(o, i):
                     val = getattr(val, field_name)
                 else:
                     val = val[field_name] # dictionary-style specification
-                rres += eq(ao.fields[field_name], val)
-        elif isinstance(ao, ArrayProxy) and not isinstance(ai, Value):
+                print ("recidx", idx, field_name, field_shape, val)
+                val = flatten(val)
+                print ("recidx flat", idx, val)
+                if isinstance(val, Sequence):
+                    rres += val
+                else:
+                    rres.append(val)
+
+        elif isinstance(ai, ArrayProxy) and not isinstance(ai, Value):
             rres = []
             for p in ai.ports():
-                op = getattr(ao, p.name)
+                op = getattr(ai, p.name)
                 #print (op, p, p.name)
-                rres.append(op.eq(p))
+                rres.append(flatten(p))
         else:
-            rres = ao.eq(ai)
+            rres = ai
         if not isinstance(rres, Sequence):
             rres = [rres]
         res += rres
-    return res
+        print ("flatten res", res)
+    return Cat(*res)
+
 
 
 class StageCls(metaclass=ABCMeta):
@@ -921,35 +1019,52 @@ class FIFOtest(ControlBase):
         Note: the only things it will accept is a Signal of width "width".
     """
 
-    def __init__(self, width, depth):
+    def __init__(self, iospecfn, width, depth):
 
-        self.fwidth = width
+        self.iospecfn = iospecfn
+        self.fwidth = width # XXX temporary
         self.fdepth = depth
-        def iospecfn():
-            return Signal(width, name="data")
-        stage = PassThroughStage(iospecfn)
-        ControlBase.__init__(self, stage=stage)
+        #stage = PassThroughStage(iospecfn)
+        ControlBase.__init__(self, stage=self)
+
+    def ispec(self): return self.iospecfn()
+    def ospec(self): return Signal(self.fwidth, name="dout")
+    def process(self, i): return i
 
     def elaborate(self, platform):
         self.m = m = ControlBase._elaborate(self, platform)
 
-        fifo = SyncFIFO(self.fwidth, self.fdepth)
+        (fwidth, _) = self.p.i_data.shape()
+        fifo = SyncFIFO(fwidth, self.fdepth)
+        m.submodules.fifo = fifo
 
-        # prev: make the FIFO "look" like a PrevControl...
+        # XXX TODO: would be nice to do these...
+        ## prev: make the FIFO "look" like a PrevControl...
         fp = PrevControl()
-        fp.i_valid = fifo.writable
-        fp._o_ready = fifo.we
+        fp.i_valid = fifo.we
+        fp._o_ready = fifo.writable
         fp.i_data = fifo.din
-        # ... so we can do this!
-        m.d.comb += fp._connect_in(self)
-        
+        m.d.comb += fp._connect_in(self.p, True, fn=flatten)
+
         # next: make the FIFO "look" like a NextControl...
         fn = NextControl()
         fn.o_valid = fifo.readable
         fn.i_ready = fifo.re
         fn.o_data = fifo.dout
         # ... so we can do this!
-        m.d.comb += fn._connect_out(self)
+        m.d.comb += fn._connect_out(self.n, fn=flatten)
+
+        # connect previous rdy/valid/data - do flatten on i_data
+        #m.d.comb += [fifo.we.eq(self.p.i_valid_test),
+        #             self.p.o_ready.eq(fifo.writable),
+        #             eq(fifo.din, flatten(self.p.i_data)),
+        #           ]
+
+        # connect next rdy/valid/data - do flatten on o_data
+        #m.d.comb += [self.n.o_valid.eq(fifo.readable),
+        #             fifo.re.eq(self.n.i_ready_test),
+        #             flatten(self.n.o_data).eq(fifo.dout),
+        #           ]
 
         # err... that should be all!
         return m