initial version of BufferedPipeline with multi-in and multi-out
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Mon, 25 Mar 2019 06:30:02 +0000 (06:30 +0000)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Mon, 25 Mar 2019 06:30:02 +0000 (06:30 +0000)
src/add/example_buf_pipe.py
src/add/test_buf_pipe.py

index 8c4cae44ce9d7e25256b0c22d34813362a5a2f43..6c2345c4d77f88e68d16112785c400f970d585a4 100644 (file)
@@ -97,7 +97,7 @@
     it's quite a complex state machine!
 """
 
-from nmigen import Signal, Cat, Const, Mux, Module
+from nmigen import Signal, Cat, Const, Mux, Module, Array
 from nmigen.cli import verilog, rtlil
 from nmigen.hdl.rec import Record, Layout
 
@@ -244,7 +244,7 @@ class StageChain:
 class PipelineBase:
     """ Common functions for Pipeline API
     """
-    def __init__(self, stage, in_multi=None):
+    def __init__(self, stage, in_multi=None, p_len=1, n_len=1):
         """ pass in a "stage" which may be either a static class or a class
             instance, which has four functions (one optional):
             * ispec: returns input signals according to the input specification
@@ -259,36 +259,56 @@ class PipelineBase:
         self.stage = stage
 
         # set up input and output IO ACK (prev/next ready/valid)
-        self.p = PrevControl(in_multi)
-        self.n = NextControl()
+        p = []
+        n = []
+        for i in range(p_len):
+            p.append(PrevControl(in_multi))
+        for i in range(n_len):
+            n.append(NextControl())
+        if p_len > 0:
+            self.p = Array(p)
+        else:
+            self.p = p
+        if n_len > 0:
+            self.n = Array(n)
+        else:
+            self.n = n
 
-    def connect_to_next(self, nxt):
+    def connect_to_next(self, nxt, p_idx=0, n_idx=0):
         """ helper function to connect to the next stage data/valid/ready.
         """
-        return self.n.connect_to_next(nxt.p)
+        return self.n[n_idx].connect_to_next(nxt.p[p_idx])
 
-    def connect_in(self, prev):
+    def connect_in(self, prev, idx=0, prev_idx=None):
         """ helper function to connect stage to an input source.  do not
             use to connect stage-to-stage!
         """
-        return self.p.connect_in(prev.p)
+        if prev_idx is None:
+            return self.p[idx].connect_in(prev.p)
+        return self.p[idx].connect_in(prev.p[prev_idx])
 
-    def connect_out(self, nxt):
+    def connect_out(self, nxt, idx=0, nxt_idx=None):
         """ helper function to connect stage to an output source.  do not
             use to connect stage-to-stage!
         """
-        return self.n.connect_out(nxt.n)
+        if nxt_idx is None:
+            return self.n[idx].connect_out(nxt.n)
+        return self.n[idx].connect_out(nxt.n[nxt+idx])
 
-    def set_input(self, i):
+    def set_input(self, i, idx=0):
         """ helper function to set the input data
         """
-        return eq(self.p.i_data, i)
+        return eq(self.p[idx].i_data, i)
 
     def ports(self):
-        return [self.p.i_valid, self.n.i_ready,
-                self.n.o_valid, self.p.o_ready,
-                self.p.i_data, self.n.o_data   # XXX need flattening!
-               ]
+        res = []
+        for i in range(len(self.p)):
+            res += [self.p[i].i_valid, self.p[i].o_ready,
+                    self.p[i].i_data]# XXX need flattening!]
+        for i in range(len(self.n)):
+            res += [self.n[i].i_ready, self.n[i].o_valid,
+                    self.n.o_data]   # XXX need flattening!]
+        return res
 
 
 class BufferedPipeline(PipelineBase):
