transform to meet nmigen FIFOInterface API
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Tue, 16 Apr 2019 12:00:59 +0000 (13:00 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Tue, 16 Apr 2019 12:00:59 +0000 (13:00 +0100)
src/add/ChiselQueue.py

index c62477a6e230e1c7fcb2b81488618370ff1e5cbe..b0f96e8ff111b684734591c72a4dbbe4c87c23c0 100644 (file)
@@ -27,114 +27,128 @@ from nmigen import Module, Signal, Memory, Mux
 from nmigen.tools import bits_for
 from typing import Tuple, Any, List
 from nmigen.cli import main
+from nmigen.lib.fifo import FIFOInterface
 
 # translated from https://github.com/freechipsproject/chisel3/blob/a4a29e29c3f1eed18f851dcf10bdc845571dfcb6/src/main/scala/chisel3/util/Decoupled.scala#L185   # noqa
 
 
-class Queue:
-    def __init__(self,
-                 data_width: int,
-                 entries: int,
-                 *,
-                 pipe: bool = False,
-                 flow: bool = False):
-        self.entries = entries
-        self.__pipe = pipe
-        self.__flow = flow
-        self.enq_data = Signal(data_width)
-        self.enq_ready = Signal(1)
-        self.enq_valid = Signal(1)
-        self.deq_data = Signal(data_width)
-        self.deq_ready = Signal(1)
-        self.deq_valid = Signal(1)
-        self.count = Signal(bits_for(entries))
-        self.__ram = Memory(data_width, entries if entries > 1 else 2)
-        self.__ram_read = self.__ram.read_port(synchronous=False)
-        self.__ram_write = self.__ram.write_port()
-        ptr_width = bits_for(entries - 1) if entries > 1 else 0
-        self.__enq_ptr = Signal(ptr_width, reset=0)
-        self.__deq_ptr = Signal(ptr_width, reset=0)
-        self.__maybe_full = Signal(1, reset=0)
-        self.__ptr_match = self.__enq_ptr == self.__deq_ptr
-        self.__empty = self.__ptr_match & ~self.__maybe_full
-        self.__full = self.__ptr_match & self.__maybe_full
-        self.__do_enq = Signal(1)
-        self.__do_deq = Signal(1)
-        self.__ptr_diff = self.__enq_ptr - self.__deq_ptr
-
-    def elaborate(self, platform: Any) -> Module:
+class Queue(FIFOInterface):
+    def __init__(self, width, depth, fwft=True, pipe=False):
+        """ din  = enq_data, writable  = enq_ready, we = enq_valid
+            dout = deq_data, re = deq_ready, readable = deq_valid
+        """
+        FIFOInterface.__init__(self, width, depth, fwft)
+        self.pipe = pipe
+        self.depth = depth
+        self.count = Signal(bits_for(depth))
+
+    def elaborate(self, platform):
         m = Module()
-        m.submodules.ram_read = self.__ram_read
-        m.submodules.ram_write = self.__ram_write
-        m.d.comb += self.__do_enq.eq(self.enq_ready & self.enq_valid)
-        m.d.comb += self.__do_deq.eq(self.deq_ready & self.deq_valid)
-        m.d.comb += self.__ram_write.addr.eq(self.__enq_ptr)
-        m.d.comb += self.__ram_write.data.eq(self.enq_data)
-        m.d.comb += self.__ram_write.en.eq(0)
-        with m.If(self.__do_enq):
-            m.d.comb += self.__ram_write.en.eq(1)
-            with m.If(self.__enq_ptr == self.entries - 1):
-                m.d.sync += self.__enq_ptr.eq(0)
+
+        ram = Memory(self.width, self.depth if self.depth > 1 else 2)
+        ram_read = ram.read_port(synchronous=False)
+        ram_write = ram.write_port()
+        ptr_width = bits_for(self.depth - 1) if self.depth > 1 else 0
+
+        enq_ptr = Signal(ptr_width)
+        deq_ptr = Signal(ptr_width)
+        maybe_full = Signal(reset_less=True)
+        do_enq = Signal(reset_less=True)
+        do_deq = Signal(reset_less=True)
+
+        # temporaries
+        ptr_diff = Signal(ptr_width)
+        ptr_match = Signal(reset_less=True)
+        empty = Signal(reset_less=True)
+        full = Signal(reset_less=True)
+
+        m.d.comb += [ptr_match.eq(enq_ptr == deq_ptr),
+                     ptr_diff.eq(enq_ptr - deq_ptr),
+                     empty.eq(ptr_match & ~maybe_full),
+                     full.eq(ptr_match & maybe_full)]
+
+        m.submodules.ram_read = ram_read
+        m.submodules.ram_write = ram_write
+
+        m.d.comb += [do_enq.eq(self.writable & self.we),
+                     do_deq.eq(self.re & self.readable),
+                     ram_write.addr.eq(enq_ptr),
+                     ram_write.data.eq(self.din),
+                     ram_write.en.eq(0)]
+
+        with m.If(do_enq):
+            m.d.comb += ram_write.en.eq(1)
+            with m.If(enq_ptr == self.depth - 1):
+                m.d.sync += enq_ptr.eq(0)
             with m.Else():
-                m.d.sync += self.__enq_ptr.eq(self.__enq_ptr + 1)
-        with m.If(self.__do_deq):
-            with m.If(self.__deq_ptr == self.entries - 1):
-                m.d.sync += self.__deq_ptr.eq(0)
+                m.d.sync += enq_ptr.eq(enq_ptr + 1)
+
+        with m.If(do_deq):
+            with m.If(deq_ptr == self.depth - 1):
+                m.d.sync += deq_ptr.eq(0)
             with m.Else():
-                m.d.sync += self.__deq_ptr.eq(self.__deq_ptr + 1)
-        with m.If(self.__do_enq != self.__do_deq):
-            m.d.sync += self.__maybe_full.eq(self.__do_enq)
-        m.d.comb += self.deq_valid.eq(~self.__empty)
-        m.d.comb += self.enq_ready.eq(~self.__full)
-        m.d.comb += self.__ram_read.addr.eq(self.__deq_ptr)
-        m.d.comb += self.deq_data.eq(self.__ram_read.data)
-        if self.__flow:
-            with m.If(self.enq_valid):
-                m.d.comb += self.deq_valid.eq(1)
-            with m.If(self.__empty):
-                m.d.comb += self.deq_data.eq(self.enq_data)
-                m.d.comb += self.__do_deq.eq(0)
-                with m.If(self.deq_ready):
-                    m.d.comb += self.__do_enq.eq(0)
-
-        if self.__pipe:
-            with m.If(self.deq_ready):
-                m.d.comb += self.enq_ready.eq(1)
-
-        if self.entries == 1 << len(self.count):  # is entries a power of 2
+                m.d.sync += deq_ptr.eq(deq_ptr + 1)
+
+        with m.If(do_enq != do_deq):
+            m.d.sync += maybe_full.eq(do_enq)
+
+        m.d.comb += [self.readable.eq(~empty),
+                     self.writable.eq(~full),
+                     ram_read.addr.eq(deq_ptr),
+                     self.dout.eq(ram_read.data)]
+
+        # first-word fall-through: same as "flow" parameter in Chisel3 Queue
+        # basically instead of relying on the Memory characteristics (which
+        # in FPGAs do not have write-through), then when the queue is empty
+        # take the output directly from the input, i.e. *bypass* the SRAM.
+        # this done combinatorially to give the exact same characteristics
+        # as Memory "write-through"... without relying on a changing API
+        if self.fwft:
+            with m.If(self.we):
+                m.d.comb += self.readable.eq(1)
+            with m.If(empty):
+                m.d.comb += self.dout.eq(self.din)
+                m.d.comb += do_deq.eq(0)
+                with m.If(self.re):
+                    m.d.comb += do_enq.eq(0)
+
+        if self.pipe:
+            with m.If(self.re):
+                m.d.comb += self.writable.eq(1)
+
+        if self.depth == 1 << len(self.count):  # is depth a power of 2
             m.d.comb += self.count.eq(
-                Mux(self.__maybe_full & self.__ptr_match, self.entries, 0)
-                | self.__ptr_diff)
+                Mux(self.maybe_full & self.ptr_match, self.depth, 0)
+                | self.ptr_diff)
         else:
