2 from collections
import OrderedDict
3 from functools
import reduce, wraps
6 from .._utils
import union
, deprecated
10 __all__
= ["Direction", "DIR_NONE", "DIR_FANOUT", "DIR_FANIN", "Layout", "Record"]
13 Direction
= Enum('Direction', ('NONE', 'FANOUT', 'FANIN'))
15 DIR_NONE
= Direction
.NONE
16 DIR_FANOUT
= Direction
.FANOUT
17 DIR_FANIN
= Direction
.FANIN
22 def cast(obj
, *, src_loc_at
=0):
23 if isinstance(obj
, Layout
):
25 return Layout(obj
, src_loc_at
=1 + src_loc_at
)
27 def __init__(self
, fields
, *, src_loc_at
=0):
28 self
.fields
= OrderedDict()
30 if not isinstance(field
, tuple) or len(field
) not in (2, 3):
31 raise TypeError("Field {!r} has invalid layout: should be either "
32 "(name, shape) or (name, shape, direction)"
37 if isinstance(shape
, list):
38 shape
= Layout
.cast(shape
)
40 name
, shape
, direction
= field
41 if not isinstance(direction
, Direction
):
42 raise TypeError("Field {!r} has invalid direction: should be a Direction "
43 "instance like DIR_FANIN"
45 if not isinstance(name
, str):
46 raise TypeError("Field {!r} has invalid name: should be a string"
48 if not isinstance(shape
, Layout
):
50 # Check provided shape by calling Shape.cast and checking for exception
51 Shape
.cast(shape
, src_loc_at
=1 + src_loc_at
)
52 except Exception as error
:
53 raise TypeError("Field {!r} has invalid shape: should be castable to Shape "
54 "or a list of fields of a nested record"
56 if name
in self
.fields
:
57 raise NameError("Field {!r} has a name that is already present in the layout"
59 self
.fields
[name
] = (shape
, direction
)
61 def __getitem__(self
, item
):
62 if isinstance(item
, tuple):
65 for (name
, (shape
, dir)) in self
.fields
.items()
69 return self
.fields
[item
]
72 for name
, (shape
, dir) in self
.fields
.items():
73 yield (name
, shape
, dir)
75 def __eq__(self
, other
):
76 return self
.fields
== other
.fields
80 for name
, shape
, dir in self
:
82 field_reprs
.append("({!r}, {!r})".format(name
, shape
))
84 field_reprs
.append("({!r}, {!r}, Direction.{})".format(name
, shape
, dir.name
))
85 return "Layout([{}])".format(", ".join(field_reprs
))
88 class Record(ValueCastable
):
90 def like(other
, *, name
=None, name_suffix
=None, src_loc_at
=0):
93 elif name_suffix
is not None:
94 new_name
= other
.name
+ str(name_suffix
)
96 new_name
= tracer
.get_var_name(depth
=2 + src_loc_at
, default
=None)
101 return "{}__{}".format(a
, b
)
104 for field_name
in other
.fields
:
105 field
= other
[field_name
]
106 if isinstance(field
, Record
):
107 fields
[field_name
] = Record
.like(field
, name
=concat(new_name
, field_name
),
108 src_loc_at
=1 + src_loc_at
)
110 fields
[field_name
] = Signal
.like(field
, name
=concat(new_name
, field_name
),
111 src_loc_at
=1 + src_loc_at
)
113 return Record(other
.layout
, name
=new_name
, fields
=fields
, src_loc_at
=1)
115 def __init__(self
, layout
, *, name
=None, fields
=None, src_loc_at
=0):
117 name
= tracer
.get_var_name(depth
=2 + src_loc_at
, default
=None)
120 self
.src_loc
= tracer
.get_src_loc(src_loc_at
)
125 return "{}__{}".format(a
, b
)
127 self
.layout
= Layout
.cast(layout
, src_loc_at
=1 + src_loc_at
)
128 self
.fields
= OrderedDict()
129 for field_name
, field_shape
, field_dir
in self
.layout
:
130 if fields
is not None and field_name
in fields
:
131 field
= fields
[field_name
]
132 if isinstance(field_shape
, Layout
):
133 assert isinstance(field
, Record
) and field_shape
== field
.layout
135 assert isinstance(field
, Signal
) and Shape
.cast(field_shape
) == field
.shape()
136 self
.fields
[field_name
] = field
138 if isinstance(field_shape
, Layout
):
139 self
.fields
[field_name
] = Record(field_shape
, name
=concat(name
, field_name
),
140 src_loc_at
=1 + src_loc_at
)
142 self
.fields
[field_name
] = Signal(field_shape
, name
=concat(name
, field_name
),
143 src_loc_at
=1 + src_loc_at
)
145 def __getattr__(self
, name
):
146 # must check `getattr` before `self` - we need to hit Value methods before fields
148 value_attr
= getattr(Value
, name
)
149 if callable(value_attr
):
151 def _wrapper(*args
, **kwargs
):
152 return value_attr(self
, *args
, **kwargs
)
155 except AttributeError:
158 def __getitem__(self
, item
):
159 if isinstance(item
, str):
161 return self
.fields
[item
]
163 if self
.name
is None:
164 reference
= "Unnamed record"
166 reference
= "Record '{}'".format(self
.name
)
167 raise AttributeError("{} does not have a field '{}'. Did you mean one of: {}?"
168 .format(reference
, item
, ", ".join(self
.fields
))) from None
169 elif isinstance(item
, tuple):
170 return Record(self
.layout
[item
], fields
={
171 field_name
: field_value
172 for field_name
, field_value
in self
.fields
.items()
173 if field_name
in item
177 return Value
.__getitem
__(self
, item
)
179 if self
.name
is None:
180 reference
= "Unnamed record"
182 reference
= "Record '{}'".format(self
.name
)
183 raise AttributeError("{} does not have a field '{}'. Did you mean one of: {}?"
184 .format(reference
, item
, ", ".join(self
.fields
))) from None
186 @ValueCastable.lowermethod
188 return Cat(self
.fields
.values())
191 return len(self
.as_value())
193 def _lhs_signals(self
):
194 return union((f
._lhs
_signals
() for f
in self
.fields
.values()), start
=SignalSet())
196 def _rhs_signals(self
):
197 return union((f
._rhs
_signals
() for f
in self
.fields
.values()), start
=SignalSet())
201 for field_name
, field
in self
.fields
.items():
202 if isinstance(field
, Signal
):
203 fields
.append(field_name
)
205 fields
.append(repr(field
))
209 return "(rec {} {})".format(name
, " ".join(fields
))
211 def connect(self
, *subordinates
, include
=None, exclude
=None):
212 def rec_name(record
):
213 if record
.name
is None:
214 return "unnamed record"
216 return "record '{}'".format(record
.name
)
218 for field
in include
or {}:
219 if field
not in self
.fields
:
220 raise AttributeError("Cannot include field '{}' because it is not present in {}"
221 .format(field
, rec_name(self
)))
222 for field
in exclude
or {}:
223 if field
not in self
.fields
:
224 raise AttributeError("Cannot exclude field '{}' because it is not present in {}"
225 .format(field
, rec_name(self
)))
228 for field
in self
.fields
:
229 if include
is not None and field
not in include
:
231 if exclude
is not None and field
in exclude
:
234 shape
, direction
= self
.layout
[field
]
235 if not isinstance(shape
, Layout
) and direction
== DIR_NONE
:
236 raise TypeError("Cannot connect field '{}' of {} because it does not have "
238 .format(field
, rec_name(self
)))
240 item
= self
.fields
[field
]
242 for subord
in subordinates
:
243 if field
not in subord
.fields
:
244 raise AttributeError("Cannot connect field '{}' of {} to subordinate {} "
245 "because the subordinate record does not have this field"
246 .format(field
, rec_name(self
), rec_name(subord
)))
247 subord_items
.append(subord
.fields
[field
])
249 if isinstance(shape
, Layout
):
250 sub_include
= include
[field
] if include
and field
in include
else None
251 sub_exclude
= exclude
[field
] if exclude
and field
in exclude
else None
252 stmts
+= item
.connect(*subord_items
, include
=sub_include
, exclude
=sub_exclude
)
254 if direction
== DIR_FANOUT
:
255 stmts
+= [sub_item
.eq(item
) for sub_item
in subord_items
]
256 if direction
== DIR_FANIN
:
257 stmts
+= [item
.eq(reduce(lambda a
, b
: a | b
, subord_items
))]