1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
7 from dataclasses
import dataclass
13 class RoundDir(enum
.Enum
):
16 NEAREST_TIES_UP
= enum
.auto()
17 ERROR_IF_INEXACT
= enum
.auto()
20 @dataclass(frozen
=True)
25 def __post_init__(self
):
26 assert isinstance(self
.bits
, int)
27 assert isinstance(self
.frac_wid
, int) and self
.frac_wid
>= 0
31 """convert `value` to a fixed-point number with enough fractional
32 bits to preserve its value."""
33 if isinstance(value
, FixedPoint
):
35 if isinstance(value
, int):
36 return FixedPoint(value
, 0)
37 if isinstance(value
, str):
39 neg
= value
.startswith("-")
40 if neg
or value
.startswith("+"):
42 if value
.startswith(("0x", "0X")) and "." in value
:
52 raise ValueError("too many `.` in string")
57 if not digit
.isalnum():
58 raise ValueError("invalid hexadecimal digit")
60 bits |
= int("0x" + digit
, base
=16)
62 bits
= int(value
, base
=0)
66 return FixedPoint(bits
, frac_wid
)
68 if isinstance(value
, float):
69 n
, d
= value
.as_integer_ratio()
70 log2_d
= d
.bit_length() - 1
71 assert d
== 1 << log2_d
, ("d isn't a power of 2 -- won't ever "
72 "fail with float being IEEE 754")
73 return FixedPoint(n
, log2_d
)
74 raise TypeError("can't convert type to FixedPoint")
77 def with_frac_wid(value
, frac_wid
, round_dir
=RoundDir
.ERROR_IF_INEXACT
):
78 """convert `value` to the nearest fixed-point number with `frac_wid`
79 fractional bits, rounding according to `round_dir`."""
80 value
= FixedPoint
.cast(value
)
81 assert isinstance(frac_wid
, int) and frac_wid
>= 0
82 assert isinstance(round_dir
, RoundDir
)
83 # compute number of bits that should be removed from value
84 del_bits
= value
.frac_wid
- frac_wid
87 if del_bits
< 0: # add bits
88 return FixedPoint(value
.bits
<< -del_bits
,
90 if round_dir
== RoundDir
.DOWN
:
91 bits
= value
.bits
>> del_bits
92 elif round_dir
== RoundDir
.UP
:
93 bits
= -((-value
.bits
) >> del_bits
)
94 elif round_dir
== RoundDir
.NEAREST_TIES_UP
:
95 bits
= value
.bits
>> (del_bits
- 1)
98 elif round_dir
== RoundDir
.ERROR_IF_INEXACT
:
99 bits
= value
.bits
>> del_bits
100 if bits
<< del_bits
!= value
.bits
:
101 raise ValueError("inexact conversion")
103 assert False, "unimplemented round_dir"
104 return FixedPoint(bits
, frac_wid
)
106 def to_frac_wid(self
, frac_wid
, round_dir
=RoundDir
.ERROR_IF_INEXACT
):
107 """convert to the nearest fixed-point number with `frac_wid`
108 fractional bits, rounding according to `round_dir`."""
109 return FixedPoint
.with_frac_wid(self
, frac_wid
, round_dir
)
112 return self
.bits
* 2.0 ** -self
.frac_wid
115 """compare self with rhs, returning a positive integer if self is
116 greater than rhs, zero if self is equal to rhs, and a negative integer
117 if self is less than rhs."""
118 rhs
= FixedPoint
.cast(rhs
)
119 common_frac_wid
= max(self
.frac_wid
, rhs
.frac_wid
)
120 lhs
= self
.to_frac_wid(common_frac_wid
)
121 rhs
= rhs
.to_frac_wid(common_frac_wid
)
122 return lhs
.bits
- rhs
.bits
124 def __eq__(self
, rhs
):
125 return self
.cmp(rhs
) == 0
127 def __ne__(self
, rhs
):
128 return self
.cmp(rhs
) != 0
130 def __gt__(self
, rhs
):
131 return self
.cmp(rhs
) > 0
133 def __lt__(self
, rhs
):
134 return self
.cmp(rhs
) < 0
136 def __ge__(self
, rhs
):
137 return self
.cmp(rhs
) >= 0
139 def __le__(self
, rhs
):
140 return self
.cmp(rhs
) <= 0
143 """return the fractional part of `self`.
144 that is `self - math.floor(self)`.
146 fract_mask
= (1 << self
.frac_wid
) - 1
147 return FixedPoint(self
.bits
& fract_mask
, self
.frac_wid
)
151 return "-" + str(-self
)
153 frac_digit_count
= (self
.frac_wid
+ digit_bits
- 1) // digit_bits
154 fract
= self
.fract().to_frac_wid(frac_digit_count
* digit_bits
)
155 frac_str
= hex(fract
.bits
)[2:].zfill(frac_digit_count
)
156 return hex(math
.floor(self
)) + "." + frac_str
159 return f
"FixedPoint.with_frac_wid({str(self)!r}, {self.frac_wid})"
161 def __add__(self
, rhs
):
162 rhs
= FixedPoint
.cast(rhs
)
163 common_frac_wid
= max(self
.frac_wid
, rhs
.frac_wid
)
164 lhs
= self
.to_frac_wid(common_frac_wid
)
165 rhs
= rhs
.to_frac_wid(common_frac_wid
)
166 return FixedPoint(lhs
.bits
+ rhs
.bits
, common_frac_wid
)
169 return FixedPoint(-self
.bits
, self
.frac_wid
)
171 def __sub__(self
, rhs
):
172 rhs
= FixedPoint
.cast(rhs
)
173 common_frac_wid
= max(self
.frac_wid
, rhs
.frac_wid
)
174 lhs
= self
.to_frac_wid(common_frac_wid
)
175 rhs
= rhs
.to_frac_wid(common_frac_wid
)
176 return FixedPoint(lhs
.bits
- rhs
.bits
, common_frac_wid
)
178 def __mul__(self
, rhs
):
179 rhs
= FixedPoint
.cast(rhs
)
180 return FixedPoint(self
.bits
* rhs
.bits
, self
.frac_wid
+ rhs
.frac_wid
)
183 return self
.bits
>> self
.frac_wid
186 def goldschmidt_div(n
, d
, width
):
187 """ Goldschmidt division algorithm.
190 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
191 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
192 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
196 numerator. a `2*width`-bit unsigned integer.
197 must be less than `d << width`, otherwise the quotient wouldn't
200 denominator. a `width`-bit unsigned integer. must not be zero.
202 the bit-width of the inputs/outputs. must be a positive integer.
205 the quotient. a `width`-bit unsigned integer.
207 assert isinstance(width
, int) and width
>= 1
208 assert isinstance(d
, int) and 0 < d
< (1 << width
)
209 assert isinstance(n
, int) and 0 <= n
< (d
<< width
)
211 # FIXME: calculate best values for extra_precision, table_addr_bits, and
212 # table_data_bits -- these are wrong
213 extra_precision
= width
+ 3
217 width
+= extra_precision
220 for i
in range(1 << table_addr_bits
):
221 value
= 1 / (1 + i
* 2 ** -table_addr_bits
)
222 table
.append(FixedPoint
.with_frac_wid(value
, table_data_bits
,
225 # this whole algorithm is done with fixed-point arithmetic where values
226 # have `width` fractional bits
228 n
= FixedPoint(n
, width
)
229 d
= FixedPoint(d
, width
)
231 # normalize so 1 <= d < 2
232 # can easily be done with count-leading-zeros and left shift
234 n
= (n
* 2).to_frac_wid(width
)
235 d
= (d
* 2).to_frac_wid(width
)
238 # normalize so 1 <= n < 2
240 n
= (n
* 0.5).to_frac_wid(width
)
243 # compute initial f by table lookup
244 f
= table
[(d
- 1).to_frac_wid(table_addr_bits
, RoundDir
.DOWN
).bits
]
246 min_bits_of_precision
= 1
247 while min_bits_of_precision
< width
* 2:
248 # multiply both n and d by f
251 n
= n
.to_frac_wid(width
, round_dir
=RoundDir
.DOWN
)
252 d
= d
.to_frac_wid(width
, round_dir
=RoundDir
.UP
)
254 # slightly less than 2 to make the computation just a bitwise not
255 nearly_two
= FixedPoint
.with_frac_wid(2, width
)
256 nearly_two
= FixedPoint(nearly_two
.bits
- 1, width
)
257 f
= (nearly_two
- d
).to_frac_wid(width
)
259 min_bits_of_precision
*= 2
261 # scale to correct value
264 # avoid incorrectly rounding down
265 n
= n
.to_frac_wid(width
- extra_precision
, round_dir
=RoundDir
.UP
)