split out div/sqrt/rsqrt trials to separate module
[ieee754fpu.git] / src / ieee754 / div_rem_sqrt_rsqrt / core.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """ Core of the div/rem/sqrt/rsqrt pipeline.
4
5 Special case handling, input/output conversion, and muxid handling are handled
6 outside of these classes.
7
8 Algorithms based on ``algorithm.FixedUDivRemSqrtRSqrt``.
9
10 Formulas solved are:
11 * div/rem:
12 ``dividend == quotient_root * divisor_radicand``
13 * sqrt/rem:
14 ``divisor_radicand == quotient_root * quotient_root``
15 * rsqrt/rem:
16 ``1 == quotient_root * quotient_root * divisor_radicand``
17
18 The remainder is the left-hand-side of the comparison minus the
19 right-hand-side of the comparison in the above formulas.
20 """
21 from nmigen import (Elaboratable, Module, Signal, Const, Mux, Cat)
22 import enum
23
24
25 class DivPipeCoreConfig:
26 """ Configuration for core of the div/rem/sqrt/rsqrt pipeline.
27
28 :attribute bit_width: base bit-width.
29 :attribute fract_width: base fract-width. Specifies location of base-2
30 radix point.
31 :attribute log2_radix: number of bits of ``quotient_root`` that should be
32 computed per pipeline stage.
33 """
34
35 def __init__(self, bit_width, fract_width, log2_radix):
36 """ Create a ``DivPipeCoreConfig`` instance. """
37 self.bit_width = bit_width
38 self.fract_width = fract_width
39 self.log2_radix = log2_radix
40
41 def __repr__(self):
42 """ Get repr. """
43 return f"DivPipeCoreConfig({self.bit_width}, " \
44 + f"{self.fract_width}, {self.log2_radix})"
45
46 @property
47 def n_stages(self):
48 """ Get the number of ``DivPipeCoreCalculateStage`` needed. """
49 return (self.bit_width + self.log2_radix - 1) // self.log2_radix
50
51
52 class DivPipeCoreOperation(enum.Enum):
53 """ Operation for ``DivPipeCore``.
54
55 :attribute UDivRem: unsigned divide/remainder.
56 :attribute SqrtRem: square-root/remainder.
57 :attribute RSqrtRem: reciprocal-square-root/remainder.
58 """
59
60 UDivRem = 0
61 SqrtRem = 1
62 RSqrtRem = 2
63
64 def __int__(self):
65 """ Convert to int. """
66 return self.value
67
68 @classmethod
69 def create_signal(cls, *, src_loc_at=0, **kwargs):
70 """ Create a signal that can contain a ``DivPipeCoreOperation``. """
71 return Signal(min=min(map(int, cls)),
72 max=max(map(int, cls)) + 2,
73 src_loc_at=(src_loc_at + 1),
74 decoder=lambda v: str(cls(v)),
75 **kwargs)
76
77
78 DP = DivPipeCoreOperation
79
80 class DivPipeCoreInputData:
81 """ input data type for ``DivPipeCore``.
82
83 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
84 configuration to be used.
85 :attribute dividend: dividend for div/rem. Signal with a bit-width of
86 ``core_config.bit_width + core_config.fract_width`` and a fract-width
87 of ``core_config.fract_width * 2`` bits.
88 :attribute divisor_radicand: divisor for div/rem and radicand for
89 sqrt/rsqrt. Signal with a bit-width of ``core_config.bit_width`` and a
90 fract-width of ``core_config.fract_width`` bits.
91 :attribute operation: the ``DivPipeCoreOperation`` to be computed.
92 """
93
94 def __init__(self, core_config, reset_less=True):
95 """ Create a ``DivPipeCoreInputData`` instance. """
96 self.core_config = core_config
97 self.dividend = Signal(core_config.bit_width + core_config.fract_width,
98 reset_less=reset_less)
99 self.divisor_radicand = Signal(core_config.bit_width,
100 reset_less=reset_less)
101 self.operation = DP.create_signal(reset_less=reset_less)
102
103 def __iter__(self):
104 """ Get member signals. """
105 yield self.dividend
106 yield self.divisor_radicand
107 yield self.operation
108
109 def eq(self, rhs):
110 """ Assign member signals. """
111 return [self.dividend.eq(rhs.dividend),
112 self.divisor_radicand.eq(rhs.divisor_radicand),
113 self.operation.eq(rhs.operation),
114 ]
115
116
117 class DivPipeCoreInterstageData:
118 """ interstage data type for ``DivPipeCore``.
119
120 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
121 configuration to be used.
122 :attribute divisor_radicand: divisor for div/rem and radicand for
123 sqrt/rsqrt. Signal with a bit-width of ``core_config.bit_width`` and a
124 fract-width of ``core_config.fract_width`` bits.
125 :attribute operation: the ``DivPipeCoreOperation`` to be computed.
126 :attribute quotient_root: the quotient or root part of the result of the
127 operation. Signal with a bit-width of ``core_config.bit_width`` and a
128 fract-width of ``core_config.fract_width`` bits.
129 :attribute root_times_radicand: ``quotient_root * divisor_radicand``.
130 Signal with a bit-width of ``core_config.bit_width * 2`` and a
131 fract-width of ``core_config.fract_width * 2`` bits.
132 :attribute compare_lhs: The left-hand-side of the comparison in the
133 equation to be solved. Signal with a bit-width of
134 ``core_config.bit_width * 3`` and a fract-width of
135 ``core_config.fract_width * 3`` bits.
136 :attribute compare_rhs: The right-hand-side of the comparison in the
137 equation to be solved. Signal with a bit-width of
138 ``core_config.bit_width * 3`` and a fract-width of
139 ``core_config.fract_width * 3`` bits.
140 """
141
142 def __init__(self, core_config, reset_less=True):
143 """ Create a ``DivPipeCoreInterstageData`` instance. """
144 self.core_config = core_config
145 self.divisor_radicand = Signal(core_config.bit_width,
146 reset_less=reset_less)
147 self.operation = DP.create_signal(reset_less=reset_less)
148 self.quotient_root = Signal(core_config.bit_width,
149 reset_less=reset_less)
150 self.root_times_radicand = Signal(core_config.bit_width * 2,
151 reset_less=reset_less)
152 self.compare_lhs = Signal(core_config.bit_width * 3,
153 reset_less=reset_less)
154 self.compare_rhs = Signal(core_config.bit_width * 3,
155 reset_less=reset_less)
156
157 def __iter__(self):
158 """ Get member signals. """
159 yield self.divisor_radicand
160 yield self.operation
161 yield self.quotient_root
162 yield self.root_times_radicand
163 yield self.compare_lhs
164 yield self.compare_rhs
165
166 def eq(self, rhs):
167 """ Assign member signals. """
168 return [self.divisor_radicand.eq(rhs.divisor_radicand),
169 self.operation.eq(rhs.operation),
170 self.quotient_root.eq(rhs.quotient_root),
171 self.root_times_radicand.eq(rhs.root_times_radicand),
172 self.compare_lhs.eq(rhs.compare_lhs),
173 self.compare_rhs.eq(rhs.compare_rhs)]
174
175
176 class DivPipeCoreOutputData:
177 """ output data type for ``DivPipeCore``.
178
179 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
180 configuration to be used.
181 :attribute quotient_root: the quotient or root part of the result of the
182 operation. Signal with a bit-width of ``core_config.bit_width`` and a
183 fract-width of ``core_config.fract_width`` bits.
184 :attribute remainder: the remainder part of the result of the operation.
185 Signal with a bit-width of ``core_config.bit_width * 3`` and a
186 fract-width of ``core_config.fract_width * 3`` bits.
187 """
188
189 def __init__(self, core_config, reset_less=True):
190 """ Create a ``DivPipeCoreOutputData`` instance. """
191 self.core_config = core_config
192 self.quotient_root = Signal(core_config.bit_width,
193 reset_less=reset_less)
194 self.remainder = Signal(core_config.bit_width * 3,
195 reset_less=reset_less)
196
197 def __iter__(self):
198 """ Get member signals. """
199 yield self.quotient_root
200 yield self.remainder
201 return
202
203 def eq(self, rhs):
204 """ Assign member signals. """
205 return [self.quotient_root.eq(rhs.quotient_root),
206 self.remainder.eq(rhs.remainder)]
207
208
209 class DivPipeCoreSetupStage(Elaboratable):
210 """ Setup Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
211
212 def __init__(self, core_config):
213 """ Create a ``DivPipeCoreSetupStage`` instance."""
214 self.core_config = core_config
215 self.i = self.ispec()
216 self.o = self.ospec()
217
218 def ispec(self):
219 """ Get the input spec for this pipeline stage."""
220 return DivPipeCoreInputData(self.core_config)
221
222 def ospec(self):
223 """ Get the output spec for this pipeline stage."""
224 return DivPipeCoreInterstageData(self.core_config)
225
226 def setup(self, m, i):
227 """ Pipeline stage setup. """
228 m.submodules.div_pipe_core_setup = self
229 m.d.comb += self.i.eq(i)
230
231 def process(self, i):
232 """ Pipeline stage process. """
233 return self.o # return processed data (ignore i)
234
235 def elaborate(self, platform):
236 """ Elaborate into ``Module``. """
237 m = Module()
238
239 m.d.comb += self.o.divisor_radicand.eq(self.i.divisor_radicand)
240 m.d.comb += self.o.quotient_root.eq(0)
241 m.d.comb += self.o.root_times_radicand.eq(0)
242
243 with m.If(self.i.operation == int(DP.UDivRem)):
244 m.d.comb += self.o.compare_lhs.eq(self.i.dividend
245 << self.core_config.fract_width)
246 with m.Elif(self.i.operation == int(DP.SqrtRem)):
247 m.d.comb += self.o.compare_lhs.eq(
248 self.i.divisor_radicand << (self.core_config.fract_width * 2))
249 with m.Else(): # DivPipeCoreOperation.RSqrtRem
250 m.d.comb += self.o.compare_lhs.eq(
251 1 << (self.core_config.fract_width * 3))
252
253 m.d.comb += self.o.compare_rhs.eq(0)
254 m.d.comb += self.o.operation.eq(self.i.operation)
255
256 return m
257
258
259 class Trial(Elaboratable):
260 def __init__(self, core_config, trial_bits, current_shift, log2_radix):
261 self.core_config = core_config
262 self.trial_bits = trial_bits
263 self.current_shift = current_shift
264 self.log2_radix = log2_radix
265 bw = core_config.bit_width
266 self.divisor_radicand = Signal(bw, reset_less=True)
267 self.quotient_root = Signal(bw, reset_less=True)
268 self.root_times_radicand = Signal(bw * 2, reset_less=True)
269 self.compare_rhs = Signal(bw * 3, reset_less=True)
270 self.trial_compare_rhs = Signal(bw * 3, reset_less=True)
271 self.operation = DP.create_signal(reset_less=True)
272
273 def elaborate(self, platform):
274
275 m = Module()
276
277 dr = self.divisor_radicand
278 qr = self.quotient_root
279 rr = self.root_times_radicand
280
281 trial_bits_sig = Const(self.trial_bits, self.log2_radix)
282 trial_bits_sqrd_sig = Const(self.trial_bits * self.trial_bits,
283 self.log2_radix * 2)
284
285 tblen = self.core_config.bit_width+self.log2_radix
286 tblen2 = self.core_config.bit_width+self.log2_radix*2
287 dr_times_trial_bits_sqrd = Signal(tblen2, reset_less=True)
288 m.d.comb += dr_times_trial_bits_sqrd.eq(dr * trial_bits_sqrd_sig)
289
290 # UDivRem
291 with m.If(self.operation == int(DP.UDivRem)):
292 dr_times_trial_bits = Signal(tblen, reset_less=True)
293 m.d.comb += dr_times_trial_bits.eq(dr * trial_bits_sig)
294 div_rhs = self.compare_rhs
295
296 div_term1 = dr_times_trial_bits
297 div_term1_shift = self.core_config.fract_width
298 div_term1_shift += self.current_shift
299 div_rhs += div_term1 << div_term1_shift
300
301 m.d.comb += self.trial_compare_rhs.eq(div_rhs)
302
303 # SqrtRem
304 with m.Elif(self.operation == int(DP.SqrtRem)):
305 qr_times_trial_bits = Signal((tblen+1)*2, reset_less=True)
306 m.d.comb += qr_times_trial_bits.eq(qr * trial_bits_sig)
307 sqrt_rhs = self.compare_rhs
308
309 sqrt_term1 = qr_times_trial_bits
310 sqrt_term1_shift = self.core_config.fract_width
311 sqrt_term1_shift += self.current_shift + 1
312 sqrt_rhs += sqrt_term1 << sqrt_term1_shift
313 sqrt_term2 = trial_bits_sqrd_sig
314 sqrt_term2_shift = self.core_config.fract_width
315 sqrt_term2_shift += self.current_shift * 2
316 sqrt_rhs += sqrt_term2 << sqrt_term2_shift
317
318 m.d.comb += self.trial_compare_rhs.eq(sqrt_rhs)
319
320 # RSqrtRem
321 with m.Else():
322 rr_times_trial_bits = Signal((tblen+1)*3, reset_less=True)
323 m.d.comb += rr_times_trial_bits.eq(rr * trial_bits_sig)
324 rsqrt_rhs = self.compare_rhs
325
326 rsqrt_term1 = rr_times_trial_bits
327 rsqrt_term1_shift = self.current_shift + 1
328 rsqrt_rhs += rsqrt_term1 << rsqrt_term1_shift
329 rsqrt_term2 = dr_times_trial_bits_sqrd
330 rsqrt_term2_shift = self.current_shift * 2
331 rsqrt_rhs += rsqrt_term2 << rsqrt_term2_shift
332
333 m.d.comb += self.trial_compare_rhs.eq(rsqrt_rhs)
334
335 return m
336
337
338 class DivPipeCoreCalculateStage(Elaboratable):
339 """ Calculate Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
340
341 def __init__(self, core_config, stage_index):
342 """ Create a ``DivPipeCoreSetupStage`` instance. """
343 self.core_config = core_config
344 assert stage_index in range(core_config.n_stages)
345 self.stage_index = stage_index
346 self.i = self.ispec()
347 self.o = self.ospec()
348
349 def ispec(self):
350 """ Get the input spec for this pipeline stage. """
351 return DivPipeCoreInterstageData(self.core_config)
352
353 def ospec(self):
354 """ Get the output spec for this pipeline stage. """
355 return DivPipeCoreInterstageData(self.core_config)
356
357 def setup(self, m, i):
358 """ Pipeline stage setup. """
359 setattr(m.submodules,
360 f"div_pipe_core_calculate_{self.stage_index}",
361 self)
362 m.d.comb += self.i.eq(i)
363
364 def process(self, i):
365 """ Pipeline stage process. """
366 return self.o
367
368 def elaborate(self, platform):
369 """ Elaborate into ``Module``. """
370 m = Module()
371 m.d.comb += self.o.divisor_radicand.eq(self.i.divisor_radicand)
372 m.d.comb += self.o.operation.eq(self.i.operation)
373 m.d.comb += self.o.compare_lhs.eq(self.i.compare_lhs)
374 log2_radix = self.core_config.log2_radix
375 current_shift = self.core_config.bit_width
376 current_shift -= self.stage_index * log2_radix
377 log2_radix = min(log2_radix, current_shift)
378 assert log2_radix > 0
379 current_shift -= log2_radix
380 radix = 1 << log2_radix
381 trial_compare_rhs_values = []
382 pass_flags = []
383 for trial_bits in range(radix):
384 t = Trial(self.core_config, trial_bits,
385 current_shift, log2_radix)
386 setattr(m.submodules, "trial%d" % trial_bits, t)
387 m.d.comb += t.divisor_radicand.eq(self.i.divisor_radicand)
388 m.d.comb += t.quotient_root.eq(self.i.quotient_root)
389 m.d.comb += t.root_times_radicand.eq(self.i.root_times_radicand)
390 m.d.comb += t.compare_rhs.eq(self.i.compare_rhs)
391 m.d.comb += t.operation.eq(self.i.operation)
392
393 trial_compare_rhs_values.append(t.trial_compare_rhs)
394
395 pass_flag = Signal(name=f"pass_flag_{trial_bits}", reset_less=True)
396 m.d.comb += pass_flag.eq(self.i.compare_lhs >= t.trial_compare_rhs)
397 pass_flags.append(pass_flag)
398
399 # convert pass_flags to next_bits.
400 #
401 # Assumes that for each set bit in pass_flag, all previous bits are
402 # also set.
403 #
404 # Assumes that pass_flag[0] is always set (since
405 # compare_lhs >= compare_rhs is a pipeline invariant).
406
407 next_bits = Signal(log2_radix, reset_less=True)
408 for i in range(log2_radix):
409 bit_value = 1
410 for j in range(0, radix, 1 << i):
411 bit_value ^= pass_flags[j]
412 m.d.comb += next_bits.part(i, 1).eq(bit_value)
413
414 # merge/select multi-bit trial_compare_rhs_values, to go
415 # into compare_rhs. XXX (only one of these will succeed?)
416 next_compare_rhs = 0
417 for i in range(radix):
418 next_flag = pass_flags[i + 1] if i + 1 < radix else 0
419 selected = Signal(name=f"selected_{i}", reset_less=True)
420 m.d.comb += selected.eq(pass_flags[i] & ~next_flag)
421 next_compare_rhs |= Mux(selected,
422 trial_compare_rhs_values[i],
423 0)
424
425 m.d.comb += self.o.compare_rhs.eq(next_compare_rhs)
426 m.d.comb += self.o.root_times_radicand.eq(self.i.root_times_radicand
427 + ((self.i.divisor_radicand
428 * next_bits)
429 << current_shift))
430 m.d.comb += self.o.quotient_root.eq(self.i.quotient_root
431 | (next_bits << current_shift))
432 return m
433
434
435 class DivPipeCoreFinalStage(Elaboratable):
436 """ Final Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
437
438 def __init__(self, core_config):
439 """ Create a ``DivPipeCoreFinalStage`` instance."""
440 self.core_config = core_config
441 self.i = self.ispec()
442 self.o = self.ospec()
443
444 def ispec(self):
445 """ Get the input spec for this pipeline stage."""
446 return DivPipeCoreInterstageData(self.core_config)
447
448 def ospec(self):
449 """ Get the output spec for this pipeline stage."""
450 return DivPipeCoreOutputData(self.core_config)
451
452 def setup(self, m, i):
453 """ Pipeline stage setup. """
454 m.submodules.div_pipe_core_final = self
455 m.d.comb += self.i.eq(i)
456
457 def process(self, i):
458 """ Pipeline stage process. """
459 return self.o # return processed data (ignore i)
460
461 def elaborate(self, platform):
462 """ Elaborate into ``Module``. """
463 m = Module()
464
465 m.d.comb += self.o.quotient_root.eq(self.i.quotient_root)
466 m.d.comb += self.o.remainder.eq(self.i.compare_lhs
467 - self.i.compare_rhs)
468
469 return m