438b547fca5cc822f91c6add9f71cadd2ddfe8cb
[nmigen-gf.git] / src / nmigen_gf / hdl / test / test_cldivrem.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 import unittest
8 from nmigen.hdl.ast import Signal, Const, unsigned
9 from nmigen.hdl.dsl import Module
10 from nmutil.formaltest import FHDLTestCase
11 from nmigen_gf.hdl.cldivrem import (CLDivRemFSMStage, CLDivRemInputData,
12 CLDivRemOutputData, CLDivRemShape,
13 cldivrem_shifting, CLDivRemState)
14 from nmigen.sim import Delay, Tick
15 from nmutil.sim_util import do_sim, hash_256
16 from nmigen_gf.reference.cldivrem import cldivrem
17
18
19 class TestCLDivRemShifting(FHDLTestCase):
20 def tst(self, width, full):
21 def case(n, d):
22 assert isinstance(n, int)
23 assert isinstance(d, int)
24 if d != 0:
25 expected_q, expected_r = cldivrem(n, d, width=width)
26 q, r = cldivrem_shifting(n, d, width=width)
27 else:
28 expected_q = expected_r = 0
29 q = r = 0
30 with self.subTest(n=hex(n), d=hex(d),
31 expected_q=hex(expected_q),
32 expected_r=hex(expected_r),
33 q=hex(q), r=hex(r)):
34 self.assertEqual(expected_q, q)
35 self.assertEqual(expected_r, r)
36 if full:
37 for n in range(1 << width):
38 for d in range(1 << width):
39 case(n, d)
40 else:
41 for i in range(100):
42 n = hash_256(f"cldivrem comb n {i}")
43 n = Const.normalize(n, unsigned(width))
44 d = hash_256(f"cldivrem comb d {i}")
45 d = Const.normalize(d, unsigned(width))
46 case(n, d)
47
48 def test_6(self):
49 self.tst(6, full=True)
50
51 def test_64(self):
52 self.tst(64, full=False)
53
54
55 class TestCLDivRemComb(FHDLTestCase):
56 def tst(self, shape, full):
57 assert isinstance(shape, CLDivRemShape)
58 width = shape.width
59 m = Module()
60 n_in = Signal(width)
61 d_in = Signal(width)
62 q_out = Signal(width)
63 r_out = Signal(width)
64 states: "list[CLDivRemState]" = []
65 for i in shape.step_range:
66 states.append(CLDivRemState(shape, name=f"state_{i}"))
67 if i == 0:
68 states[i].set_to_initial(m, n=n_in, d=d_in)
69 else:
70 states[i].set_to_next(m, states[i - 1])
71 q, r = states[-1].get_output()
72 m.d.comb += [q_out.eq(q), r_out.eq(r)]
73
74 def case(n, d):
75 assert isinstance(n, int)
76 assert isinstance(d, int)
77 if d != 0:
78 expected_q, expected_r = cldivrem_shifting(n, d, width)
79 else:
80 expected_q = expected_r = 0
81 with self.subTest(n=hex(n), d=hex(d),
82 expected_q=hex(expected_q),
83 expected_r=hex(expected_r)):
84 yield n_in.eq(n)
85 yield d_in.eq(d)
86 yield Delay(1e-6)
87 for i in shape.step_range:
88 with self.subTest(i=i):
89 done = yield states[i].done
90 step = yield states[i].step
91 self.assertEqual(done, i >= shape.done_step)
92 self.assertEqual(step, i)
93 q = yield q_out
94 r = yield r_out
95 with self.subTest(q=hex(q), r=hex(r)):
96 # only check results when inputs are valid
97 if d != 0:
98 self.assertEqual(q, expected_q)
99 self.assertEqual(r, expected_r)
100
101 def process():
102 if full:
103 for n in range(1 << width):
104 for d in range(1 << width):
105 yield from case(n, d)
106 else:
107 for i in range(100):
108 n = hash_256(f"cldivrem comb n {i}")
109 n = Const.normalize(n, unsigned(width))
110 d = hash_256(f"cldivrem comb d {i}")
111 d = Const.normalize(d, unsigned(width))
112 yield from case(n, d)
113 with do_sim(self, m, [n_in, d_in, q_out, r_out]) as sim:
114 sim.add_process(process)
115 sim.run()
116
117 def test_4(self):
118 self.tst(CLDivRemShape(width=4), full=True)
119
120 def test_6(self):
121 self.tst(CLDivRemShape(width=6), full=True)
122
123 def test_8(self):
124 self.tst(CLDivRemShape(width=8), full=False)
125
126
127 class TestCLDivRemFSM(FHDLTestCase):
128 def tst(self, shape, full, steps_per_clock):
129 assert isinstance(shape, CLDivRemShape)
130 assert isinstance(steps_per_clock, int) and steps_per_clock >= 1
131 pspec = {}
132 dut = CLDivRemFSMStage(pspec, shape, steps_per_clock=steps_per_clock)
133 i_data: CLDivRemInputData = dut.p.i_data
134 o_data: CLDivRemOutputData = dut.n.o_data
135 self.assertEqual(i_data.n.shape(), unsigned(shape.width))
136 self.assertEqual(i_data.d.shape(), unsigned(shape.width))
137 self.assertEqual(o_data.q.shape(), unsigned(shape.width))
138 self.assertEqual(o_data.r.shape(), unsigned(shape.width))
139
140 def case(n, d):
141 assert isinstance(n, int)
142 assert isinstance(d, int)
143 if d != 0:
144 expected_q, expected_r = cldivrem(n, d, width=shape.width)
145 else:
146 expected_q = expected_r = 0
147 with self.subTest(n=hex(n), d=hex(d),
148 expected_q=hex(expected_q),
149 expected_r=hex(expected_r)):
150 yield dut.p.i_valid.eq(0)
151 yield Tick()
152 yield i_data.n.eq(n)
153 yield i_data.d.eq(d)
154 yield dut.p.i_valid.eq(1)
155 yield Delay(0.1e-6)
156 valid = yield dut.n.o_valid
157 ready = yield dut.p.o_ready
158 with self.subTest():
159 self.assertFalse(valid)
160 self.assertTrue(ready)
161 yield Tick()
162 yield i_data.n.eq(-1)
163 yield i_data.d.eq(-1)
164 yield dut.p.i_valid.eq(0)
165 for step in range(0, shape.done_step, steps_per_clock):
166 yield Delay(0.1e-6)
167 valid = yield dut.n.o_valid
168 ready = yield dut.p.o_ready
169 with self.subTest():
170 self.assertFalse(valid)
171 self.assertFalse(ready)
172 yield Tick()
173 yield Delay(0.1e-6)
174 valid = yield dut.n.o_valid
175 ready = yield dut.p.o_ready
176 with self.subTest():
177 self.assertTrue(valid)
178 self.assertFalse(ready)
179 q = yield o_data.q
180 r = yield o_data.r
181 with self.subTest(q=hex(q), r=hex(r)):
182 # only check results when inputs are valid
183 if d != 0 and (expected_q >> shape.width) == 0:
184 self.assertEqual(q, expected_q)
185 self.assertEqual(r, expected_r)
186 yield dut.n.i_ready.eq(1)
187 yield Tick()
188 yield Delay(0.1e-6)
189 valid = yield dut.n.o_valid
190 ready = yield dut.p.o_ready
191 with self.subTest():
192 self.assertFalse(valid)
193 self.assertTrue(ready)
194 yield dut.n.i_ready.eq(0)
195
196 def process():
197 if full:
198 for n in range(1 << shape.width):
199 for d in range(1 << shape.width):
200 yield from case(n, d)
201 else:
202 for i in range(100):
203 n = hash_256(f"cldivrem fsm n {i}")
204 n = Const.normalize(n, unsigned(shape.width))
205 d = hash_256(f"cldivrem fsm d {i}")
206 d = Const.normalize(d, unsigned(shape.width))
207 yield from case(n, d)
208
209 with do_sim(self, dut, list(dut.ports())) as sim:
210 sim.add_process(process)
211 sim.add_clock(1e-6)
212 sim.run()
213
214 def test_4_step_1(self):
215 self.tst(CLDivRemShape(width=4),
216 full=True,
217 steps_per_clock=1)
218
219 def test_4_step_2(self):
220 self.tst(CLDivRemShape(width=4),
221 full=True,
222 steps_per_clock=2)
223
224 def test_4_step_3(self):
225 self.tst(CLDivRemShape(width=4),
226 full=True,
227 steps_per_clock=3)
228
229 def test_4_step_4(self):
230 self.tst(CLDivRemShape(width=4),
231 full=True,
232 steps_per_clock=4)
233
234 def test_8_step_4(self):
235 self.tst(CLDivRemShape(width=8),
236 full=False,
237 steps_per_clock=4)
238
239 def test_64_step_4(self):
240 self.tst(CLDivRemShape(width=64),
241 full=False,
242 steps_per_clock=4)
243
244 def test_64_step_8(self):
245 self.tst(CLDivRemShape(width=64),
246 full=False,
247 steps_per_clock=8)
248
249
250 if __name__ == "__main__":
251 unittest.main()