1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """ Core of the div/rem/sqrt/rsqrt pipeline.
5 Special case handling, input/output conversion, and muxid handling are handled
6 outside of these classes.
8 Algorithms based on ``algorithm.FixedUDivRemSqrtRSqrt``.
12 ``dividend == quotient_root * divisor_radicand``
14 ``divisor_radicand == quotient_root * quotient_root``
16 ``1 == quotient_root * quotient_root * divisor_radicand``
18 The remainder is the left-hand-side of the comparison minus the
19 right-hand-side of the comparison in the above formulas.
21 from nmigen
import (Elaboratable
, Module
, Signal
, Const
, Mux
, Cat
, Repl
)
22 from nmigen
.lib
.coding
import PriorityEncoder
23 from nmutil
.util
import treereduce
28 class DivPipeCoreOperation(enum
.Enum
):
29 """ Operation for ``DivPipeCore``.
31 :attribute UDivRem: unsigned divide/remainder.
32 :attribute SqrtRem: square-root/remainder.
33 :attribute RSqrtRem: reciprocal-square-root/remainder.
41 """ Convert to int. """
45 def create_signal(cls
, *, src_loc_at
=0, **kwargs
):
46 """ Create a signal that can contain a ``DivPipeCoreOperation``. """
47 return Signal(range(min(map(int, cls
)), max(map(int, cls
)) + 2),
48 src_loc_at
=(src_loc_at
+ 1),
49 decoder
=lambda v
: str(cls(v
)),
53 DP
= DivPipeCoreOperation
56 class DivPipeCoreConfig
:
57 """ Configuration for core of the div/rem/sqrt/rsqrt pipeline.
59 :attribute bit_width: base bit-width.
60 :attribute fract_width: base fract-width. Specifies location of base-2
62 :attribute log2_radix: number of bits of ``quotient_root`` that should be
63 computed per pipeline stage.
66 def __init__(self
, bit_width
, fract_width
, log2_radix
, supported
=None):
67 """ Create a ``DivPipeCoreConfig`` instance. """
68 self
.bit_width
= bit_width
69 self
.fract_width
= fract_width
70 self
.log2_radix
= log2_radix
72 supported
= frozenset(DP
)
74 supported
= frozenset(supported
)
75 self
.supported
= supported
76 print(f
"{self}: n_stages={self.n_stages}")
80 return f
"DivPipeCoreConfig({self.bit_width}, " \
81 + f
"{self.fract_width}, {self.log2_radix}, "\
82 + f
"supported={self.supported})"
86 """ Get the number of ``DivPipeCoreCalculateStage`` needed. """
87 return (self
.bit_width
+ self
.log2_radix
- 1) // self
.log2_radix
90 class DivPipeCoreInputData
:
91 """ input data type for ``DivPipeCore``.
93 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
94 configuration to be used.
95 :attribute dividend: dividend for div/rem. Signal with a bit-width of
96 ``core_config.bit_width + core_config.fract_width`` and a fract-width
97 of ``core_config.fract_width * 2`` bits.
98 :attribute divisor_radicand: divisor for div/rem and radicand for
99 sqrt/rsqrt. Signal with a bit-width of ``core_config.bit_width`` and a
100 fract-width of ``core_config.fract_width`` bits.
101 :attribute operation: the ``DivPipeCoreOperation`` to be computed.
104 def __init__(self
, core_config
, reset_less
=True):
105 """ Create a ``DivPipeCoreInputData`` instance. """
106 self
.core_config
= core_config
107 bw
= core_config
.bit_width
108 fw
= core_config
.fract_width
109 self
.dividend
= Signal(bw
+ fw
, reset_less
=reset_less
)
110 self
.divisor_radicand
= Signal(bw
, reset_less
=reset_less
)
111 self
.operation
= DP
.create_signal(reset_less
=reset_less
)
114 """ Get member signals. """
116 yield self
.divisor_radicand
120 """ Assign member signals. """
121 return [self
.dividend
.eq(rhs
.dividend
),
122 self
.divisor_radicand
.eq(rhs
.divisor_radicand
),
123 self
.operation
.eq(rhs
.operation
),
127 class DivPipeCoreInterstageData
:
128 """ interstage data type for ``DivPipeCore``.
130 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
131 configuration to be used.
132 :attribute divisor_radicand: divisor for div/rem and radicand for
133 sqrt/rsqrt. Signal with a bit-width of ``core_config.bit_width`` and a
134 fract-width of ``core_config.fract_width`` bits.
135 :attribute operation: the ``DivPipeCoreOperation`` to be computed.
136 :attribute quotient_root: the quotient or root part of the result of the
137 operation. Signal with a bit-width of ``core_config.bit_width`` and a
138 fract-width of ``core_config.fract_width`` bits.
139 :attribute root_times_radicand: ``quotient_root * divisor_radicand``.
140 Signal with a bit-width of ``core_config.bit_width * 2`` and a
141 fract-width of ``core_config.fract_width * 2`` bits.
142 :attribute compare_lhs: The left-hand-side of the comparison in the
143 equation to be solved. Signal with a bit-width of
144 ``core_config.bit_width * 3`` and a fract-width of
145 ``core_config.fract_width * 3`` bits.
146 :attribute compare_rhs: The right-hand-side of the comparison in the
147 equation to be solved. Signal with a bit-width of
148 ``core_config.bit_width * 3`` and a fract-width of
149 ``core_config.fract_width * 3`` bits.
152 def __init__(self
, core_config
, reset_less
=True):
153 """ Create a ``DivPipeCoreInterstageData`` instance. """
154 self
.core_config
= core_config
155 bw
= core_config
.bit_width
156 # TODO(programmerjake): re-enable once bit_width reduction is fixed
157 if False and core_config
.supported
== {DP
.UDivRem
}:
158 self
.compare_len
= bw
* 2
160 self
.compare_len
= bw
* 3
161 self
.divisor_radicand
= Signal(bw
, reset_less
=reset_less
)
162 self
.operation
= DP
.create_signal(reset_less
=reset_less
)
163 self
.quotient_root
= Signal(bw
, reset_less
=reset_less
)
164 self
.root_times_radicand
= Signal(bw
* 2, reset_less
=reset_less
)
165 self
.compare_lhs
= Signal(self
.compare_len
, reset_less
=reset_less
)
166 self
.compare_rhs
= Signal(self
.compare_len
, reset_less
=reset_less
)
169 """ Get member signals. """
170 yield self
.divisor_radicand
172 yield self
.quotient_root
173 yield self
.root_times_radicand
174 yield self
.compare_lhs
175 yield self
.compare_rhs
178 """ Assign member signals. """
179 return [self
.divisor_radicand
.eq(rhs
.divisor_radicand
),
180 self
.operation
.eq(rhs
.operation
),
181 self
.quotient_root
.eq(rhs
.quotient_root
),
182 self
.root_times_radicand
.eq(rhs
.root_times_radicand
),
183 self
.compare_lhs
.eq(rhs
.compare_lhs
),
184 self
.compare_rhs
.eq(rhs
.compare_rhs
)]
187 class DivPipeCoreOutputData
:
188 """ output data type for ``DivPipeCore``.
190 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
191 configuration to be used.
192 :attribute quotient_root: the quotient or root part of the result of the
193 operation. Signal with a bit-width of ``core_config.bit_width`` and a
194 fract-width of ``core_config.fract_width`` bits.
195 :attribute remainder: the remainder part of the result of the operation.
196 Signal with a bit-width of ``core_config.bit_width * 3`` and a
197 fract-width of ``core_config.fract_width * 3`` bits.
200 def __init__(self
, core_config
, reset_less
=True):
201 """ Create a ``DivPipeCoreOutputData`` instance. """
202 self
.core_config
= core_config
203 bw
= core_config
.bit_width
204 # TODO(programmerjake): re-enable once bit_width reduction is fixed
205 if False and core_config
.supported
== {DP
.UDivRem
}:
206 self
.compare_len
= bw
* 2
208 self
.compare_len
= bw
* 3
209 self
.quotient_root
= Signal(bw
, reset_less
=reset_less
)
210 self
.remainder
= Signal(self
.compare_len
, reset_less
=reset_less
)
213 """ Get member signals. """
214 yield self
.quotient_root
219 """ Assign member signals. """
220 return [self
.quotient_root
.eq(rhs
.quotient_root
),
221 self
.remainder
.eq(rhs
.remainder
)]
224 class DivPipeCoreSetupStage(Elaboratable
):
225 """ Setup Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
227 def __init__(self
, core_config
):
228 """ Create a ``DivPipeCoreSetupStage`` instance."""
229 self
.core_config
= core_config
230 self
.i
= self
.ispec()
231 self
.o
= self
.ospec()
232 bw
= core_config
.bit_width
233 # TODO(programmerjake): re-enable once bit_width reduction is fixed
234 if False and core_config
.supported
== {DP
.UDivRem
}:
235 self
.compare_len
= bw
* 2
237 self
.compare_len
= bw
* 3
240 """ Get the input spec for this pipeline stage."""
241 return DivPipeCoreInputData(self
.core_config
)
244 """ Get the output spec for this pipeline stage."""
245 return DivPipeCoreInterstageData(self
.core_config
)
247 def setup(self
, m
, i
):
248 """ Pipeline stage setup. """
249 m
.submodules
.div_pipe_core_setup
= self
250 m
.d
.comb
+= self
.i
.eq(i
)
252 def process(self
, i
):
253 """ Pipeline stage process. """
254 return self
.o
# return processed data (ignore i)
256 def elaborate(self
, platform
):
257 """ Elaborate into ``Module``. """
261 comb
+= self
.o
.divisor_radicand
.eq(self
.i
.divisor_radicand
)
262 comb
+= self
.o
.quotient_root
.eq(0)
263 comb
+= self
.o
.root_times_radicand
.eq(0)
265 lhs
= Signal(self
.compare_len
, reset_less
=True)
266 fw
= self
.core_config
.fract_width
268 with m
.Switch(self
.i
.operation
):
269 with m
.Case(int(DP
.UDivRem
)):
270 comb
+= lhs
.eq(self
.i
.dividend
<< fw
)
271 with m
.Case(int(DP
.SqrtRem
)):
272 comb
+= lhs
.eq(self
.i
.divisor_radicand
<< (fw
* 2))
273 with m
.Case(int(DP
.RSqrtRem
)):
274 comb
+= lhs
.eq(1 << (fw
* 3))
276 comb
+= self
.o
.compare_lhs
.eq(lhs
)
277 comb
+= self
.o
.compare_rhs
.eq(0)
278 comb
+= self
.o
.operation
.eq(self
.i
.operation
)
283 class Trial(Elaboratable
):
284 def __init__(self
, core_config
, trial_bits
, current_shift
, log2_radix
):
285 self
.core_config
= core_config
286 self
.trial_bits
= trial_bits
287 self
.current_shift
= current_shift
288 self
.log2_radix
= log2_radix
289 bw
= core_config
.bit_width
290 # TODO(programmerjake): re-enable once bit_width reduction is fixed
291 if False and core_config
.supported
== {DP
.UDivRem
}:
292 self
.compare_len
= bw
* 2
294 self
.compare_len
= bw
* 3
295 self
.divisor_radicand
= Signal(bw
, reset_less
=True)
296 self
.quotient_root
= Signal(bw
, reset_less
=True)
297 self
.root_times_radicand
= Signal(bw
* 2, reset_less
=True)
298 self
.compare_rhs
= Signal(self
.compare_len
, reset_less
=True)
299 self
.trial_compare_rhs
= Signal(self
.compare_len
, reset_less
=True)
300 self
.operation
= DP
.create_signal(reset_less
=True)
302 def elaborate(self
, platform
):
307 cc
= self
.core_config
308 dr
= self
.divisor_radicand
310 trial_bits_sig
= Const(self
.trial_bits
, self
.log2_radix
)
311 trial_bits_sqrd_sig
= Const(self
.trial_bits
* self
.trial_bits
,
314 tblen
= self
.core_config
.bit_width
+self
.log2_radix
317 if DP
.UDivRem
in cc
.supported
:
318 with m
.If(self
.operation
== int(DP
.UDivRem
)):
319 dr_times_trial_bits
= Signal(tblen
, reset_less
=True)
320 comb
+= dr_times_trial_bits
.eq(dr
* trial_bits_sig
)
321 div_rhs
= self
.compare_rhs
323 div_term1
= dr_times_trial_bits
324 div_term1_shift
= self
.core_config
.fract_width
325 div_term1_shift
+= self
.current_shift
326 div_rhs
+= div_term1
<< div_term1_shift
328 comb
+= self
.trial_compare_rhs
.eq(div_rhs
)
331 if DP
.SqrtRem
in cc
.supported
:
332 with m
.If(self
.operation
== int(DP
.SqrtRem
)):
333 qr
= self
.quotient_root
334 qr_times_trial_bits
= Signal((tblen
+1)*2, reset_less
=True)
335 comb
+= qr_times_trial_bits
.eq(qr
* trial_bits_sig
)
336 sqrt_rhs
= self
.compare_rhs
338 sqrt_term1
= qr_times_trial_bits
339 sqrt_term1_shift
= self
.core_config
.fract_width
340 sqrt_term1_shift
+= self
.current_shift
+ 1
341 sqrt_rhs
+= sqrt_term1
<< sqrt_term1_shift
342 sqrt_term2
= trial_bits_sqrd_sig
343 sqrt_term2_shift
= self
.core_config
.fract_width
344 sqrt_term2_shift
+= self
.current_shift
* 2
345 sqrt_rhs
+= sqrt_term2
<< sqrt_term2_shift
347 comb
+= self
.trial_compare_rhs
.eq(sqrt_rhs
)
350 if DP
.RSqrtRem
in cc
.supported
:
351 with m
.If(self
.operation
== int(DP
.RSqrtRem
)):
352 rr
= self
.root_times_radicand
353 tblen2
= self
.core_config
.bit_width
+self
.log2_radix
*2
354 dr_times_trial_bits_sqrd
= Signal(tblen2
, reset_less
=True)
355 comb
+= dr_times_trial_bits_sqrd
.eq(dr
* trial_bits_sqrd_sig
)
356 rr_times_trial_bits
= Signal((tblen
+1)*3, reset_less
=True)
357 comb
+= rr_times_trial_bits
.eq(rr
* trial_bits_sig
)
358 rsqrt_rhs
= self
.compare_rhs
360 rsqrt_term1
= rr_times_trial_bits
361 rsqrt_term1_shift
= self
.current_shift
+ 1
362 rsqrt_rhs
+= rsqrt_term1
<< rsqrt_term1_shift
363 rsqrt_term2
= dr_times_trial_bits_sqrd
364 rsqrt_term2_shift
= self
.current_shift
* 2
365 rsqrt_rhs
+= rsqrt_term2
<< rsqrt_term2_shift
367 comb
+= self
.trial_compare_rhs
.eq(rsqrt_rhs
)
372 class DivPipeCoreCalculateStage(Elaboratable
):
373 """ Calculate Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
375 def __init__(self
, core_config
, stage_index
):
376 """ Create a ``DivPipeCoreSetupStage`` instance. """
377 assert stage_index
in range(core_config
.n_stages
)
378 self
.core_config
= core_config
379 bw
= core_config
.bit_width
380 # TODO(programmerjake): re-enable once bit_width reduction is fixed
381 if False and core_config
.supported
== {DP
.UDivRem
}:
382 self
.compare_len
= bw
* 2
384 self
.compare_len
= bw
* 3
385 self
.stage_index
= stage_index
386 self
.i
= self
.ispec()
387 self
.o
= self
.ospec()
390 """ Get the input spec for this pipeline stage. """
391 return DivPipeCoreInterstageData(self
.core_config
)
394 """ Get the output spec for this pipeline stage. """
395 return DivPipeCoreInterstageData(self
.core_config
)
397 def setup(self
, m
, i
):
398 """ Pipeline stage setup. """
399 setattr(m
.submodules
,
400 f
"div_pipe_core_calculate_{self.stage_index}",
402 m
.d
.comb
+= self
.i
.eq(i
)
404 def process(self
, i
):
405 """ Pipeline stage process. """
408 def elaborate(self
, platform
):
409 """ Elaborate into ``Module``. """
412 cc
= self
.core_config
414 # copy invariant inputs to outputs (for next stage)
415 comb
+= self
.o
.divisor_radicand
.eq(self
.i
.divisor_radicand
)
416 comb
+= self
.o
.operation
.eq(self
.i
.operation
)
417 comb
+= self
.o
.compare_lhs
.eq(self
.i
.compare_lhs
)
420 log2_radix
= self
.core_config
.log2_radix
421 current_shift
= self
.core_config
.bit_width
422 current_shift
-= self
.stage_index
* log2_radix
423 log2_radix
= min(log2_radix
, current_shift
)
424 assert log2_radix
> 0
425 current_shift
-= log2_radix
426 print(f
"DivPipeCoreCalc: stage {self.stage_index}"
427 + f
" of {self.core_config.n_stages} handling "
428 + f
"bits [{current_shift}, {current_shift+log2_radix})"
429 + f
" of {self.core_config.bit_width}")
430 radix
= 1 << log2_radix
432 # trials within this radix range. carried out by Trial module,
433 # results stored in pass_flags. pass_flags are unary priority.
434 trial_compare_rhs_values
= []
436 for trial_bits
in range(radix
):
437 t
= Trial(self
.core_config
, trial_bits
, current_shift
, log2_radix
)
438 setattr(m
.submodules
, "trial%d" % trial_bits
, t
)
440 comb
+= t
.divisor_radicand
.eq(self
.i
.divisor_radicand
)
441 comb
+= t
.quotient_root
.eq(self
.i
.quotient_root
)
442 comb
+= t
.root_times_radicand
.eq(self
.i
.root_times_radicand
)
443 comb
+= t
.compare_rhs
.eq(self
.i
.compare_rhs
)
444 comb
+= t
.operation
.eq(self
.i
.operation
)
446 # get the trial output (needed even in pass_flags[0] case)
447 trial_compare_rhs_values
.append(t
.trial_compare_rhs
)
449 # make the trial comparison against the [invariant] lhs.
450 # trial_compare_rhs is always decreasing as trial_bits increases
451 pass_flag
= Signal(name
=f
"pass_flag_{trial_bits}", reset_less
=True)
453 # do not do first comparison: no point.
454 comb
+= pass_flag
.eq(1)
456 comb
+= pass_flag
.eq(self
.i
.compare_lhs
>= t
.trial_compare_rhs
)
457 pfl
.append(pass_flag
)
459 # Cat all the pass flags list together (easier to handle, below)
460 pass_flags
= Signal(radix
, reset_less
=True)
461 comb
+= pass_flags
.eq(Cat(*pfl
))
463 # convert pass_flags (unary priority) to next_bits (binary index)
465 # Assumes that for each set bit in pass_flag, all previous bits are
468 # Assumes that pass_flag[0] is always set (since
469 # compare_lhs >= compare_rhs is a pipeline invariant).
471 m
.submodules
.pe
= pe
= PriorityEncoder(radix
)
472 next_bits
= Signal(log2_radix
, reset_less
=True)
473 comb
+= pe
.i
.eq(~pass_flags
)
475 comb
+= next_bits
.eq(pe
.o
-1)
477 comb
+= next_bits
.eq(radix
-1)
479 # get the highest passing rhs trial. use treereduce because
480 # Array on such massively long numbers is insanely gate-hungry
482 tcrh
= trial_compare_rhs_values
483 for i
in range(radix
):
484 nbe
= Signal(reset_less
=True)
485 comb
+= nbe
.eq(next_bits
== i
)
486 crhs
.append(Repl(nbe
, self
.compare_len
) & tcrh
[i
])
487 comb
+= self
.o
.compare_rhs
.eq(treereduce(crhs
, operator
.or_
,
490 # create outputs for next phase
491 qr
= self
.i
.quotient_root |
(next_bits
<< current_shift
)
492 comb
+= self
.o
.quotient_root
.eq(qr
)
493 if DP
.RSqrtRem
in cc
.supported
:
494 rr
= self
.i
.root_times_radicand
+ ((self
.i
.divisor_radicand
*
495 next_bits
) << current_shift
)
496 comb
+= self
.o
.root_times_radicand
.eq(rr
)
501 class DivPipeCoreFinalStage(Elaboratable
):
502 """ Final Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
504 def __init__(self
, core_config
):
505 """ Create a ``DivPipeCoreFinalStage`` instance."""
506 self
.core_config
= core_config
507 self
.i
= self
.ispec()
508 self
.o
= self
.ospec()
511 """ Get the input spec for this pipeline stage."""
512 return DivPipeCoreInterstageData(self
.core_config
)
515 """ Get the output spec for this pipeline stage."""
516 return DivPipeCoreOutputData(self
.core_config
)
518 def setup(self
, m
, i
):
519 """ Pipeline stage setup. """
520 m
.submodules
.div_pipe_core_final
= self
521 m
.d
.comb
+= self
.i
.eq(i
)
523 def process(self
, i
):
524 """ Pipeline stage process. """
525 return self
.o
# return processed data (ignore i)
527 def elaborate(self
, platform
):
528 """ Elaborate into ``Module``. """
532 comb
+= self
.o
.quotient_root
.eq(self
.i
.quotient_root
)
533 comb
+= self
.o
.remainder
.eq(self
.i
.compare_lhs
- self
.i
.compare_rhs
)