continue reducing length of signals in div core
[ieee754fpu.git] / src / ieee754 / div_rem_sqrt_rsqrt / core.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """ Core of the div/rem/sqrt/rsqrt pipeline.
4
5 Special case handling, input/output conversion, and muxid handling are handled
6 outside of these classes.
7
8 Algorithms based on ``algorithm.FixedUDivRemSqrtRSqrt``.
9
10 Formulas solved are:
11 * div/rem:
12 ``dividend == quotient_root * divisor_radicand``
13 * sqrt/rem:
14 ``divisor_radicand == quotient_root * quotient_root``
15 * rsqrt/rem:
16 ``1 == quotient_root * quotient_root * divisor_radicand``
17
18 The remainder is the left-hand-side of the comparison minus the
19 right-hand-side of the comparison in the above formulas.
20 """
21 from nmigen import (Elaboratable, Module, Signal, Const, Mux, Cat, Repl)
22 from nmigen.lib.coding import PriorityEncoder
23 from nmutil.util import treereduce
24 import enum
25 import operator
26
27
28 class DivPipeCoreOperation(enum.Enum):
29 """ Operation for ``DivPipeCore``.
30
31 :attribute UDivRem: unsigned divide/remainder.
32 :attribute SqrtRem: square-root/remainder.
33 :attribute RSqrtRem: reciprocal-square-root/remainder.
34 """
35
36 SqrtRem = 0
37 UDivRem = 1
38 RSqrtRem = 2
39
40 def __int__(self):
41 """ Convert to int. """
42 return self.value
43
44 @classmethod
45 def create_signal(cls, *, src_loc_at=0, **kwargs):
46 """ Create a signal that can contain a ``DivPipeCoreOperation``. """
47 return Signal(range(min(map(int, cls)), max(map(int, cls)) + 2),
48 src_loc_at=(src_loc_at + 1),
49 decoder=lambda v: str(cls(v)),
50 **kwargs)
51
52
53 DP = DivPipeCoreOperation
54
55
56 class DivPipeCoreConfig:
57 """ Configuration for core of the div/rem/sqrt/rsqrt pipeline.
58
59 :attribute bit_width: base bit-width.
60 :attribute fract_width: base fract-width. Specifies location of base-2
61 radix point.
62 :attribute log2_radix: number of bits of ``quotient_root`` that should be
63 computed per pipeline stage.
64 """
65
66 def __init__(self, bit_width, fract_width, log2_radix, supported=None):
67 """ Create a ``DivPipeCoreConfig`` instance. """
68 self.bit_width = bit_width
69 self.fract_width = fract_width
70 self.log2_radix = log2_radix
71 if supported is None:
72 supported = [DP.SqrtRem, DP.UDivRem, DP.RSqrtRem]
73 self.supported = supported
74 print(f"{self}: n_stages={self.n_stages}")
75
76 def __repr__(self):
77 """ Get repr. """
78 return f"DivPipeCoreConfig({self.bit_width}, " \
79 + f"{self.fract_width}, {self.log2_radix})"
80
81 @property
82 def n_stages(self):
83 """ Get the number of ``DivPipeCoreCalculateStage`` needed. """
84 return (self.bit_width + self.log2_radix - 1) // self.log2_radix
85
86
87 class DivPipeCoreInputData:
88 """ input data type for ``DivPipeCore``.
89
90 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
91 configuration to be used.
92 :attribute dividend: dividend for div/rem. Signal with a bit-width of
93 ``core_config.bit_width + core_config.fract_width`` and a fract-width
94 of ``core_config.fract_width * 2`` bits.
95 :attribute divisor_radicand: divisor for div/rem and radicand for
96 sqrt/rsqrt. Signal with a bit-width of ``core_config.bit_width`` and a
97 fract-width of ``core_config.fract_width`` bits.
98 :attribute operation: the ``DivPipeCoreOperation`` to be computed.
99 """
100
101 def __init__(self, core_config, reset_less=True):
102 """ Create a ``DivPipeCoreInputData`` instance. """
103 self.core_config = core_config
104 bw = core_config.bit_width
105 fw = core_config.fract_width
106 self.dividend = Signal(bw + fw, reset_less=reset_less)
107 self.divisor_radicand = Signal(bw, reset_less=reset_less)
108 self.operation = DP.create_signal(reset_less=reset_less)
109
110 def __iter__(self):
111 """ Get member signals. """
112 yield self.dividend
113 yield self.divisor_radicand
114 yield self.operation
115
116 def eq(self, rhs):
117 """ Assign member signals. """
118 return [self.dividend.eq(rhs.dividend),
119 self.divisor_radicand.eq(rhs.divisor_radicand),
120 self.operation.eq(rhs.operation),
121 ]
122
123
124 class DivPipeCoreInterstageData:
125 """ interstage data type for ``DivPipeCore``.
126
127 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
128 configuration to be used.
129 :attribute divisor_radicand: divisor for div/rem and radicand for
130 sqrt/rsqrt. Signal with a bit-width of ``core_config.bit_width`` and a
131 fract-width of ``core_config.fract_width`` bits.
132 :attribute operation: the ``DivPipeCoreOperation`` to be computed.
133 :attribute quotient_root: the quotient or root part of the result of the
134 operation. Signal with a bit-width of ``core_config.bit_width`` and a
135 fract-width of ``core_config.fract_width`` bits.
136 :attribute root_times_radicand: ``quotient_root * divisor_radicand``.
137 Signal with a bit-width of ``core_config.bit_width * 2`` and a
138 fract-width of ``core_config.fract_width * 2`` bits.
139 :attribute compare_lhs: The left-hand-side of the comparison in the
140 equation to be solved. Signal with a bit-width of
141 ``core_config.bit_width * 3`` and a fract-width of
142 ``core_config.fract_width * 3`` bits.
143 :attribute compare_rhs: The right-hand-side of the comparison in the
144 equation to be solved. Signal with a bit-width of
145 ``core_config.bit_width * 3`` and a fract-width of
146 ``core_config.fract_width * 3`` bits.
147 """
148
149 def __init__(self, core_config, reset_less=True):
150 """ Create a ``DivPipeCoreInterstageData`` instance. """
151 self.core_config = core_config
152 bw = core_config.bit_width
153 if core_config.supported == [DP.UDivRem]:
154 self.compare_len = bw * 2
155 else:
156 self.compare_len = bw * 3
157 self.divisor_radicand = Signal(bw, reset_less=reset_less)
158 self.operation = DP.create_signal(reset_less=reset_less)
159 self.quotient_root = Signal(bw, reset_less=reset_less)
160 self.root_times_radicand = Signal(bw * 2, reset_less=reset_less)
161 self.compare_lhs = Signal(self.compare_len, reset_less=reset_less)
162 self.compare_rhs = Signal(self.compare_len, reset_less=reset_less)
163
164 def __iter__(self):
165 """ Get member signals. """
166 yield self.divisor_radicand
167 yield self.operation
168 yield self.quotient_root
169 yield self.root_times_radicand
170 yield self.compare_lhs
171 yield self.compare_rhs
172
173 def eq(self, rhs):
174 """ Assign member signals. """
175 return [self.divisor_radicand.eq(rhs.divisor_radicand),
176 self.operation.eq(rhs.operation),
177 self.quotient_root.eq(rhs.quotient_root),
178 self.root_times_radicand.eq(rhs.root_times_radicand),
179 self.compare_lhs.eq(rhs.compare_lhs),
180 self.compare_rhs.eq(rhs.compare_rhs)]
181
182
183 class DivPipeCoreOutputData:
184 """ output data type for ``DivPipeCore``.
185
186 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
187 configuration to be used.
188 :attribute quotient_root: the quotient or root part of the result of the
189 operation. Signal with a bit-width of ``core_config.bit_width`` and a
190 fract-width of ``core_config.fract_width`` bits.
191 :attribute remainder: the remainder part of the result of the operation.
192 Signal with a bit-width of ``core_config.bit_width * 3`` and a
193 fract-width of ``core_config.fract_width * 3`` bits.
194 """
195
196 def __init__(self, core_config, reset_less=True):
197 """ Create a ``DivPipeCoreOutputData`` instance. """
198 self.core_config = core_config
199 bw = core_config.bit_width
200 if core_config.supported == [DP.UDivRem]:
201 self.compare_len = bw * 2
202 else:
203 self.compare_len = bw * 3
204 self.quotient_root = Signal(bw, reset_less=reset_less)
205 self.remainder = Signal(self.compare_len, reset_less=reset_less)
206
207 def __iter__(self):
208 """ Get member signals. """
209 yield self.quotient_root
210 yield self.remainder
211 return
212
213 def eq(self, rhs):
214 """ Assign member signals. """
215 return [self.quotient_root.eq(rhs.quotient_root),
216 self.remainder.eq(rhs.remainder)]
217
218
219 class DivPipeCoreSetupStage(Elaboratable):
220 """ Setup Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
221
222 def __init__(self, core_config):
223 """ Create a ``DivPipeCoreSetupStage`` instance."""
224 self.core_config = core_config
225 self.i = self.ispec()
226 self.o = self.ospec()
227
228 def ispec(self):
229 """ Get the input spec for this pipeline stage."""
230 return DivPipeCoreInputData(self.core_config)
231
232 def ospec(self):
233 """ Get the output spec for this pipeline stage."""
234 return DivPipeCoreInterstageData(self.core_config)
235
236 def setup(self, m, i):
237 """ Pipeline stage setup. """
238 m.submodules.div_pipe_core_setup = self
239 m.d.comb += self.i.eq(i)
240
241 def process(self, i):
242 """ Pipeline stage process. """
243 return self.o # return processed data (ignore i)
244
245 def elaborate(self, platform):
246 """ Elaborate into ``Module``. """
247 m = Module()
248 comb = m.d.comb
249
250 comb += self.o.divisor_radicand.eq(self.i.divisor_radicand)
251 comb += self.o.quotient_root.eq(0)
252 comb += self.o.root_times_radicand.eq(0)
253
254 lhs = Signal(self.core_config.bit_width * 3, reset_less=True)
255 fw = self.core_config.fract_width
256
257 with m.Switch(self.i.operation):
258 with m.Case(int(DP.UDivRem)):
259 comb += lhs.eq(self.i.dividend << fw)
260 with m.Case(int(DP.SqrtRem)):
261 comb += lhs.eq(self.i.divisor_radicand << (fw * 2))
262 with m.Case(int(DP.RSqrtRem)):
263 comb += lhs.eq(1 << (fw * 3))
264
265 comb += self.o.compare_lhs.eq(lhs)
266 comb += self.o.compare_rhs.eq(0)
267 comb += self.o.operation.eq(self.i.operation)
268
269 return m
270
271
272 class Trial(Elaboratable):
273 def __init__(self, core_config, trial_bits, current_shift, log2_radix):
274 self.core_config = core_config
275 self.trial_bits = trial_bits
276 self.current_shift = current_shift
277 self.log2_radix = log2_radix
278 bw = core_config.bit_width
279 if core_config.supported == [DP.UDivRem]:
280 self.compare_len = bw * 2
281 else:
282 self.compare_len = bw * 3
283 self.divisor_radicand = Signal(bw, reset_less=True)
284 self.quotient_root = Signal(bw, reset_less=True)
285 self.root_times_radicand = Signal(bw * 2, reset_less=True)
286 self.compare_rhs = Signal(self.compare_len, reset_less=True)
287 self.trial_compare_rhs = Signal(self.compare_len, reset_less=True)
288 self.operation = DP.create_signal(reset_less=True)
289
290 def elaborate(self, platform):
291
292 m = Module()
293 comb = m.d.comb
294
295 cc = self.core_config
296 dr = self.divisor_radicand
297
298 trial_bits_sig = Const(self.trial_bits, self.log2_radix)
299 trial_bits_sqrd_sig = Const(self.trial_bits * self.trial_bits,
300 self.log2_radix * 2)
301
302 tblen = self.core_config.bit_width+self.log2_radix
303
304 # UDivRem
305 if DP.UDivRem in cc.supported:
306 with m.If(self.operation == int(DP.UDivRem)):
307 dr_times_trial_bits = Signal(tblen, reset_less=True)
308 comb += dr_times_trial_bits.eq(dr * trial_bits_sig)
309 div_rhs = self.compare_rhs
310
311 div_term1 = dr_times_trial_bits
312 div_term1_shift = self.core_config.fract_width
313 div_term1_shift += self.current_shift
314 div_rhs += div_term1 << div_term1_shift
315
316 comb += self.trial_compare_rhs.eq(div_rhs)
317
318 # SqrtRem
319 if DP.SqrtRem in cc.supported:
320 with m.If(self.operation == int(DP.SqrtRem)):
321 qr = self.quotient_root
322 qr_times_trial_bits = Signal((tblen+1)*2, reset_less=True)
323 comb += qr_times_trial_bits.eq(qr * trial_bits_sig)
324 sqrt_rhs = self.compare_rhs
325
326 sqrt_term1 = qr_times_trial_bits
327 sqrt_term1_shift = self.core_config.fract_width
328 sqrt_term1_shift += self.current_shift + 1
329 sqrt_rhs += sqrt_term1 << sqrt_term1_shift
330 sqrt_term2 = trial_bits_sqrd_sig
331 sqrt_term2_shift = self.core_config.fract_width
332 sqrt_term2_shift += self.current_shift * 2
333 sqrt_rhs += sqrt_term2 << sqrt_term2_shift
334
335 comb += self.trial_compare_rhs.eq(sqrt_rhs)
336
337 # RSqrtRem
338 if DP.RSqrtRem in cc.supported:
339 with m.If(self.operation == int(DP.RSqrtRem)):
340 rr = self.root_times_radicand
341 tblen2 = self.core_config.bit_width+self.log2_radix*2
342 dr_times_trial_bits_sqrd = Signal(tblen2, reset_less=True)
343 comb += dr_times_trial_bits_sqrd.eq(dr * trial_bits_sqrd_sig)
344 rr_times_trial_bits = Signal((tblen+1)*3, reset_less=True)
345 comb += rr_times_trial_bits.eq(rr * trial_bits_sig)
346 rsqrt_rhs = self.compare_rhs
347
348 rsqrt_term1 = rr_times_trial_bits
349 rsqrt_term1_shift = self.current_shift + 1
350 rsqrt_rhs += rsqrt_term1 << rsqrt_term1_shift
351 rsqrt_term2 = dr_times_trial_bits_sqrd
352 rsqrt_term2_shift = self.current_shift * 2
353 rsqrt_rhs += rsqrt_term2 << rsqrt_term2_shift
354
355 comb += self.trial_compare_rhs.eq(rsqrt_rhs)
356
357 return m
358
359
360 class DivPipeCoreCalculateStage(Elaboratable):
361 """ Calculate Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
362
363 def __init__(self, core_config, stage_index):
364 """ Create a ``DivPipeCoreSetupStage`` instance. """
365 assert stage_index in range(core_config.n_stages)
366 self.core_config = core_config
367 bw = core_config.bit_width
368 if core_config.supported == [DP.UDivRem]:
369 self.compare_len = bw * 2
370 else:
371 self.compare_len = bw * 3
372 self.stage_index = stage_index
373 self.i = self.ispec()
374 self.o = self.ospec()
375
376 def ispec(self):
377 """ Get the input spec for this pipeline stage. """
378 return DivPipeCoreInterstageData(self.core_config)
379
380 def ospec(self):
381 """ Get the output spec for this pipeline stage. """
382 return DivPipeCoreInterstageData(self.core_config)
383
384 def setup(self, m, i):
385 """ Pipeline stage setup. """
386 setattr(m.submodules,
387 f"div_pipe_core_calculate_{self.stage_index}",
388 self)
389 m.d.comb += self.i.eq(i)
390
391 def process(self, i):
392 """ Pipeline stage process. """
393 return self.o
394
395 def elaborate(self, platform):
396 """ Elaborate into ``Module``. """
397 m = Module()
398 comb = m.d.comb
399 cc = self.core_config
400
401 # copy invariant inputs to outputs (for next stage)
402 comb += self.o.divisor_radicand.eq(self.i.divisor_radicand)
403 comb += self.o.operation.eq(self.i.operation)
404 comb += self.o.compare_lhs.eq(self.i.compare_lhs)
405
406 # constants
407 log2_radix = self.core_config.log2_radix
408 current_shift = self.core_config.bit_width
409 current_shift -= self.stage_index * log2_radix
410 log2_radix = min(log2_radix, current_shift)
411 assert log2_radix > 0
412 current_shift -= log2_radix
413 print(f"DivPipeCoreCalc: stage {self.stage_index}"
414 + f" of {self.core_config.n_stages} handling "
415 + f"bits [{current_shift}, {current_shift+log2_radix})"
416 + f" of {self.core_config.bit_width}")
417 radix = 1 << log2_radix
418
419 # trials within this radix range. carried out by Trial module,
420 # results stored in pass_flags. pass_flags are unary priority.
421 trial_compare_rhs_values = []
422 pfl = []
423 for trial_bits in range(radix):
424 t = Trial(self.core_config, trial_bits, current_shift, log2_radix)
425 setattr(m.submodules, "trial%d" % trial_bits, t)
426
427 comb += t.divisor_radicand.eq(self.i.divisor_radicand)
428 comb += t.quotient_root.eq(self.i.quotient_root)
429 comb += t.root_times_radicand.eq(self.i.root_times_radicand)
430 comb += t.compare_rhs.eq(self.i.compare_rhs)
431 comb += t.operation.eq(self.i.operation)
432
433 # get the trial output (needed even in pass_flags[0] case)
434 trial_compare_rhs_values.append(t.trial_compare_rhs)
435
436 # make the trial comparison against the [invariant] lhs.
437 # trial_compare_rhs is always decreasing as trial_bits increases
438 pass_flag = Signal(name=f"pass_flag_{trial_bits}", reset_less=True)
439 if trial_bits == 0:
440 # do not do first comparison: no point.
441 comb += pass_flag.eq(1)
442 else:
443 comb += pass_flag.eq(self.i.compare_lhs >= t.trial_compare_rhs)
444 pfl.append(pass_flag)
445
446 # Cat all the pass flags list together (easier to handle, below)
447 pass_flags = Signal(radix, reset_less=True)
448 comb += pass_flags.eq(Cat(*pfl))
449
450 # convert pass_flags (unary priority) to next_bits (binary index)
451 #
452 # Assumes that for each set bit in pass_flag, all previous bits are
453 # also set.
454 #
455 # Assumes that pass_flag[0] is always set (since
456 # compare_lhs >= compare_rhs is a pipeline invariant).
457
458 m.submodules.pe = pe = PriorityEncoder(radix)
459 next_bits = Signal(log2_radix, reset_less=True)
460 comb += pe.i.eq(~pass_flags)
461 with m.If(~pe.n):
462 comb += next_bits.eq(pe.o-1)
463 with m.Else():
464 comb += next_bits.eq(radix-1)
465
466 # get the highest passing rhs trial. use treereduce because
467 # Array on such massively long numbers is insanely gate-hungry
468 crhs = []
469 tcrh = trial_compare_rhs_values
470 for i in range(radix):
471 nbe = Signal(reset_less=True)
472 comb += nbe.eq(next_bits == i)
473 crhs.append(Repl(nbe, self.compare_len) & tcrh[i])
474 comb += self.o.compare_rhs.eq(treereduce(crhs, operator.or_,
475 lambda x:x))
476
477 # create outputs for next phase
478 qr = self.i.quotient_root | (next_bits << current_shift)
479 comb += self.o.quotient_root.eq(qr)
480 if DP.RSqrtRem in cc.supported:
481 rr = self.i.root_times_radicand + ((self.i.divisor_radicand *
482 next_bits) << current_shift)
483 comb += self.o.root_times_radicand.eq(rr)
484
485 return m
486
487
488 class DivPipeCoreFinalStage(Elaboratable):
489 """ Final Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
490
491 def __init__(self, core_config):
492 """ Create a ``DivPipeCoreFinalStage`` instance."""
493 self.core_config = core_config
494 self.i = self.ispec()
495 self.o = self.ospec()
496
497 def ispec(self):
498 """ Get the input spec for this pipeline stage."""
499 return DivPipeCoreInterstageData(self.core_config)
500
501 def ospec(self):
502 """ Get the output spec for this pipeline stage."""
503 return DivPipeCoreOutputData(self.core_config)
504
505 def setup(self, m, i):
506 """ Pipeline stage setup. """
507 m.submodules.div_pipe_core_final = self
508 m.d.comb += self.i.eq(i)
509
510 def process(self, i):
511 """ Pipeline stage process. """
512 return self.o # return processed data (ignore i)
513
514 def elaborate(self, platform):
515 """ Elaborate into ``Module``. """
516 m = Module()
517 comb = m.d.comb
518
519 comb += self.o.quotient_root.eq(self.i.quotient_root)
520 comb += self.o.remainder.eq(self.i.compare_lhs - self.i.compare_rhs)
521
522 return m