53647082033ed646fad84d4067ef07d60fdf233c
[ieee754fpu.git] / src / ieee754 / part / test / test_partsig.py
1 #!/usr/bin/env python3
2 # SPDX-License-Identifier: LGPL-2.1-or-later
3 # See Notices.txt for copyright information
4
5 from nmigen import Signal, Module, Elaboratable
6 from nmigen.back.pysim import Simulator, Delay
7 from nmigen.cli import rtlil
8
9 from ieee754.part.partsig import PartitionedSignal
10 from ieee754.part_mux.part_mux import PMux
11
12 from random import randint
13 import unittest
14 import itertools
15
16
17 def perms(k):
18 return map(''.join, itertools.product('01', repeat=k))
19
20
21 def create_ilang(dut, traces, test_name):
22 vl = rtlil.convert(dut, ports=traces)
23 with open("%s.il" % test_name, "w") as f:
24 f.write(vl)
25
26
27 def create_simulator(module, traces, test_name):
28 create_ilang(module, traces, test_name)
29 return Simulator(module,
30 vcd_file=open(test_name + ".vcd", "w"),
31 gtkw_file=open(test_name + ".gtkw", "w"),
32 traces=traces)
33
34
35 class TestAddMod(Elaboratable):
36 def __init__(self, width, partpoints):
37 self.partpoints = partpoints
38 self.a = PartitionedSignal(partpoints, width)
39 self.b = PartitionedSignal(partpoints, width)
40 self.add_output = Signal(width)
41 self.sub_output = Signal(width)
42 self.eq_output = Signal(len(partpoints)+1)
43 self.gt_output = Signal(len(partpoints)+1)
44 self.ge_output = Signal(len(partpoints)+1)
45 self.ne_output = Signal(len(partpoints)+1)
46 self.lt_output = Signal(len(partpoints)+1)
47 self.le_output = Signal(len(partpoints)+1)
48 self.mux_sel = Signal(len(partpoints)+1)
49 self.mux_out = Signal(width)
50 self.carry_in = Signal(len(partpoints)+1)
51 self.add_carry_out = Signal(len(partpoints)+1)
52 self.sub_carry_out = Signal(len(partpoints)+1)
53 self.neg_output = Signal(width)
54
55 def elaborate(self, platform):
56 m = Module()
57 comb = m.d.comb
58 self.a.set_module(m)
59 self.b.set_module(m)
60 comb += self.lt_output.eq(self.a < self.b)
61 comb += self.ne_output.eq(self.a != self.b)
62 comb += self.le_output.eq(self.a <= self.b)
63 comb += self.gt_output.eq(self.a > self.b)
64 comb += self.eq_output.eq(self.a == self.b)
65 comb += self.ge_output.eq(self.a >= self.b)
66 # add
67 add_out, add_carry = self.a.add_op(self.a, self.b,
68 self.carry_in)
69 comb += self.add_output.eq(add_out)
70 comb += self.add_carry_out.eq(add_carry)
71 sub_out, sub_carry = self.a.sub_op(self.a, self.b,
72 self.carry_in)
73 comb += self.sub_output.eq(sub_out)
74 comb += self.sub_carry_out.eq(add_carry)
75 comb += self.neg_output.eq(-self.a)
76 ppts = self.partpoints
77 comb += self.mux_out.eq(PMux(m, ppts, self.mux_sel, self.a, self.b))
78
79 return m
80
81
82 class TestPartitionPoints(unittest.TestCase):
83 def test(self):
84 width = 16
85 part_mask = Signal(4) # divide into 4-bits
86 module = TestAddMod(width, part_mask)
87
88 sim = create_simulator(module,
89 [part_mask,
90 module.a.sig,
91 module.b.sig,
92 module.add_output,
93 module.eq_output],
94 "part_sig_add")
95
96 def async_process():
97
98 def test_add_fn(carry_in, a, b, mask):
99 lsb = mask & ~(mask-1) if carry_in else 0
100 return mask & ((a & mask) + (b & mask) + lsb)
101
102 def test_sub_fn(carry_in, a, b, mask):
103 lsb = mask & ~(mask-1) if carry_in else 0
104 return mask & ((a & mask) + (~b & mask) + lsb)
105
106 def test_neg_fn(carry_in, a, b, mask):
107 return mask & ((a & mask) + (~0 & mask))
108
109 def test_op(msg_prefix, carry, test_fn, mod_attr, *mask_list):
110 rand_data = []
111 for i in range(100):
112 a, b = randint(0, 1 << 16), randint(0, 1 << 16)
113 rand_data.append((a, b))
114 for a, b in [(0x0000, 0x0000),
115 (0x1234, 0x1234),
116 (0xABCD, 0xABCD),
117 (0xFFFF, 0x0000),
118 (0x0000, 0x0000),
119 (0xFFFF, 0xFFFF),
120 (0x0000, 0xFFFF)] + rand_data:
121 yield module.a.eq(a)
122 yield module.b.eq(b)
123 carry_sig = 0xf if carry else 0
124 yield module.carry_in.eq(carry_sig)
125 yield Delay(0.1e-6)
126 y = 0
127 for i, mask in enumerate(mask_list):
128 y |= test_fn(carry, a, b, mask)
129 outval = (yield getattr(module, "%s_output" % mod_attr))
130 # TODO: get (and test) carry output as well
131 print(a, b, outval, carry)
132 msg = f"{msg_prefix}: 0x{a:X} + 0x{b:X}" + \
133 f" => 0x{y:X} != 0x{outval:X}"
134 self.assertEqual(y, outval, msg)
135
136 for (test_fn, mod_attr) in ((test_add_fn, "add"),
137 (test_sub_fn, "sub"),
138 (test_neg_fn, "neg"),
139 ):
140 yield part_mask.eq(0)
141 yield from test_op("16-bit", 1, test_fn, mod_attr, 0xFFFF)
142 yield from test_op("16-bit", 0, test_fn, mod_attr, 0xFFFF)
143 yield part_mask.eq(0b10)
144 yield from test_op("8-bit", 0, test_fn, mod_attr,
145 0xFF00, 0x00FF)
146 yield from test_op("8-bit", 1, test_fn, mod_attr,
147 0xFF00, 0x00FF)
148 yield part_mask.eq(0b1111)
149 yield from test_op("4-bit", 0, test_fn, mod_attr,
150 0xF000, 0x0F00, 0x00F0, 0x000F)
151 yield from test_op("4-bit", 1, test_fn, mod_attr,
152 0xF000, 0x0F00, 0x00F0, 0x000F)
153
154 def test_ne_fn(a, b, mask):
155 return (a & mask) != (b & mask)
156
157 def test_lt_fn(a, b, mask):
158 return (a & mask) < (b & mask)
159
160 def test_le_fn(a, b, mask):
161 return (a & mask) <= (b & mask)
162
163 def test_eq_fn(a, b, mask):
164 return (a & mask) == (b & mask)
165
166 def test_gt_fn(a, b, mask):
167 return (a & mask) > (b & mask)
168
169 def test_ge_fn(a, b, mask):
170 return (a & mask) >= (b & mask)
171
172 def test_binop(msg_prefix, test_fn, mod_attr, *maskbit_list):
173 for a, b in [(0x0000, 0x0000),
174 (0x1234, 0x1234),
175 (0xABCD, 0xABCD),
176 (0xFFFF, 0x0000),
177 (0x0000, 0x0000),
178 (0xFFFF, 0xFFFF),
179 (0x0000, 0xFFFF),
180 (0xABCD, 0xABCE),
181 (0x8000, 0x0000),
182 (0xBEEF, 0xFEED)]:
183 yield module.a.eq(a)
184 yield module.b.eq(b)
185 yield Delay(0.1e-6)
186 # convert to mask_list
187 mask_list = []
188 for mb in maskbit_list:
189 v = 0
190 for i in range(4):
191 if mb & (1 << i):
192 v |= 0xf << (i*4)
193 mask_list.append(v)
194 y = 0
195 # do the partitioned tests
196 for i, mask in enumerate(mask_list):
197 if test_fn(a, b, mask):
198 # OR y with the lowest set bit in the mask
199 y |= maskbit_list[i]
200 # check the result
201 outval = (yield getattr(module, "%s_output" % mod_attr))
202 msg = f"{msg_prefix}: {mod_attr} 0x{a:X} == 0x{b:X}" + \
203 f" => 0x{y:X} != 0x{outval:X}, masklist %s"
204 print((msg % str(maskbit_list)).format(locals()))
205 self.assertEqual(y, outval, msg % str(maskbit_list))
206
207 for (test_fn, mod_attr) in ((test_eq_fn, "eq"),
208 (test_gt_fn, "gt"),
209 (test_ge_fn, "ge"),
210 (test_lt_fn, "lt"),
211 (test_le_fn, "le"),
212 (test_ne_fn, "ne"),
213 ):
214 yield part_mask.eq(0)
215 yield from test_binop("16-bit", test_fn, mod_attr, 0b1111)
216 yield part_mask.eq(0b10)
217 yield from test_binop("8-bit", test_fn, mod_attr,
218 0b1100, 0b0011)
219 yield part_mask.eq(0b1111)
220 yield from test_binop("4-bit", test_fn, mod_attr,
221 0b1000, 0b0100, 0b0010, 0b0001)
222
223 def test_muxop(msg_prefix, *maskbit_list):
224 for a, b in [(0x0000, 0x0000),
225 (0x1234, 0x1234),
226 (0xABCD, 0xABCD),
227 (0xFFFF, 0x0000),
228 (0x0000, 0x0000),
229 (0xFFFF, 0xFFFF),
230 (0x0000, 0xFFFF)]:
231 # convert to mask_list
232 mask_list = []
233 for mb in maskbit_list:
234 v = 0
235 for i in range(4):
236 if mb & (1 << i):
237 v |= 0xf << (i*4)
238 mask_list.append(v)
239
240 # TODO: sel needs to go through permutations of mask_list
241 for p in perms(len(mask_list)):
242
243 sel = 0
244 selmask = 0
245 for i, v in enumerate(p):
246 if v == '1':
247 sel |= maskbit_list[i]
248 selmask |= mask_list[i]
249
250 yield module.a.eq(a)
251 yield module.b.eq(b)
252 yield module.mux_sel.eq(sel)
253 yield Delay(0.1e-6)
254 y = 0
255 # do the partitioned tests
256 for i, mask in enumerate(mask_list):
257 if (selmask & mask):
258 y |= (a & mask)
259 else:
260 y |= (b & mask)
261 # check the result
262 outval = (yield module.mux_out)
263 msg = f"{msg_prefix}: mux " + \
264 f"0x{sel:X} ? 0x{a:X} : 0x{b:X}" + \
265 f" => 0x{y:X} != 0x{outval:X}, masklist %s"
266 # print ((msg % str(maskbit_list)).format(locals()))
267 self.assertEqual(y, outval, msg % str(maskbit_list))
268
269 yield part_mask.eq(0)
270 yield from test_muxop("16-bit", 0b1111)
271 yield part_mask.eq(0b10)
272 yield from test_muxop("8-bit", 0b1100, 0b0011)
273 yield part_mask.eq(0b1111)
274 yield from test_muxop("4-bit", 0b1000, 0b0100, 0b0010, 0b0001)
275
276 sim.add_process(async_process)
277 sim.run()
278
279
280 if __name__ == '__main__':
281 unittest.main()