hdl.rec: fix slicing of records.
authorwhitequark <whitequark@whitequark.org>
Fri, 19 Apr 2019 19:55:39 +0000 (19:55 +0000)
committerwhitequark <whitequark@whitequark.org>
Fri, 19 Apr 2019 19:55:39 +0000 (19:55 +0000)
nmigen/hdl/rec.py
nmigen/test/test_hdl_rec.py

index f1ddbd2995e26e78db28cd8df7509491e45f8b16..8f1065b12a15d189d24b2ae71cc7dd863621c4e8 100644 (file)
@@ -85,16 +85,19 @@ class Record(Value):
     def __getattr__(self, name):
         return self[name]
 
-    def __getitem__(self, name):
-        try:
-            return self.fields[name]
-        except KeyError:
-            if self.name is None:
-                reference = "Unnamed record"
-            else:
-                reference = "Record '{}'".format(self.name)
-            raise NameError("{} does not have a field '{}'. Did you mean one of: {}?"
-                            .format(reference, name, ", ".join(self.fields))) from None
+    def __getitem__(self, item):
+        if isinstance(item, str):
+            try:
+                return self.fields[item]
+            except KeyError:
+                if self.name is None:
+                    reference = "Unnamed record"
+                else:
+                    reference = "Record '{}'".format(self.name)
+                raise NameError("{} does not have a field '{}'. Did you mean one of: {}?"
+                                .format(reference, item, ", ".join(self.fields))) from None
+        else:
+            return super().__getitem__(item)
 
     def shape(self):
         return sum(len(f) for f in self.fields.values()), False
index 65e8bf6bc2ea477c4d4de332f3ff48e19d97f05a..847eedc78c0f1db5750e1295e92a1ba0a75388f7 100644 (file)
@@ -79,6 +79,15 @@ class RecordTestCase(FHDLTestCase):
         self.assertEqual(repr(r), "(rec <unnamed> stb)")
         self.assertEqual(r.stb.name, "stb")
 
+    def test_iter(self):
+        r = Record([
+            ("data", 4),
+            ("stb",  1),
+        ])
+
+        self.assertEqual(repr(r[0]),   "(slice (rec r data stb) 0:1)")
+        self.assertEqual(repr(r[0:3]), "(slice (rec r data stb) 0:3)")
+
     def test_wrong_field(self):
         r = Record([
             ("stb", 1),