hdl.rec: fix Record.like() being called through a subclass.
[nmigen.git] / nmigen / hdl / rec.py
1 from enum import Enum
2 from collections import OrderedDict
3 from functools import reduce
4
5 from .. import tracer
6 from .._utils import union, deprecated
7 from .ast import *
8
9
10 __all__ = ["Direction", "DIR_NONE", "DIR_FANOUT", "DIR_FANIN", "Layout", "Record"]
11
12
13 Direction = Enum('Direction', ('NONE', 'FANOUT', 'FANIN'))
14
15 DIR_NONE = Direction.NONE
16 DIR_FANOUT = Direction.FANOUT
17 DIR_FANIN = Direction.FANIN
18
19
20 class Layout:
21 @staticmethod
22 def cast(obj, *, src_loc_at=0):
23 if isinstance(obj, Layout):
24 return obj
25 return Layout(obj, src_loc_at=1 + src_loc_at)
26
27 # TODO(nmigen-0.2): remove this
28 @classmethod
29 @deprecated("instead of `Layout.wrap`, use `Layout.cast`")
30 def wrap(cls, obj, *, src_loc_at=0):
31 return cls.cast(obj, src_loc_at=1 + src_loc_at)
32
33 def __init__(self, fields, *, src_loc_at=0):
34 self.fields = OrderedDict()
35 for field in fields:
36 if not isinstance(field, tuple) or len(field) not in (2, 3):
37 raise TypeError("Field {!r} has invalid layout: should be either "
38 "(name, shape) or (name, shape, direction)"
39 .format(field))
40 if len(field) == 2:
41 name, shape = field
42 direction = DIR_NONE
43 if isinstance(shape, list):
44 shape = Layout.cast(shape)
45 else:
46 name, shape, direction = field
47 if not isinstance(direction, Direction):
48 raise TypeError("Field {!r} has invalid direction: should be a Direction "
49 "instance like DIR_FANIN"
50 .format(field))
51 if not isinstance(name, str):
52 raise TypeError("Field {!r} has invalid name: should be a string"
53 .format(field))
54 if not isinstance(shape, Layout):
55 try:
56 shape = Shape.cast(shape, src_loc_at=1 + src_loc_at)
57 except Exception as error:
58 raise TypeError("Field {!r} has invalid shape: should be castable to Shape "
59 "or a list of fields of a nested record"
60 .format(field))
61 if name in self.fields:
62 raise NameError("Field {!r} has a name that is already present in the layout"
63 .format(field))
64 self.fields[name] = (shape, direction)
65
66 def __getitem__(self, item):
67 if isinstance(item, tuple):
68 return Layout([
69 (name, shape, dir)
70 for (name, (shape, dir)) in self.fields.items()
71 if name in item
72 ])
73
74 return self.fields[item]
75
76 def __iter__(self):
77 for name, (shape, dir) in self.fields.items():
78 yield (name, shape, dir)
79
80 def __eq__(self, other):
81 return self.fields == other.fields
82
83
84 # Unlike most Values, Record *can* be subclassed.
85 class Record(Value):
86 @staticmethod
87 def like(other, *, name=None, name_suffix=None, src_loc_at=0):
88 if name is not None:
89 new_name = str(name)
90 elif name_suffix is not None:
91 new_name = other.name + str(name_suffix)
92 else:
93 new_name = tracer.get_var_name(depth=2 + src_loc_at, default=None)
94
95 def concat(a, b):
96 if a is None:
97 return b
98 return "{}__{}".format(a, b)
99
100 fields = {}
101 for field_name in other.fields:
102 field = other[field_name]
103 if isinstance(field, Record):
104 fields[field_name] = Record.like(field, name=concat(new_name, field_name),
105 src_loc_at=1 + src_loc_at)
106 else:
107 fields[field_name] = Signal.like(field, name=concat(new_name, field_name),
108 src_loc_at=1 + src_loc_at)
109
110 return Record(other.layout, name=new_name, fields=fields, src_loc_at=1)
111
112 def __init__(self, layout, *, name=None, fields=None, src_loc_at=0):
113 if name is None:
114 name = tracer.get_var_name(depth=2 + src_loc_at, default=None)
115
116 self.name = name
117 self.src_loc = tracer.get_src_loc(src_loc_at)
118
119 def concat(a, b):
120 if a is None:
121 return b
122 return "{}__{}".format(a, b)
123
124 self.layout = Layout.cast(layout, src_loc_at=1 + src_loc_at)
125 self.fields = OrderedDict()
126 for field_name, field_shape, field_dir in self.layout:
127 if fields is not None and field_name in fields:
128 field = fields[field_name]
129 if isinstance(field_shape, Layout):
130 assert isinstance(field, Record) and field_shape == field.layout
131 else:
132 assert isinstance(field, Signal) and field_shape == field.shape()
133 self.fields[field_name] = field
134 else:
135 if isinstance(field_shape, Layout):
136 self.fields[field_name] = Record(field_shape, name=concat(name, field_name),
137 src_loc_at=1 + src_loc_at)
138 else:
139 self.fields[field_name] = Signal(field_shape, name=concat(name, field_name),
140 src_loc_at=1 + src_loc_at)
141
142 def __getattr__(self, name):
143 return self[name]
144
145 def __getitem__(self, item):
146 if isinstance(item, str):
147 try:
148 return self.fields[item]
149 except KeyError:
150 if self.name is None:
151 reference = "Unnamed record"
152 else:
153 reference = "Record '{}'".format(self.name)
154 raise AttributeError("{} does not have a field '{}'. Did you mean one of: {}?"
155 .format(reference, item, ", ".join(self.fields))) from None
156 elif isinstance(item, tuple):
157 return Record(self.layout[item], fields={
158 field_name: field_value
159 for field_name, field_value in self.fields.items()
160 if field_name in item
161 })
162 else:
163 return super().__getitem__(item)
164
165 def shape(self):
166 return Shape(sum(len(f) for f in self.fields.values()))
167
168 def _lhs_signals(self):
169 return union((f._lhs_signals() for f in self.fields.values()), start=SignalSet())
170
171 def _rhs_signals(self):
172 return union((f._rhs_signals() for f in self.fields.values()), start=SignalSet())
173
174 def __repr__(self):
175 fields = []
176 for field_name, field in self.fields.items():
177 if isinstance(field, Signal):
178 fields.append(field_name)
179 else:
180 fields.append(repr(field))
181 name = self.name
182 if name is None:
183 name = "<unnamed>"
184 return "(rec {} {})".format(name, " ".join(fields))
185
186 def connect(self, *subordinates, include=None, exclude=None):
187 def rec_name(record):
188 if record.name is None:
189 return "unnamed record"
190 else:
191 return "record '{}'".format(record.name)
192
193 for field in include or {}:
194 if field not in self.fields:
195 raise AttributeError("Cannot include field '{}' because it is not present in {}"
196 .format(field, rec_name(self)))
197 for field in exclude or {}:
198 if field not in self.fields:
199 raise AttributeError("Cannot exclude field '{}' because it is not present in {}"
200 .format(field, rec_name(self)))
201
202 stmts = []
203 for field in self.fields:
204 if include is not None and field not in include:
205 continue
206 if exclude is not None and field in exclude:
207 continue
208
209 shape, direction = self.layout[field]
210 if not isinstance(shape, Layout) and direction == DIR_NONE:
211 raise TypeError("Cannot connect field '{}' of {} because it does not have "
212 "a direction"
213 .format(field, rec_name(self)))
214
215 item = self.fields[field]
216 subord_items = []
217 for subord in subordinates:
218 if field not in subord.fields:
219 raise AttributeError("Cannot connect field '{}' of {} to subordinate {} "
220 "because the subordinate record does not have this field"
221 .format(field, rec_name(self), rec_name(subord)))
222 subord_items.append(subord.fields[field])
223
224 if isinstance(shape, Layout):
225 sub_include = include[field] if include and field in include else None
226 sub_exclude = exclude[field] if exclude and field in exclude else None
227 stmts += item.connect(*subord_items, include=sub_include, exclude=sub_exclude)
228 else:
229 if direction == DIR_FANOUT:
230 stmts += [sub_item.eq(item) for sub_item in subord_items]
231 if direction == DIR_FANIN:
232 stmts += [item.eq(reduce(lambda a, b: a | b, subord_items))]
233
234 return stmts