add tests for SimdMap and friends
[ieee754fpu.git] / src / ieee754 / part / test / test_util.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # See Notices.txt for copyright information
3
4 import unittest
5 import math
6 from ieee754.part.util import FpElWid, IntElWid, SimdMap, SimdWHintMap
7
8
9 class TestElWid(unittest.TestCase):
10 def test_repr(self):
11 self.assertEqual(repr(FpElWid.F64), "FpElWid.F64")
12 self.assertEqual(repr(IntElWid.I8), "IntElWid.I8")
13
14
15 class TestSimdMap(unittest.TestCase):
16 def test_extract_value(self):
17 self.assertEqual(SimdMap.extract_value(IntElWid.I8, None), None)
18 self.assertEqual(SimdMap.extract_value(IntElWid.I8, None, 5), 5)
19 self.assertEqual(SimdMap.extract_value(IntElWid.I8, 3, 5), 3)
20 self.assertEqual(SimdMap.extract_value(IntElWid.I8,
21 {FpElWid.F64: 3}, 5), 5)
22 self.assertEqual(SimdMap.extract_value(IntElWid.I8,
23 {FpElWid.F64: 3}), None)
24 self.assertEqual(SimdMap.extract_value(FpElWid.F16,
25 {FpElWid.F64: 3}), None)
26 self.assertEqual(SimdMap.extract_value(FpElWid.F64,
27 {FpElWid.F64: 3}), 3)
28 self.assertEqual(SimdMap.extract_value(FpElWid.F64,
29 {FpElWid.F64: None}, 5), 5)
30 self.assertEqual(SimdMap.extract_value(FpElWid.F64,
31 {FpElWid.F64: None}), None)
32 self.assertEqual(SimdMap.extract_value(
33 FpElWid.F64, {FpElWid.F64: {FpElWid.F64: 5}}), 5)
34 self.assertEqual(SimdMap.extract_value(
35 FpElWid.F64, {FpElWid.F64: {FpElWid.F32: 5}}), None)
36 simd_map = SimdMap({IntElWid.I8: 3, FpElWid.F64: 5})
37 self.assertEqual(SimdMap.extract_value(
38 FpElWid.F64, simd_map), 5)
39 self.assertEqual(SimdMap.extract_value(
40 FpElWid.F32, simd_map), None)
41 self.assertEqual(SimdMap.extract_value(
42 IntElWid.I8, simd_map), 3)
43 with self.assertRaisesRegex(
44 AssertionError, "can't resolve infinitely recursive value"):
45 m = {}
46 m[IntElWid.I32] = m
47 SimdMap.extract_value(IntElWid.I32, m)
48
49 def test_init(self):
50 self.assertEqual(repr(SimdMap(0).mapping),
51 "mappingproxy({FpElWid.F64: 0, FpElWid.F32: 0, "
52 "FpElWid.F16: 0, FpElWid.BF16: 0, IntElWid.I64: 0, "
53 "IntElWid.I32: 0, IntElWid.I16: 0, IntElWid.I8: 0})")
54 self.assertEqual(repr(SimdMap(None).mapping), "mappingproxy({})")
55 self.assertEqual(repr(SimdMap().mapping), "mappingproxy({})")
56 self.assertEqual(repr(SimdMap({FpElWid.F64: 5,
57 IntElWid.I8: 10}).mapping),
58 "mappingproxy({FpElWid.F64: 5, IntElWid.I8: 10})")
59 self.assertEqual(repr(SimdMap(SimdMap({FpElWid.F64: 5,
60 IntElWid.I8: 10})).mapping),
61 "mappingproxy({FpElWid.F64: 5, IntElWid.I8: 10})")
62
63 def test_values(self):
64 self.assertEqual(repr(SimdMap({FpElWid.F64: 5,
65 IntElWid.I8: 10}).values()),
66 "dict_values([5, 10])")
67
68 def test_keys(self):
69 self.assertEqual(repr(SimdMap({FpElWid.F64: 5,
70 IntElWid.I8: 10}).keys()),
71 "dict_keys([FpElWid.F64, IntElWid.I8])")
72
73 def test_items(self):
74 self.assertEqual(repr(SimdMap({FpElWid.F64: 5,
75 IntElWid.I8: 10}).items()),
76 "dict_items([(FpElWid.F64, 5), (IntElWid.I8, 10)])")
77
78 def test_map_and_map_with_elwid(self):
79 def case(*args, expected, expected_args):
80 calls = []
81
82 def callback(*args):
83 calls.append(args)
84 return len(calls)
85
86 self.assertEqual(repr(SimdMap.map_with_elwid(callback, *args)),
87 repr(SimdMap(expected)))
88 self.assertEqual(calls, expected_args)
89
90 calls = []
91 expected_args = [tuple(i[1:]) for i in expected_args]
92
93 self.assertEqual(repr(SimdMap.map(callback, *args)),
94 repr(SimdMap(expected)))
95 self.assertEqual(calls, expected_args)
96
97 case(expected={
98 FpElWid.F64: 1, FpElWid.F32: 2, FpElWid.F16: 3, FpElWid.BF16: 4,
99 IntElWid.I64: 5, IntElWid.I32: 6, IntElWid.I16: 7, IntElWid.I8: 8,
100 }, expected_args=[
101 (FpElWid.F64,), (FpElWid.F32,), (FpElWid.F16,), (FpElWid.BF16,),
102 (IntElWid.I64,), (IntElWid.I32,), (IntElWid.I16,), (IntElWid.I8,),
103 ])
104
105 case(None, expected={}, expected_args=[])
106 case(1, expected={
107 FpElWid.F64: 1, FpElWid.F32: 2, FpElWid.F16: 3, FpElWid.BF16: 4,
108 IntElWid.I64: 5, IntElWid.I32: 6, IntElWid.I16: 7, IntElWid.I8: 8,
109 }, expected_args=[
110 (FpElWid.F64, 1), (FpElWid.F32, 1),
111 (FpElWid.F16, 1), (FpElWid.BF16, 1),
112 (IntElWid.I64, 1), (IntElWid.I32, 1),
113 (IntElWid.I16, 1), (IntElWid.I8, 1),
114 ])
115 case(1, 5, expected={
116 FpElWid.F64: 1, FpElWid.F32: 2, FpElWid.F16: 3, FpElWid.BF16: 4,
117 IntElWid.I64: 5, IntElWid.I32: 6, IntElWid.I16: 7, IntElWid.I8: 8,
118 }, expected_args=[
119 (FpElWid.F64, 1, 5), (FpElWid.F32, 1, 5),
120 (FpElWid.F16, 1, 5), (FpElWid.BF16, 1, 5),
121 (IntElWid.I64, 1, 5), (IntElWid.I32, 1, 5),
122 (IntElWid.I16, 1, 5), (IntElWid.I8, 1, 5),
123 ])
124 case({FpElWid.F64: 1, IntElWid.I8: 3, FpElWid.F32: 5}, 5, expected={
125 FpElWid.F64: 1, FpElWid.F32: 2, IntElWid.I8: 3,
126 }, expected_args=[
127 (FpElWid.F64, 1, 5), (FpElWid.F32, 5, 5), (IntElWid.I8, 3, 5),
128 ])
129 case({FpElWid.F64: 1, IntElWid.I8: 3},
130 {FpElWid.F64: 5, IntElWid.I8: 7},
131 expected={FpElWid.F64: 1, IntElWid.I8: 2},
132 expected_args=[(FpElWid.F64, 1, 5), (IntElWid.I8, 3, 7)])
133 case(SimdMap({FpElWid.F64: 1, IntElWid.I8: 3}),
134 SimdMap({FpElWid.F64: 5, IntElWid.I8: 7}),
135 expected={FpElWid.F64: 1, IntElWid.I8: 2},
136 expected_args=[(FpElWid.F64, 1, 5), (IntElWid.I8, 3, 7)])
137
138 def test_get(self):
139 v = SimdMap({FpElWid.F64: 1, IntElWid.I8: 3})
140 self.assertEqual(v.get(IntElWid.I8), 3)
141 self.assertEqual(v.get(IntElWid.I16), None)
142 self.assertEqual(v.get(FpElWid.F64), 1)
143 self.assertEqual(v.get(IntElWid.I16, default="blah"), "blah")
144 self.assertEqual(v.get(FpElWid.F64, default="blah"), 1)
145 with self.assertRaises(KeyError):
146 v.get(IntElWid.I16, raise_key_error=True)
147
148 def test_iter(self):
149 self.assertEqual(list(SimdMap({FpElWid.F64: 5,
150 IntElWid.I8: 10})),
151 [(FpElWid.F64, 5), (IntElWid.I8, 10)])
152
153 def test_ops(self):
154 a = SimdMap({FpElWid.F64: 5, IntElWid.I8: 10})
155 b = SimdMap({FpElWid.F64: "abc", IntElWid.I8: "def"})
156 c = SimdMap({FpElWid.F64: -5, IntElWid.I8: 10})
157 d = SimdMap({FpElWid.F64: -3.5, IntElWid.I8: 10.5})
158 # add
159 self.assertEqual(a + 20,
160 SimdMap({FpElWid.F64: 25, IntElWid.I8: 30}))
161 self.assertEqual(20 + a,
162 SimdMap({FpElWid.F64: 25, IntElWid.I8: 30}))
163 self.assertEqual(b + "ghi",
164 SimdMap({FpElWid.F64: "abcghi",
165 IntElWid.I8: "defghi"}))
166 self.assertEqual("ghi" + b,
167 SimdMap({FpElWid.F64: "ghiabc",
168 IntElWid.I8: "ghidef"}))
169 # sub
170 self.assertEqual(a - 20,
171 SimdMap({FpElWid.F64: -15, IntElWid.I8: -10}))
172 self.assertEqual(20 - a,
173 SimdMap({FpElWid.F64: 15, IntElWid.I8: 10}))
174 # mul
175 self.assertEqual(a * 2,
176 SimdMap({FpElWid.F64: 10, IntElWid.I8: 20}))
177 self.assertEqual(2 * a,
178 SimdMap({FpElWid.F64: 10, IntElWid.I8: 20}))
179 self.assertEqual(b * 2,
180 SimdMap({FpElWid.F64: "abcabc",
181 IntElWid.I8: "defdef"}))
182 self.assertEqual(2 * b,
183 SimdMap({FpElWid.F64: "abcabc",
184 IntElWid.I8: "defdef"}))
185
186 # floordiv
187 self.assertEqual(a // 2,
188 SimdMap({FpElWid.F64: 2, IntElWid.I8: 5}))
189 self.assertEqual(20 // a,
190 SimdMap({FpElWid.F64: 4, IntElWid.I8: 2}))
191
192 # truediv
193 self.assertEqual(repr(a / 2),
194 repr(SimdMap({FpElWid.F64: 2.5, IntElWid.I8: 5.0})))
195 self.assertEqual(repr(20 / a),
196 repr(SimdMap({FpElWid.F64: 4.0, IntElWid.I8: 2.0})))
197
198 # mod
199 self.assertEqual(a % 3,
200 SimdMap({FpElWid.F64: 2, IntElWid.I8: 1}))
201 self.assertEqual(17 % a,
202 SimdMap({FpElWid.F64: 2, IntElWid.I8: 7}))
203
204 # abs
205 self.assertEqual(abs(a),
206 SimdMap({FpElWid.F64: 5, IntElWid.I8: 10}))
207 self.assertEqual(abs(c),
208 SimdMap({FpElWid.F64: 5, IntElWid.I8: 10}))
209
210 # and
211 self.assertEqual(a & 3,
212 SimdMap({FpElWid.F64: 1, IntElWid.I8: 2}))
213 self.assertEqual(31 & a,
214 SimdMap({FpElWid.F64: 5, IntElWid.I8: 10}))
215
216 # divmod
217 self.assertEqual(divmod(a, 3),
218 SimdMap({FpElWid.F64: (1, 2), IntElWid.I8: (3, 1)}))
219
220 # ceil
221 self.assertEqual(math.ceil(d),
222 SimdMap({FpElWid.F64: -3, IntElWid.I8: 11}))
223
224 # floor
225 self.assertEqual(math.floor(d),
226 SimdMap({FpElWid.F64: -4, IntElWid.I8: 10}))
227
228 # invert
229 self.assertEqual(~a,
230 SimdMap({FpElWid.F64: -6, IntElWid.I8: -11}))
231 self.assertEqual(~c,
232 SimdMap({FpElWid.F64: 4, IntElWid.I8: -11}))
233
234 # lshift
235 self.assertEqual(a << 2,
236 SimdMap({FpElWid.F64: 20, IntElWid.I8: 40}))
237 self.assertEqual(1 << a,
238 SimdMap({FpElWid.F64: 32, IntElWid.I8: 1024}))
239
240 # rshift
241 self.assertEqual(a >> 1,
242 SimdMap({FpElWid.F64: 2, IntElWid.I8: 5}))
243 self.assertEqual(1000 >> a,
244 SimdMap({FpElWid.F64: 31, IntElWid.I8: 0}))
245
246 # neg
247 self.assertEqual(-a,
248 SimdMap({FpElWid.F64: -5, IntElWid.I8: -10}))
249 self.assertEqual(-c,
250 SimdMap({FpElWid.F64: 5, IntElWid.I8: -10}))
251
252 # pos
253 self.assertEqual(+a,
254 SimdMap({FpElWid.F64: 5, IntElWid.I8: 10}))
255 self.assertEqual(+c,
256 SimdMap({FpElWid.F64: -5, IntElWid.I8: 10}))
257
258 # or
259 self.assertEqual(a | 2,
260 SimdMap({FpElWid.F64: 7, IntElWid.I8: 10}))
261 self.assertEqual(1 | a,
262 SimdMap({FpElWid.F64: 5, IntElWid.I8: 11}))
263
264 # xor
265 self.assertEqual(a ^ 2,
266 SimdMap({FpElWid.F64: 7, IntElWid.I8: 8}))
267 self.assertEqual(1 ^ a,
268 SimdMap({FpElWid.F64: 4, IntElWid.I8: 11}))
269
270
271 class TestSimdWHintMap(unittest.TestCase):
272 def test_extract_width_hint(self):
273 self.assertEqual(SimdWHintMap.extract_width_hint(None), None)
274 self.assertEqual(SimdWHintMap.extract_width_hint(None, 5), 5)
275 self.assertEqual(SimdWHintMap.extract_width_hint(3), 3)
276 self.assertEqual(SimdWHintMap.extract_width_hint(3, 5), 3)
277 self.assertEqual(SimdWHintMap.extract_width_hint(
278 {FpElWid.F64: 3}, 5), 5)
279 self.assertEqual(SimdWHintMap.extract_width_hint(
280 {FpElWid.F64: 3}), None)
281 a = SimdWHintMap({IntElWid.I8: 3, FpElWid.F64: 5}, width_hint=7)
282 b = SimdMap(a)
283 self.assertEqual(SimdWHintMap.extract_width_hint(a), 7)
284 self.assertEqual(SimdWHintMap.extract_width_hint(b), None)
285
286 def test_init(self):
287 self.assertEqual(repr(SimdWHintMap(width_hint=1)),
288 "SimdWHintMap({}, width_hint=1)")
289 self.assertEqual(repr(SimdWHintMap(width_hint="abc")),
290 "SimdWHintMap({}, width_hint='abc')")
291 self.assertEqual(repr(SimdWHintMap()),
292 "SimdWHintMap({})")
293 self.assertEqual(repr(SimdWHintMap(SimdWHintMap(width_hint=1))),
294 "SimdWHintMap({}, width_hint=1)")
295 self.assertEqual(repr(SimdWHintMap(SimdWHintMap(width_hint="abc"))),
296 "SimdWHintMap({}, width_hint='abc')")
297 self.assertEqual(repr(SimdWHintMap(SimdWHintMap())),
298 "SimdWHintMap({})")
299 self.assertEqual(repr(SimdWHintMap(SimdWHintMap(width_hint=1),
300 width_hint=2)),
301 "SimdWHintMap({}, width_hint=2)")
302 self.assertEqual(repr(SimdWHintMap(SimdWHintMap({FpElWid.F16: 5},
303 width_hint=1),
304 width_hint=2)),
305 "SimdWHintMap({FpElWid.F16: 5}, width_hint=2)")
306 self.assertEqual(repr(SimdWHintMap(SimdWHintMap({FpElWid.F16: 5},
307 width_hint=1))),
308 "SimdWHintMap({FpElWid.F16: 5}, width_hint=1)")
309 self.assertEqual(repr(SimdWHintMap(SimdWHintMap({FpElWid.F16: 5}))),
310 "SimdWHintMap({FpElWid.F16: 5})")
311
312 def test_eq(self):
313 self.assertEqual(SimdWHintMap(), SimdMap())
314 self.assertEqual(SimdWHintMap({FpElWid.F16: 5}),
315 SimdMap({FpElWid.F16: 5}))
316 self.assertNotEqual(SimdWHintMap({FpElWid.F16: 5}),
317 SimdMap({FpElWid.F16: 6}))
318 self.assertNotEqual(SimdWHintMap({FpElWid.F16: 5}, width_hint=3),
319 SimdMap({FpElWid.F16: 5}))
320 self.assertEqual(SimdWHintMap({FpElWid.F16: 5}, width_hint=3),
321 SimdWHintMap({FpElWid.F16: 5}, width_hint=3))
322
323 def test_ops(self):
324 a = SimdWHintMap({FpElWid.F16: 3, FpElWid.F32: 10}, width_hint=12)
325 self.assertEqual(a + 1,
326 SimdWHintMap({FpElWid.F32: 11, FpElWid.F16: 4},
327 width_hint=13))
328 self.assertEqual(a - a,
329 SimdWHintMap({FpElWid.F32: 0, FpElWid.F16: 0},
330 width_hint=0))
331 self.assertEqual(a - 12,
332 SimdWHintMap({FpElWid.F32: -2, FpElWid.F16: -9},
333 width_hint=0))
334 # test exceptions being suppressed for width_hint
335 self.assertEqual(5 // (a - 12),
336 SimdWHintMap({FpElWid.F32: -3, FpElWid.F16: -1}))
337 # test exceptions not being suppressed for non-width_hint
338 with self.assertRaises(ZeroDivisionError):
339 5 // (a - 3)
340
341
342 if __name__ == '__main__':
343 unittest.main()