897ced521446fbc66ee00cc0c5f7541921e820b2
[nmigen-gf.git] / src / nmigen_gf / hdl / cldivrem.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 """ Carry-less Division and Remainder.
8
9 https://bugs.libre-soc.org/show_bug.cgi?id=784
10 """
11
12 from dataclasses import dataclass, field, fields
13 from nmigen.hdl.ir import Elaboratable
14 from nmigen.hdl.ast import Signal, Value
15 from nmigen.hdl.dsl import Module
16 from nmutil.singlepipe import ControlBase
17 from nmutil.clz import CLZ, clz
18
19
20 def equal_leading_zero_count_reference(a, b, width):
21 """checks if `clz(a) == clz(b)`.
22 Reference code for algorithm used in `EqualLeadingZeroCount`.
23 """
24 assert isinstance(width, int) and 0 <= width
25 assert isinstance(a, int) and 0 <= a < (1 << width)
26 assert isinstance(b, int) and 0 <= b < (1 << width)
27 eq = True # both have no leading zeros so far...
28 for i in range(width):
29 a_bit = (a >> i) & 1
30 b_bit = (b >> i) & 1
31 # `both_ones` is set if both have no leading zeros so far
32 both_ones = a_bit & b_bit
33 # `different` is set if there are a different number of leading
34 # zeros so far
35 different = a_bit != b_bit
36 if both_ones:
37 eq = True
38 elif different:
39 eq = False
40 else:
41 pass # propagate from lower bits
42 return eq
43
44
45 class EqualLeadingZeroCount(Elaboratable):
46 """checks if `clz(a) == clz(b)`.
47
48 Properties:
49 width: int
50 the width in bits of `a` and `b`.
51 a: Signal of width `width`
52 input
53 b: Signal of width `width`
54 input
55 out: Signal of width `1`
56 output, set if the number of leading zeros in `a` is the same as in
57 `b`.
58 """
59
60 def __init__(self, width):
61 assert isinstance(width, int)
62 self.width = width
63 self.a = Signal(width)
64 self.b = Signal(width)
65 self.out = Signal()
66
67 def elaborate(self, platform):
68 # the operation is converted into calculation of the carry-out of a
69 # binary addition, allowing FPGAs to re-use their specialized
70 # carry-propagation logic. This should be simplified by yosys to
71 # remove the extraneous xor gates from addition when targeting
72 # FPGAs/ASICs, so no efficiency is lost.
73 #
74 # see `equal_leading_zero_count_reference` for a Python version of
75 # the algorithm, but without conversion to carry-propagation.
76 # note that it's possible to do all the bits at once: a for-loop
77 # (unlike in the reference-code) is not necessary
78
79 m = Module()
80 both_ones = Signal(self.width)
81 different = Signal(self.width)
82
83 # build `both_ones` and `different` such that:
84 # for every bit index `i`:
85 # * if `both_ones[i]` is set, then both addends bits at index `i` are
86 # set in order to set the carry bit out, since `cin + 1 + 1` always
87 # has a carry out.
88 # * if `different[i]` is set, then both addends bits at index `i` are
89 # zeros in order to clear the carry bit out, since `cin + 0 + 0`
90 # never has a carry out.
91 # * otherwise exactly one of the addends bits at index `i` is set and
92 # the other is clear in order to propagate the carry bit from
93 # less significant bits, since `cin + 1 + 0` has a carry out that is
94 # equal to `cin`.
95
96 # `both_ones` is set if both have no leading zeros so far
97 m.d.comb += both_ones.eq(self.a & self.b)
98 # `different` is set if there are a different number of leading
99 # zeros so far
100 m.d.comb += different.eq(self.a ^ self.b)
101
102 # now [ab]use add: the last bit [carry-out] is the result
103 csum = Signal(self.width + 1)
104 carry_in = 1 # both have no leading zeros so far, so set carry in
105 m.d.comb += csum.eq(both_ones + (~different) + carry_in)
106 m.d.comb += self.out.eq(csum[self.width]) # out is carry-out
107
108 return m
109
110
111 def cldivrem_shifting(n, d, width):
112 """ Carry-less Division and Remainder based on shifting at start and end
113 allowing us to get away with checking a single bit each iteration
114 rather than checking for equal degrees every iteration.
115 `n` and `d` are integers, `width` is the number of bits needed to hold
116 each input/output.
117 Returns a tuple `q, r` of the quotient and remainder.
118 """
119 assert isinstance(width, int) and width >= 1
120 assert isinstance(n, int) and 0 <= n < 1 << width
121 assert isinstance(d, int) and 0 <= d < 1 << width
122 assert d != 0, "TODO: decide what happens on division by zero"
123 shift = clz(d, width)
124 d <<= shift
125 n <<= shift
126 r = n
127 q = 0
128 d <<= width
129 for _ in range(width):
130 q <<= 1
131 r <<= 1
132 if r >> (width * 2 - 1) != 0:
133 r ^= d
134 q |= 1
135 r >>= width
136 r >>= shift
137 return q, r
138
139
140 @dataclass(frozen=True, unsafe_hash=True)
141 class CLDivRemShape:
142 width: int
143 n_width: int
144
145 def __post_init__(self):
146 assert self.n_width >= self.width > 0
147
148 @property
149 def done_step(self):
150 return self.width
151
152 @property
153 def step_range(self):
154 return range(self.done_step + 1)
155
156
157 @dataclass(frozen=True, eq=False)
158 class CLDivRemState:
159 shape: CLDivRemShape
160 name: str
161 d: Signal = field(init=False)
162 r: Signal = field(init=False)
163 q: Signal = field(init=False)
164 step: Signal = field(init=False)
165
166 def __init__(self, shape, *, name=None, src_loc_at=0):
167 assert isinstance(shape, CLDivRemShape)
168 if name is None:
169 name = Signal(src_loc_at=1 + src_loc_at).name
170 assert isinstance(name, str)
171 d = Signal(2 * shape.width, name=f"{name}_d")
172 r = Signal(shape.n_width, name=f"{name}_r")
173 q = Signal(shape.width, name=f"{name}_q")
174 step = Signal(shape.width, name=f"{name}_step")
175 object.__setattr__(self, "shape", shape)
176 object.__setattr__(self, "name", name)
177 object.__setattr__(self, "d", d)
178 object.__setattr__(self, "r", r)
179 object.__setattr__(self, "q", q)
180 object.__setattr__(self, "step", step)
181
182 def eq(self, rhs):
183 assert isinstance(rhs, CLDivRemState)
184 for f in fields(CLDivRemState):
185 if f.name in ("shape", "name"):
186 continue
187 l = getattr(self, f.name)
188 r = getattr(rhs, f.name)
189 yield l.eq(r)
190
191 @staticmethod
192 def like(other, *, name=None, src_loc_at=0):
193 assert isinstance(other, CLDivRemState)
194 return CLDivRemState(other.shape, name=name, src_loc_at=1 + src_loc_at)
195
196 @property
197 def done(self):
198 return self.will_be_done_after(steps=0)
199
200 def will_be_done_after(self, steps):
201 """ Returns True if this state will be done after
202 another `steps` passes through `set_to_next`."""
203 assert isinstance(steps, int) and steps >= 0
204 return self.step >= max(0, self.shape.done_step - steps)
205
206 def set_to_initial(self, m, n, d):
207 assert isinstance(m, Module)
208 m.d.comb += [
209 self.d.eq(Value.cast(d) << self.shape.width),
210 self.r.eq(n),
211 self.q.eq(0),
212 self.step.eq(0),
213 ]
214
215 def set_to_next(self, m, state_in):
216 assert isinstance(m, Module)
217 assert isinstance(state_in, CLDivRemState)
218 assert state_in.shape == self.shape
219 assert self is not state_in, "a.set_to_next(m, a) is not allowed"
220
221 equal_leading_zero_count = EqualLeadingZeroCount(self.shape.n_width)
222 # can't name submodule since it would conflict if this function is
223 # called multiple times in a Module
224 m.submodules += equal_leading_zero_count
225
226 with m.If(state_in.done):
227 m.d.comb += self.eq(state_in)
228 with m.Else():
229 m.d.comb += [
230 self.step.eq(state_in.step + 1),
231 self.d.eq(state_in.d >> 1),
232 equal_leading_zero_count.a.eq(self.d),
233 equal_leading_zero_count.b.eq(state_in.r),
234 ]
235 d_top = self.d[self.shape.n_width:]
236 with m.If(equal_leading_zero_count.out & (d_top == 0)):
237 m.d.comb += [
238 self.r.eq(state_in.r ^ self.d),
239 self.q.eq((state_in.q << 1) | 1),
240 ]
241 with m.Else():
242 m.d.comb += [
243 self.r.eq(state_in.r),
244 self.q.eq(state_in.q << 1),
245 ]
246
247
248 class CLDivRemInputData:
249 def __init__(self, shape):
250 assert isinstance(shape, CLDivRemShape)
251 self.shape = shape
252 self.n = Signal(shape.n_width)
253 self.d = Signal(shape.width)
254
255 def __iter__(self):
256 """ Get member signals. """
257 yield self.n
258 yield self.d
259
260 def eq(self, rhs):
261 """ Assign member signals. """
262 return [
263 self.n.eq(rhs.n),
264 self.d.eq(rhs.d),
265 ]
266
267
268 class CLDivRemOutputData:
269 def __init__(self, shape):
270 assert isinstance(shape, CLDivRemShape)
271 self.shape = shape
272 self.q = Signal(shape.width)
273 self.r = Signal(shape.width)
274
275 def __iter__(self):
276 """ Get member signals. """
277 yield self.q
278 yield self.r
279
280 def eq(self, rhs):
281 """ Assign member signals. """
282 return [
283 self.q.eq(rhs.q),
284 self.r.eq(rhs.r),
285 ]
286
287
288 class CLDivRemFSMStage(ControlBase):
289 """carry-less div/rem
290
291 Attributes:
292 shape: CLDivRemShape
293 the shape
294 steps_per_clock: int
295 number of steps that should be taken per clock cycle
296 in_valid: Signal()
297 input. true when the data inputs (`n` and `d`) are valid.
298 data transfer in occurs when `in_valid & in_ready`.
299 in_ready: Signal()
300 output. true when this FSM is ready to accept input.
301 data transfer in occurs when `in_valid & in_ready`.
302 n: Signal(shape.n_width)
303 numerator in, the value must be small enough that `q` and `r` don't
304 overflow. having `n_width == width` is sufficient.
305 d: Signal(shape.width)
306 denominator in, must be non-zero.
307 q: Signal(shape.width)
308 quotient out.
309 r: Signal(shape.width)
310 remainder out.
311 out_valid: Signal()
312 output. true when the data outputs (`q` and `r`) are valid
313 (or are junk because the inputs were out of range).
314 data transfer out occurs when `out_valid & out_ready`.
315 out_ready: Signal()
316 input. true when the output can be read.
317 data transfer out occurs when `out_valid & out_ready`.
318 """
319
320 def __init__(self, pspec, shape, *, steps_per_clock=4):
321 assert isinstance(shape, CLDivRemShape)
322 assert isinstance(steps_per_clock, int) and steps_per_clock >= 1
323 self.shape = shape
324 self.steps_per_clock = steps_per_clock
325 self.pspec = pspec # store now: used in ispec and ospec
326 super().__init__(stage=self)
327 self.empty = Signal(reset=1)
328 self.saved_state = CLDivRemState(shape)
329
330 def ispec(self):
331 return CLDivRemInputData(self.shape)
332
333 def ospec(self):
334 return CLDivRemOutputData(self.shape)
335
336 def setup(self, m, i):
337 pass
338
339 def elaborate(self, platform):
340 m = super().elaborate(platform)
341 i_data: CLDivRemInputData = self.p.i_data
342 o_data: CLDivRemOutputData = self.n.o_data
343
344 # TODO: handle cancellation
345
346 state_will_be_done = self.saved_state.will_be_done_after(
347 self.steps_per_clock)
348 m.d.comb += self.n.o_valid.eq(~self.empty & state_will_be_done)
349 m.d.comb += self.p.o_ready.eq(self.empty)
350
351 def make_nc(i):
352 return CLDivRemState(self.shape, name=f"next_chain_{i}")
353 next_chain = [make_nc(i) for i in range(self.steps_per_clock + 1)]
354 for i in range(self.steps_per_clock):
355 next_chain[i + 1].set_to_next(m, next_chain[i])
356 m.d.sync += self.saved_state.eq(next_chain[-1])
357 m.d.comb += o_data.q.eq(next_chain[-1].q)
358 m.d.comb += o_data.r.eq(next_chain[-1].r)
359
360 with m.If(self.empty):
361 next_chain[0].set_to_initial(m, n=i_data.n, d=i_data.d)
362 with m.If(self.p.i_valid):
363 m.d.sync += self.empty.eq(0)
364 with m.Else():
365 m.d.comb += next_chain[0].eq(self.saved_state)
366 with m.If(self.n.i_ready & self.n.o_valid):
367 m.d.sync += self.empty.eq(1)
368
369 return m
370
371 def __iter__(self):
372 yield from self.p
373 yield from self.n
374
375 def ports(self):
376 return list(self)