95425c0315d1d3e34e8279d2a4989312a501d306
[ieee754fpu.git] / src / ieee754 / part / test / test_partsig.py
1 #!/usr/bin/env python3
2 # SPDX-License-Identifier: LGPL-2.1-or-later
3 # See Notices.txt for copyright information
4
5 from nmigen import Signal, Module, Elaboratable, Mux, Cat, Shape, Repl
6 from nmigen.back.pysim import Simulator, Delay, Settle
7 from nmigen.cli import rtlil
8
9 from ieee754.part.partsig import SimdSignal
10 from ieee754.part_mux.part_mux import PMux
11
12 from random import randint
13 import unittest
14 import itertools
15 import math
16
17
18 def first_zero(x):
19 res = 0
20 for i in range(16):
21 if x & (1 << i):
22 return res
23 res += 1
24
25
26 def count_bits(x):
27 res = 0
28 for i in range(16):
29 if x & (1 << i):
30 res += 1
31 return res
32
33
34 def perms(k):
35 return map(''.join, itertools.product('01', repeat=k))
36
37
38 def create_ilang(dut, traces, test_name):
39 vl = rtlil.convert(dut, ports=traces)
40 with open("%s.il" % test_name, "w") as f:
41 f.write(vl)
42
43
44 def create_simulator(module, traces, test_name):
45 create_ilang(module, traces, test_name)
46 return Simulator(module)
47
48
49 # XXX this is for coriolis2 experimentation
50 class TestAddMod2(Elaboratable):
51 def __init__(self, width, partpoints):
52 self.partpoints = partpoints
53 self.a = SimdSignal(partpoints, width)
54 self.b = SimdSignal(partpoints, width)
55 self.bsig = Signal(width)
56 self.add_output = Signal(width)
57 self.ls_output = Signal(width) # left shift
58 self.ls_scal_output = Signal(width) # left shift
59 self.rs_output = Signal(width) # right shift
60 self.rs_scal_output = Signal(width) # right shift
61 self.sub_output = Signal(width)
62 self.eq_output = Signal(len(partpoints)+1)
63 self.gt_output = Signal(len(partpoints)+1)
64 self.ge_output = Signal(len(partpoints)+1)
65 self.ne_output = Signal(len(partpoints)+1)
66 self.lt_output = Signal(len(partpoints)+1)
67 self.le_output = Signal(len(partpoints)+1)
68 self.mux_sel2 = Signal(len(partpoints)+1)
69 self.mux_sel2 = SimdSignal(partpoints, len(partpoints))
70 self.mux2_out = Signal(width)
71 self.carry_in = Signal(len(partpoints)+1)
72 self.add_carry_out = Signal(len(partpoints)+1)
73 self.sub_carry_out = Signal(len(partpoints)+1)
74 self.neg_output = Signal(width)
75
76 def elaborate(self, platform):
77 m = Module()
78 comb = m.d.comb
79 sync = m.d.sync
80 self.a.set_module(m)
81 self.b.set_module(m)
82 self.mux_sel2.set_module(m)
83 # compares
84 sync += self.lt_output.eq(self.a < self.b)
85 sync += self.ne_output.eq(self.a != self.b)
86 sync += self.le_output.eq(self.a <= self.b)
87 sync += self.gt_output.eq(self.a > self.b)
88 sync += self.eq_output.eq(self.a == self.b)
89 sync += self.ge_output.eq(self.a >= self.b)
90 # add
91 add_out, add_carry = self.a.add_op(self.a, self.b,
92 self.carry_in)
93 sync += self.add_output.eq(add_out)
94 sync += self.add_carry_out.eq(add_carry)
95 # sub
96 sub_out, sub_carry = self.a.sub_op(self.a, self.b,
97 self.carry_in)
98 sync += self.sub_output.eq(sub_out)
99 sync += self.sub_carry_out.eq(sub_carry)
100 # neg
101 sync += self.neg_output.eq(-self.a)
102 # left shift
103 sync += self.ls_output.eq(self.a << self.b)
104 sync += self.rs_output.eq(self.a >> self.b)
105 ppts = self.partpoints
106 sync += self.mux_out2.eq(Mux(self.mux_sel2, self.a, self.b))
107 # scalar left shift
108 comb += self.bsig.eq(self.b.lower())
109 sync += self.ls_scal_output.eq(self.a << self.bsig)
110 sync += self.rs_scal_output.eq(self.a >> self.bsig)
111
112 return m
113
114
115 class TestMuxMod(Elaboratable):
116 def __init__(self, width, partpoints):
117 self.partpoints = partpoints
118 self.a = SimdSignal(partpoints, width)
119 self.b = SimdSignal(partpoints, width)
120 self.mux_sel = Signal(len(partpoints)+1)
121 self.mux_sel2 = SimdSignal(partpoints, len(partpoints)+1)
122 self.mux_out2 = Signal(width)
123
124 def elaborate(self, platform):
125 m = Module()
126 comb = m.d.comb
127 sync = m.d.sync
128 self.a.set_module(m)
129 self.b.set_module(m)
130 self.mux_sel2.set_module(m)
131 ppts = self.partpoints
132
133 comb += self.mux_out2.eq(Mux(self.mux_sel2, self.a, self.b))
134
135 return m
136
137
138 class TestCatMod(Elaboratable):
139 def __init__(self, width, partpoints):
140 self.partpoints = partpoints
141 self.a = SimdSignal(partpoints, width)
142 self.b = SimdSignal(partpoints, width*2)
143 self.o = SimdSignal(partpoints, width*3)
144 self.cat_out = self.o.sig
145
146 def elaborate(self, platform):
147 m = Module()
148 comb = m.d.comb
149 self.a.set_module(m)
150 self.b.set_module(m)
151 self.o.set_module(m)
152
153 comb += self.o.eq(Cat(self.a, self.b))
154
155 return m
156
157
158 class TestReplMod(Elaboratable):
159 def __init__(self, width, partpoints):
160 self.partpoints = partpoints
161 self.a = SimdSignal(partpoints, width)
162 self.repl_sel = Signal(len(partpoints)+1)
163 self.repl_out = Signal(width*2)
164
165 def elaborate(self, platform):
166 m = Module()
167 comb = m.d.comb
168 self.a.set_module(m)
169
170 comb += self.repl_out.eq(Repl(self.a, 2))
171
172 return m
173
174
175 class TestAssMod(Elaboratable):
176 def __init__(self, width, out_shape, partpoints, scalar):
177 self.partpoints = partpoints
178 self.scalar = scalar
179 if scalar:
180 self.a = Signal(width)
181 else:
182 self.a = SimdSignal(partpoints, width)
183 self.ass_out = SimdSignal(partpoints, out_shape)
184
185 def elaborate(self, platform):
186 m = Module()
187 comb = m.d.comb
188 if not self.scalar:
189 self.a.set_module(m)
190 self.ass_out.set_module(m)
191
192 comb += self.ass_out.eq(self.a)
193
194 return m
195
196
197 class TestAddMod(Elaboratable):
198 def __init__(self, width, partpoints):
199 self.partpoints = partpoints
200 self.a = SimdSignal(partpoints, width)
201 self.b = SimdSignal(partpoints, width)
202 self.bsig = Signal(width)
203 self.add_output = Signal(width)
204 self.ls_output = Signal(width) # left shift
205 self.ls_scal_output = Signal(width) # left shift
206 self.rs_output = Signal(width) # right shift
207 self.rs_scal_output = Signal(width) # right shift
208 self.sub_output = Signal(width)
209 self.eq_output = Signal(len(partpoints)+1)
210 self.gt_output = Signal(len(partpoints)+1)
211 self.ge_output = Signal(len(partpoints)+1)
212 self.ne_output = Signal(len(partpoints)+1)
213 self.lt_output = Signal(len(partpoints)+1)
214 self.le_output = Signal(len(partpoints)+1)
215 self.carry_in = Signal(len(partpoints)+1)
216 self.add_carry_out = Signal(len(partpoints)+1)
217 self.sub_carry_out = Signal(len(partpoints)+1)
218 self.neg_output = Signal(width)
219 self.signed_output = Signal(width)
220 self.xor_output = Signal(len(partpoints)+1)
221 self.bool_output = Signal(len(partpoints)+1)
222 self.all_output = Signal(len(partpoints)+1)
223 self.any_output = Signal(len(partpoints)+1)
224
225 def elaborate(self, platform):
226 m = Module()
227 comb = m.d.comb
228 sync = m.d.sync
229 self.a.set_module(m)
230 self.b.set_module(m)
231 # compares
232 comb += self.lt_output.eq(self.a < self.b)
233 comb += self.ne_output.eq(self.a != self.b)
234 comb += self.le_output.eq(self.a <= self.b)
235 comb += self.gt_output.eq(self.a > self.b)
236 comb += self.eq_output.eq(self.a == self.b)
237 comb += self.ge_output.eq(self.a >= self.b)
238 # add
239 add_out, add_carry = self.a.add_op(self.a, self.b,
240 self.carry_in)
241 comb += self.add_output.eq(add_out.sig)
242 comb += self.add_carry_out.eq(add_carry)
243 # sub
244 sub_out, sub_carry = self.a.sub_op(self.a, self.b,
245 self.carry_in)
246 comb += self.sub_output.eq(sub_out.sig)
247 comb += self.sub_carry_out.eq(sub_carry)
248 # neg / signed / unsigned
249 comb += self.neg_output.eq((-self.a).sig)
250 comb += self.signed_output.eq(self.a.as_signed())
251 # horizontal operators
252 comb += self.xor_output.eq(self.a.xor())
253 comb += self.bool_output.eq(self.a.bool())
254 comb += self.all_output.eq(self.a.all())
255 comb += self.any_output.eq(self.a.any())
256 # left shift
257 comb += self.ls_output.eq(self.a << self.b)
258 # right shift
259 comb += self.rs_output.eq(self.a >> self.b)
260 ppts = self.partpoints
261 # scalar left shift
262 comb += self.bsig.eq(self.b.lower())
263 comb += self.ls_scal_output.eq(self.a << self.bsig)
264 # scalar right shift
265 comb += self.rs_scal_output.eq(self.a >> self.bsig)
266
267 return m
268
269
270 class TestMux(unittest.TestCase):
271 @unittest.expectedFailure # FIXME: test fails in CI
272 def test(self):
273 width = 16
274 part_mask = Signal(3) # divide into 4-bits
275 module = TestMuxMod(width, part_mask)
276
277 test_name = "part_sig_mux"
278 traces = [part_mask,
279 module.a.sig,
280 module.b.sig,
281 module.mux_out2]
282 sim = create_simulator(module, traces, test_name)
283
284 def async_process():
285
286 def test_muxop(msg_prefix, *maskbit_list):
287 for a, b in [(0x0000, 0x0000),
288 (0x1234, 0x1234),
289 (0xABCD, 0xABCD),
290 (0xFFFF, 0x0000),
291 (0x0000, 0x0000),
292 (0xFFFF, 0xFFFF),
293 (0x0000, 0xFFFF)]:
294 # convert to mask_list
295 mask_list = []
296 for mb in maskbit_list:
297 v = 0
298 for i in range(4):
299 if mb & (1 << i):
300 v |= 0xf << (i*4)
301 mask_list.append(v)
302
303 # TODO: sel needs to go through permutations of mask_list
304 for p in perms(len(mask_list)):
305
306 sel = 0
307 selmask = 0
308 for i, v in enumerate(p):
309 if v == '1':
310 sel |= maskbit_list[i]
311 selmask |= mask_list[i]
312
313 yield module.a.lower().eq(a)
314 yield module.b.lower().eq(b)
315 yield module.mux_sel.eq(sel)
316 yield module.mux_sel2.lower().eq(sel)
317 yield Delay(0.1e-6)
318 y = 0
319 # do the partitioned tests
320 for i, mask in enumerate(mask_list):
321 if (selmask & mask):
322 y |= (a & mask)
323 else:
324 y |= (b & mask)
325 # check the result
326 outval2 = (yield module.mux_out2)
327 msg = f"{msg_prefix}: mux " + \
328 f"0x{sel:X} ? 0x{a:X} : 0x{b:X}" + \
329 f" => 0x{y:X} != 0x{outval2:X}, masklist %s"
330 # print ((msg % str(maskbit_list)).format(locals()))
331 self.assertEqual(y, outval2, msg % str(maskbit_list))
332
333 yield part_mask.eq(0)
334 yield from test_muxop("16-bit", 0b1111)
335 yield part_mask.eq(0b10)
336 yield from test_muxop("8-bit", 0b1100, 0b0011)
337 yield part_mask.eq(0b1111)
338 yield from test_muxop("4-bit", 0b1000, 0b0100, 0b0010, 0b0001)
339
340 sim.add_process(async_process)
341 with sim.write_vcd(
342 vcd_file=open(test_name + ".vcd", "w"),
343 gtkw_file=open(test_name + ".gtkw", "w"),
344 traces=traces):
345 sim.run()
346
347
348 class TestCat(unittest.TestCase):
349 @unittest.expectedFailure # FIXME: test fails in CI
350 def test(self):
351 width = 16
352 part_mask = Signal(3) # divide into 4-bits
353 module = TestCatMod(width, part_mask)
354
355 test_name = "part_sig_cat"
356 traces = [part_mask,
357 module.a.sig,
358 module.b.sig,
359 module.cat_out]
360 sim = create_simulator(module, traces, test_name)
361
362 # annoying recursive import issue
363 from ieee754.part_cat.cat import get_runlengths
364
365 def async_process():
366
367 def test_catop(msg_prefix):
368 # define lengths of a/b test input
369 alen, blen = 16, 32
370 # pairs of test values a, b
371 for a, b in [(0x0000, 0x00000000),
372 (0xDCBA, 0x12345678),
373 (0xABCD, 0x01234567),
374 (0xFFFF, 0x0000),
375 (0x0000, 0x0000),
376 (0x1F1F, 0xF1F1F1F1),
377 (0x0000, 0xFFFFFFFF)]:
378
379 # convert a and b to partitions
380 apart, bpart = [], []
381 ajump, bjump = alen // 4, blen // 4
382 for i in range(4):
383 apart.append((a >> (ajump*i) & ((1 << ajump)-1)))
384 bpart.append((b >> (bjump*i) & ((1 << bjump)-1)))
385
386 print("apart bpart", hex(a), hex(b),
387 list(map(hex, apart)), list(map(hex, bpart)))
388
389 yield module.a.lower().eq(a)
390 yield module.b.lower().eq(b)
391 yield Delay(0.1e-6)
392
393 y = 0
394 # work out the runlengths for this mask.
395 # 0b011 returns [1,1,2] (for a mask of length 3)
396 mval = yield part_mask
397 runlengths = get_runlengths(mval, 3)
398 j = 0
399 ai = 0
400 bi = 0
401 for i in runlengths:
402 # a first
403 for _ in range(i):
404 print("runlength", i,
405 "ai", ai,
406 "apart", hex(apart[ai]),
407 "j", j)
408 y |= apart[ai] << j
409 print(" y", hex(y))
410 j += ajump
411 ai += 1
412 # now b
413 for _ in range(i):
414 print("runlength", i,
415 "bi", bi,
416 "bpart", hex(bpart[bi]),
417 "j", j)
418 y |= bpart[bi] << j
419 print(" y", hex(y))
420 j += bjump
421 bi += 1
422
423 # check the result
424 outval = (yield module.cat_out)
425 msg = f"{msg_prefix}: cat " + \
426 f"0x{mval:X} 0x{a:X} : 0x{b:X}" + \
427 f" => 0x{y:X} != 0x{outval:X}"
428 self.assertEqual(y, outval, msg)
429
430 yield part_mask.eq(0)
431 yield from test_catop("16-bit")
432 yield part_mask.eq(0b10)
433 yield from test_catop("8-bit")
434 yield part_mask.eq(0b1111)
435 yield from test_catop("4-bit")
436
437 sim.add_process(async_process)
438 with sim.write_vcd(
439 vcd_file=open(test_name + ".vcd", "w"),
440 gtkw_file=open(test_name + ".gtkw", "w"),
441 traces=traces):
442 sim.run()
443
444
445 class TestRepl(unittest.TestCase):
446 @unittest.expectedFailure # FIXME: test fails in CI
447 def test(self):
448 width = 16
449 part_mask = Signal(3) # divide into 4-bits
450 module = TestReplMod(width, part_mask)
451
452 test_name = "part_sig_repl"
453 traces = [part_mask,
454 module.a.sig,
455 module.repl_out]
456 sim = create_simulator(module, traces, test_name)
457
458 # annoying recursive import issue
459 from ieee754.part_repl.repl import get_runlengths
460
461 def async_process():
462
463 def test_replop(msg_prefix):
464 # define length of a test input
465 alen = 16
466 # test values a
467 for a in [0x0000,
468 0xDCBA,
469 0x1234,
470 0xABCD,
471 0xFFFF,
472 0x0000,
473 0x1F1F,
474 0xF1F1,
475 ]:
476
477 # convert a to partitions
478 apart = []
479 ajump = alen // 4
480 for i in range(4):
481 apart.append((a >> (ajump*i) & ((1 << ajump)-1)))
482
483 print("apart", hex(a), list(map(hex, apart)))
484
485 yield module.a.lower().eq(a)
486 yield Delay(0.1e-6)
487
488 y = 0
489 # work out the runlengths for this mask.
490 # 0b011 returns [1,1,2] (for a mask of length 3)
491 mval = yield part_mask
492 runlengths = get_runlengths(mval, 3)
493 j = 0
494 ai = [0, 0]
495 for i in runlengths:
496 # a twice because the test is Repl(a, 2)
497 for aidx in range(2):
498 for _ in range(i):
499 print("runlength", i,
500 "ai", ai,
501 "apart", hex(apart[ai[aidx]]),
502 "j", j)
503 y |= apart[ai[aidx]] << j
504 print(" y", hex(y))
505 j += ajump
506 ai[aidx] += 1
507
508 # check the result
509 outval = (yield module.repl_out)
510 msg = f"{msg_prefix}: repl " + \
511 f"0x{mval:X} 0x{a:X}" + \
512 f" => 0x{y:X} != 0x{outval:X}"
513 self.assertEqual(y, outval, msg)
514
515 yield part_mask.eq(0)
516 yield from test_replop("16-bit")
517 yield part_mask.eq(0b10)
518 yield from test_replop("8-bit")
519 yield part_mask.eq(0b1111)
520 yield from test_replop("4-bit")
521
522 sim.add_process(async_process)
523 with sim.write_vcd(
524 vcd_file=open(test_name + ".vcd", "w"),
525 gtkw_file=open(test_name + ".gtkw", "w"),
526 traces=traces):
527 sim.run()
528
529
530 class TestAssign(unittest.TestCase):
531 def run_tst(self, in_width, out_width, out_signed, scalar):
532 part_mask = Signal(3) # divide into 4-bits
533 module = TestAssMod(in_width,
534 Shape(out_width, out_signed),
535 part_mask, scalar)
536
537 test_name = "part_sig_ass_%d_%d_%s_%s" % (in_width, out_width,
538 "signed" if out_signed else "unsigned",
539 "scalar" if scalar else "partitioned")
540
541 traces = [part_mask,
542 module.ass_out.lower()]
543 if module.scalar:
544 traces.append(module.a)
545 else:
546 traces.append(module.a.lower())
547 sim = create_simulator(module, traces, test_name)
548
549 # annoying recursive import issue
550 from ieee754.part_cat.cat import get_runlengths
551
552 def async_process():
553
554 def test_assop(msg_prefix):
555 # define lengths of a test input
556 alen = in_width
557 randomvals = []
558 for i in range(10):
559 randomvals.append(randint(0, 65535))
560 # test values a
561 for a in [0x0001,
562 0x0010,
563 0x0100,
564 0x1000,
565 0x000c,
566 0x00c0,
567 0x0c00,
568 0xc000,
569 0x1234,
570 0xDCBA,
571 0xABCD,
572 0x0000,
573 0xFFFF,
574 ] + randomvals:
575 # work out the runlengths for this mask.
576 # 0b011 returns [1,1,2] (for a mask of length 3)
577 mval = yield part_mask
578 runlengths = get_runlengths(mval, 3)
579
580 print("test a", hex(a), "mask", bin(mval), "widths",
581 in_width, out_width,
582 "signed", out_signed,
583 "scalar", scalar)
584
585 # convert a to runlengths sub-sections
586 apart = []
587 ajump = alen // 4
588 ai = 0
589 for i in runlengths:
590 subpart = (a >> (ajump*ai) & ((1 << (ajump*i))-1))
591 # will contain the sign
592 msb = (subpart >> ((ajump*i)-1))
593 apart.append((subpart, msb))
594 print("apart", ajump*i, hex(a), hex(subpart), msb)
595 if not scalar:
596 ai += i
597
598 if scalar:
599 yield module.a.eq(a)
600 else:
601 yield module.a.lower().eq(a)
602 yield Delay(0.1e-6)
603
604 y = 0
605 j = 0
606 ojump = out_width // 4
607 for ai, i in enumerate(runlengths):
608 # get "a" partition value
609 av, amsb = apart[ai]
610 # do sign-extension if needed
611 signext = 0
612 if out_signed and ojump > ajump:
613 if amsb:
614 signext = (-1 << ajump *
615 i) & ((1 << (ojump*i))-1)
616 av |= signext
617 # truncate if needed
618 if ojump < ajump:
619 av &= ((1 << (ojump*i))-1)
620 print("runlength", i,
621 "ai", ai,
622 "apart", hex(av), amsb,
623 "signext", hex(signext),
624 "j", j)
625 y |= av << j
626 print(" y", hex(y))
627 j += ojump*i
628 ai += 1
629
630 y &= (1 << out_width)-1
631
632 # check the result
633 outval = (yield module.ass_out.lower())
634 outval &= (1 << out_width)-1
635 msg = f"{msg_prefix}: assign " + \
636 f"mask 0x{mval:X} input 0x{a:X}" + \
637 f" => expected 0x{y:X} != actual 0x{outval:X}"
638 self.assertEqual(y, outval, msg)
639
640 # run the actual tests, here - 16/8/4 bit partitions
641 for (mask, name) in ((0, "16-bit"),
642 (0b10, "8-bit"),
643 (0b111, "4-bit")):
644 with self.subTest(name + " " + test_name):
645 yield part_mask.eq(mask)
646 yield Settle()
647 yield from test_assop(name)
648
649 sim.add_process(async_process)
650 with sim.write_vcd(
651 vcd_file=open(test_name + ".vcd", "w"),
652 gtkw_file=open(test_name + ".gtkw", "w"),
653 traces=traces):
654 sim.run()
655
656 @unittest.expectedFailure # FIXME: test fails in CI
657 def test(self):
658 for out_width in [16, 24, 8]:
659 for sign in [True, False]:
660 for scalar in [True, False]:
661 self.run_tst(16, out_width, sign, scalar)
662
663
664 class TestSimdSignal(unittest.TestCase):
665 def test(self):
666 width = 16
667 part_mask = Signal(3) # divide into 4-bits
668 module = TestAddMod(width, part_mask)
669
670 test_name = "part_sig_add"
671 traces = [part_mask,
672 module.a.sig,
673 module.b.sig,
674 module.add_output,
675 module.eq_output]
676 sim = create_simulator(module, traces, test_name)
677
678 def async_process():
679
680 def test_xor_fn(a, mask):
681 test = (a & mask)
682 result = 0
683 while test != 0:
684 bit = (test & 1)
685 result ^= bit
686 test >>= 1
687 return result
688
689 def test_bool_fn(a, mask):
690 test = (a & mask)
691 return test != 0
692
693 def test_all_fn(a, mask):
694 # slightly different: all bits masked must be 1
695 test = (a & mask)
696 return test == mask
697
698 def test_horizop(msg_prefix, test_fn, mod_attr, *maskbit_list):
699 randomvals = []
700 for i in range(100):
701 randomvals.append(randint(0, 65535))
702 for a in [0x0000,
703 0x1111,
704 0x0001,
705 0x0010,
706 0x0100,
707 0x1000,
708 0x000F,
709 0x00F0,
710 0x0F00,
711 0xF000,
712 0x00FF,
713 0xFF00,
714 0x1234,
715 0xABCD,
716 0xFFFF,
717 0x8000,
718 0xBEEF, 0xFEED,
719 ]+randomvals:
720 with self.subTest("%s %s %s" % (msg_prefix,
721 test_fn.__name__, hex(a))):
722 yield module.a.lower().eq(a)
723 yield Delay(0.1e-6)
724 # convert to mask_list
725 mask_list = []
726 for mb in maskbit_list:
727 v = 0
728 for i in range(4):
729 if mb & (1 << i):
730 v |= 0xf << (i*4)
731 mask_list.append(v)
732 y = 0
733 # do the partitioned tests
734 for i, mask in enumerate(mask_list):
735 if test_fn(a, mask):
736 # OR y with the lowest set bit in the mask
737 y |= maskbit_list[i]
738 # check the result
739 outval = (yield getattr(module, "%s_output" % mod_attr))
740 msg = f"{msg_prefix}: {mod_attr} 0x{a:X} " + \
741 f" => 0x{y:X} != 0x{outval:X}, masklist %s"
742 print((msg % str(maskbit_list)).format(locals()))
743 self.assertEqual(y, outval, msg % str(maskbit_list))
744
745 for (test_fn, mod_attr) in ((test_xor_fn, "xor"),
746 (test_all_fn, "all"),
747 (test_bool_fn, "any"), # same as bool
748 (test_bool_fn, "bool"),
749 #(test_ne_fn, "ne"),
750 ):
751 yield part_mask.eq(0)
752 yield from test_horizop("16-bit", test_fn, mod_attr, 0b1111)
753 yield part_mask.eq(0b10)
754 yield from test_horizop("8-bit", test_fn, mod_attr,
755 0b1100, 0b0011)
756 yield part_mask.eq(0b1111)
757 yield from test_horizop("4-bit", test_fn, mod_attr,
758 0b1000, 0b0100, 0b0010, 0b0001)
759
760 def test_ls_scal_fn(carry_in, a, b, mask):
761 # reduce range of b
762 bits = count_bits(mask)
763 newb = b & ((bits-1))
764 print("%x %x %x bits %d trunc %x" %
765 (a, b, mask, bits, newb))
766 b = newb
767 # TODO: carry
768 carry_in = 0
769 lsb = mask & ~(mask-1) if carry_in else 0
770 sum = ((a & mask) << b)
771 result = mask & sum
772 carry = (sum & mask) != sum
773 carry = 0
774 print("res", hex(a), hex(b), hex(sum), hex(mask), hex(result))
775 return result, carry
776
777 def test_rs_scal_fn(carry_in, a, b, mask):
778 # reduce range of b
779 bits = count_bits(mask)
780 newb = b & ((bits-1))
781 print("%x %x %x bits %d trunc %x" %
782 (a, b, mask, bits, newb))
783 b = newb
784 # TODO: carry
785 carry_in = 0
786 lsb = mask & ~(mask-1) if carry_in else 0
787 sum = ((a & mask) >> b)
788 result = mask & sum
789 carry = (sum & mask) != sum
790 carry = 0
791 print("res", hex(a), hex(b), hex(sum), hex(mask), hex(result))
792 return result, carry
793
794 def test_ls_fn(carry_in, a, b, mask):
795 # reduce range of b
796 bits = count_bits(mask)
797 fz = first_zero(mask)
798 newb = b & ((bits-1) << fz)
799 print("%x %x %x bits %d zero %d trunc %x" %
800 (a, b, mask, bits, fz, newb))
801 b = newb
802 # TODO: carry
803 carry_in = 0
804 lsb = mask & ~(mask-1) if carry_in else 0
805 b = (b & mask)
806 b = b >> fz
807 sum = ((a & mask) << b)
808 result = mask & sum
809 carry = (sum & mask) != sum
810 carry = 0
811 print("res", hex(a), hex(b), hex(sum), hex(mask), hex(result))
812 return result, carry
813
814 def test_rs_fn(carry_in, a, b, mask):
815 # reduce range of b
816 bits = count_bits(mask)
817 fz = first_zero(mask)
818 newb = b & ((bits-1) << fz)
819 print("%x %x %x bits %d zero %d trunc %x" %
820 (a, b, mask, bits, fz, newb))
821 b = newb
822 # TODO: carry
823 carry_in = 0
824 lsb = mask & ~(mask-1) if carry_in else 0
825 b = (b & mask)
826 b = b >> fz
827 sum = ((a & mask) >> b)
828 result = mask & sum
829 carry = (sum & mask) != sum
830 carry = 0
831 print("res", hex(a), hex(b), hex(sum), hex(mask), hex(result))
832 return result, carry
833
834 def test_add_fn(carry_in, a, b, mask):
835 lsb = mask & ~(mask-1) if carry_in else 0
836 sum = (a & mask) + (b & mask) + lsb
837 result = mask & sum
838 carry = (sum & mask) != sum
839 print(a, b, sum, mask)
840 return result, carry
841
842 def test_sub_fn(carry_in, a, b, mask):
843 lsb = mask & ~(mask-1) if carry_in else 0
844 sum = (a & mask) + (~b & mask) + lsb
845 result = mask & sum
846 carry = (sum & mask) != sum
847 return result, carry
848
849 def test_neg_fn(carry_in, a, b, mask):
850 lsb = mask & ~(mask - 1) # has only LSB of mask set
851 pos = lsb.bit_length() - 1 # find bit position
852 a = (a & mask) >> pos # shift it to the beginning
853 return ((-a) << pos) & mask, 0 # negate and shift it back
854
855 def test_signed_fn(carry_in, a, b, mask):
856 return a & mask, 0
857
858 def test_op(msg_prefix, carry, test_fn, mod_attr, *mask_list):
859 rand_data = []
860 for i in range(100):
861 a, b = randint(0, 1 << 16), randint(0, 1 << 16)
862 rand_data.append((a, b))
863 for a, b in [(0x0000, 0x0000),
864 (0x1234, 0x1234),
865 (0xABCD, 0xABCD),
866 (0xFFFF, 0x0000),
867 (0x0000, 0x0000),
868 (0xFFFF, 0xFFFF),
869 (0x0000, 0xFFFF)] + rand_data:
870 yield module.a.lower().eq(a)
871 yield module.b.lower().eq(b)
872 carry_sig = 0xf if carry else 0
873 yield module.carry_in.eq(carry_sig)
874 yield Delay(0.1e-6)
875 y = 0
876 carry_result = 0
877 for i, mask in enumerate(mask_list):
878 print("i/mask", i, hex(mask))
879 res, c = test_fn(carry, a, b, mask)
880 y |= res
881 lsb = mask & ~(mask - 1)
882 bit_set = int(math.log2(lsb))
883 carry_result |= c << int(bit_set/4)
884 outval = (yield getattr(module, "%s_output" % mod_attr))
885 # TODO: get (and test) carry output as well
886 print(a, b, outval, carry)
887 msg = f"{msg_prefix}: 0x{a:X} {mod_attr} 0x{b:X}" + \
888 f" => 0x{y:X} != 0x{outval:X}"
889 self.assertEqual(y, outval, msg)
890 if hasattr(module, "%s_carry_out" % mod_attr):
891 c_outval = (yield getattr(module,
892 "%s_carry_out" % mod_attr))
893 msg = f"{msg_prefix}: 0x{a:X} {mod_attr} 0x{b:X}" + \
894 f" => 0x{carry_result:X} != 0x{c_outval:X}"
895 self.assertEqual(carry_result, c_outval, msg)
896
897 # run through series of operations with corresponding
898 # "helper" routines to reproduce the result (test_fn). the same
899 # a/b input is passed to *all* outputs, where the name of the
900 # output attribute (mod_attr) will contain the result to be
901 # compared against the expected output from test_fn
902 for (test_fn, mod_attr) in (
903 (test_ls_scal_fn, "ls_scal"),
904 (test_ls_fn, "ls"),
905 (test_rs_scal_fn, "rs_scal"),
906 (test_rs_fn, "rs"),
907 (test_add_fn, "add"),
908 (test_sub_fn, "sub"),
909 (test_neg_fn, "neg"),
910 (test_signed_fn, "signed"),
911 ):
912 yield part_mask.eq(0)
913 yield from test_op("16-bit", 1, test_fn, mod_attr, 0xFFFF)
914 yield from test_op("16-bit", 0, test_fn, mod_attr, 0xFFFF)
915 yield part_mask.eq(0b10)
916 yield from test_op("8-bit", 0, test_fn, mod_attr,
917 0xFF00, 0x00FF)
918 yield from test_op("8-bit", 1, test_fn, mod_attr,
919 0xFF00, 0x00FF)
920 yield part_mask.eq(0b1111)
921 yield from test_op("4-bit", 0, test_fn, mod_attr,
922 0xF000, 0x0F00, 0x00F0, 0x000F)
923 yield from test_op("4-bit", 1, test_fn, mod_attr,
924 0xF000, 0x0F00, 0x00F0, 0x000F)
925
926 def test_ne_fn(a, b, mask):
927 return (a & mask) != (b & mask)
928
929 def test_lt_fn(a, b, mask):
930 return (a & mask) < (b & mask)
931
932 def test_le_fn(a, b, mask):
933 return (a & mask) <= (b & mask)
934
935 def test_eq_fn(a, b, mask):
936 return (a & mask) == (b & mask)
937
938 def test_gt_fn(a, b, mask):
939 return (a & mask) > (b & mask)
940
941 def test_ge_fn(a, b, mask):
942 return (a & mask) >= (b & mask)
943
944 def test_binop(msg_prefix, test_fn, mod_attr, *maskbit_list):
945 for a, b in [(0x0000, 0x0000),
946 (0x1234, 0x1234),
947 (0xABCD, 0xABCD),
948 (0xFFFF, 0x0000),
949 (0x0000, 0x0000),
950 (0xFFFF, 0xFFFF),
951 (0x0000, 0xFFFF),
952 (0xABCD, 0xABCE),
953 (0x8000, 0x0000),
954 (0xBEEF, 0xFEED)]:
955 yield module.a.lower().eq(a)
956 yield module.b.lower().eq(b)
957 yield Delay(0.1e-6)
958 # convert to mask_list
959 mask_list = []
960 for mb in maskbit_list:
961 v = 0
962 for i in range(4):
963 if mb & (1 << i):
964 v |= 0xf << (i*4)
965 mask_list.append(v)
966 y = 0
967 # do the partitioned tests
968 for i, mask in enumerate(mask_list):
969 if test_fn(a, b, mask):
970 # OR y with the lowest set bit in the mask
971 y |= maskbit_list[i]
972 # check the result
973 outval = (yield getattr(module, "%s_output" % mod_attr))
974 msg = f"{msg_prefix}: {mod_attr} 0x{a:X} == 0x{b:X}" + \
975 f" => 0x{y:X} != 0x{outval:X}, masklist %s"
976 print((msg % str(maskbit_list)).format(locals()))
977 self.assertEqual(y, outval, msg % str(maskbit_list))
978
979 for (test_fn, mod_attr) in ((test_eq_fn, "eq"),
980 (test_gt_fn, "gt"),
981 (test_ge_fn, "ge"),
982 (test_lt_fn, "lt"),
983 (test_le_fn, "le"),
984 (test_ne_fn, "ne"),
985 ):
986 yield part_mask.eq(0)
987 yield from test_binop("16-bit", test_fn, mod_attr, 0b1111)
988 yield part_mask.eq(0b10)
989 yield from test_binop("8-bit", test_fn, mod_attr,
990 0b1100, 0b0011)
991 yield part_mask.eq(0b1111)
992 yield from test_binop("4-bit", test_fn, mod_attr,
993 0b1000, 0b0100, 0b0010, 0b0001)
994
995 sim.add_process(async_process)
996 with sim.write_vcd(
997 vcd_file=open(test_name + ".vcd", "w"),
998 gtkw_file=open(test_name + ".gtkw", "w"),
999 traces=traces):
1000 sim.run()
1001
1002
1003 # TODO: adapt to SimdSignal. perhaps a different style?
1004 r'''
1005 from nmigen.tests.test_hdl_ast import SignedEnum
1006 def test_matches(self)
1007 s = Signal(4)
1008 self.assertRepr(s.matches(), "(const 1'd0)")
1009 self.assertRepr(s.matches(1), """
1010 (== (sig s) (const 1'd1))
1011 """)
1012 self.assertRepr(s.matches(0, 1), """
1013 (r| (cat (== (sig s) (const 1'd0)) (== (sig s) (const 1'd1))))
1014 """)
1015 self.assertRepr(s.matches("10--"), """
1016 (== (& (sig s) (const 4'd12)) (const 4'd8))
1017 """)
1018 self.assertRepr(s.matches("1 0--"), """
1019 (== (& (sig s) (const 4'd12)) (const 4'd8))
1020 """)
1021
1022 def test_matches_enum(self):
1023 s = Signal(SignedEnum)
1024 self.assertRepr(s.matches(SignedEnum.FOO), """
1025 (== (sig s) (const 1'sd-1))
1026 """)
1027
1028 def test_matches_width_wrong(self):
1029 s = Signal(4)
1030 with self.assertRaisesRegex(SyntaxError,
1031 r"^Match pattern '--' must have the same width as "
1032 r"match value \(which is 4\)$"):
1033 s.matches("--")
1034 with self.assertWarnsRegex(SyntaxWarning,
1035 (r"^Match pattern '10110' is wider than match value "
1036 r"\(which has width 4\); "
1037 r"comparison will never be true$")):
1038 s.matches(0b10110)
1039
1040 def test_matches_bits_wrong(self):
1041 s = Signal(4)
1042 with self.assertRaisesRegex(SyntaxError,
1043 (r"^Match pattern 'abc' must consist of 0, 1, "
1044 r"and - \(don't care\) bits, "
1045 r"and may include whitespace$")):
1046 s.matches("abc")
1047
1048 def test_matches_pattern_wrong(self):
1049 s = Signal(4)
1050 with self.assertRaisesRegex(SyntaxError,
1051 r"^Match pattern must be an integer, a string, "
1052 r"or an enumeration, not 1\.0$"):
1053 s.matches(1.0)
1054 '''
1055
1056 if __name__ == '__main__':
1057 unittest.main()