95425c0315d1d3e34e8279d2a4989312a501d306
2 # SPDX-License-Identifier: LGPL-2.1-or-later
3 # See Notices.txt for copyright information
5 from nmigen
import Signal
, Module
, Elaboratable
, Mux
, Cat
, Shape
, Repl
6 from nmigen
.back
.pysim
import Simulator
, Delay
, Settle
7 from nmigen
.cli
import rtlil
9 from ieee754
.part
.partsig
import SimdSignal
10 from ieee754
.part_mux
.part_mux
import PMux
12 from random
import randint
35 return map(''.join
, itertools
.product('01', repeat
=k
))
38 def create_ilang(dut
, traces
, test_name
):
39 vl
= rtlil
.convert(dut
, ports
=traces
)
40 with
open("%s.il" % test_name
, "w") as f
:
44 def create_simulator(module
, traces
, test_name
):
45 create_ilang(module
, traces
, test_name
)
46 return Simulator(module
)
49 # XXX this is for coriolis2 experimentation
50 class TestAddMod2(Elaboratable
):
51 def __init__(self
, width
, partpoints
):
52 self
.partpoints
= partpoints
53 self
.a
= SimdSignal(partpoints
, width
)
54 self
.b
= SimdSignal(partpoints
, width
)
55 self
.bsig
= Signal(width
)
56 self
.add_output
= Signal(width
)
57 self
.ls_output
= Signal(width
) # left shift
58 self
.ls_scal_output
= Signal(width
) # left shift
59 self
.rs_output
= Signal(width
) # right shift
60 self
.rs_scal_output
= Signal(width
) # right shift
61 self
.sub_output
= Signal(width
)
62 self
.eq_output
= Signal(len(partpoints
)+1)
63 self
.gt_output
= Signal(len(partpoints
)+1)
64 self
.ge_output
= Signal(len(partpoints
)+1)
65 self
.ne_output
= Signal(len(partpoints
)+1)
66 self
.lt_output
= Signal(len(partpoints
)+1)
67 self
.le_output
= Signal(len(partpoints
)+1)
68 self
.mux_sel2
= Signal(len(partpoints
)+1)
69 self
.mux_sel2
= SimdSignal(partpoints
, len(partpoints
))
70 self
.mux2_out
= Signal(width
)
71 self
.carry_in
= Signal(len(partpoints
)+1)
72 self
.add_carry_out
= Signal(len(partpoints
)+1)
73 self
.sub_carry_out
= Signal(len(partpoints
)+1)
74 self
.neg_output
= Signal(width
)
76 def elaborate(self
, platform
):
82 self
.mux_sel2
.set_module(m
)
84 sync
+= self
.lt_output
.eq(self
.a
< self
.b
)
85 sync
+= self
.ne_output
.eq(self
.a
!= self
.b
)
86 sync
+= self
.le_output
.eq(self
.a
<= self
.b
)
87 sync
+= self
.gt_output
.eq(self
.a
> self
.b
)
88 sync
+= self
.eq_output
.eq(self
.a
== self
.b
)
89 sync
+= self
.ge_output
.eq(self
.a
>= self
.b
)
91 add_out
, add_carry
= self
.a
.add_op(self
.a
, self
.b
,
93 sync
+= self
.add_output
.eq(add_out
)
94 sync
+= self
.add_carry_out
.eq(add_carry
)
96 sub_out
, sub_carry
= self
.a
.sub_op(self
.a
, self
.b
,
98 sync
+= self
.sub_output
.eq(sub_out
)
99 sync
+= self
.sub_carry_out
.eq(sub_carry
)
101 sync
+= self
.neg_output
.eq(-self
.a
)
103 sync
+= self
.ls_output
.eq(self
.a
<< self
.b
)
104 sync
+= self
.rs_output
.eq(self
.a
>> self
.b
)
105 ppts
= self
.partpoints
106 sync
+= self
.mux_out2
.eq(Mux(self
.mux_sel2
, self
.a
, self
.b
))
108 comb
+= self
.bsig
.eq(self
.b
.lower())
109 sync
+= self
.ls_scal_output
.eq(self
.a
<< self
.bsig
)
110 sync
+= self
.rs_scal_output
.eq(self
.a
>> self
.bsig
)
115 class TestMuxMod(Elaboratable
):
116 def __init__(self
, width
, partpoints
):
117 self
.partpoints
= partpoints
118 self
.a
= SimdSignal(partpoints
, width
)
119 self
.b
= SimdSignal(partpoints
, width
)
120 self
.mux_sel
= Signal(len(partpoints
)+1)
121 self
.mux_sel2
= SimdSignal(partpoints
, len(partpoints
)+1)
122 self
.mux_out2
= Signal(width
)
124 def elaborate(self
, platform
):
130 self
.mux_sel2
.set_module(m
)
131 ppts
= self
.partpoints
133 comb
+= self
.mux_out2
.eq(Mux(self
.mux_sel2
, self
.a
, self
.b
))
138 class TestCatMod(Elaboratable
):
139 def __init__(self
, width
, partpoints
):
140 self
.partpoints
= partpoints
141 self
.a
= SimdSignal(partpoints
, width
)
142 self
.b
= SimdSignal(partpoints
, width
*2)
143 self
.o
= SimdSignal(partpoints
, width
*3)
144 self
.cat_out
= self
.o
.sig
146 def elaborate(self
, platform
):
153 comb
+= self
.o
.eq(Cat(self
.a
, self
.b
))
158 class TestReplMod(Elaboratable
):
159 def __init__(self
, width
, partpoints
):
160 self
.partpoints
= partpoints
161 self
.a
= SimdSignal(partpoints
, width
)
162 self
.repl_sel
= Signal(len(partpoints
)+1)
163 self
.repl_out
= Signal(width
*2)
165 def elaborate(self
, platform
):
170 comb
+= self
.repl_out
.eq(Repl(self
.a
, 2))
175 class TestAssMod(Elaboratable
):
176 def __init__(self
, width
, out_shape
, partpoints
, scalar
):
177 self
.partpoints
= partpoints
180 self
.a
= Signal(width
)
182 self
.a
= SimdSignal(partpoints
, width
)
183 self
.ass_out
= SimdSignal(partpoints
, out_shape
)
185 def elaborate(self
, platform
):
190 self
.ass_out
.set_module(m
)
192 comb
+= self
.ass_out
.eq(self
.a
)
197 class TestAddMod(Elaboratable
):
198 def __init__(self
, width
, partpoints
):
199 self
.partpoints
= partpoints
200 self
.a
= SimdSignal(partpoints
, width
)
201 self
.b
= SimdSignal(partpoints
, width
)
202 self
.bsig
= Signal(width
)
203 self
.add_output
= Signal(width
)
204 self
.ls_output
= Signal(width
) # left shift
205 self
.ls_scal_output
= Signal(width
) # left shift
206 self
.rs_output
= Signal(width
) # right shift
207 self
.rs_scal_output
= Signal(width
) # right shift
208 self
.sub_output
= Signal(width
)
209 self
.eq_output
= Signal(len(partpoints
)+1)
210 self
.gt_output
= Signal(len(partpoints
)+1)
211 self
.ge_output
= Signal(len(partpoints
)+1)
212 self
.ne_output
= Signal(len(partpoints
)+1)
213 self
.lt_output
= Signal(len(partpoints
)+1)
214 self
.le_output
= Signal(len(partpoints
)+1)
215 self
.carry_in
= Signal(len(partpoints
)+1)
216 self
.add_carry_out
= Signal(len(partpoints
)+1)
217 self
.sub_carry_out
= Signal(len(partpoints
)+1)
218 self
.neg_output
= Signal(width
)
219 self
.signed_output
= Signal(width
)
220 self
.xor_output
= Signal(len(partpoints
)+1)
221 self
.bool_output
= Signal(len(partpoints
)+1)
222 self
.all_output
= Signal(len(partpoints
)+1)
223 self
.any_output
= Signal(len(partpoints
)+1)
225 def elaborate(self
, platform
):
232 comb
+= self
.lt_output
.eq(self
.a
< self
.b
)
233 comb
+= self
.ne_output
.eq(self
.a
!= self
.b
)
234 comb
+= self
.le_output
.eq(self
.a
<= self
.b
)
235 comb
+= self
.gt_output
.eq(self
.a
> self
.b
)
236 comb
+= self
.eq_output
.eq(self
.a
== self
.b
)
237 comb
+= self
.ge_output
.eq(self
.a
>= self
.b
)
239 add_out
, add_carry
= self
.a
.add_op(self
.a
, self
.b
,
241 comb
+= self
.add_output
.eq(add_out
.sig
)
242 comb
+= self
.add_carry_out
.eq(add_carry
)
244 sub_out
, sub_carry
= self
.a
.sub_op(self
.a
, self
.b
,
246 comb
+= self
.sub_output
.eq(sub_out
.sig
)
247 comb
+= self
.sub_carry_out
.eq(sub_carry
)
248 # neg / signed / unsigned
249 comb
+= self
.neg_output
.eq((-self
.a
).sig
)
250 comb
+= self
.signed_output
.eq(self
.a
.as_signed())
251 # horizontal operators
252 comb
+= self
.xor_output
.eq(self
.a
.xor())
253 comb
+= self
.bool_output
.eq(self
.a
.bool())
254 comb
+= self
.all_output
.eq(self
.a
.all())
255 comb
+= self
.any_output
.eq(self
.a
.any())
257 comb
+= self
.ls_output
.eq(self
.a
<< self
.b
)
259 comb
+= self
.rs_output
.eq(self
.a
>> self
.b
)
260 ppts
= self
.partpoints
262 comb
+= self
.bsig
.eq(self
.b
.lower())
263 comb
+= self
.ls_scal_output
.eq(self
.a
<< self
.bsig
)
265 comb
+= self
.rs_scal_output
.eq(self
.a
>> self
.bsig
)
270 class TestMux(unittest
.TestCase
):
271 @unittest.expectedFailure
# FIXME: test fails in CI
274 part_mask
= Signal(3) # divide into 4-bits
275 module
= TestMuxMod(width
, part_mask
)
277 test_name
= "part_sig_mux"
282 sim
= create_simulator(module
, traces
, test_name
)
286 def test_muxop(msg_prefix
, *maskbit_list
):
287 for a
, b
in [(0x0000, 0x0000),
294 # convert to mask_list
296 for mb
in maskbit_list
:
303 # TODO: sel needs to go through permutations of mask_list
304 for p
in perms(len(mask_list
)):
308 for i
, v
in enumerate(p
):
310 sel |
= maskbit_list
[i
]
311 selmask |
= mask_list
[i
]
313 yield module
.a
.lower().eq(a
)
314 yield module
.b
.lower().eq(b
)
315 yield module
.mux_sel
.eq(sel
)
316 yield module
.mux_sel2
.lower().eq(sel
)
319 # do the partitioned tests
320 for i
, mask
in enumerate(mask_list
):
326 outval2
= (yield module
.mux_out2
)
327 msg
= f
"{msg_prefix}: mux " + \
328 f
"0x{sel:X} ? 0x{a:X} : 0x{b:X}" + \
329 f
" => 0x{y:X} != 0x{outval2:X}, masklist %s"
330 # print ((msg % str(maskbit_list)).format(locals()))
331 self
.assertEqual(y
, outval2
, msg
% str(maskbit_list
))
333 yield part_mask
.eq(0)
334 yield from test_muxop("16-bit", 0b1111)
335 yield part_mask
.eq(0b10)
336 yield from test_muxop("8-bit", 0b1100, 0b0011)
337 yield part_mask
.eq(0b1111)
338 yield from test_muxop("4-bit", 0b1000, 0b0100, 0b0010, 0b0001)
340 sim
.add_process(async_process
)
342 vcd_file
=open(test_name
+ ".vcd", "w"),
343 gtkw_file
=open(test_name
+ ".gtkw", "w"),
348 class TestCat(unittest
.TestCase
):
349 @unittest.expectedFailure
# FIXME: test fails in CI
352 part_mask
= Signal(3) # divide into 4-bits
353 module
= TestCatMod(width
, part_mask
)
355 test_name
= "part_sig_cat"
360 sim
= create_simulator(module
, traces
, test_name
)
362 # annoying recursive import issue
363 from ieee754
.part_cat
.cat
import get_runlengths
367 def test_catop(msg_prefix
):
368 # define lengths of a/b test input
370 # pairs of test values a, b
371 for a
, b
in [(0x0000, 0x00000000),
372 (0xDCBA, 0x12345678),
373 (0xABCD, 0x01234567),
376 (0x1F1F, 0xF1F1F1F1),
377 (0x0000, 0xFFFFFFFF)]:
379 # convert a and b to partitions
380 apart
, bpart
= [], []
381 ajump
, bjump
= alen
// 4, blen
// 4
383 apart
.append((a
>> (ajump
*i
) & ((1 << ajump
)-1)))
384 bpart
.append((b
>> (bjump
*i
) & ((1 << bjump
)-1)))
386 print("apart bpart", hex(a
), hex(b
),
387 list(map(hex, apart
)), list(map(hex, bpart
)))
389 yield module
.a
.lower().eq(a
)
390 yield module
.b
.lower().eq(b
)
394 # work out the runlengths for this mask.
395 # 0b011 returns [1,1,2] (for a mask of length 3)
396 mval
= yield part_mask
397 runlengths
= get_runlengths(mval
, 3)
404 print("runlength", i
,
406 "apart", hex(apart
[ai
]),
414 print("runlength", i
,
416 "bpart", hex(bpart
[bi
]),
424 outval
= (yield module
.cat_out
)
425 msg
= f
"{msg_prefix}: cat " + \
426 f
"0x{mval:X} 0x{a:X} : 0x{b:X}" + \
427 f
" => 0x{y:X} != 0x{outval:X}"
428 self
.assertEqual(y
, outval
, msg
)
430 yield part_mask
.eq(0)
431 yield from test_catop("16-bit")
432 yield part_mask
.eq(0b10)
433 yield from test_catop("8-bit")
434 yield part_mask
.eq(0b1111)
435 yield from test_catop("4-bit")
437 sim
.add_process(async_process
)
439 vcd_file
=open(test_name
+ ".vcd", "w"),
440 gtkw_file
=open(test_name
+ ".gtkw", "w"),
445 class TestRepl(unittest
.TestCase
):
446 @unittest.expectedFailure
# FIXME: test fails in CI
449 part_mask
= Signal(3) # divide into 4-bits
450 module
= TestReplMod(width
, part_mask
)
452 test_name
= "part_sig_repl"
456 sim
= create_simulator(module
, traces
, test_name
)
458 # annoying recursive import issue
459 from ieee754
.part_repl
.repl
import get_runlengths
463 def test_replop(msg_prefix
):
464 # define length of a test input
477 # convert a to partitions
481 apart
.append((a
>> (ajump
*i
) & ((1 << ajump
)-1)))
483 print("apart", hex(a
), list(map(hex, apart
)))
485 yield module
.a
.lower().eq(a
)
489 # work out the runlengths for this mask.
490 # 0b011 returns [1,1,2] (for a mask of length 3)
491 mval
= yield part_mask
492 runlengths
= get_runlengths(mval
, 3)
496 # a twice because the test is Repl(a, 2)
497 for aidx
in range(2):
499 print("runlength", i
,
501 "apart", hex(apart
[ai
[aidx
]]),
503 y |
= apart
[ai
[aidx
]] << j
509 outval
= (yield module
.repl_out
)
510 msg
= f
"{msg_prefix}: repl " + \
511 f
"0x{mval:X} 0x{a:X}" + \
512 f
" => 0x{y:X} != 0x{outval:X}"
513 self
.assertEqual(y
, outval
, msg
)
515 yield part_mask
.eq(0)
516 yield from test_replop("16-bit")
517 yield part_mask
.eq(0b10)
518 yield from test_replop("8-bit")
519 yield part_mask
.eq(0b1111)
520 yield from test_replop("4-bit")
522 sim
.add_process(async_process
)
524 vcd_file
=open(test_name
+ ".vcd", "w"),
525 gtkw_file
=open(test_name
+ ".gtkw", "w"),
530 class TestAssign(unittest
.TestCase
):
531 def run_tst(self
, in_width
, out_width
, out_signed
, scalar
):
532 part_mask
= Signal(3) # divide into 4-bits
533 module
= TestAssMod(in_width
,
534 Shape(out_width
, out_signed
),
537 test_name
= "part_sig_ass_%d_%d_%s_%s" % (in_width
, out_width
,
538 "signed" if out_signed
else "unsigned",
539 "scalar" if scalar
else "partitioned")
542 module
.ass_out
.lower()]
544 traces
.append(module
.a
)
546 traces
.append(module
.a
.lower())
547 sim
= create_simulator(module
, traces
, test_name
)
549 # annoying recursive import issue
550 from ieee754
.part_cat
.cat
import get_runlengths
554 def test_assop(msg_prefix
):
555 # define lengths of a test input
559 randomvals
.append(randint(0, 65535))
575 # work out the runlengths for this mask.
576 # 0b011 returns [1,1,2] (for a mask of length 3)
577 mval
= yield part_mask
578 runlengths
= get_runlengths(mval
, 3)
580 print("test a", hex(a
), "mask", bin(mval
), "widths",
582 "signed", out_signed
,
585 # convert a to runlengths sub-sections
590 subpart
= (a
>> (ajump
*ai
) & ((1 << (ajump
*i
))-1))
591 # will contain the sign
592 msb
= (subpart
>> ((ajump
*i
)-1))
593 apart
.append((subpart
, msb
))
594 print("apart", ajump
*i
, hex(a
), hex(subpart
), msb
)
601 yield module
.a
.lower().eq(a
)
606 ojump
= out_width
// 4
607 for ai
, i
in enumerate(runlengths
):
608 # get "a" partition value
610 # do sign-extension if needed
612 if out_signed
and ojump
> ajump
:
614 signext
= (-1 << ajump
*
615 i
) & ((1 << (ojump
*i
))-1)
619 av
&= ((1 << (ojump
*i
))-1)
620 print("runlength", i
,
622 "apart", hex(av
), amsb
,
623 "signext", hex(signext
),
630 y
&= (1 << out_width
)-1
633 outval
= (yield module
.ass_out
.lower())
634 outval
&= (1 << out_width
)-1
635 msg
= f
"{msg_prefix}: assign " + \
636 f
"mask 0x{mval:X} input 0x{a:X}" + \
637 f
" => expected 0x{y:X} != actual 0x{outval:X}"
638 self
.assertEqual(y
, outval
, msg
)
640 # run the actual tests, here - 16/8/4 bit partitions
641 for (mask
, name
) in ((0, "16-bit"),
644 with self
.subTest(name
+ " " + test_name
):
645 yield part_mask
.eq(mask
)
647 yield from test_assop(name
)
649 sim
.add_process(async_process
)
651 vcd_file
=open(test_name
+ ".vcd", "w"),
652 gtkw_file
=open(test_name
+ ".gtkw", "w"),
656 @unittest.expectedFailure
# FIXME: test fails in CI
658 for out_width
in [16, 24, 8]:
659 for sign
in [True, False]:
660 for scalar
in [True, False]:
661 self
.run_tst(16, out_width
, sign
, scalar
)
664 class TestSimdSignal(unittest
.TestCase
):
667 part_mask
= Signal(3) # divide into 4-bits
668 module
= TestAddMod(width
, part_mask
)
670 test_name
= "part_sig_add"
676 sim
= create_simulator(module
, traces
, test_name
)
680 def test_xor_fn(a
, mask
):
689 def test_bool_fn(a
, mask
):
693 def test_all_fn(a
, mask
):
694 # slightly different: all bits masked must be 1
698 def test_horizop(msg_prefix
, test_fn
, mod_attr
, *maskbit_list
):
701 randomvals
.append(randint(0, 65535))
720 with self
.subTest("%s %s %s" % (msg_prefix
,
721 test_fn
.__name
__, hex(a
))):
722 yield module
.a
.lower().eq(a
)
724 # convert to mask_list
726 for mb
in maskbit_list
:
733 # do the partitioned tests
734 for i
, mask
in enumerate(mask_list
):
736 # OR y with the lowest set bit in the mask
739 outval
= (yield getattr(module
, "%s_output" % mod_attr
))
740 msg
= f
"{msg_prefix}: {mod_attr} 0x{a:X} " + \
741 f
" => 0x{y:X} != 0x{outval:X}, masklist %s"
742 print((msg
% str(maskbit_list
)).format(locals()))
743 self
.assertEqual(y
, outval
, msg
% str(maskbit_list
))
745 for (test_fn
, mod_attr
) in ((test_xor_fn
, "xor"),
746 (test_all_fn
, "all"),
747 (test_bool_fn
, "any"), # same as bool
748 (test_bool_fn
, "bool"),
751 yield part_mask
.eq(0)
752 yield from test_horizop("16-bit", test_fn
, mod_attr
, 0b1111)
753 yield part_mask
.eq(0b10)
754 yield from test_horizop("8-bit", test_fn
, mod_attr
,
756 yield part_mask
.eq(0b1111)
757 yield from test_horizop("4-bit", test_fn
, mod_attr
,
758 0b1000, 0b0100, 0b0010, 0b0001)
760 def test_ls_scal_fn(carry_in
, a
, b
, mask
):
762 bits
= count_bits(mask
)
763 newb
= b
& ((bits
-1))
764 print("%x %x %x bits %d trunc %x" %
765 (a
, b
, mask
, bits
, newb
))
769 lsb
= mask
& ~
(mask
-1) if carry_in
else 0
770 sum = ((a
& mask
) << b
)
772 carry
= (sum & mask
) != sum
774 print("res", hex(a
), hex(b
), hex(sum), hex(mask
), hex(result
))
777 def test_rs_scal_fn(carry_in
, a
, b
, mask
):
779 bits
= count_bits(mask
)
780 newb
= b
& ((bits
-1))
781 print("%x %x %x bits %d trunc %x" %
782 (a
, b
, mask
, bits
, newb
))
786 lsb
= mask
& ~
(mask
-1) if carry_in
else 0
787 sum = ((a
& mask
) >> b
)
789 carry
= (sum & mask
) != sum
791 print("res", hex(a
), hex(b
), hex(sum), hex(mask
), hex(result
))
794 def test_ls_fn(carry_in
, a
, b
, mask
):
796 bits
= count_bits(mask
)
797 fz
= first_zero(mask
)
798 newb
= b
& ((bits
-1) << fz
)
799 print("%x %x %x bits %d zero %d trunc %x" %
800 (a
, b
, mask
, bits
, fz
, newb
))
804 lsb
= mask
& ~
(mask
-1) if carry_in
else 0
807 sum = ((a
& mask
) << b
)
809 carry
= (sum & mask
) != sum
811 print("res", hex(a
), hex(b
), hex(sum), hex(mask
), hex(result
))
814 def test_rs_fn(carry_in
, a
, b
, mask
):
816 bits
= count_bits(mask
)
817 fz
= first_zero(mask
)
818 newb
= b
& ((bits
-1) << fz
)
819 print("%x %x %x bits %d zero %d trunc %x" %
820 (a
, b
, mask
, bits
, fz
, newb
))
824 lsb
= mask
& ~
(mask
-1) if carry_in
else 0
827 sum = ((a
& mask
) >> b
)
829 carry
= (sum & mask
) != sum
831 print("res", hex(a
), hex(b
), hex(sum), hex(mask
), hex(result
))
834 def test_add_fn(carry_in
, a
, b
, mask
):
835 lsb
= mask
& ~
(mask
-1) if carry_in
else 0
836 sum = (a
& mask
) + (b
& mask
) + lsb
838 carry
= (sum & mask
) != sum
839 print(a
, b
, sum, mask
)
842 def test_sub_fn(carry_in
, a
, b
, mask
):
843 lsb
= mask
& ~
(mask
-1) if carry_in
else 0
844 sum = (a
& mask
) + (~b
& mask
) + lsb
846 carry
= (sum & mask
) != sum
849 def test_neg_fn(carry_in
, a
, b
, mask
):
850 lsb
= mask
& ~
(mask
- 1) # has only LSB of mask set
851 pos
= lsb
.bit_length() - 1 # find bit position
852 a
= (a
& mask
) >> pos
# shift it to the beginning
853 return ((-a
) << pos
) & mask
, 0 # negate and shift it back
855 def test_signed_fn(carry_in
, a
, b
, mask
):
858 def test_op(msg_prefix
, carry
, test_fn
, mod_attr
, *mask_list
):
861 a
, b
= randint(0, 1 << 16), randint(0, 1 << 16)
862 rand_data
.append((a
, b
))
863 for a
, b
in [(0x0000, 0x0000),
869 (0x0000, 0xFFFF)] + rand_data
:
870 yield module
.a
.lower().eq(a
)
871 yield module
.b
.lower().eq(b
)
872 carry_sig
= 0xf if carry
else 0
873 yield module
.carry_in
.eq(carry_sig
)
877 for i
, mask
in enumerate(mask_list
):
878 print("i/mask", i
, hex(mask
))
879 res
, c
= test_fn(carry
, a
, b
, mask
)
881 lsb
= mask
& ~
(mask
- 1)
882 bit_set
= int(math
.log2(lsb
))
883 carry_result |
= c
<< int(bit_set
/4)
884 outval
= (yield getattr(module
, "%s_output" % mod_attr
))
885 # TODO: get (and test) carry output as well
886 print(a
, b
, outval
, carry
)
887 msg
= f
"{msg_prefix}: 0x{a:X} {mod_attr} 0x{b:X}" + \
888 f
" => 0x{y:X} != 0x{outval:X}"
889 self
.assertEqual(y
, outval
, msg
)
890 if hasattr(module
, "%s_carry_out" % mod_attr
):
891 c_outval
= (yield getattr(module
,
892 "%s_carry_out" % mod_attr
))
893 msg
= f
"{msg_prefix}: 0x{a:X} {mod_attr} 0x{b:X}" + \
894 f
" => 0x{carry_result:X} != 0x{c_outval:X}"
895 self
.assertEqual(carry_result
, c_outval
, msg
)
897 # run through series of operations with corresponding
898 # "helper" routines to reproduce the result (test_fn). the same
899 # a/b input is passed to *all* outputs, where the name of the
900 # output attribute (mod_attr) will contain the result to be
901 # compared against the expected output from test_fn
902 for (test_fn
, mod_attr
) in (
903 (test_ls_scal_fn
, "ls_scal"),
905 (test_rs_scal_fn
, "rs_scal"),
907 (test_add_fn
, "add"),
908 (test_sub_fn
, "sub"),
909 (test_neg_fn
, "neg"),
910 (test_signed_fn
, "signed"),
912 yield part_mask
.eq(0)
913 yield from test_op("16-bit", 1, test_fn
, mod_attr
, 0xFFFF)
914 yield from test_op("16-bit", 0, test_fn
, mod_attr
, 0xFFFF)
915 yield part_mask
.eq(0b10)
916 yield from test_op("8-bit", 0, test_fn
, mod_attr
,
918 yield from test_op("8-bit", 1, test_fn
, mod_attr
,
920 yield part_mask
.eq(0b1111)
921 yield from test_op("4-bit", 0, test_fn
, mod_attr
,
922 0xF000, 0x0F00, 0x00F0, 0x000F)
923 yield from test_op("4-bit", 1, test_fn
, mod_attr
,
924 0xF000, 0x0F00, 0x00F0, 0x000F)
926 def test_ne_fn(a
, b
, mask
):
927 return (a
& mask
) != (b
& mask
)
929 def test_lt_fn(a
, b
, mask
):
930 return (a
& mask
) < (b
& mask
)
932 def test_le_fn(a
, b
, mask
):
933 return (a
& mask
) <= (b
& mask
)
935 def test_eq_fn(a
, b
, mask
):
936 return (a
& mask
) == (b
& mask
)
938 def test_gt_fn(a
, b
, mask
):
939 return (a
& mask
) > (b
& mask
)
941 def test_ge_fn(a
, b
, mask
):
942 return (a
& mask
) >= (b
& mask
)
944 def test_binop(msg_prefix
, test_fn
, mod_attr
, *maskbit_list
):
945 for a
, b
in [(0x0000, 0x0000),
955 yield module
.a
.lower().eq(a
)
956 yield module
.b
.lower().eq(b
)
958 # convert to mask_list
960 for mb
in maskbit_list
:
967 # do the partitioned tests
968 for i
, mask
in enumerate(mask_list
):
969 if test_fn(a
, b
, mask
):
970 # OR y with the lowest set bit in the mask
973 outval
= (yield getattr(module
, "%s_output" % mod_attr
))
974 msg
= f
"{msg_prefix}: {mod_attr} 0x{a:X} == 0x{b:X}" + \
975 f
" => 0x{y:X} != 0x{outval:X}, masklist %s"
976 print((msg
% str(maskbit_list
)).format(locals()))
977 self
.assertEqual(y
, outval
, msg
% str(maskbit_list
))
979 for (test_fn
, mod_attr
) in ((test_eq_fn
, "eq"),
986 yield part_mask
.eq(0)
987 yield from test_binop("16-bit", test_fn
, mod_attr
, 0b1111)
988 yield part_mask
.eq(0b10)
989 yield from test_binop("8-bit", test_fn
, mod_attr
,
991 yield part_mask
.eq(0b1111)
992 yield from test_binop("4-bit", test_fn
, mod_attr
,
993 0b1000, 0b0100, 0b0010, 0b0001)
995 sim
.add_process(async_process
)
997 vcd_file
=open(test_name
+ ".vcd", "w"),
998 gtkw_file
=open(test_name
+ ".gtkw", "w"),
1003 # TODO: adapt to SimdSignal. perhaps a different style?
1005 from nmigen.tests.test_hdl_ast import SignedEnum
1006 def test_matches(self)
1008 self.assertRepr(s.matches(), "(const 1'd0)")
1009 self.assertRepr(s.matches(1), """
1010 (== (sig s) (const 1'd1))
1012 self.assertRepr(s.matches(0, 1), """
1013 (r| (cat (== (sig s) (const 1'd0)) (== (sig s) (const 1'd1))))
1015 self.assertRepr(s.matches("10--"), """
1016 (== (& (sig s) (const 4'd12)) (const 4'd8))
1018 self.assertRepr(s.matches("1 0--"), """
1019 (== (& (sig s) (const 4'd12)) (const 4'd8))
1022 def test_matches_enum(self):
1023 s = Signal(SignedEnum)
1024 self.assertRepr(s.matches(SignedEnum.FOO), """
1025 (== (sig s) (const 1'sd-1))
1028 def test_matches_width_wrong(self):
1030 with self.assertRaisesRegex(SyntaxError,
1031 r"^Match pattern '--' must have the same width as "
1032 r"match value \(which is 4\)$"):
1034 with self.assertWarnsRegex(SyntaxWarning,
1035 (r"^Match pattern '10110' is wider than match value "
1036 r"\(which has width 4\); "
1037 r"comparison will never be true$")):
1040 def test_matches_bits_wrong(self):
1042 with self.assertRaisesRegex(SyntaxError,
1043 (r"^Match pattern 'abc' must consist of 0, 1, "
1044 r"and - \(don't care\) bits, "
1045 r"and may include whitespace$")):
1048 def test_matches_pattern_wrong(self):
1050 with self.assertRaisesRegex(SyntaxError,
1051 r"^Match pattern must be an integer, a string, "
1052 r"or an enumeration, not 1\.0$"):
1056 if __name__
== '__main__':