pass in flatten/processing function into _connect_in/out
[ieee754fpu.git] / src / add / singlepipe.py
index 137f8163a966ef71fe8018ae894c1eeb4ae7330e..8296bf2d4ecd9702cd2d32c20dc345734ed2a641 100644 (file)
@@ -174,6 +174,22 @@ from abc import ABCMeta, abstractmethod
 from collections.abc import Sequence
 
 
+class RecordObject(Record):
+    def __init__(self, layout=None, name=None):
+        Record.__init__(self, layout=layout or [], name=None)
+
+    def __setattr__(self, k, v):
+        if k in dir(Record) or "fields" not in self.__dict__:
+            return object.__setattr__(self, k, v)
+        self.fields[k] = v
+        if isinstance(v, Record):
+            newlayout = {k: (k, v.layout)}
+        else:
+            newlayout = {k: (k, v.shape())}
+        self.layout.fields.update(newlayout)
+
+
+
 class PrevControl:
     """ contains signals that come *from* the previous stage (both in and out)
         * i_valid: previous stage indicating all incoming data is valid.
@@ -199,17 +215,15 @@ class PrevControl:
             return self.s_o_ready # set dynamically by stage
         return self._o_ready      # return this when not under dynamic control
 
-    def _connect_in(self, prev, direct=False):
+    def _connect_in(self, prev, direct=False, fn=None):
         """ internal helper function to connect stage to an input source.
             do not use to connect stage-to-stage!
         """
-        if direct:
-            i_valid = prev.i_valid
-        else:
-            i_valid = prev.i_valid_test
+        i_valid = prev.i_valid if direct else prev.i_valid_test
+        i_data = fn(prev.i_data) if fn is not None else prev.i_data
         return [self.i_valid.eq(i_valid),
                 prev.o_ready.eq(self.o_ready),
-                eq(self.i_data, prev.i_data),
+                eq(self.i_data, i_data),
                ]
 
     @property
@@ -261,17 +275,15 @@ class NextControl:
                 eq(nxt.i_data, self.o_data),
                ]
 
-    def _connect_out(self, nxt, direct=False):
+    def _connect_out(self, nxt, direct=False, fn=None):
         """ internal helper function to connect stage to an output source.
             do not use to connect stage-to-stage!
         """
-        if direct:
-            i_ready = nxt.i_ready
-        else:
-            i_ready = nxt.i_ready_test
+        i_ready = nxt.i_ready if direct else nxt.i_ready_test
+        o_data = fn(self.o_data) if fn is not None else self.o_data
         return [nxt.o_valid.eq(self.o_valid),
                 self.i_ready.eq(i_ready),
-                eq(nxt.o_data, self.o_data),
+                eq(nxt.o_data, o_data),
                ]
 
 
@@ -333,7 +345,11 @@ class Visitor:
                 val = getattr(val, field_name)
             else:
                 val = val[field_name] # dictionary-style specification
-            res += self.visit(ao.fields[field_name], val, act)
+            val = self.visit(ao.fields[field_name], val, act)
+            if isinstance(val, Sequence):
+                res += val
+            else:
+                res.append(val)
         return res
 
     def arrayproxy_visit(self, ao, ai, act):
@@ -358,6 +374,7 @@ class Eq(Visitor):
     def __call__(self, o, i):
         return self.visit(o, i, self)
 
+
 def eq(o, i):
     """ makes signals equal: a helper routine which identifies if it is being
         passed a list (or tuple) of objects, or signals, or Records, and calls
@@ -366,6 +383,50 @@ def eq(o, i):
     return Eq()(o, i)
 
 
+def flatten(i):
+    """ flattens a compound structure recursively using Cat
+    """
+    if not isinstance(i, Sequence):
+        i = [i]
+    res = []
+    for ai in i:
+        print ("flatten", ai)
+        if isinstance(ai, Record):
+            print ("record", list(ai.layout))
+            rres = []
+            for idx, (field_name, field_shape, _) in enumerate(ai.layout):
+                if isinstance(field_shape, Layout):
+                    val = ai.fields
+                else:
+                    val = ai
+                if hasattr(val, field_name): # check for attribute
+                    val = getattr(val, field_name)
+                else:
+                    val = val[field_name] # dictionary-style specification
+                print ("recidx", idx, field_name, field_shape, val)
+                val = flatten(val)
+                print ("recidx flat", idx, val)
+                if isinstance(val, Sequence):
+                    rres += val
+                else:
+                    rres.append(val)
+
+        elif isinstance(ai, ArrayProxy) and not isinstance(ai, Value):
+            rres = []
+            for p in ai.ports():
+                op = getattr(ai, p.name)
+                #print (op, p, p.name)
+                rres.append(flatten(p))
+        else:
+            rres = ai
+        if not isinstance(rres, Sequence):
+            rres = [rres]
+        res += rres
+        print ("flatten res", res)
+    return Cat(*res)
+
+
+
 class StageCls(metaclass=ABCMeta):
     """ Class-based "Stage" API.  requires instantiation (after derivation)
 
@@ -958,36 +1019,52 @@ class FIFOtest(ControlBase):
         Note: the only things it will accept is a Signal of width "width".
     """
 
-    def __init__(self, width, depth):
+    def __init__(self, iospecfn, width, depth):
 
-        self.fwidth = width
+        self.iospecfn = iospecfn
+        self.fwidth = width # XXX temporary
         self.fdepth = depth
-        def iospecfn():
-            return Signal(width, name="data")
-        stage = PassThroughStage(iospecfn)
-        ControlBase.__init__(self, stage=stage)
+        #stage = PassThroughStage(iospecfn)
+        ControlBase.__init__(self, stage=self)
+
+    def ispec(self): return self.iospecfn()
+    def ospec(self): return Signal(self.fwidth, name="dout")
+    def process(self, i): return i
 
     def elaborate(self, platform):
         self.m = m = ControlBase._elaborate(self, platform)
 
-        fifo = SyncFIFO(self.fwidth, self.fdepth)
+        (fwidth, _) = self.p.i_data.shape()
+        fifo = SyncFIFO(fwidth, self.fdepth)
         m.submodules.fifo = fifo
 
-        # prev: make the FIFO "look" like a PrevControl...
+        # XXX TODO: would be nice to do these...
+        ## prev: make the FIFO "look" like a PrevControl...
         fp = PrevControl()
         fp.i_valid = fifo.we
         fp._o_ready = fifo.writable
         fp.i_data = fifo.din
-        # ... so we can do this!
-        m.d.comb += fp._connect_in(self.p, True)
-        
+        m.d.comb += fp._connect_in(self.p, True, fn=flatten)
+
         # next: make the FIFO "look" like a NextControl...
         fn = NextControl()
         fn.o_valid = fifo.readable
         fn.i_ready = fifo.re
         fn.o_data = fifo.dout
         # ... so we can do this!
-        m.d.comb += fn._connect_out(self.n)
+        m.d.comb += fn._connect_out(self.n, fn=flatten)
+
+        # connect previous rdy/valid/data - do flatten on i_data
+        #m.d.comb += [fifo.we.eq(self.p.i_valid_test),
+        #             self.p.o_ready.eq(fifo.writable),
+        #             eq(fifo.din, flatten(self.p.i_data)),
+        #           ]
+
+        # connect next rdy/valid/data - do flatten on o_data
+        #m.d.comb += [self.n.o_valid.eq(fifo.readable),
+        #             fifo.re.eq(self.n.i_ready_test),
+        #             flatten(self.n.o_data).eq(fifo.dout),
+        #           ]
 
         # err... that should be all!
         return m