5e5687c8c3ed7da78915cf7ae34863366e2bf7d0
[nmigen.git] / nmigen / hdl / rec.py
1 from enum import Enum
2 from collections import OrderedDict
3 from functools import reduce
4
5 from .. import tracer
6 from .._utils import union, deprecated
7 from .ast import *
8
9
10 __all__ = ["Direction", "DIR_NONE", "DIR_FANOUT", "DIR_FANIN", "Layout", "Record"]
11
12
13 Direction = Enum('Direction', ('NONE', 'FANOUT', 'FANIN'))
14
15 DIR_NONE = Direction.NONE
16 DIR_FANOUT = Direction.FANOUT
17 DIR_FANIN = Direction.FANIN
18
19
20 class Layout:
21 @staticmethod
22 def cast(obj, *, src_loc_at=0):
23 if isinstance(obj, Layout):
24 return obj
25 return Layout(obj, src_loc_at=1 + src_loc_at)
26
27 def __init__(self, fields, *, src_loc_at=0):
28 self.fields = OrderedDict()
29 for field in fields:
30 if not isinstance(field, tuple) or len(field) not in (2, 3):
31 raise TypeError("Field {!r} has invalid layout: should be either "
32 "(name, shape) or (name, shape, direction)"
33 .format(field))
34 if len(field) == 2:
35 name, shape = field
36 direction = DIR_NONE
37 if isinstance(shape, list):
38 shape = Layout.cast(shape)
39 else:
40 name, shape, direction = field
41 if not isinstance(direction, Direction):
42 raise TypeError("Field {!r} has invalid direction: should be a Direction "
43 "instance like DIR_FANIN"
44 .format(field))
45 if not isinstance(name, str):
46 raise TypeError("Field {!r} has invalid name: should be a string"
47 .format(field))
48 if not isinstance(shape, Layout):
49 try:
50 # Check provided shape by calling Shape.cast and checking for exception
51 Shape.cast(shape, src_loc_at=1 + src_loc_at)
52 except Exception as error:
53 raise TypeError("Field {!r} has invalid shape: should be castable to Shape "
54 "or a list of fields of a nested record"
55 .format(field))
56 if name in self.fields:
57 raise NameError("Field {!r} has a name that is already present in the layout"
58 .format(field))
59 self.fields[name] = (shape, direction)
60
61 def __getitem__(self, item):
62 if isinstance(item, tuple):
63 return Layout([
64 (name, shape, dir)
65 for (name, (shape, dir)) in self.fields.items()
66 if name in item
67 ])
68
69 return self.fields[item]
70
71 def __iter__(self):
72 for name, (shape, dir) in self.fields.items():
73 yield (name, shape, dir)
74
75 def __eq__(self, other):
76 return self.fields == other.fields
77
78 def __repr__(self):
79 field_reprs = []
80 for name, shape, dir in self:
81 if dir == DIR_NONE:
82 field_reprs.append("({!r}, {!r})".format(name, shape))
83 else:
84 field_reprs.append("({!r}, {!r}, Direction.{})".format(name, shape, dir.name))
85 return "Layout([{}])".format(", ".join(field_reprs))
86
87
88 # Unlike most Values, Record *can* be subclassed.
89 class Record(UserValue):
90 @staticmethod
91 def like(other, *, name=None, name_suffix=None, src_loc_at=0):
92 if name is not None:
93 new_name = str(name)
94 elif name_suffix is not None:
95 new_name = other.name + str(name_suffix)
96 else:
97 new_name = tracer.get_var_name(depth=2 + src_loc_at, default=None)
98
99 def concat(a, b):
100 if a is None:
101 return b
102 return "{}__{}".format(a, b)
103
104 fields = {}
105 for field_name in other.fields:
106 field = other[field_name]
107 if isinstance(field, Record):
108 fields[field_name] = Record.like(field, name=concat(new_name, field_name),
109 src_loc_at=1 + src_loc_at)
110 else:
111 fields[field_name] = Signal.like(field, name=concat(new_name, field_name),
112 src_loc_at=1 + src_loc_at)
113
114 return Record(other.layout, name=new_name, fields=fields, src_loc_at=1)
115
116 def __init__(self, layout, *, name=None, fields=None, src_loc_at=0):
117 super().__init__(src_loc_at=src_loc_at)
118
119 if name is None:
120 name = tracer.get_var_name(depth=2 + src_loc_at, default=None)
121
122 self.name = name
123 self.src_loc = tracer.get_src_loc(src_loc_at)
124
125 def concat(a, b):
126 if a is None:
127 return b
128 return "{}__{}".format(a, b)
129
130 self.layout = Layout.cast(layout, src_loc_at=1 + src_loc_at)
131 self.fields = OrderedDict()
132 for field_name, field_shape, field_dir in self.layout:
133 if fields is not None and field_name in fields:
134 field = fields[field_name]
135 if isinstance(field_shape, Layout):
136 assert isinstance(field, Record) and field_shape == field.layout
137 else:
138 assert isinstance(field, Signal) and Shape.cast(field_shape) == field.shape()
139 self.fields[field_name] = field
140 else:
141 if isinstance(field_shape, Layout):
142 self.fields[field_name] = Record(field_shape, name=concat(name, field_name),
143 src_loc_at=1 + src_loc_at)
144 else:
145 self.fields[field_name] = Signal(field_shape, name=concat(name, field_name),
146 src_loc_at=1 + src_loc_at)
147
148 def __getattr__(self, name):
149 return self[name]
150
151 def __getitem__(self, item):
152 if isinstance(item, str):
153 try:
154 return self.fields[item]
155 except KeyError:
156 if self.name is None:
157 reference = "Unnamed record"
158 else:
159 reference = "Record '{}'".format(self.name)
160 raise AttributeError("{} does not have a field '{}'. Did you mean one of: {}?"
161 .format(reference, item, ", ".join(self.fields))) from None
162 elif isinstance(item, tuple):
163 return Record(self.layout[item], fields={
164 field_name: field_value
165 for field_name, field_value in self.fields.items()
166 if field_name in item
167 })
168 else:
169 return super().__getitem__(item)
170
171 def lower(self):
172 return Cat(self.fields.values())
173
174 def _lhs_signals(self):
175 return union((f._lhs_signals() for f in self.fields.values()), start=SignalSet())
176
177 def _rhs_signals(self):
178 return union((f._rhs_signals() for f in self.fields.values()), start=SignalSet())
179
180 def __repr__(self):
181 fields = []
182 for field_name, field in self.fields.items():
183 if isinstance(field, Signal):
184 fields.append(field_name)
185 else:
186 fields.append(repr(field))
187 name = self.name
188 if name is None:
189 name = "<unnamed>"
190 return "(rec {} {})".format(name, " ".join(fields))
191
192 def connect(self, *subordinates, include=None, exclude=None):
193 def rec_name(record):
194 if record.name is None:
195 return "unnamed record"
196 else:
197 return "record '{}'".format(record.name)
198
199 for field in include or {}:
200 if field not in self.fields:
201 raise AttributeError("Cannot include field '{}' because it is not present in {}"
202 .format(field, rec_name(self)))
203 for field in exclude or {}:
204 if field not in self.fields:
205 raise AttributeError("Cannot exclude field '{}' because it is not present in {}"
206 .format(field, rec_name(self)))
207
208 stmts = []
209 for field in self.fields:
210 if include is not None and field not in include:
211 continue
212 if exclude is not None and field in exclude:
213 continue
214
215 shape, direction = self.layout[field]
216 if not isinstance(shape, Layout) and direction == DIR_NONE:
217 raise TypeError("Cannot connect field '{}' of {} because it does not have "
218 "a direction"
219 .format(field, rec_name(self)))
220
221 item = self.fields[field]
222 subord_items = []
223 for subord in subordinates:
224 if field not in subord.fields:
225 raise AttributeError("Cannot connect field '{}' of {} to subordinate {} "
226 "because the subordinate record does not have this field"
227 .format(field, rec_name(self), rec_name(subord)))
228 subord_items.append(subord.fields[field])
229
230 if isinstance(shape, Layout):
231 sub_include = include[field] if include and field in include else None
232 sub_exclude = exclude[field] if exclude and field in exclude else None
233 stmts += item.connect(*subord_items, include=sub_include, exclude=sub_exclude)
234 else:
235 if direction == DIR_FANOUT:
236 stmts += [sub_item.eq(item) for sub_item in subord_items]
237 if direction == DIR_FANIN:
238 stmts += [item.eq(reduce(lambda a, b: a | b, subord_items))]
239
240 return stmts