cut root_times_radicand if not doing Sqrt
[ieee754fpu.git] / src / ieee754 / div_rem_sqrt_rsqrt / core.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """ Core of the div/rem/sqrt/rsqrt pipeline.
4
5 Special case handling, input/output conversion, and muxid handling are handled
6 outside of these classes.
7
8 Algorithms based on ``algorithm.FixedUDivRemSqrtRSqrt``.
9
10 Formulas solved are:
11 * div/rem:
12 ``dividend == quotient_root * divisor_radicand``
13 * sqrt/rem:
14 ``divisor_radicand == quotient_root * quotient_root``
15 * rsqrt/rem:
16 ``1 == quotient_root * quotient_root * divisor_radicand``
17
18 The remainder is the left-hand-side of the comparison minus the
19 right-hand-side of the comparison in the above formulas.
20 """
21 from nmigen import (Elaboratable, Module, Signal, Const, Mux, Cat, Array)
22 from nmigen.lib.coding import PriorityEncoder
23 import enum
24
25
26 class DivPipeCoreOperation(enum.Enum):
27 """ Operation for ``DivPipeCore``.
28
29 :attribute UDivRem: unsigned divide/remainder.
30 :attribute SqrtRem: square-root/remainder.
31 :attribute RSqrtRem: reciprocal-square-root/remainder.
32 """
33
34 SqrtRem = 0
35 UDivRem = 1
36 RSqrtRem = 2
37
38 def __int__(self):
39 """ Convert to int. """
40 return self.value
41
42 @classmethod
43 def create_signal(cls, *, src_loc_at=0, **kwargs):
44 """ Create a signal that can contain a ``DivPipeCoreOperation``. """
45 return Signal(range(min(map(int, cls)), max(map(int, cls)) + 2),
46 src_loc_at=(src_loc_at + 1),
47 decoder=lambda v: str(cls(v)),
48 **kwargs)
49
50
51 DP = DivPipeCoreOperation
52
53
54 class DivPipeCoreConfig:
55 """ Configuration for core of the div/rem/sqrt/rsqrt pipeline.
56
57 :attribute bit_width: base bit-width.
58 :attribute fract_width: base fract-width. Specifies location of base-2
59 radix point.
60 :attribute log2_radix: number of bits of ``quotient_root`` that should be
61 computed per pipeline stage.
62 """
63
64 def __init__(self, bit_width, fract_width, log2_radix, supported=None):
65 """ Create a ``DivPipeCoreConfig`` instance. """
66 self.bit_width = bit_width
67 self.fract_width = fract_width
68 self.log2_radix = log2_radix
69 if supported is None:
70 supported = [DP.SqrtRem, DP.UDivRem, DP.RSqrtRem]
71 self.supported = supported
72 print(f"{self}: n_stages={self.n_stages}")
73
74 def __repr__(self):
75 """ Get repr. """
76 return f"DivPipeCoreConfig({self.bit_width}, " \
77 + f"{self.fract_width}, {self.log2_radix})"
78
79 @property
80 def n_stages(self):
81 """ Get the number of ``DivPipeCoreCalculateStage`` needed. """
82 return (self.bit_width + self.log2_radix - 1) // self.log2_radix
83
84
85 class DivPipeCoreInputData:
86 """ input data type for ``DivPipeCore``.
87
88 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
89 configuration to be used.
90 :attribute dividend: dividend for div/rem. Signal with a bit-width of
91 ``core_config.bit_width + core_config.fract_width`` and a fract-width
92 of ``core_config.fract_width * 2`` bits.
93 :attribute divisor_radicand: divisor for div/rem and radicand for
94 sqrt/rsqrt. Signal with a bit-width of ``core_config.bit_width`` and a
95 fract-width of ``core_config.fract_width`` bits.
96 :attribute operation: the ``DivPipeCoreOperation`` to be computed.
97 """
98
99 def __init__(self, core_config, reset_less=True):
100 """ Create a ``DivPipeCoreInputData`` instance. """
101 self.core_config = core_config
102 bw = core_config.bit_width
103 fw = core_config.fract_width
104 self.dividend = Signal(bw + fw, reset_less=reset_less)
105 self.divisor_radicand = Signal(bw, reset_less=reset_less)
106 self.operation = DP.create_signal(reset_less=reset_less)
107
108 def __iter__(self):
109 """ Get member signals. """
110 yield self.dividend
111 yield self.divisor_radicand
112 yield self.operation
113
114 def eq(self, rhs):
115 """ Assign member signals. """
116 return [self.dividend.eq(rhs.dividend),
117 self.divisor_radicand.eq(rhs.divisor_radicand),
118 self.operation.eq(rhs.operation),
119 ]
120
121
122 class DivPipeCoreInterstageData:
123 """ interstage data type for ``DivPipeCore``.
124
125 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
126 configuration to be used.
127 :attribute divisor_radicand: divisor for div/rem and radicand for
128 sqrt/rsqrt. Signal with a bit-width of ``core_config.bit_width`` and a
129 fract-width of ``core_config.fract_width`` bits.
130 :attribute operation: the ``DivPipeCoreOperation`` to be computed.
131 :attribute quotient_root: the quotient or root part of the result of the
132 operation. Signal with a bit-width of ``core_config.bit_width`` and a
133 fract-width of ``core_config.fract_width`` bits.
134 :attribute root_times_radicand: ``quotient_root * divisor_radicand``.
135 Signal with a bit-width of ``core_config.bit_width * 2`` and a
136 fract-width of ``core_config.fract_width * 2`` bits.
137 :attribute compare_lhs: The left-hand-side of the comparison in the
138 equation to be solved. Signal with a bit-width of
139 ``core_config.bit_width * 3`` and a fract-width of
140 ``core_config.fract_width * 3`` bits.
141 :attribute compare_rhs: The right-hand-side of the comparison in the
142 equation to be solved. Signal with a bit-width of
143 ``core_config.bit_width * 3`` and a fract-width of
144 ``core_config.fract_width * 3`` bits.
145 """
146
147 def __init__(self, core_config, reset_less=True):
148 """ Create a ``DivPipeCoreInterstageData`` instance. """
149 self.core_config = core_config
150 bw = core_config.bit_width
151 self.divisor_radicand = Signal(bw, reset_less=reset_less)
152 self.operation = DP.create_signal(reset_less=reset_less)
153 self.quotient_root = Signal(bw, reset_less=reset_less)
154 self.root_times_radicand = Signal(bw * 2, reset_less=reset_less)
155 self.compare_lhs = Signal(bw * 3, reset_less=reset_less)
156 self.compare_rhs = Signal(bw * 3, reset_less=reset_less)
157
158 def __iter__(self):
159 """ Get member signals. """
160 yield self.divisor_radicand
161 yield self.operation
162 yield self.quotient_root
163 yield self.root_times_radicand
164 yield self.compare_lhs
165 yield self.compare_rhs
166
167 def eq(self, rhs):
168 """ Assign member signals. """
169 return [self.divisor_radicand.eq(rhs.divisor_radicand),
170 self.operation.eq(rhs.operation),
171 self.quotient_root.eq(rhs.quotient_root),
172 self.root_times_radicand.eq(rhs.root_times_radicand),
173 self.compare_lhs.eq(rhs.compare_lhs),
174 self.compare_rhs.eq(rhs.compare_rhs)]
175
176
177 class DivPipeCoreOutputData:
178 """ output data type for ``DivPipeCore``.
179
180 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
181 configuration to be used.
182 :attribute quotient_root: the quotient or root part of the result of the
183 operation. Signal with a bit-width of ``core_config.bit_width`` and a
184 fract-width of ``core_config.fract_width`` bits.
185 :attribute remainder: the remainder part of the result of the operation.
186 Signal with a bit-width of ``core_config.bit_width * 3`` and a
187 fract-width of ``core_config.fract_width * 3`` bits.
188 """
189
190 def __init__(self, core_config, reset_less=True):
191 """ Create a ``DivPipeCoreOutputData`` instance. """
192 self.core_config = core_config
193 bw = core_config.bit_width
194 self.quotient_root = Signal(bw, reset_less=reset_less)
195 self.remainder = Signal(bw * 3, reset_less=reset_less)
196
197 def __iter__(self):
198 """ Get member signals. """
199 yield self.quotient_root
200 yield self.remainder
201 return
202
203 def eq(self, rhs):
204 """ Assign member signals. """
205 return [self.quotient_root.eq(rhs.quotient_root),
206 self.remainder.eq(rhs.remainder)]
207
208
209 class DivPipeCoreSetupStage(Elaboratable):
210 """ Setup Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
211
212 def __init__(self, core_config):
213 """ Create a ``DivPipeCoreSetupStage`` instance."""
214 self.core_config = core_config
215 self.i = self.ispec()
216 self.o = self.ospec()
217
218 def ispec(self):
219 """ Get the input spec for this pipeline stage."""
220 return DivPipeCoreInputData(self.core_config)
221
222 def ospec(self):
223 """ Get the output spec for this pipeline stage."""
224 return DivPipeCoreInterstageData(self.core_config)
225
226 def setup(self, m, i):
227 """ Pipeline stage setup. """
228 m.submodules.div_pipe_core_setup = self
229 m.d.comb += self.i.eq(i)
230
231 def process(self, i):
232 """ Pipeline stage process. """
233 return self.o # return processed data (ignore i)
234
235 def elaborate(self, platform):
236 """ Elaborate into ``Module``. """
237 m = Module()
238 comb = m.d.comb
239
240 comb += self.o.divisor_radicand.eq(self.i.divisor_radicand)
241 comb += self.o.quotient_root.eq(0)
242 comb += self.o.root_times_radicand.eq(0)
243
244 lhs = Signal(self.core_config.bit_width * 3, reset_less=True)
245 fw = self.core_config.fract_width
246
247 with m.Switch(self.i.operation):
248 with m.Case(int(DP.UDivRem)):
249 comb += lhs.eq(self.i.dividend << fw)
250 with m.Case(int(DP.SqrtRem)):
251 comb += lhs.eq(self.i.divisor_radicand << (fw * 2))
252 with m.Case(int(DP.RSqrtRem)):
253 comb += lhs.eq(1 << (fw * 3))
254
255 comb += self.o.compare_lhs.eq(lhs)
256 comb += self.o.compare_rhs.eq(0)
257 comb += self.o.operation.eq(self.i.operation)
258
259 return m
260
261
262 class Trial(Elaboratable):
263 def __init__(self, core_config, trial_bits, current_shift, log2_radix):
264 self.core_config = core_config
265 self.trial_bits = trial_bits
266 self.current_shift = current_shift
267 self.log2_radix = log2_radix
268 bw = core_config.bit_width
269 self.divisor_radicand = Signal(bw, reset_less=True)
270 self.quotient_root = Signal(bw, reset_less=True)
271 self.root_times_radicand = Signal(bw * 2, reset_less=True)
272 self.compare_rhs = Signal(bw * 3, reset_less=True)
273 self.trial_compare_rhs = Signal(bw * 3, reset_less=True)
274 self.operation = DP.create_signal(reset_less=True)
275
276 def elaborate(self, platform):
277
278 m = Module()
279 comb = m.d.comb
280
281 cc = self.core_config
282 dr = self.divisor_radicand
283
284 trial_bits_sig = Const(self.trial_bits, self.log2_radix)
285 trial_bits_sqrd_sig = Const(self.trial_bits * self.trial_bits,
286 self.log2_radix * 2)
287
288 tblen = self.core_config.bit_width+self.log2_radix
289
290 # UDivRem
291 if DP.UDivRem in cc.supported:
292 with m.If(self.operation == int(DP.UDivRem)):
293 dr_times_trial_bits = Signal(tblen, reset_less=True)
294 comb += dr_times_trial_bits.eq(dr * trial_bits_sig)
295 div_rhs = self.compare_rhs
296
297 div_term1 = dr_times_trial_bits
298 div_term1_shift = self.core_config.fract_width
299 div_term1_shift += self.current_shift
300 div_rhs += div_term1 << div_term1_shift
301
302 comb += self.trial_compare_rhs.eq(div_rhs)
303
304 # SqrtRem
305 if DP.SqrtRem in cc.supported:
306 with m.If(self.operation == int(DP.SqrtRem)):
307 qr = self.quotient_root
308 qr_times_trial_bits = Signal((tblen+1)*2, reset_less=True)
309 comb += qr_times_trial_bits.eq(qr * trial_bits_sig)
310 sqrt_rhs = self.compare_rhs
311
312 sqrt_term1 = qr_times_trial_bits
313 sqrt_term1_shift = self.core_config.fract_width
314 sqrt_term1_shift += self.current_shift + 1
315 sqrt_rhs += sqrt_term1 << sqrt_term1_shift
316 sqrt_term2 = trial_bits_sqrd_sig
317 sqrt_term2_shift = self.core_config.fract_width
318 sqrt_term2_shift += self.current_shift * 2
319 sqrt_rhs += sqrt_term2 << sqrt_term2_shift
320
321 comb += self.trial_compare_rhs.eq(sqrt_rhs)
322
323 # RSqrtRem
324 if DP.RSqrtRem in cc.supported:
325 with m.If(self.operation == int(DP.RSqrtRem)):
326 rr = self.root_times_radicand
327 tblen2 = self.core_config.bit_width+self.log2_radix*2
328 dr_times_trial_bits_sqrd = Signal(tblen2, reset_less=True)
329 comb += dr_times_trial_bits_sqrd.eq(dr * trial_bits_sqrd_sig)
330 rr_times_trial_bits = Signal((tblen+1)*3, reset_less=True)
331 comb += rr_times_trial_bits.eq(rr * trial_bits_sig)
332 rsqrt_rhs = self.compare_rhs
333
334 rsqrt_term1 = rr_times_trial_bits
335 rsqrt_term1_shift = self.current_shift + 1
336 rsqrt_rhs += rsqrt_term1 << rsqrt_term1_shift
337 rsqrt_term2 = dr_times_trial_bits_sqrd
338 rsqrt_term2_shift = self.current_shift * 2
339 rsqrt_rhs += rsqrt_term2 << rsqrt_term2_shift
340
341 comb += self.trial_compare_rhs.eq(rsqrt_rhs)
342
343 return m
344
345
346 class DivPipeCoreCalculateStage(Elaboratable):
347 """ Calculate Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
348
349 def __init__(self, core_config, stage_index):
350 """ Create a ``DivPipeCoreSetupStage`` instance. """
351 assert stage_index in range(core_config.n_stages)
352 self.core_config = core_config
353 self.stage_index = stage_index
354 self.i = self.ispec()
355 self.o = self.ospec()
356
357 def ispec(self):
358 """ Get the input spec for this pipeline stage. """
359 return DivPipeCoreInterstageData(self.core_config)
360
361 def ospec(self):
362 """ Get the output spec for this pipeline stage. """
363 return DivPipeCoreInterstageData(self.core_config)
364
365 def setup(self, m, i):
366 """ Pipeline stage setup. """
367 setattr(m.submodules,
368 f"div_pipe_core_calculate_{self.stage_index}",
369 self)
370 m.d.comb += self.i.eq(i)
371
372 def process(self, i):
373 """ Pipeline stage process. """
374 return self.o
375
376 def elaborate(self, platform):
377 """ Elaborate into ``Module``. """
378 m = Module()
379 comb = m.d.comb
380 cc = self.core_config
381
382 # copy invariant inputs to outputs (for next stage)
383 comb += self.o.divisor_radicand.eq(self.i.divisor_radicand)
384 comb += self.o.operation.eq(self.i.operation)
385 comb += self.o.compare_lhs.eq(self.i.compare_lhs)
386
387 # constants
388 log2_radix = self.core_config.log2_radix
389 current_shift = self.core_config.bit_width
390 current_shift -= self.stage_index * log2_radix
391 log2_radix = min(log2_radix, current_shift)
392 assert log2_radix > 0
393 current_shift -= log2_radix
394 print(f"DivPipeCoreCalc: stage {self.stage_index}"
395 + f" of {self.core_config.n_stages} handling "
396 + f"bits [{current_shift}, {current_shift+log2_radix})"
397 + f" of {self.core_config.bit_width}")
398 radix = 1 << log2_radix
399
400 # trials within this radix range. carried out by Trial module,
401 # results stored in pass_flags. pass_flags are unary priority.
402 trial_compare_rhs_values = []
403 pfl = []
404 for trial_bits in range(radix):
405 t = Trial(self.core_config, trial_bits, current_shift, log2_radix)
406 setattr(m.submodules, "trial%d" % trial_bits, t)
407
408 comb += t.divisor_radicand.eq(self.i.divisor_radicand)
409 comb += t.quotient_root.eq(self.i.quotient_root)
410 comb += t.root_times_radicand.eq(self.i.root_times_radicand)
411 comb += t.compare_rhs.eq(self.i.compare_rhs)
412 comb += t.operation.eq(self.i.operation)
413
414 # get the trial output
415 trial_compare_rhs_values.append(t.trial_compare_rhs)
416
417 # make the trial comparison against the [invariant] lhs.
418 # trial_compare_rhs is always decreasing as trial_bits increases
419 pass_flag = Signal(name=f"pass_flag_{trial_bits}", reset_less=True)
420 comb += pass_flag.eq(self.i.compare_lhs >= t.trial_compare_rhs)
421 pfl.append(pass_flag)
422
423 # Cat all the pass flags list together (easier to handle, below)
424 pass_flags = Signal(radix, reset_less=True)
425 comb += pass_flags.eq(Cat(*pfl))
426
427 # convert pass_flags (unary priority) to next_bits (binary index)
428 #
429 # Assumes that for each set bit in pass_flag, all previous bits are
430 # also set.
431 #
432 # Assumes that pass_flag[0] is always set (since
433 # compare_lhs >= compare_rhs is a pipeline invariant).
434
435 m.submodules.pe = pe = PriorityEncoder(radix)
436 next_bits = Signal(log2_radix, reset_less=True)
437 comb += pe.i.eq(~pass_flags)
438 with m.If(~pe.n):
439 comb += next_bits.eq(pe.o-1)
440 with m.Else():
441 comb += next_bits.eq(radix-1)
442
443 # get the highest passing rhs trial (indexed by next_bits)
444 ta = Array(trial_compare_rhs_values)
445 comb += self.o.compare_rhs.eq(ta[next_bits])
446
447 # create outputs for next phase
448 qr = self.i.quotient_root | (next_bits << current_shift)
449 comb += self.o.quotient_root.eq(qr)
450 if DP.RSqrtRem in cc.supported:
451 rr = self.i.root_times_radicand + ((self.i.divisor_radicand *
452 next_bits) << current_shift)
453 comb += self.o.root_times_radicand.eq(rr)
454
455 return m
456
457
458 class DivPipeCoreFinalStage(Elaboratable):
459 """ Final Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
460
461 def __init__(self, core_config):
462 """ Create a ``DivPipeCoreFinalStage`` instance."""
463 self.core_config = core_config
464 self.i = self.ispec()
465 self.o = self.ospec()
466
467 def ispec(self):
468 """ Get the input spec for this pipeline stage."""
469 return DivPipeCoreInterstageData(self.core_config)
470
471 def ospec(self):
472 """ Get the output spec for this pipeline stage."""
473 return DivPipeCoreOutputData(self.core_config)
474
475 def setup(self, m, i):
476 """ Pipeline stage setup. """
477 m.submodules.div_pipe_core_final = self
478 m.d.comb += self.i.eq(i)
479
480 def process(self, i):
481 """ Pipeline stage process. """
482 return self.o # return processed data (ignore i)
483
484 def elaborate(self, platform):
485 """ Elaborate into ``Module``. """
486 m = Module()
487 comb = m.d.comb
488
489 comb += self.o.quotient_root.eq(self.i.quotient_root)
490 comb += self.o.remainder.eq(self.i.compare_lhs - self.i.compare_rhs)
491
492 return m