hdl.rec: migrate Record from UserValue to ValueCastable.
[nmigen.git] / nmigen / hdl / rec.py
1 from enum import Enum
2 from collections import OrderedDict
3 from functools import reduce, wraps
4
5 from .. import tracer
6 from .._utils import union, deprecated
7 from .ast import *
8
9
10 __all__ = ["Direction", "DIR_NONE", "DIR_FANOUT", "DIR_FANIN", "Layout", "Record"]
11
12
13 Direction = Enum('Direction', ('NONE', 'FANOUT', 'FANIN'))
14
15 DIR_NONE = Direction.NONE
16 DIR_FANOUT = Direction.FANOUT
17 DIR_FANIN = Direction.FANIN
18
19
20 class Layout:
21 @staticmethod
22 def cast(obj, *, src_loc_at=0):
23 if isinstance(obj, Layout):
24 return obj
25 return Layout(obj, src_loc_at=1 + src_loc_at)
26
27 def __init__(self, fields, *, src_loc_at=0):
28 self.fields = OrderedDict()
29 for field in fields:
30 if not isinstance(field, tuple) or len(field) not in (2, 3):
31 raise TypeError("Field {!r} has invalid layout: should be either "
32 "(name, shape) or (name, shape, direction)"
33 .format(field))
34 if len(field) == 2:
35 name, shape = field
36 direction = DIR_NONE
37 if isinstance(shape, list):
38 shape = Layout.cast(shape)
39 else:
40 name, shape, direction = field
41 if not isinstance(direction, Direction):
42 raise TypeError("Field {!r} has invalid direction: should be a Direction "
43 "instance like DIR_FANIN"
44 .format(field))
45 if not isinstance(name, str):
46 raise TypeError("Field {!r} has invalid name: should be a string"
47 .format(field))
48 if not isinstance(shape, Layout):
49 try:
50 # Check provided shape by calling Shape.cast and checking for exception
51 Shape.cast(shape, src_loc_at=1 + src_loc_at)
52 except Exception as error:
53 raise TypeError("Field {!r} has invalid shape: should be castable to Shape "
54 "or a list of fields of a nested record"
55 .format(field))
56 if name in self.fields:
57 raise NameError("Field {!r} has a name that is already present in the layout"
58 .format(field))
59 self.fields[name] = (shape, direction)
60
61 def __getitem__(self, item):
62 if isinstance(item, tuple):
63 return Layout([
64 (name, shape, dir)
65 for (name, (shape, dir)) in self.fields.items()
66 if name in item
67 ])
68
69 return self.fields[item]
70
71 def __iter__(self):
72 for name, (shape, dir) in self.fields.items():
73 yield (name, shape, dir)
74
75 def __eq__(self, other):
76 return self.fields == other.fields
77
78 def __repr__(self):
79 field_reprs = []
80 for name, shape, dir in self:
81 if dir == DIR_NONE:
82 field_reprs.append("({!r}, {!r})".format(name, shape))
83 else:
84 field_reprs.append("({!r}, {!r}, Direction.{})".format(name, shape, dir.name))
85 return "Layout([{}])".format(", ".join(field_reprs))
86
87
88 class Record(ValueCastable):
89 @staticmethod
90 def like(other, *, name=None, name_suffix=None, src_loc_at=0):
91 if name is not None:
92 new_name = str(name)
93 elif name_suffix is not None:
94 new_name = other.name + str(name_suffix)
95 else:
96 new_name = tracer.get_var_name(depth=2 + src_loc_at, default=None)
97
98 def concat(a, b):
99 if a is None:
100 return b
101 return "{}__{}".format(a, b)
102
103 fields = {}
104 for field_name in other.fields:
105 field = other[field_name]
106 if isinstance(field, Record):
107 fields[field_name] = Record.like(field, name=concat(new_name, field_name),
108 src_loc_at=1 + src_loc_at)
109 else:
110 fields[field_name] = Signal.like(field, name=concat(new_name, field_name),
111 src_loc_at=1 + src_loc_at)
112
113 return Record(other.layout, name=new_name, fields=fields, src_loc_at=1)
114
115 def __init__(self, layout, *, name=None, fields=None, src_loc_at=0):
116 if name is None:
117 name = tracer.get_var_name(depth=2 + src_loc_at, default=None)
118
119 self.name = name
120 self.src_loc = tracer.get_src_loc(src_loc_at)
121
122 def concat(a, b):
123 if a is None:
124 return b
125 return "{}__{}".format(a, b)
126
127 self.layout = Layout.cast(layout, src_loc_at=1 + src_loc_at)
128 self.fields = OrderedDict()
129 for field_name, field_shape, field_dir in self.layout:
130 if fields is not None and field_name in fields:
131 field = fields[field_name]
132 if isinstance(field_shape, Layout):
133 assert isinstance(field, Record) and field_shape == field.layout
134 else:
135 assert isinstance(field, Signal) and Shape.cast(field_shape) == field.shape()
136 self.fields[field_name] = field
137 else:
138 if isinstance(field_shape, Layout):
139 self.fields[field_name] = Record(field_shape, name=concat(name, field_name),
140 src_loc_at=1 + src_loc_at)
141 else:
142 self.fields[field_name] = Signal(field_shape, name=concat(name, field_name),
143 src_loc_at=1 + src_loc_at)
144
145 def __getattr__(self, name):
146 # must check `getattr` before `self` - we need to hit Value methods before fields
147 try:
148 value_attr = getattr(Value, name)
149 if callable(value_attr):
150 @wraps(value_attr)
151 def _wrapper(*args, **kwargs):
152 return value_attr(self, *args, **kwargs)
153 return _wrapper
154 return value_attr
155 except AttributeError:
156 return self[name]
157
158 def __getitem__(self, item):
159 if isinstance(item, str):
160 try:
161 return self.fields[item]
162 except KeyError:
163 if self.name is None:
164 reference = "Unnamed record"
165 else:
166 reference = "Record '{}'".format(self.name)
167 raise AttributeError("{} does not have a field '{}'. Did you mean one of: {}?"
168 .format(reference, item, ", ".join(self.fields))) from None
169 elif isinstance(item, tuple):
170 return Record(self.layout[item], fields={
171 field_name: field_value
172 for field_name, field_value in self.fields.items()
173 if field_name in item
174 })
175 else:
176 try:
177 return Value.__getitem__(self, item)
178 except KeyError:
179 if self.name is None:
180 reference = "Unnamed record"
181 else:
182 reference = "Record '{}'".format(self.name)
183 raise AttributeError("{} does not have a field '{}'. Did you mean one of: {}?"
184 .format(reference, item, ", ".join(self.fields))) from None
185
186 @ValueCastable.lowermethod
187 def as_value(self):
188 return Cat(self.fields.values())
189
190 def __len__(self):
191 return len(self.as_value())
192
193 def _lhs_signals(self):
194 return union((f._lhs_signals() for f in self.fields.values()), start=SignalSet())
195
196 def _rhs_signals(self):
197 return union((f._rhs_signals() for f in self.fields.values()), start=SignalSet())
198
199 def __repr__(self):
200 fields = []
201 for field_name, field in self.fields.items():
202 if isinstance(field, Signal):
203 fields.append(field_name)
204 else:
205 fields.append(repr(field))
206 name = self.name
207 if name is None:
208 name = "<unnamed>"
209 return "(rec {} {})".format(name, " ".join(fields))
210
211 def connect(self, *subordinates, include=None, exclude=None):
212 def rec_name(record):
213 if record.name is None:
214 return "unnamed record"
215 else:
216 return "record '{}'".format(record.name)
217
218 for field in include or {}:
219 if field not in self.fields:
220 raise AttributeError("Cannot include field '{}' because it is not present in {}"
221 .format(field, rec_name(self)))
222 for field in exclude or {}:
223 if field not in self.fields:
224 raise AttributeError("Cannot exclude field '{}' because it is not present in {}"
225 .format(field, rec_name(self)))
226
227 stmts = []
228 for field in self.fields:
229 if include is not None and field not in include:
230 continue
231 if exclude is not None and field in exclude:
232 continue
233
234 shape, direction = self.layout[field]
235 if not isinstance(shape, Layout) and direction == DIR_NONE:
236 raise TypeError("Cannot connect field '{}' of {} because it does not have "
237 "a direction"
238 .format(field, rec_name(self)))
239
240 item = self.fields[field]
241 subord_items = []
242 for subord in subordinates:
243 if field not in subord.fields:
244 raise AttributeError("Cannot connect field '{}' of {} to subordinate {} "
245 "because the subordinate record does not have this field"
246 .format(field, rec_name(self), rec_name(subord)))
247 subord_items.append(subord.fields[field])
248
249 if isinstance(shape, Layout):
250 sub_include = include[field] if include and field in include else None
251 sub_exclude = exclude[field] if exclude and field in exclude else None
252 stmts += item.connect(*subord_items, include=sub_include, exclude=sub_exclude)
253 else:
254 if direction == DIR_FANOUT:
255 stmts += [sub_item.eq(item) for sub_item in subord_items]
256 if direction == DIR_FANIN:
257 stmts += [item.eq(reduce(lambda a, b: a | b, subord_items))]
258
259 return stmts