fce8418bbf8272cdbecbcb9faacda350d97285ca
[soc.git] / src / soc / fu / div / test / test_fsm.py
1 import unittest
2 from soc.fu.div.fsm import DivState, DivStateInit, DivStateNext
3 from nmigen import Elaboratable, Module, Signal, unsigned
4 from nmigen.cli import rtlil
5 try:
6 from nmigen.sim.pysim import Simulator, Delay, Tick
7 except ImportError:
8 from nmigen.back.pysim import Simulator, Delay, Tick
9
10
11 class CheckEvent(Elaboratable):
12 """helper to add indication to vcd when signals are checked"""
13
14 def __init__(self):
15 self.event = Signal()
16
17 def trigger(self):
18 yield self.event.eq(~self.event)
19
20 def elaborate(self, platform):
21 m = Module()
22 # use event somehow so nmigen simulation knows about it
23 m.d.comb += Signal().eq(self.event)
24 return m
25
26
27 class DivStateCombTest(Elaboratable):
28 """Test stringing a bunch of copies of the FSM state-function together"""
29
30 def __init__(self, quotient_width):
31 self.check_event = CheckEvent()
32 self.quotient_width = quotient_width
33 self.dividend = Signal(unsigned(quotient_width * 2))
34 self.divisor = Signal(unsigned(quotient_width))
35 self.quotient = Signal(unsigned(quotient_width))
36 self.remainder = Signal(unsigned(quotient_width))
37 self.expected_quotient = Signal(unsigned(quotient_width))
38 self.expected_remainder = Signal(unsigned(quotient_width))
39 self.expected_valid = Signal()
40 self.states = []
41 for i in range(quotient_width + 1):
42 state = DivState(quotient_width=quotient_width, name=f"state{i}")
43 self.states.append(state)
44 self.init = DivStateInit(quotient_width)
45 self.nexts = []
46 for i in range(quotient_width):
47 next = DivStateNext(quotient_width)
48 self.nexts.append(next)
49
50 def elaborate(self, platform):
51 m = Module()
52 m.submodules.check_event = self.check_event
53 m.submodules.init = self.init
54 m.d.comb += self.init.dividend.eq(self.dividend)
55 m.d.comb += self.states[0].eq(self.init.o)
56 last_state = self.states[0]
57 for i in range(self.quotient_width):
58 setattr(m.submodules, f"next{i}", self.nexts[i])
59 m.d.comb += self.nexts[i].divisor.eq(self.divisor)
60 m.d.comb += self.nexts[i].i.eq(last_state)
61 last_state = self.states[i + 1]
62 m.d.comb += last_state.eq(self.nexts[i].o)
63 m.d.comb += self.quotient.eq(last_state.quotient)
64 m.d.comb += self.remainder.eq(last_state.remainder)
65 m.d.comb += self.expected_valid.eq(
66 (self.dividend < (self.divisor << self.quotient_width))
67 & (self.divisor != 0))
68 with m.If(self.expected_valid):
69 m.d.comb += self.expected_quotient.eq(
70 self.dividend // self.divisor)
71 m.d.comb += self.expected_remainder.eq(
72 self.dividend % self.divisor)
73 return m
74
75
76 class DivStateFSMTest(Elaboratable):
77 def __init__(self, quotient_width):
78 self.check_done_event = CheckEvent()
79 self.check_event = CheckEvent()
80 self.quotient_width = quotient_width
81 self.dividend = Signal(unsigned(quotient_width * 2))
82 self.divisor = Signal(unsigned(quotient_width))
83 self.quotient = Signal(unsigned(quotient_width))
84 self.remainder = Signal(unsigned(quotient_width))
85 self.expected_quotient = Signal(unsigned(quotient_width))
86 self.expected_remainder = Signal(unsigned(quotient_width))
87 self.expected_valid = Signal()
88 self.state = DivState(quotient_width=quotient_width,
89 name="state")
90 self.next_state = DivState(quotient_width=quotient_width,
91 name="next_state")
92 self.init = DivStateInit(quotient_width)
93 self.next = DivStateNext(quotient_width)
94 self.state_done = Signal()
95 self.next_state_done = Signal()
96 self.clear = Signal(reset=1)
97
98 def elaborate(self, platform):
99 m = Module()
100 m.submodules.check_event = self.check_event
101 m.submodules.check_done_event = self.check_done_event
102 m.submodules.init = self.init
103 m.submodules.next = self.next
104 m.d.comb += self.init.dividend.eq(self.dividend)
105 m.d.comb += self.next.divisor.eq(self.divisor)
106 m.d.comb += self.quotient.eq(self.state.quotient)
107 m.d.comb += self.remainder.eq(self.state.remainder)
108 m.d.comb += self.next.i.eq(self.state)
109 m.d.comb += self.state_done.eq(self.state.done)
110 m.d.comb += self.next_state_done.eq(self.next_state.done)
111
112 with m.If(self.state.done | self.clear):
113 m.d.comb += self.next_state.eq(self.init.o)
114 with m.Else():
115 m.d.comb += self.next_state.eq(self.next.o)
116
117 m.d.sync += self.state.eq(self.next_state)
118
119 m.d.comb += self.expected_valid.eq(
120 (self.dividend < (self.divisor << self.quotient_width))
121 & (self.divisor != 0))
122 with m.If(self.expected_valid):
123 m.d.comb += self.expected_quotient.eq(
124 self.dividend // self.divisor)
125 m.d.comb += self.expected_remainder.eq(
126 self.dividend % self.divisor)
127 return m
128
129
130 def get_cases(quotient_width):
131 test_cases = []
132 mask = ~(~0 << quotient_width)
133 for i in range(-3, 4):
134 test_cases.append(i & mask)
135 for i in [-1, 0, 1]:
136 test_cases.append((i + (mask >> 1)) & mask)
137 test_cases.sort()
138 return test_cases
139
140
141 class TestDivState(unittest.TestCase):
142 def test_div_state_comb(self, quotient_width=8):
143 test_cases = get_cases(quotient_width)
144 mask = ~(~0 << quotient_width)
145 dut = DivStateCombTest(quotient_width)
146 vl = rtlil.convert(dut,
147 ports=[dut.dividend,
148 dut.divisor,
149 dut.quotient,
150 dut.remainder])
151 with open("div_fsm_comb_pipeline.il", "w") as f:
152 f.write(vl)
153 dut = DivStateCombTest(quotient_width)
154
155 def check(dividend, divisor):
156 with self.subTest(dividend=f"{dividend:#x}",
157 divisor=f"{divisor:#x}"):
158 yield from dut.check_event.trigger()
159 for i in range(quotient_width + 1):
160 # done must be correct and eventually true
161 # even if a div-by-zero or overflow occurred
162 done = yield dut.states[i].done
163 self.assertEqual(done, i == quotient_width)
164 if divisor != 0:
165 quotient = dividend // divisor
166 remainder = dividend % divisor
167 if quotient <= mask:
168 with self.subTest(quotient=f"{quotient:#x}",
169 remainder=f"{remainder:#x}"):
170 self.assertTrue((yield dut.expected_valid))
171 self.assertEqual((yield dut.expected_quotient), quotient)
172 self.assertEqual((yield dut.expected_remainder), remainder)
173 self.assertEqual((yield dut.quotient), quotient)
174 self.assertEqual((yield dut.remainder), remainder)
175 else:
176 self.assertFalse((yield dut.expected_valid))
177 else:
178 self.assertFalse((yield dut.expected_valid))
179
180 def process(gen):
181 for dividend_high in test_cases:
182 for dividend_low in test_cases:
183 dividend = dividend_low + \
184 (dividend_high << quotient_width)
185 for divisor in test_cases:
186 if gen:
187 yield Delay(0.5e-6)
188 yield dut.dividend.eq(dividend)
189 yield dut.divisor.eq(divisor)
190 yield Delay(0.5e-6)
191 else:
192 yield Delay(1e-6)
193 yield from check(dividend, divisor)
194
195 def gen_process():
196 yield from process(gen=True)
197
198 def check_process():
199 yield from process(gen=False)
200
201 sim = Simulator(dut)
202 with sim.write_vcd(vcd_file="div_fsm_comb_pipeline.vcd",
203 gtkw_file="div_fsm_comb_pipeline.gtkw"):
204
205 sim.add_process(gen_process)
206 sim.add_process(check_process)
207 sim.run()
208
209 def test_div_state_fsm(self, quotient_width=8):
210 test_cases = get_cases(quotient_width)
211 mask = ~(~0 << quotient_width)
212 dut = DivStateFSMTest(quotient_width)
213 vl = rtlil.convert(dut,
214 ports=[dut.dividend,
215 dut.divisor,
216 dut.quotient,
217 dut.remainder])
218 with open("div_fsm.il", "w") as f:
219 f.write(vl)
220
221 def check(dividend, divisor):
222 with self.subTest(dividend=f"{dividend:#x}",
223 divisor=f"{divisor:#x}"):
224 for i in range(quotient_width + 1):
225 yield Tick()
226 yield Delay(0.1e-6)
227 yield from dut.check_done_event.trigger()
228 with self.subTest():
229 # done must be correct and eventually true
230 # even if a div-by-zero or overflow occurred
231 done = yield dut.state.done
232 self.assertEqual(done, i == quotient_width)
233 yield from dut.check_event.trigger()
234 now = None
235 try:
236 # FIXME(programmerjake): replace with public API
237 # see https://github.com/nmigen/nmigen/issues/443
238 now = sim._state.timeline.now
239 except KeyError:
240 pass
241 if divisor != 0:
242 quotient = dividend // divisor
243 remainder = dividend % divisor
244 if quotient <= mask:
245 with self.subTest(quotient=f"{quotient:#x}",
246 remainder=f"{remainder:#x}",
247 now=f"{now}"):
248 self.assertTrue((yield dut.expected_valid))
249 self.assertEqual((yield dut.expected_quotient), quotient)
250 self.assertEqual((yield dut.expected_remainder), remainder)
251 self.assertEqual((yield dut.quotient), quotient)
252 self.assertEqual((yield dut.remainder), remainder)
253 else:
254 self.assertFalse((yield dut.expected_valid))
255 else:
256 self.assertFalse((yield dut.expected_valid))
257
258 def process(gen):
259 if gen:
260 yield dut.clear.eq(1)
261 yield Tick()
262 else:
263 yield from dut.check_event.trigger()
264 yield from dut.check_done_event.trigger()
265 for dividend_high in test_cases:
266 for dividend_low in test_cases:
267 dividend = dividend_low + \
268 (dividend_high << quotient_width)
269 for divisor in test_cases:
270 if gen:
271 yield Delay(0.2e-6)
272 yield dut.clear.eq(0)
273 yield dut.dividend.eq(dividend)
274 yield dut.divisor.eq(divisor)
275 for _ in range(quotient_width + 1):
276 yield Tick()
277 else:
278 yield from check(dividend, divisor)
279
280 def gen_process():
281 yield from process(gen=True)
282
283 def check_process():
284 yield from process(gen=False)
285
286 sim = Simulator(dut)
287 with sim.write_vcd(vcd_file="div_fsm.vcd",
288 gtkw_file="div_fsm.gtkw"):
289
290 sim.add_clock(1e-6)
291 sim.add_process(gen_process)
292 sim.add_process(check_process)
293 sim.run()
294
295
296 if __name__ == "__main__":
297 unittest.main()