-            m.d.comb += self.count.eq(Mux(
-                self.__ptr_match,
-                Mux(self.__maybe_full, self.entries, 0),
-                Mux(self.__deq_ptr > self.__enq_ptr,
-                    self.entries + self.__ptr_diff,
-                    self.__ptr_diff)))
+            m.d.comb += self.count.eq(Mux(ptr_match,
+                                          Mux(maybe_full, self.depth, 0),
+                                          Mux(deq_ptr > enq_ptr,
+                                              self.depth + ptr_diff,
+                                              ptr_diff)))
 
         return m
 
 
 if __name__ == "__main__":
-    reg_stage = Queue(1, 1, pipe=True)
-    break_ready_chain_stage = Queue(1, 1, flow=True)
+    reg_stage = Queue(1, 2, pipe=True)
+    break_ready_chain_stage = Queue(1, 2, pipe=True, fwft=True)
     m = Module()
     ports = []
 
-    def queue_ports(queue: Queue, name_prefix: str) -> List[Signal]:
+    def queue_ports(queue, name_prefix):
         retval = []
         for name in ["count",
-                     "deq_data",
-                     "deq_valid",
-                     "enq_ready"]:
+                     "dout",
+                     "readable",
+                     "writable"]:
             port = getattr(queue, name)
             signal = Signal(port.shape(), name=name_prefix+name)
             m.d.comb += signal.eq(port)
             retval.append(signal)
-        for name in ["deq_ready",
-                     "enq_data",
-                     "enq_valid"]:
+        for name in ["re",
+                     "din",
+                     "we"]:
             port = getattr(queue, name)
             signal = Signal(port.shape(), name=name_prefix+name)
             m.d.comb += port.eq(signal)