1 # SPDX-License-Identifier: LGPL-3-or-later
2 # See Notices.txt for copyright information
5 from typing
import Mapping
8 from types
import MappingProxyType
9 from contextlib
import contextmanager
11 from nmigen
.hdl
.ast
import Signal
16 return super().__str
__()
26 class IntElWid(ElWid
):
34 """A map from ElWid values to Python values.
35 SimdMap instances are immutable."""
37 ALL_ELWIDTHS
= (*FpElWid
, *IntElWid
)
38 __slots__
= ("__map",)
41 def extract_value(cls
, elwid
, values
, default
=None):
42 """get the value for elwid.
43 if `values` is a `SimdMap` or a `Mapping`, then return the
44 corresponding value for `elwid`, recursing until finding a non-map.
45 if `values` ever ends up not existing (in the case of a map) or being
46 `None`, return `default`.
49 SimdMap.extract_value(IntElWid.I8, 5) == 5
50 SimdMap.extract_value(IntElWid.I8, None) == None
51 SimdMap.extract_value(IntElWid.I8, None, 3) == 3
52 SimdMap.extract_value(IntElWid.I8, {}) == None
53 SimdMap.extract_value(IntElWid.I8, {IntElWid.I8: 5}) == 5
54 SimdMap.extract_value(IntElWid.I8, {
55 IntElWid.I8: {IntElWid.I8: 5},
57 SimdMap.extract_value(IntElWid.I8, {
58 IntElWid.I8: SimdMap({IntElWid.I8: 5}),
61 assert elwid
in cls
.ALL_ELWIDTHS
63 while values
is not None:
64 # specifically use base class to catch all SimdMap instances
65 if isinstance(values
, SimdMap
):
66 values
= values
.__map
.get(elwid
)
67 elif isinstance(values
, Mapping
):
68 values
= values
.get(elwid
)
72 # use object.__repr__ since repr() would probably recurse forever
73 assert step
< 10000, (f
"can't resolve infinitely recursive "
74 f
"value {object.__repr__(values)}")
77 def __init__(self
, values
=None):
78 """construct a SimdMap"""
80 for elwid
in self
.ALL_ELWIDTHS
:
81 v
= self
.extract_value(elwid
, values
)
84 self
.__map
= MappingProxyType(mapping
)
88 """the values as a read-only Mapping[ElWid, Any]"""
92 return self
.__map
.values()
95 return self
.__map
.keys()
98 return self
.__map
.items()
101 def map_with_elwid(cls
, f
, *args
):
102 """get the SimdMap of the results of calling
103 `f(elwid, value1, value2, value3, ...)` where
104 `value1`, `value2`, `value3`, ... are the results of calling
105 `cls.extract_value` on each `args`.
107 This is similar to Python's built-in `map` function.
110 SimdMap.map_with_elwid(lambda elwid, a: a + 1, {IntElWid.I32: 5}) ==
111 SimdMap({IntElWid.I32: 6})
112 SimdMap.map_with_elwid(lambda elwid, a: a + 1, 3) ==
113 SimdMap({IntElWid.I8: 4, IntElWid.I16: 4, ...})
114 SimdMap.map_with_elwid(lambda elwid, a, b: a + b,
116 ) == SimdMap({IntElWid.I8: 7})
117 SimdMap.map_with_elwid(lambda elwid: elwid.name) ==
118 SimdMap({IntElWid.I8: "I8", IntElWid.I16: "I16"})
121 for elwid
in cls
.ALL_ELWIDTHS
:
122 extracted_args
= [cls
.extract_value(elwid
, arg
) for arg
in args
]
123 if None not in extracted_args
:
124 retval
[elwid
] = f(elwid
, *extracted_args
)
128 def map(cls
, f
, *args
):
129 """get the SimdMap of the results of calling
130 `f(value1, value2, value3, ...)` where
131 `value1`, `value2`, `value3`, ... are the results of calling
132 `cls.extract_value` on each `args`.
134 This is similar to Python's built-in `map` function.
137 SimdMap.map(lambda a: a + 1, {IntElWid.I32: 5}) ==
138 SimdMap({IntElWid.I32: 6})
139 SimdMap.map(lambda a: a + 1, 3) ==
140 SimdMap({IntElWid.I8: 4, IntElWid.I16: 4, ...})
141 SimdMap.map(lambda a, b: a + b,
143 ) == SimdMap({IntElWid.I8: 7})
145 return cls
.map_with_elwid(lambda elwid
, *args2
: f(*args2
), *args
)
147 def get(self
, elwid
, default
=None, *, raise_key_error
=False):
149 retval
= self
.extract_value(elwid
, self
)
153 return self
.extract_value(elwid
, self
, default
)
156 """return an iterator of (elwid, value) pairs"""
157 return self
.__map
.items()
159 def __add__(self
, other
):
160 return self
.map(operator
.add
, self
, other
)
162 def __radd__(self
, other
):
163 return self
.map(operator
.add
, other
, self
)
165 def __sub__(self
, other
):
166 return self
.map(operator
.sub
, self
, other
)
168 def __rsub__(self
, other
):
169 return self
.map(operator
.sub
, other
, self
)
171 def __mul__(self
, other
):
172 return self
.map(operator
.mul
, self
, other
)
174 def __rmul__(self
, other
):
175 return self
.map(operator
.mul
, other
, self
)
177 def __floordiv__(self
, other
):
178 return self
.map(operator
.floordiv
, self
, other
)
180 def __rfloordiv__(self
, other
):
181 return self
.map(operator
.floordiv
, other
, self
)
183 def __truediv__(self
, other
):
184 return self
.map(operator
.truediv
, self
, other
)
186 def __rtruediv__(self
, other
):
187 return self
.map(operator
.truediv
, other
, self
)
189 def __mod__(self
, other
):
190 return self
.map(operator
.mod
, self
, other
)
192 def __rmod__(self
, other
):
193 return self
.map(operator
.mod
, other
, self
)
196 return self
.map(abs, self
)
198 def __and__(self
, other
):
199 return self
.map(operator
.and_
, self
, other
)
201 def __rand__(self
, other
):
202 return self
.map(operator
.and_
, other
, self
)
204 def __divmod__(self
, other
):
205 return self
.map(divmod, self
, other
)
208 return self
.map(math
.ceil
, self
)
211 return self
.map(float, self
)
214 return self
.map(math
.floor
, self
)
216 def __eq__(self
, other
):
217 if isinstance(other
, SimdMap
):
218 return self
.mapping
== other
.mapping
219 return NotImplemented
222 return hash(tuple(self
.mapping
.get(i
) for i
in self
.ALL_ELWIDTHS
))
225 return f
"{self.__class__.__name__}({dict(self.mapping)})"
227 def __invert__(self
):
228 return self
.map(operator
.invert
, self
)
230 def __lshift__(self
, other
):
231 return self
.map(operator
.lshift
, self
, other
)
233 def __rlshift__(self
, other
):
234 return self
.map(operator
.lshift
, other
, self
)
236 def __rshift__(self
, other
):
237 return self
.map(operator
.rshift
, self
, other
)
239 def __rrshift__(self
, other
):
240 return self
.map(operator
.rshift
, other
, self
)
243 return self
.map(operator
.neg
, self
)
246 return self
.map(operator
.pos
, self
)
248 def __or__(self
, other
):
249 return self
.map(operator
.or_
, self
, other
)
251 def __ror__(self
, other
):
252 return self
.map(operator
.or_
, other
, self
)
254 def __xor__(self
, other
):
255 return self
.map(operator
.xor
, self
, other
)
257 def __rxor__(self
, other
):
258 return self
.map(operator
.xor
, other
, self
)
260 def missing_elwidths(self
, *, all_elwidths
=None):
261 """an iterator of the elwidths where self doesn't have a corresponding
263 if all_elwidths
is None:
264 all_elwidths
= self
.ALL_ELWIDTHS
265 for elwid
in all_elwidths
:
266 if elwid
not in self
.keys():
270 def _check_for_missing_elwidths(name
, all_elwidths
=None):
271 missing
= list(globals()[name
].missing_elwidths(all_elwidths
=all_elwidths
))
272 assert missing
== [], f
"{name} is missing entries for {missing}"
286 DEFAULT_FP_PART_COUNTS
= SimdMap({
293 DEFAULT_INT_PART_COUNTS
= SimdMap({
300 _check_for_missing_elwidths("XLEN")
301 _check_for_missing_elwidths("DEFAULT_FP_PART_COUNTS", FpElWid
)
302 _check_for_missing_elwidths("DEFAULT_INT_PART_COUNTS", IntElWid
)
306 """The global scope object for SimdSignal and friends
309 * part_counts: SimdMap
310 a map from `ElWid` values `k` to the number of parts in an element
311 when `self.elwid == k`. Values should be minimized, since higher values
312 often create bigger circuits.
315 # here, an I8 element is 1 part wide
316 part_counts = {ElWid.I8: 1, ElWid.I16: 2, ElWid.I32: 4, ElWid.I64: 8}
319 # here, an F16 element is 1 part wide
320 part_counts = {ElWid.F16: 1, ElWid.BF16: 1, ElWid.F32: 2, ElWid.F64: 4}
321 * simd_full_width_hint: int
322 the default value for SimdLayout's full_width argument, the full number
323 of bits in a SIMD value.
324 * elwid: ElWid or nmigen Value with a shape of some ElWid class
325 the current elwid (simd element type)
332 """get the current SimdScope.
335 SimdScope.get(None) is None
336 SimdScope.get() raises ValueError
337 with SimdScope(...) as s:
340 if len(cls
.__SCOPE
_STACK
) > 0:
341 retval
= cls
.__SCOPE
_STACK
[-1]
342 assert isinstance(retval
, SimdScope
), "inconsistent scope stack"
344 raise ValueError("not in a `with SimdScope()` statement")
347 self
.__SCOPE
_STACK
.append(self
)
350 def __exit__(self
, exc_type
, exc_value
, traceback
):
351 assert self
.__SCOPE
_STACK
.pop() is self
, "inconsistent scope stack"
354 def __init__(self
, *, simd_full_width_hint
=64, elwid
=None,
355 part_counts
=None, elwid_type
=IntElWid
, scalar
=False):
356 # TODO: add more arguments/members and processing for integration with
357 self
.simd_full_width_hint
= simd_full_width_hint
358 if isinstance(elwid
, (IntElWid
, FpElWid
)):
359 elwid_type
= type(elwid
)
360 if part_counts
is None:
361 part_counts
= SimdMap({elwid
: 1})
362 assert issubclass(elwid_type
, (IntElWid
, FpElWid
))
363 self
.elwid_type
= elwid_type
364 scalar_elwid
= elwid_type(0)
365 if part_counts
is None:
367 part_counts
= SimdMap({scalar_elwid
: 1})
368 elif issubclass(elwid_type
, FpElWid
):
369 part_counts
= DEFAULT_FP_PART_COUNTS
371 part_counts
= DEFAULT_INT_PART_COUNTS
373 def check(elwid
, part_count
):
374 assert type(elwid
) == elwid_type
, "inconsistent ElWid types"
375 part_count
= int(part_count
)
376 assert part_count
!= 0 and (part_count
& (part_count
- 1)) == 0,\
377 "part_counts values must all be powers of two"
380 self
.part_counts
= SimdMap
.map_with_elwid(check
, part_counts
)
381 self
.full_part_count
= max(part_counts
.values())
382 assert self
.simd_full_width_hint
% self
.full_part_count
== 0,\
383 "simd_full_width_hint must be a multiple of full_part_count"
384 if elwid
is not None:
387 self
.elwid
= scalar_elwid
389 self
.elwid
= Signal(elwid_type
)
392 return (f
"SimdScope(\n"
393 f
" simd_full_width_hint={self.simd_full_width_hint},\n"
394 f
" elwid={self.elwid},\n"
395 f
" elwid_type={self.elwid_type},\n"
396 f
" part_counts={self.part_counts},\n"
397 f
" full_part_count={self.full_part_count})")