a1a8edefc343b21e1c48db9bedd13e3f8c126137
[ieee754fpu.git] / src / ieee754 / div_rem_sqrt_rsqrt / core.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """ Core of the div/rem/sqrt/rsqrt pipeline.
4
5 Special case handling, input/output conversion, and muxid handling are handled
6 outside of these classes.
7
8 Algorithms based on ``algorithm.FixedUDivRemSqrtRSqrt``.
9
10 Formulas solved are:
11 * div/rem:
12 ``dividend == quotient_root * divisor_radicand``
13 * sqrt/rem:
14 ``divisor_radicand == quotient_root * quotient_root``
15 * rsqrt/rem:
16 ``1 == quotient_root * quotient_root * divisor_radicand``
17
18 The remainder is the left-hand-side of the comparison minus the
19 right-hand-side of the comparison in the above formulas.
20 """
21 from nmigen import (Elaboratable, Module, Signal, Const, Mux, Cat)
22 import enum
23
24
25 class DivPipeCoreConfig:
26 """ Configuration for core of the div/rem/sqrt/rsqrt pipeline.
27
28 :attribute bit_width: base bit-width.
29 :attribute fract_width: base fract-width. Specifies location of base-2
30 radix point.
31 :attribute log2_radix: number of bits of ``quotient_root`` that should be
32 computed per pipeline stage.
33 """
34
35 def __init__(self, bit_width, fract_width, log2_radix):
36 """ Create a ``DivPipeCoreConfig`` instance. """
37 self.bit_width = bit_width
38 self.fract_width = fract_width
39 self.log2_radix = log2_radix
40
41 def __repr__(self):
42 """ Get repr. """
43 return f"DivPipeCoreConfig({self.bit_width}, " \
44 + f"{self.fract_width}, {self.log2_radix})"
45
46 @property
47 def num_calculate_stages(self):
48 """ Get the number of ``DivPipeCoreCalculateStage`` needed. """
49 return (self.bit_width + self.log2_radix - 1) // self.log2_radix
50
51
52 class DivPipeCoreOperation(enum.IntEnum):
53 """ Operation for ``DivPipeCore``.
54
55 :attribute UDivRem: unsigned divide/remainder.
56 :attribute SqrtRem: square-root/remainder.
57 :attribute RSqrtRem: reciprocal-square-root/remainder.
58 """
59
60 UDivRem = 0
61 SqrtRem = 1
62 RSqrtRem = 2
63
64 @classmethod
65 def create_signal(cls, *, src_loc_at=0, **kwargs):
66 """ Create a signal that can contain a ``DivPipeCoreOperation``. """
67 return Signal(min=int(min(cls)),
68 max=int(max(cls)),
69 src_loc_at=(src_loc_at + 1),
70 decoder=cls,
71 **kwargs)
72
73
74 class DivPipeCoreInputData:
75 """ input data type for ``DivPipeCore``.
76
77 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
78 configuration to be used.
79 :attribute dividend: dividend for div/rem. Signal with a bit-width of
80 ``core_config.bit_width + core_config.fract_width`` and a fract-width
81 of ``core_config.fract_width * 2`` bits.
82 :attribute divisor_radicand: divisor for div/rem and radicand for
83 sqrt/rsqrt. Signal with a bit-width of ``core_config.bit_width`` and a
84 fract-width of ``core_config.fract_width`` bits.
85 :attribute operation: the ``DivPipeCoreOperation`` to be computed.
86 """
87
88 def __init__(self, core_config, reset_less=True):
89 """ Create a ``DivPipeCoreInputData`` instance. """
90 self.core_config = core_config
91 self.dividend = Signal(core_config.bit_width + core_config.fract_width,
92 reset_less=reset_less)
93 self.divisor_radicand = Signal(core_config.bit_width,
94 reset_less=reset_less)
95
96 # FIXME: this goes into (is replaced by) self.ctx.op
97 self.operation = \
98 DivPipeCoreOperation.create_signal(reset_less=reset_less)
99
100 def __iter__(self):
101 """ Get member signals. """
102 yield self.dividend
103 yield self.divisor_radicand
104 yield self.operation # FIXME: delete. already covered by self.ctx
105
106 def eq(self, rhs):
107 """ Assign member signals. """
108 return [self.dividend.eq(rhs.dividend),
109 self.divisor_radicand.eq(rhs.divisor_radicand),
110 self.operation.eq(rhs.operation), # FIXME: delete.
111 ]
112
113
114 class DivPipeCoreInterstageData:
115 """ interstage data type for ``DivPipeCore``.
116
117 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
118 configuration to be used.
119 :attribute divisor_radicand: divisor for div/rem and radicand for
120 sqrt/rsqrt. Signal with a bit-width of ``core_config.bit_width`` and a
121 fract-width of ``core_config.fract_width`` bits.
122 :attribute operation: the ``DivPipeCoreOperation`` to be computed.
123 :attribute quotient_root: the quotient or root part of the result of the
124 operation. Signal with a bit-width of ``core_config.bit_width`` and a
125 fract-width of ``core_config.fract_width`` bits.
126 :attribute root_times_radicand: ``quotient_root * divisor_radicand``.
127 Signal with a bit-width of ``core_config.bit_width * 2`` and a
128 fract-width of ``core_config.fract_width * 2`` bits.
129 :attribute compare_lhs: The left-hand-side of the comparison in the
130 equation to be solved. Signal with a bit-width of
131 ``core_config.bit_width * 3`` and a fract-width of
132 ``core_config.fract_width * 3`` bits.
133 :attribute compare_rhs: The right-hand-side of the comparison in the
134 equation to be solved. Signal with a bit-width of
135 ``core_config.bit_width * 3`` and a fract-width of
136 ``core_config.fract_width * 3`` bits.
137 """
138
139 def __init__(self, core_config, reset_less=True):
140 """ Create a ``DivPipeCoreInterstageData`` instance. """
141 self.core_config = core_config
142 self.divisor_radicand = Signal(core_config.bit_width,
143 reset_less=reset_less)
144 # FIXME: delete self.operation. already covered by self.ctx.op
145 self.operation = \
146 DivPipeCoreOperation.create_signal(reset_less=reset_less)
147 self.quotient_root = Signal(core_config.bit_width,
148 reset_less=reset_less)
149 self.root_times_radicand = Signal(core_config.bit_width * 2,
150 reset_less=reset_less)
151 self.compare_lhs = Signal(core_config.bit_width * 3,
152 reset_less=reset_less)
153 self.compare_rhs = Signal(core_config.bit_width * 3,
154 reset_less=reset_less)
155
156 def __iter__(self):
157 """ Get member signals. """
158 yield self.divisor_radicand
159 yield self.operation # FIXME: delete. already in self.ctx.op
160 yield self.quotient_root
161 yield self.root_times_radicand
162 yield self.compare_lhs
163 yield self.compare_rhs
164
165 def eq(self, rhs):
166 """ Assign member signals. """
167 return [self.divisor_radicand.eq(rhs.divisor_radicand),
168 self.operation.eq(rhs.operation), # FIXME: delete.
169 self.quotient_root.eq(rhs.quotient_root),
170 self.root_times_radicand.eq(rhs.root_times_radicand),
171 self.compare_lhs.eq(rhs.compare_lhs),
172 self.compare_rhs.eq(rhs.compare_rhs)]
173
174
175 class DivPipeCoreOutputData:
176 """ output data type for ``DivPipeCore``.
177
178 :attribute core_config: ``DivPipeCoreConfig`` instance describing the
179 configuration to be used.
180 :attribute quotient_root: the quotient or root part of the result of the
181 operation. Signal with a bit-width of ``core_config.bit_width`` and a
182 fract-width of ``core_config.fract_width`` bits.
183 :attribute remainder: the remainder part of the result of the operation.
184 Signal with a bit-width of ``core_config.bit_width * 3`` and a
185 fract-width of ``core_config.fract_width * 3`` bits.
186 """
187
188 def __init__(self, core_config, reset_less=True):
189 """ Create a ``DivPipeCoreOutputData`` instance. """
190 self.core_config = core_config
191 self.quotient_root = Signal(core_config.bit_width,
192 reset_less=reset_less)
193 self.remainder = Signal(core_config.bit_width * 3,
194 reset_less=reset_less)
195
196 def __iter__(self):
197 """ Get member signals. """
198 yield self.quotient_root
199 yield self.remainder
200 return
201
202 def eq(self, rhs):
203 """ Assign member signals. """
204 return [self.quotient_root.eq(rhs.quotient_root),
205 self.remainder.eq(rhs.remainder)]
206
207
208 class DivPipeCoreSetupStage(Elaboratable):
209 """ Setup Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
210
211 def __init__(self, core_config):
212 """ Create a ``DivPipeCoreSetupStage`` instance."""
213 self.core_config = core_config
214 self.i = self.ispec()
215 self.o = self.ospec()
216
217 def ispec(self):
218 """ Get the input spec for this pipeline stage."""
219 return DivPipeCoreInputData(self.core_config)
220
221 def ospec(self):
222 """ Get the output spec for this pipeline stage."""
223 return DivPipeCoreInterstageData(self.core_config)
224
225 def setup(self, m, i):
226 """ Pipeline stage setup. """
227 m.submodules.div_pipe_core_setup = self
228 m.d.comb += self.i.eq(i)
229
230 def process(self, i):
231 """ Pipeline stage process. """
232 return self.o # return processed data (ignore i)
233
234 def elaborate(self, platform):
235 """ Elaborate into ``Module``. """
236 m = Module()
237
238 m.d.comb += self.o.divisor_radicand.eq(self.i.divisor_radicand)
239 m.d.comb += self.o.quotient_root.eq(0)
240 m.d.comb += self.o.root_times_radicand.eq(0)
241
242 with m.If(self.i.operation == DivPipeCoreOperation.UDivRem):
243 m.d.comb += self.o.compare_lhs.eq(self.i.dividend
244 << self.core_config.fract_width)
245 with m.Elif(self.i.operation == DivPipeCoreOperation.SqrtRem):
246 m.d.comb += self.o.compare_lhs.eq(
247 self.i.divisor_radicand << (self.core_config.fract_width * 2))
248 with m.Else(): # DivPipeCoreOperation.RSqrtRem
249 m.d.comb += self.o.compare_lhs.eq(
250 1 << (self.core_config.fract_width * 3))
251
252 m.d.comb += self.o.compare_rhs.eq(0)
253 m.d.comb += self.o.operation.eq(self.i.operation)
254
255 return m
256
257
258 class DivPipeCoreCalculateStage(Elaboratable):
259 """ Calculate Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
260
261 def __init__(self, core_config, stage_index):
262 """ Create a ``DivPipeCoreSetupStage`` instance. """
263 self.core_config = core_config
264 assert stage_index in range(core_config.num_calculate_stages)
265 self.stage_index = stage_index
266 self.i = self.ispec()
267 self.o = self.ospec()
268
269 def ispec(self):
270 """ Get the input spec for this pipeline stage. """
271 return DivPipeCoreInterstageData(self.core_config)
272
273 def ospec(self):
274 """ Get the output spec for this pipeline stage. """
275 return DivPipeCoreInterstageData(self.core_config)
276
277 def setup(self, m, i):
278 """ Pipeline stage setup. """
279 setattr(m.submodules,
280 f"div_pipe_core_calculate_{self.stage_index}",
281 self)
282 m.d.comb += self.i.eq(i)
283
284 def process(self, i):
285 """ Pipeline stage process. """
286 return self.o
287
288 def elaborate(self, platform):
289 """ Elaborate into ``Module``. """
290 m = Module()
291 m.d.comb += self.o.divisor_radicand.eq(self.i.divisor_radicand)
292 m.d.comb += self.o.operation.eq(self.i.operation)
293 m.d.comb += self.o.compare_lhs.eq(self.i.compare_lhs)
294 log2_radix = self.core_config.log2_radix
295 current_shift = self.core_config.bit_width
296 current_shift -= self.stage_index * log2_radix
297 log2_radix = min(log2_radix, current_shift)
298 assert log2_radix > 0
299 current_shift -= log2_radix
300 radix = 1 << log2_radix
301 trial_compare_rhs_values = []
302 pass_flags = []
303 for trial_bits in range(radix):
304 tb = trial_bits << current_shift
305 tb_width = log2_radix + current_shift
306 shifted_trial_bits = Const(tb, tb_width)
307 shifted_trial_bits2 = Const(tb*2, tb_width+1)
308 shifted_trial_bits_sqrd = Const(tb * tb, tb_width * 2)
309
310 # UDivRem
311 div_rhs = self.i.compare_rhs
312 if tb != 0: # no point adding stuff that's multiplied by zero
313 div_factor1 = self.i.divisor_radicand * shifted_trial_bits2
314 div_rhs += div_factor1 << self.core_config.fract_width
315
316 # SqrtRem
317 sqrt_rhs = self.i.compare_rhs
318 if tb != 0: # no point adding stuff that's multiplied by zero
319 sqrt_factor1 = self.i.quotient_root * shifted_trial_bits2
320 sqrt_rhs += sqrt_factor1 << self.core_config.fract_width
321 sqrt_factor2 = shifted_trial_bits_sqrd
322 sqrt_rhs += sqrt_factor2 << self.core_config.fract_width
323
324 # RSqrtRem
325 rsqrt_rhs = self.i.compare_rhs
326 if tb != 0: # no point adding stuff that's multiplied by zero
327 rsqrt_rhs += self.i.root_times_radicand * shifted_trial_bits2
328 rsqrt_rhs += self.i.divisor_radicand * shifted_trial_bits_sqrd
329
330 trial_compare_rhs = Signal.like(
331 self.o.compare_rhs, name=f"trial_compare_rhs_{trial_bits}",
332 reset_less=True)
333
334 with m.If(self.i.operation == DivPipeCoreOperation.UDivRem):
335 m.d.comb += trial_compare_rhs.eq(div_rhs)
336 with m.Elif(self.i.operation == DivPipeCoreOperation.SqrtRem):
337 m.d.comb += trial_compare_rhs.eq(sqrt_rhs)
338 with m.Else(): # DivPipeCoreOperation.RSqrtRem
339 m.d.comb += trial_compare_rhs.eq(rsqrt_rhs)
340 trial_compare_rhs_values.append(trial_compare_rhs)
341
342 pass_flag = Signal(name=f"pass_flag_{trial_bits}", reset_less=True)
343 m.d.comb += pass_flag.eq(self.i.compare_lhs >= trial_compare_rhs)
344 pass_flags.append(pass_flag)
345
346 # convert pass_flags to next_bits.
347 #
348 # Assumes that for each set bit in pass_flag, all previous bits are
349 # also set.
350 #
351 # Assumes that pass_flag[0] is always set (since
352 # compare_lhs >= compare_rhs is a pipeline invariant).
353
354 next_bits = Signal(log2_radix, reset_less=True)
355 for i in range(log2_radix):
356 bit_value = 1
357 for j in range(0, radix, 1 << i):
358 bit_value ^= pass_flags[j]
359 m.d.comb += next_bits.part(i, 1).eq(bit_value)
360
361 next_compare_rhs = Signal(radix, reset_less=True)
362 l = []
363 for i in range(radix):
364 next_flag = pass_flags[i + 1] if (i + 1 < radix) else Const(0)
365 flag = Signal(reset_less=True, name=f"flag{i}")
366 test = Signal(reset_less=True, name=f"test{i}")
367 # XXX TODO: check the width on this
368 m.d.comb += test.eq((pass_flags[i] & ~next_flag))
369 m.d.comb += flag.eq(Mux(test, trial_compare_rhs_values[i], 0))
370 l.append(flag)
371
372 m.d.comb += next_compare_rhs.eq(Cat(*l))
373 m.d.comb += self.o.compare_rhs.eq(next_compare_rhs.bool())
374 m.d.comb += self.o.root_times_radicand.eq(self.i.root_times_radicand
375 + ((self.i.divisor_radicand
376 * next_bits)
377 << current_shift))
378 m.d.comb += self.o.quotient_root.eq(self.i.quotient_root
379 | (next_bits << current_shift))
380 return m
381
382
383 class DivPipeCoreFinalStage(Elaboratable):
384 """ Final Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
385
386 def __init__(self, core_config):
387 """ Create a ``DivPipeCoreFinalStage`` instance."""
388 self.core_config = core_config
389 self.i = self.ispec()
390 self.o = self.ospec()
391
392 def ispec(self):
393 """ Get the input spec for this pipeline stage."""
394 return DivPipeCoreInterstageData(self.core_config)
395
396 def ospec(self):
397 """ Get the output spec for this pipeline stage."""
398 return DivPipeCoreOutputData(self.core_config)
399
400 def setup(self, m, i):
401 """ Pipeline stage setup. """
402 m.submodules.div_pipe_core_final = self
403 m.d.comb += self.i.eq(i)
404
405 def process(self, i):
406 """ Pipeline stage process. """
407 return self.o # return processed data (ignore i)
408
409 def elaborate(self, platform):
410 """ Elaborate into ``Module``. """
411 m = Module()
412
413 m.d.comb += self.o.quotient_root.eq(self.i.quotient_root)
414 m.d.comb += self.o.remainder.eq(self.i.compare_lhs
415 - self.i.compare_rhs)
416
417 return m