97b4a998423b4a5d2a17808c721d3c8da53753cd
[ieee754fpu.git] / src / ieee754 / div_rem_sqrt_rsqrt / core.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """ Core of the div/rem/sqrt/rsqrt pipeline.
4
5 Special case handling, input/output conversion, and muxid handling are handled
6 outside of these classes.
7
8 Algorithms based on ``algorithm.FixedUDivRemSqrtRSqrt``.
9
10 Formulas solved are:
11 * div/rem:
12 ``dividend == quotient_root * divisor_radicand``
13 * sqrt/rem:
14 ``divisor_radicand == quotient_root * quotient_root``
15 * rsqrt/rem:
16 ``1 == quotient_root * quotient_root * divisor_radicand``
17
18 The remainder is the left-hand-side of the comparison minus the
19 right-hand-side of the comparison in the above formulas.
20 """
21 from nmigen import (Elaboratable, Module, Signal, Const, Mux, Cat, Array)
22 from nmigen.lib.coding import PriorityEncoder
23 import enum
24
25
26 class DivPipeCoreConfig:
27 """ Configuration for core of the div/rem/sqrt/rsqrt pipeline.
28
29 :attribute bit_width: base bit-width.
30 :attribute fract_width: base fract-width. Specifies location of base-2
31 radix point.
32 :attribute log2_radix: number of bits of ``quotient_root`` that should be
33 computed per pipeline stage.
34 """
35
36 def __init__(self, bit_width, fract_width, log2_radix):
37 """ Create a ``DivPipeCoreConfig`` instance. """
38 self.bit_width = bit_width
39 self.fract_width = fract_width
40 self.log2_radix = log2_radix
41
42 def __repr__(self):
43 """ Get repr. """
44 return f"DivPipeCoreConfig({self.bit_width}, " \
45 + f"{self.fract_width}, {self.log2_radix})"
46
47 @property
48 def n_stages(self):
49 """ Get the number of ``DivPipeCoreCalculateStage`` needed. """
50 return (self.bit_width + self.log2_radix - 1) // self.log2_radix
51
52
53 class DivPipeCoreOperation(enum.Enum):
54 """ Operation for ``DivPipeCore``.
55
56 :attribute UDivRem: unsigned divide/remainder.
57 :attribute SqrtRem: square-root/remainder.
58 :attribute RSqrtRem: reciprocal-square-root/remainder.
59 """
60
61 UDivRem = 0
62 SqrtRem = 1
63 RSqrtRem = 2
64
65 def __int__(self):
66 """ Convert to int. """
67 return self.value
68
69 @classmethod
70 def create_signal(cls, *, src_loc_at=0, **kwargs):
71 """ Create a signal that can contain a ``DivPipeCoreOperation``. """
72 return Signal(min=min(map(int, cls)),
73 max=max(map(int, cls)) + 2,
74 src_loc_at=(src_loc_at + 1),
75 decoder=lambda v: str(cls(v)),
76 **kwargs)
77
78
79 DP = DivPipeCoreOperation
80
81 class DivPipeCoreInputData:
82 """ input data type for ``DivPipeCore``.
83
84 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
85 configuration to be used.
86 :attribute dividend: dividend for div/rem. Signal with a bit-width of
87 ``core_config.bit_width + core_config.fract_width`` and a fract-width
88 of ``core_config.fract_width * 2`` bits.
89 :attribute divisor_radicand: divisor for div/rem and radicand for
90 sqrt/rsqrt. Signal with a bit-width of ``core_config.bit_width`` and a
91 fract-width of ``core_config.fract_width`` bits.
92 :attribute operation: the ``DivPipeCoreOperation`` to be computed.
93 """
94
95 def __init__(self, core_config, reset_less=True):
96 """ Create a ``DivPipeCoreInputData`` instance. """
97 self.core_config = core_config
98 self.dividend = Signal(core_config.bit_width + core_config.fract_width,
99 reset_less=reset_less)
100 self.divisor_radicand = Signal(core_config.bit_width,
101 reset_less=reset_less)
102 self.operation = DP.create_signal(reset_less=reset_less)
103
104 def __iter__(self):
105 """ Get member signals. """
106 yield self.dividend
107 yield self.divisor_radicand
108 yield self.operation
109
110 def eq(self, rhs):
111 """ Assign member signals. """
112 return [self.dividend.eq(rhs.dividend),
113 self.divisor_radicand.eq(rhs.divisor_radicand),
114 self.operation.eq(rhs.operation),
115 ]
116
117
118 class DivPipeCoreInterstageData:
119 """ interstage data type for ``DivPipeCore``.
120
121 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
122 configuration to be used.
123 :attribute divisor_radicand: divisor for div/rem and radicand for
124 sqrt/rsqrt. Signal with a bit-width of ``core_config.bit_width`` and a
125 fract-width of ``core_config.fract_width`` bits.
126 :attribute operation: the ``DivPipeCoreOperation`` to be computed.
127 :attribute quotient_root: the quotient or root part of the result of the
128 operation. Signal with a bit-width of ``core_config.bit_width`` and a
129 fract-width of ``core_config.fract_width`` bits.
130 :attribute root_times_radicand: ``quotient_root * divisor_radicand``.
131 Signal with a bit-width of ``core_config.bit_width * 2`` and a
132 fract-width of ``core_config.fract_width * 2`` bits.
133 :attribute compare_lhs: The left-hand-side of the comparison in the
134 equation to be solved. Signal with a bit-width of
135 ``core_config.bit_width * 3`` and a fract-width of
136 ``core_config.fract_width * 3`` bits.
137 :attribute compare_rhs: The right-hand-side of the comparison in the
138 equation to be solved. Signal with a bit-width of
139 ``core_config.bit_width * 3`` and a fract-width of
140 ``core_config.fract_width * 3`` bits.
141 """
142
143 def __init__(self, core_config, reset_less=True):
144 """ Create a ``DivPipeCoreInterstageData`` instance. """
145 self.core_config = core_config
146 self.divisor_radicand = Signal(core_config.bit_width,
147 reset_less=reset_less)
148 self.operation = DP.create_signal(reset_less=reset_less)
149 self.quotient_root = Signal(core_config.bit_width,
150 reset_less=reset_less)
151 self.root_times_radicand = Signal(core_config.bit_width * 2,
152 reset_less=reset_less)
153 self.compare_lhs = Signal(core_config.bit_width * 3,
154 reset_less=reset_less)
155 self.compare_rhs = Signal(core_config.bit_width * 3,
156 reset_less=reset_less)
157
158 def __iter__(self):
159 """ Get member signals. """
160 yield self.divisor_radicand
161 yield self.operation
162 yield self.quotient_root
163 yield self.root_times_radicand
164 yield self.compare_lhs
165 yield self.compare_rhs
166
167 def eq(self, rhs):
168 """ Assign member signals. """
169 return [self.divisor_radicand.eq(rhs.divisor_radicand),
170 self.operation.eq(rhs.operation),
171 self.quotient_root.eq(rhs.quotient_root),
172 self.root_times_radicand.eq(rhs.root_times_radicand),
173 self.compare_lhs.eq(rhs.compare_lhs),
174 self.compare_rhs.eq(rhs.compare_rhs)]
175
176
177 class DivPipeCoreOutputData:
178 """ output data type for ``DivPipeCore``.
179
180 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
181 configuration to be used.
182 :attribute quotient_root: the quotient or root part of the result of the
183 operation. Signal with a bit-width of ``core_config.bit_width`` and a
184 fract-width of ``core_config.fract_width`` bits.
185 :attribute remainder: the remainder part of the result of the operation.
186 Signal with a bit-width of ``core_config.bit_width * 3`` and a
187 fract-width of ``core_config.fract_width * 3`` bits.
188 """
189
190 def __init__(self, core_config, reset_less=True):
191 """ Create a ``DivPipeCoreOutputData`` instance. """
192 self.core_config = core_config
193 self.quotient_root = Signal(core_config.bit_width,
194 reset_less=reset_less)
195 self.remainder = Signal(core_config.bit_width * 3,
196 reset_less=reset_less)
197
198 def __iter__(self):
199 """ Get member signals. """
200 yield self.quotient_root
201 yield self.remainder
202 return
203
204 def eq(self, rhs):
205 """ Assign member signals. """
206 return [self.quotient_root.eq(rhs.quotient_root),
207 self.remainder.eq(rhs.remainder)]
208
209
210 class DivPipeCoreSetupStage(Elaboratable):
211 """ Setup Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
212
213 def __init__(self, core_config):
214 """ Create a ``DivPipeCoreSetupStage`` instance."""
215 self.core_config = core_config
216 self.i = self.ispec()
217 self.o = self.ospec()
218
219 def ispec(self):
220 """ Get the input spec for this pipeline stage."""
221 return DivPipeCoreInputData(self.core_config)
222
223 def ospec(self):
224 """ Get the output spec for this pipeline stage."""
225 return DivPipeCoreInterstageData(self.core_config)
226
227 def setup(self, m, i):
228 """ Pipeline stage setup. """
229 m.submodules.div_pipe_core_setup = self
230 m.d.comb += self.i.eq(i)
231
232 def process(self, i):
233 """ Pipeline stage process. """
234 return self.o # return processed data (ignore i)
235
236 def elaborate(self, platform):
237 """ Elaborate into ``Module``. """
238 m = Module()
239
240 m.d.comb += self.o.divisor_radicand.eq(self.i.divisor_radicand)
241 m.d.comb += self.o.quotient_root.eq(0)
242 m.d.comb += self.o.root_times_radicand.eq(0)
243
244 with m.If(self.i.operation == int(DP.UDivRem)):
245 m.d.comb += self.o.compare_lhs.eq(self.i.dividend
246 << self.core_config.fract_width)
247 with m.Elif(self.i.operation == int(DP.SqrtRem)):
248 m.d.comb += self.o.compare_lhs.eq(
249 self.i.divisor_radicand << (self.core_config.fract_width * 2))
250 with m.Else(): # DivPipeCoreOperation.RSqrtRem
251 m.d.comb += self.o.compare_lhs.eq(
252 1 << (self.core_config.fract_width * 3))
253
254 m.d.comb += self.o.compare_rhs.eq(0)
255 m.d.comb += self.o.operation.eq(self.i.operation)
256
257 return m
258
259
260 class Trial(Elaboratable):
261 def __init__(self, core_config, trial_bits, current_shift, log2_radix):
262 self.core_config = core_config
263 self.trial_bits = trial_bits
264 self.current_shift = current_shift
265 self.log2_radix = log2_radix
266 bw = core_config.bit_width
267 self.divisor_radicand = Signal(bw, reset_less=True)
268 self.quotient_root = Signal(bw, reset_less=True)
269 self.root_times_radicand = Signal(bw * 2, reset_less=True)
270 self.compare_rhs = Signal(bw * 3, reset_less=True)
271 self.trial_compare_rhs = Signal(bw * 3, reset_less=True)
272 self.operation = DP.create_signal(reset_less=True)
273
274 def elaborate(self, platform):
275
276 m = Module()
277
278 dr = self.divisor_radicand
279 qr = self.quotient_root
280 rr = self.root_times_radicand
281
282 trial_bits_sig = Const(self.trial_bits, self.log2_radix)
283 trial_bits_sqrd_sig = Const(self.trial_bits * self.trial_bits,
284 self.log2_radix * 2)
285
286 tblen = self.core_config.bit_width+self.log2_radix
287 tblen2 = self.core_config.bit_width+self.log2_radix*2
288 dr_times_trial_bits_sqrd = Signal(tblen2, reset_less=True)
289 m.d.comb += dr_times_trial_bits_sqrd.eq(dr * trial_bits_sqrd_sig)
290
291 # UDivRem
292 with m.If(self.operation == int(DP.UDivRem)):
293 dr_times_trial_bits = Signal(tblen, reset_less=True)
294 m.d.comb += dr_times_trial_bits.eq(dr * trial_bits_sig)
295 div_rhs = self.compare_rhs
296
297 div_term1 = dr_times_trial_bits
298 div_term1_shift = self.core_config.fract_width
299 div_term1_shift += self.current_shift
300 div_rhs += div_term1 << div_term1_shift
301
302 m.d.comb += self.trial_compare_rhs.eq(div_rhs)
303
304 # SqrtRem
305 with m.Elif(self.operation == int(DP.SqrtRem)):
306 qr_times_trial_bits = Signal((tblen+1)*2, reset_less=True)
307 m.d.comb += qr_times_trial_bits.eq(qr * trial_bits_sig)
308 sqrt_rhs = self.compare_rhs
309
310 sqrt_term1 = qr_times_trial_bits
311 sqrt_term1_shift = self.core_config.fract_width
312 sqrt_term1_shift += self.current_shift + 1
313 sqrt_rhs += sqrt_term1 << sqrt_term1_shift
314 sqrt_term2 = trial_bits_sqrd_sig
315 sqrt_term2_shift = self.core_config.fract_width
316 sqrt_term2_shift += self.current_shift * 2
317 sqrt_rhs += sqrt_term2 << sqrt_term2_shift
318
319 m.d.comb += self.trial_compare_rhs.eq(sqrt_rhs)
320
321 # RSqrtRem
322 with m.Else():
323 rr_times_trial_bits = Signal((tblen+1)*3, reset_less=True)
324 m.d.comb += rr_times_trial_bits.eq(rr * trial_bits_sig)
325 rsqrt_rhs = self.compare_rhs
326
327 rsqrt_term1 = rr_times_trial_bits
328 rsqrt_term1_shift = self.current_shift + 1
329 rsqrt_rhs += rsqrt_term1 << rsqrt_term1_shift
330 rsqrt_term2 = dr_times_trial_bits_sqrd
331 rsqrt_term2_shift = self.current_shift * 2
332 rsqrt_rhs += rsqrt_term2 << rsqrt_term2_shift
333
334 m.d.comb += self.trial_compare_rhs.eq(rsqrt_rhs)
335
336 return m
337
338
339 class DivPipeCoreCalculateStage(Elaboratable):
340 """ Calculate Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
341
342 def __init__(self, core_config, stage_index):
343 """ Create a ``DivPipeCoreSetupStage`` instance. """
344 self.core_config = core_config
345 assert stage_index in range(core_config.n_stages)
346 self.stage_index = stage_index
347 self.i = self.ispec()
348 self.o = self.ospec()
349
350 def ispec(self):
351 """ Get the input spec for this pipeline stage. """
352 return DivPipeCoreInterstageData(self.core_config)
353
354 def ospec(self):
355 """ Get the output spec for this pipeline stage. """
356 return DivPipeCoreInterstageData(self.core_config)
357
358 def setup(self, m, i):
359 """ Pipeline stage setup. """
360 setattr(m.submodules,
361 f"div_pipe_core_calculate_{self.stage_index}",
362 self)
363 m.d.comb += self.i.eq(i)
364
365 def process(self, i):
366 """ Pipeline stage process. """
367 return self.o
368
369 def elaborate(self, platform):
370 """ Elaborate into ``Module``. """
371 m = Module()
372 m.d.comb += self.o.divisor_radicand.eq(self.i.divisor_radicand)
373 m.d.comb += self.o.operation.eq(self.i.operation)
374 m.d.comb += self.o.compare_lhs.eq(self.i.compare_lhs)
375 log2_radix = self.core_config.log2_radix
376 current_shift = self.core_config.bit_width
377 current_shift -= self.stage_index * log2_radix
378 log2_radix = min(log2_radix, current_shift)
379 assert log2_radix > 0
380 current_shift -= log2_radix
381 radix = 1 << log2_radix
382 trial_compare_rhs_values = []
383 pfl = []
384 for trial_bits in range(radix):
385 t = Trial(self.core_config, trial_bits,
386 current_shift, log2_radix)
387 setattr(m.submodules, "trial%d" % trial_bits, t)
388 m.d.comb += t.divisor_radicand.eq(self.i.divisor_radicand)
389 m.d.comb += t.quotient_root.eq(self.i.quotient_root)
390 m.d.comb += t.root_times_radicand.eq(self.i.root_times_radicand)
391 m.d.comb += t.compare_rhs.eq(self.i.compare_rhs)
392 m.d.comb += t.operation.eq(self.i.operation)
393
394 trial_compare_rhs_values.append(t.trial_compare_rhs)
395
396 pass_flag = Signal(name=f"pass_flag_{trial_bits}", reset_less=True)
397 m.d.comb += pass_flag.eq(self.i.compare_lhs >= t.trial_compare_rhs)
398 pfl.append(pass_flag)
399 pass_flags = Signal(radix, reset_less=True)
400 m.d.comb += pass_flags.eq(Cat(*pfl))
401
402 # convert pass_flags (unary priority) to next_bits (binary index)
403 #
404 # Assumes that for each set bit in pass_flag, all previous bits are
405 # also set.
406 #
407 # Assumes that pass_flag[0] is always set (since
408 # compare_lhs >= compare_rhs is a pipeline invariant).
409
410 m.submodules.pe = pe = PriorityEncoder(radix)
411 next_bits = Signal(log2_radix+1, reset_less=True)
412 m.d.comb += pe.i.eq(~pass_flags)
413 with m.If(~pe.n):
414 m.d.comb += next_bits.eq(pe.o-1)
415 with m.Else():
416 m.d.comb += next_bits.eq(radix-1)
417
418 # get the highest passing rhs trial (indexed by next_bits)
419 ta = Array(trial_compare_rhs_values)
420 m.d.comb += self.o.compare_rhs.eq(ta[next_bits])
421
422 # creae outputs for next phase
423 m.d.comb += self.o.root_times_radicand.eq(self.i.root_times_radicand
424 + ((self.i.divisor_radicand
425 * next_bits)
426 << current_shift))
427 m.d.comb += self.o.quotient_root.eq(self.i.quotient_root
428 | (next_bits << current_shift))
429 return m
430
431
432 class DivPipeCoreFinalStage(Elaboratable):
433 """ Final Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
434
435 def __init__(self, core_config):
436 """ Create a ``DivPipeCoreFinalStage`` instance."""
437 self.core_config = core_config
438 self.i = self.ispec()
439 self.o = self.ospec()
440
441 def ispec(self):
442 """ Get the input spec for this pipeline stage."""
443 return DivPipeCoreInterstageData(self.core_config)
444
445 def ospec(self):
446 """ Get the output spec for this pipeline stage."""
447 return DivPipeCoreOutputData(self.core_config)
448
449 def setup(self, m, i):
450 """ Pipeline stage setup. """
451 m.submodules.div_pipe_core_final = self
452 m.d.comb += self.i.eq(i)
453
454 def process(self, i):
455 """ Pipeline stage process. """
456 return self.o # return processed data (ignore i)
457
458 def elaborate(self, platform):
459 """ Elaborate into ``Module``. """
460 m = Module()
461
462 m.d.comb += self.o.quotient_root.eq(self.i.quotient_root)
463 m.d.comb += self.o.remainder.eq(self.i.compare_lhs
464 - self.i.compare_rhs)
465
466 return m