hdl.rec: add basic record support.
authorwhitequark <whitequark@whitequark.org>
Fri, 28 Dec 2018 13:22:10 +0000 (13:22 +0000)
committerwhitequark <whitequark@whitequark.org>
Fri, 28 Dec 2018 13:22:10 +0000 (13:22 +0000)
doc/COMPAT_SUMMARY.md
examples/gpio.py
nmigen/__init__.py
nmigen/back/pysim.py
nmigen/back/rtlil.py
nmigen/hdl/rec.py [new file with mode: 0644]
nmigen/hdl/xfrm.py
nmigen/test/test_hdl_rec.py [new file with mode: 0644]
nmigen/test/test_sim.py

index 8cd07d5966333f157a02283c9bb56971ff82f105..5fe1682574b3d184a0a9e4decfacd308c19bee97 100644 (file)
@@ -162,13 +162,15 @@ Compatibility summary
       - (−) `timeline` ?
       - (−) `WaitTimer` ?
       - (−) `BitSlip` ?
-    - (−) `record` ?
-      - (−) `DIR_NONE`/`DIR_S_TO_M`/`DIR_M_TO_S` ?
-      - (−) `set_layout_parameters` ?
-      - (−) `layout_len` ?
-      - (−) `layout_get` ?
-      - (−) `layout_partial` ?
-      - (−) `Record` ?
+    - (−) `record` **obs** → `.hdl.rec.Record`
+      - (−) `DIR_NONE` id
+      - (−) `DIR_M_TO_S` → `DIR_FANOUT`
+      - (−) `DIR_S_TO_M` → `DIR_FANIN`
+      - (−) `set_layout_parameters` **brk**
+      - (−) `layout_len` **brk**
+      - (−) `layout_get` **brk**
+      - (−) `layout_partial` **brk**
+      - (−) `Record` id
     - (−) `resetsync` ?
       - (−) `AsyncResetSynchronizer` ?
     - (−) `roundrobin` ?
index 06caab169088899e6135999c067f38947dbfc35c..dbd2c686bb8db8406de8f3b253566f44172a1a5d 100644 (file)
@@ -10,20 +10,19 @@ class GPIO:
 
     def get_fragment(self, platform):
         m = Module()
-        m.d.comb += self.bus.dat_r.eq(self.pins[self.bus.adr])
+        m.d.comb += self.bus.r_data.eq(self.pins[self.bus.addr])
         with m.If(self.bus.we):
-            m.d.sync += self.pins[self.bus.adr].eq(self.bus.dat_w)
+            m.d.sync += self.pins[self.bus.addr].eq(self.bus.w_data)
         return m.lower(platform)
 
 
 if __name__ == "__main__":
-    # TODO: use Record
-    bus = SimpleNamespace(
-        adr  =Signal(name="adr", max=8),
-        dat_r=Signal(name="dat_r"),
-        dat_w=Signal(name="dat_w"),
-        we   =Signal(name="we"),
-    )
+    bus = Record([
+        ("addr",   3),
+        ("r_data", 1),
+        ("w_data", 1),
+        ("we",     1),
+    ])
     pins = Signal(8)
     gpio = GPIO(Array(pins), bus)
-    main(gpio, ports=[pins, bus.adr, bus.dat_r, bus.dat_w, bus.we])
+    main(gpio, ports=[pins, bus.addr, bus.r_data, bus.w_data, bus.we])
index a9ed14de77e33fde6e87bf632b16ddfbbc009fd1..220f5bb84ab31e70de3dddf2337f7bcc212a26cf 100644 (file)
@@ -3,6 +3,7 @@ from .hdl.dsl import Module
 from .hdl.cd import ClockDomain
 from .hdl.ir import Fragment, Instance
 from .hdl.mem import Memory
+from .hdl.rec import Record
 from .hdl.xfrm import ResetInserter, CEInserter
 
 from .lib.cdc import MultiReg
index 27bed924820c620915b520f970049ff3c385db82..7a5639cb8fd55089a3e28193b62c408ab15bb2a8 100644 (file)
@@ -72,7 +72,12 @@ class _State:
 normalize = Const.normalize
 
 
-class _RHSValueCompiler(ValueVisitor):
+class _ValueCompiler(ValueVisitor):
+    def on_Record(self, value):
+        return self(Cat(value.fields.values()))
+
+
+class _RHSValueCompiler(_ValueCompiler):
     def __init__(self, signal_slots, sensitivity=None, mode="rhs"):
         self.signal_slots = signal_slots
         self.sensitivity  = sensitivity
@@ -202,7 +207,7 @@ class _RHSValueCompiler(ValueVisitor):
         return eval
 
 
-class _LHSValueCompiler(ValueVisitor):
+class _LHSValueCompiler(_ValueCompiler):
     def __init__(self, signal_slots, rhs_compiler):
         self.signal_slots = signal_slots
         self.rhs_compiler = rhs_compiler
