73bdd1e9b75f144bcd81d5ad961b688f1a70f8c7
[nmigen-gf.git] / src / nmigen_gf / hdl / cldivrem.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 """ Carry-less Division and Remainder.
8
9 https://bugs.libre-soc.org/show_bug.cgi?id=784
10 """
11
12 from dataclasses import dataclass, field, fields
13 from nmigen.hdl.ir import Elaboratable
14 from nmigen.hdl.ast import Signal, Value
15 from nmigen.hdl.dsl import Module
16 from nmutil.singlepipe import ControlBase
17 from nmutil.clz import CLZ, clz
18
19
20 def equal_leading_zero_count_reference(a, b, width):
21 """checks if `clz(a) == clz(b)`.
22 Reference code for algorithm used in `EqualLeadingZeroCount`.
23 """
24 assert isinstance(width, int) and 0 <= width
25 assert isinstance(a, int) and 0 <= a < (1 << width)
26 assert isinstance(b, int) and 0 <= b < (1 << width)
27 eq = True # both have no leading zeros so far...
28 for i in range(width):
29 a_bit = (a >> i) & 1
30 b_bit = (b >> i) & 1
31 # `both_ones` is set if both have no leading zeros so far
32 both_ones = a_bit & b_bit
33 # `different` is set if there are a different number of leading
34 # zeros so far
35 different = a_bit != b_bit
36 if both_ones:
37 eq = True
38 elif different:
39 eq = False
40 else:
41 pass # propagate from lower bits
42 return eq
43
44
45 class EqualLeadingZeroCount(Elaboratable):
46 """checks if `clz(a) == clz(b)`.
47
48 Properties:
49 width: int
50 the width in bits of `a` and `b`.
51 a: Signal of width `width`
52 input
53 b: Signal of width `width`
54 input
55 out: Signal of width `1`
56 output, set if the number of leading zeros in `a` is the same as in
57 `b`.
58 """
59
60 def __init__(self, width):
61 assert isinstance(width, int)
62 self.width = width
63 self.a = Signal(width)
64 self.b = Signal(width)
65 self.out = Signal()
66
67 def elaborate(self, platform):
68 # the operation is converted into calculation of the carry-out of a
69 # binary addition, allowing FPGAs to re-use their specialized
70 # carry-propagation logic. This should be simplified by yosys to
71 # remove the extraneous xor gates from addition when targeting
72 # FPGAs/ASICs, so no efficiency is lost.
73 #
74 # see `equal_leading_zero_count_reference` for a Python version of
75 # the algorithm, but without conversion to carry-propagation.
76 # note that it's possible to do all the bits at once: a for-loop
77 # (unlike in the reference-code) is not necessary
78
79 m = Module()
80 both_ones = Signal(self.width)
81 different = Signal(self.width)
82
83 # build `both_ones` and `different` such that:
84 # for every bit index `i`:
85 # * if `both_ones[i]` is set, then both addends bits at index `i` are
86 # set in order to set the carry bit out, since `cin + 1 + 1` always
87 # has a carry out.
88 # * if `different[i]` is set, then both addends bits at index `i` are
89 # zeros in order to clear the carry bit out, since `cin + 0 + 0`
90 # never has a carry out.
91 # * otherwise exactly one of the addends bits at index `i` is set and
92 # the other is clear in order to propagate the carry bit from
93 # less significant bits, since `cin + 1 + 0` has a carry out that is
94 # equal to `cin`.
95
96 # `both_ones` is set if both have no leading zeros so far
97 m.d.comb += both_ones.eq(self.a & self.b)
98 # `different` is set if there are a different number of leading
99 # zeros so far
100 m.d.comb += different.eq(self.a ^ self.b)
101
102 # now [ab]use add: the last bit [carry-out] is the result
103 csum = Signal(self.width + 1)
104 carry_in = 1 # both have no leading zeros so far, so set carry in
105 m.d.comb += csum.eq(both_ones + (~different) + carry_in)
106 m.d.comb += self.out.eq(csum[self.width]) # out is carry-out
107
108 return m
109
110
111 def cldivrem_shifting(n, d, width):
112 """ Carry-less Division and Remainder based on shifting at start and end
113 allowing us to get away with checking a single bit each iteration
114 rather than checking for equal degrees every iteration.
115 `n` and `d` are integers, `width` is the number of bits needed to hold
116 each input/output.
117 Returns a tuple `q, r` of the quotient and remainder.
118 """
119 assert isinstance(width, int) and width >= 1
120 assert isinstance(n, int) and 0 <= n < 1 << width
121 assert isinstance(d, int) and 0 <= d < 1 << width
122 assert d != 0, "TODO: decide what happens on division by zero"
123
124 shape = CLDivRemShape(width)
125
126 # `clz(d, width)`, but maxes out at `width - 1` instead of `width` in
127 # order to both fit in `shape.shift_width` bits and to not shift by more
128 # than needed.
129 shift = clz(d >> 1, width - 1)
130 assert 0 <= shift < 1 << shape.shift_width, "shift overflow"
131 d <<= shift
132 assert 0 <= d < 1 << shape.d_width, "d overflow"
133 r = n << shift
134 assert 0 <= r < 1 << shape.r_width, "r overflow"
135 q = 0
136 for step in range(width):
137 q <<= 1
138 r <<= 1
139 if r >> (width * 2 - 1) != 0:
140 r ^= d << width
141 q |= 1
142 assert 0 <= q < 1 << shape.q_width, "q overflow"
143 assert 0 <= r < 1 << shape.r_width, "r overflow"
144 r >>= width
145 r >>= shift
146 return q, r
147
148
149 @dataclass(frozen=True, unsafe_hash=True)
150 class CLDivRemShape:
151 width: int
152
153 def __post_init__(self):
154 assert isinstance(self.width, int) and self.width >= 1, "invalid width"
155
156 @property
157 def done_step(self):
158 """the step number when iteration is finished
159 -- the largest `CLDivRemState.step` will get
160 """
161 return self.width
162
163 @property
164 def step_range(self):
165 """the range that `CLDivRemState.step` will fall in.
166
167 returns: range
168 """
169 return range(self.done_step + 1)
170
171 @property
172 def d_width(self):
173 """bit-width of the internal signal `CLDivRemState.d`"""
174 return self.width
175
176 @property
177 def r_width(self):
178 """bit-width of the internal signal `CLDivRemState.r`"""
179 return self.width * 2
180
181 @property
182 def q_width(self):
183 """bit-width of the internal signal `CLDivRemState.q`"""
184 return self.width
185
186 @property
187 def shift_width(self):
188 """bit-width of the internal signal `CLDivRemState.shift`"""
189 return (self.width - 1).bit_length()
190
191
192 @dataclass(frozen=True, eq=False)
193 class CLDivRemState:
194 shape: CLDivRemShape
195 name: str
196 step: Signal = field(init=False)
197 d: Signal = field(init=False)
198 r: Signal = field(init=False)
199 q: Signal = field(init=False)
200 shift: Signal = field(init=False)
201
202 def __init__(self, shape, *, name=None, src_loc_at=0):
203 assert isinstance(shape, CLDivRemShape)
204 if name is None:
205 name = Signal(src_loc_at=1 + src_loc_at).name
206 assert isinstance(name, str)
207 step = Signal(shape.step_range, name=f"{name}_step")
208 d = Signal(shape.d_width, name=f"{name}_d")
209 r = Signal(shape.r_width, name=f"{name}_r")
210 q = Signal(shape.q_width, name=f"{name}_q")
211 shift = Signal(shape.shift_width, name=f"{name}_shift")
212 object.__setattr__(self, "shape", shape)
213 object.__setattr__(self, "name", name)
214 object.__setattr__(self, "step", step)
215 object.__setattr__(self, "d", d)
216 object.__setattr__(self, "r", r)
217 object.__setattr__(self, "q", q)
218 object.__setattr__(self, "shift", shift)
219
220 def eq(self, rhs):
221 assert isinstance(rhs, CLDivRemState)
222 for f in fields(CLDivRemState):
223 if f.name in ("shape", "name"):
224 continue
225 l = getattr(self, f.name)
226 r = getattr(rhs, f.name)
227 yield l.eq(r)
228
229 @staticmethod
230 def like(other, *, name=None, src_loc_at=0):
231 assert isinstance(other, CLDivRemState)
232 return CLDivRemState(other.shape, name=name, src_loc_at=1 + src_loc_at)
233
234 @property
235 def done(self):
236 return self.will_be_done_after(steps=0)
237
238 def will_be_done_after(self, steps):
239 """ Returns True if this state will be done after
240 another `steps` passes through `set_to_next`."""
241 assert isinstance(steps, int) and steps >= 0
242 return self.step >= max(0, self.shape.done_step - steps)
243
244 def get_output(self):
245 return self.q, (self.r >> self.shape.width) >> self.shift
246
247 def set_to_initial(self, m, n, d):
248 assert isinstance(m, Module)
249 n = Value.cast(n) # convert to Value
250 d = Value.cast(d) # convert to Value
251 clz_mod = CLZ(self.shape.width - 1)
252 # can't name submodule since it would conflict if this function is
253 # called multiple times in a Module
254 m.submodules += clz_mod
255 assert clz_mod.lz.width == self.shape.shift_width, \
256 "internal inconsistency -- mismatched shift signal width"
257 m.d.comb += [
258 clz_mod.sig_in.eq(d >> 1),
259 self.shift.eq(clz_mod.lz),
260 self.d.eq(d << self.shift),
261 self.r.eq(n << self.shift),
262 self.q.eq(0),
263 self.step.eq(0),
264 ]
265
266 def set_to_next(self, m, state_in):
267 assert isinstance(m, Module)
268 assert isinstance(state_in, CLDivRemState)
269 assert state_in.shape == self.shape
270 assert self is not state_in, "a.set_to_next(m, a) is not allowed"
271 width = self.shape.width
272
273 with m.If(state_in.done):
274 m.d.comb += self.eq(state_in)
275 with m.Else():
276 m.d.comb += [
277 self.step.eq(state_in.step + 1),
278 self.d.eq(state_in.d),
279 self.shift.eq(state_in.shift),
280 ]
281 q = state_in.q << 1
282 r = state_in.r << 1
283 with m.If(r[width * 2 - 1]):
284 m.d.comb += [
285 self.q.eq(q | 1),
286 self.r.eq(r ^ (state_in.d << width)),
287 ]
288 with m.Else():
289 m.d.comb += [
290 self.q.eq(q),
291 self.r.eq(r),
292 ]
293
294
295 class CLDivRemInputData:
296 def __init__(self, shape):
297 assert isinstance(shape, CLDivRemShape)
298 self.shape = shape
299 self.n = Signal(shape.width)
300 self.d = Signal(shape.width)
301
302 def __iter__(self):
303 """ Get member signals. """
304 yield self.n
305 yield self.d
306
307 def eq(self, rhs):
308 """ Assign member signals. """
309 return [
310 self.n.eq(rhs.n),
311 self.d.eq(rhs.d),
312 ]
313
314
315 class CLDivRemOutputData:
316 def __init__(self, shape):
317 assert isinstance(shape, CLDivRemShape)
318 self.shape = shape
319 self.q = Signal(shape.width)
320 self.r = Signal(shape.width)
321
322 def __iter__(self):
323 """ Get member signals. """
324 yield self.q
325 yield self.r
326
327 def eq(self, rhs):
328 """ Assign member signals. """
329 return [
330 self.q.eq(rhs.q),
331 self.r.eq(rhs.r),
332 ]
333
334
335 class CLDivRemFSMStage(ControlBase):
336 """carry-less div/rem
337
338 Attributes:
339 shape: CLDivRemShape
340 the shape
341 steps_per_clock: int
342 number of steps that should be taken per clock cycle
343 pspec:
344 pipe-spec
345 empty: Signal()
346 true if nothing is stored in `self.saved_state`
347 saved_state: CLDivRemState()
348 the saved state that is currently being worked on.
349 """
350
351 def __init__(self, pspec, shape, *, steps_per_clock=8):
352 assert isinstance(shape, CLDivRemShape)
353 assert isinstance(steps_per_clock, int) and steps_per_clock >= 1
354 self.shape = shape
355 self.steps_per_clock = steps_per_clock
356 self.pspec = pspec # store now: used in ispec and ospec
357 super().__init__(stage=self)
358 self.empty = Signal(reset=1)
359 self.saved_state = CLDivRemState(shape)
360
361 def ispec(self):
362 return CLDivRemInputData(self.shape)
363
364 def ospec(self):
365 return CLDivRemOutputData(self.shape)
366
367 def setup(self, m, i):
368 pass
369
370 def elaborate(self, platform):
371 m = super().elaborate(platform)
372 i_data: CLDivRemInputData = self.p.i_data
373 o_data: CLDivRemOutputData = self.n.o_data
374
375 # TODO: handle cancellation
376
377 m.d.comb += self.n.o_valid.eq(~self.empty & self.saved_state.done)
378 m.d.comb += self.p.o_ready.eq(self.empty)
379
380 def make_nc(i):
381 return CLDivRemState(self.shape, name=f"next_chain_{i}")
382 next_chain = [make_nc(i) for i in range(self.steps_per_clock + 1)]
383 for i in range(self.steps_per_clock):
384 next_chain[i + 1].set_to_next(m, next_chain[i])
385 m.d.comb += next_chain[0].eq(self.saved_state)
386 out_q, out_r = self.saved_state.get_output()
387 m.d.comb += o_data.q.eq(out_q)
388 m.d.comb += o_data.r.eq(out_r)
389 initial_state = CLDivRemState(self.shape)
390 initial_state.set_to_initial(m, n=i_data.n, d=i_data.d)
391
392 with m.If(self.empty):
393 m.d.sync += self.saved_state.eq(initial_state)
394 with m.If(self.p.i_valid):
395 m.d.sync += self.empty.eq(0)
396 with m.Else():
397 m.d.sync += self.saved_state.eq(next_chain[-1])
398 with m.If(self.n.i_ready & self.n.o_valid):
399 m.d.sync += self.empty.eq(1)
400 return m
401
402 def __iter__(self):
403 yield from self.p
404 yield from self.n
405
406 def ports(self):
407 return list(self)