@@ -321,12 +341,16 @@ class BufferedPipeline(PipelineBase):
         input may begin to be processed and transferred directly to output.
 
     """
-    def __init__(self, stage):
+    def __init__(self, stage, n_len=1, p_len=1, p_mux=None, n_mux=None):
         PipelineBase.__init__(self, stage)
+        self.p_mux = p_mux
+        self.n_mux = n_mux
 
         # set up the input and output data
-        self.p.i_data = stage.ispec() # input type
-        self.n.o_data = stage.ospec()
+        for i in range(p_len):
+            self.p[i].i_data = stage.ispec() # input type
+        for i in range(n_len):
+            self.n[i].o_data = stage.ospec()
 
     def elaborate(self, platform):
         m = Module()
@@ -334,49 +358,53 @@ class BufferedPipeline(PipelineBase):
         result = self.stage.ospec()
         r_data = self.stage.ospec()
         if hasattr(self.stage, "setup"):
-            self.stage.setup(m, self.p.i_data)
+            for i in range(p_len):
+                self.stage.setup(m, self.p[i].i_data)
+
+        pi = 0 # TODO: use p_mux to decide which to select
+        ni = 0 # TODO: use n_nux to decide which to select
 
         # establish some combinatorial temporaries
         o_n_validn = Signal(reset_less=True)
         i_p_valid_o_p_ready = Signal(reset_less=True)
         p_i_valid = Signal(reset_less=True)
-        m.d.comb += [p_i_valid.eq(self.p.i_valid_logic()),
-                     o_n_validn.eq(~self.n.o_valid),
-                     i_p_valid_o_p_ready.eq(p_i_valid & self.p.o_ready),
+        m.d.comb += [p_i_valid.eq(self.p[pi].i_valid_logic()),
+                     o_n_validn.eq(~self.n[ni].o_valid),
+                     i_p_valid_o_p_ready.eq(p_i_valid & self.p[pi].o_ready),
         ]
 
         # store result of processing in combinatorial temporary
-        m.d.comb += eq(result, self.stage.process(self.p.i_data))
+        m.d.comb += eq(result, self.stage.process(self.p[pi].i_data))
 
         # if not in stall condition, update the temporary register
-        with m.If(self.p.o_ready): # not stalled
+        with m.If(self.p[pi].o_ready): # not stalled
             m.d.sync += eq(r_data, result) # update buffer
 
-        with m.If(self.n.i_ready): # next stage is ready
-            with m.If(self.p.o_ready): # not stalled
+        with m.If(self.n[ni].i_ready): # next stage is ready
+            with m.If(self.p[pi].o_ready): # not stalled
                 # nothing in buffer: send (processed) input direct to output
-                m.d.sync += [self.n.o_valid.eq(p_i_valid),
-                             eq(self.n.o_data, result), # update output
+                m.d.sync += [self.n[ni].o_valid.eq(p_i_valid),
+                             eq(self.n[ni].o_data, result), # update output
                             ]
             with m.Else(): # p.o_ready is false, and something is in buffer.
                 # Flush the [already processed] buffer to the output port.
-                m.d.sync += [self.n.o_valid.eq(1),      # declare reg empty
-                             eq(self.n.o_data, r_data), # flush buffer
-                             self.p.o_ready.eq(1),      # clear stall condition
+                m.d.sync += [self.n[ni].o_valid.eq(1),      # declare reg empty
+                             eq(self.n[ni].o_data, r_data), # flush buffer
+                             self.p[pi].o_ready.eq(1),      # clear stall 
                             ]
                 # ignore input, since p.o_ready is also false.
 
         # (n.i_ready) is false here: next stage is ready
         with m.Elif(o_n_validn): # next stage being told "ready"
-            m.d.sync += [self.n.o_valid.eq(p_i_valid),
-                         self.p.o_ready.eq(1), # Keep the buffer empty
-                         eq(self.n.o_data, result), # set output data
+            m.d.sync += [self.n[ni].o_valid.eq(p_i_valid),
+                         self.p[pi].o_ready.eq(1), # Keep the buffer empty
+                         eq(self.n[ni].o_data, result), # set output data
                         ]
 
         # (n.i_ready) false and (n.o_valid) true:
         with m.Elif(i_p_valid_o_p_ready):
             # If next stage *is* ready, and not stalled yet, accept input
-            m.d.sync += self.p.o_ready.eq(~(p_i_valid & self.n.o_valid))
+            m.d.sync += self.p[pi].o_ready.eq(~(p_i_valid & self.n[ni].o_valid))
 
         return m
 
index 53c177d80f6d6f93dd9e9522aa788624def8ce77..49d53935b6cff156f0534ce5844a067494ed4dff 100644 (file)
@@ -13,37 +13,41 @@ from random import randint
 
 
 def check_o_n_valid(dut, val):
+    o_n_valid = yield dut.n[0].o_valid
+    assert o_n_valid == val
+
+def check_o_n_valid2(dut, val):
     o_n_valid = yield dut.n.o_valid
     assert o_n_valid == val
 
 
 def testbench(dut):
     #yield dut.i_p_rst.eq(1)
-    yield dut.n.i_ready.eq(0)
-    yield dut.p.o_ready.eq(0)
+    yield dut.n[0].i_ready.eq(0)
+    yield dut.p[0].o_ready.eq(0)
     yield
     yield
     #yield dut.i_p_rst.eq(0)
-    yield dut.n.i_ready.eq(1)
-    yield dut.p.i_data.eq(5)
-    yield dut.p.i_valid.eq(1)
+    yield dut.n[0].i_ready.eq(1)
+    yield dut.p[0].i_data.eq(5)
+    yield dut.p[0].i_valid.eq(1)
     yield
 
-    yield dut.p.i_data.eq(7)
+    yield dut.p[0].i_data.eq(7)
     yield from check_o_n_valid(dut, 0) # effects of i_p_valid delayed
     yield
     yield from check_o_n_valid(dut, 1) # ok *now* i_p_valid effect is felt
 
-    yield dut.p.i_data.eq(2)
+    yield dut.p[0].i_data.eq(2)
     yield
-    yield dut.n.i_ready.eq(0) # begin going into "stall" (next stage says ready)
-    yield dut.p.i_data.eq(9)
+    yield dut.n[0].i_ready.eq(0) # begin going into "stall" (next stage says ready)
+    yield dut.p[0].i_data.eq(9)
     yield
-    yield dut.p.i_valid.eq(0)
-    yield dut.p.i_data.eq(12)
+    yield dut.p[0].i_valid.eq(0)
+    yield dut.p[0].i_data.eq(12)
     yield
-    yield dut.p.i_data.eq(32)
-    yield dut.n.i_ready.eq(1)
+    yield dut.p[0].i_data.eq(32)
+    yield dut.n[0].i_ready.eq(1)
     yield
     yield from check_o_n_valid(dut, 1) # buffer still needs to output
     yield
@@ -66,13 +70,13 @@ def testbench2(dut):
     yield
 
     yield dut.p.i_data.eq(7)
-    yield from check_o_n_valid(dut, 0) # effects of i_p_valid delayed 2 clocks
+    yield from check_o_n_valid2(dut, 0) # effects of i_p_valid delayed 2 clocks
     yield
-    yield from check_o_n_valid(dut, 0) # effects of i_p_valid delayed 2 clocks
+    yield from check_o_n_valid2(dut, 0) # effects of i_p_valid delayed 2 clocks
 
     yield dut.p.i_data.eq(2)
     yield
-    yield from check_o_n_valid(dut, 1) # ok *now* i_p_valid effect is felt
+    yield from check_o_n_valid2(dut, 1) # ok *now* i_p_valid effect is felt
     yield dut.n.i_ready.eq(0) # begin going into "stall" (next stage says ready)
     yield dut.p.i_data.eq(9)
     yield
@@ -82,13 +86,13 @@ def testbench2(dut):
     yield dut.p.i_data.eq(32)
     yield dut.n.i_ready.eq(1)
     yield
-    yield from check_o_n_valid(dut, 1) # buffer still needs to output
+    yield from check_o_n_valid2(dut, 1) # buffer still needs to output
     yield
-    yield from check_o_n_valid(dut, 1) # buffer still needs to output
+    yield from check_o_n_valid2(dut, 1) # buffer still needs to output
     yield
-    yield from check_o_n_valid(dut, 1) # buffer still needs to output
+    yield from check_o_n_valid2(dut, 1) # buffer still needs to output
     yield
-    yield from check_o_n_valid(dut, 0) # buffer outputted, *now* we're done.
+    yield from check_o_n_valid2(dut, 0) # buffer outputted, *now* we're done.
     yield
     yield
     yield
@@ -113,16 +117,16 @@ class Test3:
                     send = True
                 else:
                     send = randint(0, send_range) != 0
-                o_p_ready = yield self.dut.p.o_ready
+                o_p_ready = yield self.dut.p[0].o_ready
                 if not o_p_ready:
                     yield
                     continue
                 if send and self.i != len(self.data):
-                    yield self.dut.p.i_valid.eq(1)
-                    yield self.dut.p.i_data.eq(self.data[self.i])
+                    yield self.dut.p[0].i_valid.eq(1)
+                    yield self.dut.p[0].i_data.eq(self.data[self.i])
                     self.i += 1
                 else:
-                    yield self.dut.p.i_valid.eq(0)
+                    yield self.dut.p[0].i_valid.eq(0)
                 yield
 
     def rcv(self):
@@ -130,13 +134,13 @@ class Test3:
             stall_range = randint(0, 3)
             for j in range(randint(1,10)):
                 stall = randint(0, stall_range) != 0
-                yield self.dut.n.i_ready.eq(stall)
+                yield self.dut.n[0].i_ready.eq(stall)
                 yield
-                o_n_valid = yield self.dut.n.o_valid
-                i_n_ready = yield self.dut.n.i_ready
+                o_n_valid = yield self.dut.n[0].o_valid
+                i_n_ready = yield self.dut.n[0].i_ready
                 if not o_n_valid or not i_n_ready:
                     continue
-                o_data = yield self.dut.n.o_data
+                o_data = yield self.dut.n[0].o_data
                 self.resultfn(o_data, self.data[self.o], self.i, self.o)
                 self.o += 1
                 if self.o == len(self.data):