Add rudimentary test for partitioned add with carry
[ieee754fpu.git] / src / ieee754 / part / test / test_partsig.py
1 #!/usr/bin/env python3
2 # SPDX-License-Identifier: LGPL-2.1-or-later
3 # See Notices.txt for copyright information
4
5 from nmigen import Signal, Module, Elaboratable
6 from nmigen.back.pysim import Simulator, Delay, Tick, Passive
7 from nmigen.cli import verilog, rtlil
8
9 from ieee754.part.partsig import PartitionedSignal
10 from ieee754.part_mux.part_mux import PMux
11
12 import unittest
13 import itertools
14
15
16 def perms(k):
17 return map(''.join, itertools.product('01', repeat=k))
18
19
20 def create_ilang(dut, traces, test_name):
21 vl = rtlil.convert(dut, ports=traces)
22 with open("%s.il" % test_name, "w") as f:
23 f.write(vl)
24
25
26 def create_simulator(module, traces, test_name):
27 create_ilang(module, traces, test_name)
28 return Simulator(module,
29 vcd_file=open(test_name + ".vcd", "w"),
30 gtkw_file=open(test_name + ".gtkw", "w"),
31 traces=traces)
32
33 class TestAddMod(Elaboratable):
34 def __init__(self, width, partpoints):
35 self.partpoints = partpoints
36 self.a = PartitionedSignal(partpoints, width)
37 self.b = PartitionedSignal(partpoints, width)
38 self.add_output = Signal(width)
39 self.eq_output = Signal(len(partpoints)+1)
40 self.gt_output = Signal(len(partpoints)+1)
41 self.ge_output = Signal(len(partpoints)+1)
42 self.ne_output = Signal(len(partpoints)+1)
43 self.lt_output = Signal(len(partpoints)+1)
44 self.le_output = Signal(len(partpoints)+1)
45 self.mux_sel = Signal(len(partpoints)+1)
46 self.mux_out = Signal(width)
47 self.carry_in = Signal(len(partpoints)+1)
48 self.carry_out = Signal(len(partpoints)+1)
49
50 def elaborate(self, platform):
51 m = Module()
52 self.a.set_module(m)
53 self.b.set_module(m)
54 m.d.comb += self.lt_output.eq(self.a < self.b)
55 m.d.comb += self.ne_output.eq(self.a != self.b)
56 m.d.comb += self.le_output.eq(self.a <= self.b)
57 m.d.comb += self.gt_output.eq(self.a > self.b)
58 m.d.comb += self.eq_output.eq(self.a == self.b)
59 m.d.comb += self.ge_output.eq(self.a >= self.b)
60 add_out, add_carry = self.a.addc(self.b, self.carry_in)
61 m.d.comb += self.add_output.eq(add_out)
62 m.d.comb += self.carry_out.eq(add_carry)
63 ppts = self.partpoints
64 m.d.comb += self.mux_out.eq(PMux(m, ppts, self.mux_sel, self.a, self.b))
65
66 return m
67
68
69 class TestPartitionPoints(unittest.TestCase):
70 def test(self):
71 width = 16
72 part_mask = Signal(4) # divide into 4-bits
73 module = TestAddMod(width, part_mask)
74
75 sim = create_simulator(module,
76 [part_mask,
77 module.a.sig,
78 module.b.sig,
79 module.add_output,
80 module.eq_output],
81 "part_sig_add")
82 def async_process():
83 def test_add(msg_prefix, carry, *mask_list):
84 for a, b in [(0x0000, 0x0000),
85 (0x1234, 0x1234),
86 (0xABCD, 0xABCD),
87 (0xFFFF, 0x0000),
88 (0x0000, 0x0000),
89 (0xFFFF, 0xFFFF),
90 (0x0000, 0xFFFF)]:
91 yield module.a.eq(a)
92 yield module.b.eq(b)
93 carry_sig = 0xf if carry else 0
94 yield module.carry_in.eq(carry_sig)
95 yield Delay(0.1e-6)
96 y = 0
97 for i, mask in enumerate(mask_list):
98 lsb = mask & ~(mask-1) if carry else 0
99 y |= mask & ((a & mask) + (b & mask) + lsb)
100 outval = (yield module.add_output)
101 print(a, b, outval, carry)
102 msg = f"{msg_prefix}: 0x{a:X} + 0x{b:X}" + \
103 f" => 0x{y:X} != 0x{outval:X}"
104 self.assertEqual(y, outval, msg)
105 yield part_mask.eq(0)
106 yield from test_add("16-bit", 1, 0xFFFF)
107 yield from test_add("16-bit", 0, 0xFFFF)
108 yield part_mask.eq(0b10)
109 yield from test_add("8-bit", 0, 0xFF00, 0x00FF)
110 yield from test_add("8-bit", 1, 0xFF00, 0x00FF)
111 yield part_mask.eq(0b1111)
112 yield from test_add("4-bit", 0, 0xF000, 0x0F00, 0x00F0, 0x000F)
113 yield from test_add("4-bit", 1, 0xF000, 0x0F00, 0x00F0, 0x000F)
114
115 def test_ne_fn(a, b, mask):
116 return (a & mask) != (b & mask)
117
118 def test_lt_fn(a, b, mask):
119 return (a & mask) < (b & mask)
120
121 def test_le_fn(a, b, mask):
122 return (a & mask) <= (b & mask)
123
124 def test_eq_fn(a, b, mask):
125 return (a & mask) == (b & mask)
126
127 def test_gt_fn(a, b, mask):
128 return (a & mask) > (b & mask)
129
130 def test_ge_fn(a, b, mask):
131 return (a & mask) >= (b & mask)
132
133 def test_binop(msg_prefix, test_fn, mod_attr, *maskbit_list):
134 for a, b in [(0x0000, 0x0000),
135 (0x1234, 0x1234),
136 (0xABCD, 0xABCD),
137 (0xFFFF, 0x0000),
138 (0x0000, 0x0000),
139 (0xFFFF, 0xFFFF),
140 (0x0000, 0xFFFF),
141 (0xABCD, 0xABCE),
142 (0x8000, 0x0000),
143 (0xBEEF, 0xFEED)]:
144 yield module.a.eq(a)
145 yield module.b.eq(b)
146 yield Delay(0.1e-6)
147 # convert to mask_list
148 mask_list = []
149 for mb in maskbit_list:
150 v = 0
151 for i in range(4):
152 if mb & (1<<i):
153 v |= 0xf << (i*4)
154 mask_list.append(v)
155 y = 0
156 # do the partitioned tests
157 for i, mask in enumerate(mask_list):
158 if test_fn(a, b, mask):
159 # OR y with the lowest set bit in the mask
160 y |= maskbit_list[i]
161 # check the result
162 outval = (yield getattr(module, "%s_output" % mod_attr))
163 msg = f"{msg_prefix}: {mod_attr} 0x{a:X} == 0x{b:X}" + \
164 f" => 0x{y:X} != 0x{outval:X}, masklist %s"
165 print ((msg % str(maskbit_list)).format(locals()))
166 self.assertEqual(y, outval, msg % str(maskbit_list))
167
168 for (test_fn, mod_attr) in ((test_eq_fn, "eq"),
169 (test_gt_fn, "gt"),
170 (test_ge_fn, "ge"),
171 (test_lt_fn, "lt"),
172 (test_le_fn, "le"),
173 (test_ne_fn, "ne"),
174 ):
175 yield part_mask.eq(0)
176 yield from test_binop("16-bit", test_fn, mod_attr, 0b1111)
177 yield part_mask.eq(0b10)
178 yield from test_binop("8-bit", test_fn, mod_attr,
179 0b1100, 0b0011)
180 yield part_mask.eq(0b1111)
181 yield from test_binop("4-bit", test_fn, mod_attr,
182 0b1000, 0b0100, 0b0010, 0b0001)
183
184 def test_muxop(msg_prefix, *maskbit_list):
185 for a, b in [(0x0000, 0x0000),
186 (0x1234, 0x1234),
187 (0xABCD, 0xABCD),
188 (0xFFFF, 0x0000),
189 (0x0000, 0x0000),
190 (0xFFFF, 0xFFFF),
191 (0x0000, 0xFFFF)]:
192 # convert to mask_list
193 mask_list = []
194 for mb in maskbit_list:
195 v = 0
196 for i in range(4):
197 if mb & (1<<i):
198 v |= 0xf << (i*4)
199 mask_list.append(v)
200
201 # TODO: sel needs to go through permutations of mask_list
202 for p in perms(len(mask_list)):
203
204 sel = 0
205 selmask = 0
206 for i, v in enumerate(p):
207 if v == '1':
208 sel |= maskbit_list[i]
209 selmask |= mask_list[i]
210
211 yield module.a.eq(a)
212 yield module.b.eq(b)
213 yield module.mux_sel.eq(sel)
214 yield Delay(0.1e-6)
215 y = 0
216 # do the partitioned tests
217 for i, mask in enumerate(mask_list):
218 if (selmask & mask):
219 y |= (a & mask)
220 else:
221 y |= (b & mask)
222 # check the result
223 outval = (yield module.mux_out)
224 msg = f"{msg_prefix}: mux " + \
225 f"0x{sel:X} ? 0x{a:X} : 0x{b:X}" + \
226 f" => 0x{y:X} != 0x{outval:X}, masklist %s"
227 #print ((msg % str(maskbit_list)).format(locals()))
228 self.assertEqual(y, outval, msg % str(maskbit_list))
229
230 yield part_mask.eq(0)
231 yield from test_muxop("16-bit", 0b1111)
232 yield part_mask.eq(0b10)
233 yield from test_muxop("8-bit", 0b1100, 0b0011)
234 yield part_mask.eq(0b1111)
235 yield from test_muxop("4-bit", 0b1000, 0b0100, 0b0010, 0b0001)
236
237 sim.add_process(async_process)
238 sim.run()
239
240 if __name__ == '__main__':
241 unittest.main()
242