1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
4 """ Algorithms for div/rem/sqrt/rsqrt.
6 code for simulating/testing the various algorithms
9 from nmigen
.hdl
.ast
import Const
13 def div_rem(dividend
, divisor
, bit_width
, signed
):
14 """ Compute the quotient/remainder following the RISC-V M extension.
16 NOT the same as the // or % operators
18 dividend
= Const
.normalize(dividend
, (bit_width
, signed
))
19 divisor
= Const
.normalize(divisor
, (bit_width
, signed
))
24 quotient
= abs(dividend
) // abs(divisor
)
25 remainder
= abs(dividend
) % abs(divisor
)
26 if (dividend
< 0) != (divisor
< 0):
29 remainder
= -remainder
30 quotient
= Const
.normalize(quotient
, (bit_width
, signed
))
31 remainder
= Const
.normalize(remainder
, (bit_width
, signed
))
32 return quotient
, remainder
36 """ Unsigned integer division/remainder following the RISC-V M extension.
38 NOT the same as the // or % operators
40 :attribute remainder: the remainder and/or dividend
41 :attribute divisor: the divisor
42 :attribute bit_width: the bit width of the inputs/outputs
43 :attribute log2_radix: the base-2 log of the division radix. The number of
44 bits of quotient that are calculated per pipeline stage.
45 :attribute quotient: the quotient
46 :attribute current_shift: the current bit index
49 def __init__(self
, dividend
, divisor
, bit_width
, log2_radix
=3):
50 """ Create an UnsignedDivRem.
52 :param dividend: the dividend/numerator
53 :param divisor: the divisor/denominator
54 :param bit_width: the bit width of the inputs/outputs
55 :param log2_radix: the base-2 log of the division radix. The number of
56 bits of quotient that are calculated per pipeline stage.
58 self
.remainder
= Const
.normalize(dividend
, (bit_width
, False))
59 self
.divisor
= Const
.normalize(divisor
, (bit_width
, False))
60 self
.bit_width
= bit_width
61 self
.log2_radix
= log2_radix
63 self
.current_shift
= bit_width
65 def calculate_stage(self
):
66 """ Calculate the next pipeline stage of the division.
68 :returns bool: True if this is the last pipeline stage.
70 if self
.current_shift
== 0:
72 log2_radix
= min(self
.log2_radix
, self
.current_shift
)
74 self
.current_shift
-= log2_radix
75 radix
= 1 << log2_radix
77 for i
in range(radix
):
78 v
= (self
.divisor
* i
) << self
.current_shift
79 remainders
.append(self
.remainder
- v
)
81 for i
in range(radix
):
82 if remainders
[i
] >= 0:
84 self
.remainder
= remainders
[quotient_bits
]
85 self
.quotient |
= quotient_bits
<< self
.current_shift
86 return self
.current_shift
== 0
89 """ Calculate the results of the division.
93 while not self
.calculate_stage():
99 """ integer division/remainder following the RISC-V M extension.
101 NOT the same as the // or % operators
103 :attribute dividend: the dividend
104 :attribute divisor: the divisor
105 :attribute signed: if the inputs/outputs are signed instead of unsigned
106 :attribute quotient: the quotient
107 :attribute remainder: the remainder
108 :attribute divider: the base UnsignedDivRem
111 def __init__(self
, dividend
, divisor
, bit_width
, signed
, log2_radix
=3):
114 :param dividend: the dividend/numerator
115 :param divisor: the divisor/denominator
116 :param bit_width: the bit width of the inputs/outputs
117 :param signed: if the inputs/outputs are signed instead of unsigned
118 :param log2_radix: the base-2 log of the division radix. The number of
119 bits of quotient that are calculated per pipeline stage.
121 self
.dividend
= Const
.normalize(dividend
, (bit_width
, signed
))
122 self
.divisor
= Const
.normalize(divisor
, (bit_width
, signed
))
126 self
.divider
= UnsignedDivRem(abs(dividend
), abs(divisor
),
127 bit_width
, log2_radix
)
129 def calculate_stage(self
):
130 """ Calculate the next pipeline stage of the division.
132 :returns bool: True if this is the last pipeline stage.
134 if not self
.divider
.calculate_stage():
136 divisor_sign
= self
.divisor
< 0
137 dividend_sign
= self
.dividend
< 0
138 if self
.divisor
!= 0 and divisor_sign
!= dividend_sign
:
139 quotient
= -self
.divider
.quotient
141 quotient
= self
.divider
.quotient
143 remainder
= -self
.divider
.remainder
145 remainder
= self
.divider
.remainder
146 bit_width
= self
.divider
.bit_width
147 self
.quotient
= Const
.normalize(quotient
, (bit_width
, self
.signed
))
148 self
.remainder
= Const
.normalize(remainder
, (bit_width
, self
.signed
))
153 """ Fixed-point number.
155 the value is bits * 2 ** -fract_width
157 :attribute bits: the bits of the fixed-point number
158 :attribute fract_width: the number of bits in the fractional portion
159 :attribute bit_width: the total number of bits
160 :attribute signed: if the type is signed
164 def from_bits(bits
, fract_width
, bit_width
, signed
):
165 """ Create a new Fixed.
167 :param bits: the bits of the fixed-point number
168 :param fract_width: the number of bits in the fractional portion
169 :param bit_width: the total number of bits
170 :param signed: if the type is signed
172 retval
= Fixed(0, fract_width
, bit_width
, signed
)
173 retval
.bits
= Const
.normalize(bits
, (bit_width
, signed
))
176 def __init__(self
, value
, fract_width
, bit_width
, signed
):
177 """ Create a new Fixed.
179 Note: ``value`` is not the same as ``bits``. To put a particular number
180 in ``bits``, use ``Fixed.from_bits``.
182 :param value: the value of the fixed-point number
183 :param fract_width: the number of bits in the fractional portion
184 :param bit_width: the total number of bits
185 :param signed: if the type is signed
187 assert fract_width
>= 0
189 if isinstance(value
, Fixed
):
190 if fract_width
< value
.fract_width
:
191 bits
= value
.bits
>> (value
.fract_width
- fract_width
)
193 bits
= value
.bits
<< (fract_width
- value
.fract_width
)
194 elif isinstance(value
, int):
195 bits
= value
<< fract_width
197 bits
= math
.floor(value
* 2 ** fract_width
)
198 self
.bits
= Const
.normalize(bits
, (bit_width
, signed
))
199 self
.fract_width
= fract_width
200 self
.bit_width
= bit_width
203 def with_bits(self
, bits
):
204 """ Create a new Fixed with the specified bits.
206 :param bits: the new bits.
207 :returns Fixed: the new Fixed.
209 return self
.from_bits(bits
,
214 def with_value(self
, value
):
215 """ Create a new Fixed with the specified value.
217 :param value: the new value.
218 :returns Fixed: the new Fixed.
226 """ Get representation."""
227 retval
= f
"Fixed.from_bits({self.bits}, {self.fract_width}, "
228 return retval
+ f
"{self.bit_width}, {self.signed})"
231 """ Truncate to integer."""
233 return self
.__ceil
__()
234 return self
.__floor
__()
237 """ Truncate to integer."""
238 return self
.__trunc
__()
241 """ Convert to float."""
242 return self
.bits
* 2.0 ** -self
.fract_width
245 """ Floor to integer."""
246 return self
.bits
>> self
.fract_width
249 """ Ceil to integer."""
250 return -((-self
.bits
) >> self
.fract_width
)
254 return self
.from_bits(-self
.bits
, self
.fract_width
,
255 self
.bit_width
, self
.signed
)
258 """ Unary Positive."""
262 """ Absolute Value."""
263 return self
.from_bits(abs(self
.bits
), self
.fract_width
,
264 self
.bit_width
, self
.signed
)
266 def __invert__(self
):
268 return self
.from_bits(~self
.bits
, self
.fract_width
,
269 self
.bit_width
, self
.signed
)
271 def _binary_op(self
, rhs
, operation
, full
=False):
272 """ Handle binary arithmetic operators. """
273 if isinstance(rhs
, int):
276 int_width
= self
.bit_width
- self
.fract_width
277 elif isinstance(rhs
, Fixed
):
278 if self
.signed
!= rhs
.signed
:
279 return TypeError("signedness must match")
280 rhs_fract_width
= rhs
.fract_width
282 int_width
= max(self
.bit_width
- self
.fract_width
,
283 rhs
.bit_width
- rhs
.fract_width
)
285 return NotImplemented
286 fract_width
= max(self
.fract_width
, rhs_fract_width
)
287 rhs_bits
<<= fract_width
- rhs_fract_width
288 lhs_bits
= self
.bits
<< fract_width
- self
.fract_width
289 bit_width
= int_width
+ fract_width
291 return operation(lhs_bits
, rhs_bits
,
292 fract_width
, bit_width
, self
.signed
)
293 bits
= operation(lhs_bits
, rhs_bits
,
295 return self
.from_bits(bits
, fract_width
, bit_width
, self
.signed
)
297 def __add__(self
, rhs
):
299 return self
._binary
_op
(rhs
, lambda lhs
, rhs
, fract_width
: lhs
+ rhs
)
301 def __radd__(self
, lhs
):
302 """ Reverse Addition."""
303 return self
.__add
__(lhs
)
305 def __sub__(self
, rhs
):
307 return self
._binary
_op
(rhs
, lambda lhs
, rhs
, fract_width
: lhs
- rhs
)
309 def __rsub__(self
, lhs
):
310 """ Reverse Subtraction."""
311 # note swapped argument and parameter order
312 return self
._binary
_op
(lhs
, lambda rhs
, lhs
, fract_width
: lhs
- rhs
)
314 def __and__(self
, rhs
):
316 return self
._binary
_op
(rhs
, lambda lhs
, rhs
, fract_width
: lhs
& rhs
)
318 def __rand__(self
, lhs
):
319 """ Reverse Bitwise And."""
320 return self
.__and
__(lhs
)
322 def __or__(self
, rhs
):
324 return self
._binary
_op
(rhs
, lambda lhs
, rhs
, fract_width
: lhs | rhs
)
326 def __ror__(self
, lhs
):
327 """ Reverse Bitwise Or."""
328 return self
.__or
__(lhs
)
330 def __xor__(self
, rhs
):
332 return self
._binary
_op
(rhs
, lambda lhs
, rhs
, fract_width
: lhs ^ rhs
)
334 def __rxor__(self
, lhs
):
335 """ Reverse Bitwise Xor."""
336 return self
.__xor
__(lhs
)
338 def __mul__(self
, rhs
):
339 """ Multiplication. """
340 if isinstance(rhs
, int):
343 int_width
= self
.bit_width
- self
.fract_width
344 elif isinstance(rhs
, Fixed
):
345 if self
.signed
!= rhs
.signed
:
346 return TypeError("signedness must match")
347 rhs_fract_width
= rhs
.fract_width
349 int_width
= (self
.bit_width
- self
.fract_width
350 + rhs
.bit_width
- rhs
.fract_width
)
352 return NotImplemented
353 fract_width
= self
.fract_width
+ rhs_fract_width
354 bit_width
= int_width
+ fract_width
355 bits
= self
.bits
* rhs_bits
356 return self
.from_bits(bits
, fract_width
, bit_width
, self
.signed
)
358 def __rmul__(self
, rhs
):
359 """ Reverse Multiplication. """
360 return self
.__mul
__(rhs
)
363 def _cmp_impl(lhs
, rhs
, fract_width
, bit_width
, signed
):
371 """ Compare self with rhs.
373 :returns int: returns -1 if self is less than rhs, 0 if they're equal,
374 and 1 for greater than.
375 Returns NotImplemented for unimplemented cases
377 return self
._binary
_op
(rhs
, self
._cmp
_impl
, full
=True)
379 def __lt__(self
, rhs
):
381 return self
.cmp(rhs
) < 0
383 def __le__(self
, rhs
):
384 """ Less Than or Equal."""
385 return self
.cmp(rhs
) <= 0
387 def __eq__(self
, rhs
):
389 return self
.cmp(rhs
) == 0
391 def __ne__(self
, rhs
):
393 return self
.cmp(rhs
) != 0
395 def __gt__(self
, rhs
):
397 return self
.cmp(rhs
) > 0
399 def __ge__(self
, rhs
):
400 """ Greater Than or Equal."""
401 return self
.cmp(rhs
) >= 0
404 """ Convert to bool."""
405 return bool(self
.bits
)
408 """ Get text representation."""
409 # don't just use self.__float__() in order to work with numbers more
416 int_part
= bits
>> self
.fract_width
417 fract_part
= bits
& ~
(-1 << self
.fract_width
)
418 # round up fract_width to nearest multiple of 4
419 fract_width
= (self
.fract_width
+ 3) & ~
3
420 fract_part
<<= (fract_width
- self
.fract_width
)
421 fract_width_in_hex_digits
= fract_width
// 4
422 retval
+= f
"0x{int_part:x}."
423 if fract_width_in_hex_digits
!= 0:
424 retval
+= f
"{fract_part:x}".zfill(fract_width_in_hex_digits
)
429 """ A polynomial root and remainder.
431 :attribute root: the polynomial root.
432 :attribute remainder: the remainder.
435 def __init__(self
, root
, remainder
):
436 """ Create a new RootRemainder.
438 :param root: the polynomial root.
439 :param remainder: the remainder.
442 self
.remainder
= remainder
445 """ Get the representation as a string. """
446 return f
"RootRemainder({repr(self.root)}, {repr(self.remainder)})"
449 """ Convert to a string. """
450 return f
"RootRemainder({str(self.root)}, {str(self.remainder)})"
453 def fixed_sqrt(radicand
):
454 """ Compute the Square Root and Remainder.
456 Solves the polynomial ``radicand - x * x == 0``
458 :param radicand: the ``Fixed`` to take the square root of.
459 :returns RootRemainder:
461 # Written for correctness, not speed
464 is_int
= isinstance(radicand
, int)
466 radicand
= Fixed(radicand
, 0, radicand
.bit_length() + 1, True)
467 elif not isinstance(radicand
, Fixed
):
470 def is_remainder_non_negative(root
):
471 return radicand
>= root
* root
473 root
= radicand
.with_bits(0)
474 for i
in reversed(range(root
.bit_width
)):
475 new_root
= root
.with_bits(root
.bits |
(1 << i
))
476 if new_root
< 0: # skip sign bit
478 if is_remainder_non_negative(new_root
):
480 remainder
= radicand
- root
* root
483 remainder
= int(remainder
)
484 return RootRemainder(root
, remainder
)
492 def fixed_rsqrt(radicand
):
494 raise NotImplementedError()