change to use plain_data.fields
[nmigen-gf.git] / src / nmigen_gf / hdl / cldivrem.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 """ Carry-less Division and Remainder.
8
9 https://bugs.libre-soc.org/show_bug.cgi?id=784
10 """
11
12 from nmigen.hdl.ast import Signal, Value, Assert
13 from nmigen.hdl.dsl import Module
14 from nmutil.singlepipe import ControlBase
15 from nmutil.clz import CLZ, clz
16 from nmutil.plain_data import plain_data, fields
17
18
19 def cldivrem_shifting(n, d, shape):
20 """ Carry-less Division and Remainder based on shifting at start and end
21 allowing us to get away with checking a single bit each iteration
22 rather than checking for equal degrees every iteration.
23 `n` and `d` are integers, `width` is the number of bits needed to hold
24 each input/output.
25 Returns a tuple `q, r` of the quotient and remainder.
26 """
27 assert isinstance(shape, CLDivRemShape)
28 assert isinstance(n, int) and 0 <= n < 1 << shape.width
29 assert isinstance(d, int) and 0 <= d < 1 << shape.width
30 assert d != 0, "TODO: decide what happens on division by zero"
31
32 # declare locals so nonlocal works
33 r = q = shift = clock = substep = NotImplemented
34
35 # functions match up to HDL parts:
36
37 def set_to_initial():
38 nonlocal d, r, q, clock, substep, shift
39 # `clz(d, shape.width)`, but maxes out at `shape.width - 1` instead of
40 # `shape.width` in order to both fit in `shape.shift_width` bits and
41 # to not shift by more than needed.
42 shift = clz(d >> 1, shape.width - 1)
43 assert 0 <= shift < 1 << shape.shift_width, "shift overflow"
44 d <<= shift
45 assert 0 <= d < 1 << shape.d_width, "d overflow"
46 r = n << shift
47 assert 0 <= r < 1 << shape.r_width, "r overflow"
48 q = 0
49 clock = 0
50 substep = 0
51
52 def done():
53 return clock == shape.done_clock
54
55 def set_to_next():
56 nonlocal r, q, clock, substep
57 substep += 1
58 substep %= shape.steps_per_clock
59 if done():
60 return
61 elif substep == 0:
62 clock += 1
63 if clock == shape.width // shape.steps_per_clock \
64 and substep >= shape.width % shape.steps_per_clock:
65 clock = shape.done_clock
66 q <<= 1
67 r <<= 1
68 if r >> (shape.width * 2 - 1) != 0:
69 r ^= d << shape.width
70 q |= 1
71 assert 0 <= q < 1 << shape.q_width, "q overflow"
72 assert 0 <= r < 1 << shape.r_width, "r overflow"
73
74 def get_output():
75 return q, (r >> shape.width) >> shift
76
77 set_to_initial()
78
79 # one clock-cycle per outer loop
80 while not done():
81 for expected_substep in range(shape.steps_per_clock):
82 assert substep == expected_substep
83 set_to_next()
84
85 return get_output()
86
87
88 @plain_data(frozen=True, unsafe_hash=True)
89 class CLDivRemShape:
90 __slots__ = "width", "steps_per_clock"
91
92 def __init__(self, width, steps_per_clock=8):
93 assert isinstance(width, int) and width >= 1, "invalid width"
94 assert (isinstance(steps_per_clock, int)
95 and steps_per_clock >= 1), "invalid steps_per_clock"
96 self.width = width
97 """bit-width of each of the carry-less div/rem inputs and outputs"""
98
99 self.steps_per_clock = steps_per_clock
100 """number of steps that should be taken per clock cycle"""
101
102 @property
103 def done_clock(self):
104 """the clock tick number when iteration is finished
105 -- the largest `CLDivRemState.clock` will get
106 """
107 if self.width % self.steps_per_clock == 0:
108 return self.width // self.steps_per_clock
109 return self.width // self.steps_per_clock + 1
110
111 @property
112 def clock_range(self):
113 """the range that `CLDivRemState.clock` will fall in.
114
115 returns: range
116 """
117 return range(self.done_clock + 1)
118
119 @property
120 def substep_range(self):
121 """the range that `CLDivRemState.substep` will fall in.
122
123 returns: range
124 """
125 return range(self.steps_per_clock)
126
127 @property
128 def d_width(self):
129 """bit-width of the internal signal `CLDivRemState.d`"""
130 return self.width
131
132 @property
133 def r_width(self):
134 """bit-width of the internal signal `CLDivRemState.r`"""
135 return self.width * 2
136
137 @property
138 def q_width(self):
139 """bit-width of the internal signal `CLDivRemState.q`"""
140 return self.width
141
142 @property
143 def shift_width(self):
144 """bit-width of the internal signal `CLDivRemState.shift`"""
145 return (self.width - 1).bit_length()
146
147
148 @plain_data(frozen=True, eq=False)
149 class CLDivRemState:
150 __slots__ = "shape", "name", "clock", "substep", "d", "r", "q", "shift"
151 def __init__(self, shape, *, name=None, src_loc_at=0):
152 assert isinstance(shape, CLDivRemShape)
153 if name is None:
154 name = Signal(src_loc_at=1 + src_loc_at).name
155 assert isinstance(name, str)
156 self.shape = shape
157 self.name = name
158 self.clock = Signal(shape.clock_range, name=f"{name}_clock")
159 self.substep = Signal(shape.substep_range, name=f"{name}_substep",
160 reset=0)
161 self.d = Signal(shape.d_width, name=f"{name}_d")
162 self.r = Signal(shape.r_width, name=f"{name}_r")
163 self.q = Signal(shape.q_width, name=f"{name}_q")
164 self.shift = Signal(shape.shift_width, name=f"{name}_shift")
165
166 def eq(self, rhs):
167 assert isinstance(rhs, CLDivRemState)
168 for f in fields(CLDivRemState):
169 if f in ("shape", "name"):
170 continue
171 l = getattr(self, f)
172 r = getattr(rhs, f)
173 yield l.eq(r)
174
175 @staticmethod
176 def like(other, *, name=None, src_loc_at=0):
177 assert isinstance(other, CLDivRemState)
178 return CLDivRemState(other.shape, name=name, src_loc_at=1 + src_loc_at)
179
180 @property
181 def done(self):
182 return self.clock == self.shape.done_clock
183
184 def get_output(self):
185 return self.q, (self.r >> self.shape.width) >> self.shift
186
187 def set_to_initial(self, m, n, d):
188 assert isinstance(m, Module)
189 n = Value.cast(n) # convert to Value
190 d = Value.cast(d) # convert to Value
191 clz_mod = CLZ(self.shape.width - 1)
192 # can't name submodule since it would conflict if this function is
193 # called multiple times in a Module
194 m.submodules += clz_mod
195 assert clz_mod.lz.width == self.shape.shift_width, \
196 "internal inconsistency -- mismatched shift signal width"
197 m.d.comb += [
198 clz_mod.sig_in.eq(d >> 1),
199 self.shift.eq(clz_mod.lz),
200 self.d.eq(d << self.shift),
201 self.r.eq(n << self.shift),
202 self.q.eq(0),
203 self.clock.eq(0),
204 self.substep.eq(0),
205 ]
206
207 def eq_but_zero_substep(self, rhs, do_assert):
208 assert isinstance(rhs, CLDivRemState)
209 for f in fields(CLDivRemState):
210 if f in ("shape", "name"):
211 continue
212 l = getattr(self, f)
213 r = getattr(rhs, f)
214 if f == "substep":
215 if do_assert:
216 yield Assert(r == 0)
217 r = 0
218 yield l.eq(r)
219
220 def set_to_next(self, m, state_in):
221 assert isinstance(m, Module)
222 assert isinstance(state_in, CLDivRemState)
223 assert state_in.shape == self.shape
224 assert self is not state_in, "a.set_to_next(m, a) is not allowed"
225 width = self.shape.width
226 substep_wraps = state_in.substep >= self.shape.steps_per_clock - 1
227 with m.If(substep_wraps):
228 m.d.comb += self.substep.eq(0)
229 with m.Else():
230 m.d.comb += self.substep.eq(state_in.substep + 1)
231
232 with m.If(state_in.done):
233 m.d.comb += [
234 self.clock.eq(state_in.clock),
235 self.d.eq(state_in.d),
236 self.r.eq(state_in.r),
237 self.q.eq(state_in.q),
238 self.shift.eq(state_in.shift),
239 ]
240 with m.Else():
241 clock = state_in.clock + substep_wraps
242 with m.If((clock == width // self.shape.steps_per_clock)
243 & (self.substep >= width % self.shape.steps_per_clock)):
244 m.d.comb += self.clock.eq(self.shape.done_clock)
245 with m.Else():
246 m.d.comb += self.clock.eq(clock)
247 m.d.comb += [
248 self.d.eq(state_in.d),
249 self.shift.eq(state_in.shift),
250 ]
251 q = state_in.q << 1
252 r = state_in.r << 1
253 with m.If(r[width * 2 - 1]):
254 m.d.comb += [
255 self.q.eq(q | 1),
256 self.r.eq(r ^ (state_in.d << width)),
257 ]
258 with m.Else():
259 m.d.comb += [
260 self.q.eq(q),
261 self.r.eq(r),
262 ]
263
264
265 class CLDivRemInputData:
266 def __init__(self, shape):
267 assert isinstance(shape, CLDivRemShape)
268 self.shape = shape
269 self.n = Signal(shape.width)
270 self.d = Signal(shape.width)
271
272 def __iter__(self):
273 """ Get member signals. """
274 yield self.n
275 yield self.d
276
277 def eq(self, rhs):
278 """ Assign member signals. """
279 return [
280 self.n.eq(rhs.n),
281 self.d.eq(rhs.d),
282 ]
283
284
285 class CLDivRemOutputData:
286 def __init__(self, shape):
287 assert isinstance(shape, CLDivRemShape)
288 self.shape = shape
289 self.q = Signal(shape.width)
290 self.r = Signal(shape.width)
291
292 def __iter__(self):
293 """ Get member signals. """
294 yield self.q
295 yield self.r
296
297 def eq(self, rhs):
298 """ Assign member signals. """
299 return [
300 self.q.eq(rhs.q),
301 self.r.eq(rhs.r),
302 ]
303
304 def eq_output(self, state):
305 assert isinstance(state, CLDivRemState)
306 assert state.shape == self.shape
307 q, r = state.get_output()
308 return [self.q.eq(q), self.r.eq(r)]
309
310
311 class CLDivRemFSMStage(ControlBase):
312 """carry-less div/rem
313
314 Attributes:
315 shape: CLDivRemShape
316 the shape
317 pspec:
318 pipe-spec
319 empty: Signal()
320 true if nothing is stored in `self.saved_state`
321 saved_state: CLDivRemState()
322 the saved state that is currently being worked on.
323 """
324
325 def __init__(self, pspec, shape):
326 assert isinstance(shape, CLDivRemShape)
327 self.shape = shape
328 self.pspec = pspec # store now: used in ispec and ospec
329 super().__init__(stage=self)
330 self.empty = Signal(reset=1)
331 self.saved_state = CLDivRemState(shape)
332
333 def ispec(self):
334 return CLDivRemInputData(self.shape)
335
336 def ospec(self):
337 return CLDivRemOutputData(self.shape)
338
339 def setup(self, m, i):
340 pass
341
342 def elaborate(self, platform):
343 m = super().elaborate(platform)
344 i_data: CLDivRemInputData = self.p.i_data
345 o_data: CLDivRemOutputData = self.n.o_data
346 steps_per_clock = self.shape.steps_per_clock
347
348 # TODO: handle cancellation
349
350 m.d.comb += self.n.o_valid.eq(~self.empty & self.saved_state.done)
351 m.d.comb += self.p.o_ready.eq(self.empty)
352
353 def make_nc(i):
354 return CLDivRemState(self.shape, name=f"next_chain_{i}")
355 next_chain = [make_nc(i) for i in range(steps_per_clock + 1)]
356 for i in range(steps_per_clock):
357 next_chain[i + 1].set_to_next(m, next_chain[i])
358 m.d.comb += next_chain[0].eq(self.saved_state)
359 m.d.comb += o_data.eq_output(self.saved_state)
360 initial_state = CLDivRemState(self.shape)
361 initial_state.set_to_initial(m, n=i_data.n, d=i_data.d)
362
363 do_assert = platform == "formal"
364
365 with m.If(self.empty):
366 m.d.sync += self.saved_state.eq_but_zero_substep(initial_state,
367 do_assert)
368 with m.If(self.p.i_valid):
369 m.d.sync += self.empty.eq(0)
370 with m.Else():
371 m.d.sync += self.saved_state.eq_but_zero_substep(next_chain[-1],
372 do_assert)
373 with m.If(self.n.i_ready & self.n.o_valid):
374 m.d.sync += self.empty.eq(1)
375 return m
376
377 def __iter__(self):
378 yield from self.p
379 yield from self.n
380
381 def ports(self):
382 return list(self)