1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
4 from nmigen
.hdl
.ast
import Const
5 from .algorithm
import div_rem
, UnsignedDivRem
, DivRem
9 class TestDivRemFn(unittest
.TestCase
):
10 def test_signed(self
):
12 # numerator, denominator, quotient, remainder
125 (-8, -1, -8, 0), # overflows and wraps around
270 for (n
, d
, q
, r
) in test_cases
:
271 self
.assertEqual(div_rem(n
, d
, 4, True), (q
, r
))
273 def test_unsigned(self
):
280 # div_rem matches // and % for unsigned integers
283 self
.assertEqual(div_rem(n
, d
, 4, False), (q
, r
))
286 class TestUnsignedDivRem(unittest
.TestCase
):
287 def helper(self
, log2_radix
):
289 for n
in range(1 << bit_width
):
290 for d
in range(1 << bit_width
):
291 q
, r
= div_rem(n
, d
, bit_width
, False)
292 with self
.subTest(n
=n
, d
=d
, q
=q
, r
=r
):
293 udr
= UnsignedDivRem(n
, d
, bit_width
, log2_radix
)
294 for _
in range(250 * bit_width
):
295 self
.assertEqual(n
, udr
.quotient
* udr
.divisor
297 if udr
.calculate_stage():
300 self
.fail("infinite loop")
301 self
.assertEqual(n
, udr
.quotient
* udr
.divisor
303 self
.assertEqual(udr
.quotient
, q
)
304 self
.assertEqual(udr
.remainder
, r
)
306 def test_radix_2(self
):
309 def test_radix_4(self
):
312 def test_radix_8(self
):
315 def test_radix_16(self
):
319 class TestDivRem(unittest
.TestCase
):
320 def helper(self
, log2_radix
):
322 for n
in range(1 << bit_width
):
323 for d
in range(1 << bit_width
):
324 for signed
in False, True:
325 n
= Const
.normalize(n
, (bit_width
, signed
))
326 d
= Const
.normalize(d
, (bit_width
, signed
))
327 q
, r
= div_rem(n
, d
, bit_width
, signed
)
328 with self
.subTest(n
=n
, d
=d
, q
=q
, r
=r
, signed
=signed
):
329 dr
= DivRem(n
, d
, bit_width
, signed
, log2_radix
)
330 for _
in range(250 * bit_width
):
331 if dr
.calculate_stage():
334 self
.fail("infinite loop")
335 self
.assertEqual(dr
.quotient
, q
)
336 self
.assertEqual(dr
.remainder
, r
)
338 def test_radix_2(self
):
341 def test_radix_4(self
):
344 def test_radix_8(self
):
347 def test_radix_16(self
):