whoops set pass_flag[0] always true
[ieee754fpu.git] / src / ieee754 / div_rem_sqrt_rsqrt / core.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """ Core of the div/rem/sqrt/rsqrt pipeline.
4
5 Special case handling, input/output conversion, and muxid handling are handled
6 outside of these classes.
7
8 Algorithms based on ``algorithm.FixedUDivRemSqrtRSqrt``.
9
10 Formulas solved are:
11 * div/rem:
12 ``dividend == quotient_root * divisor_radicand``
13 * sqrt/rem:
14 ``divisor_radicand == quotient_root * quotient_root``
15 * rsqrt/rem:
16 ``1 == quotient_root * quotient_root * divisor_radicand``
17
18 The remainder is the left-hand-side of the comparison minus the
19 right-hand-side of the comparison in the above formulas.
20 """
21 from nmigen import (Elaboratable, Module, Signal, Const, Mux, Cat, Repl)
22 from nmigen.lib.coding import PriorityEncoder
23 from nmutil.util import treereduce
24 import enum
25 import operator
26
27
28 class DivPipeCoreOperation(enum.Enum):
29 """ Operation for ``DivPipeCore``.
30
31 :attribute UDivRem: unsigned divide/remainder.
32 :attribute SqrtRem: square-root/remainder.
33 :attribute RSqrtRem: reciprocal-square-root/remainder.
34 """
35
36 SqrtRem = 0
37 UDivRem = 1
38 RSqrtRem = 2
39
40 def __int__(self):
41 """ Convert to int. """
42 return self.value
43
44 @classmethod
45 def create_signal(cls, *, src_loc_at=0, **kwargs):
46 """ Create a signal that can contain a ``DivPipeCoreOperation``. """
47 return Signal(range(min(map(int, cls)), max(map(int, cls)) + 2),
48 src_loc_at=(src_loc_at + 1),
49 decoder=lambda v: str(cls(v)),
50 **kwargs)
51
52
53 DP = DivPipeCoreOperation
54
55
56 class DivPipeCoreConfig:
57 """ Configuration for core of the div/rem/sqrt/rsqrt pipeline.
58
59 :attribute bit_width: base bit-width.
60 :attribute fract_width: base fract-width. Specifies location of base-2
61 radix point.
62 :attribute log2_radix: number of bits of ``quotient_root`` that should be
63 computed per pipeline stage.
64 """
65
66 def __init__(self, bit_width, fract_width, log2_radix, supported=None):
67 """ Create a ``DivPipeCoreConfig`` instance. """
68 self.bit_width = bit_width
69 self.fract_width = fract_width
70 self.log2_radix = log2_radix
71 if supported is None:
72 supported = [DP.SqrtRem, DP.UDivRem, DP.RSqrtRem]
73 self.supported = supported
74 print(f"{self}: n_stages={self.n_stages}")
75
76 def __repr__(self):
77 """ Get repr. """
78 return f"DivPipeCoreConfig({self.bit_width}, " \
79 + f"{self.fract_width}, {self.log2_radix})"
80
81 @property
82 def n_stages(self):
83 """ Get the number of ``DivPipeCoreCalculateStage`` needed. """
84 return (self.bit_width + self.log2_radix - 1) // self.log2_radix
85
86
87 class DivPipeCoreInputData:
88 """ input data type for ``DivPipeCore``.
89
90 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
91 configuration to be used.
92 :attribute dividend: dividend for div/rem. Signal with a bit-width of
93 ``core_config.bit_width + core_config.fract_width`` and a fract-width
94 of ``core_config.fract_width * 2`` bits.
95 :attribute divisor_radicand: divisor for div/rem and radicand for
96 sqrt/rsqrt. Signal with a bit-width of ``core_config.bit_width`` and a
97 fract-width of ``core_config.fract_width`` bits.
98 :attribute operation: the ``DivPipeCoreOperation`` to be computed.
99 """
100
101 def __init__(self, core_config, reset_less=True):
102 """ Create a ``DivPipeCoreInputData`` instance. """
103 self.core_config = core_config
104 bw = core_config.bit_width
105 fw = core_config.fract_width
106 self.dividend = Signal(bw + fw, reset_less=reset_less)
107 self.divisor_radicand = Signal(bw, reset_less=reset_less)
108 self.operation = DP.create_signal(reset_less=reset_less)
109
110 def __iter__(self):
111 """ Get member signals. """
112 yield self.dividend
113 yield self.divisor_radicand
114 yield self.operation
115
116 def eq(self, rhs):
117 """ Assign member signals. """
118 return [self.dividend.eq(rhs.dividend),
119 self.divisor_radicand.eq(rhs.divisor_radicand),
120 self.operation.eq(rhs.operation),
121 ]
122
123
124 class DivPipeCoreInterstageData:
125 """ interstage data type for ``DivPipeCore``.
126
127 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
128 configuration to be used.
129 :attribute divisor_radicand: divisor for div/rem and radicand for
130 sqrt/rsqrt. Signal with a bit-width of ``core_config.bit_width`` and a
131 fract-width of ``core_config.fract_width`` bits.
132 :attribute operation: the ``DivPipeCoreOperation`` to be computed.
133 :attribute quotient_root: the quotient or root part of the result of the
134 operation. Signal with a bit-width of ``core_config.bit_width`` and a
135 fract-width of ``core_config.fract_width`` bits.
136 :attribute root_times_radicand: ``quotient_root * divisor_radicand``.
137 Signal with a bit-width of ``core_config.bit_width * 2`` and a
138 fract-width of ``core_config.fract_width * 2`` bits.
139 :attribute compare_lhs: The left-hand-side of the comparison in the
140 equation to be solved. Signal with a bit-width of
141 ``core_config.bit_width * 3`` and a fract-width of
142 ``core_config.fract_width * 3`` bits.
143 :attribute compare_rhs: The right-hand-side of the comparison in the
144 equation to be solved. Signal with a bit-width of
145 ``core_config.bit_width * 3`` and a fract-width of
146 ``core_config.fract_width * 3`` bits.
147 """
148
149 def __init__(self, core_config, reset_less=True):
150 """ Create a ``DivPipeCoreInterstageData`` instance. """
151 self.core_config = core_config
152 bw = core_config.bit_width
153 self.divisor_radicand = Signal(bw, reset_less=reset_less)
154 self.operation = DP.create_signal(reset_less=reset_less)
155 self.quotient_root = Signal(bw, reset_less=reset_less)
156 self.root_times_radicand = Signal(bw * 2, reset_less=reset_less)
157 self.compare_lhs = Signal(bw * 3, reset_less=reset_less)
158 self.compare_rhs = Signal(bw * 3, reset_less=reset_less)
159
160 def __iter__(self):
161 """ Get member signals. """
162 yield self.divisor_radicand
163 yield self.operation
164 yield self.quotient_root
165 yield self.root_times_radicand
166 yield self.compare_lhs
167 yield self.compare_rhs
168
169 def eq(self, rhs):
170 """ Assign member signals. """
171 return [self.divisor_radicand.eq(rhs.divisor_radicand),
172 self.operation.eq(rhs.operation),
173 self.quotient_root.eq(rhs.quotient_root),
174 self.root_times_radicand.eq(rhs.root_times_radicand),
175 self.compare_lhs.eq(rhs.compare_lhs),
176 self.compare_rhs.eq(rhs.compare_rhs)]
177
178
179 class DivPipeCoreOutputData:
180 """ output data type for ``DivPipeCore``.
181
182 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
183 configuration to be used.
184 :attribute quotient_root: the quotient or root part of the result of the
185 operation. Signal with a bit-width of ``core_config.bit_width`` and a
186 fract-width of ``core_config.fract_width`` bits.
187 :attribute remainder: the remainder part of the result of the operation.
188 Signal with a bit-width of ``core_config.bit_width * 3`` and a
189 fract-width of ``core_config.fract_width * 3`` bits.
190 """
191
192 def __init__(self, core_config, reset_less=True):
193 """ Create a ``DivPipeCoreOutputData`` instance. """
194 self.core_config = core_config
195 bw = core_config.bit_width
196 self.quotient_root = Signal(bw, reset_less=reset_less)
197 self.remainder = Signal(bw * 3, reset_less=reset_less)
198
199 def __iter__(self):
200 """ Get member signals. """
201 yield self.quotient_root
202 yield self.remainder
203 return
204
205 def eq(self, rhs):
206 """ Assign member signals. """
207 return [self.quotient_root.eq(rhs.quotient_root),
208 self.remainder.eq(rhs.remainder)]
209
210
211 class DivPipeCoreSetupStage(Elaboratable):
212 """ Setup Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
213
214 def __init__(self, core_config):
215 """ Create a ``DivPipeCoreSetupStage`` instance."""
216 self.core_config = core_config
217 self.i = self.ispec()
218 self.o = self.ospec()
219
220 def ispec(self):
221 """ Get the input spec for this pipeline stage."""
222 return DivPipeCoreInputData(self.core_config)
223
224 def ospec(self):
225 """ Get the output spec for this pipeline stage."""
226 return DivPipeCoreInterstageData(self.core_config)
227
228 def setup(self, m, i):
229 """ Pipeline stage setup. """
230 m.submodules.div_pipe_core_setup = self
231 m.d.comb += self.i.eq(i)
232
233 def process(self, i):
234 """ Pipeline stage process. """
235 return self.o # return processed data (ignore i)
236
237 def elaborate(self, platform):
238 """ Elaborate into ``Module``. """
239 m = Module()
240 comb = m.d.comb
241
242 comb += self.o.divisor_radicand.eq(self.i.divisor_radicand)
243 comb += self.o.quotient_root.eq(0)
244 comb += self.o.root_times_radicand.eq(0)
245
246 lhs = Signal(self.core_config.bit_width * 3, reset_less=True)
247 fw = self.core_config.fract_width
248
249 with m.Switch(self.i.operation):
250 with m.Case(int(DP.UDivRem)):
251 comb += lhs.eq(self.i.dividend << fw)
252 with m.Case(int(DP.SqrtRem)):
253 comb += lhs.eq(self.i.divisor_radicand << (fw * 2))
254 with m.Case(int(DP.RSqrtRem)):
255 comb += lhs.eq(1 << (fw * 3))
256
257 comb += self.o.compare_lhs.eq(lhs)
258 comb += self.o.compare_rhs.eq(0)
259 comb += self.o.operation.eq(self.i.operation)
260
261 return m
262
263
264 class Trial(Elaboratable):
265 def __init__(self, core_config, trial_bits, current_shift, log2_radix):
266 self.core_config = core_config
267 self.trial_bits = trial_bits
268 self.current_shift = current_shift
269 self.log2_radix = log2_radix
270 bw = core_config.bit_width
271 self.divisor_radicand = Signal(bw, reset_less=True)
272 self.quotient_root = Signal(bw, reset_less=True)
273 self.root_times_radicand = Signal(bw * 2, reset_less=True)
274 self.compare_rhs = Signal(bw * 3, reset_less=True)
275 self.trial_compare_rhs = Signal(bw * 3, reset_less=True)
276 self.operation = DP.create_signal(reset_less=True)
277
278 def elaborate(self, platform):
279
280 m = Module()
281 comb = m.d.comb
282
283 cc = self.core_config
284 dr = self.divisor_radicand
285
286 trial_bits_sig = Const(self.trial_bits, self.log2_radix)
287 trial_bits_sqrd_sig = Const(self.trial_bits * self.trial_bits,
288 self.log2_radix * 2)
289
290 tblen = self.core_config.bit_width+self.log2_radix
291
292 # UDivRem
293 if DP.UDivRem in cc.supported:
294 with m.If(self.operation == int(DP.UDivRem)):
295 dr_times_trial_bits = Signal(tblen, reset_less=True)
296 comb += dr_times_trial_bits.eq(dr * trial_bits_sig)
297 div_rhs = self.compare_rhs
298
299 div_term1 = dr_times_trial_bits
300 div_term1_shift = self.core_config.fract_width
301 div_term1_shift += self.current_shift
302 div_rhs += div_term1 << div_term1_shift
303
304 comb += self.trial_compare_rhs.eq(div_rhs)
305
306 # SqrtRem
307 if DP.SqrtRem in cc.supported:
308 with m.If(self.operation == int(DP.SqrtRem)):
309 qr = self.quotient_root
310 qr_times_trial_bits = Signal((tblen+1)*2, reset_less=True)
311 comb += qr_times_trial_bits.eq(qr * trial_bits_sig)
312 sqrt_rhs = self.compare_rhs
313
314 sqrt_term1 = qr_times_trial_bits
315 sqrt_term1_shift = self.core_config.fract_width
316 sqrt_term1_shift += self.current_shift + 1
317 sqrt_rhs += sqrt_term1 << sqrt_term1_shift
318 sqrt_term2 = trial_bits_sqrd_sig
319 sqrt_term2_shift = self.core_config.fract_width
320 sqrt_term2_shift += self.current_shift * 2
321 sqrt_rhs += sqrt_term2 << sqrt_term2_shift
322
323 comb += self.trial_compare_rhs.eq(sqrt_rhs)
324
325 # RSqrtRem
326 if DP.RSqrtRem in cc.supported:
327 with m.If(self.operation == int(DP.RSqrtRem)):
328 rr = self.root_times_radicand
329 tblen2 = self.core_config.bit_width+self.log2_radix*2
330 dr_times_trial_bits_sqrd = Signal(tblen2, reset_less=True)
331 comb += dr_times_trial_bits_sqrd.eq(dr * trial_bits_sqrd_sig)
332 rr_times_trial_bits = Signal((tblen+1)*3, reset_less=True)
333 comb += rr_times_trial_bits.eq(rr * trial_bits_sig)
334 rsqrt_rhs = self.compare_rhs
335
336 rsqrt_term1 = rr_times_trial_bits
337 rsqrt_term1_shift = self.current_shift + 1
338 rsqrt_rhs += rsqrt_term1 << rsqrt_term1_shift
339 rsqrt_term2 = dr_times_trial_bits_sqrd
340 rsqrt_term2_shift = self.current_shift * 2
341 rsqrt_rhs += rsqrt_term2 << rsqrt_term2_shift
342
343 comb += self.trial_compare_rhs.eq(rsqrt_rhs)
344
345 return m
346
347
348 class DivPipeCoreCalculateStage(Elaboratable):
349 """ Calculate Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
350
351 def __init__(self, core_config, stage_index):
352 """ Create a ``DivPipeCoreSetupStage`` instance. """
353 assert stage_index in range(core_config.n_stages)
354 self.core_config = core_config
355 self.stage_index = stage_index
356 self.i = self.ispec()
357 self.o = self.ospec()
358
359 def ispec(self):
360 """ Get the input spec for this pipeline stage. """
361 return DivPipeCoreInterstageData(self.core_config)
362
363 def ospec(self):
364 """ Get the output spec for this pipeline stage. """
365 return DivPipeCoreInterstageData(self.core_config)
366
367 def setup(self, m, i):
368 """ Pipeline stage setup. """
369 setattr(m.submodules,
370 f"div_pipe_core_calculate_{self.stage_index}",
371 self)
372 m.d.comb += self.i.eq(i)
373
374 def process(self, i):
375 """ Pipeline stage process. """
376 return self.o
377
378 def elaborate(self, platform):
379 """ Elaborate into ``Module``. """
380 m = Module()
381 comb = m.d.comb
382 cc = self.core_config
383
384 # copy invariant inputs to outputs (for next stage)
385 comb += self.o.divisor_radicand.eq(self.i.divisor_radicand)
386 comb += self.o.operation.eq(self.i.operation)
387 comb += self.o.compare_lhs.eq(self.i.compare_lhs)
388
389 # constants
390 log2_radix = self.core_config.log2_radix
391 current_shift = self.core_config.bit_width
392 current_shift -= self.stage_index * log2_radix
393 log2_radix = min(log2_radix, current_shift)
394 assert log2_radix > 0
395 current_shift -= log2_radix
396 print(f"DivPipeCoreCalc: stage {self.stage_index}"
397 + f" of {self.core_config.n_stages} handling "
398 + f"bits [{current_shift}, {current_shift+log2_radix})"
399 + f" of {self.core_config.bit_width}")
400 radix = 1 << log2_radix
401
402 # trials within this radix range. carried out by Trial module,
403 # results stored in pass_flags. pass_flags are unary priority.
404 trial_compare_rhs_values = []
405 pfl = []
406 for trial_bits in range(radix):
407 t = Trial(self.core_config, trial_bits, current_shift, log2_radix)
408 setattr(m.submodules, "trial%d" % trial_bits, t)
409
410 comb += t.divisor_radicand.eq(self.i.divisor_radicand)
411 comb += t.quotient_root.eq(self.i.quotient_root)
412 comb += t.root_times_radicand.eq(self.i.root_times_radicand)
413 comb += t.compare_rhs.eq(self.i.compare_rhs)
414 comb += t.operation.eq(self.i.operation)
415
416 # get the trial output (needed even in pass_flags[0] case)
417 trial_compare_rhs_values.append(t.trial_compare_rhs)
418
419 # make the trial comparison against the [invariant] lhs.
420 # trial_compare_rhs is always decreasing as trial_bits increases
421 pass_flag = Signal(name=f"pass_flag_{trial_bits}", reset_less=True)
422 if trial_bits == 0:
423 # do not do first comparison: no point.
424 comb += pass_flag.eq(1)
425 else:
426 comb += pass_flag.eq(self.i.compare_lhs >= t.trial_compare_rhs)
427 pfl.append(pass_flag)
428
429 # Cat all the pass flags list together (easier to handle, below)
430 pass_flags = Signal(radix, reset_less=True)
431 comb += pass_flags.eq(Cat(*pfl))
432
433 # convert pass_flags (unary priority) to next_bits (binary index)
434 #
435 # Assumes that for each set bit in pass_flag, all previous bits are
436 # also set.
437 #
438 # Assumes that pass_flag[0] is always set (since
439 # compare_lhs >= compare_rhs is a pipeline invariant).
440
441 m.submodules.pe = pe = PriorityEncoder(radix)
442 next_bits = Signal(log2_radix, reset_less=True)
443 comb += pe.i.eq(~pass_flags)
444 with m.If(~pe.n):
445 comb += next_bits.eq(pe.o-1)
446 with m.Else():
447 comb += next_bits.eq(radix-1)
448
449 # get the highest passing rhs trial. use treereduce because
450 # Array on such massively long numbers is insanely gate-hungry
451 crhs = []
452 tcrh = trial_compare_rhs_values
453 bw = self.core_config.bit_width
454 for i in range(radix):
455 nbe = Signal(reset_less=True)
456 comb += nbe.eq(next_bits == i)
457 crhs.append(Repl(nbe, bw*3) & tcrh[i])
458 comb += self.o.compare_rhs.eq(treereduce(crhs, operator.or_,
459 lambda x:x))
460
461 # create outputs for next phase
462 qr = self.i.quotient_root | (next_bits << current_shift)
463 comb += self.o.quotient_root.eq(qr)
464 if DP.RSqrtRem in cc.supported:
465 rr = self.i.root_times_radicand + ((self.i.divisor_radicand *
466 next_bits) << current_shift)
467 comb += self.o.root_times_radicand.eq(rr)
468
469 return m
470
471
472 class DivPipeCoreFinalStage(Elaboratable):
473 """ Final Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
474
475 def __init__(self, core_config):
476 """ Create a ``DivPipeCoreFinalStage`` instance."""
477 self.core_config = core_config
478 self.i = self.ispec()
479 self.o = self.ospec()
480
481 def ispec(self):
482 """ Get the input spec for this pipeline stage."""
483 return DivPipeCoreInterstageData(self.core_config)
484
485 def ospec(self):
486 """ Get the output spec for this pipeline stage."""
487 return DivPipeCoreOutputData(self.core_config)
488
489 def setup(self, m, i):
490 """ Pipeline stage setup. """
491 m.submodules.div_pipe_core_final = self
492 m.d.comb += self.i.eq(i)
493
494 def process(self, i):
495 """ Pipeline stage process. """
496 return self.o # return processed data (ignore i)
497
498 def elaborate(self, platform):
499 """ Elaborate into ``Module``. """
500 m = Module()
501 comb = m.d.comb
502
503 comb += self.o.quotient_root.eq(self.i.quotient_root)
504 comb += self.o.remainder.eq(self.i.compare_lhs - self.i.compare_rhs)
505
506 return m