add WIP goldschmidt division algorithm
[soc.git] / src / soc / fu / div / experiment / goldschmidt_div_sqrt.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 from dataclasses import dataclass
8 import math
9 import enum
10
11
12 @enum.unique
13 class RoundDir(enum.Enum):
14 DOWN = enum.auto()
15 UP = enum.auto()
16 NEAREST_TIES_UP = enum.auto()
17 ERROR_IF_INEXACT = enum.auto()
18
19
20 @dataclass(frozen=True)
21 class FixedPoint:
22 bits: int
23 frac_wid: int
24
25 def __post_init__(self):
26 assert isinstance(self.bits, int)
27 assert isinstance(self.frac_wid, int) and self.frac_wid >= 0
28
29 @staticmethod
30 def cast(value):
31 """convert `value` to a fixed-point number with enough fractional
32 bits to preserve its value."""
33 if isinstance(value, FixedPoint):
34 return value
35 if isinstance(value, int):
36 return FixedPoint(value, 0)
37 if isinstance(value, str):
38 value = value.strip()
39 neg = value.startswith("-")
40 if neg or value.startswith("+"):
41 value = value[1:]
42 if value.startswith(("0x", "0X")) and "." in value:
43 value = value[2:]
44 got_dot = False
45 bits = 0
46 frac_wid = 0
47 for digit in value:
48 if digit == "_":
49 continue
50 if got_dot:
51 if digit == ".":
52 raise ValueError("too many `.` in string")
53 frac_wid += 4
54 if digit == ".":
55 got_dot = True
56 continue
57 if not digit.isalnum():
58 raise ValueError("invalid hexadecimal digit")
59 bits <<= 4
60 bits |= int("0x" + digit, base=16)
61 else:
62 bits = int(value, base=0)
63 frac_wid = 0
64 if neg:
65 bits = -bits
66 return FixedPoint(bits, frac_wid)
67
68 if isinstance(value, float):
69 n, d = value.as_integer_ratio()
70 log2_d = d.bit_length() - 1
71 assert d == 1 << log2_d, ("d isn't a power of 2 -- won't ever "
72 "fail with float being IEEE 754")
73 return FixedPoint(n, log2_d)
74 raise TypeError("can't convert type to FixedPoint")
75
76 @staticmethod
77 def with_frac_wid(value, frac_wid, round_dir=RoundDir.ERROR_IF_INEXACT):
78 """convert `value` to the nearest fixed-point number with `frac_wid`
79 fractional bits, rounding according to `round_dir`."""
80 value = FixedPoint.cast(value)
81 assert isinstance(frac_wid, int) and frac_wid >= 0
82 assert isinstance(round_dir, RoundDir)
83 # compute number of bits that should be removed from value
84 del_bits = value.frac_wid - frac_wid
85 if del_bits == 0:
86 return value
87 if del_bits < 0: # add bits
88 return FixedPoint(value.bits << -del_bits,
89 frac_wid)
90 if round_dir == RoundDir.DOWN:
91 bits = value.bits >> del_bits
92 elif round_dir == RoundDir.UP:
93 bits = -((-value.bits) >> del_bits)
94 elif round_dir == RoundDir.NEAREST_TIES_UP:
95 bits = value.bits >> (del_bits - 1)
96 bits += 1
97 bits >>= 1
98 elif round_dir == RoundDir.ERROR_IF_INEXACT:
99 bits = value.bits >> del_bits
100 if bits << del_bits != value.bits:
101 raise ValueError("inexact conversion")
102 else:
103 assert False, "unimplemented round_dir"
104 return FixedPoint(bits, frac_wid)
105
106 def to_frac_wid(self, frac_wid, round_dir=RoundDir.ERROR_IF_INEXACT):
107 """convert to the nearest fixed-point number with `frac_wid`
108 fractional bits, rounding according to `round_dir`."""
109 return FixedPoint.with_frac_wid(self, frac_wid, round_dir)
110
111 def __float__(self):
112 return self.bits * 2.0 ** -self.frac_wid
113
114 def cmp(self, rhs):
115 """compare self with rhs, returning a positive integer if self is
116 greater than rhs, zero if self is equal to rhs, and a negative integer
117 if self is less than rhs."""
118 rhs = FixedPoint.cast(rhs)
119 common_frac_wid = max(self.frac_wid, rhs.frac_wid)
120 lhs = self.to_frac_wid(common_frac_wid)
121 rhs = rhs.to_frac_wid(common_frac_wid)
122 return lhs.bits - rhs.bits
123
124 def __eq__(self, rhs):
125 return self.cmp(rhs) == 0
126
127 def __ne__(self, rhs):
128 return self.cmp(rhs) != 0
129
130 def __gt__(self, rhs):
131 return self.cmp(rhs) > 0
132
133 def __lt__(self, rhs):
134 return self.cmp(rhs) < 0
135
136 def __ge__(self, rhs):
137 return self.cmp(rhs) >= 0
138
139 def __le__(self, rhs):
140 return self.cmp(rhs) <= 0
141
142 def fract(self):
143 """return the fractional part of `self`.
144 that is `self - math.floor(self)`.
145 """
146 fract_mask = (1 << self.frac_wid) - 1
147 return FixedPoint(self.bits & fract_mask, self.frac_wid)
148
149 def __str__(self):
150 if self < 0:
151 return "-" + str(-self)
152 digit_bits = 4
153 frac_digit_count = (self.frac_wid + digit_bits - 1) // digit_bits
154 fract = self.fract().to_frac_wid(frac_digit_count * digit_bits)
155 frac_str = hex(fract.bits)[2:].zfill(frac_digit_count)
156 return hex(math.floor(self)) + "." + frac_str
157
158 def __repr__(self):
159 return f"FixedPoint.with_frac_wid({str(self)!r}, {self.frac_wid})"
160
161 def __add__(self, rhs):
162 rhs = FixedPoint.cast(rhs)
163 common_frac_wid = max(self.frac_wid, rhs.frac_wid)
164 lhs = self.to_frac_wid(common_frac_wid)
165 rhs = rhs.to_frac_wid(common_frac_wid)
166 return FixedPoint(lhs.bits + rhs.bits, common_frac_wid)
167
168 def __neg__(self):
169 return FixedPoint(-self.bits, self.frac_wid)
170
171 def __sub__(self, rhs):
172 rhs = FixedPoint.cast(rhs)
173 common_frac_wid = max(self.frac_wid, rhs.frac_wid)
174 lhs = self.to_frac_wid(common_frac_wid)
175 rhs = rhs.to_frac_wid(common_frac_wid)
176 return FixedPoint(lhs.bits - rhs.bits, common_frac_wid)
177
178 def __mul__(self, rhs):
179 rhs = FixedPoint.cast(rhs)
180 return FixedPoint(self.bits * rhs.bits, self.frac_wid + rhs.frac_wid)
181
182 def __floor__(self):
183 return self.bits >> self.frac_wid
184
185
186 def goldschmidt_div(n, d, width):
187 """ Goldschmidt division algorithm.
188
189 based on:
190 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
191 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
192 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
193
194 arguments:
195 n: int
196 numerator. a `2*width`-bit unsigned integer.
197 must be less than `d << width`, otherwise the quotient wouldn't
198 fit in `width` bits.
199 d: int
200 denominator. a `width`-bit unsigned integer. must not be zero.
201 width: int
202 the bit-width of the inputs/outputs. must be a positive integer.
203
204 returns: int
205 the quotient. a `width`-bit unsigned integer.
206 """
207 assert isinstance(width, int) and width >= 1
208 assert isinstance(d, int) and 0 < d < (1 << width)
209 assert isinstance(n, int) and 0 <= n < (d << width)
210
211 # FIXME: calculate best values for extra_precision, table_addr_bits, and
212 # table_data_bits -- these are wrong
213 extra_precision = width + 3
214 table_addr_bits = 4
215 table_data_bits = 8
216
217 width += extra_precision
218
219 table = []
220 for i in range(1 << table_addr_bits):
221 value = 1 / (1 + i * 2 ** -table_addr_bits)
222 table.append(FixedPoint.with_frac_wid(value, table_data_bits,
223 RoundDir.DOWN))
224
225 # this whole algorithm is done with fixed-point arithmetic where values
226 # have `width` fractional bits
227
228 n = FixedPoint(n, width)
229 d = FixedPoint(d, width)
230
231 # normalize so 1 <= d < 2
232 # can easily be done with count-leading-zeros and left shift
233 while d < 1:
234 n = (n * 2).to_frac_wid(width)
235 d = (d * 2).to_frac_wid(width)
236
237 n_shift = 0
238 # normalize so 1 <= n < 2
239 while n >= 2:
240 n = (n * 0.5).to_frac_wid(width)
241 n_shift += 1
242
243 # compute initial f by table lookup
244 f = table[(d - 1).to_frac_wid(table_addr_bits, RoundDir.DOWN).bits]
245
246 min_bits_of_precision = 1
247 while min_bits_of_precision < width * 2:
248 # multiply both n and d by f
249 n *= f
250 d *= f
251 n = n.to_frac_wid(width, round_dir=RoundDir.DOWN)
252 d = d.to_frac_wid(width, round_dir=RoundDir.UP)
253
254 # slightly less than 2 to make the computation just a bitwise not
255 nearly_two = FixedPoint.with_frac_wid(2, width)
256 nearly_two = FixedPoint(nearly_two.bits - 1, width)
257 f = (nearly_two - d).to_frac_wid(width)
258
259 min_bits_of_precision *= 2
260
261 # scale to correct value
262 n *= 1 << n_shift
263
264 # avoid incorrectly rounding down
265 n = n.to_frac_wid(width - extra_precision, round_dir=RoundDir.UP)
266 return math.floor(n)