add SimdMap and SimdScope and XLEN
[ieee754fpu.git] / src / ieee754 / part / util.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # See Notices.txt for copyright information
3
4 from enum import Enum
5 from typing import Mapping
6 import operator
7 import math
8 from types import MappingProxyType
9 from contextlib import contextmanager
10
11 from nmigen.hdl.ast import Signal
12
13
14 class ElWid(Enum):
15 def __repr__(self):
16 return super().__str__()
17
18
19 class FpElWid(ElWid):
20 F64 = 0
21 F32 = 1
22 F16 = 2
23 BF16 = 3
24
25
26 class IntElWid(ElWid):
27 I64 = 0
28 I32 = 1
29 I16 = 2
30 I8 = 3
31
32
33 class SimdMap:
34 """A map from ElWid values to Python values.
35 SimdMap instances are immutable."""
36
37 ALL_ELWIDTHS = (*FpElWid, *IntElWid)
38 __slots__ = ("__map",)
39
40 @classmethod
41 def extract_value(cls, elwid, values, default=None):
42 """get the value for elwid.
43 if `values` is a `SimdMap` or a `Mapping`, then return the
44 corresponding value for `elwid`, recursing until finding a non-map.
45 if `values` ever ends up not existing (in the case of a map) or being
46 `None`, return `default`.
47
48 Examples:
49 SimdMap.extract_value(IntElWid.I8, 5) == 5
50 SimdMap.extract_value(IntElWid.I8, None) == None
51 SimdMap.extract_value(IntElWid.I8, None, 3) == 3
52 SimdMap.extract_value(IntElWid.I8, {}) == None
53 SimdMap.extract_value(IntElWid.I8, {IntElWid.I8: 5}) == 5
54 SimdMap.extract_value(IntElWid.I8, {
55 IntElWid.I8: {IntElWid.I8: 5},
56 }) == 5
57 SimdMap.extract_value(IntElWid.I8, {
58 IntElWid.I8: SimdMap({IntElWid.I8: 5}),
59 }) == 5
60 """
61 assert elwid in cls.ALL_ELWIDTHS
62 step = 0
63 while values is not None:
64 # specifically use base class to catch all SimdMap instances
65 if isinstance(values, SimdMap):
66 values = values.__map.get(elwid)
67 elif isinstance(values, Mapping):
68 values = values.get(elwid)
69 else:
70 return values
71 step += 1
72 # use object.__repr__ since repr() would probably recurse forever
73 assert step < 10000, (f"can't resolve infinitely recursive "
74 f"value {object.__repr__(values)}")
75 return default
76
77 def __init__(self, values=None):
78 """construct a SimdMap"""
79 mapping = {}
80 for elwid in self.ALL_ELWIDTHS:
81 v = self.extract_value(elwid, values)
82 if v is not None:
83 mapping[elwid] = v
84 self.__map = MappingProxyType(mapping)
85
86 @property
87 def mapping(self):
88 """the values as a read-only Mapping[ElWid, Any]"""
89 return self.__map
90
91 def values(self):
92 return self.__map.values()
93
94 def keys(self):
95 return self.__map.keys()
96
97 def items(self):
98 return self.__map.items()
99
100 @classmethod
101 def map_with_elwid(cls, f, *args):
102 """get the SimdMap of the results of calling
103 `f(elwid, value1, value2, value3, ...)` where
104 `value1`, `value2`, `value3`, ... are the results of calling
105 `cls.extract_value` on each `args`.
106
107 This is similar to Python's built-in `map` function.
108
109 Examples:
110 SimdMap.map_with_elwid(lambda elwid, a: a + 1, {IntElWid.I32: 5}) ==
111 SimdMap({IntElWid.I32: 6})
112 SimdMap.map_with_elwid(lambda elwid, a: a + 1, 3) ==
113 SimdMap({IntElWid.I8: 4, IntElWid.I16: 4, ...})
114 SimdMap.map_with_elwid(lambda elwid, a, b: a + b,
115 3, {IntElWid.I8: 4},
116 ) == SimdMap({IntElWid.I8: 7})
117 SimdMap.map_with_elwid(lambda elwid: elwid.name) ==
118 SimdMap({IntElWid.I8: "I8", IntElWid.I16: "I16"})
119 """
120 retval = {}
121 for elwid in cls.ALL_ELWIDTHS:
122 extracted_args = [cls.extract_value(elwid, arg) for arg in args]
123 if None not in extracted_args:
124 retval[elwid] = f(elwid, *extracted_args)
125 return cls(retval)
126
127 @classmethod
128 def map(cls, f, *args):
129 """get the SimdMap of the results of calling
130 `f(value1, value2, value3, ...)` where
131 `value1`, `value2`, `value3`, ... are the results of calling
132 `cls.extract_value` on each `args`.
133
134 This is similar to Python's built-in `map` function.
135
136 Examples:
137 SimdMap.map(lambda a: a + 1, {IntElWid.I32: 5}) ==
138 SimdMap({IntElWid.I32: 6})
139 SimdMap.map(lambda a: a + 1, 3) ==
140 SimdMap({IntElWid.I8: 4, IntElWid.I16: 4, ...})
141 SimdMap.map(lambda a, b: a + b,
142 3, {IntElWid.I8: 4},
143 ) == SimdMap({IntElWid.I8: 7})
144 """
145 return cls.map_with_elwid(lambda elwid, *args2: f(*args2), *args)
146
147 def get(self, elwid, default=None, *, raise_key_error=False):
148 if raise_key_error:
149 retval = self.extract_value(elwid, self)
150 if retval is None:
151 raise KeyError()
152 return retval
153 return self.extract_value(elwid, self, default)
154
155 def __iter__(self):
156 """return an iterator of (elwid, value) pairs"""
157 return self.__map.items()
158
159 def __add__(self, other):
160 return self.map(operator.add, self, other)
161
162 def __radd__(self, other):
163 return self.map(operator.add, other, self)
164
165 def __sub__(self, other):
166 return self.map(operator.sub, self, other)
167
168 def __rsub__(self, other):
169 return self.map(operator.sub, other, self)
170
171 def __mul__(self, other):
172 return self.map(operator.mul, self, other)
173
174 def __rmul__(self, other):
175 return self.map(operator.mul, other, self)
176
177 def __floordiv__(self, other):
178 return self.map(operator.floordiv, self, other)
179
180 def __rfloordiv__(self, other):
181 return self.map(operator.floordiv, other, self)
182
183 def __truediv__(self, other):
184 return self.map(operator.truediv, self, other)
185
186 def __rtruediv__(self, other):
187 return self.map(operator.truediv, other, self)
188
189 def __mod__(self, other):
190 return self.map(operator.mod, self, other)
191
192 def __rmod__(self, other):
193 return self.map(operator.mod, other, self)
194
195 def __abs__(self):
196 return self.map(abs, self)
197
198 def __and__(self, other):
199 return self.map(operator.and_, self, other)
200
201 def __rand__(self, other):
202 return self.map(operator.and_, other, self)
203
204 def __divmod__(self, other):
205 return self.map(divmod, self, other)
206
207 def __ceil__(self):
208 return self.map(math.ceil, self)
209
210 def __float__(self):
211 return self.map(float, self)
212
213 def __floor__(self):
214 return self.map(math.floor, self)
215
216 def __eq__(self, other):
217 if isinstance(other, SimdMap):
218 return self.mapping == other.mapping
219 return NotImplemented
220
221 def __hash__(self):
222 return hash(tuple(self.mapping.get(i) for i in self.ALL_ELWIDTHS))
223
224 def __repr__(self):
225 return f"{self.__class__.__name__}({dict(self.mapping)})"
226
227 def __invert__(self):
228 return self.map(operator.invert, self)
229
230 def __lshift__(self, other):
231 return self.map(operator.lshift, self, other)
232
233 def __rlshift__(self, other):
234 return self.map(operator.lshift, other, self)
235
236 def __rshift__(self, other):
237 return self.map(operator.rshift, self, other)
238
239 def __rrshift__(self, other):
240 return self.map(operator.rshift, other, self)
241
242 def __neg__(self):
243 return self.map(operator.neg, self)
244
245 def __pos__(self):
246 return self.map(operator.pos, self)
247
248 def __or__(self, other):
249 return self.map(operator.or_, self, other)
250
251 def __ror__(self, other):
252 return self.map(operator.or_, other, self)
253
254 def __xor__(self, other):
255 return self.map(operator.xor, self, other)
256
257 def __rxor__(self, other):
258 return self.map(operator.xor, other, self)
259
260 def missing_elwidths(self, *, all_elwidths=None):
261 """an iterator of the elwidths where self doesn't have a corresponding
262 value"""
263 if all_elwidths is None:
264 all_elwidths = self.ALL_ELWIDTHS
265 for elwid in all_elwidths:
266 if elwid not in self.keys():
267 yield elwid
268
269
270 def _check_for_missing_elwidths(name, all_elwidths=None):
271 missing = list(globals()[name].missing_elwidths(all_elwidths=all_elwidths))
272 assert missing == [], f"{name} is missing entries for {missing}"
273
274
275 XLEN = SimdMap({
276 IntElWid.I64: 64,
277 IntElWid.I32: 32,
278 IntElWid.I16: 16,
279 IntElWid.I8: 8,
280 FpElWid.F64: 64,
281 FpElWid.F32: 32,
282 FpElWid.F16: 16,
283 FpElWid.BF16: 16,
284 })
285
286 DEFAULT_FP_PART_COUNTS = SimdMap({
287 FpElWid.F64: 4,
288 FpElWid.F32: 2,
289 FpElWid.F16: 1,
290 FpElWid.BF16: 1,
291 })
292
293 DEFAULT_INT_PART_COUNTS = SimdMap({
294 IntElWid.I64: 8,
295 IntElWid.I32: 4,
296 IntElWid.I16: 2,
297 IntElWid.I8: 1,
298 })
299
300 _check_for_missing_elwidths("XLEN")
301 _check_for_missing_elwidths("DEFAULT_FP_PART_COUNTS", FpElWid)
302 _check_for_missing_elwidths("DEFAULT_INT_PART_COUNTS", IntElWid)
303
304
305 class SimdScope:
306 """The global scope object for SimdSignal and friends
307
308 Members:
309 * part_counts: SimdMap
310 a map from `ElWid` values `k` to the number of parts in an element
311 when `self.elwid == k`. Values should be minimized, since higher values
312 often create bigger circuits.
313
314 Example:
315 # here, an I8 element is 1 part wide
316 part_counts = {ElWid.I8: 1, ElWid.I16: 2, ElWid.I32: 4, ElWid.I64: 8}
317
318 Another Example:
319 # here, an F16 element is 1 part wide
320 part_counts = {ElWid.F16: 1, ElWid.BF16: 1, ElWid.F32: 2, ElWid.F64: 4}
321 * simd_full_width_hint: int
322 the default value for SimdLayout's full_width argument, the full number
323 of bits in a SIMD value.
324 * elwid: ElWid or nmigen Value with a shape of some ElWid class
325 the current elwid (simd element type)
326 """
327
328 __SCOPE_STACK = []
329
330 @classmethod
331 def get(cls):
332 """get the current SimdScope.
333
334 Example:
335 SimdScope.get(None) is None
336 SimdScope.get() raises ValueError
337 with SimdScope(...) as s:
338 SimdScope.get() is s
339 """
340 if len(cls.__SCOPE_STACK) > 0:
341 retval = cls.__SCOPE_STACK[-1]
342 assert isinstance(retval, SimdScope), "inconsistent scope stack"
343 return retval
344 raise ValueError("not in a `with SimdScope()` statement")
345
346 def __enter__(self):
347 self.__SCOPE_STACK.append(self)
348 return self
349
350 def __exit__(self, exc_type, exc_value, traceback):
351 assert self.__SCOPE_STACK.pop() is self, "inconsistent scope stack"
352 return False
353
354 def __init__(self, *, simd_full_width_hint=64, elwid=None,
355 part_counts=None, elwid_type=IntElWid, scalar=False):
356 # TODO: add more arguments/members and processing for integration with
357 self.simd_full_width_hint = simd_full_width_hint
358 if isinstance(elwid, (IntElWid, FpElWid)):
359 elwid_type = type(elwid)
360 if part_counts is None:
361 part_counts = SimdMap({elwid: 1})
362 assert issubclass(elwid_type, (IntElWid, FpElWid))
363 self.elwid_type = elwid_type
364 scalar_elwid = elwid_type(0)
365 if part_counts is None:
366 if scalar:
367 part_counts = SimdMap({scalar_elwid: 1})
368 elif issubclass(elwid_type, FpElWid):
369 part_counts = DEFAULT_FP_PART_COUNTS
370 else:
371 part_counts = DEFAULT_INT_PART_COUNTS
372
373 def check(elwid, part_count):
374 assert type(elwid) == elwid_type, "inconsistent ElWid types"
375 part_count = int(part_count)
376 assert part_count != 0 and (part_count & (part_count - 1)) == 0,\
377 "part_counts values must all be powers of two"
378 return part_count
379
380 self.part_counts = SimdMap.map_with_elwid(check, part_counts)
381 self.full_part_count = max(part_counts.values())
382 assert self.simd_full_width_hint % self.full_part_count == 0,\
383 "simd_full_width_hint must be a multiple of full_part_count"
384 if elwid is not None:
385 self.elwid = elwid
386 elif scalar:
387 self.elwid = scalar_elwid
388 else:
389 self.elwid = Signal(elwid_type)
390
391 def __repr__(self):
392 return (f"SimdScope(\n"
393 f" simd_full_width_hint={self.simd_full_width_hint},\n"
394 f" elwid={self.elwid},\n"
395 f" elwid_type={self.elwid_type},\n"
396 f" part_counts={self.part_counts},\n"
397 f" full_part_count={self.full_part_count})")