Update Simulator interface to current nMigen
[ieee754fpu.git] / src / ieee754 / part / test / test_partsig.py
1 #!/usr/bin/env python3
2 # SPDX-License-Identifier: LGPL-2.1-or-later
3 # See Notices.txt for copyright information
4
5 from nmigen import Signal, Module, Elaboratable
6 from nmigen.back.pysim import Simulator, Delay
7 from nmigen.cli import rtlil
8
9 from ieee754.part.partsig import PartitionedSignal
10 from ieee754.part_mux.part_mux import PMux
11
12 from random import randint
13 import unittest
14 import itertools
15 import math
16
17 def first_zero(x):
18 res = 0
19 for i in range(16):
20 if x & (1<<i):
21 return res
22 res += 1
23
24 def count_bits(x):
25 res = 0
26 for i in range(16):
27 if x & (1<<i):
28 res += 1
29 return res
30
31
32 def perms(k):
33 return map(''.join, itertools.product('01', repeat=k))
34
35
36 def create_ilang(dut, traces, test_name):
37 vl = rtlil.convert(dut, ports=traces)
38 with open("%s.il" % test_name, "w") as f:
39 f.write(vl)
40
41
42 def create_simulator(module, traces, test_name):
43 create_ilang(module, traces, test_name)
44 return Simulator(module)
45
46
47 # XXX this is for coriolis2 experimentation
48 class TestAddMod2(Elaboratable):
49 def __init__(self, width, partpoints):
50 self.partpoints = partpoints
51 self.a = PartitionedSignal(partpoints, width)
52 self.b = PartitionedSignal(partpoints, width)
53 self.bsig = Signal(width)
54 self.add_output = Signal(width)
55 self.ls_output = Signal(width) # left shift
56 self.ls_scal_output = Signal(width) # left shift
57 self.rs_output = Signal(width) # right shift
58 self.rs_scal_output = Signal(width) # right shift
59 self.sub_output = Signal(width)
60 self.eq_output = Signal(len(partpoints)+1)
61 self.gt_output = Signal(len(partpoints)+1)
62 self.ge_output = Signal(len(partpoints)+1)
63 self.ne_output = Signal(len(partpoints)+1)
64 self.lt_output = Signal(len(partpoints)+1)
65 self.le_output = Signal(len(partpoints)+1)
66 self.mux_sel = Signal(len(partpoints)+1)
67 self.mux_out = Signal(width)
68 self.carry_in = Signal(len(partpoints)+1)
69 self.add_carry_out = Signal(len(partpoints)+1)
70 self.sub_carry_out = Signal(len(partpoints)+1)
71 self.neg_output = Signal(width)
72
73 def elaborate(self, platform):
74 m = Module()
75 comb = m.d.comb
76 sync = m.d.sync
77 self.a.set_module(m)
78 self.b.set_module(m)
79 # compares
80 sync += self.lt_output.eq(self.a < self.b)
81 sync += self.ne_output.eq(self.a != self.b)
82 sync += self.le_output.eq(self.a <= self.b)
83 sync += self.gt_output.eq(self.a > self.b)
84 sync += self.eq_output.eq(self.a == self.b)
85 sync += self.ge_output.eq(self.a >= self.b)
86 # add
87 add_out, add_carry = self.a.add_op(self.a, self.b,
88 self.carry_in)
89 sync += self.add_output.eq(add_out)
90 sync += self.add_carry_out.eq(add_carry)
91 # sub
92 sub_out, sub_carry = self.a.sub_op(self.a, self.b,
93 self.carry_in)
94 sync += self.sub_output.eq(sub_out)
95 sync += self.sub_carry_out.eq(sub_carry)
96 # neg
97 sync += self.neg_output.eq(-self.a)
98 # left shift
99 sync += self.ls_output.eq(self.a << self.b)
100 sync += self.rs_output.eq(self.a >> self.b)
101 ppts = self.partpoints
102 sync += self.mux_out.eq(PMux(m, ppts, self.mux_sel, self.a, self.b))
103 # scalar left shift
104 comb += self.bsig.eq(self.b.sig)
105 sync += self.ls_scal_output.eq(self.a << self.bsig)
106 sync += self.rs_scal_output.eq(self.a >> self.bsig)
107
108 return m
109
110
111 class TestAddMod(Elaboratable):
112 def __init__(self, width, partpoints):
113 self.partpoints = partpoints
114 self.a = PartitionedSignal(partpoints, width)
115 self.b = PartitionedSignal(partpoints, width)
116 self.bsig = Signal(width)
117 self.add_output = Signal(width)
118 self.ls_output = Signal(width) # left shift
119 self.ls_scal_output = Signal(width) # left shift
120 self.rs_output = Signal(width) # right shift
121 self.rs_scal_output = Signal(width) # right shift
122 self.sub_output = Signal(width)
123 self.eq_output = Signal(len(partpoints)+1)
124 self.gt_output = Signal(len(partpoints)+1)
125 self.ge_output = Signal(len(partpoints)+1)
126 self.ne_output = Signal(len(partpoints)+1)
127 self.lt_output = Signal(len(partpoints)+1)
128 self.le_output = Signal(len(partpoints)+1)
129 self.mux_sel = Signal(len(partpoints)+1)
130 self.mux_out = Signal(width)
131 self.carry_in = Signal(len(partpoints)+1)
132 self.add_carry_out = Signal(len(partpoints)+1)
133 self.sub_carry_out = Signal(len(partpoints)+1)
134 self.neg_output = Signal(width)
135
136 def elaborate(self, platform):
137 m = Module()
138 comb = m.d.comb
139 sync = m.d.sync
140 self.a.set_module(m)
141 self.b.set_module(m)
142 # compares
143 comb += self.lt_output.eq(self.a < self.b)
144 comb += self.ne_output.eq(self.a != self.b)
145 comb += self.le_output.eq(self.a <= self.b)
146 comb += self.gt_output.eq(self.a > self.b)
147 comb += self.eq_output.eq(self.a == self.b)
148 comb += self.ge_output.eq(self.a >= self.b)
149 # add
150 add_out, add_carry = self.a.add_op(self.a, self.b,
151 self.carry_in)
152 comb += self.add_output.eq(add_out)
153 comb += self.add_carry_out.eq(add_carry)
154 # sub
155 sub_out, sub_carry = self.a.sub_op(self.a, self.b,
156 self.carry_in)
157 comb += self.sub_output.eq(sub_out)
158 comb += self.sub_carry_out.eq(sub_carry)
159 # neg
160 comb += self.neg_output.eq(-self.a)
161 # left shift
162 comb += self.ls_output.eq(self.a << self.b)
163 # right shift
164 comb += self.rs_output.eq(self.a >> self.b)
165 ppts = self.partpoints
166 # mux
167 comb += self.mux_out.eq(PMux(m, ppts, self.mux_sel, self.a, self.b))
168 # scalar left shift
169 comb += self.bsig.eq(self.b.sig)
170 comb += self.ls_scal_output.eq(self.a << self.bsig)
171 # scalar right shift
172 comb += self.rs_scal_output.eq(self.a >> self.bsig)
173
174 return m
175
176
177 class TestPartitionPoints(unittest.TestCase):
178 def test(self):
179 width = 16
180 part_mask = Signal(4) # divide into 4-bits
181 module = TestAddMod(width, part_mask)
182
183 test_name = "part_sig_add"
184 traces = [part_mask,
185 module.a.sig,
186 module.b.sig,
187 module.add_output,
188 module.eq_output]
189 sim = create_simulator(module, traces, test_name)
190
191 def async_process():
192
193 def test_ls_scal_fn(carry_in, a, b, mask):
194 # reduce range of b
195 bits = count_bits(mask)
196 newb = b & ((bits-1))
197 print ("%x %x %x bits %d trunc %x" % \
198 (a, b, mask, bits, newb))
199 b = newb
200 # TODO: carry
201 carry_in = 0
202 lsb = mask & ~(mask-1) if carry_in else 0
203 sum = ((a & mask) << b)
204 result = mask & sum
205 carry = (sum & mask) != sum
206 carry = 0
207 print("res", hex(a), hex(b), hex(sum), hex(mask), hex(result))
208 return result, carry
209
210 def test_rs_scal_fn(carry_in, a, b, mask):
211 # reduce range of b
212 bits = count_bits(mask)
213 newb = b & ((bits-1))
214 print ("%x %x %x bits %d trunc %x" % \
215 (a, b, mask, bits, newb))
216 b = newb
217 # TODO: carry
218 carry_in = 0
219 lsb = mask & ~(mask-1) if carry_in else 0
220 sum = ((a & mask) >> b)
221 result = mask & sum
222 carry = (sum & mask) != sum
223 carry = 0
224 print("res", hex(a), hex(b), hex(sum), hex(mask), hex(result))
225 return result, carry
226
227 def test_ls_fn(carry_in, a, b, mask):
228 # reduce range of b
229 bits = count_bits(mask)
230 fz = first_zero(mask)
231 newb = b & ((bits-1)<<fz)
232 print ("%x %x %x bits %d zero %d trunc %x" % \
233 (a, b, mask, bits, fz, newb))
234 b = newb
235 # TODO: carry
236 carry_in = 0
237 lsb = mask & ~(mask-1) if carry_in else 0
238 b = (b & mask)
239 b = b >>fz
240 sum = ((a & mask) << b)
241 result = mask & sum
242 carry = (sum & mask) != sum
243 carry = 0
244 print("res", hex(a), hex(b), hex(sum), hex(mask), hex(result))
245 return result, carry
246
247 def test_rs_fn(carry_in, a, b, mask):
248 # reduce range of b
249 bits = count_bits(mask)
250 fz = first_zero(mask)
251 newb = b & ((bits-1)<<fz)
252 print ("%x %x %x bits %d zero %d trunc %x" % \
253 (a, b, mask, bits, fz, newb))
254 b = newb
255 # TODO: carry
256 carry_in = 0
257 lsb = mask & ~(mask-1) if carry_in else 0
258 b = (b & mask)
259 b = b >>fz
260 sum = ((a & mask) >> b)
261 result = mask & sum
262 carry = (sum & mask) != sum
263 carry = 0
264 print("res", hex(a), hex(b), hex(sum), hex(mask), hex(result))
265 return result, carry
266
267 def test_add_fn(carry_in, a, b, mask):
268 lsb = mask & ~(mask-1) if carry_in else 0
269 sum = (a & mask) + (b & mask) + lsb
270 result = mask & sum
271 carry = (sum & mask) != sum
272 print(a, b, sum, mask)
273 return result, carry
274
275 def test_sub_fn(carry_in, a, b, mask):
276 lsb = mask & ~(mask-1) if carry_in else 0
277 sum = (a & mask) + (~b & mask) + lsb
278 result = mask & sum
279 carry = (sum & mask) != sum
280 return result, carry
281
282 def test_neg_fn(carry_in, a, b, mask):
283 return test_add_fn(0, a, ~0, mask)
284
285 def test_op(msg_prefix, carry, test_fn, mod_attr, *mask_list):
286 rand_data = []
287 for i in range(100):
288 a, b = randint(0, 1 << 16), randint(0, 1 << 16)
289 rand_data.append((a, b))
290 for a, b in [(0x0000, 0x0000),
291 (0x1234, 0x1234),
292 (0xABCD, 0xABCD),
293 (0xFFFF, 0x0000),
294 (0x0000, 0x0000),
295 (0xFFFF, 0xFFFF),
296 (0x0000, 0xFFFF)] + rand_data:
297 yield module.a.eq(a)
298 yield module.b.eq(b)
299 carry_sig = 0xf if carry else 0
300 yield module.carry_in.eq(carry_sig)
301 yield Delay(0.1e-6)
302 y = 0
303 carry_result = 0
304 for i, mask in enumerate(mask_list):
305 print ("i/mask", i, hex(mask))
306 res, c = test_fn(carry, a, b, mask)
307 y |= res
308 lsb = mask & ~(mask - 1)
309 bit_set = int(math.log2(lsb))
310 carry_result |= c << int(bit_set/4)
311 outval = (yield getattr(module, "%s_output" % mod_attr))
312 # TODO: get (and test) carry output as well
313 print(a, b, outval, carry)
314 msg = f"{msg_prefix}: 0x{a:X} {mod_attr} 0x{b:X}" + \
315 f" => 0x{y:X} != 0x{outval:X}"
316 self.assertEqual(y, outval, msg)
317 if hasattr(module, "%s_carry_out" % mod_attr):
318 c_outval = (yield getattr(module,
319 "%s_carry_out" % mod_attr))
320 msg = f"{msg_prefix}: 0x{a:X} {mod_attr} 0x{b:X}" + \
321 f" => 0x{carry_result:X} != 0x{c_outval:X}"
322 self.assertEqual(carry_result, c_outval, msg)
323
324 for (test_fn, mod_attr) in (
325 (test_ls_scal_fn, "ls_scal"),
326 (test_ls_fn, "ls"),
327 (test_rs_scal_fn, "rs_scal"),
328 (test_rs_fn, "rs"),
329 (test_add_fn, "add"),
330 (test_sub_fn, "sub"),
331 (test_neg_fn, "neg"),
332 ):
333 yield part_mask.eq(0)
334 yield from test_op("16-bit", 1, test_fn, mod_attr, 0xFFFF)
335 yield from test_op("16-bit", 0, test_fn, mod_attr, 0xFFFF)
336 yield part_mask.eq(0b10)
337 yield from test_op("8-bit", 0, test_fn, mod_attr,
338 0xFF00, 0x00FF)
339 yield from test_op("8-bit", 1, test_fn, mod_attr,
340 0xFF00, 0x00FF)
341 yield part_mask.eq(0b1111)
342 yield from test_op("4-bit", 0, test_fn, mod_attr,
343 0xF000, 0x0F00, 0x00F0, 0x000F)
344 yield from test_op("4-bit", 1, test_fn, mod_attr,
345 0xF000, 0x0F00, 0x00F0, 0x000F)
346
347 def test_ne_fn(a, b, mask):
348 return (a & mask) != (b & mask)
349
350 def test_lt_fn(a, b, mask):
351 return (a & mask) < (b & mask)
352
353 def test_le_fn(a, b, mask):
354 return (a & mask) <= (b & mask)
355
356 def test_eq_fn(a, b, mask):
357 return (a & mask) == (b & mask)
358
359 def test_gt_fn(a, b, mask):
360 return (a & mask) > (b & mask)
361
362 def test_ge_fn(a, b, mask):
363 return (a & mask) >= (b & mask)
364
365 def test_binop(msg_prefix, test_fn, mod_attr, *maskbit_list):
366 for a, b in [(0x0000, 0x0000),
367 (0x1234, 0x1234),
368 (0xABCD, 0xABCD),
369 (0xFFFF, 0x0000),
370 (0x0000, 0x0000),
371 (0xFFFF, 0xFFFF),
372 (0x0000, 0xFFFF),
373 (0xABCD, 0xABCE),
374 (0x8000, 0x0000),
375 (0xBEEF, 0xFEED)]:
376 yield module.a.eq(a)
377 yield module.b.eq(b)
378 yield Delay(0.1e-6)
379 # convert to mask_list
380 mask_list = []
381 for mb in maskbit_list:
382 v = 0
383 for i in range(4):
384 if mb & (1 << i):
385 v |= 0xf << (i*4)
386 mask_list.append(v)
387 y = 0
388 # do the partitioned tests
389 for i, mask in enumerate(mask_list):
390 if test_fn(a, b, mask):
391 # OR y with the lowest set bit in the mask
392 y |= maskbit_list[i]
393 # check the result
394 outval = (yield getattr(module, "%s_output" % mod_attr))
395 msg = f"{msg_prefix}: {mod_attr} 0x{a:X} == 0x{b:X}" + \
396 f" => 0x{y:X} != 0x{outval:X}, masklist %s"
397 print((msg % str(maskbit_list)).format(locals()))
398 self.assertEqual(y, outval, msg % str(maskbit_list))
399
400 for (test_fn, mod_attr) in ((test_eq_fn, "eq"),
401 (test_gt_fn, "gt"),
402 (test_ge_fn, "ge"),
403 (test_lt_fn, "lt"),
404 (test_le_fn, "le"),
405 (test_ne_fn, "ne"),
406 ):
407 yield part_mask.eq(0)
408 yield from test_binop("16-bit", test_fn, mod_attr, 0b1111)
409 yield part_mask.eq(0b10)
410 yield from test_binop("8-bit", test_fn, mod_attr,
411 0b1100, 0b0011)
412 yield part_mask.eq(0b1111)
413 yield from test_binop("4-bit", test_fn, mod_attr,
414 0b1000, 0b0100, 0b0010, 0b0001)
415
416 def test_muxop(msg_prefix, *maskbit_list):
417 for a, b in [(0x0000, 0x0000),
418 (0x1234, 0x1234),
419 (0xABCD, 0xABCD),
420 (0xFFFF, 0x0000),
421 (0x0000, 0x0000),
422 (0xFFFF, 0xFFFF),
423 (0x0000, 0xFFFF)]:
424 # convert to mask_list
425 mask_list = []
426 for mb in maskbit_list:
427 v = 0
428 for i in range(4):
429 if mb & (1 << i):
430 v |= 0xf << (i*4)
431 mask_list.append(v)
432
433 # TODO: sel needs to go through permutations of mask_list
434 for p in perms(len(mask_list)):
435
436 sel = 0
437 selmask = 0
438 for i, v in enumerate(p):
439 if v == '1':
440 sel |= maskbit_list[i]
441 selmask |= mask_list[i]
442
443 yield module.a.eq(a)
444 yield module.b.eq(b)
445 yield module.mux_sel.eq(sel)
446 yield Delay(0.1e-6)
447 y = 0
448 # do the partitioned tests
449 for i, mask in enumerate(mask_list):
450 if (selmask & mask):
451 y |= (a & mask)
452 else:
453 y |= (b & mask)
454 # check the result
455 outval = (yield module.mux_out)
456 msg = f"{msg_prefix}: mux " + \
457 f"0x{sel:X} ? 0x{a:X} : 0x{b:X}" + \
458 f" => 0x{y:X} != 0x{outval:X}, masklist %s"
459 # print ((msg % str(maskbit_list)).format(locals()))
460 self.assertEqual(y, outval, msg % str(maskbit_list))
461
462 yield part_mask.eq(0)
463 yield from test_muxop("16-bit", 0b1111)
464 yield part_mask.eq(0b10)
465 yield from test_muxop("8-bit", 0b1100, 0b0011)
466 yield part_mask.eq(0b1111)
467 yield from test_muxop("4-bit", 0b1000, 0b0100, 0b0010, 0b0001)
468
469 sim.add_process(async_process)
470 with sim.write_vcd(
471 vcd_file=open(test_name + ".vcd", "w"),
472 gtkw_file=open(test_name + ".gtkw", "w"),
473 traces=traces):
474 sim.run()
475
476
477 if __name__ == '__main__':
478 unittest.main()