b30c4b7c801158e531a9619beead852a9b14d1c6
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # See Notices.txt for copyright information
5 from nmigen
.hdl
.ast
import AnyConst
, Assert
, Signal
6 from nmigen
.hdl
.dsl
import Module
7 from nmutil
.formaltest
import FHDLTestCase
8 from nmutil
.lut
import BitwiseMux
, BitwiseLut
, TreeBitwiseLut
9 from nmigen
.sim
import Delay
10 from nmutil
.sim_util
import do_sim
, hash_256
13 class TestBitwiseMux(FHDLTestCase
):
16 dut
= BitwiseMux(width
)
18 def case(sel
, t
, f
, expected
):
19 with self
.subTest(sel
=bin(sel
), t
=bin(t
), f
=bin(f
)):
24 output
= yield dut
.output
25 with self
.subTest(output
=bin(output
), expected
=bin(expected
)):
26 self
.assertEqual(expected
, output
)
29 for sel
in range(2 ** width
):
30 for t
in range(2 ** width
):
31 for f
in range(2**width
):
33 for i
in range(width
):
39 yield from case(sel
, t
, f
, expected
)
40 with
do_sim(self
, dut
, [dut
.sel
, dut
.t
, dut
.f
, dut
.output
]) as sim
:
41 sim
.add_process(process
)
44 def test_formal(self
):
46 dut
= BitwiseMux(width
)
48 m
.submodules
.dut
= dut
49 m
.d
.comb
+= dut
.sel
.eq(AnyConst(width
))
50 m
.d
.comb
+= dut
.f
.eq(AnyConst(width
))
51 m
.d
.comb
+= dut
.t
.eq(AnyConst(width
))
52 for i
in range(width
):
53 with m
.If(dut
.sel
[i
]):
54 m
.d
.comb
+= Assert(dut
.t
[i
] == dut
.output
[i
])
56 m
.d
.comb
+= Assert(dut
.f
[i
] == dut
.output
[i
])
60 class TestBitwiseLut(FHDLTestCase
):
63 mask
= 2 ** dut
.width
- 1
64 lut_mask
= 2 ** dut
.lut
.width
- 1
65 if cls
is TreeBitwiseLut
:
66 mux_inputs
= {k
: s
.name
for k
, s
in dut
._mux
_inputs
.items()}
67 self
.assertEqual(mux_inputs
, {
68 (): 'mux_input_0bxxx',
69 (False,): 'mux_input_0bxx0',
70 (False, False): 'mux_input_0bx00',
71 (False, False, False): 'mux_input_0b000',
72 (False, False, True): 'mux_input_0b100',
73 (False, True): 'mux_input_0bx10',
74 (False, True, False): 'mux_input_0b010',
75 (False, True, True): 'mux_input_0b110',
76 (True,): 'mux_input_0bxx1',
77 (True, False): 'mux_input_0bx01',
78 (True, False, False): 'mux_input_0b001',
79 (True, False, True): 'mux_input_0b101',
80 (True, True): 'mux_input_0bx11',
81 (True, True, False): 'mux_input_0b011',
82 (True, True, True): 'mux_input_0b111'
85 def case(in0
, in1
, in2
, lut
):
87 for i
in range(dut
.width
):
95 if lut
& 2 ** lut_index
:
97 with self
.subTest(in0
=bin(in0
), in1
=bin(in1
), in2
=bin(in2
),
99 yield dut
.inputs
[0].eq(in0
)
100 yield dut
.inputs
[1].eq(in1
)
101 yield dut
.inputs
[2].eq(in2
)
102 yield dut
.lut
.eq(lut
)
104 output
= yield dut
.output
105 with self
.subTest(output
=bin(output
), expected
=bin(expected
)):
106 self
.assertEqual(expected
, output
)
109 for case_index
in range(100):
110 with self
.subTest(case_index
=case_index
):
111 in0
= hash_256(f
"{case_index} in0") & mask
112 in1
= hash_256(f
"{case_index} in1") & mask
113 in2
= hash_256(f
"{case_index} in2") & mask
114 lut
= hash_256(f
"{case_index} lut") & lut_mask
115 yield from case(in0
, in1
, in2
, lut
)
116 with
do_sim(self
, dut
, [*dut
.inputs
, dut
.lut
, dut
.output
]) as sim
:
117 sim
.add_process(process
)
120 def tst_formal(self
, cls
):
123 m
.submodules
.dut
= dut
124 m
.d
.comb
+= dut
.inputs
[0].eq(AnyConst(dut
.width
))
125 m
.d
.comb
+= dut
.inputs
[1].eq(AnyConst(dut
.width
))
126 m
.d
.comb
+= dut
.inputs
[2].eq(AnyConst(dut
.width
))
127 m
.d
.comb
+= dut
.lut
.eq(AnyConst(dut
.lut
.width
))
128 for i
in range(dut
.width
):
129 lut_index
= Signal(dut
.input_count
, name
=f
"lut_index_{i}")
130 for j
in range(dut
.input_count
):
131 m
.d
.comb
+= lut_index
[j
].eq(dut
.inputs
[j
][i
])
132 for j
in range(dut
.lut
.width
):
133 with m
.If(lut_index
== j
):
134 m
.d
.comb
+= Assert(dut
.lut
[j
] == dut
.output
[i
])
141 self
.tst(TreeBitwiseLut
)
143 def test_formal(self
):
144 self
.tst_formal(BitwiseLut
)
146 def test_tree_formal(self
):
147 self
.tst_formal(TreeBitwiseLut
)
150 if __name__
== "__main__":