switch to exact version of cython
[ieee754fpu.git] / src / ieee754 / div_rem_sqrt_rsqrt / core.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """ Core of the div/rem/sqrt/rsqrt pipeline.
4
5 Special case handling, input/output conversion, and muxid handling are handled
6 outside of these classes.
7
8 Algorithms based on ``algorithm.FixedUDivRemSqrtRSqrt``.
9
10 Formulas solved are:
11 * div/rem:
12 ``dividend == quotient_root * divisor_radicand``
13 * sqrt/rem:
14 ``divisor_radicand == quotient_root * quotient_root``
15 * rsqrt/rem:
16 ``1 == quotient_root * quotient_root * divisor_radicand``
17
18 The remainder is the left-hand-side of the comparison minus the
19 right-hand-side of the comparison in the above formulas.
20 """
21 from nmigen import (Elaboratable, Module, Signal, Const, Mux, Cat, Repl)
22 from nmigen.lib.coding import PriorityEncoder
23 from nmutil.util import treereduce
24 import enum
25 import operator
26
27
28 class DivPipeCoreOperation(enum.Enum):
29 """ Operation for ``DivPipeCore``.
30
31 :attribute UDivRem: unsigned divide/remainder.
32 :attribute SqrtRem: square-root/remainder.
33 :attribute RSqrtRem: reciprocal-square-root/remainder.
34 """
35
36 SqrtRem = 0
37 UDivRem = 1
38 RSqrtRem = 2
39
40 def __int__(self):
41 """ Convert to int. """
42 return self.value
43
44 @classmethod
45 def create_signal(cls, *, src_loc_at=0, **kwargs):
46 """ Create a signal that can contain a ``DivPipeCoreOperation``. """
47 return Signal(range(min(map(int, cls)), max(map(int, cls)) + 2),
48 src_loc_at=(src_loc_at + 1),
49 decoder=lambda v: str(cls(v)),
50 **kwargs)
51
52
53 DP = DivPipeCoreOperation
54
55
56 class DivPipeCoreConfig:
57 """ Configuration for core of the div/rem/sqrt/rsqrt pipeline.
58
59 :attribute bit_width: base bit-width.
60 :attribute fract_width: base fract-width. Specifies location of base-2
61 radix point.
62 :attribute log2_radix: number of bits of ``quotient_root`` that should be
63 computed per pipeline stage.
64 """
65
66 def __init__(self, bit_width, fract_width, log2_radix, supported=None):
67 """ Create a ``DivPipeCoreConfig`` instance. """
68 self.bit_width = bit_width
69 self.fract_width = fract_width
70 self.log2_radix = log2_radix
71 if supported is None:
72 supported = frozenset(DP)
73 else:
74 supported = frozenset(supported)
75 self.supported = supported
76 print(f"{self}: n_stages={self.n_stages}")
77
78 def __repr__(self):
79 """ Get repr. """
80 return f"DivPipeCoreConfig({self.bit_width}, " \
81 + f"{self.fract_width}, {self.log2_radix}, "\
82 + f"supported={self.supported})"
83
84 @property
85 def n_stages(self):
86 """ Get the number of ``DivPipeCoreCalculateStage`` needed. """
87 return (self.bit_width + self.log2_radix - 1) // self.log2_radix
88
89
90 class DivPipeCoreInputData:
91 """ input data type for ``DivPipeCore``.
92
93 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
94 configuration to be used.
95 :attribute dividend: dividend for div/rem. Signal with a bit-width of
96 ``core_config.bit_width + core_config.fract_width`` and a fract-width
97 of ``core_config.fract_width * 2`` bits.
98 :attribute divisor_radicand: divisor for div/rem and radicand for
99 sqrt/rsqrt. Signal with a bit-width of ``core_config.bit_width`` and a
100 fract-width of ``core_config.fract_width`` bits.
101 :attribute operation: the ``DivPipeCoreOperation`` to be computed.
102 """
103
104 def __init__(self, core_config, reset_less=True):
105 """ Create a ``DivPipeCoreInputData`` instance. """
106 self.core_config = core_config
107 bw = core_config.bit_width
108 fw = core_config.fract_width
109 self.dividend = Signal(bw + fw, reset_less=reset_less)
110 self.divisor_radicand = Signal(bw, reset_less=reset_less)
111 self.operation = DP.create_signal(reset_less=reset_less)
112
113 def __iter__(self):
114 """ Get member signals. """
115 yield self.dividend
116 yield self.divisor_radicand
117 yield self.operation
118
119 def eq(self, rhs):
120 """ Assign member signals. """
121 return [self.dividend.eq(rhs.dividend),
122 self.divisor_radicand.eq(rhs.divisor_radicand),
123 self.operation.eq(rhs.operation),
124 ]
125
126
127 class DivPipeCoreInterstageData:
128 """ interstage data type for ``DivPipeCore``.
129
130 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
131 configuration to be used.
132 :attribute divisor_radicand: divisor for div/rem and radicand for
133 sqrt/rsqrt. Signal with a bit-width of ``core_config.bit_width`` and a
134 fract-width of ``core_config.fract_width`` bits.
135 :attribute operation: the ``DivPipeCoreOperation`` to be computed.
136 :attribute quotient_root: the quotient or root part of the result of the
137 operation. Signal with a bit-width of ``core_config.bit_width`` and a
138 fract-width of ``core_config.fract_width`` bits.
139 :attribute root_times_radicand: ``quotient_root * divisor_radicand``.
140 Signal with a bit-width of ``core_config.bit_width * 2`` and a
141 fract-width of ``core_config.fract_width * 2`` bits.
142 :attribute compare_lhs: The left-hand-side of the comparison in the
143 equation to be solved. Signal with a bit-width of
144 ``core_config.bit_width * 3`` and a fract-width of
145 ``core_config.fract_width * 3`` bits.
146 :attribute compare_rhs: The right-hand-side of the comparison in the
147 equation to be solved. Signal with a bit-width of
148 ``core_config.bit_width * 3`` and a fract-width of
149 ``core_config.fract_width * 3`` bits.
150 """
151
152 def __init__(self, core_config, reset_less=True):
153 """ Create a ``DivPipeCoreInterstageData`` instance. """
154 self.core_config = core_config
155 bw = core_config.bit_width
156 # TODO(programmerjake): re-enable once bit_width reduction is fixed
157 if False and core_config.supported == {DP.UDivRem}:
158 self.compare_len = bw * 2
159 else:
160 self.compare_len = bw * 3
161 self.divisor_radicand = Signal(bw, reset_less=reset_less)
162 self.operation = DP.create_signal(reset_less=reset_less)
163 self.quotient_root = Signal(bw, reset_less=reset_less)
164 self.root_times_radicand = Signal(bw * 2, reset_less=reset_less)
165 self.compare_lhs = Signal(self.compare_len, reset_less=reset_less)
166 self.compare_rhs = Signal(self.compare_len, reset_less=reset_less)
167
168 def __iter__(self):
169 """ Get member signals. """
170 yield self.divisor_radicand
171 yield self.operation
172 yield self.quotient_root
173 yield self.root_times_radicand
174 yield self.compare_lhs
175 yield self.compare_rhs
176
177 def eq(self, rhs):
178 """ Assign member signals. """
179 return [self.divisor_radicand.eq(rhs.divisor_radicand),
180 self.operation.eq(rhs.operation),
181 self.quotient_root.eq(rhs.quotient_root),
182 self.root_times_radicand.eq(rhs.root_times_radicand),
183 self.compare_lhs.eq(rhs.compare_lhs),
184 self.compare_rhs.eq(rhs.compare_rhs)]
185
186
187 class DivPipeCoreOutputData:
188 """ output data type for ``DivPipeCore``.
189
190 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
191 configuration to be used.
192 :attribute quotient_root: the quotient or root part of the result of the
193 operation. Signal with a bit-width of ``core_config.bit_width`` and a
194 fract-width of ``core_config.fract_width`` bits.
195 :attribute remainder: the remainder part of the result of the operation.
196 Signal with a bit-width of ``core_config.bit_width * 3`` and a
197 fract-width of ``core_config.fract_width * 3`` bits.
198 """
199
200 def __init__(self, core_config, reset_less=True):
201 """ Create a ``DivPipeCoreOutputData`` instance. """
202 self.core_config = core_config
203 bw = core_config.bit_width
204 # TODO(programmerjake): re-enable once bit_width reduction is fixed
205 if False and core_config.supported == {DP.UDivRem}:
206 self.compare_len = bw * 2
207 else:
208 self.compare_len = bw * 3
209 self.quotient_root = Signal(bw, reset_less=reset_less)
210 self.remainder = Signal(self.compare_len, reset_less=reset_less)
211
212 def __iter__(self):
213 """ Get member signals. """
214 yield self.quotient_root
215 yield self.remainder
216 return
217
218 def eq(self, rhs):
219 """ Assign member signals. """
220 return [self.quotient_root.eq(rhs.quotient_root),
221 self.remainder.eq(rhs.remainder)]
222
223
224 class DivPipeCoreSetupStage(Elaboratable):
225 """ Setup Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
226
227 def __init__(self, core_config):
228 """ Create a ``DivPipeCoreSetupStage`` instance."""
229 self.core_config = core_config
230 self.i = self.ispec()
231 self.o = self.ospec()
232 bw = core_config.bit_width
233 # TODO(programmerjake): re-enable once bit_width reduction is fixed
234 if False and core_config.supported == {DP.UDivRem}:
235 self.compare_len = bw * 2
236 else:
237 self.compare_len = bw * 3
238
239 def ispec(self):
240 """ Get the input spec for this pipeline stage."""
241 return DivPipeCoreInputData(self.core_config)
242
243 def ospec(self):
244 """ Get the output spec for this pipeline stage."""
245 return DivPipeCoreInterstageData(self.core_config)
246
247 def setup(self, m, i):
248 """ Pipeline stage setup. """
249 m.submodules.div_pipe_core_setup = self
250 m.d.comb += self.i.eq(i)
251
252 def process(self, i):
253 """ Pipeline stage process. """
254 return self.o # return processed data (ignore i)
255
256 def elaborate(self, platform):
257 """ Elaborate into ``Module``. """
258 m = Module()
259 comb = m.d.comb
260
261 comb += self.o.divisor_radicand.eq(self.i.divisor_radicand)
262 comb += self.o.quotient_root.eq(0)
263 comb += self.o.root_times_radicand.eq(0)
264
265 lhs = Signal(self.compare_len, reset_less=True)
266 fw = self.core_config.fract_width
267
268 with m.Switch(self.i.operation):
269 with m.Case(int(DP.UDivRem)):
270 comb += lhs.eq(self.i.dividend << fw)
271 with m.Case(int(DP.SqrtRem)):
272 comb += lhs.eq(self.i.divisor_radicand << (fw * 2))
273 with m.Case(int(DP.RSqrtRem)):
274 comb += lhs.eq(1 << (fw * 3))
275
276 comb += self.o.compare_lhs.eq(lhs)
277 comb += self.o.compare_rhs.eq(0)
278 comb += self.o.operation.eq(self.i.operation)
279
280 return m
281
282
283 class Trial(Elaboratable):
284 def __init__(self, core_config, trial_bits, current_shift, log2_radix):
285 self.core_config = core_config
286 self.trial_bits = trial_bits
287 self.current_shift = current_shift
288 self.log2_radix = log2_radix
289 bw = core_config.bit_width
290 # TODO(programmerjake): re-enable once bit_width reduction is fixed
291 if False and core_config.supported == {DP.UDivRem}:
292 self.compare_len = bw * 2
293 else:
294 self.compare_len = bw * 3
295 self.divisor_radicand = Signal(bw, reset_less=True)
296 self.quotient_root = Signal(bw, reset_less=True)
297 self.root_times_radicand = Signal(bw * 2, reset_less=True)
298 self.compare_rhs = Signal(self.compare_len, reset_less=True)
299 self.trial_compare_rhs = Signal(self.compare_len, reset_less=True)
300 self.operation = DP.create_signal(reset_less=True)
301
302 def elaborate(self, platform):
303
304 m = Module()
305 comb = m.d.comb
306
307 cc = self.core_config
308 dr = self.divisor_radicand
309
310 trial_bits_sig = Const(self.trial_bits, self.log2_radix)
311 trial_bits_sqrd_sig = Const(self.trial_bits * self.trial_bits,
312 self.log2_radix * 2)
313
314 tblen = self.core_config.bit_width+self.log2_radix
315
316 # UDivRem
317 if DP.UDivRem in cc.supported:
318 with m.If(self.operation == int(DP.UDivRem)):
319 dr_times_trial_bits = Signal(tblen, reset_less=True)
320 comb += dr_times_trial_bits.eq(dr * trial_bits_sig)
321 div_rhs = self.compare_rhs
322
323 div_term1 = dr_times_trial_bits
324 div_term1_shift = self.core_config.fract_width
325 div_term1_shift += self.current_shift
326 div_rhs += div_term1 << div_term1_shift
327
328 comb += self.trial_compare_rhs.eq(div_rhs)
329
330 # SqrtRem
331 if DP.SqrtRem in cc.supported:
332 with m.If(self.operation == int(DP.SqrtRem)):
333 qr = self.quotient_root
334 qr_times_trial_bits = Signal((tblen+1)*2, reset_less=True)
335 comb += qr_times_trial_bits.eq(qr * trial_bits_sig)
336 sqrt_rhs = self.compare_rhs
337
338 sqrt_term1 = qr_times_trial_bits
339 sqrt_term1_shift = self.core_config.fract_width
340 sqrt_term1_shift += self.current_shift + 1
341 sqrt_rhs += sqrt_term1 << sqrt_term1_shift
342 sqrt_term2 = trial_bits_sqrd_sig
343 sqrt_term2_shift = self.core_config.fract_width
344 sqrt_term2_shift += self.current_shift * 2
345 sqrt_rhs += sqrt_term2 << sqrt_term2_shift
346
347 comb += self.trial_compare_rhs.eq(sqrt_rhs)
348
349 # RSqrtRem
350 if DP.RSqrtRem in cc.supported:
351 with m.If(self.operation == int(DP.RSqrtRem)):
352 rr = self.root_times_radicand
353 tblen2 = self.core_config.bit_width+self.log2_radix*2
354 dr_times_trial_bits_sqrd = Signal(tblen2, reset_less=True)
355 comb += dr_times_trial_bits_sqrd.eq(dr * trial_bits_sqrd_sig)
356 rr_times_trial_bits = Signal((tblen+1)*3, reset_less=True)
357 comb += rr_times_trial_bits.eq(rr * trial_bits_sig)
358 rsqrt_rhs = self.compare_rhs
359
360 rsqrt_term1 = rr_times_trial_bits
361 rsqrt_term1_shift = self.current_shift + 1
362 rsqrt_rhs += rsqrt_term1 << rsqrt_term1_shift
363 rsqrt_term2 = dr_times_trial_bits_sqrd
364 rsqrt_term2_shift = self.current_shift * 2
365 rsqrt_rhs += rsqrt_term2 << rsqrt_term2_shift
366
367 comb += self.trial_compare_rhs.eq(rsqrt_rhs)
368
369 return m
370
371
372 class DivPipeCoreCalculateStage(Elaboratable):
373 """ Calculate Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
374
375 def __init__(self, core_config, stage_index):
376 """ Create a ``DivPipeCoreSetupStage`` instance. """
377 assert stage_index in range(core_config.n_stages)
378 self.core_config = core_config
379 bw = core_config.bit_width
380 # TODO(programmerjake): re-enable once bit_width reduction is fixed
381 if False and core_config.supported == {DP.UDivRem}:
382 self.compare_len = bw * 2
383 else:
384 self.compare_len = bw * 3
385 self.stage_index = stage_index
386 self.i = self.ispec()
387 self.o = self.ospec()
388
389 def ispec(self):
390 """ Get the input spec for this pipeline stage. """
391 return DivPipeCoreInterstageData(self.core_config)
392
393 def ospec(self):
394 """ Get the output spec for this pipeline stage. """
395 return DivPipeCoreInterstageData(self.core_config)
396
397 def setup(self, m, i):
398 """ Pipeline stage setup. """
399 setattr(m.submodules,
400 f"div_pipe_core_calculate_{self.stage_index}",
401 self)
402 m.d.comb += self.i.eq(i)
403
404 def process(self, i):
405 """ Pipeline stage process. """
406 return self.o
407
408 def elaborate(self, platform):
409 """ Elaborate into ``Module``. """
410 m = Module()
411 comb = m.d.comb
412 cc = self.core_config
413
414 # copy invariant inputs to outputs (for next stage)
415 comb += self.o.divisor_radicand.eq(self.i.divisor_radicand)
416 comb += self.o.operation.eq(self.i.operation)
417 comb += self.o.compare_lhs.eq(self.i.compare_lhs)
418
419 # constants
420 log2_radix = self.core_config.log2_radix
421 current_shift = self.core_config.bit_width
422 current_shift -= self.stage_index * log2_radix
423 log2_radix = min(log2_radix, current_shift)
424 assert log2_radix > 0
425 current_shift -= log2_radix
426 print(f"DivPipeCoreCalc: stage {self.stage_index}"
427 + f" of {self.core_config.n_stages} handling "
428 + f"bits [{current_shift}, {current_shift+log2_radix})"
429 + f" of {self.core_config.bit_width}")
430 radix = 1 << log2_radix
431
432 # trials within this radix range. carried out by Trial module,
433 # results stored in pass_flags. pass_flags are unary priority.
434 trial_compare_rhs_values = []
435 pfl = []
436 for trial_bits in range(radix):
437 t = Trial(self.core_config, trial_bits, current_shift, log2_radix)
438 setattr(m.submodules, "trial%d" % trial_bits, t)
439
440 comb += t.divisor_radicand.eq(self.i.divisor_radicand)
441 comb += t.quotient_root.eq(self.i.quotient_root)
442 comb += t.root_times_radicand.eq(self.i.root_times_radicand)
443 comb += t.compare_rhs.eq(self.i.compare_rhs)
444 comb += t.operation.eq(self.i.operation)
445
446 # get the trial output (needed even in pass_flags[0] case)
447 trial_compare_rhs_values.append(t.trial_compare_rhs)
448
449 # make the trial comparison against the [invariant] lhs.
450 # trial_compare_rhs is always decreasing as trial_bits increases
451 pass_flag = Signal(name=f"pass_flag_{trial_bits}", reset_less=True)
452 if trial_bits == 0:
453 # do not do first comparison: no point.
454 comb += pass_flag.eq(1)
455 else:
456 comb += pass_flag.eq(self.i.compare_lhs >= t.trial_compare_rhs)
457 pfl.append(pass_flag)
458
459 # Cat all the pass flags list together (easier to handle, below)
460 pass_flags = Signal(radix, reset_less=True)
461 comb += pass_flags.eq(Cat(*pfl))
462
463 # convert pass_flags (unary priority) to next_bits (binary index)
464 #
465 # Assumes that for each set bit in pass_flag, all previous bits are
466 # also set.
467 #
468 # Assumes that pass_flag[0] is always set (since
469 # compare_lhs >= compare_rhs is a pipeline invariant).
470
471 m.submodules.pe = pe = PriorityEncoder(radix)
472 next_bits = Signal(log2_radix, reset_less=True)
473 comb += pe.i.eq(~pass_flags)
474 with m.If(~pe.n):
475 comb += next_bits.eq(pe.o-1)
476 with m.Else():
477 comb += next_bits.eq(radix-1)
478
479 # get the highest passing rhs trial. use treereduce because
480 # Array on such massively long numbers is insanely gate-hungry
481 crhs = []
482 tcrh = trial_compare_rhs_values
483 for i in range(radix):
484 nbe = Signal(reset_less=True)
485 comb += nbe.eq(next_bits == i)
486 crhs.append(Repl(nbe, self.compare_len) & tcrh[i])
487 comb += self.o.compare_rhs.eq(treereduce(crhs, operator.or_,
488 lambda x:x))
489
490 # create outputs for next phase
491 qr = self.i.quotient_root | (next_bits << current_shift)
492 comb += self.o.quotient_root.eq(qr)
493 if DP.RSqrtRem in cc.supported:
494 rr = self.i.root_times_radicand + ((self.i.divisor_radicand *
495 next_bits) << current_shift)
496 comb += self.o.root_times_radicand.eq(rr)
497
498 return m
499
500
501 class DivPipeCoreFinalStage(Elaboratable):
502 """ Final Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
503
504 def __init__(self, core_config):
505 """ Create a ``DivPipeCoreFinalStage`` instance."""
506 self.core_config = core_config
507 self.i = self.ispec()
508 self.o = self.ospec()
509
510 def ispec(self):
511 """ Get the input spec for this pipeline stage."""
512 return DivPipeCoreInterstageData(self.core_config)
513
514 def ospec(self):
515 """ Get the output spec for this pipeline stage."""
516 return DivPipeCoreOutputData(self.core_config)
517
518 def setup(self, m, i):
519 """ Pipeline stage setup. """
520 m.submodules.div_pipe_core_final = self
521 m.d.comb += self.i.eq(i)
522
523 def process(self, i):
524 """ Pipeline stage process. """
525 return self.o # return processed data (ignore i)
526
527 def elaborate(self, platform):
528 """ Elaborate into ``Module``. """
529 m = Module()
530 comb = m.d.comb
531
532 comb += self.o.quotient_root.eq(self.i.quotient_root)
533 comb += self.o.remainder.eq(self.i.compare_lhs - self.i.compare_rhs)
534
535 return m