1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
7 """ Carry-less Division and Remainder.
9 https://bugs.libre-soc.org/show_bug.cgi?id=784
12 from dataclasses
import dataclass
, field
, fields
13 from nmigen
.hdl
.ir
import Elaboratable
14 from nmigen
.hdl
.ast
import Signal
, Value
15 from nmigen
.hdl
.dsl
import Module
16 from nmutil
.singlepipe
import ControlBase
17 from nmutil
.clz
import CLZ
, clz
20 def equal_leading_zero_count_reference(a
, b
, width
):
21 """checks if `clz(a) == clz(b)`.
22 Reference code for algorithm used in `EqualLeadingZeroCount`.
24 assert isinstance(width
, int) and 0 <= width
25 assert isinstance(a
, int) and 0 <= a
< (1 << width
)
26 assert isinstance(b
, int) and 0 <= b
< (1 << width
)
27 eq
= True # both have no leading zeros so far...
28 for i
in range(width
):
31 # `both_ones` is set if both have no leading zeros so far
32 both_ones
= a_bit
& b_bit
33 # `different` is set if there are a different number of leading
35 different
= a_bit
!= b_bit
41 pass # propagate from lower bits
45 class EqualLeadingZeroCount(Elaboratable
):
46 """checks if `clz(a) == clz(b)`.
50 the width in bits of `a` and `b`.
51 a: Signal of width `width`
53 b: Signal of width `width`
55 out: Signal of width `1`
56 output, set if the number of leading zeros in `a` is the same as in
60 def __init__(self
, width
):
61 assert isinstance(width
, int)
63 self
.a
= Signal(width
)
64 self
.b
= Signal(width
)
67 def elaborate(self
, platform
):
68 # the operation is converted into calculation of the carry-out of a
69 # binary addition, allowing FPGAs to re-use their specialized
70 # carry-propagation logic. This should be simplified by yosys to
71 # remove the extraneous xor gates from addition when targeting
72 # FPGAs/ASICs, so no efficiency is lost.
74 # see `equal_leading_zero_count_reference` for a Python version of
75 # the algorithm, but without conversion to carry-propagation.
76 # note that it's possible to do all the bits at once: a for-loop
77 # (unlike in the reference-code) is not necessary
80 both_ones
= Signal(self
.width
)
81 different
= Signal(self
.width
)
83 # build `both_ones` and `different` such that:
84 # for every bit index `i`:
85 # * if `both_ones[i]` is set, then both addends bits at index `i` are
86 # set in order to set the carry bit out, since `cin + 1 + 1` always
88 # * if `different[i]` is set, then both addends bits at index `i` are
89 # zeros in order to clear the carry bit out, since `cin + 0 + 0`
90 # never has a carry out.
91 # * otherwise exactly one of the addends bits at index `i` is set and
92 # the other is clear in order to propagate the carry bit from
93 # less significant bits, since `cin + 1 + 0` has a carry out that is
96 # `both_ones` is set if both have no leading zeros so far
97 m
.d
.comb
+= both_ones
.eq(self
.a
& self
.b
)
98 # `different` is set if there are a different number of leading
100 m
.d
.comb
+= different
.eq(self
.a ^ self
.b
)
102 # now [ab]use add: the last bit [carry-out] is the result
103 csum
= Signal(self
.width
+ 1)
104 carry_in
= 1 # both have no leading zeros so far, so set carry in
105 m
.d
.comb
+= csum
.eq(both_ones
+ (~different
) + carry_in
)
106 m
.d
.comb
+= self
.out
.eq(csum
[self
.width
]) # out is carry-out
111 def cldivrem_shifting(n
, d
, width
):
112 """ Carry-less Division and Remainder based on shifting at start and end
113 allowing us to get away with checking a single bit each iteration
114 rather than checking for equal degrees every iteration.
115 `n` and `d` are integers, `width` is the number of bits needed to hold
117 Returns a tuple `q, r` of the quotient and remainder.
119 assert isinstance(width
, int) and width
>= 1
120 assert isinstance(n
, int) and 0 <= n
< 1 << width
121 assert isinstance(d
, int) and 0 <= d
< 1 << width
122 assert d
!= 0, "TODO: decide what happens on division by zero"
124 shape
= CLDivRemShape(width
)
126 # `clz(d, width)`, but maxes out at `width - 1` instead of `width` in
127 # order to both fit in `shape.shift_width` bits and to not shift by more
129 shift
= clz(d
>> 1, width
- 1)
130 assert 0 <= shift
< 1 << shape
.shift_width
, "shift overflow"
132 assert 0 <= d
< 1 << shape
.d_width
, "d overflow"
134 assert 0 <= r
< 1 << shape
.r_width
, "r overflow"
136 for step
in range(width
):
139 if r
>> (width
* 2 - 1) != 0:
142 assert 0 <= q
< 1 << shape
.q_width
, "q overflow"
143 assert 0 <= r
< 1 << shape
.r_width
, "r overflow"
149 @dataclass(frozen
=True, unsafe_hash
=True)
153 def __post_init__(self
):
154 assert isinstance(self
.width
, int) and self
.width
>= 1, "invalid width"
158 """the step number when iteration is finished
159 -- the largest `CLDivRemState.step` will get
164 def step_range(self
):
165 """the range that `CLDivRemState.step` will fall in.
169 return range(self
.done_step
+ 1)
173 """bit-width of the internal signal `CLDivRemState.d`"""
178 """bit-width of the internal signal `CLDivRemState.r`"""
179 return self
.width
* 2
183 """bit-width of the internal signal `CLDivRemState.q`"""
187 def shift_width(self
):
188 """bit-width of the internal signal `CLDivRemState.shift`"""
189 return (self
.width
- 1).bit_length()
192 @dataclass(frozen
=True, eq
=False)
196 step
: Signal
= field(init
=False)
197 d
: Signal
= field(init
=False)
198 r
: Signal
= field(init
=False)
199 q
: Signal
= field(init
=False)
200 shift
: Signal
= field(init
=False)
202 def __init__(self
, shape
, *, name
=None, src_loc_at
=0):
203 assert isinstance(shape
, CLDivRemShape
)
205 name
= Signal(src_loc_at
=1 + src_loc_at
).name
206 assert isinstance(name
, str)
207 step
= Signal(shape
.step_range
, name
=f
"{name}_step")
208 d
= Signal(shape
.d_width
, name
=f
"{name}_d")
209 r
= Signal(shape
.r_width
, name
=f
"{name}_r")
210 q
= Signal(shape
.q_width
, name
=f
"{name}_q")
211 shift
= Signal(shape
.shift_width
, name
=f
"{name}_shift")
212 object.__setattr
__(self
, "shape", shape
)
213 object.__setattr
__(self
, "name", name
)
214 object.__setattr
__(self
, "step", step
)
215 object.__setattr
__(self
, "d", d
)
216 object.__setattr
__(self
, "r", r
)
217 object.__setattr
__(self
, "q", q
)
218 object.__setattr
__(self
, "shift", shift
)
221 assert isinstance(rhs
, CLDivRemState
)
222 for f
in fields(CLDivRemState
):
223 if f
.name
in ("shape", "name"):
225 l
= getattr(self
, f
.name
)
226 r
= getattr(rhs
, f
.name
)
230 def like(other
, *, name
=None, src_loc_at
=0):
231 assert isinstance(other
, CLDivRemState
)
232 return CLDivRemState(other
.shape
, name
=name
, src_loc_at
=1 + src_loc_at
)
236 return self
.will_be_done_after(steps
=0)
238 def will_be_done_after(self
, steps
):
239 """ Returns True if this state will be done after
240 another `steps` passes through `set_to_next`."""
241 assert isinstance(steps
, int) and steps
>= 0
242 return self
.step
>= max(0, self
.shape
.done_step
- steps
)
244 def get_output(self
):
245 return self
.q
, (self
.r
>> self
.shape
.width
) >> self
.shift
247 def set_to_initial(self
, m
, n
, d
):
248 assert isinstance(m
, Module
)
249 n
= Value
.cast(n
) # convert to Value
250 d
= Value
.cast(d
) # convert to Value
251 clz_mod
= CLZ(self
.shape
.width
- 1)
252 # can't name submodule since it would conflict if this function is
253 # called multiple times in a Module
254 m
.submodules
+= clz_mod
255 assert clz_mod
.lz
.width
== self
.shape
.shift_width
, \
256 "internal inconsistency -- mismatched shift signal width"
258 clz_mod
.sig_in
.eq(d
>> 1),
259 self
.shift
.eq(clz_mod
.lz
),
260 self
.d
.eq(d
<< self
.shift
),
261 self
.r
.eq(n
<< self
.shift
),
266 def set_to_next(self
, m
, state_in
):
267 assert isinstance(m
, Module
)
268 assert isinstance(state_in
, CLDivRemState
)
269 assert state_in
.shape
== self
.shape
270 assert self
is not state_in
, "a.set_to_next(m, a) is not allowed"
271 width
= self
.shape
.width
273 with m
.If(state_in
.done
):
274 m
.d
.comb
+= self
.eq(state_in
)
277 self
.step
.eq(state_in
.step
+ 1),
278 self
.d
.eq(state_in
.d
),
279 self
.shift
.eq(state_in
.shift
),
283 with m
.If(r
[width
* 2 - 1]):
286 self
.r
.eq(r ^
(state_in
.d
<< width
)),
295 class CLDivRemInputData
:
296 def __init__(self
, shape
):
297 assert isinstance(shape
, CLDivRemShape
)
299 self
.n
= Signal(shape
.width
)
300 self
.d
= Signal(shape
.width
)
303 """ Get member signals. """
308 """ Assign member signals. """
315 class CLDivRemOutputData
:
316 def __init__(self
, shape
):
317 assert isinstance(shape
, CLDivRemShape
)
319 self
.q
= Signal(shape
.width
)
320 self
.r
= Signal(shape
.width
)
323 """ Get member signals. """
328 """ Assign member signals. """
335 class CLDivRemFSMStage(ControlBase
):
336 """carry-less div/rem
342 number of steps that should be taken per clock cycle
346 true if nothing is stored in `self.saved_state`
347 saved_state: CLDivRemState()
348 the saved state that is currently being worked on.
351 def __init__(self
, pspec
, shape
, *, steps_per_clock
=8):
352 assert isinstance(shape
, CLDivRemShape
)
353 assert isinstance(steps_per_clock
, int) and steps_per_clock
>= 1
355 self
.steps_per_clock
= steps_per_clock
356 self
.pspec
= pspec
# store now: used in ispec and ospec
357 super().__init
__(stage
=self
)
358 self
.empty
= Signal(reset
=1)
359 self
.saved_state
= CLDivRemState(shape
)
362 return CLDivRemInputData(self
.shape
)
365 return CLDivRemOutputData(self
.shape
)
367 def setup(self
, m
, i
):
370 def elaborate(self
, platform
):
371 m
= super().elaborate(platform
)
372 i_data
: CLDivRemInputData
= self
.p
.i_data
373 o_data
: CLDivRemOutputData
= self
.n
.o_data
375 # TODO: handle cancellation
377 m
.d
.comb
+= self
.n
.o_valid
.eq(~self
.empty
& self
.saved_state
.done
)
378 m
.d
.comb
+= self
.p
.o_ready
.eq(self
.empty
)
381 return CLDivRemState(self
.shape
, name
=f
"next_chain_{i}")
382 next_chain
= [make_nc(i
) for i
in range(self
.steps_per_clock
+ 1)]
383 for i
in range(self
.steps_per_clock
):
384 next_chain
[i
+ 1].set_to_next(m
, next_chain
[i
])
385 m
.d
.comb
+= next_chain
[0].eq(self
.saved_state
)
386 out_q
, out_r
= self
.saved_state
.get_output()
387 m
.d
.comb
+= o_data
.q
.eq(out_q
)
388 m
.d
.comb
+= o_data
.r
.eq(out_r
)
389 initial_state
= CLDivRemState(self
.shape
)
390 initial_state
.set_to_initial(m
, n
=i_data
.n
, d
=i_data
.d
)
392 with m
.If(self
.empty
):
393 m
.d
.sync
+= self
.saved_state
.eq(initial_state
)
394 with m
.If(self
.p
.i_valid
):
395 m
.d
.sync
+= self
.empty
.eq(0)
397 m
.d
.sync
+= self
.saved_state
.eq(next_chain
[-1])
398 with m
.If(self
.n
.i_ready
& self
.n
.o_valid
):
399 m
.d
.sync
+= self
.empty
.eq(1)