index 1572d3eceb6211219bda795d8b673bbd52af680d..9ad30336777ea68dbd61b3fbefac7f0c51c09cd0 100644 (file)
@@ -305,6 +305,9 @@ class _ValueCompiler(xfrm.ValueVisitor):
     def on_ResetSignal(self, value):
         raise NotImplementedError # :nocov:
 
+    def on_Record(self, value):
+        return self(Cat(value.fields.values()))
+
     def on_Cat(self, value):
         return "{{ {} }}".format(" ".join(reversed([self(o) for o in value.parts])))
 
diff --git a/nmigen/hdl/rec.py b/nmigen/hdl/rec.py
new file mode 100644 (file)
index 0000000..32028af
--- /dev/null
@@ -0,0 +1,114 @@
+from enum import Enum
+from collections import OrderedDict
+
+from .. import tracer
+from ..tools import union
+from .ast import *
+
+
+__all__ = ["Direction", "DIR_NONE", "DIR_FANOUT", "DIR_FANIN", "Layout", "Record"]
+
+
+Direction = Enum('Direction', ('NONE', 'FANOUT', 'FANIN'))
+
+DIR_NONE   = Direction.NONE
+DIR_FANOUT = Direction.FANOUT
+DIR_FANIN  = Direction.FANIN
+
+
+class Layout:
+    @staticmethod
+    def wrap(obj):
+        if isinstance(obj, Layout):
+            return obj
+        return Layout(obj)
+
+    def __init__(self, fields):
+        self.fields = OrderedDict()
+        for field in fields:
+            if not isinstance(field, tuple) or len(field) not in (2, 3):
+                raise TypeError("Field {!r} has invalid layout: should be either "
+                                "(name, shape) or (name, shape, direction)"
+                                .format(field))
+            if len(field) == 2:
+                name, shape = field
+                direction = DIR_NONE
+                if isinstance(shape, list):
+                    shape = Layout.wrap(shape)
+            else:
+                name, shape, direction = field
+                if not isinstance(direction, Direction):
+                    raise TypeError("Field {!r} has invalid direction: should be a Direction "
+                                    "instance like DIR_FANIN"
+                                    .format(field))
+            if not isinstance(name, str):
+                raise TypeError("Field {!r} has invalid name: should be a string"
+                                .format(field))
+            if not isinstance(shape, (int, tuple, Layout)):
+                raise TypeError("Field {!r} has invalid shape: should be an int, tuple, or list "
+                                "of fields of a nested record"
+                                .format(field))
+            if name in self.fields:
+                raise NameError("Field {!r} has a name that is already present in the layout"
+                                .format(field))
+            self.fields[name] = (shape, direction)
+
+    def __getitem__(self, name):
+        return self.fields[name]
+
+    def __iter__(self):
+        for name, (shape, dir) in self.fields.items():
+            yield (name, shape, dir)
+
+
+class Record(Value):
+    __slots__ = ("fields",)
+
+    def __init__(self, layout, name=None):
+        if name is None:
+            try:
+                name = tracer.get_var_name()
+            except tracer.NameNotFound:
+                pass
+        self.name    = name
+        self.src_loc = tracer.get_src_loc()
+
+        def concat(a, b):
+            if a is None:
+                return b
+            return "{}_{}".format(a, b)
+
+        self.layout = Layout.wrap(layout)
+        self.fields = OrderedDict()
+        for field_name, field_shape, field_dir in self.layout:
+            if isinstance(field_shape, Layout):
+                self.fields[field_name] = Record(field_shape, name=concat(name, field_name))
+            else:
+                self.fields[field_name] = Signal(field_shape, name=concat(name, field_name))
+
+    def __getattr__(self, name):
+        return self.fields[name]
+
+    def __getitem__(self, name):
+        return self.fields[name]
+
+    def shape(self):
+        return sum(len(f) for f in self.fields.values()), False
+
+    def _lhs_signals(self):
+        return union((f._lhs_signals() for f in self.fields.values()), start=SignalSet())
+
+    def _rhs_signals(self):
+        return union((f._rhs_signals() for f in self.fields.values()), start=SignalSet())
+
+    def __repr__(self):
+        fields = []
+        for field_name, field in self.fields.items():
+            if isinstance(field, Signal):
+                fields.append(field_name)
+            else:
+                fields.append(repr(field))
+        name = self.name
+        if name is None:
+            name = "<unnamed>"
+        return "(rec {} {})".format(name, " ".join(fields))
index 2f1f9377f085c925a54ae63adadfee54258143cc..e70f9d17fd140c82bb282fd9729416eeb6db711a 100644 (file)
@@ -7,6 +7,7 @@ from .ast import *
 from .ast import _StatementList
 from .cd import *
 from .ir import *
