31650e53db40ea65b429629298001be5fe2f1232
[nmigen-gf.git] / src / nmigen_gf / hdl / cldivrem.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 """ Carry-less Division and Remainder.
8
9 https://bugs.libre-soc.org/show_bug.cgi?id=784
10 """
11
12 from dataclasses import dataclass, field, fields
13 from nmigen.hdl.ast import Signal, Value, Assert
14 from nmigen.hdl.dsl import Module
15 from nmutil.singlepipe import ControlBase
16 from nmutil.clz import CLZ, clz
17
18
19 def cldivrem_shifting(n, d, shape):
20 """ Carry-less Division and Remainder based on shifting at start and end
21 allowing us to get away with checking a single bit each iteration
22 rather than checking for equal degrees every iteration.
23 `n` and `d` are integers, `width` is the number of bits needed to hold
24 each input/output.
25 Returns a tuple `q, r` of the quotient and remainder.
26 """
27 assert isinstance(shape, CLDivRemShape)
28 assert isinstance(n, int) and 0 <= n < 1 << shape.width
29 assert isinstance(d, int) and 0 <= d < 1 << shape.width
30 assert d != 0, "TODO: decide what happens on division by zero"
31
32 # declare locals so nonlocal works
33 r = q = shift = clock = substep = NotImplemented
34
35 # functions match up to HDL parts:
36
37 def set_to_initial():
38 nonlocal d, r, q, clock, substep, shift
39 # `clz(d, shape.width)`, but maxes out at `shape.width - 1` instead of
40 # `shape.width` in order to both fit in `shape.shift_width` bits and
41 # to not shift by more than needed.
42 shift = clz(d >> 1, shape.width - 1)
43 assert 0 <= shift < 1 << shape.shift_width, "shift overflow"
44 d <<= shift
45 assert 0 <= d < 1 << shape.d_width, "d overflow"
46 r = n << shift
47 assert 0 <= r < 1 << shape.r_width, "r overflow"
48 q = 0
49 clock = 0
50 substep = 0
51
52 def done():
53 return clock == shape.done_clock
54
55 def set_to_next():
56 nonlocal r, q, clock, substep
57 substep += 1
58 substep %= shape.steps_per_clock
59 if done():
60 return
61 elif substep == 0:
62 clock += 1
63 if clock == shape.width // shape.steps_per_clock \
64 and substep >= shape.width % shape.steps_per_clock:
65 clock = shape.done_clock
66 q <<= 1
67 r <<= 1
68 if r >> (shape.width * 2 - 1) != 0:
69 r ^= d << shape.width
70 q |= 1
71 assert 0 <= q < 1 << shape.q_width, "q overflow"
72 assert 0 <= r < 1 << shape.r_width, "r overflow"
73
74 def get_output():
75 return q, (r >> shape.width) >> shift
76
77 set_to_initial()
78
79 # one clock-cycle per outer loop
80 while not done():
81 for expected_substep in range(shape.steps_per_clock):
82 assert substep == expected_substep
83 set_to_next()
84
85 return get_output()
86
87
88 @dataclass(frozen=True, unsafe_hash=True)
89 class CLDivRemShape:
90 width: int
91 """bit-width of each of the carry-less div/rem inputs and outputs"""
92
93 steps_per_clock: int = 8
94 """number of steps that should be taken per clock cycle"""
95
96 def __post_init__(self):
97 assert isinstance(self.width, int) and self.width >= 1, "invalid width"
98 assert (isinstance(self.steps_per_clock, int)
99 and self.steps_per_clock >= 1), "invalid steps_per_clock"
100
101 @property
102 def done_clock(self):
103 """the clock tick number when iteration is finished
104 -- the largest `CLDivRemState.clock` will get
105 """
106 if self.width % self.steps_per_clock == 0:
107 return self.width // self.steps_per_clock
108 return self.width // self.steps_per_clock + 1
109
110 @property
111 def clock_range(self):
112 """the range that `CLDivRemState.clock` will fall in.
113
114 returns: range
115 """
116 return range(self.done_clock + 1)
117
118 @property
119 def substep_range(self):
120 """the range that `CLDivRemState.substep` will fall in.
121
122 returns: range
123 """
124 return range(self.steps_per_clock)
125
126 @property
127 def d_width(self):
128 """bit-width of the internal signal `CLDivRemState.d`"""
129 return self.width
130
131 @property
132 def r_width(self):
133 """bit-width of the internal signal `CLDivRemState.r`"""
134 return self.width * 2
135
136 @property
137 def q_width(self):
138 """bit-width of the internal signal `CLDivRemState.q`"""
139 return self.width
140
141 @property
142 def shift_width(self):
143 """bit-width of the internal signal `CLDivRemState.shift`"""
144 return (self.width - 1).bit_length()
145
146
147 @dataclass(frozen=True, eq=False)
148 class CLDivRemState:
149 shape: CLDivRemShape
150 name: str
151 clock: Signal = field(init=False)
152 substep: Signal = field(init=False)
153 d: Signal = field(init=False)
154 r: Signal = field(init=False)
155 q: Signal = field(init=False)
156 shift: Signal = field(init=False)
157
158 def __init__(self, shape, *, name=None, src_loc_at=0):
159 assert isinstance(shape, CLDivRemShape)
160 if name is None:
161 name = Signal(src_loc_at=1 + src_loc_at).name
162 assert isinstance(name, str)
163 clock = Signal(shape.clock_range, name=f"{name}_clock")
164 substep = Signal(shape.substep_range, name=f"{name}_substep", reset=0)
165 d = Signal(shape.d_width, name=f"{name}_d")
166 r = Signal(shape.r_width, name=f"{name}_r")
167 q = Signal(shape.q_width, name=f"{name}_q")
168 shift = Signal(shape.shift_width, name=f"{name}_shift")
169 object.__setattr__(self, "shape", shape)
170 object.__setattr__(self, "name", name)
171 object.__setattr__(self, "clock", clock)
172 object.__setattr__(self, "substep", substep)
173 object.__setattr__(self, "d", d)
174 object.__setattr__(self, "r", r)
175 object.__setattr__(self, "q", q)
176 object.__setattr__(self, "shift", shift)
177
178 def eq(self, rhs):
179 assert isinstance(rhs, CLDivRemState)
180 for f in fields(CLDivRemState):
181 if f.name in ("shape", "name"):
182 continue
183 l = getattr(self, f.name)
184 r = getattr(rhs, f.name)
185 yield l.eq(r)
186
187 @staticmethod
188 def like(other, *, name=None, src_loc_at=0):
189 assert isinstance(other, CLDivRemState)
190 return CLDivRemState(other.shape, name=name, src_loc_at=1 + src_loc_at)
191
192 @property
193 def done(self):
194 return self.clock == self.shape.done_clock
195
196 def get_output(self):
197 return self.q, (self.r >> self.shape.width) >> self.shift
198
199 def set_to_initial(self, m, n, d):
200 assert isinstance(m, Module)
201 n = Value.cast(n) # convert to Value
202 d = Value.cast(d) # convert to Value
203 clz_mod = CLZ(self.shape.width - 1)
204 # can't name submodule since it would conflict if this function is
205 # called multiple times in a Module
206 m.submodules += clz_mod
207 assert clz_mod.lz.width == self.shape.shift_width, \
208 "internal inconsistency -- mismatched shift signal width"
209 m.d.comb += [
210 clz_mod.sig_in.eq(d >> 1),
211 self.shift.eq(clz_mod.lz),
212 self.d.eq(d << self.shift),
213 self.r.eq(n << self.shift),
214 self.q.eq(0),
215 self.clock.eq(0),
216 self.substep.eq(0),
217 ]
218
219 def eq_but_zero_substep(self, rhs, do_assert):
220 assert isinstance(rhs, CLDivRemState)
221 for f in fields(CLDivRemState):
222 if f.name in ("shape", "name"):
223 continue
224 l = getattr(self, f.name)
225 r = getattr(rhs, f.name)
226 if f.name == "substep":
227 if do_assert:
228 yield Assert(r == 0)
229 r = 0
230 yield l.eq(r)
231
232 def set_to_next(self, m, state_in):
233 assert isinstance(m, Module)
234 assert isinstance(state_in, CLDivRemState)
235 assert state_in.shape == self.shape
236 assert self is not state_in, "a.set_to_next(m, a) is not allowed"
237 width = self.shape.width
238 substep_wraps = state_in.substep >= self.shape.steps_per_clock - 1
239 with m.If(substep_wraps):
240 m.d.comb += self.substep.eq(0)
241 with m.Else():
242 m.d.comb += self.substep.eq(state_in.substep + 1)
243
244 with m.If(state_in.done):
245 m.d.comb += [
246 self.clock.eq(state_in.clock),
247 self.d.eq(state_in.d),
248 self.r.eq(state_in.r),
249 self.q.eq(state_in.q),
250 self.shift.eq(state_in.shift),
251 ]
252 with m.Else():
253 clock = state_in.clock + substep_wraps
254 with m.If((clock == width // self.shape.steps_per_clock)
255 & (self.substep >= width % self.shape.steps_per_clock)):
256 m.d.comb += self.clock.eq(self.shape.done_clock)
257 with m.Else():
258 m.d.comb += self.clock.eq(clock)
259 m.d.comb += [
260 self.d.eq(state_in.d),
261 self.shift.eq(state_in.shift),
262 ]
263 q = state_in.q << 1
264 r = state_in.r << 1
265 with m.If(r[width * 2 - 1]):
266 m.d.comb += [
267 self.q.eq(q | 1),
268 self.r.eq(r ^ (state_in.d << width)),
269 ]
270 with m.Else():
271 m.d.comb += [
272 self.q.eq(q),
273 self.r.eq(r),
274 ]
275
276
277 class CLDivRemInputData:
278 def __init__(self, shape):
279 assert isinstance(shape, CLDivRemShape)
280 self.shape = shape
281 self.n = Signal(shape.width)
282 self.d = Signal(shape.width)
283
284 def __iter__(self):
285 """ Get member signals. """
286 yield self.n
287 yield self.d
288
289 def eq(self, rhs):
290 """ Assign member signals. """
291 return [
292 self.n.eq(rhs.n),
293 self.d.eq(rhs.d),
294 ]
295
296
297 class CLDivRemOutputData:
298 def __init__(self, shape):
299 assert isinstance(shape, CLDivRemShape)
300 self.shape = shape
301 self.q = Signal(shape.width)
302 self.r = Signal(shape.width)
303
304 def __iter__(self):
305 """ Get member signals. """
306 yield self.q
307 yield self.r
308
309 def eq(self, rhs):
310 """ Assign member signals. """
311 return [
312 self.q.eq(rhs.q),
313 self.r.eq(rhs.r),
314 ]
315
316 def eq_output(self, state):
317 assert isinstance(state, CLDivRemState)
318 assert state.shape == self.shape
319 q, r = state.get_output()
320 return [self.q.eq(q), self.r.eq(r)]
321
322
323 class CLDivRemFSMStage(ControlBase):
324 """carry-less div/rem
325
326 Attributes:
327 shape: CLDivRemShape
328 the shape
329 pspec:
330 pipe-spec
331 empty: Signal()
332 true if nothing is stored in `self.saved_state`
333 saved_state: CLDivRemState()
334 the saved state that is currently being worked on.
335 """
336
337 def __init__(self, pspec, shape):
338 assert isinstance(shape, CLDivRemShape)
339 self.shape = shape
340 self.pspec = pspec # store now: used in ispec and ospec
341 super().__init__(stage=self)
342 self.empty = Signal(reset=1)
343 self.saved_state = CLDivRemState(shape)
344
345 def ispec(self):
346 return CLDivRemInputData(self.shape)
347
348 def ospec(self):
349 return CLDivRemOutputData(self.shape)
350
351 def setup(self, m, i):
352 pass
353
354 def elaborate(self, platform):
355 m = super().elaborate(platform)
356 i_data: CLDivRemInputData = self.p.i_data
357 o_data: CLDivRemOutputData = self.n.o_data
358 steps_per_clock = self.shape.steps_per_clock
359
360 # TODO: handle cancellation
361
362 m.d.comb += self.n.o_valid.eq(~self.empty & self.saved_state.done)
363 m.d.comb += self.p.o_ready.eq(self.empty)
364
365 def make_nc(i):
366 return CLDivRemState(self.shape, name=f"next_chain_{i}")
367 next_chain = [make_nc(i) for i in range(steps_per_clock + 1)]
368 for i in range(steps_per_clock):
369 next_chain[i + 1].set_to_next(m, next_chain[i])
370 m.d.comb += next_chain[0].eq(self.saved_state)
371 m.d.comb += o_data.eq_output(self.saved_state)
372 initial_state = CLDivRemState(self.shape)
373 initial_state.set_to_initial(m, n=i_data.n, d=i_data.d)
374
375 do_assert = platform == "formal"
376
377 with m.If(self.empty):
378 m.d.sync += self.saved_state.eq_but_zero_substep(initial_state,
379 do_assert)
380 with m.If(self.p.i_valid):
381 m.d.sync += self.empty.eq(0)
382 with m.Else():
383 m.d.sync += self.saved_state.eq_but_zero_substep(next_chain[-1],
384 do_assert)
385 with m.If(self.n.i_ready & self.n.o_valid):
386 m.d.sync += self.empty.eq(1)
387 return m
388
389 def __iter__(self):
390 yield from self.p
391 yield from self.n
392
393 def ports(self):
394 return list(self)