1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
7 """ Carry-less Division and Remainder.
9 https://bugs.libre-soc.org/show_bug.cgi?id=784
12 from nmigen
.hdl
.ast
import Signal
, Value
, Assert
13 from nmigen
.hdl
.dsl
import Module
14 from nmutil
.singlepipe
import ControlBase
15 from nmutil
.clz
import CLZ
, clz
16 from nmutil
.plain_data
import plain_data
, fields
19 def cldivrem_shifting(n
, d
, shape
):
20 """ Carry-less Division and Remainder based on shifting at start and end
21 allowing us to get away with checking a single bit each iteration
22 rather than checking for equal degrees every iteration.
23 `n` and `d` are integers, `width` is the number of bits needed to hold
25 Returns a tuple `q, r` of the quotient and remainder.
27 assert isinstance(shape
, CLDivRemShape
)
28 assert isinstance(n
, int) and 0 <= n
< 1 << shape
.width
29 assert isinstance(d
, int) and 0 <= d
< 1 << shape
.width
30 assert d
!= 0, "TODO: decide what happens on division by zero"
32 # declare locals so nonlocal works
33 r
= q
= shift
= clock
= substep
= NotImplemented
35 # functions match up to HDL parts:
38 nonlocal d
, r
, q
, clock
, substep
, shift
39 # `clz(d, shape.width)`, but maxes out at `shape.width - 1` instead of
40 # `shape.width` in order to both fit in `shape.shift_width` bits and
41 # to not shift by more than needed.
42 shift
= clz(d
>> 1, shape
.width
- 1)
43 assert 0 <= shift
< 1 << shape
.shift_width
, "shift overflow"
45 assert 0 <= d
< 1 << shape
.d_width
, "d overflow"
47 assert 0 <= r
< 1 << shape
.r_width
, "r overflow"
53 return clock
== shape
.done_clock
56 nonlocal r
, q
, clock
, substep
58 substep
%= shape
.steps_per_clock
63 if clock
== shape
.width
// shape
.steps_per_clock \
64 and substep
>= shape
.width
% shape
.steps_per_clock
:
65 clock
= shape
.done_clock
68 if r
>> (shape
.width
* 2 - 1) != 0:
71 assert 0 <= q
< 1 << shape
.q_width
, "q overflow"
72 assert 0 <= r
< 1 << shape
.r_width
, "r overflow"
75 return q
, (r
>> shape
.width
) >> shift
79 # one clock-cycle per outer loop
81 for expected_substep
in range(shape
.steps_per_clock
):
82 assert substep
== expected_substep
88 @plain_data(frozen
=True, unsafe_hash
=True)
90 __slots__
= "width", "steps_per_clock"
92 def __init__(self
, width
, steps_per_clock
=8):
93 assert isinstance(width
, int) and width
>= 1, "invalid width"
94 assert (isinstance(steps_per_clock
, int)
95 and steps_per_clock
>= 1), "invalid steps_per_clock"
97 """bit-width of each of the carry-less div/rem inputs and outputs"""
99 self
.steps_per_clock
= steps_per_clock
100 """number of steps that should be taken per clock cycle"""
103 def done_clock(self
):
104 """the clock tick number when iteration is finished
105 -- the largest `CLDivRemState.clock` will get
107 if self
.width
% self
.steps_per_clock
== 0:
108 return self
.width
// self
.steps_per_clock
109 return self
.width
// self
.steps_per_clock
+ 1
112 def clock_range(self
):
113 """the range that `CLDivRemState.clock` will fall in.
117 return range(self
.done_clock
+ 1)
120 def substep_range(self
):
121 """the range that `CLDivRemState.substep` will fall in.
125 return range(self
.steps_per_clock
)
129 """bit-width of the internal signal `CLDivRemState.d`"""
134 """bit-width of the internal signal `CLDivRemState.r`"""
135 return self
.width
* 2
139 """bit-width of the internal signal `CLDivRemState.q`"""
143 def shift_width(self
):
144 """bit-width of the internal signal `CLDivRemState.shift`"""
145 return (self
.width
- 1).bit_length()
148 @plain_data(frozen
=True, eq
=False)
150 __slots__
= "shape", "name", "clock", "substep", "d", "r", "q", "shift"
151 def __init__(self
, shape
, *, name
=None, src_loc_at
=0):
152 assert isinstance(shape
, CLDivRemShape
)
154 name
= Signal(src_loc_at
=1 + src_loc_at
).name
155 assert isinstance(name
, str)
158 self
.clock
= Signal(shape
.clock_range
, name
=f
"{name}_clock")
159 self
.substep
= Signal(shape
.substep_range
, name
=f
"{name}_substep",
161 self
.d
= Signal(shape
.d_width
, name
=f
"{name}_d")
162 self
.r
= Signal(shape
.r_width
, name
=f
"{name}_r")
163 self
.q
= Signal(shape
.q_width
, name
=f
"{name}_q")
164 self
.shift
= Signal(shape
.shift_width
, name
=f
"{name}_shift")
167 assert isinstance(rhs
, CLDivRemState
)
168 for f
in fields(CLDivRemState
):
169 if f
in ("shape", "name"):
176 def like(other
, *, name
=None, src_loc_at
=0):
177 assert isinstance(other
, CLDivRemState
)
178 return CLDivRemState(other
.shape
, name
=name
, src_loc_at
=1 + src_loc_at
)
182 return self
.clock
== self
.shape
.done_clock
184 def get_output(self
):
185 return self
.q
, (self
.r
>> self
.shape
.width
) >> self
.shift
187 def set_to_initial(self
, m
, n
, d
):
188 assert isinstance(m
, Module
)
189 n
= Value
.cast(n
) # convert to Value
190 d
= Value
.cast(d
) # convert to Value
191 clz_mod
= CLZ(self
.shape
.width
- 1)
192 # can't name submodule since it would conflict if this function is
193 # called multiple times in a Module
194 m
.submodules
+= clz_mod
195 assert clz_mod
.lz
.width
== self
.shape
.shift_width
, \
196 "internal inconsistency -- mismatched shift signal width"
198 clz_mod
.sig_in
.eq(d
>> 1),
199 self
.shift
.eq(clz_mod
.lz
),
200 self
.d
.eq(d
<< self
.shift
),
201 self
.r
.eq(n
<< self
.shift
),
207 def eq_but_zero_substep(self
, rhs
, do_assert
):
208 assert isinstance(rhs
, CLDivRemState
)
209 for f
in fields(CLDivRemState
):
210 if f
in ("shape", "name"):
220 def set_to_next(self
, m
, state_in
):
221 assert isinstance(m
, Module
)
222 assert isinstance(state_in
, CLDivRemState
)
223 assert state_in
.shape
== self
.shape
224 assert self
is not state_in
, "a.set_to_next(m, a) is not allowed"
225 width
= self
.shape
.width
226 substep_wraps
= state_in
.substep
>= self
.shape
.steps_per_clock
- 1
227 with m
.If(substep_wraps
):
228 m
.d
.comb
+= self
.substep
.eq(0)
230 m
.d
.comb
+= self
.substep
.eq(state_in
.substep
+ 1)
232 with m
.If(state_in
.done
):
234 self
.clock
.eq(state_in
.clock
),
235 self
.d
.eq(state_in
.d
),
236 self
.r
.eq(state_in
.r
),
237 self
.q
.eq(state_in
.q
),
238 self
.shift
.eq(state_in
.shift
),
241 clock
= state_in
.clock
+ substep_wraps
242 with m
.If((clock
== width
// self
.shape
.steps_per_clock
)
243 & (self
.substep
>= width
% self
.shape
.steps_per_clock
)):
244 m
.d
.comb
+= self
.clock
.eq(self
.shape
.done_clock
)
246 m
.d
.comb
+= self
.clock
.eq(clock
)
248 self
.d
.eq(state_in
.d
),
249 self
.shift
.eq(state_in
.shift
),
253 with m
.If(r
[width
* 2 - 1]):
256 self
.r
.eq(r ^
(state_in
.d
<< width
)),
265 class CLDivRemInputData
:
266 def __init__(self
, shape
):
267 assert isinstance(shape
, CLDivRemShape
)
269 self
.n
= Signal(shape
.width
)
270 self
.d
= Signal(shape
.width
)
273 """ Get member signals. """
278 """ Assign member signals. """
285 class CLDivRemOutputData
:
286 def __init__(self
, shape
):
287 assert isinstance(shape
, CLDivRemShape
)
289 self
.q
= Signal(shape
.width
)
290 self
.r
= Signal(shape
.width
)
293 """ Get member signals. """
298 """ Assign member signals. """
304 def eq_output(self
, state
):
305 assert isinstance(state
, CLDivRemState
)
306 assert state
.shape
== self
.shape
307 q
, r
= state
.get_output()
308 return [self
.q
.eq(q
), self
.r
.eq(r
)]
311 class CLDivRemFSMStage(ControlBase
):
312 """carry-less div/rem
320 true if nothing is stored in `self.saved_state`
321 saved_state: CLDivRemState()
322 the saved state that is currently being worked on.
325 def __init__(self
, pspec
, shape
):
326 assert isinstance(shape
, CLDivRemShape
)
328 self
.pspec
= pspec
# store now: used in ispec and ospec
329 super().__init
__(stage
=self
)
330 self
.empty
= Signal(reset
=1)
331 self
.saved_state
= CLDivRemState(shape
)
334 return CLDivRemInputData(self
.shape
)
337 return CLDivRemOutputData(self
.shape
)
339 def setup(self
, m
, i
):
342 def elaborate(self
, platform
):
343 m
= super().elaborate(platform
)
344 i_data
: CLDivRemInputData
= self
.p
.i_data
345 o_data
: CLDivRemOutputData
= self
.n
.o_data
346 steps_per_clock
= self
.shape
.steps_per_clock
348 # TODO: handle cancellation
350 m
.d
.comb
+= self
.n
.o_valid
.eq(~self
.empty
& self
.saved_state
.done
)
351 m
.d
.comb
+= self
.p
.o_ready
.eq(self
.empty
)
354 return CLDivRemState(self
.shape
, name
=f
"next_chain_{i}")
355 next_chain
= [make_nc(i
) for i
in range(steps_per_clock
+ 1)]
356 for i
in range(steps_per_clock
):
357 next_chain
[i
+ 1].set_to_next(m
, next_chain
[i
])
358 m
.d
.comb
+= next_chain
[0].eq(self
.saved_state
)
359 m
.d
.comb
+= o_data
.eq_output(self
.saved_state
)
360 initial_state
= CLDivRemState(self
.shape
)
361 initial_state
.set_to_initial(m
, n
=i_data
.n
, d
=i_data
.d
)
363 do_assert
= platform
== "formal"
365 with m
.If(self
.empty
):
366 m
.d
.sync
+= self
.saved_state
.eq_but_zero_substep(initial_state
,
368 with m
.If(self
.p
.i_valid
):
369 m
.d
.sync
+= self
.empty
.eq(0)
371 m
.d
.sync
+= self
.saved_state
.eq_but_zero_substep(next_chain
[-1],
373 with m
.If(self
.n
.i_ready
& self
.n
.o_valid
):
374 m
.d
.sync
+= self
.empty
.eq(1)