5e5687c8c3ed7da78915cf7ae34863366e2bf7d0
2 from collections
import OrderedDict
3 from functools
import reduce
6 from .._utils
import union
, deprecated
10 __all__
= ["Direction", "DIR_NONE", "DIR_FANOUT", "DIR_FANIN", "Layout", "Record"]
13 Direction
= Enum('Direction', ('NONE', 'FANOUT', 'FANIN'))
15 DIR_NONE
= Direction
.NONE
16 DIR_FANOUT
= Direction
.FANOUT
17 DIR_FANIN
= Direction
.FANIN
22 def cast(obj
, *, src_loc_at
=0):
23 if isinstance(obj
, Layout
):
25 return Layout(obj
, src_loc_at
=1 + src_loc_at
)
27 def __init__(self
, fields
, *, src_loc_at
=0):
28 self
.fields
= OrderedDict()
30 if not isinstance(field
, tuple) or len(field
) not in (2, 3):
31 raise TypeError("Field {!r} has invalid layout: should be either "
32 "(name, shape) or (name, shape, direction)"
37 if isinstance(shape
, list):
38 shape
= Layout
.cast(shape
)
40 name
, shape
, direction
= field
41 if not isinstance(direction
, Direction
):
42 raise TypeError("Field {!r} has invalid direction: should be a Direction "
43 "instance like DIR_FANIN"
45 if not isinstance(name
, str):
46 raise TypeError("Field {!r} has invalid name: should be a string"
48 if not isinstance(shape
, Layout
):
50 # Check provided shape by calling Shape.cast and checking for exception
51 Shape
.cast(shape
, src_loc_at
=1 + src_loc_at
)
52 except Exception as error
:
53 raise TypeError("Field {!r} has invalid shape: should be castable to Shape "
54 "or a list of fields of a nested record"
56 if name
in self
.fields
:
57 raise NameError("Field {!r} has a name that is already present in the layout"
59 self
.fields
[name
] = (shape
, direction
)
61 def __getitem__(self
, item
):
62 if isinstance(item
, tuple):
65 for (name
, (shape
, dir)) in self
.fields
.items()
69 return self
.fields
[item
]
72 for name
, (shape
, dir) in self
.fields
.items():
73 yield (name
, shape
, dir)
75 def __eq__(self
, other
):
76 return self
.fields
== other
.fields
80 for name
, shape
, dir in self
:
82 field_reprs
.append("({!r}, {!r})".format(name
, shape
))
84 field_reprs
.append("({!r}, {!r}, Direction.{})".format(name
, shape
, dir.name
))
85 return "Layout([{}])".format(", ".join(field_reprs
))
88 # Unlike most Values, Record *can* be subclassed.
89 class Record(UserValue
):
91 def like(other
, *, name
=None, name_suffix
=None, src_loc_at
=0):
94 elif name_suffix
is not None:
95 new_name
= other
.name
+ str(name_suffix
)
97 new_name
= tracer
.get_var_name(depth
=2 + src_loc_at
, default
=None)
102 return "{}__{}".format(a
, b
)
105 for field_name
in other
.fields
:
106 field
= other
[field_name
]
107 if isinstance(field
, Record
):
108 fields
[field_name
] = Record
.like(field
, name
=concat(new_name
, field_name
),
109 src_loc_at
=1 + src_loc_at
)
111 fields
[field_name
] = Signal
.like(field
, name
=concat(new_name
, field_name
),
112 src_loc_at
=1 + src_loc_at
)
114 return Record(other
.layout
, name
=new_name
, fields
=fields
, src_loc_at
=1)
116 def __init__(self
, layout
, *, name
=None, fields
=None, src_loc_at
=0):
117 super().__init
__(src_loc_at
=src_loc_at
)
120 name
= tracer
.get_var_name(depth
=2 + src_loc_at
, default
=None)
123 self
.src_loc
= tracer
.get_src_loc(src_loc_at
)
128 return "{}__{}".format(a
, b
)
130 self
.layout
= Layout
.cast(layout
, src_loc_at
=1 + src_loc_at
)
131 self
.fields
= OrderedDict()
132 for field_name
, field_shape
, field_dir
in self
.layout
:
133 if fields
is not None and field_name
in fields
:
134 field
= fields
[field_name
]
135 if isinstance(field_shape
, Layout
):
136 assert isinstance(field
, Record
) and field_shape
== field
.layout
138 assert isinstance(field
, Signal
) and Shape
.cast(field_shape
) == field
.shape()
139 self
.fields
[field_name
] = field
141 if isinstance(field_shape
, Layout
):
142 self
.fields
[field_name
] = Record(field_shape
, name
=concat(name
, field_name
),
143 src_loc_at
=1 + src_loc_at
)
145 self
.fields
[field_name
] = Signal(field_shape
, name
=concat(name
, field_name
),
146 src_loc_at
=1 + src_loc_at
)
148 def __getattr__(self
, name
):
151 def __getitem__(self
, item
):
152 if isinstance(item
, str):
154 return self
.fields
[item
]
156 if self
.name
is None:
157 reference
= "Unnamed record"
159 reference
= "Record '{}'".format(self
.name
)
160 raise AttributeError("{} does not have a field '{}'. Did you mean one of: {}?"
161 .format(reference
, item
, ", ".join(self
.fields
))) from None
162 elif isinstance(item
, tuple):
163 return Record(self
.layout
[item
], fields
={
164 field_name
: field_value
165 for field_name
, field_value
in self
.fields
.items()
166 if field_name
in item
169 return super().__getitem
__(item
)
172 return Cat(self
.fields
.values())
174 def _lhs_signals(self
):
175 return union((f
._lhs
_signals
() for f
in self
.fields
.values()), start
=SignalSet())
177 def _rhs_signals(self
):
178 return union((f
._rhs
_signals
() for f
in self
.fields
.values()), start
=SignalSet())
182 for field_name
, field
in self
.fields
.items():
183 if isinstance(field
, Signal
):
184 fields
.append(field_name
)
186 fields
.append(repr(field
))
190 return "(rec {} {})".format(name
, " ".join(fields
))
192 def connect(self
, *subordinates
, include
=None, exclude
=None):
193 def rec_name(record
):
194 if record
.name
is None:
195 return "unnamed record"
197 return "record '{}'".format(record
.name
)
199 for field
in include
or {}:
200 if field
not in self
.fields
:
201 raise AttributeError("Cannot include field '{}' because it is not present in {}"
202 .format(field
, rec_name(self
)))
203 for field
in exclude
or {}:
204 if field
not in self
.fields
:
205 raise AttributeError("Cannot exclude field '{}' because it is not present in {}"
206 .format(field
, rec_name(self
)))
209 for field
in self
.fields
:
210 if include
is not None and field
not in include
:
212 if exclude
is not None and field
in exclude
:
215 shape
, direction
= self
.layout
[field
]
216 if not isinstance(shape
, Layout
) and direction
== DIR_NONE
:
217 raise TypeError("Cannot connect field '{}' of {} because it does not have "
219 .format(field
, rec_name(self
)))
221 item
= self
.fields
[field
]
223 for subord
in subordinates
:
224 if field
not in subord
.fields
:
225 raise AttributeError("Cannot connect field '{}' of {} to subordinate {} "
226 "because the subordinate record does not have this field"
227 .format(field
, rec_name(self
), rec_name(subord
)))
228 subord_items
.append(subord
.fields
[field
])
230 if isinstance(shape
, Layout
):
231 sub_include
= include
[field
] if include
and field
in include
else None
232 sub_exclude
= exclude
[field
] if exclude
and field
in exclude
else None
233 stmts
+= item
.connect(*subord_items
, include
=sub_include
, exclude
=sub_exclude
)
235 if direction
== DIR_FANOUT
:
236 stmts
+= [sub_item
.eq(item
) for sub_item
in subord_items
]
237 if direction
== DIR_FANIN
:
238 stmts
+= [item
.eq(reduce(lambda a
, b
: a | b
, subord_items
))]