split step counter into clock and substep
[nmigen-gf.git] / src / nmigen_gf / hdl / test / test_cldivrem.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 import unittest
8 from nmigen.hdl.ast import Signal, Const, unsigned
9 from nmigen.hdl.dsl import Module
10 from nmutil.formaltest import FHDLTestCase
11 from nmigen_gf.hdl.cldivrem import (CLDivRemFSMStage, CLDivRemInputData,
12 CLDivRemOutputData, CLDivRemShape,
13 cldivrem_shifting, CLDivRemState)
14 from nmigen.sim import Delay, Tick
15 from nmutil.sim_util import do_sim, hash_256
16 from nmigen_gf.reference.cldivrem import cldivrem
17
18
19 class TestCLDivRemShifting(FHDLTestCase):
20 def tst(self, shape, full):
21 assert isinstance(shape, CLDivRemShape)
22
23 def case(n, d):
24 assert isinstance(n, int)
25 assert isinstance(d, int)
26 if d != 0:
27 expected_q, expected_r = cldivrem(n, d, width=shape.width)
28 q, r = cldivrem_shifting(n, d, shape)
29 else:
30 expected_q = expected_r = 0
31 q = r = 0
32 with self.subTest(expected_q=hex(expected_q),
33 expected_r=hex(expected_r),
34 q=hex(q), r=hex(r)):
35 self.assertEqual(expected_q, q)
36 self.assertEqual(expected_r, r)
37 if full:
38 for n in range(1 << shape.width):
39 for d in range(1 << shape.width):
40 with self.subTest(n=hex(n), d=hex(d)):
41 case(n, d)
42 else:
43 for i in range(100):
44 n = hash_256(f"cldivrem comb n {i}")
45 n = Const.normalize(n, unsigned(shape.width))
46 d = hash_256(f"cldivrem comb d {i}")
47 d = Const.normalize(d, unsigned(shape.width))
48 case(n, d)
49
50 def test_6_step_1(self):
51 self.tst(CLDivRemShape(width=6, steps_per_clock=1), full=True)
52
53 def test_6_step_2(self):
54 self.tst(CLDivRemShape(width=6, steps_per_clock=2), full=True)
55
56 def test_6_step_3(self):
57 self.tst(CLDivRemShape(width=6, steps_per_clock=3), full=True)
58
59 def test_6_step_4(self):
60 self.tst(CLDivRemShape(width=6, steps_per_clock=4), full=True)
61
62 def test_6_step_6(self):
63 self.tst(CLDivRemShape(width=6, steps_per_clock=6), full=True)
64
65 def test_6_step_10(self):
66 self.tst(CLDivRemShape(width=6, steps_per_clock=10), full=True)
67
68 def test_64_step_1(self):
69 self.tst(CLDivRemShape(width=64, steps_per_clock=1), full=False)
70
71 def test_64_step_2(self):
72 self.tst(CLDivRemShape(width=64, steps_per_clock=2), full=False)
73
74 def test_64_step_3(self):
75 self.tst(CLDivRemShape(width=64, steps_per_clock=3), full=False)
76
77 def test_64_step_4(self):
78 self.tst(CLDivRemShape(width=64, steps_per_clock=4), full=False)
79
80 def test_64_step_8(self):
81 self.tst(CLDivRemShape(width=64, steps_per_clock=8), full=False)
82
83
84 class TestCLDivRemComb(FHDLTestCase):
85 def tst(self, shape, full):
86 assert isinstance(shape, CLDivRemShape)
87 width = shape.width
88 m = Module()
89 n_in = Signal(width)
90 d_in = Signal(width)
91 q_out = Signal(width)
92 r_out = Signal(width)
93 states: "list[CLDivRemState]" = []
94 for i in range(shape.width + 1):
95 states.append(CLDivRemState(shape, name=f"state_{i}"))
96 if i == 0:
97 states[i].set_to_initial(m, n=n_in, d=d_in)
98 else:
99 states[i].set_to_next(m, states[i - 1])
100 q, r = states[-1].get_output()
101 m.d.comb += [q_out.eq(q), r_out.eq(r)]
102
103 def case(n, d):
104 assert isinstance(n, int)
105 assert isinstance(d, int)
106 if d != 0:
107 expected_q, expected_r = cldivrem_shifting(n, d, shape)
108 else:
109 expected_q = expected_r = 0
110 with self.subTest(n=hex(n), d=hex(d),
111 expected_q=hex(expected_q),
112 expected_r=hex(expected_r)):
113 yield n_in.eq(n)
114 yield d_in.eq(d)
115 yield Delay(1e-6)
116 for i, state in enumerate(states):
117 with self.subTest(i=i):
118 done = yield state.done
119 substep = yield state.substep
120 clock = yield state.clock
121 self.assertEqual(done, i >= shape.width)
122 if i % shape.steps_per_clock == 0:
123 self.assertEqual(substep, 0)
124 if i < shape.width:
125 self.assertEqual(substep
126 + clock * shape.steps_per_clock, i)
127 q = yield q_out
128 r = yield r_out
129 with self.subTest(q=hex(q), r=hex(r)):
130 # only check results when inputs are valid
131 if d != 0:
132 self.assertEqual(q, expected_q)
133 self.assertEqual(r, expected_r)
134
135 def process():
136 if full:
137 for n in range(1 << width):
138 for d in range(1 << width):
139 yield from case(n, d)
140 else:
141 for i in range(100):
142 n = hash_256(f"cldivrem comb n {i}")
143 n = Const.normalize(n, unsigned(width))
144 d = hash_256(f"cldivrem comb d {i}")
145 d = Const.normalize(d, unsigned(width))
146 yield from case(n, d)
147 with do_sim(self, m, [n_in, d_in, q_out, r_out]) as sim:
148 sim.add_process(process)
149 sim.run()
150
151 def test_4_step_1(self):
152 self.tst(CLDivRemShape(width=4, steps_per_clock=1), full=True)
153
154 def test_4_step_2(self):
155 self.tst(CLDivRemShape(width=4, steps_per_clock=2), full=True)
156
157 def test_4_step_3(self):
158 self.tst(CLDivRemShape(width=4, steps_per_clock=3), full=True)
159
160 def test_4_step_4(self):
161 self.tst(CLDivRemShape(width=4, steps_per_clock=4), full=True)
162
163 def test_6_step_1(self):
164 self.tst(CLDivRemShape(width=6, steps_per_clock=1), full=False)
165
166 def test_6_step_2(self):
167 self.tst(CLDivRemShape(width=6, steps_per_clock=2), full=False)
168
169 def test_6_step_6(self):
170 self.tst(CLDivRemShape(width=6, steps_per_clock=6), full=False)
171
172 def test_6_step_8(self):
173 self.tst(CLDivRemShape(width=6, steps_per_clock=8), full=False)
174
175 def test_8_step_1(self):
176 self.tst(CLDivRemShape(width=8, steps_per_clock=1), full=False)
177
178 def test_8_step_4(self):
179 self.tst(CLDivRemShape(width=8, steps_per_clock=4), full=False)
180
181 def test_8_step_8(self):
182 self.tst(CLDivRemShape(width=8, steps_per_clock=8), full=False)
183
184
185 class TestCLDivRemFSM(FHDLTestCase):
186 def tst(self, shape, full):
187 assert isinstance(shape, CLDivRemShape)
188 pspec = {}
189 dut = CLDivRemFSMStage(pspec, shape)
190 i_data: CLDivRemInputData = dut.p.i_data
191 o_data: CLDivRemOutputData = dut.n.o_data
192 self.assertEqual(i_data.n.shape(), unsigned(shape.width))
193 self.assertEqual(i_data.d.shape(), unsigned(shape.width))
194 self.assertEqual(o_data.q.shape(), unsigned(shape.width))
195 self.assertEqual(o_data.r.shape(), unsigned(shape.width))
196
197 def case(n, d):
198 assert isinstance(n, int)
199 assert isinstance(d, int)
200 if d != 0:
201 expected_q, expected_r = cldivrem(n, d, width=shape.width)
202 else:
203 expected_q = expected_r = 0
204 with self.subTest(n=hex(n), d=hex(d),
205 expected_q=hex(expected_q),
206 expected_r=hex(expected_r)):
207 yield dut.p.i_valid.eq(0)
208 yield Tick()
209 yield i_data.n.eq(n)
210 yield i_data.d.eq(d)
211 yield dut.p.i_valid.eq(1)
212 yield Delay(0.1e-6)
213 valid = yield dut.n.o_valid
214 ready = yield dut.p.o_ready
215 with self.subTest():
216 self.assertFalse(valid)
217 self.assertTrue(ready)
218 yield Tick()
219 yield i_data.n.eq(-1)
220 yield i_data.d.eq(-1)
221 yield dut.p.i_valid.eq(0)
222 for step in range(shape.done_clock):
223 yield Delay(0.1e-6)
224 valid = yield dut.n.o_valid
225 ready = yield dut.p.o_ready
226 with self.subTest():
227 self.assertFalse(valid)
228 self.assertFalse(ready)
229 yield Tick()
230 yield Delay(0.1e-6)
231 valid = yield dut.n.o_valid
232 ready = yield dut.p.o_ready
233 with self.subTest():
234 self.assertTrue(valid)
235 self.assertFalse(ready)
236 q = yield o_data.q
237 r = yield o_data.r
238 with self.subTest(q=hex(q), r=hex(r)):
239 # only check results when inputs are valid
240 if d != 0 and (expected_q >> shape.width) == 0:
241 self.assertEqual(q, expected_q)
242 self.assertEqual(r, expected_r)
243 yield dut.n.i_ready.eq(1)
244 yield Tick()
245 yield Delay(0.1e-6)
246 valid = yield dut.n.o_valid
247 ready = yield dut.p.o_ready
248 with self.subTest():
249 self.assertFalse(valid)
250 self.assertTrue(ready)
251 yield dut.n.i_ready.eq(0)
252
253 def process():
254 if full:
255 for n in range(1 << shape.width):
256 for d in range(1 << shape.width):
257 yield from case(n, d)
258 else:
259 for i in range(100):
260 n = hash_256(f"cldivrem fsm n {i}")
261 n = Const.normalize(n, unsigned(shape.width))
262 d = hash_256(f"cldivrem fsm d {i}")
263 d = Const.normalize(d, unsigned(shape.width))
264 yield from case(n, d)
265
266 with do_sim(self, dut, list(dut.ports())) as sim:
267 sim.add_process(process)
268 sim.add_clock(1e-6)
269 sim.run()
270
271 def test_4_step_1(self):
272 self.tst(CLDivRemShape(width=4, steps_per_clock=1), full=True)
273
274 def test_4_step_2(self):
275 self.tst(CLDivRemShape(width=4, steps_per_clock=2), full=True)
276
277 def test_4_step_3(self):
278 self.tst(CLDivRemShape(width=4, steps_per_clock=3), full=True)
279
280 def test_4_step_4(self):
281 self.tst(CLDivRemShape(width=4, steps_per_clock=4), full=True)
282
283 def test_8_step_4(self):
284 self.tst(CLDivRemShape(width=8, steps_per_clock=4), full=False)
285
286 def test_64_step_4(self):
287 self.tst(CLDivRemShape(width=64, steps_per_clock=4), full=False)
288
289 def test_64_step_8(self):
290 self.tst(CLDivRemShape(width=64, steps_per_clock=8), full=False)
291
292
293 if __name__ == "__main__":
294 unittest.main()