move Base eqs to separate mixin class
[ieee754fpu.git] / src / ieee754 / div_rem_sqrt_rsqrt / core.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """ Core of the div/rem/sqrt/rsqrt pipeline.
4
5 Special case handling, input/output conversion, and muxid handling are handled
6 outside of these classes.
7
8 Algorithms based on ``algorithm.FixedUDivRemSqrtRSqrt``.
9
10 Formulas solved are:
11 * div/rem:
12 ``dividend == quotient_root * divisor_radicand``
13 * sqrt/rem:
14 ``divisor_radicand == quotient_root * quotient_root``
15 * rsqrt/rem:
16 ``1 == quotient_root * quotient_root * divisor_radicand``
17
18 The remainder is the left-hand-side of the comparison minus the
19 right-hand-side of the comparison in the above formulas.
20 """
21 from nmigen import (Elaboratable, Module, Signal, Const, Mux)
22 import enum
23
24 # TODO, move to new (suitable) location
25 #from ieee754.fpcommon.getop import FPPipeContext
26
27
28 class DivPipeCoreConfig:
29 """ Configuration for core of the div/rem/sqrt/rsqrt pipeline.
30
31 :attribute bit_width: base bit-width.
32 :attribute fract_width: base fract-width. Specifies location of base-2
33 radix point.
34 :attribute log2_radix: number of bits of ``quotient_root`` that should be
35 computed per pipeline stage.
36 """
37
38 def __init__(self, bit_width, fract_width, log2_radix):
39 """ Create a ``DivPipeCoreConfig`` instance. """
40 self.bit_width = bit_width
41 self.fract_width = fract_width
42 self.log2_radix = log2_radix
43
44 def __repr__(self):
45 """ Get repr. """
46 return f"DivPipeCoreConfig({self.bit_width}, " \
47 + f"{self.fract_width}, {self.log2_radix})"
48
49 @property
50 def num_calculate_stages(self):
51 """ Get the number of ``DivPipeCoreCalculateStage`` needed. """
52 return (self.bit_width + self.log2_radix - 1) // self.log2_radix
53
54
55 class DivPipeCoreOperation(enum.IntEnum):
56 """ Operation for ``DivPipeCore``.
57
58 :attribute UDivRem: unsigned divide/remainder.
59 :attribute SqrtRem: square-root/remainder.
60 :attribute RSqrtRem: reciprocal-square-root/remainder.
61 """
62
63 UDivRem = 0
64 SqrtRem = 1
65 RSqrtRem = 2
66
67 @classmethod
68 def create_signal(cls, *, src_loc_at=0, **kwargs):
69 """ Create a signal that can contain a ``DivPipeCoreOperation``. """
70 return Signal(min=int(min(cls)),
71 max=int(max(cls)),
72 src_loc_at=(src_loc_at + 1),
73 decoder=cls,
74 **kwargs)
75
76
77 # TODO: move to suitable location
78 class DivPipeBaseData:
79 """ input data base type for ``DivPipe``.
80 """
81
82 def __init__(self, width, pspec):
83 """ Create a ``DivPipeBaseData`` instance. """
84 self.out_do_z = Signal(reset_less=True)
85 self.oz = Signal(width, reset_less=True)
86
87 self.ctx = FPPipeContext(width, pspec) # context: muxid, operator etc.
88 self.muxid = self.ctx.muxid # annoying. complicated.
89
90 def __iter__(self):
91 """ Get member signals. """
92 yield self.out_do_z
93 yield self.oz
94 yield from self.ctx
95
96 def eq(self, rhs):
97 """ Assign member signals. """
98 return [self.out_do_z.eq(i.out_do_z), self.oz.eq(i.oz),
99 self.ctx.eq(i.ctx)]
100
101
102 class DivPipeCoreInputData:
103 """ input data type for ``DivPipeCore``.
104
105 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
106 configuration to be used.
107 :attribute dividend: dividend for div/rem. Signal with a bit-width of
108 ``core_config.bit_width + core_config.fract_width`` and a fract-width
109 of ``core_config.fract_width * 2`` bits.
110 :attribute divisor_radicand: divisor for div/rem and radicand for
111 sqrt/rsqrt. Signal with a bit-width of ``core_config.bit_width`` and a
112 fract-width of ``core_config.fract_width`` bits.
113 :attribute operation: the ``DivPipeCoreOperation`` to be computed.
114 """
115
116 def __init__(self, core_config):
117 """ Create a ``DivPipeCoreInputData`` instance. """
118 self.core_config = core_config
119 self.dividend = Signal(core_config.bit_width + core_config.fract_width,
120 reset_less=True)
121 self.divisor_radicand = Signal(core_config.bit_width, reset_less=True)
122
123 # FIXME: this goes into (is replaced by) self.ctx.op
124 self.operation = DivPipeCoreOperation.create_signal(reset_less=True)
125
126 def __iter__(self):
127 """ Get member signals. """
128 yield self.dividend
129 yield self.divisor_radicand
130 yield self.operation # FIXME: delete. already covered by self.ctx
131 return
132 yield self.z
133 yield self.out_do_z
134 yield self.oz
135 yield from self.ctx
136
137 def eq(self, rhs):
138 """ Assign member signals. """
139 return [self.dividend.eq(rhs.dividend),
140 self.divisor_radicand.eq(rhs.divisor_radicand),
141 self.operation.eq(rhs.operation)] # FIXME: delete.
142
143
144 # TODO: move to suitable location
145 class DivPipeInputData(DivPipeCoreInputData, DivPipeBaseData):
146 """ input data type for ``DivPipe``.
147 """
148
149 def __init__(self, core_config):
150 """ Create a ``DivPipeInputData`` instance. """
151 DivPipeCoreInputData.__init__(self, core_config)
152 DivPipeBaseData.__init__(self, width, pspec) # XXX TODO args
153 self.out_do_z = Signal(reset_less=True)
154 self.oz = Signal(width, reset_less=True)
155
156 self.ctx = FPPipeContext(width, pspec) # context: muxid, operator etc.
157 self.muxid = self.ctx.muxid # annoying. complicated.
158
159 def __iter__(self):
160 """ Get member signals. """
161 yield from DivPipeCoreInputData.__iter__(self)
162 yield from DivPipeBaseData.__iter__(self)
163
164 def eq(self, rhs):
165 """ Assign member signals. """
166 return DivPipeBaseData.eq(self, rhs) + \
167 DivPipeCoreInputData.eq(self, rhs)
168
169
170
171 class DivPipeCoreInterstageData:
172 """ interstage data type for ``DivPipeCore``.
173
174 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
175 configuration to be used.
176 :attribute divisor_radicand: divisor for div/rem and radicand for
177 sqrt/rsqrt. Signal with a bit-width of ``core_config.bit_width`` and a
178 fract-width of ``core_config.fract_width`` bits.
179 :attribute operation: the ``DivPipeCoreOperation`` to be computed.
180 :attribute quotient_root: the quotient or root part of the result of the
181 operation. Signal with a bit-width of ``core_config.bit_width`` and a
182 fract-width of ``core_config.fract_width`` bits.
183 :attribute root_times_radicand: ``quotient_root * divisor_radicand``.
184 Signal with a bit-width of ``core_config.bit_width * 2`` and a
185 fract-width of ``core_config.fract_width * 2`` bits.
186 :attribute compare_lhs: The left-hand-side of the comparison in the
187 equation to be solved. Signal with a bit-width of
188 ``core_config.bit_width * 3`` and a fract-width of
189 ``core_config.fract_width * 3`` bits.
190 :attribute compare_rhs: The right-hand-side of the comparison in the
191 equation to be solved. Signal with a bit-width of
192 ``core_config.bit_width * 3`` and a fract-width of
193 ``core_config.fract_width * 3`` bits.
194 """
195
196 def __init__(self, core_config):
197 """ Create a ``DivPipeCoreInterstageData`` instance. """
198 self.core_config = core_config
199 self.divisor_radicand = Signal(core_config.bit_width, reset_less=True)
200 # XXX FIXME: delete. already covered by self.ctx.op
201 self.operation = DivPipeCoreOperation.create_signal(reset_less=True)
202 self.quotient_root = Signal(core_config.bit_width, reset_less=True)
203 self.root_times_radicand = Signal(core_config.bit_width * 2,
204 reset_less=True)
205 self.compare_lhs = Signal(core_config.bit_width * 3, reset_less=True)
206 self.compare_rhs = Signal(core_config.bit_width * 3, reset_less=True)
207
208 def __iter__(self):
209 """ Get member signals. """
210 yield self.divisor_radicand
211 yield self.operation # XXX FIXME: delete. already in self.ctx.op
212 yield self.quotient_root
213 yield self.root_times_radicand
214 yield self.compare_lhs
215 yield self.compare_rhs
216
217 def eq(self, rhs):
218 """ Assign member signals. """
219 return [self.divisor_radicand.eq(rhs.divisor_radicand),
220 self.operation.eq(rhs.operation), # FIXME: delete.
221 self.quotient_root.eq(rhs.quotient_root),
222 self.root_times_radicand.eq(rhs.root_times_radicand),
223 self.compare_lhs.eq(rhs.compare_lhs),
224 self.compare_rhs.eq(rhs.compare_rhs)]
225
226
227 # TODO: move to suitable location
228 class DivPipeInterstageData(DivPipeCoreInterstageData, DivPipeBaseData):
229 """ interstage data type for ``DivPipe``.
230
231 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
232 configuration to be used.
233 """
234
235 def __init__(self, core_config):
236 """ Create a ``DivPipeCoreInterstageData`` instance. """
237 DivPipeCoreInterstageData.__init__(self, core_config)
238 DivPipeBaseData.__init__(self, width, pspec) # XXX TODO args
239
240 def __iter__(self):
241 """ Get member signals. """
242 yield from DivPipeInterstageData.__iter__(self)
243 yield from DivPipeBaseData.__iter__(self)
244
245 def eq(self, rhs):
246 """ Assign member signals. """
247 return DivPipeBaseData.eq(self, rhs) + \
248 DivPipeCoreInterstageData.eq(self, rhs)
249
250
251 class DivPipeCoreOutputData:
252 """ output data type for ``DivPipeCore``.
253
254 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
255 configuration to be used.
256 :attribute quotient_root: the quotient or root part of the result of the
257 operation. Signal with a bit-width of ``core_config.bit_width`` and a
258 fract-width of ``core_config.fract_width`` bits.
259 :attribute remainder: the remainder part of the result of the operation.
260 Signal with a bit-width of ``core_config.bit_width * 3`` and a
261 fract-width of ``core_config.fract_width * 3`` bits.
262 """
263
264 def __init__(self, core_config):
265 """ Create a ``DivPipeCoreOutputData`` instance. """
266 self.core_config = core_config
267 self.quotient_root = Signal(core_config.bit_width, reset_less=True)
268 self.remainder = Signal(core_config.bit_width * 3, reset_less=True)
269
270 def __iter__(self):
271 """ Get member signals. """
272 yield self.quotient_root
273 yield self.remainder
274 return
275
276 def eq(self, rhs):
277 """ Assign member signals. """
278 return [self.quotient_root.eq(rhs.quotient_root),
279 self.remainder.eq(rhs.remainder)]
280
281
282 # TODO: move to suitable location
283 class DivPipeOutputData(DivPipeCoreOutputData, DivPipeBaseData):
284 """ interstage data type for ``DivPipe``.
285
286 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
287 configuration to be used.
288 """
289
290 def __init__(self, core_config):
291 """ Create a ``DivPipeCoreOutputData`` instance. """
292 DivPipeCoreOutputData.__init__(self, core_config)
293 DivPipeBaseData.__init__(self, width, pspec) # XXX TODO args
294
295 def __iter__(self):
296 """ Get member signals. """
297 yield from DivPipeOutputData.__iter__(self)
298 yield from DivPipeBaseData.__iter__(self)
299
300 def eq(self, rhs):
301 """ Assign member signals. """
302 return DivPipeBaseData.eq(self, rhs) + \
303 DivPipeCoreOutputData.eq(self, rhs)
304
305
306 class DivPipeBaseStage:
307 """ Base Mix-in for DivPipe*Stage """
308
309 def _elaborate(self, m, platform):
310 m.d.comb += self.o.oz.eq(self.i.oz)
311 m.d.comb += self.o.out_do_z.eq(self.i.out_do_z)
312 m.d.comb += self.o.ctx.eq(self.i.ctx)
313
314
315 class DivPipeCoreSetupStage(Elaboratable):
316 """ Setup Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
317
318 def __init__(self, core_config):
319 """ Create a ``DivPipeCoreSetupStage`` instance."""
320 self.core_config = core_config
321 self.i = self.ispec()
322 self.o = self.ospec()
323
324 def ispec(self):
325 """ Get the input spec for this pipeline stage."""
326 return DivPipeCoreInputData(self.core_config)
327
328 def ospec(self):
329 """ Get the output spec for this pipeline stage."""
330 return DivPipeCoreInterstageData(self.core_config)
331
332 def setup(self, m, i):
333 """ Pipeline stage setup. """
334 m.submodules.div_pipe_core_setup = self
335 m.d.comb += self.i.eq(i)
336
337 def process(self, i):
338 """ Pipeline stage process. """
339 return self.o # return processed data (ignore i)
340
341 def elaborate(self, platform):
342 """ Elaborate into ``Module``. """
343 m = Module()
344
345 m.d.comb += self.o.divisor_radicand.eq(self.i.divisor_radicand)
346 m.d.comb += self.o.quotient_root.eq(0)
347 m.d.comb += self.o.root_times_radicand.eq(0)
348
349 with m.If(self.i.operation == DivPipeCoreOperation.UDivRem):
350 m.d.comb += self.o.compare_lhs.eq(self.i.dividend
351 << self.core_config.fract_width)
352 with m.Elif(self.i.operation == DivPipeCoreOperation.SqrtRem):
353 m.d.comb += self.o.compare_lhs.eq(
354 self.i.divisor_radicand << (self.core_config.fract_width * 2))
355 with m.Else(): # DivPipeCoreOperation.RSqrtRem
356 m.d.comb += self.o.compare_lhs.eq(
357 1 << (self.core_config.fract_width * 3))
358
359 m.d.comb += self.o.compare_rhs.eq(0)
360 m.d.comb += self.o.operation.eq(self.i.operation)
361
362 return m
363
364 # XXX in DivPipeSetupStage
365 DivPipeBaseStage._elaborate(self, m, platform)
366
367
368 class DivPipeCoreCalculateStage(Elaboratable):
369 """ Calculate Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
370
371 def __init__(self, core_config, stage_index):
372 """ Create a ``DivPipeCoreSetupStage`` instance. """
373 self.core_config = core_config
374 assert stage_index in range(core_config.num_calculate_stages)
375 self.stage_index = stage_index
376 self.i = self.ispec()
377 self.o = self.ospec()
378
379 def ispec(self):
380 """ Get the input spec for this pipeline stage. """
381 return DivPipeCoreInterstageData(self.core_config)
382
383 def ospec(self):
384 """ Get the output spec for this pipeline stage. """
385 return DivPipeCoreInterstageData(self.core_config)
386
387 def setup(self, m, i):
388 """ Pipeline stage setup. """
389 setattr(m.submodules,
390 f"div_pipe_core_calculate_{self.stage_index}",
391 self)
392 m.d.comb += self.i.eq(i)
393
394 def process(self, i):
395 """ Pipeline stage process. """
396 return self.o
397
398 def elaborate(self, platform):
399 """ Elaborate into ``Module``. """
400 m = Module()
401 m.d.comb += self.o.divisor_radicand.eq(self.i.divisor_radicand)
402 m.d.comb += self.o.operation.eq(self.i.operation)
403 m.d.comb += self.o.compare_lhs.eq(self.i.compare_lhs)
404 log2_radix = self.core_config.log2_radix
405 current_shift = self.core_config.bit_width
406 current_shift -= self.stage_index * log2_radix
407 log2_radix = min(log2_radix, current_shift)
408 assert log2_radix > 0
409 current_shift -= log2_radix
410 radix = 1 << log2_radix
411 trial_compare_rhs_values = []
412 pass_flags = []
413 for trial_bits in range(radix):
414 shifted_trial_bits = Const(trial_bits, log2_radix) << current_shift
415 shifted_trial_bits_sqrd = shifted_trial_bits * shifted_trial_bits
416
417 # UDivRem
418 div_rhs = self.i.compare_rhs
419 div_factor1 = self.i.divisor_radicand * shifted_trial_bits
420 div_rhs += div_factor1 << self.core_config.fract_width
421
422 # SqrtRem
423 sqrt_rhs = self.i.compare_rhs
424 sqrt_factor1 = self.i.quotient_root * (shifted_trial_bits << 1)
425 sqrt_rhs += sqrt_factor1 << self.core_config.fract_width
426 sqrt_factor2 = shifted_trial_bits_sqrd
427 sqrt_rhs += sqrt_factor2 << self.core_config.fract_width
428
429 # RSqrtRem
430 rsqrt_rhs = self.i.compare_rhs
431 rsqrt_rhs += self.i.root_times_radicand * (shifted_trial_bits << 1)
432 rsqrt_rhs += self.i.divisor_radicand * shifted_trial_bits_sqrd
433
434 trial_compare_rhs = self.o.compare_rhs.like(
435 name=f"trial_compare_rhs_{trial_bits}")
436
437 with m.If(self.i.operation == DivPipeCoreOperation.UDivRem):
438 m.d.comb += trial_compare_rhs.eq(div_rhs)
439 with m.Elif(self.i.operation == DivPipeCoreOperation.SqrtRem):
440 m.d.comb += trial_compare_rhs.eq(sqrt_rhs)
441 with m.Else(): # DivPipeCoreOperation.RSqrtRem
442 m.d.comb += trial_compare_rhs.eq(rsqrt_rhs)
443 trial_compare_rhs_values.append(trial_compare_rhs)
444
445 pass_flag = Signal(name=f"pass_flag_{trial_bits}")
446 m.d.comb += pass_flag.eq(self.i.compare_lhs >= trial_compare_rhs)
447 pass_flags.append(pass_flag)
448
449 # convert pass_flags to next_bits.
450 #
451 # Assumes that for each set bit in pass_flag, all previous bits are
452 # also set.
453 #
454 # Assumes that pass_flag[0] is always set (since
455 # compare_lhs >= compare_rhs is a pipeline invariant).
456
457 next_bits = Signal(log2_radix)
458 for i in range(log2_radix):
459 bit_value = 1
460 for j in range(0, radix, 1 << i):
461 bit_value ^= pass_flags[j]
462 m.d.comb += next_bits.part(i, 1).eq(bit_value)
463
464 next_compare_rhs = 0
465 for i in range(radix):
466 next_flag = pass_flags[i + 1] if i + 1 < radix else 0
467 next_compare_rhs |= Mux(pass_flags[i] & ~next_flag,
468 trial_compare_rhs_values[i],
469 0)
470
471 m.d.comb += self.o.compare_rhs.eq(next_compare_rhs)
472 m.d.comb += self.o.root_times_radicand.eq(self.i.root_times_radicand
473 + ((self.i.divisor_radicand
474 * next_bits)
475 << current_shift))
476 m.d.comb += self.o.quotient_root.eq(self.i.quotient_root
477 | (next_bits << current_shift))
478 return m
479
480 # XXX in DivPipeCalculateStage
481 DivPipeBaseStage._elaborate(self, m, platform)
482
483
484
485 class DivPipeCoreFinalStage(Elaboratable):
486 """ Final Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
487
488 def __init__(self, core_config):
489 """ Create a ``DivPipeCoreFinalStage`` instance."""
490 self.core_config = core_config
491 self.i = self.ispec()
492 self.o = self.ospec()
493
494 def ispec(self):
495 """ Get the input spec for this pipeline stage."""
496 return DivPipeCoreInterstageData(self.core_config)
497
498 def ospec(self):
499 """ Get the output spec for this pipeline stage."""
500 return DivPipeCoreOutputData(self.core_config)
501
502 def setup(self, m, i):
503 """ Pipeline stage setup. """
504 m.submodules.div_pipe_core_setup = self
505 m.d.comb += self.i.eq(i)
506
507 def process(self, i):
508 """ Pipeline stage process. """
509 return self.o # return processed data (ignore i)
510
511 def elaborate(self, platform):
512 """ Elaborate into ``Module``. """
513 m = Module()
514
515 m.d.comb += self.o.quotient_root.eq(self.i.quotient_root)
516 m.d.comb += self.o.remainder.eq(self.i.compare_lhs
517 - self.i.compare_rhs)
518
519 return m
520
521 # XXX in DivPipeFinalStage
522 DivPipeBaseStage._elaborate(self, m, platform)
523