switch to exact version of cython
[ieee754fpu.git] / src / ieee754 / div_rem_sqrt_rsqrt / core.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """ Core of the div/rem/sqrt/rsqrt pipeline.
4
5 Special case handling, input/output conversion, and muxid handling are handled
6 outside of these classes.
7
8 Algorithms based on ``algorithm.FixedUDivRemSqrtRSqrt``.
9
10 Formulas solved are:
11 * div/rem:
12 ``dividend == quotient_root * divisor_radicand``
13 * sqrt/rem:
14 ``divisor_radicand == quotient_root * quotient_root``
15 * rsqrt/rem:
16 ``1 == quotient_root * quotient_root * divisor_radicand``
17
18 The remainder is the left-hand-side of the comparison minus the
19 right-hand-side of the comparison in the above formulas.
20 """
21 from nmigen import (Elaboratable, Module, Signal, Const, Mux, Cat)
22 import enum
23
24
25 class DivPipeCoreConfig:
26 """ Configuration for core of the div/rem/sqrt/rsqrt pipeline.
27
28 :attribute bit_width: base bit-width.
29 :attribute fract_width: base fract-width. Specifies location of base-2
30 radix point.
31 :attribute log2_radix: number of bits of ``quotient_root`` that should be
32 computed per pipeline stage.
33 """
34
35 def __init__(self, bit_width, fract_width, log2_radix):
36 """ Create a ``DivPipeCoreConfig`` instance. """
37 self.bit_width = bit_width
38 self.fract_width = fract_width
39 self.log2_radix = log2_radix
40
41 def __repr__(self):
42 """ Get repr. """
43 return f"DivPipeCoreConfig({self.bit_width}, " \
44 + f"{self.fract_width}, {self.log2_radix})"
45
46 @property
47 def num_calculate_stages(self):
48 """ Get the number of ``DivPipeCoreCalculateStage`` needed. """
49 return (self.bit_width + self.log2_radix - 1) // self.log2_radix
50
51
52 class DivPipeCoreOperation(enum.Enum):
53 """ Operation for ``DivPipeCore``.
54
55 :attribute UDivRem: unsigned divide/remainder.
56 :attribute SqrtRem: square-root/remainder.
57 :attribute RSqrtRem: reciprocal-square-root/remainder.
58 """
59
60 UDivRem = 0
61 SqrtRem = 1
62 RSqrtRem = 2
63
64 def __int__(self):
65 """ Convert to int. """
66 return self.value
67
68 @classmethod
69 def create_signal(cls, *, src_loc_at=0, **kwargs):
70 """ Create a signal that can contain a ``DivPipeCoreOperation``. """
71 return Signal(min=min(map(int, cls)),
72 max=max(map(int, cls)) + 2,
73 src_loc_at=(src_loc_at + 1),
74 decoder=lambda v: str(cls(v)),
75 **kwargs)
76
77
78 class DivPipeCoreInputData:
79 """ input data type for ``DivPipeCore``.
80
81 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
82 configuration to be used.
83 :attribute dividend: dividend for div/rem. Signal with a bit-width of
84 ``core_config.bit_width + core_config.fract_width`` and a fract-width
85 of ``core_config.fract_width * 2`` bits.
86 :attribute divisor_radicand: divisor for div/rem and radicand for
87 sqrt/rsqrt. Signal with a bit-width of ``core_config.bit_width`` and a
88 fract-width of ``core_config.fract_width`` bits.
89 :attribute operation: the ``DivPipeCoreOperation`` to be computed.
90 """
91
92 def __init__(self, core_config, reset_less=True):
93 """ Create a ``DivPipeCoreInputData`` instance. """
94 self.core_config = core_config
95 self.dividend = Signal(core_config.bit_width + core_config.fract_width,
96 reset_less=reset_less)
97 self.divisor_radicand = Signal(core_config.bit_width,
98 reset_less=reset_less)
99
100 # FIXME: this goes into (is replaced by) self.ctx.op
101 self.operation = \
102 DivPipeCoreOperation.create_signal(reset_less=reset_less)
103
104 def __iter__(self):
105 """ Get member signals. """
106 yield self.dividend
107 yield self.divisor_radicand
108 yield self.operation # FIXME: delete. already covered by self.ctx
109
110 def eq(self, rhs):
111 """ Assign member signals. """
112 return [self.dividend.eq(rhs.dividend),
113 self.divisor_radicand.eq(rhs.divisor_radicand),
114 self.operation.eq(rhs.operation), # FIXME: delete.
115 ]
116
117
118 class DivPipeCoreInterstageData:
119 """ interstage data type for ``DivPipeCore``.
120
121 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
122 configuration to be used.
123 :attribute divisor_radicand: divisor for div/rem and radicand for
124 sqrt/rsqrt. Signal with a bit-width of ``core_config.bit_width`` and a
125 fract-width of ``core_config.fract_width`` bits.
126 :attribute operation: the ``DivPipeCoreOperation`` to be computed.
127 :attribute quotient_root: the quotient or root part of the result of the
128 operation. Signal with a bit-width of ``core_config.bit_width`` and a
129 fract-width of ``core_config.fract_width`` bits.
130 :attribute root_times_radicand: ``quotient_root * divisor_radicand``.
131 Signal with a bit-width of ``core_config.bit_width * 2`` and a
132 fract-width of ``core_config.fract_width * 2`` bits.
133 :attribute compare_lhs: The left-hand-side of the comparison in the
134 equation to be solved. Signal with a bit-width of
135 ``core_config.bit_width * 3`` and a fract-width of
136 ``core_config.fract_width * 3`` bits.
137 :attribute compare_rhs: The right-hand-side of the comparison in the
138 equation to be solved. Signal with a bit-width of
139 ``core_config.bit_width * 3`` and a fract-width of
140 ``core_config.fract_width * 3`` bits.
141 """
142
143 def __init__(self, core_config, reset_less=True):
144 """ Create a ``DivPipeCoreInterstageData`` instance. """
145 self.core_config = core_config
146 self.divisor_radicand = Signal(core_config.bit_width,
147 reset_less=reset_less)
148 # FIXME: delete self.operation. already covered by self.ctx.op
149 self.operation = \
150 DivPipeCoreOperation.create_signal(reset_less=reset_less)
151 self.quotient_root = Signal(core_config.bit_width,
152 reset_less=reset_less)
153 self.root_times_radicand = Signal(core_config.bit_width * 2,
154 reset_less=reset_less)
155 self.compare_lhs = Signal(core_config.bit_width * 3,
156 reset_less=reset_less)
157 self.compare_rhs = Signal(core_config.bit_width * 3,
158 reset_less=reset_less)
159
160 def __iter__(self):
161 """ Get member signals. """
162 yield self.divisor_radicand
163 yield self.operation # FIXME: delete. already in self.ctx.op
164 yield self.quotient_root
165 yield self.root_times_radicand
166 yield self.compare_lhs
167 yield self.compare_rhs
168
169 def eq(self, rhs):
170 """ Assign member signals. """
171 return [self.divisor_radicand.eq(rhs.divisor_radicand),
172 self.operation.eq(rhs.operation), # FIXME: delete.
173 self.quotient_root.eq(rhs.quotient_root),
174 self.root_times_radicand.eq(rhs.root_times_radicand),
175 self.compare_lhs.eq(rhs.compare_lhs),
176 self.compare_rhs.eq(rhs.compare_rhs)]
177
178
179 class DivPipeCoreOutputData:
180 """ output data type for ``DivPipeCore``.
181
182 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
183 configuration to be used.
184 :attribute quotient_root: the quotient or root part of the result of the
185 operation. Signal with a bit-width of ``core_config.bit_width`` and a
186 fract-width of ``core_config.fract_width`` bits.
187 :attribute remainder: the remainder part of the result of the operation.
188 Signal with a bit-width of ``core_config.bit_width * 3`` and a
189 fract-width of ``core_config.fract_width * 3`` bits.
190 """
191
192 def __init__(self, core_config, reset_less=True):
193 """ Create a ``DivPipeCoreOutputData`` instance. """
194 self.core_config = core_config
195 self.quotient_root = Signal(core_config.bit_width,
196 reset_less=reset_less)
197 self.remainder = Signal(core_config.bit_width * 3,
198 reset_less=reset_less)
199
200 def __iter__(self):
201 """ Get member signals. """
202 yield self.quotient_root
203 yield self.remainder
204 return
205
206 def eq(self, rhs):
207 """ Assign member signals. """
208 return [self.quotient_root.eq(rhs.quotient_root),
209 self.remainder.eq(rhs.remainder)]
210
211
212 class DivPipeCoreSetupStage(Elaboratable):
213 """ Setup Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
214
215 def __init__(self, core_config):
216 """ Create a ``DivPipeCoreSetupStage`` instance."""
217 self.core_config = core_config
218 self.i = self.ispec()
219 self.o = self.ospec()
220
221 def ispec(self):
222 """ Get the input spec for this pipeline stage."""
223 return DivPipeCoreInputData(self.core_config)
224
225 def ospec(self):
226 """ Get the output spec for this pipeline stage."""
227 return DivPipeCoreInterstageData(self.core_config)
228
229 def setup(self, m, i):
230 """ Pipeline stage setup. """
231 m.submodules.div_pipe_core_setup = self
232 m.d.comb += self.i.eq(i)
233
234 def process(self, i):
235 """ Pipeline stage process. """
236 return self.o # return processed data (ignore i)
237
238 def elaborate(self, platform):
239 """ Elaborate into ``Module``. """
240 m = Module()
241
242 m.d.comb += self.o.divisor_radicand.eq(self.i.divisor_radicand)
243 m.d.comb += self.o.quotient_root.eq(0)
244 m.d.comb += self.o.root_times_radicand.eq(0)
245
246 with m.If(self.i.operation == int(DivPipeCoreOperation.UDivRem)):
247 m.d.comb += self.o.compare_lhs.eq(self.i.dividend
248 << self.core_config.fract_width)
249 with m.Elif(self.i.operation == int(DivPipeCoreOperation.SqrtRem)):
250 m.d.comb += self.o.compare_lhs.eq(
251 self.i.divisor_radicand << (self.core_config.fract_width * 2))
252 with m.Else(): # DivPipeCoreOperation.RSqrtRem
253 m.d.comb += self.o.compare_lhs.eq(
254 1 << (self.core_config.fract_width * 3))
255
256 m.d.comb += self.o.compare_rhs.eq(0)
257 m.d.comb += self.o.operation.eq(self.i.operation)
258
259 return m
260
261
262 class DivPipeCoreCalculateStage(Elaboratable):
263 """ Calculate Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
264
265 def __init__(self, core_config, stage_index):
266 """ Create a ``DivPipeCoreSetupStage`` instance. """
267 self.core_config = core_config
268 assert stage_index in range(core_config.num_calculate_stages)
269 self.stage_index = stage_index
270 self.i = self.ispec()
271 self.o = self.ospec()
272
273 def ispec(self):
274 """ Get the input spec for this pipeline stage. """
275 return DivPipeCoreInterstageData(self.core_config)
276
277 def ospec(self):
278 """ Get the output spec for this pipeline stage. """
279 return DivPipeCoreInterstageData(self.core_config)
280
281 def setup(self, m, i):
282 """ Pipeline stage setup. """
283 setattr(m.submodules,
284 f"div_pipe_core_calculate_{self.stage_index}",
285 self)
286 m.d.comb += self.i.eq(i)
287
288 def process(self, i):
289 """ Pipeline stage process. """
290 return self.o
291
292 def elaborate(self, platform):
293 """ Elaborate into ``Module``. """
294 m = Module()
295 m.d.comb += self.o.divisor_radicand.eq(self.i.divisor_radicand)
296 m.d.comb += self.o.operation.eq(self.i.operation)
297 m.d.comb += self.o.compare_lhs.eq(self.i.compare_lhs)
298 log2_radix = self.core_config.log2_radix
299 current_shift = self.core_config.bit_width
300 current_shift -= self.stage_index * log2_radix
301 log2_radix = min(log2_radix, current_shift)
302 assert log2_radix > 0
303 current_shift -= log2_radix
304 radix = 1 << log2_radix
305 trial_compare_rhs_values = []
306 pass_flags = []
307 for trial_bits in range(radix):
308 trial_bits_sig = Const(trial_bits, log2_radix)
309 trial_bits_sqrd_sig = Const(trial_bits * trial_bits,
310 log2_radix * 2)
311
312 dr_times_trial_bits = self.i.divisor_radicand * trial_bits_sig
313 dr_times_trial_bits_sqrd = self.i.divisor_radicand \
314 * trial_bits_sqrd_sig
315 qr_times_trial_bits = self.i.quotient_root * trial_bits_sig
316 rr_times_trial_bits = self.i.root_times_radicand * trial_bits_sig
317
318 # UDivRem
319 div_rhs = self.i.compare_rhs
320 if trial_bits != 0: # no point adding stuff that's multiplied by zero
321 div_term1 = dr_times_trial_bits
322 div_term1_shift = self.core_config.fract_width
323 div_term1_shift += current_shift
324 div_rhs += div_term1 << div_term1_shift
325
326 # SqrtRem
327 sqrt_rhs = self.i.compare_rhs
328 if trial_bits != 0: # no point adding stuff that's multiplied by zero
329 sqrt_term1 = qr_times_trial_bits
330 sqrt_term1_shift = self.core_config.fract_width
331 sqrt_term1_shift += current_shift + 1
332 sqrt_rhs += sqrt_term1 << sqrt_term1_shift
333 sqrt_term2 = trial_bits_sqrd_sig
334 sqrt_term2_shift = self.core_config.fract_width
335 sqrt_term2_shift += current_shift * 2
336 sqrt_rhs += sqrt_term2 << sqrt_term2_shift
337
338 # RSqrtRem
339 rsqrt_rhs = self.i.compare_rhs
340 if trial_bits != 0: # no point adding stuff that's multiplied by zero
341 rsqrt_term1 = rr_times_trial_bits
342 rsqrt_term1_shift = current_shift + 1
343 rsqrt_rhs += rsqrt_term1 << rsqrt_term1_shift
344 rsqrt_term2 = dr_times_trial_bits_sqrd
345 rsqrt_term2_shift = current_shift * 2
346 rsqrt_rhs += rsqrt_term2 << rsqrt_term2_shift
347
348 trial_compare_rhs = Signal.like(
349 self.o.compare_rhs, name=f"trial_compare_rhs_{trial_bits}",
350 reset_less=True)
351
352 with m.If(self.i.operation == int(DivPipeCoreOperation.UDivRem)):
353 m.d.comb += trial_compare_rhs.eq(div_rhs)
354 with m.Elif(self.i.operation == int(DivPipeCoreOperation.SqrtRem)):
355 m.d.comb += trial_compare_rhs.eq(sqrt_rhs)
356 with m.Else(): # DivPipeCoreOperation.RSqrtRem
357 m.d.comb += trial_compare_rhs.eq(rsqrt_rhs)
358 trial_compare_rhs_values.append(trial_compare_rhs)
359
360 pass_flag = Signal(name=f"pass_flag_{trial_bits}", reset_less=True)
361 m.d.comb += pass_flag.eq(self.i.compare_lhs >= trial_compare_rhs)
362 pass_flags.append(pass_flag)
363
364 # convert pass_flags to next_bits.
365 #
366 # Assumes that for each set bit in pass_flag, all previous bits are
367 # also set.
368 #
369 # Assumes that pass_flag[0] is always set (since
370 # compare_lhs >= compare_rhs is a pipeline invariant).
371
372 next_bits = Signal(log2_radix, reset_less=True)
373 for i in range(log2_radix):
374 bit_value = 1
375 for j in range(0, radix, 1 << i):
376 bit_value ^= pass_flags[j]
377 m.d.comb += next_bits.part(i, 1).eq(bit_value)
378
379 next_compare_rhs = 0
380 for i in range(radix):
381 next_flag = pass_flags[i + 1] if i + 1 < radix else 0
382 selected = Signal(name=f"selected_{i}", reset_less=True)
383 m.d.comb += selected.eq(pass_flags[i] & ~next_flag)
384 next_compare_rhs |= Mux(selected,
385 trial_compare_rhs_values[i],
386 0)
387
388 m.d.comb += self.o.compare_rhs.eq(next_compare_rhs)
389 m.d.comb += self.o.root_times_radicand.eq(self.i.root_times_radicand
390 + ((self.i.divisor_radicand
391 * next_bits)
392 << current_shift))
393 m.d.comb += self.o.quotient_root.eq(self.i.quotient_root
394 | (next_bits << current_shift))
395 return m
396
397
398 class DivPipeCoreFinalStage(Elaboratable):
399 """ Final Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
400
401 def __init__(self, core_config):
402 """ Create a ``DivPipeCoreFinalStage`` instance."""
403 self.core_config = core_config
404 self.i = self.ispec()
405 self.o = self.ospec()
406
407 def ispec(self):
408 """ Get the input spec for this pipeline stage."""
409 return DivPipeCoreInterstageData(self.core_config)
410
411 def ospec(self):
412 """ Get the output spec for this pipeline stage."""
413 return DivPipeCoreOutputData(self.core_config)
414
415 def setup(self, m, i):
416 """ Pipeline stage setup. """
417 m.submodules.div_pipe_core_final = self
418 m.d.comb += self.i.eq(i)
419
420 def process(self, i):
421 """ Pipeline stage process. """
422 return self.o # return processed data (ignore i)
423
424 def elaborate(self, platform):
425 """ Elaborate into ``Module``. """
426 m = Module()
427
428 m.d.comb += self.o.quotient_root.eq(self.i.quotient_root)
429 m.d.comb += self.o.remainder.eq(self.i.compare_lhs
430 - self.i.compare_rhs)
431
432 return m