tidyup a bit
[ieee754fpu.git] / src / ieee754 / div_rem_sqrt_rsqrt / core.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """ Core of the div/rem/sqrt/rsqrt pipeline.
4
5 Special case handling, input/output conversion, and muxid handling are handled
6 outside of these classes.
7
8 Algorithms based on ``algorithm.FixedUDivRemSqrtRSqrt``.
9
10 Formulas solved are:
11 * div/rem:
12 ``dividend == quotient_root * divisor_radicand``
13 * sqrt/rem:
14 ``divisor_radicand == quotient_root * quotient_root``
15 * rsqrt/rem:
16 ``1 == quotient_root * quotient_root * divisor_radicand``
17
18 The remainder is the left-hand-side of the comparison minus the
19 right-hand-side of the comparison in the above formulas.
20 """
21 from nmigen import (Elaboratable, Module, Signal, Const, Mux, Cat)
22 import enum
23
24
25 class DivPipeCoreConfig:
26 """ Configuration for core of the div/rem/sqrt/rsqrt pipeline.
27
28 :attribute bit_width: base bit-width.
29 :attribute fract_width: base fract-width. Specifies location of base-2
30 radix point.
31 :attribute log2_radix: number of bits of ``quotient_root`` that should be
32 computed per pipeline stage.
33 """
34
35 def __init__(self, bit_width, fract_width, log2_radix):
36 """ Create a ``DivPipeCoreConfig`` instance. """
37 self.bit_width = bit_width
38 self.fract_width = fract_width
39 self.log2_radix = log2_radix
40
41 def __repr__(self):
42 """ Get repr. """
43 return f"DivPipeCoreConfig({self.bit_width}, " \
44 + f"{self.fract_width}, {self.log2_radix})"
45
46 @property
47 def n_stages(self):
48 """ Get the number of ``DivPipeCoreCalculateStage`` needed. """
49 return (self.bit_width + self.log2_radix - 1) // self.log2_radix
50
51
52 class DivPipeCoreOperation(enum.Enum):
53 """ Operation for ``DivPipeCore``.
54
55 :attribute UDivRem: unsigned divide/remainder.
56 :attribute SqrtRem: square-root/remainder.
57 :attribute RSqrtRem: reciprocal-square-root/remainder.
58 """
59
60 UDivRem = 0
61 SqrtRem = 1
62 RSqrtRem = 2
63
64 def __int__(self):
65 """ Convert to int. """
66 return self.value
67
68 @classmethod
69 def create_signal(cls, *, src_loc_at=0, **kwargs):
70 """ Create a signal that can contain a ``DivPipeCoreOperation``. """
71 return Signal(min=min(map(int, cls)),
72 max=max(map(int, cls)) + 2,
73 src_loc_at=(src_loc_at + 1),
74 decoder=lambda v: str(cls(v)),
75 **kwargs)
76
77
78 DP = DivPipeCoreOperation
79
80 class DivPipeCoreInputData:
81 """ input data type for ``DivPipeCore``.
82
83 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
84 configuration to be used.
85 :attribute dividend: dividend for div/rem. Signal with a bit-width of
86 ``core_config.bit_width + core_config.fract_width`` and a fract-width
87 of ``core_config.fract_width * 2`` bits.
88 :attribute divisor_radicand: divisor for div/rem and radicand for
89 sqrt/rsqrt. Signal with a bit-width of ``core_config.bit_width`` and a
90 fract-width of ``core_config.fract_width`` bits.
91 :attribute operation: the ``DivPipeCoreOperation`` to be computed.
92 """
93
94 def __init__(self, core_config, reset_less=True):
95 """ Create a ``DivPipeCoreInputData`` instance. """
96 self.core_config = core_config
97 self.dividend = Signal(core_config.bit_width + core_config.fract_width,
98 reset_less=reset_less)
99 self.divisor_radicand = Signal(core_config.bit_width,
100 reset_less=reset_less)
101 self.operation = DP.create_signal(reset_less=reset_less)
102
103 def __iter__(self):
104 """ Get member signals. """
105 yield self.dividend
106 yield self.divisor_radicand
107 yield self.operation
108
109 def eq(self, rhs):
110 """ Assign member signals. """
111 return [self.dividend.eq(rhs.dividend),
112 self.divisor_radicand.eq(rhs.divisor_radicand),
113 self.operation.eq(rhs.operation),
114 ]
115
116
117 class DivPipeCoreInterstageData:
118 """ interstage data type for ``DivPipeCore``.
119
120 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
121 configuration to be used.
122 :attribute divisor_radicand: divisor for div/rem and radicand for
123 sqrt/rsqrt. Signal with a bit-width of ``core_config.bit_width`` and a
124 fract-width of ``core_config.fract_width`` bits.
125 :attribute operation: the ``DivPipeCoreOperation`` to be computed.
126 :attribute quotient_root: the quotient or root part of the result of the
127 operation. Signal with a bit-width of ``core_config.bit_width`` and a
128 fract-width of ``core_config.fract_width`` bits.
129 :attribute root_times_radicand: ``quotient_root * divisor_radicand``.
130 Signal with a bit-width of ``core_config.bit_width * 2`` and a
131 fract-width of ``core_config.fract_width * 2`` bits.
132 :attribute compare_lhs: The left-hand-side of the comparison in the
133 equation to be solved. Signal with a bit-width of
134 ``core_config.bit_width * 3`` and a fract-width of
135 ``core_config.fract_width * 3`` bits.
136 :attribute compare_rhs: The right-hand-side of the comparison in the
137 equation to be solved. Signal with a bit-width of
138 ``core_config.bit_width * 3`` and a fract-width of
139 ``core_config.fract_width * 3`` bits.
140 """
141
142 def __init__(self, core_config, reset_less=True):
143 """ Create a ``DivPipeCoreInterstageData`` instance. """
144 self.core_config = core_config
145 self.divisor_radicand = Signal(core_config.bit_width,
146 reset_less=reset_less)
147 self.operation = DP.create_signal(reset_less=reset_less)
148 self.quotient_root = Signal(core_config.bit_width,
149 reset_less=reset_less)
150 self.root_times_radicand = Signal(core_config.bit_width * 2,
151 reset_less=reset_less)
152 self.compare_lhs = Signal(core_config.bit_width * 3,
153 reset_less=reset_less)
154 self.compare_rhs = Signal(core_config.bit_width * 3,
155 reset_less=reset_less)
156
157 def __iter__(self):
158 """ Get member signals. """
159 yield self.divisor_radicand
160 yield self.operation
161 yield self.quotient_root
162 yield self.root_times_radicand
163 yield self.compare_lhs
164 yield self.compare_rhs
165
166 def eq(self, rhs):
167 """ Assign member signals. """
168 return [self.divisor_radicand.eq(rhs.divisor_radicand),
169 self.operation.eq(rhs.operation),
170 self.quotient_root.eq(rhs.quotient_root),
171 self.root_times_radicand.eq(rhs.root_times_radicand),
172 self.compare_lhs.eq(rhs.compare_lhs),
173 self.compare_rhs.eq(rhs.compare_rhs)]
174
175
176 class DivPipeCoreOutputData:
177 """ output data type for ``DivPipeCore``.
178
179 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
180 configuration to be used.
181 :attribute quotient_root: the quotient or root part of the result of the
182 operation. Signal with a bit-width of ``core_config.bit_width`` and a
183 fract-width of ``core_config.fract_width`` bits.
184 :attribute remainder: the remainder part of the result of the operation.
185 Signal with a bit-width of ``core_config.bit_width * 3`` and a
186 fract-width of ``core_config.fract_width * 3`` bits.
187 """
188
189 def __init__(self, core_config, reset_less=True):
190 """ Create a ``DivPipeCoreOutputData`` instance. """
191 self.core_config = core_config
192 self.quotient_root = Signal(core_config.bit_width,
193 reset_less=reset_less)
194 self.remainder = Signal(core_config.bit_width * 3,
195 reset_less=reset_less)
196
197 def __iter__(self):
198 """ Get member signals. """
199 yield self.quotient_root
200 yield self.remainder
201 return
202
203 def eq(self, rhs):
204 """ Assign member signals. """
205 return [self.quotient_root.eq(rhs.quotient_root),
206 self.remainder.eq(rhs.remainder)]
207
208
209 class DivPipeCoreSetupStage(Elaboratable):
210 """ Setup Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
211
212 def __init__(self, core_config):
213 """ Create a ``DivPipeCoreSetupStage`` instance."""
214 self.core_config = core_config
215 self.i = self.ispec()
216 self.o = self.ospec()
217
218 def ispec(self):
219 """ Get the input spec for this pipeline stage."""
220 return DivPipeCoreInputData(self.core_config)
221
222 def ospec(self):
223 """ Get the output spec for this pipeline stage."""
224 return DivPipeCoreInterstageData(self.core_config)
225
226 def setup(self, m, i):
227 """ Pipeline stage setup. """
228 m.submodules.div_pipe_core_setup = self
229 m.d.comb += self.i.eq(i)
230
231 def process(self, i):
232 """ Pipeline stage process. """
233 return self.o # return processed data (ignore i)
234
235 def elaborate(self, platform):
236 """ Elaborate into ``Module``. """
237 m = Module()
238
239 m.d.comb += self.o.divisor_radicand.eq(self.i.divisor_radicand)
240 m.d.comb += self.o.quotient_root.eq(0)
241 m.d.comb += self.o.root_times_radicand.eq(0)
242
243 with m.If(self.i.operation == int(DP.UDivRem)):
244 m.d.comb += self.o.compare_lhs.eq(self.i.dividend
245 << self.core_config.fract_width)
246 with m.Elif(self.i.operation == int(DP.SqrtRem)):
247 m.d.comb += self.o.compare_lhs.eq(
248 self.i.divisor_radicand << (self.core_config.fract_width * 2))
249 with m.Else(): # DivPipeCoreOperation.RSqrtRem
250 m.d.comb += self.o.compare_lhs.eq(
251 1 << (self.core_config.fract_width * 3))
252
253 m.d.comb += self.o.compare_rhs.eq(0)
254 m.d.comb += self.o.operation.eq(self.i.operation)
255
256 return m
257
258
259 class DivPipeCoreCalculateStage(Elaboratable):
260 """ Calculate Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
261
262 def __init__(self, core_config, stage_index):
263 """ Create a ``DivPipeCoreSetupStage`` instance. """
264 self.core_config = core_config
265 assert stage_index in range(core_config.n_stages)
266 self.stage_index = stage_index
267 self.i = self.ispec()
268 self.o = self.ospec()
269
270 def ispec(self):
271 """ Get the input spec for this pipeline stage. """
272 return DivPipeCoreInterstageData(self.core_config)
273
274 def ospec(self):
275 """ Get the output spec for this pipeline stage. """
276 return DivPipeCoreInterstageData(self.core_config)
277
278 def setup(self, m, i):
279 """ Pipeline stage setup. """
280 setattr(m.submodules,
281 f"div_pipe_core_calculate_{self.stage_index}",
282 self)
283 m.d.comb += self.i.eq(i)
284
285 def process(self, i):
286 """ Pipeline stage process. """
287 return self.o
288
289 def elaborate(self, platform):
290 """ Elaborate into ``Module``. """
291 m = Module()
292 m.d.comb += self.o.divisor_radicand.eq(self.i.divisor_radicand)
293 m.d.comb += self.o.operation.eq(self.i.operation)
294 m.d.comb += self.o.compare_lhs.eq(self.i.compare_lhs)
295 log2_radix = self.core_config.log2_radix
296 current_shift = self.core_config.bit_width
297 current_shift -= self.stage_index * log2_radix
298 log2_radix = min(log2_radix, current_shift)
299 assert log2_radix > 0
300 current_shift -= log2_radix
301 radix = 1 << log2_radix
302 trial_compare_rhs_values = []
303 pass_flags = []
304 for trial_bits in range(radix):
305 trial_bits_sig = Const(trial_bits, log2_radix)
306 trial_bits_sqrd_sig = Const(trial_bits * trial_bits,
307 log2_radix * 2)
308
309 dr_times_trial_bits = self.i.divisor_radicand * trial_bits_sig
310 dr_times_trial_bits_sqrd = self.i.divisor_radicand \
311 * trial_bits_sqrd_sig
312 qr_times_trial_bits = self.i.quotient_root * trial_bits_sig
313 rr_times_trial_bits = self.i.root_times_radicand * trial_bits_sig
314
315 trial_compare_rhs = Signal.like(
316 self.o.compare_rhs, name=f"trial_compare_rhs_{trial_bits}",
317 reset_less=True)
318 m.d.comb += trial_compare_rhs.eq(self.i.compare_rhs)
319
320 if trial_bits != 0: # no point adding multiply by zero
321 # UDivRem
322 with m.If(self.i.operation == int(DP.UDivRem)):
323 div_rhs = self.i.compare_rhs
324
325 div_term1 = dr_times_trial_bits
326 div_term1_shift = self.core_config.fract_width
327 div_term1_shift += current_shift
328 div_rhs += div_term1 << div_term1_shift
329
330 m.d.comb += trial_compare_rhs.eq(div_rhs)
331
332 # SqrtRem
333 with m.Elif(self.i.operation == int(DP.SqrtRem)):
334 sqrt_rhs = self.i.compare_rhs
335
336 sqrt_term1 = qr_times_trial_bits
337 sqrt_term1_shift = self.core_config.fract_width
338 sqrt_term1_shift += current_shift + 1
339 sqrt_rhs += sqrt_term1 << sqrt_term1_shift
340 sqrt_term2 = trial_bits_sqrd_sig
341 sqrt_term2_shift = self.core_config.fract_width
342 sqrt_term2_shift += current_shift * 2
343 sqrt_rhs += sqrt_term2 << sqrt_term2_shift
344
345 m.d.comb += trial_compare_rhs.eq(sqrt_rhs)
346
347 # RSqrtRem
348 with m.Else():
349 rsqrt_rhs = self.i.compare_rhs
350
351 rsqrt_term1 = rr_times_trial_bits
352 rsqrt_term1_shift = current_shift + 1
353 rsqrt_rhs += rsqrt_term1 << rsqrt_term1_shift
354 rsqrt_term2 = dr_times_trial_bits_sqrd
355 rsqrt_term2_shift = current_shift * 2
356 rsqrt_rhs += rsqrt_term2 << rsqrt_term2_shift
357
358 m.d.comb += trial_compare_rhs.eq(rsqrt_rhs)
359
360 trial_compare_rhs_values.append(trial_compare_rhs)
361
362 pass_flag = Signal(name=f"pass_flag_{trial_bits}", reset_less=True)
363 m.d.comb += pass_flag.eq(self.i.compare_lhs >= trial_compare_rhs)
364 pass_flags.append(pass_flag)
365
366 # convert pass_flags to next_bits.
367 #
368 # Assumes that for each set bit in pass_flag, all previous bits are
369 # also set.
370 #
371 # Assumes that pass_flag[0] is always set (since
372 # compare_lhs >= compare_rhs is a pipeline invariant).
373
374 next_bits = Signal(log2_radix, reset_less=True)
375 for i in range(log2_radix):
376 bit_value = 1
377 for j in range(0, radix, 1 << i):
378 bit_value ^= pass_flags[j]
379 m.d.comb += next_bits.part(i, 1).eq(bit_value)
380
381 # merge/select multi-bit trial_compare_rhs_values, to go
382 # into compare_rhs. XXX (only one of these will succeed?)
383 next_compare_rhs = 0
384 for i in range(radix):
385 next_flag = pass_flags[i + 1] if i + 1 < radix else 0
386 selected = Signal(name=f"selected_{i}", reset_less=True)
387 m.d.comb += selected.eq(pass_flags[i] & ~next_flag)
388 next_compare_rhs |= Mux(selected,
389 trial_compare_rhs_values[i],
390 0)
391
392 m.d.comb += self.o.compare_rhs.eq(next_compare_rhs)
393 m.d.comb += self.o.root_times_radicand.eq(self.i.root_times_radicand
394 + ((self.i.divisor_radicand
395 * next_bits)
396 << current_shift))
397 m.d.comb += self.o.quotient_root.eq(self.i.quotient_root
398 | (next_bits << current_shift))
399 return m
400
401
402 class DivPipeCoreFinalStage(Elaboratable):
403 """ Final Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
404
405 def __init__(self, core_config):
406 """ Create a ``DivPipeCoreFinalStage`` instance."""
407 self.core_config = core_config
408 self.i = self.ispec()
409 self.o = self.ospec()
410
411 def ispec(self):
412 """ Get the input spec for this pipeline stage."""
413 return DivPipeCoreInterstageData(self.core_config)
414
415 def ospec(self):
416 """ Get the output spec for this pipeline stage."""
417 return DivPipeCoreOutputData(self.core_config)
418
419 def setup(self, m, i):
420 """ Pipeline stage setup. """
421 m.submodules.div_pipe_core_final = self
422 m.d.comb += self.i.eq(i)
423
424 def process(self, i):
425 """ Pipeline stage process. """
426 return self.o # return processed data (ignore i)
427
428 def elaborate(self, platform):
429 """ Elaborate into ``Module``. """
430 m = Module()
431
432 m.d.comb += self.o.quotient_root.eq(self.i.quotient_root)
433 m.d.comb += self.o.remainder.eq(self.i.compare_lhs
434 - self.i.compare_rhs)
435
436 return m