projects
/
nmigen.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
06c7349
)
hdl.rec: migrate Record from UserValue to ValueCastable.
author
awygle
<awygle@gmail.com>
Fri, 6 Nov 2020 01:10:39 +0000
(17:10 -0800)
committer
GitHub
<noreply@github.com>
Fri, 6 Nov 2020 01:10:39 +0000
(
01:10
+0000)
Closes #528.
nmigen/hdl/rec.py
patch
|
blob
|
history
tests/test_hdl_rec.py
patch
|
blob
|
history
diff --git
a/nmigen/hdl/rec.py
b/nmigen/hdl/rec.py
index 5e5687c8c3ed7da78915cf7ae34863366e2bf7d0..ddc833424c632a68d91fb235f3bff639e129b4f5 100644
(file)
--- a/
nmigen/hdl/rec.py
+++ b/
nmigen/hdl/rec.py
@@
-1,6
+1,6
@@
from enum import Enum
from collections import OrderedDict
from enum import Enum
from collections import OrderedDict
-from functools import reduce
+from functools import reduce
, wraps
from .. import tracer
from .._utils import union, deprecated
from .. import tracer
from .._utils import union, deprecated
@@
-85,8
+85,7
@@
class Layout:
return "Layout([{}])".format(", ".join(field_reprs))
return "Layout([{}])".format(", ".join(field_reprs))
-# Unlike most Values, Record *can* be subclassed.
-class Record(UserValue):
+class Record(ValueCastable):
@staticmethod
def like(other, *, name=None, name_suffix=None, src_loc_at=0):
if name is not None:
@staticmethod
def like(other, *, name=None, name_suffix=None, src_loc_at=0):
if name is not None:
@@
-114,8
+113,6
@@
class Record(UserValue):
return Record(other.layout, name=new_name, fields=fields, src_loc_at=1)
def __init__(self, layout, *, name=None, fields=None, src_loc_at=0):
return Record(other.layout, name=new_name, fields=fields, src_loc_at=1)
def __init__(self, layout, *, name=None, fields=None, src_loc_at=0):
- super().__init__(src_loc_at=src_loc_at)
-
if name is None:
name = tracer.get_var_name(depth=2 + src_loc_at, default=None)
if name is None:
name = tracer.get_var_name(depth=2 + src_loc_at, default=None)
@@
-146,7
+143,17
@@
class Record(UserValue):
src_loc_at=1 + src_loc_at)
def __getattr__(self, name):
src_loc_at=1 + src_loc_at)
def __getattr__(self, name):
- return self[name]
+ # must check `getattr` before `self` - we need to hit Value methods before fields
+ try:
+ value_attr = getattr(Value, name)
+ if callable(value_attr):
+ @wraps(value_attr)
+ def _wrapper(*args, **kwargs):
+ return value_attr(self, *args, **kwargs)
+ return _wrapper
+ return value_attr
+ except AttributeError:
+ return self[name]
def __getitem__(self, item):
if isinstance(item, str):
def __getitem__(self, item):
if isinstance(item, str):
@@
-166,11
+173,23
@@
class Record(UserValue):
if field_name in item
})
else:
if field_name in item
})
else:
- return super().__getitem__(item)
+ try:
+ return Value.__getitem__(self, item)
+ except KeyError:
+ if self.name is None:
+ reference = "Unnamed record"
+ else:
+ reference = "Record '{}'".format(self.name)
+ raise AttributeError("{} does not have a field '{}'. Did you mean one of: {}?"
+ .format(reference, item, ", ".join(self.fields))) from None
- def lower(self):
+ @ValueCastable.lowermethod
+ def as_value(self):
return Cat(self.fields.values())
return Cat(self.fields.values())
+ def __len__(self):
+ return len(self.as_value())
+
def _lhs_signals(self):
return union((f._lhs_signals() for f in self.fields.values()), start=SignalSet())
def _lhs_signals(self):
return union((f._lhs_signals() for f in self.fields.values()), start=SignalSet())
diff --git
a/tests/test_hdl_rec.py
b/tests/test_hdl_rec.py
index 718fa4a74a38ed485b4a486ef7c81045c00af854..7e8ae53c03a920708604d1e2a9521269506caa8e 100644
(file)
--- a/
tests/test_hdl_rec.py
+++ b/
tests/test_hdl_rec.py
@@
-135,8
+135,8
@@
class RecordTestCase(FHDLTestCase):
("stb", 1),
])
("stb", 1),
])
- self.assertEqual(repr(r[0]), "(slice (
rec r data stb
) 0:1)")
- self.assertEqual(repr(r[0:3]), "(slice (
rec r data stb
) 0:3)")
+ self.assertEqual(repr(r[0]), "(slice (
cat (sig r__data) (sig r__stb)
) 0:1)")
+ self.assertEqual(repr(r[0:3]), "(slice (
cat (sig r__data) (sig r__stb)
) 0:3)")
def test_wrong_field(self):
r = Record([
def test_wrong_field(self):
r = Record([