implement CLDivRemFSMStage
[nmigen-gf.git] / src / nmigen_gf / hdl / cldivrem.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 """ Carry-less Division and Remainder.
8
9 https://bugs.libre-soc.org/show_bug.cgi?id=784
10 """
11
12 from dataclasses import dataclass, field, fields
13 from nmigen.hdl.ir import Elaboratable
14 from nmigen.hdl.ast import Signal, Value
15 from nmigen.hdl.dsl import Module
16 from nmutil.singlepipe import ControlBase
17
18
19 def equal_leading_zero_count_reference(a, b, width):
20 """checks if `clz(a) == clz(b)`.
21 Reference code for algorithm used in `EqualLeadingZeroCount`.
22 """
23 assert isinstance(width, int) and 0 <= width
24 assert isinstance(a, int) and 0 <= a < (1 << width)
25 assert isinstance(b, int) and 0 <= b < (1 << width)
26 eq = True # both have no leading zeros so far...
27 for i in range(width):
28 a_bit = (a >> i) & 1
29 b_bit = (b >> i) & 1
30 # `both_ones` is set if both have no leading zeros so far
31 both_ones = a_bit & b_bit
32 # `different` is set if there are a different number of leading
33 # zeros so far
34 different = a_bit != b_bit
35 if both_ones:
36 eq = True
37 elif different:
38 eq = False
39 else:
40 pass # propagate from lower bits
41 return eq
42
43
44 class EqualLeadingZeroCount(Elaboratable):
45 """checks if `clz(a) == clz(b)`.
46
47 Properties:
48 width: int
49 the width in bits of `a` and `b`.
50 a: Signal of width `width`
51 input
52 b: Signal of width `width`
53 input
54 out: Signal of width `1`
55 output, set if the number of leading zeros in `a` is the same as in
56 `b`.
57 """
58
59 def __init__(self, width):
60 assert isinstance(width, int)
61 self.width = width
62 self.a = Signal(width)
63 self.b = Signal(width)
64 self.out = Signal()
65
66 def elaborate(self, platform):
67 # the operation is converted into calculation of the carry-out of a
68 # binary addition, allowing FPGAs to re-use their specialized
69 # carry-propagation logic. This should be simplified by yosys to
70 # remove the extraneous xor gates from addition when targeting
71 # FPGAs/ASICs, so no efficiency is lost.
72 #
73 # see `equal_leading_zero_count_reference` for a Python version of
74 # the algorithm, but without conversion to carry-propagation.
75 # note that it's possible to do all the bits at once: a for-loop
76 # (unlike in the reference-code) is not necessary
77
78 m = Module()
79 both_ones = Signal(self.width)
80 different = Signal(self.width)
81
82 # build `both_ones` and `different` such that:
83 # for every bit index `i`:
84 # * if `both_ones[i]` is set, then both addends bits at index `i` are
85 # set in order to set the carry bit out, since `cin + 1 + 1` always
86 # has a carry out.
87 # * if `different[i]` is set, then both addends bits at index `i` are
88 # zeros in order to clear the carry bit out, since `cin + 0 + 0`
89 # never has a carry out.
90 # * otherwise exactly one of the addends bits at index `i` is set and
91 # the other is clear in order to propagate the carry bit from
92 # less significant bits, since `cin + 1 + 0` has a carry out that is
93 # equal to `cin`.
94
95 # `both_ones` is set if both have no leading zeros so far
96 m.d.comb += both_ones.eq(self.a & self.b)
97 # `different` is set if there are a different number of leading
98 # zeros so far
99 m.d.comb += different.eq(self.a ^ self.b)
100
101 # now [ab]use add: the last bit [carry-out] is the result
102 csum = Signal(self.width + 1)
103 carry_in = 1 # both have no leading zeros so far, so set carry in
104 m.d.comb += csum.eq(both_ones + (~different) + carry_in)
105 m.d.comb += self.out.eq(csum[self.width]) # out is carry-out
106
107 return m
108
109
110 @dataclass(frozen=True, unsafe_hash=True)
111 class CLDivRemShape:
112 width: int
113 n_width: int
114
115 def __post_init__(self):
116 assert self.n_width >= self.width > 0
117
118 @property
119 def done_step(self):
120 return self.width
121
122 @property
123 def step_range(self):
124 return range(self.done_step + 1)
125
126
127 @dataclass(frozen=True, eq=False)
128 class CLDivRemState:
129 shape: CLDivRemShape
130 name: str
131 d: Signal = field(init=False)
132 r: Signal = field(init=False)
133 q: Signal = field(init=False)
134 step: Signal = field(init=False)
135
136 def __init__(self, shape, *, name=None, src_loc_at=0):
137 assert isinstance(shape, CLDivRemShape)
138 if name is None:
139 name = Signal(src_loc_at=1 + src_loc_at).name
140 assert isinstance(name, str)
141 d = Signal(2 * shape.width, name=f"{name}_d")
142 r = Signal(shape.n_width, name=f"{name}_r")
143 q = Signal(shape.width, name=f"{name}_q")
144 step = Signal(shape.width, name=f"{name}_step")
145 object.__setattr__(self, "shape", shape)
146 object.__setattr__(self, "name", name)
147 object.__setattr__(self, "d", d)
148 object.__setattr__(self, "r", r)
149 object.__setattr__(self, "q", q)
150 object.__setattr__(self, "step", step)
151
152 def eq(self, rhs):
153 assert isinstance(rhs, CLDivRemState)
154 for f in fields(CLDivRemState):
155 if f.name in ("shape", "name"):
156 continue
157 l = getattr(self, f.name)
158 r = getattr(rhs, f.name)
159 yield l.eq(r)
160
161 @staticmethod
162 def like(other, *, name=None, src_loc_at=0):
163 assert isinstance(other, CLDivRemState)
164 return CLDivRemState(other.shape, name=name, src_loc_at=1 + src_loc_at)
165
166 @property
167 def done(self):
168 return self.will_be_done_after(steps=0)
169
170 def will_be_done_after(self, steps):
171 """ Returns True if this state will be done after
172 another `steps` passes through `set_to_next`."""
173 assert isinstance(steps, int) and steps >= 0
174 return self.step >= max(0, self.shape.done_step - steps)
175
176 def set_to_initial(self, m, n, d):
177 assert isinstance(m, Module)
178 m.d.comb += [
179 self.d.eq(Value.cast(d) << self.shape.width),
180 self.r.eq(n),
181 self.q.eq(0),
182 self.step.eq(0),
183 ]
184
185 def set_to_next(self, m, state_in):
186 assert isinstance(m, Module)
187 assert isinstance(state_in, CLDivRemState)
188 assert state_in.shape == self.shape
189 assert self is not state_in, "a.set_to_next(m, a) is not allowed"
190
191 equal_leading_zero_count = EqualLeadingZeroCount(self.shape.n_width)
192 # can't name submodule since it would conflict if this function is
193 # called multiple times in a Module
194 m.submodules += equal_leading_zero_count
195
196 with m.If(state_in.done):
197 m.d.comb += self.eq(state_in)
198 with m.Else():
199 m.d.comb += [
200 self.step.eq(state_in.step + 1),
201 self.d.eq(state_in.d >> 1),
202 equal_leading_zero_count.a.eq(self.d),
203 equal_leading_zero_count.b.eq(state_in.r),
204 ]
205 d_top = self.d[self.shape.n_width:]
206 with m.If(equal_leading_zero_count.out & (d_top == 0)):
207 m.d.comb += [
208 self.r.eq(state_in.r ^ self.d),
209 self.q.eq((state_in.q << 1) | 1),
210 ]
211 with m.Else():
212 m.d.comb += [
213 self.r.eq(state_in.r),
214 self.q.eq(state_in.q << 1),
215 ]
216
217
218 class CLDivRemInputData:
219 def __init__(self, shape):
220 assert isinstance(shape, CLDivRemShape)
221 self.shape = shape
222 self.n = Signal(shape.n_width)
223 self.d = Signal(shape.width)
224
225 def __iter__(self):
226 """ Get member signals. """
227 yield self.n
228 yield self.d
229
230 def eq(self, rhs):
231 """ Assign member signals. """
232 return [
233 self.n.eq(rhs.n),
234 self.d.eq(rhs.d),
235 ]
236
237
238 class CLDivRemOutputData:
239 def __init__(self, shape):
240 assert isinstance(shape, CLDivRemShape)
241 self.shape = shape
242 self.q = Signal(shape.width)
243 self.r = Signal(shape.width)
244
245 def __iter__(self):
246 """ Get member signals. """
247 yield self.q
248 yield self.r
249
250 def eq(self, rhs):
251 """ Assign member signals. """
252 return [
253 self.q.eq(rhs.q),
254 self.r.eq(rhs.r),
255 ]
256
257
258 class CLDivRemFSMStage(ControlBase):
259 """carry-less div/rem
260
261 Attributes:
262 shape: CLDivRemShape
263 the shape
264 steps_per_clock: int
265 number of steps that should be taken per clock cycle
266 in_valid: Signal()
267 input. true when the data inputs (`n` and `d`) are valid.
268 data transfer in occurs when `in_valid & in_ready`.
269 in_ready: Signal()
270 output. true when this FSM is ready to accept input.
271 data transfer in occurs when `in_valid & in_ready`.
272 n: Signal(shape.n_width)
273 numerator in, the value must be small enough that `q` and `r` don't
274 overflow. having `n_width == width` is sufficient.
275 d: Signal(shape.width)
276 denominator in, must be non-zero.
277 q: Signal(shape.width)
278 quotient out.
279 r: Signal(shape.width)
280 remainder out.
281 out_valid: Signal()
282 output. true when the data outputs (`q` and `r`) are valid
283 (or are junk because the inputs were out of range).
284 data transfer out occurs when `out_valid & out_ready`.
285 out_ready: Signal()
286 input. true when the output can be read.
287 data transfer out occurs when `out_valid & out_ready`.
288 """
289
290 def __init__(self, pspec, shape, *, steps_per_clock=4):
291 assert isinstance(shape, CLDivRemShape)
292 assert isinstance(steps_per_clock, int) and steps_per_clock >= 1
293 self.shape = shape
294 self.steps_per_clock = steps_per_clock
295 self.pspec = pspec # store now: used in ispec and ospec
296 super().__init__(stage=self)
297 self.empty = Signal(reset=1)
298 self.saved_state = CLDivRemState(shape)
299
300 def ispec(self):
301 return CLDivRemInputData(self.shape)
302
303 def ospec(self):
304 return CLDivRemOutputData(self.shape)
305
306 def setup(self, m, i):
307 pass
308
309 def elaborate(self, platform):
310 m = super().elaborate(platform)
311 i_data: CLDivRemInputData = self.p.i_data
312 o_data: CLDivRemOutputData = self.n.o_data
313
314 # TODO: handle cancellation
315
316 state_will_be_done = self.saved_state.will_be_done_after(
317 self.steps_per_clock)
318 m.d.comb += self.n.o_valid.eq(~self.empty & state_will_be_done)
319 m.d.comb += self.p.o_ready.eq(self.empty)
320
321 def make_nc(i):
322 return CLDivRemState(self.shape, name=f"next_chain_{i}")
323 next_chain = [make_nc(i) for i in range(self.steps_per_clock + 1)]
324 for i in range(self.steps_per_clock):
325 next_chain[i + 1].set_to_next(m, next_chain[i])
326 m.d.sync += self.saved_state.eq(next_chain[-1])
327 m.d.comb += o_data.q.eq(next_chain[-1].q)
328 m.d.comb += o_data.r.eq(next_chain[-1].r)
329
330 with m.If(self.empty):
331 next_chain[0].set_to_initial(m, n=i_data.n, d=i_data.d)
332 with m.If(self.p.i_valid):
333 m.d.sync += self.empty.eq(0)
334 with m.Else():
335 m.d.comb += next_chain[0].eq(self.saved_state)
336 with m.If(self.n.i_ready & self.n.o_valid):
337 m.d.sync += self.empty.eq(1)
338
339 return m
340
341 def __iter__(self):
342 yield from self.p
343 yield from self.n
344
345 def ports(self):
346 return list(self)