6a9bcb13d4c92e17f995d964e357b655a63df813
[ieee754fpu.git] / src / ieee754 / part / util.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # See Notices.txt for copyright information
3
4 from enum import Enum
5 from typing import Mapping
6 import operator
7 import math
8 from types import MappingProxyType
9
10
11 class ElWid(Enum):
12 def __repr__(self):
13 return super().__str__()
14
15
16 class FpElWid(ElWid):
17 F64 = 0
18 F32 = 1
19 F16 = 2
20 BF16 = 3
21
22
23 class IntElWid(ElWid):
24 I64 = 0
25 I32 = 1
26 I16 = 2
27 I8 = 3
28
29
30 class SimdMap:
31 """A map from ElWid values to Python values.
32 SimdMap instances are immutable."""
33
34 ALL_ELWIDTHS = (*FpElWid, *IntElWid)
35 __slots__ = ("__map",)
36
37 @classmethod
38 def extract_value(cls, elwid, values, default=None):
39 """get the value for elwid.
40 if `values` is a `SimdMap` or a `Mapping`, then return the
41 corresponding value for `elwid`, recursing until finding a non-map.
42 if `values` ever ends up not existing (in the case of a map) or being
43 `None`, return `default`.
44
45 Examples:
46 SimdMap.extract_value(IntElWid.I8, 5) == 5
47 SimdMap.extract_value(IntElWid.I8, None) == None
48 SimdMap.extract_value(IntElWid.I8, None, 3) == 3
49 SimdMap.extract_value(IntElWid.I8, {}) == None
50 SimdMap.extract_value(IntElWid.I8, {IntElWid.I8: 5}) == 5
51 SimdMap.extract_value(IntElWid.I8, {
52 IntElWid.I8: {IntElWid.I8: 5},
53 }) == 5
54 SimdMap.extract_value(IntElWid.I8, {
55 IntElWid.I8: SimdMap({IntElWid.I8: 5}),
56 }) == 5
57 """
58 assert elwid in cls.ALL_ELWIDTHS
59 step = 0
60 while values is not None:
61 # specifically use base class to catch all SimdMap instances
62 if isinstance(values, SimdMap):
63 values = values.__map.get(elwid)
64 elif isinstance(values, Mapping):
65 values = values.get(elwid)
66 else:
67 return values
68 step += 1
69 # use object.__repr__ since repr() would probably recurse forever
70 assert step < 10000, (f"can't resolve infinitely recursive "
71 f"value {object.__repr__(values)}")
72 return default
73
74 def __init__(self, values=None):
75 """construct a SimdMap"""
76 mapping = {}
77 for elwid in self.ALL_ELWIDTHS:
78 v = self.extract_value(elwid, values)
79 if v is not None:
80 mapping[elwid] = v
81 self.__map = MappingProxyType(mapping)
82
83 @property
84 def mapping(self):
85 """the values as a read-only Mapping[ElWid, Any]"""
86 return self.__map
87
88 def values(self):
89 return self.__map.values()
90
91 def keys(self):
92 return self.__map.keys()
93
94 def items(self):
95 return self.__map.items()
96
97 @classmethod
98 def map_with_elwid(cls, f, *args):
99 """get the SimdMap of the results of calling
100 `f(elwid, value1, value2, value3, ...)` where
101 `value1`, `value2`, `value3`, ... are the results of calling
102 `cls.extract_value` on each `args`.
103
104 This is similar to Python's built-in `map` function.
105
106 Examples:
107 SimdMap.map_with_elwid(lambda elwid, a: a + 1, {IntElWid.I32: 5}) ==
108 SimdMap({IntElWid.I32: 6})
109 SimdMap.map_with_elwid(lambda elwid, a: a + 1, 3) ==
110 SimdMap({IntElWid.I8: 4, IntElWid.I16: 4, ...})
111 SimdMap.map_with_elwid(lambda elwid, a, b: a + b,
112 3, {IntElWid.I8: 4},
113 ) == SimdMap({IntElWid.I8: 7})
114 SimdMap.map_with_elwid(lambda elwid: elwid.name) ==
115 SimdMap({IntElWid.I8: "I8", IntElWid.I16: "I16"})
116 """
117 retval = {}
118 for elwid in cls.ALL_ELWIDTHS:
119 extracted_args = [cls.extract_value(elwid, arg) for arg in args]
120 if None not in extracted_args:
121 retval[elwid] = f(elwid, *extracted_args)
122 return cls(retval)
123
124 @classmethod
125 def map(cls, f, *args):
126 """get the SimdMap of the results of calling
127 `f(value1, value2, value3, ...)` where
128 `value1`, `value2`, `value3`, ... are the results of calling
129 `cls.extract_value` on each `args`.
130
131 This is similar to Python's built-in `map` function.
132
133 Examples:
134 SimdMap.map(lambda a: a + 1, {IntElWid.I32: 5}) ==
135 SimdMap({IntElWid.I32: 6})
136 SimdMap.map(lambda a: a + 1, 3) ==
137 SimdMap({IntElWid.I8: 4, IntElWid.I16: 4, ...})
138 SimdMap.map(lambda a, b: a + b,
139 3, {IntElWid.I8: 4},
140 ) == SimdMap({IntElWid.I8: 7})
141 """
142 return cls.map_with_elwid(lambda elwid, *args2: f(*args2), *args)
143
144 def get(self, elwid, default=None, *, raise_key_error=False):
145 if raise_key_error:
146 retval = self.extract_value(elwid, self)
147 if retval is None:
148 raise KeyError()
149 return retval
150 return self.extract_value(elwid, self, default)
151
152 def __iter__(self):
153 """return an iterator of (elwid, value) pairs"""
154 return self.__map.items()
155
156 def __add__(self, other):
157 return self.map(operator.add, self, other)
158
159 def __radd__(self, other):
160 return self.map(operator.add, other, self)
161
162 def __sub__(self, other):
163 return self.map(operator.sub, self, other)
164
165 def __rsub__(self, other):
166 return self.map(operator.sub, other, self)
167
168 def __mul__(self, other):
169 return self.map(operator.mul, self, other)
170
171 def __rmul__(self, other):
172 return self.map(operator.mul, other, self)
173
174 def __floordiv__(self, other):
175 return self.map(operator.floordiv, self, other)
176
177 def __rfloordiv__(self, other):
178 return self.map(operator.floordiv, other, self)
179
180 def __truediv__(self, other):
181 return self.map(operator.truediv, self, other)
182
183 def __rtruediv__(self, other):
184 return self.map(operator.truediv, other, self)
185
186 def __mod__(self, other):
187 return self.map(operator.mod, self, other)
188
189 def __rmod__(self, other):
190 return self.map(operator.mod, other, self)
191
192 def __abs__(self):
193 return self.map(abs, self)
194
195 def __and__(self, other):
196 return self.map(operator.and_, self, other)
197
198 def __rand__(self, other):
199 return self.map(operator.and_, other, self)
200
201 def __divmod__(self, other):
202 return self.map(divmod, self, other)
203
204 def __ceil__(self):
205 return self.map(math.ceil, self)
206
207 def __float__(self):
208 return self.map(float, self)
209
210 def __floor__(self):
211 return self.map(math.floor, self)
212
213 def __eq__(self, other):
214 if isinstance(other, SimdMap):
215 return self.mapping == other.mapping
216 return NotImplemented
217
218 def __hash__(self):
219 return hash(tuple(self.mapping.get(i) for i in self.ALL_ELWIDTHS))
220
221 def __repr__(self):
222 return f"{self.__class__.__name__}({dict(self.mapping)})"
223
224 def __invert__(self):
225 return self.map(operator.invert, self)
226
227 def __lshift__(self, other):
228 return self.map(operator.lshift, self, other)
229
230 def __rlshift__(self, other):
231 return self.map(operator.lshift, other, self)
232
233 def __rshift__(self, other):
234 return self.map(operator.rshift, self, other)
235
236 def __rrshift__(self, other):
237 return self.map(operator.rshift, other, self)
238
239 def __neg__(self):
240 return self.map(operator.neg, self)
241
242 def __pos__(self):
243 return self.map(operator.pos, self)
244
245 def __or__(self, other):
246 return self.map(operator.or_, self, other)
247
248 def __ror__(self, other):
249 return self.map(operator.or_, other, self)
250
251 def __xor__(self, other):
252 return self.map(operator.xor, self, other)
253
254 def __rxor__(self, other):
255 return self.map(operator.xor, other, self)
256
257 def missing_elwidths(self, *, all_elwidths=None):
258 """an iterator of the elwidths where self doesn't have a corresponding
259 value"""
260 if all_elwidths is None:
261 all_elwidths = self.ALL_ELWIDTHS
262 for elwid in all_elwidths:
263 if elwid not in self.keys():
264 yield elwid
265
266
267 def _check_for_missing_elwidths(name, all_elwidths=None):
268 missing = list(globals()[name].missing_elwidths(all_elwidths=all_elwidths))
269 assert missing == [], f"{name} is missing entries for {missing}"
270
271
272 XLEN = SimdMap({
273 IntElWid.I64: 64,
274 IntElWid.I32: 32,
275 IntElWid.I16: 16,
276 IntElWid.I8: 8,
277 FpElWid.F64: 64,
278 FpElWid.F32: 32,
279 FpElWid.F16: 16,
280 FpElWid.BF16: 16,
281 })
282
283 DEFAULT_FP_VEC_EL_COUNTS = SimdMap({
284 FpElWid.F64: 1,
285 FpElWid.F32: 2,
286 FpElWid.F16: 4,
287 FpElWid.BF16: 4,
288 })
289
290 DEFAULT_INT_VEC_EL_COUNTS = SimdMap({
291 IntElWid.I64: 1,
292 IntElWid.I32: 2,
293 IntElWid.I16: 4,
294 IntElWid.I8: 8,
295 })
296
297 _check_for_missing_elwidths("XLEN")
298 _check_for_missing_elwidths("DEFAULT_FP_VEC_EL_COUNTS", FpElWid)
299 _check_for_missing_elwidths("DEFAULT_INT_VEC_EL_COUNTS", IntElWid)