hdl.rec: migrate Record from UserValue to ValueCastable.
authorawygle <awygle@gmail.com>
Fri, 6 Nov 2020 01:10:39 +0000 (17:10 -0800)
committerGitHub <noreply@github.com>
Fri, 6 Nov 2020 01:10:39 +0000 (01:10 +0000)
Closes #528.

nmigen/hdl/rec.py
tests/test_hdl_rec.py

index 5e5687c8c3ed7da78915cf7ae34863366e2bf7d0..ddc833424c632a68d91fb235f3bff639e129b4f5 100644 (file)
@@ -1,6 +1,6 @@
 from enum import Enum
 from collections import OrderedDict
-from functools import reduce
+from functools import reduce, wraps
 
 from .. import tracer
 from .._utils import union, deprecated
@@ -85,8 +85,7 @@ class Layout:
         return "Layout([{}])".format(", ".join(field_reprs))
 
 
-# Unlike most Values, Record *can* be subclassed.
-class Record(UserValue):
+class Record(ValueCastable):
     @staticmethod
     def like(other, *, name=None, name_suffix=None, src_loc_at=0):
         if name is not None:
@@ -114,8 +113,6 @@ class Record(UserValue):
         return Record(other.layout, name=new_name, fields=fields, src_loc_at=1)
 
     def __init__(self, layout, *, name=None, fields=None, src_loc_at=0):
-        super().__init__(src_loc_at=src_loc_at)
-
         if name is None:
             name = tracer.get_var_name(depth=2 + src_loc_at, default=None)
 
@@ -146,7 +143,17 @@ class Record(UserValue):
                                                      src_loc_at=1 + src_loc_at)
 
     def __getattr__(self, name):
-        return self[name]
+        # must check `getattr` before `self` - we need to hit Value methods before fields
+        try:
+            value_attr = getattr(Value, name)
+            if callable(value_attr):
+                @wraps(value_attr)
+                def _wrapper(*args, **kwargs):
+                    return value_attr(self, *args, **kwargs)
+                return _wrapper
+            return value_attr
+        except AttributeError:
+            return self[name]
 
     def __getitem__(self, item):
         if isinstance(item, str):
@@ -166,11 +173,23 @@ class Record(UserValue):
                 if field_name in item
             })
         else:
-            return super().__getitem__(item)
+            try:
+                return Value.__getitem__(self, item)
+            except KeyError:
+                if self.name is None:
+                    reference = "Unnamed record"
+                else:
+                    reference = "Record '{}'".format(self.name)
+                raise AttributeError("{} does not have a field '{}'. Did you mean one of: {}?"
+                                     .format(reference, item, ", ".join(self.fields))) from None
 
-    def lower(self):
+    @ValueCastable.lowermethod
+    def as_value(self):
         return Cat(self.fields.values())
 
+    def __len__(self):
+        return len(self.as_value())
+
     def _lhs_signals(self):
         return union((f._lhs_signals() for f in self.fields.values()), start=SignalSet())
 
index 718fa4a74a38ed485b4a486ef7c81045c00af854..7e8ae53c03a920708604d1e2a9521269506caa8e 100644 (file)
@@ -135,8 +135,8 @@ class RecordTestCase(FHDLTestCase):
             ("stb",  1),
         ])
 
-        self.assertEqual(repr(r[0]),   "(slice (rec r data stb) 0:1)")
-        self.assertEqual(repr(r[0:3]), "(slice (rec r data stb) 0:3)")
+        self.assertEqual(repr(r[0]),   "(slice (cat (sig r__data) (sig r__stb)) 0:1)")
+        self.assertEqual(repr(r[0:3]), "(slice (cat (sig r__data) (sig r__stb)) 0:3)")
 
     def test_wrong_field(self):
         r = Record([