+from .rec import *
 
 
 __all__ = ["ValueVisitor", "ValueTransformer",
@@ -26,6 +27,10 @@ class ValueVisitor(metaclass=ABCMeta):
     def on_Signal(self, value):
         pass # :nocov:
 
+    @abstractmethod
+    def on_Record(self, value):
+        pass # :nocov:
+
     @abstractmethod
     def on_ClockSignal(self, value):
         pass # :nocov:
@@ -66,6 +71,8 @@ class ValueVisitor(metaclass=ABCMeta):
             new_value = self.on_Const(value)
         elif type(value) is Signal:
             new_value = self.on_Signal(value)
+        elif type(value) is Record:
+            new_value = self.on_Record(value)
         elif type(value) is ClockSignal:
             new_value = self.on_ClockSignal(value)
         elif type(value) is ResetSignal:
@@ -100,6 +107,9 @@ class ValueTransformer(ValueVisitor):
     def on_Signal(self, value):
         return value
 
+    def on_Record(self, value):
+        return value
+
     def on_ClockSignal(self, value):
         return value
 
diff --git a/nmigen/test/test_hdl_rec.py b/nmigen/test/test_hdl_rec.py
new file mode 100644 (file)
index 0000000..501fda9
--- /dev/null
@@ -0,0 +1,80 @@
+from ..hdl.ast import *
+from ..hdl.rec import *
+from .tools import *
+
+
+class LayoutTestCase(FHDLTestCase):
+    def test_fields(self):
+        layout = Layout.wrap([
+            ("cyc",  1),
+            ("data", (32, True)),
+            ("stb",  1, DIR_FANOUT),
+            ("ack",  1, DIR_FANIN),
+            ("info", [
+                ("a", 1),
+                ("b", 1),
+            ])
+        ])
+
+        self.assertEqual(layout["cyc"], (1, DIR_NONE))
+        self.assertEqual(layout["data"], ((32, True), DIR_NONE))
+        self.assertEqual(layout["stb"], (1, DIR_FANOUT))
+        self.assertEqual(layout["ack"], (1, DIR_FANIN))
+        sublayout = layout["info"][0]
+        self.assertEqual(layout["info"][1], DIR_NONE)
+        self.assertEqual(sublayout["a"], (1, DIR_NONE))
+        self.assertEqual(sublayout["b"], (1, DIR_NONE))
+
+    def test_wrong_field(self):
+        with self.assertRaises(TypeError,
+                msg="Field (1,) has invalid layout: should be either (name, shape) or "
+                    "(name, shape, direction)"):
+            Layout.wrap([(1,)])
+
+    def test_wrong_name(self):
+        with self.assertRaises(TypeError,
+                msg="Field (1, 1) has invalid name: should be a string"):
+            Layout.wrap([(1, 1)])
+
+    def test_wrong_name_duplicate(self):
+        with self.assertRaises(NameError,
+                msg="Field ('a', 2) has a name that is already present in the layout"):
+            Layout.wrap([("a", 1), ("a", 2)])
+
+    def test_wrong_direction(self):
+        with self.assertRaises(TypeError,
+                msg="Field ('a', 1, 0) has invalid direction: should be a Direction "
+                    "instance like DIR_FANIN"):
+            Layout.wrap([("a", 1, 0)])
+
+    def test_wrong_shape(self):
+        with self.assertRaises(TypeError,
+                msg="Field ('a', 'x') has invalid shape: should be an int, tuple, or "
+                    "list of fields of a nested record"):
+            Layout.wrap([("a", "x")])
+
+
+class RecordTestCase(FHDLTestCase):
+    def test_basic(self):
+        r = Record([
+            ("stb",  1),
+            ("data", 32),
+            ("info", [
+                ("a", 1),
+                ("b", 1),
+            ])
+        ])
+
+        self.assertEqual(repr(r), "(rec r stb data (rec r_info a b))")
+        self.assertEqual(len(r),  35)
+        self.assertIsInstance(r.stb, Signal)
+        self.assertEqual(r.stb.name, "r_stb")
+        self.assertEqual(r["stb"].name, "r_stb")
+
+    def test_unnamed(self):
+        r = [Record([
+            ("stb", 1)
+        ])][0]
+
+        self.assertEqual(repr(r), "(rec <unnamed> stb)")
+        self.assertEqual(r.stb.name, "stb")
index 4dd7b5b16e7e16d89b977d097cc6d5757b227067..cdd83d7648af7503392030e5dff149e854f8e1f7 100644 (file)
@@ -5,6 +5,7 @@ from ..tools import flatten, union
 from ..hdl.ast import *
 from ..hdl.cd import  *
 from ..hdl.mem import *
+from ..hdl.rec import *
 from ..hdl.dsl import  *
 from ..hdl.ir import *
 from ..back.pysim import *
@@ -173,6 +174,14 @@ class SimulatorUnitTestCase(FHDLTestCase):
         stmt = lambda y, a: [Cat(l, m, n).eq(a), y.eq(Cat(n, m, l))]
         self.assertStatement(stmt, [C(0b100101110, 9)], C(0b110101100, 9))
 
+    def test_record(self):
+        rec = Record([
+            ("l", 1),
+            ("m", 2),
+        ])
+        stmt = lambda y, a: [rec.eq(a), y.eq(rec)]
+        self.assertStatement(stmt, [C(0b101, 3)], C(0b101, 3))
+
     def test_repl(self):
         stmt = lambda y, a: y.eq(Repl(a, 3))
         self.assertStatement(stmt, [C(0b10, 2)], C(0b101010, 6))