1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """ Core of the div/rem/sqrt/rsqrt pipeline.
5 Special case handling, input/output conversion, and muxid handling are handled
6 outside of these classes.
8 Algorithms based on ``algorithm.FixedUDivRemSqrtRSqrt``.
12 ``dividend == quotient_root * divisor_radicand``
14 ``divisor_radicand == quotient_root * quotient_root``
16 ``1 == quotient_root * quotient_root * divisor_radicand``
18 The remainder is the left-hand-side of the comparison minus the
19 right-hand-side of the comparison in the above formulas.
21 from nmigen
import (Elaboratable
, Module
, Signal
, Const
, Mux
, Cat
, Array
)
22 from nmigen
.lib
.coding
import PriorityEncoder
26 class DivPipeCoreOperation(enum
.Enum
):
27 """ Operation for ``DivPipeCore``.
29 :attribute UDivRem: unsigned divide/remainder.
30 :attribute SqrtRem: square-root/remainder.
31 :attribute RSqrtRem: reciprocal-square-root/remainder.
39 """ Convert to int. """
43 def create_signal(cls
, *, src_loc_at
=0, **kwargs
):
44 """ Create a signal that can contain a ``DivPipeCoreOperation``. """
45 return Signal(range(min(map(int, cls
)), max(map(int, cls
)) + 2),
46 src_loc_at
=(src_loc_at
+ 1),
47 decoder
=lambda v
: str(cls(v
)),
51 DP
= DivPipeCoreOperation
54 class DivPipeCoreConfig
:
55 """ Configuration for core of the div/rem/sqrt/rsqrt pipeline.
57 :attribute bit_width: base bit-width.
58 :attribute fract_width: base fract-width. Specifies location of base-2
60 :attribute log2_radix: number of bits of ``quotient_root`` that should be
61 computed per pipeline stage.
64 def __init__(self
, bit_width
, fract_width
, log2_radix
, supported
=None):
65 """ Create a ``DivPipeCoreConfig`` instance. """
66 self
.bit_width
= bit_width
67 self
.fract_width
= fract_width
68 self
.log2_radix
= log2_radix
70 supported
= [DP
.SqrtRem
, DP
.UDivRem
, DP
.RSqrtRem
]
71 self
.supported
= supported
72 print(f
"{self}: n_stages={self.n_stages}")
76 return f
"DivPipeCoreConfig({self.bit_width}, " \
77 + f
"{self.fract_width}, {self.log2_radix})"
81 """ Get the number of ``DivPipeCoreCalculateStage`` needed. """
82 return (self
.bit_width
+ self
.log2_radix
- 1) // self
.log2_radix
85 class DivPipeCoreInputData
:
86 """ input data type for ``DivPipeCore``.
88 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
89 configuration to be used.
90 :attribute dividend: dividend for div/rem. Signal with a bit-width of
91 ``core_config.bit_width + core_config.fract_width`` and a fract-width
92 of ``core_config.fract_width * 2`` bits.
93 :attribute divisor_radicand: divisor for div/rem and radicand for
94 sqrt/rsqrt. Signal with a bit-width of ``core_config.bit_width`` and a
95 fract-width of ``core_config.fract_width`` bits.
96 :attribute operation: the ``DivPipeCoreOperation`` to be computed.
99 def __init__(self
, core_config
, reset_less
=True):
100 """ Create a ``DivPipeCoreInputData`` instance. """
101 self
.core_config
= core_config
102 bw
= core_config
.bit_width
103 fw
= core_config
.fract_width
104 self
.dividend
= Signal(bw
+ fw
, reset_less
=reset_less
)
105 self
.divisor_radicand
= Signal(bw
, reset_less
=reset_less
)
106 self
.operation
= DP
.create_signal(reset_less
=reset_less
)
109 """ Get member signals. """
111 yield self
.divisor_radicand
115 """ Assign member signals. """
116 return [self
.dividend
.eq(rhs
.dividend
),
117 self
.divisor_radicand
.eq(rhs
.divisor_radicand
),
118 self
.operation
.eq(rhs
.operation
),
122 class DivPipeCoreInterstageData
:
123 """ interstage data type for ``DivPipeCore``.
125 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
126 configuration to be used.
127 :attribute divisor_radicand: divisor for div/rem and radicand for
128 sqrt/rsqrt. Signal with a bit-width of ``core_config.bit_width`` and a
129 fract-width of ``core_config.fract_width`` bits.
130 :attribute operation: the ``DivPipeCoreOperation`` to be computed.
131 :attribute quotient_root: the quotient or root part of the result of the
132 operation. Signal with a bit-width of ``core_config.bit_width`` and a
133 fract-width of ``core_config.fract_width`` bits.
134 :attribute root_times_radicand: ``quotient_root * divisor_radicand``.
135 Signal with a bit-width of ``core_config.bit_width * 2`` and a
136 fract-width of ``core_config.fract_width * 2`` bits.
137 :attribute compare_lhs: The left-hand-side of the comparison in the
138 equation to be solved. Signal with a bit-width of
139 ``core_config.bit_width * 3`` and a fract-width of
140 ``core_config.fract_width * 3`` bits.
141 :attribute compare_rhs: The right-hand-side of the comparison in the
142 equation to be solved. Signal with a bit-width of
143 ``core_config.bit_width * 3`` and a fract-width of
144 ``core_config.fract_width * 3`` bits.
147 def __init__(self
, core_config
, reset_less
=True):
148 """ Create a ``DivPipeCoreInterstageData`` instance. """
149 self
.core_config
= core_config
150 bw
= core_config
.bit_width
151 self
.divisor_radicand
= Signal(bw
, reset_less
=reset_less
)
152 self
.operation
= DP
.create_signal(reset_less
=reset_less
)
153 self
.quotient_root
= Signal(bw
, reset_less
=reset_less
)
154 self
.root_times_radicand
= Signal(bw
* 2, reset_less
=reset_less
)
155 self
.compare_lhs
= Signal(bw
* 3, reset_less
=reset_less
)
156 self
.compare_rhs
= Signal(bw
* 3, reset_less
=reset_less
)
159 """ Get member signals. """
160 yield self
.divisor_radicand
162 yield self
.quotient_root
163 yield self
.root_times_radicand
164 yield self
.compare_lhs
165 yield self
.compare_rhs
168 """ Assign member signals. """
169 return [self
.divisor_radicand
.eq(rhs
.divisor_radicand
),
170 self
.operation
.eq(rhs
.operation
),
171 self
.quotient_root
.eq(rhs
.quotient_root
),
172 self
.root_times_radicand
.eq(rhs
.root_times_radicand
),
173 self
.compare_lhs
.eq(rhs
.compare_lhs
),
174 self
.compare_rhs
.eq(rhs
.compare_rhs
)]
177 class DivPipeCoreOutputData
:
178 """ output data type for ``DivPipeCore``.
180 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
181 configuration to be used.
182 :attribute quotient_root: the quotient or root part of the result of the
183 operation. Signal with a bit-width of ``core_config.bit_width`` and a
184 fract-width of ``core_config.fract_width`` bits.
185 :attribute remainder: the remainder part of the result of the operation.
186 Signal with a bit-width of ``core_config.bit_width * 3`` and a
187 fract-width of ``core_config.fract_width * 3`` bits.
190 def __init__(self
, core_config
, reset_less
=True):
191 """ Create a ``DivPipeCoreOutputData`` instance. """
192 self
.core_config
= core_config
193 bw
= core_config
.bit_width
194 self
.quotient_root
= Signal(bw
, reset_less
=reset_less
)
195 self
.remainder
= Signal(bw
* 3, reset_less
=reset_less
)
198 """ Get member signals. """
199 yield self
.quotient_root
204 """ Assign member signals. """
205 return [self
.quotient_root
.eq(rhs
.quotient_root
),
206 self
.remainder
.eq(rhs
.remainder
)]
209 class DivPipeCoreSetupStage(Elaboratable
):
210 """ Setup Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
212 def __init__(self
, core_config
):
213 """ Create a ``DivPipeCoreSetupStage`` instance."""
214 self
.core_config
= core_config
215 self
.i
= self
.ispec()
216 self
.o
= self
.ospec()
219 """ Get the input spec for this pipeline stage."""
220 return DivPipeCoreInputData(self
.core_config
)
223 """ Get the output spec for this pipeline stage."""
224 return DivPipeCoreInterstageData(self
.core_config
)
226 def setup(self
, m
, i
):
227 """ Pipeline stage setup. """
228 m
.submodules
.div_pipe_core_setup
= self
229 m
.d
.comb
+= self
.i
.eq(i
)
231 def process(self
, i
):
232 """ Pipeline stage process. """
233 return self
.o
# return processed data (ignore i)
235 def elaborate(self
, platform
):
236 """ Elaborate into ``Module``. """
240 comb
+= self
.o
.divisor_radicand
.eq(self
.i
.divisor_radicand
)
241 comb
+= self
.o
.quotient_root
.eq(0)
242 comb
+= self
.o
.root_times_radicand
.eq(0)
244 lhs
= Signal(self
.core_config
.bit_width
* 3, reset_less
=True)
245 fw
= self
.core_config
.fract_width
247 with m
.Switch(self
.i
.operation
):
248 with m
.Case(int(DP
.UDivRem
)):
249 comb
+= lhs
.eq(self
.i
.dividend
<< fw
)
250 with m
.Case(int(DP
.SqrtRem
)):
251 comb
+= lhs
.eq(self
.i
.divisor_radicand
<< (fw
* 2))
252 with m
.Case(int(DP
.RSqrtRem
)):
253 comb
+= lhs
.eq(1 << (fw
* 3))
255 comb
+= self
.o
.compare_lhs
.eq(lhs
)
256 comb
+= self
.o
.compare_rhs
.eq(0)
257 comb
+= self
.o
.operation
.eq(self
.i
.operation
)
262 class Trial(Elaboratable
):
263 def __init__(self
, core_config
, trial_bits
, current_shift
, log2_radix
):
264 self
.core_config
= core_config
265 self
.trial_bits
= trial_bits
266 self
.current_shift
= current_shift
267 self
.log2_radix
= log2_radix
268 bw
= core_config
.bit_width
269 self
.divisor_radicand
= Signal(bw
, reset_less
=True)
270 self
.quotient_root
= Signal(bw
, reset_less
=True)
271 self
.root_times_radicand
= Signal(bw
* 2, reset_less
=True)
272 self
.compare_rhs
= Signal(bw
* 3, reset_less
=True)
273 self
.trial_compare_rhs
= Signal(bw
* 3, reset_less
=True)
274 self
.operation
= DP
.create_signal(reset_less
=True)
276 def elaborate(self
, platform
):
281 cc
= self
.core_config
282 dr
= self
.divisor_radicand
284 trial_bits_sig
= Const(self
.trial_bits
, self
.log2_radix
)
285 trial_bits_sqrd_sig
= Const(self
.trial_bits
* self
.trial_bits
,
288 tblen
= self
.core_config
.bit_width
+self
.log2_radix
291 if DP
.UDivRem
in cc
.supported
:
292 with m
.If(self
.operation
== int(DP
.UDivRem
)):
293 dr_times_trial_bits
= Signal(tblen
, reset_less
=True)
294 comb
+= dr_times_trial_bits
.eq(dr
* trial_bits_sig
)
295 div_rhs
= self
.compare_rhs
297 div_term1
= dr_times_trial_bits
298 div_term1_shift
= self
.core_config
.fract_width
299 div_term1_shift
+= self
.current_shift
300 div_rhs
+= div_term1
<< div_term1_shift
302 comb
+= self
.trial_compare_rhs
.eq(div_rhs
)
305 if DP
.SqrtRem
in cc
.supported
:
306 with m
.If(self
.operation
== int(DP
.SqrtRem
)):
307 qr
= self
.quotient_root
308 qr_times_trial_bits
= Signal((tblen
+1)*2, reset_less
=True)
309 comb
+= qr_times_trial_bits
.eq(qr
* trial_bits_sig
)
310 sqrt_rhs
= self
.compare_rhs
312 sqrt_term1
= qr_times_trial_bits
313 sqrt_term1_shift
= self
.core_config
.fract_width
314 sqrt_term1_shift
+= self
.current_shift
+ 1
315 sqrt_rhs
+= sqrt_term1
<< sqrt_term1_shift
316 sqrt_term2
= trial_bits_sqrd_sig
317 sqrt_term2_shift
= self
.core_config
.fract_width
318 sqrt_term2_shift
+= self
.current_shift
* 2
319 sqrt_rhs
+= sqrt_term2
<< sqrt_term2_shift
321 comb
+= self
.trial_compare_rhs
.eq(sqrt_rhs
)
324 if DP
.RSqrtRem
in cc
.supported
:
325 with m
.If(self
.operation
== int(DP
.RSqrtRem
)):
326 rr
= self
.root_times_radicand
327 tblen2
= self
.core_config
.bit_width
+self
.log2_radix
*2
328 dr_times_trial_bits_sqrd
= Signal(tblen2
, reset_less
=True)
329 comb
+= dr_times_trial_bits_sqrd
.eq(dr
* trial_bits_sqrd_sig
)
330 rr_times_trial_bits
= Signal((tblen
+1)*3, reset_less
=True)
331 comb
+= rr_times_trial_bits
.eq(rr
* trial_bits_sig
)
332 rsqrt_rhs
= self
.compare_rhs
334 rsqrt_term1
= rr_times_trial_bits
335 rsqrt_term1_shift
= self
.current_shift
+ 1
336 rsqrt_rhs
+= rsqrt_term1
<< rsqrt_term1_shift
337 rsqrt_term2
= dr_times_trial_bits_sqrd
338 rsqrt_term2_shift
= self
.current_shift
* 2
339 rsqrt_rhs
+= rsqrt_term2
<< rsqrt_term2_shift
341 comb
+= self
.trial_compare_rhs
.eq(rsqrt_rhs
)
346 class DivPipeCoreCalculateStage(Elaboratable
):
347 """ Calculate Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
349 def __init__(self
, core_config
, stage_index
):
350 """ Create a ``DivPipeCoreSetupStage`` instance. """
351 assert stage_index
in range(core_config
.n_stages
)
352 self
.core_config
= core_config
353 self
.stage_index
= stage_index
354 self
.i
= self
.ispec()
355 self
.o
= self
.ospec()
358 """ Get the input spec for this pipeline stage. """
359 return DivPipeCoreInterstageData(self
.core_config
)
362 """ Get the output spec for this pipeline stage. """
363 return DivPipeCoreInterstageData(self
.core_config
)
365 def setup(self
, m
, i
):
366 """ Pipeline stage setup. """
367 setattr(m
.submodules
,
368 f
"div_pipe_core_calculate_{self.stage_index}",
370 m
.d
.comb
+= self
.i
.eq(i
)
372 def process(self
, i
):
373 """ Pipeline stage process. """
376 def elaborate(self
, platform
):
377 """ Elaborate into ``Module``. """
380 cc
= self
.core_config
382 # copy invariant inputs to outputs (for next stage)
383 comb
+= self
.o
.divisor_radicand
.eq(self
.i
.divisor_radicand
)
384 comb
+= self
.o
.operation
.eq(self
.i
.operation
)
385 comb
+= self
.o
.compare_lhs
.eq(self
.i
.compare_lhs
)
388 log2_radix
= self
.core_config
.log2_radix
389 current_shift
= self
.core_config
.bit_width
390 current_shift
-= self
.stage_index
* log2_radix
391 log2_radix
= min(log2_radix
, current_shift
)
392 assert log2_radix
> 0
393 current_shift
-= log2_radix
394 print(f
"DivPipeCoreCalc: stage {self.stage_index}"
395 + f
" of {self.core_config.n_stages} handling "
396 + f
"bits [{current_shift}, {current_shift+log2_radix})"
397 + f
" of {self.core_config.bit_width}")
398 radix
= 1 << log2_radix
400 # trials within this radix range. carried out by Trial module,
401 # results stored in pass_flags. pass_flags are unary priority.
402 trial_compare_rhs_values
= []
404 for trial_bits
in range(radix
):
405 t
= Trial(self
.core_config
, trial_bits
, current_shift
, log2_radix
)
406 setattr(m
.submodules
, "trial%d" % trial_bits
, t
)
408 comb
+= t
.divisor_radicand
.eq(self
.i
.divisor_radicand
)
409 comb
+= t
.quotient_root
.eq(self
.i
.quotient_root
)
410 comb
+= t
.root_times_radicand
.eq(self
.i
.root_times_radicand
)
411 comb
+= t
.compare_rhs
.eq(self
.i
.compare_rhs
)
412 comb
+= t
.operation
.eq(self
.i
.operation
)
414 # get the trial output
415 trial_compare_rhs_values
.append(t
.trial_compare_rhs
)
417 # make the trial comparison against the [invariant] lhs.
418 # trial_compare_rhs is always decreasing as trial_bits increases
419 pass_flag
= Signal(name
=f
"pass_flag_{trial_bits}", reset_less
=True)
420 comb
+= pass_flag
.eq(self
.i
.compare_lhs
>= t
.trial_compare_rhs
)
421 pfl
.append(pass_flag
)
423 # Cat all the pass flags list together (easier to handle, below)
424 pass_flags
= Signal(radix
, reset_less
=True)
425 comb
+= pass_flags
.eq(Cat(*pfl
))
427 # convert pass_flags (unary priority) to next_bits (binary index)
429 # Assumes that for each set bit in pass_flag, all previous bits are
432 # Assumes that pass_flag[0] is always set (since
433 # compare_lhs >= compare_rhs is a pipeline invariant).
435 m
.submodules
.pe
= pe
= PriorityEncoder(radix
)
436 next_bits
= Signal(log2_radix
, reset_less
=True)
437 comb
+= pe
.i
.eq(~pass_flags
)
439 comb
+= next_bits
.eq(pe
.o
-1)
441 comb
+= next_bits
.eq(radix
-1)
443 # get the highest passing rhs trial (indexed by next_bits)
444 ta
= Array(trial_compare_rhs_values
)
445 comb
+= self
.o
.compare_rhs
.eq(ta
[next_bits
])
447 # create outputs for next phase
448 qr
= self
.i
.quotient_root |
(next_bits
<< current_shift
)
449 comb
+= self
.o
.quotient_root
.eq(qr
)
450 if DP
.RSqrtRem
in cc
.supported
:
451 rr
= self
.i
.root_times_radicand
+ ((self
.i
.divisor_radicand
*
452 next_bits
) << current_shift
)
453 comb
+= self
.o
.root_times_radicand
.eq(rr
)
458 class DivPipeCoreFinalStage(Elaboratable
):
459 """ Final Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
461 def __init__(self
, core_config
):
462 """ Create a ``DivPipeCoreFinalStage`` instance."""
463 self
.core_config
= core_config
464 self
.i
= self
.ispec()
465 self
.o
= self
.ospec()
468 """ Get the input spec for this pipeline stage."""
469 return DivPipeCoreInterstageData(self
.core_config
)
472 """ Get the output spec for this pipeline stage."""
473 return DivPipeCoreOutputData(self
.core_config
)
475 def setup(self
, m
, i
):
476 """ Pipeline stage setup. """
477 m
.submodules
.div_pipe_core_final
= self
478 m
.d
.comb
+= self
.i
.eq(i
)
480 def process(self
, i
):
481 """ Pipeline stage process. """
482 return self
.o
# return processed data (ignore i)
484 def elaborate(self
, platform
):
485 """ Elaborate into ``Module``. """
489 comb
+= self
.o
.quotient_root
.eq(self
.i
.quotient_root
)
490 comb
+= self
.o
.remainder
.eq(self
.i
.compare_lhs
- self
.i
.compare_rhs
)