hdl.rec: implement Record.connect.
authorwhitequark <whitequark@whitequark.org>
Sun, 21 Apr 2019 06:37:08 +0000 (06:37 +0000)
committerwhitequark <whitequark@whitequark.org>
Sun, 21 Apr 2019 06:37:08 +0000 (06:37 +0000)
Fixes #31.

nmigen/hdl/rec.py
nmigen/test/test_hdl_rec.py

index 8f1065b12a15d189d24b2ae71cc7dd863621c4e8..2ab00c89886fb470b78958b6dee2c3a01ded53a4 100644 (file)
@@ -1,5 +1,6 @@
 from enum import Enum
 from collections import OrderedDict
+from functools import reduce
 
 from .. import tracer
 from ..tools import union
@@ -119,3 +120,53 @@ class Record(Value):
         if name is None:
             name = "<unnamed>"
         return "(rec {} {})".format(name, " ".join(fields))
+
+    def connect(self, *subordinates, include=None, exclude=None):
+        def rec_name(record):
+            if record.name is None:
+                return "unnamed record"
+            else:
+                return "record '{}'".format(record.name)
+
+        for field in include or {}:
+            if field not in self.fields:
+                raise AttributeError("Cannot include field '{}' because it is not present in {}"
+                                     .format(field, rec_name(self)))
+        for field in exclude or {}:
+            if field not in self.fields:
+                raise AttributeError("Cannot exclude field '{}' because it is not present in {}"
+                                     .format(field, rec_name(self)))
+
+        stmts = []
+        for field in self.fields:
+            if include is not None and field not in include:
+                continue
+            if exclude is not None and field in exclude:
+                continue
+
+            shape, direction = self.layout[field]
+            if not isinstance(shape, Layout) and direction == DIR_NONE:
+                raise TypeError("Cannot connect field '{}' of {} because it does not have "
+                                "a direction"
+                                .format(field, rec_name(self)))
+
+            item = self.fields[field]
+            subord_items = []
+            for subord in subordinates:
+                if field not in subord.fields:
+                    raise AttributeError("Cannot connect field '{}' of {} to subordinate {} "
+                                         "because the subordinate record does not have this field"
+                                         .format(field, rec_name(self), rec_name(subord)))
+                subord_items.append(subord.fields[field])
+
+            if isinstance(shape, Layout):
+                sub_include = include[field] if include and field in include else None
+                sub_exclude = exclude[field] if exclude and field in exclude else None
+                stmts += item.connect(*subord_items, include=sub_include, exclude=sub_exclude)
+            else:
+                if direction == DIR_FANOUT:
+                    stmts += [sub_item.eq(item) for sub_item in subord_items]
+                if direction == DIR_FANIN:
+                    stmts += [item.eq(reduce(lambda a, b: a | b, subord_items))]
+
+        return stmts
index 847eedc78c0f1db5750e1295e92a1ba0a75388f7..81aa89d2fd7c43fd049e3c5ad85715d3e6403330 100644 (file)
@@ -105,3 +105,121 @@ class RecordTestCase(FHDLTestCase):
         with self.assertRaises(NameError,
                 msg="Unnamed record does not have a field 'en'. Did you mean one of: stb, ack?"):
             r.en
+
+
+class ConnectTestCase(FHDLTestCase):
+    def setUp_flat(self):
+        self.core_layout = [
+            ("addr",   32, DIR_FANOUT),
+            ("data_r", 32, DIR_FANIN),
+            ("data_w", 32, DIR_FANIN),
+        ]
+        self.periph_layout = [
+            ("addr",   32, DIR_FANOUT),
+            ("data_r", 32, DIR_FANIN),
+            ("data_w", 32, DIR_FANIN),
+        ]
+
+    def setUp_nested(self):
+        self.core_layout = [
+            ("addr",   32, DIR_FANOUT),
+            ("data", [
+                ("r",  32, DIR_FANIN),
+                ("w",  32, DIR_FANIN),
+            ]),
+        ]
+        self.periph_layout = [
+            ("addr",   32, DIR_FANOUT),
+            ("data", [
+                ("r",  32, DIR_FANIN),
+                ("w",  32, DIR_FANIN),
+            ]),
+        ]
+
+    def test_flat(self):
+        self.setUp_flat()
+
+        core    = Record(self.core_layout)
+        periph1 = Record(self.periph_layout)
+        periph2 = Record(self.periph_layout)
+
+        stmts = core.connect(periph1, periph2)
+        self.assertRepr(stmts, """(
+            (eq (sig periph1__addr) (sig core__addr))
+            (eq (sig periph2__addr) (sig core__addr))
+            (eq (sig core__data_r) (| (sig periph1__data_r) (sig periph2__data_r)))
+            (eq (sig core__data_w) (| (sig periph1__data_w) (sig periph2__data_w)))
+        )""")
+
+    def test_flat_include(self):
+        self.setUp_flat()
+
+        core    = Record(self.core_layout)
+        periph1 = Record(self.periph_layout)
+        periph2 = Record(self.periph_layout)
+
+        stmts = core.connect(periph1, periph2, include={"addr": True})
+        self.assertRepr(stmts, """(
+            (eq (sig periph1__addr) (sig core__addr))
+            (eq (sig periph2__addr) (sig core__addr))
+        )""")
+
+    def test_flat_exclude(self):
+        self.setUp_flat()
+
+        core    = Record(self.core_layout)
+        periph1 = Record(self.periph_layout)
+        periph2 = Record(self.periph_layout)
+
+        stmts = core.connect(periph1, periph2, exclude={"addr": True})
+        self.assertRepr(stmts, """(
+            (eq (sig core__data_r) (| (sig periph1__data_r) (sig periph2__data_r)))
+            (eq (sig core__data_w) (| (sig periph1__data_w) (sig periph2__data_w)))
+        )""")
+
+    def test_nested(self):
+        self.setUp_nested()
+
+        core    = Record(self.core_layout)
+        periph1 = Record(self.periph_layout)
+        periph2 = Record(self.periph_layout)
+
+        stmts = core.connect(periph1, periph2)
+        self.maxDiff = None
+        self.assertRepr(stmts, """(
+            (eq (sig periph1__addr) (sig core__addr))
+            (eq (sig periph2__addr) (sig core__addr))
+            (eq (sig core__data__r) (| (sig periph1__data__r) (sig periph2__data__r)))
+            (eq (sig core__data__w) (| (sig periph1__data__w) (sig periph2__data__w)))
+        )""")
+
+    def test_wrong_include_exclude(self):
+        self.setUp_flat()
+
+        core   = Record(self.core_layout)
+        periph = Record(self.periph_layout)
+
+        with self.assertRaises(AttributeError,
+                msg="Cannot include field 'foo' because it is not present in record 'core'"):
+            core.connect(periph, include={"foo": True})
+
+        with self.assertRaises(AttributeError,
+                msg="Cannot exclude field 'foo' because it is not present in record 'core'"):
+            core.connect(periph, exclude={"foo": True})
+
+    def test_wrong_direction(self):
+        recs = [Record([("x", 1)]) for _ in range(2)]
+
+        with self.assertRaises(TypeError,
+                msg="Cannot connect field 'x' of unnamed record because it does not have "
+                    "a direction"):
+            recs[0].connect(recs[1])
+
+    def test_wrong_missing_field(self):
+        core   = Record([("addr", 32, DIR_FANOUT)])
+        periph = Record([])
+
+        with self.assertRaises(AttributeError,
+                msg="Cannot connect field 'addr' of record 'core' to subordinate record 'periph' "
+                    "because the subordinate record does not have this field"):
+            core.connect(periph)