add SimdWHintMap to support tracking width_hint for XLEN
[ieee754fpu.git] / src / ieee754 / part / util.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # See Notices.txt for copyright information
3
4 from enum import Enum
5 from typing import Mapping
6 import operator
7 import math
8 from types import MappingProxyType
9
10
11 class ElWid(Enum):
12 def __repr__(self):
13 return super().__str__()
14
15
16 class FpElWid(ElWid):
17 F64 = 0
18 F32 = 1
19 F16 = 2
20 BF16 = 3
21
22
23 class IntElWid(ElWid):
24 I64 = 0
25 I32 = 1
26 I16 = 2
27 I8 = 3
28
29
30 class SimdMap:
31 """A map from ElWid values to Python values.
32 SimdMap instances are immutable."""
33
34 ALL_ELWIDTHS = (*FpElWid, *IntElWid)
35 __slots__ = ("__map",)
36
37 @staticmethod
38 def extract_value_algo(values, default=None, *, simd_map_get, mapping_get):
39 step = 0
40 while values is not None:
41 # specifically use base class to catch all SimdMap instances
42 if isinstance(values, SimdMap):
43 values = simd_map_get(values)
44 elif isinstance(values, Mapping):
45 values = mapping_get(values)
46 else:
47 return values
48 step += 1
49 # use object.__repr__ since repr() would probably recurse forever
50 assert step < 10000, (f"can't resolve infinitely recursive "
51 f"value {object.__repr__(values)}")
52 return default
53
54 @classmethod
55 def extract_value(cls, elwid, values, default=None):
56 """get the value for elwid.
57 if `values` is a `SimdMap` or a `Mapping`, then return the
58 corresponding value for `elwid`, recursing until finding a non-map.
59 if `values` ever ends up not existing (in the case of a map) or being
60 `None`, return `default`.
61
62 Examples:
63 SimdMap.extract_value(IntElWid.I8, 5) == 5
64 SimdMap.extract_value(IntElWid.I8, None) == None
65 SimdMap.extract_value(IntElWid.I8, None, 3) == 3
66 SimdMap.extract_value(IntElWid.I8, {}) == None
67 SimdMap.extract_value(IntElWid.I8, {IntElWid.I8: 5}) == 5
68 SimdMap.extract_value(IntElWid.I8, {
69 IntElWid.I8: {IntElWid.I8: 5},
70 }) == 5
71 SimdMap.extract_value(IntElWid.I8, {
72 IntElWid.I8: SimdMap({IntElWid.I8: 5}),
73 }) == 5
74 """
75 assert elwid in cls.ALL_ELWIDTHS
76 return SimdMap.extract_value_algo(
77 values, default,
78 simd_map_get=lambda v: v.__map.get(elwid),
79 mapping_get=lambda v: v.get(elwid))
80
81 def __init__(self, values=None):
82 """construct a SimdMap"""
83 mapping = {}
84 for elwid in self.ALL_ELWIDTHS:
85 v = self.extract_value(elwid, values)
86 if v is not None:
87 mapping[elwid] = v
88 self.__map = MappingProxyType(mapping)
89
90 @property
91 def mapping(self):
92 """the values as a read-only Mapping[ElWid, Any]"""
93 return self.__map
94
95 def values(self):
96 return self.__map.values()
97
98 def keys(self):
99 return self.__map.keys()
100
101 def items(self):
102 return self.__map.items()
103
104 @classmethod
105 def map_with_elwid(cls, f, *args):
106 """get the SimdMap of the results of calling
107 `f(elwid, value1, value2, value3, ...)` where
108 `value1`, `value2`, `value3`, ... are the results of calling
109 `cls.extract_value` on each `args`.
110
111 This is similar to Python's built-in `map` function.
112
113 Examples:
114 SimdMap.map_with_elwid(lambda elwid, a: a + 1, {IntElWid.I32: 5}) ==
115 SimdMap({IntElWid.I32: 6})
116 SimdMap.map_with_elwid(lambda elwid, a: a + 1, 3) ==
117 SimdMap({IntElWid.I8: 4, IntElWid.I16: 4, ...})
118 SimdMap.map_with_elwid(lambda elwid, a, b: a + b,
119 3, {IntElWid.I8: 4},
120 ) == SimdMap({IntElWid.I8: 7})
121 SimdMap.map_with_elwid(lambda elwid: elwid.name) ==
122 SimdMap({IntElWid.I8: "I8", IntElWid.I16: "I16"})
123 """
124 retval = {}
125 for elwid in cls.ALL_ELWIDTHS:
126 extracted_args = [cls.extract_value(elwid, arg) for arg in args]
127 if None not in extracted_args:
128 retval[elwid] = f(elwid, *extracted_args)
129 return cls(retval)
130
131 @classmethod
132 def map(cls, f, *args):
133 """get the SimdMap of the results of calling
134 `f(value1, value2, value3, ...)` where
135 `value1`, `value2`, `value3`, ... are the results of calling
136 `cls.extract_value` on each `args`.
137
138 This is similar to Python's built-in `map` function.
139
140 Examples:
141 SimdMap.map(lambda a: a + 1, {IntElWid.I32: 5}) ==
142 SimdMap({IntElWid.I32: 6})
143 SimdMap.map(lambda a: a + 1, 3) ==
144 SimdMap({IntElWid.I8: 4, IntElWid.I16: 4, ...})
145 SimdMap.map(lambda a, b: a + b,
146 3, {IntElWid.I8: 4},
147 ) == SimdMap({IntElWid.I8: 7})
148 """
149 return cls.map_with_elwid(lambda elwid, *args2: f(*args2), *args)
150
151 def get(self, elwid, default=None, *, raise_key_error=False):
152 if raise_key_error:
153 retval = self.extract_value(elwid, self)
154 if retval is None:
155 raise KeyError()
156 return retval
157 return self.extract_value(elwid, self, default)
158
159 def __iter__(self):
160 """return an iterator of (elwid, value) pairs"""
161 return iter(self.__map.items())
162
163 def __add__(self, other):
164 return self.map(operator.add, self, other)
165
166 def __radd__(self, other):
167 return self.map(operator.add, other, self)
168
169 def __sub__(self, other):
170 return self.map(operator.sub, self, other)
171
172 def __rsub__(self, other):
173 return self.map(operator.sub, other, self)
174
175 def __mul__(self, other):
176 return self.map(operator.mul, self, other)
177
178 def __rmul__(self, other):
179 return self.map(operator.mul, other, self)
180
181 def __floordiv__(self, other):
182 return self.map(operator.floordiv, self, other)
183
184 def __rfloordiv__(self, other):
185 return self.map(operator.floordiv, other, self)
186
187 def __truediv__(self, other):
188 return self.map(operator.truediv, self, other)
189
190 def __rtruediv__(self, other):
191 return self.map(operator.truediv, other, self)
192
193 def __mod__(self, other):
194 return self.map(operator.mod, self, other)
195
196 def __rmod__(self, other):
197 return self.map(operator.mod, other, self)
198
199 def __abs__(self):
200 return self.map(abs, self)
201
202 def __and__(self, other):
203 return self.map(operator.and_, self, other)
204
205 def __rand__(self, other):
206 return self.map(operator.and_, other, self)
207
208 def __divmod__(self, other):
209 return self.map(divmod, self, other)
210
211 def __ceil__(self):
212 return self.map(math.ceil, self)
213
214 def __float__(self):
215 return self.map(float, self)
216
217 def __floor__(self):
218 return self.map(math.floor, self)
219
220 def __eq__(self, other):
221 if isinstance(other, SimdMap):
222 return self.mapping == other.mapping
223 return NotImplemented
224
225 def __hash__(self):
226 return hash(tuple(self.mapping.get(i) for i in self.ALL_ELWIDTHS))
227
228 def __repr__(self):
229 return f"{self.__class__.__name__}({dict(self.mapping)})"
230
231 def __invert__(self):
232 return self.map(operator.invert, self)
233
234 def __lshift__(self, other):
235 return self.map(operator.lshift, self, other)
236
237 def __rlshift__(self, other):
238 return self.map(operator.lshift, other, self)
239
240 def __rshift__(self, other):
241 return self.map(operator.rshift, self, other)
242
243 def __rrshift__(self, other):
244 return self.map(operator.rshift, other, self)
245
246 def __neg__(self):
247 return self.map(operator.neg, self)
248
249 def __pos__(self):
250 return self.map(operator.pos, self)
251
252 def __or__(self, other):
253 return self.map(operator.or_, self, other)
254
255 def __ror__(self, other):
256 return self.map(operator.or_, other, self)
257
258 def __xor__(self, other):
259 return self.map(operator.xor, self, other)
260
261 def __rxor__(self, other):
262 return self.map(operator.xor, other, self)
263
264 def missing_elwidths(self, *, all_elwidths=None):
265 """an iterator of the elwidths where self doesn't have a corresponding
266 value"""
267 if all_elwidths is None:
268 all_elwidths = self.ALL_ELWIDTHS
269 for elwid in all_elwidths:
270 if elwid not in self.keys():
271 yield elwid
272
273
274 class SimdWHintMap(SimdMap):
275 """SimdMap with a width hint."""
276
277 __slots__ = ("__width_hint",)
278
279 @classmethod
280 def extract_width_hint(cls, values, default=None):
281 """get the value for width hint."""
282 def simd_map_get(v):
283 return v.width_hint if isinstance(v, SimdWHintMap) else None
284 return SimdMap.extract_value_algo(
285 values, default,
286 simd_map_get=simd_map_get,
287 mapping_get=lambda _: None)
288
289 def __init__(self, values=None, *, width_hint=None):
290 """construct a SimdWHintMap"""
291 super().__init__(values)
292 if width_hint is None:
293 width_hint = values
294
295 self.__width_hint = self.extract_width_hint(width_hint)
296
297 @property
298 def width_hint(self):
299 return self.__width_hint
300
301 def __eq__(self, other):
302 if isinstance(other, SimdMap):
303 return self.mapping == other.mapping \
304 and self.width_hint == self.extract_width_hint(other)
305 return NotImplemented
306
307 def __hash__(self):
308 if self.width_hint is None:
309 return super().__hash__()
310 return hash((tuple(self.mapping.get(i) for i in self.ALL_ELWIDTHS),
311 self.width_hint))
312
313 def __repr__(self):
314 wh = ""
315 if self.width_hint is not None:
316 wh = f", width_hint={self.width_hint!r}"
317 return f"{self.__class__.__name__}({dict(self.mapping)}{wh})"
318
319 @classmethod
320 def map(cls, f, *args):
321 """get the SimdWHintMap of the results of calling
322 `f(value1, value2, value3, ...)` where
323 `value1`, `value2`, `value3`, ... are the results of calling
324 `cls.extract_value` on each `args`.
325 """
326 retval = {}
327 for elwid in cls.ALL_ELWIDTHS:
328 extracted_args = [cls.extract_value(elwid, arg) for arg in args]
329 if None not in extracted_args:
330 retval[elwid] = f(*extracted_args)
331 width_hint = None
332 try:
333 extracted_args = [cls.extract_width_hint(arg) for arg in args]
334 if None not in extracted_args:
335 width_hint = f(*extracted_args)
336 except (ArithmeticError, LookupError, ValueError):
337 # ignore some errors and just clear width_hint
338 pass
339 return cls(retval, width_hint=width_hint)
340
341
342 def _check_for_missing_elwidths(name, all_elwidths=None):
343 missing = list(globals()[name].missing_elwidths(all_elwidths=all_elwidths))
344 assert missing == [], f"{name} is missing entries for {missing}"
345
346
347 XLEN = SimdWHintMap({
348 IntElWid.I64: 64,
349 IntElWid.I32: 32,
350 IntElWid.I16: 16,
351 IntElWid.I8: 8,
352 FpElWid.F64: 64,
353 FpElWid.F32: 32,
354 FpElWid.F16: 16,
355 FpElWid.BF16: 16,
356 }, width_hint=64)
357
358 DEFAULT_FP_VEC_EL_COUNTS = SimdMap({
359 FpElWid.F64: 1,
360 FpElWid.F32: 2,
361 FpElWid.F16: 4,
362 FpElWid.BF16: 4,
363 })
364
365 DEFAULT_INT_VEC_EL_COUNTS = SimdMap({
366 IntElWid.I64: 1,
367 IntElWid.I32: 2,
368 IntElWid.I16: 4,
369 IntElWid.I8: 8,
370 })
371
372 _check_for_missing_elwidths("XLEN")
373 _check_for_missing_elwidths("DEFAULT_FP_VEC_EL_COUNTS", FpElWid)
374 _check_for_missing_elwidths("DEFAULT_INT_VEC_EL_COUNTS", IntElWid)