switch to exact version of cython
[ieee754fpu.git] / src / ieee754 / div_rem_sqrt_rsqrt / core.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """ Core of the div/rem/sqrt/rsqrt pipeline.
4
5 Special case handling, input/output conversion, and muxid handling are handled
6 outside of these classes.
7
8 Algorithms based on ``algorithm.FixedUDivRemSqrtRSqrt``.
9
10 Formulas solved are:
11 * div/rem:
12 ``dividend == quotient_root * divisor_radicand``
13 * sqrt/rem:
14 ``divisor_radicand == quotient_root * quotient_root``
15 * rsqrt/rem:
16 ``1 == quotient_root * quotient_root * divisor_radicand``
17
18 The remainder is the left-hand-side of the comparison minus the
19 right-hand-side of the comparison in the above formulas.
20 """
21 from nmigen import (Elaboratable, Module, Signal, Const, Mux, Cat, Array)
22 from nmigen.lib.coding import PriorityEncoder
23 import enum
24
25
26 class DivPipeCoreConfig:
27 """ Configuration for core of the div/rem/sqrt/rsqrt pipeline.
28
29 :attribute bit_width: base bit-width.
30 :attribute fract_width: base fract-width. Specifies location of base-2
31 radix point.
32 :attribute log2_radix: number of bits of ``quotient_root`` that should be
33 computed per pipeline stage.
34 """
35
36 def __init__(self, bit_width, fract_width, log2_radix):
37 """ Create a ``DivPipeCoreConfig`` instance. """
38 self.bit_width = bit_width
39 self.fract_width = fract_width
40 self.log2_radix = log2_radix
41 print(f"{self}: n_stages={self.n_stages}")
42
43 def __repr__(self):
44 """ Get repr. """
45 return f"DivPipeCoreConfig({self.bit_width}, " \
46 + f"{self.fract_width}, {self.log2_radix})"
47
48 @property
49 def n_stages(self):
50 """ Get the number of ``DivPipeCoreCalculateStage`` needed. """
51 return (self.bit_width + self.log2_radix - 1) // self.log2_radix
52
53
54 class DivPipeCoreOperation(enum.Enum):
55 """ Operation for ``DivPipeCore``.
56
57 :attribute UDivRem: unsigned divide/remainder.
58 :attribute SqrtRem: square-root/remainder.
59 :attribute RSqrtRem: reciprocal-square-root/remainder.
60 """
61
62 SqrtRem = 0
63 UDivRem = 1
64 RSqrtRem = 2
65
66 def __int__(self):
67 """ Convert to int. """
68 return self.value
69
70 @classmethod
71 def create_signal(cls, *, src_loc_at=0, **kwargs):
72 """ Create a signal that can contain a ``DivPipeCoreOperation``. """
73 return Signal(min=min(map(int, cls)),
74 max=max(map(int, cls)) + 2,
75 src_loc_at=(src_loc_at + 1),
76 decoder=lambda v: str(cls(v)),
77 **kwargs)
78
79
80 DP = DivPipeCoreOperation
81
82
83 class DivPipeCoreInputData:
84 """ input data type for ``DivPipeCore``.
85
86 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
87 configuration to be used.
88 :attribute dividend: dividend for div/rem. Signal with a bit-width of
89 ``core_config.bit_width + core_config.fract_width`` and a fract-width
90 of ``core_config.fract_width * 2`` bits.
91 :attribute divisor_radicand: divisor for div/rem and radicand for
92 sqrt/rsqrt. Signal with a bit-width of ``core_config.bit_width`` and a
93 fract-width of ``core_config.fract_width`` bits.
94 :attribute operation: the ``DivPipeCoreOperation`` to be computed.
95 """
96
97 def __init__(self, core_config, reset_less=True):
98 """ Create a ``DivPipeCoreInputData`` instance. """
99 self.core_config = core_config
100 bw = core_config.bit_width
101 fw = core_config.fract_width
102 self.dividend = Signal(bw + fw, reset_less=reset_less)
103 self.divisor_radicand = Signal(bw, reset_less=reset_less)
104 self.operation = DP.create_signal(reset_less=reset_less)
105
106 def __iter__(self):
107 """ Get member signals. """
108 yield self.dividend
109 yield self.divisor_radicand
110 yield self.operation
111
112 def eq(self, rhs):
113 """ Assign member signals. """
114 return [self.dividend.eq(rhs.dividend),
115 self.divisor_radicand.eq(rhs.divisor_radicand),
116 self.operation.eq(rhs.operation),
117 ]
118
119
120 class DivPipeCoreInterstageData:
121 """ interstage data type for ``DivPipeCore``.
122
123 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
124 configuration to be used.
125 :attribute divisor_radicand: divisor for div/rem and radicand for
126 sqrt/rsqrt. Signal with a bit-width of ``core_config.bit_width`` and a
127 fract-width of ``core_config.fract_width`` bits.
128 :attribute operation: the ``DivPipeCoreOperation`` to be computed.
129 :attribute quotient_root: the quotient or root part of the result of the
130 operation. Signal with a bit-width of ``core_config.bit_width`` and a
131 fract-width of ``core_config.fract_width`` bits.
132 :attribute root_times_radicand: ``quotient_root * divisor_radicand``.
133 Signal with a bit-width of ``core_config.bit_width * 2`` and a
134 fract-width of ``core_config.fract_width * 2`` bits.
135 :attribute compare_lhs: The left-hand-side of the comparison in the
136 equation to be solved. Signal with a bit-width of
137 ``core_config.bit_width * 3`` and a fract-width of
138 ``core_config.fract_width * 3`` bits.
139 :attribute compare_rhs: The right-hand-side of the comparison in the
140 equation to be solved. Signal with a bit-width of
141 ``core_config.bit_width * 3`` and a fract-width of
142 ``core_config.fract_width * 3`` bits.
143 """
144
145 def __init__(self, core_config, reset_less=True):
146 """ Create a ``DivPipeCoreInterstageData`` instance. """
147 self.core_config = core_config
148 bw = core_config.bit_width
149 self.divisor_radicand = Signal(bw, reset_less=reset_less)
150 self.operation = DP.create_signal(reset_less=reset_less)
151 self.quotient_root = Signal(bw, reset_less=reset_less)
152 self.root_times_radicand = Signal(bw * 2, reset_less=reset_less)
153 self.compare_lhs = Signal(bw * 3, reset_less=reset_less)
154 self.compare_rhs = Signal(bw * 3, reset_less=reset_less)
155
156 def __iter__(self):
157 """ Get member signals. """
158 yield self.divisor_radicand
159 yield self.operation
160 yield self.quotient_root
161 yield self.root_times_radicand
162 yield self.compare_lhs
163 yield self.compare_rhs
164
165 def eq(self, rhs):
166 """ Assign member signals. """
167 return [self.divisor_radicand.eq(rhs.divisor_radicand),
168 self.operation.eq(rhs.operation),
169 self.quotient_root.eq(rhs.quotient_root),
170 self.root_times_radicand.eq(rhs.root_times_radicand),
171 self.compare_lhs.eq(rhs.compare_lhs),
172 self.compare_rhs.eq(rhs.compare_rhs)]
173
174
175 class DivPipeCoreOutputData:
176 """ output data type for ``DivPipeCore``.
177
178 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
179 configuration to be used.
180 :attribute quotient_root: the quotient or root part of the result of the
181 operation. Signal with a bit-width of ``core_config.bit_width`` and a
182 fract-width of ``core_config.fract_width`` bits.
183 :attribute remainder: the remainder part of the result of the operation.
184 Signal with a bit-width of ``core_config.bit_width * 3`` and a
185 fract-width of ``core_config.fract_width * 3`` bits.
186 """
187
188 def __init__(self, core_config, reset_less=True):
189 """ Create a ``DivPipeCoreOutputData`` instance. """
190 self.core_config = core_config
191 bw = core_config.bit_width
192 self.quotient_root = Signal(bw, reset_less=reset_less)
193 self.remainder = Signal(bw * 3, reset_less=reset_less)
194
195 def __iter__(self):
196 """ Get member signals. """
197 yield self.quotient_root
198 yield self.remainder
199 return
200
201 def eq(self, rhs):
202 """ Assign member signals. """
203 return [self.quotient_root.eq(rhs.quotient_root),
204 self.remainder.eq(rhs.remainder)]
205
206
207 class DivPipeCoreSetupStage(Elaboratable):
208 """ Setup Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
209
210 def __init__(self, core_config):
211 """ Create a ``DivPipeCoreSetupStage`` instance."""
212 self.core_config = core_config
213 self.i = self.ispec()
214 self.o = self.ospec()
215
216 def ispec(self):
217 """ Get the input spec for this pipeline stage."""
218 return DivPipeCoreInputData(self.core_config)
219
220 def ospec(self):
221 """ Get the output spec for this pipeline stage."""
222 return DivPipeCoreInterstageData(self.core_config)
223
224 def setup(self, m, i):
225 """ Pipeline stage setup. """
226 m.submodules.div_pipe_core_setup = self
227 m.d.comb += self.i.eq(i)
228
229 def process(self, i):
230 """ Pipeline stage process. """
231 return self.o # return processed data (ignore i)
232
233 def elaborate(self, platform):
234 """ Elaborate into ``Module``. """
235 m = Module()
236 comb = m.d.comb
237
238 comb += self.o.divisor_radicand.eq(self.i.divisor_radicand)
239 comb += self.o.quotient_root.eq(0)
240 comb += self.o.root_times_radicand.eq(0)
241
242 lhs = Signal(self.core_config.bit_width * 3, reset_less=True)
243 fw = self.core_config.fract_width
244
245 with m.Switch(self.i.operation):
246 with m.Case(int(DP.UDivRem)):
247 comb += lhs.eq(self.i.dividend << fw)
248 with m.Case(int(DP.SqrtRem)):
249 comb += lhs.eq(self.i.divisor_radicand << (fw * 2))
250 with m.Case(int(DP.RSqrtRem)):
251 comb += lhs.eq(1 << (fw * 3))
252
253 comb += self.o.compare_lhs.eq(lhs)
254 comb += self.o.compare_rhs.eq(0)
255 comb += self.o.operation.eq(self.i.operation)
256
257 return m
258
259
260 class Trial(Elaboratable):
261 def __init__(self, core_config, trial_bits, current_shift, log2_radix):
262 self.core_config = core_config
263 self.trial_bits = trial_bits
264 self.current_shift = current_shift
265 self.log2_radix = log2_radix
266 bw = core_config.bit_width
267 self.divisor_radicand = Signal(bw, reset_less=True)
268 self.quotient_root = Signal(bw, reset_less=True)
269 self.root_times_radicand = Signal(bw * 2, reset_less=True)
270 self.compare_rhs = Signal(bw * 3, reset_less=True)
271 self.trial_compare_rhs = Signal(bw * 3, reset_less=True)
272 self.operation = DP.create_signal(reset_less=True)
273
274 def elaborate(self, platform):
275
276 m = Module()
277 comb = m.d.comb
278
279 dr = self.divisor_radicand
280 qr = self.quotient_root
281 rr = self.root_times_radicand
282
283 trial_bits_sig = Const(self.trial_bits, self.log2_radix)
284 trial_bits_sqrd_sig = Const(self.trial_bits * self.trial_bits,
285 self.log2_radix * 2)
286
287 tblen = self.core_config.bit_width+self.log2_radix
288 tblen2 = self.core_config.bit_width+self.log2_radix*2
289 dr_times_trial_bits_sqrd = Signal(tblen2, reset_less=True)
290 comb += dr_times_trial_bits_sqrd.eq(dr * trial_bits_sqrd_sig)
291
292 with m.Switch(self.operation):
293 # UDivRem
294 with m.Case(int(DP.UDivRem)):
295 dr_times_trial_bits = Signal(tblen, reset_less=True)
296 comb += dr_times_trial_bits.eq(dr * trial_bits_sig)
297 div_rhs = self.compare_rhs
298
299 div_term1 = dr_times_trial_bits
300 div_term1_shift = self.core_config.fract_width
301 div_term1_shift += self.current_shift
302 div_rhs += div_term1 << div_term1_shift
303
304 comb += self.trial_compare_rhs.eq(div_rhs)
305
306 # SqrtRem
307 with m.Case(int(DP.SqrtRem)):
308 qr_times_trial_bits = Signal((tblen+1)*2, reset_less=True)
309 comb += qr_times_trial_bits.eq(qr * trial_bits_sig)
310 sqrt_rhs = self.compare_rhs
311
312 sqrt_term1 = qr_times_trial_bits
313 sqrt_term1_shift = self.core_config.fract_width
314 sqrt_term1_shift += self.current_shift + 1
315 sqrt_rhs += sqrt_term1 << sqrt_term1_shift
316 sqrt_term2 = trial_bits_sqrd_sig
317 sqrt_term2_shift = self.core_config.fract_width
318 sqrt_term2_shift += self.current_shift * 2
319 sqrt_rhs += sqrt_term2 << sqrt_term2_shift
320
321 comb += self.trial_compare_rhs.eq(sqrt_rhs)
322
323 # RSqrtRem
324 with m.Case(int(DP.RSqrtRem)):
325 rr_times_trial_bits = Signal((tblen+1)*3, reset_less=True)
326 comb += rr_times_trial_bits.eq(rr * trial_bits_sig)
327 rsqrt_rhs = self.compare_rhs
328
329 rsqrt_term1 = rr_times_trial_bits
330 rsqrt_term1_shift = self.current_shift + 1
331 rsqrt_rhs += rsqrt_term1 << rsqrt_term1_shift
332 rsqrt_term2 = dr_times_trial_bits_sqrd
333 rsqrt_term2_shift = self.current_shift * 2
334 rsqrt_rhs += rsqrt_term2 << rsqrt_term2_shift
335
336 comb += self.trial_compare_rhs.eq(rsqrt_rhs)
337
338 return m
339
340
341 class DivPipeCoreCalculateStage(Elaboratable):
342 """ Calculate Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
343
344 def __init__(self, core_config, stage_index):
345 """ Create a ``DivPipeCoreSetupStage`` instance. """
346 assert stage_index in range(core_config.n_stages)
347 self.core_config = core_config
348 self.stage_index = stage_index
349 self.i = self.ispec()
350 self.o = self.ospec()
351
352 def ispec(self):
353 """ Get the input spec for this pipeline stage. """
354 return DivPipeCoreInterstageData(self.core_config)
355
356 def ospec(self):
357 """ Get the output spec for this pipeline stage. """
358 return DivPipeCoreInterstageData(self.core_config)
359
360 def setup(self, m, i):
361 """ Pipeline stage setup. """
362 setattr(m.submodules,
363 f"div_pipe_core_calculate_{self.stage_index}",
364 self)
365 m.d.comb += self.i.eq(i)
366
367 def process(self, i):
368 """ Pipeline stage process. """
369 return self.o
370
371 def elaborate(self, platform):
372 """ Elaborate into ``Module``. """
373 m = Module()
374 comb = m.d.comb
375
376 # copy invariant inputs to outputs (for next stage)
377 comb += self.o.divisor_radicand.eq(self.i.divisor_radicand)
378 comb += self.o.operation.eq(self.i.operation)
379 comb += self.o.compare_lhs.eq(self.i.compare_lhs)
380
381 # constants
382 log2_radix = self.core_config.log2_radix
383 current_shift = self.core_config.bit_width
384 current_shift -= self.stage_index * log2_radix
385 log2_radix = min(log2_radix, current_shift)
386 assert log2_radix > 0
387 current_shift -= log2_radix
388 print(f"DivPipeCoreCalc: stage {self.stage_index}"
389 + f" of {self.core_config.n_stages} handling "
390 + f"bits [{current_shift}, {current_shift+log2_radix})"
391 + f" of {self.core_config.bit_width}")
392 radix = 1 << log2_radix
393
394 # trials within this radix range. carried out by Trial module,
395 # results stored in pass_flags. pass_flags are unary priority.
396 trial_compare_rhs_values = []
397 pfl = []
398 for trial_bits in range(radix):
399 t = Trial(self.core_config, trial_bits, current_shift, log2_radix)
400 setattr(m.submodules, "trial%d" % trial_bits, t)
401
402 comb += t.divisor_radicand.eq(self.i.divisor_radicand)
403 comb += t.quotient_root.eq(self.i.quotient_root)
404 comb += t.root_times_radicand.eq(self.i.root_times_radicand)
405 comb += t.compare_rhs.eq(self.i.compare_rhs)
406 comb += t.operation.eq(self.i.operation)
407
408 # get the trial output
409 trial_compare_rhs_values.append(t.trial_compare_rhs)
410
411 # make the trial comparison against the [invariant] lhs.
412 # trial_compare_rhs is always decreasing as trial_bits increases
413 pass_flag = Signal(name=f"pass_flag_{trial_bits}", reset_less=True)
414 comb += pass_flag.eq(self.i.compare_lhs >= t.trial_compare_rhs)
415 pfl.append(pass_flag)
416
417 # Cat all the pass flags list together (easier to handle, below)
418 pass_flags = Signal(radix, reset_less=True)
419 comb += pass_flags.eq(Cat(*pfl))
420
421 # convert pass_flags (unary priority) to next_bits (binary index)
422 #
423 # Assumes that for each set bit in pass_flag, all previous bits are
424 # also set.
425 #
426 # Assumes that pass_flag[0] is always set (since
427 # compare_lhs >= compare_rhs is a pipeline invariant).
428
429 m.submodules.pe = pe = PriorityEncoder(radix)
430 next_bits = Signal(log2_radix, reset_less=True)
431 comb += pe.i.eq(~pass_flags)
432 with m.If(~pe.n):
433 comb += next_bits.eq(pe.o-1)
434 with m.Else():
435 comb += next_bits.eq(radix-1)
436
437 # get the highest passing rhs trial (indexed by next_bits)
438 ta = Array(trial_compare_rhs_values)
439 comb += self.o.compare_rhs.eq(ta[next_bits])
440
441 # create outputs for next phase
442 qr = self.i.quotient_root | (next_bits << current_shift)
443 rr = self.i.root_times_radicand + ((self.i.divisor_radicand * next_bits)
444 << current_shift)
445 comb += self.o.quotient_root.eq(qr)
446 comb += self.o.root_times_radicand.eq(rr)
447
448 return m
449
450
451 class DivPipeCoreFinalStage(Elaboratable):
452 """ Final Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
453
454 def __init__(self, core_config):
455 """ Create a ``DivPipeCoreFinalStage`` instance."""
456 self.core_config = core_config
457 self.i = self.ispec()
458 self.o = self.ospec()
459
460 def ispec(self):
461 """ Get the input spec for this pipeline stage."""
462 return DivPipeCoreInterstageData(self.core_config)
463
464 def ospec(self):
465 """ Get the output spec for this pipeline stage."""
466 return DivPipeCoreOutputData(self.core_config)
467
468 def setup(self, m, i):
469 """ Pipeline stage setup. """
470 m.submodules.div_pipe_core_final = self
471 m.d.comb += self.i.eq(i)
472
473 def process(self, i):
474 """ Pipeline stage process. """
475 return self.o # return processed data (ignore i)
476
477 def elaborate(self, platform):
478 """ Elaborate into ``Module``. """
479 m = Module()
480 comb = m.d.comb
481
482 comb += self.o.quotient_root.eq(self.i.quotient_root)
483 comb += self.o.remainder.eq(self.i.compare_lhs - self.i.compare_rhs)
484
485 return m