a72f9243feb53bb65d6201a03fe288dbe72d3cd9
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
4 from nmigen
.hdl
.ast
import Const
5 from .algorithm
import (div_rem
, UnsignedDivRem
, DivRem
,
6 Fixed
, fixed_sqrt
, FixedSqrt
, fixed_rsqrt
, FixedRSqrt
)
11 class TestDivRemFn(unittest
.TestCase
):
12 def test_signed(self
):
14 # numerator, denominator, quotient, remainder
127 (-8, -1, -8, 0), # overflows and wraps around
272 for (n
, d
, q
, r
) in test_cases
:
273 self
.assertEqual(div_rem(n
, d
, 4, True), (q
, r
))
275 def test_unsigned(self
):
282 # div_rem matches // and % for unsigned integers
285 self
.assertEqual(div_rem(n
, d
, 4, False), (q
, r
))
288 class TestUnsignedDivRem(unittest
.TestCase
):
289 def helper(self
, log2_radix
):
291 for n
in range(1 << bit_width
):
292 for d
in range(1 << bit_width
):
293 q
, r
= div_rem(n
, d
, bit_width
, False)
294 with self
.subTest(n
=n
, d
=d
, q
=q
, r
=r
):
295 udr
= UnsignedDivRem(n
, d
, bit_width
, log2_radix
)
296 for _
in range(250 * bit_width
):
297 self
.assertEqual(n
, udr
.quotient
* udr
.divisor
299 if udr
.calculate_stage():
302 self
.fail("infinite loop")
303 self
.assertEqual(n
, udr
.quotient
* udr
.divisor
305 self
.assertEqual(udr
.quotient
, q
)
306 self
.assertEqual(udr
.remainder
, r
)
308 def test_radix_2(self
):
311 def test_radix_4(self
):
314 def test_radix_8(self
):
317 def test_radix_16(self
):
321 class TestDivRem(unittest
.TestCase
):
322 def helper(self
, log2_radix
):
324 for n
in range(1 << bit_width
):
325 for d
in range(1 << bit_width
):
326 for signed
in False, True:
327 n
= Const
.normalize(n
, (bit_width
, signed
))
328 d
= Const
.normalize(d
, (bit_width
, signed
))
329 q
, r
= div_rem(n
, d
, bit_width
, signed
)
330 with self
.subTest(n
=n
, d
=d
, q
=q
, r
=r
, signed
=signed
):
331 dr
= DivRem(n
, d
, bit_width
, signed
, log2_radix
)
332 for _
in range(250 * bit_width
):
333 if dr
.calculate_stage():
336 self
.fail("infinite loop")
337 self
.assertEqual(dr
.quotient
, q
)
338 self
.assertEqual(dr
.remainder
, r
)
340 def test_radix_2(self
):
343 def test_radix_4(self
):
346 def test_radix_8(self
):
349 def test_radix_16(self
):
353 class TestFixed(unittest
.TestCase
):
354 def test_constructor(self
):
355 value
= Fixed(0, 0, 1, False)
356 self
.assertEqual(value
.bits
, 0)
357 self
.assertEqual(value
.fract_width
, 0)
358 self
.assertEqual(value
.bit_width
, 1)
359 self
.assertEqual(value
.signed
, False)
360 value
= Fixed(1, 2, 3, True)
361 self
.assertEqual(value
.bits
, -4)
362 self
.assertEqual(value
.fract_width
, 2)
363 self
.assertEqual(value
.bit_width
, 3)
364 self
.assertEqual(value
.signed
, True)
365 value
= Fixed(1, 2, 4, True)
366 self
.assertEqual(value
.bits
, 4)
367 self
.assertEqual(value
.fract_width
, 2)
368 self
.assertEqual(value
.bit_width
, 4)
369 self
.assertEqual(value
.signed
, True)
370 value
= Fixed(1.25, 4, 8, True)
371 self
.assertEqual(value
.bits
, 0x14)
372 self
.assertEqual(value
.fract_width
, 4)
373 self
.assertEqual(value
.bit_width
, 8)
374 self
.assertEqual(value
.signed
, True)
375 value
= Fixed(Fixed(2, 0, 12, False), 4, 8, True)
376 self
.assertEqual(value
.bits
, 0x20)
377 self
.assertEqual(value
.fract_width
, 4)
378 self
.assertEqual(value
.bit_width
, 8)
379 self
.assertEqual(value
.signed
, True)
380 value
= Fixed(0x2FF / 2 ** 8, 8, 12, False)
381 self
.assertEqual(value
.bits
, 0x2FF)
382 self
.assertEqual(value
.fract_width
, 8)
383 self
.assertEqual(value
.bit_width
, 12)
384 self
.assertEqual(value
.signed
, False)
385 value
= Fixed(value
, 4, 8, True)
386 self
.assertEqual(value
.bits
, 0x2F)
387 self
.assertEqual(value
.fract_width
, 4)
388 self
.assertEqual(value
.bit_width
, 8)
389 self
.assertEqual(value
.signed
, True)
391 def helper_test_from_bits(self
, bit_width
, fract_width
):
393 for bits
in range(1 << bit_width
):
394 with self
.subTest(bit_width
=bit_width
,
395 fract_width
=fract_width
,
398 value
= Fixed
.from_bits(bits
, fract_width
, bit_width
, signed
)
399 self
.assertEqual(value
.bit_width
, bit_width
)
400 self
.assertEqual(value
.fract_width
, fract_width
)
401 self
.assertEqual(value
.signed
, signed
)
402 self
.assertEqual(value
.bits
, bits
)
404 for bits
in range(-1 << (bit_width
- 1), 1 << (bit_width
- 1)):
405 with self
.subTest(bit_width
=bit_width
,
406 fract_width
=fract_width
,
409 value
= Fixed
.from_bits(bits
, fract_width
, bit_width
, signed
)
410 self
.assertEqual(value
.bit_width
, bit_width
)
411 self
.assertEqual(value
.fract_width
, fract_width
)
412 self
.assertEqual(value
.signed
, signed
)
413 self
.assertEqual(value
.bits
, bits
)
415 def test_from_bits(self
):
416 for bit_width
in range(1, 5):
417 for fract_width
in range(bit_width
):
418 self
.helper_test_from_bits(bit_width
, fract_width
)
421 self
.assertEqual(repr(Fixed
.from_bits(1, 2, 3, False)),
422 "Fixed.from_bits(1, 2, 3, False)")
423 self
.assertEqual(repr(Fixed
.from_bits(-4, 2, 3, True)),
424 "Fixed.from_bits(-4, 2, 3, True)")
425 self
.assertEqual(repr(Fixed
.from_bits(-4, 7, 10, True)),
426 "Fixed.from_bits(-4, 7, 10, True)")
428 def test_trunc(self
):
429 for i
in range(-8, 8):
430 value
= Fixed
.from_bits(i
, 2, 4, True)
431 with self
.subTest(value
=repr(value
)):
432 self
.assertEqual(math
.trunc(value
), math
.trunc(i
/ 4))
435 for i
in range(-8, 8):
436 value
= Fixed
.from_bits(i
, 2, 4, True)
437 with self
.subTest(value
=repr(value
)):
438 self
.assertEqual(int(value
), math
.trunc(value
))
440 def test_float(self
):
441 for i
in range(-8, 8):
442 value
= Fixed
.from_bits(i
, 2, 4, True)
443 with self
.subTest(value
=repr(value
)):
444 self
.assertEqual(float(value
), i
/ 4)
446 def test_floor(self
):
447 for i
in range(-8, 8):
448 value
= Fixed
.from_bits(i
, 2, 4, True)
449 with self
.subTest(value
=repr(value
)):
450 self
.assertEqual(math
.floor(value
), math
.floor(i
/ 4))
453 for i
in range(-8, 8):
454 value
= Fixed
.from_bits(i
, 2, 4, True)
455 with self
.subTest(value
=repr(value
)):
456 self
.assertEqual(math
.ceil(value
), math
.ceil(i
/ 4))
459 for i
in range(-8, 8):
460 value
= Fixed
.from_bits(i
, 2, 4, True)
461 expected
= -i
/ 4 if i
!= -8 else -2.0 # handle wrap-around
462 with self
.subTest(value
=repr(value
)):
463 self
.assertEqual(float(-value
), expected
)
466 for i
in range(-8, 8):
467 value
= Fixed
.from_bits(i
, 2, 4, True)
468 with self
.subTest(value
=repr(value
)):
470 self
.assertEqual(value
.bits
, i
)
473 for i
in range(-8, 8):
474 value
= Fixed
.from_bits(i
, 2, 4, True)
475 expected
= abs(i
) / 4 if i
!= -8 else -2.0 # handle wrap-around
476 with self
.subTest(value
=repr(value
)):
477 self
.assertEqual(float(abs(value
)), expected
)
480 for i
in range(-8, 8):
481 value
= Fixed
.from_bits(i
, 2, 4, True)
482 with self
.subTest(value
=repr(value
)):
483 self
.assertEqual(float(~value
), (~i
) / 4)
486 def get_test_values(max_bit_width
, include_int
):
487 for signed
in False, True:
489 for bits
in range(1 << max_bit_width
):
490 int_value
= Const
.normalize(bits
, (max_bit_width
, signed
))
492 for bit_width
in range(1, max_bit_width
):
493 for fract_width
in range(bit_width
+ 1):
494 for bits
in range(1 << bit_width
):
495 yield Fixed
.from_bits(bits
,
500 def binary_op_test_helper(self
,
503 width_combine_op
=max,
504 adjust_bits_op
=None):
505 def default_adjust_bits_op(bits
, out_fract_width
, in_fract_width
):
506 return bits
<< (out_fract_width
- in_fract_width
)
507 if adjust_bits_op
is None:
508 adjust_bits_op
= default_adjust_bits_op
510 for lhs
in self
.get_test_values(max_bit_width
, True):
511 lhs_is_int
= isinstance(lhs
, int)
512 for rhs
in self
.get_test_values(max_bit_width
, not lhs_is_int
):
513 rhs_is_int
= isinstance(rhs
, int)
515 assert not rhs_is_int
516 lhs_int
= adjust_bits_op(lhs
, rhs
.fract_width
, 0)
517 int_result
= operation(lhs_int
, rhs
.bits
)
519 expected
= Fixed
.from_bits(int_result
,
524 expected
= int_result
526 rhs_int
= adjust_bits_op(rhs
, lhs
.fract_width
, 0)
527 int_result
= operation(lhs
.bits
, rhs_int
)
529 expected
= Fixed
.from_bits(int_result
,
534 expected
= int_result
535 elif lhs
.signed
!= rhs
.signed
:
538 fract_width
= width_combine_op(lhs
.fract_width
,
540 int_width
= width_combine_op(lhs
.bit_width
544 bit_width
= fract_width
+ int_width
545 lhs_int
= adjust_bits_op(lhs
.bits
,
548 rhs_int
= adjust_bits_op(rhs
.bits
,
551 int_result
= operation(lhs_int
, rhs_int
)
553 expected
= Fixed
.from_bits(int_result
,
558 expected
= int_result
559 with self
.subTest(lhs
=repr(lhs
),
561 expected
=repr(expected
)):
562 result
= operation(lhs
, rhs
)
564 self
.assertEqual(result
.bit_width
, expected
.bit_width
)
565 self
.assertEqual(result
.signed
, expected
.signed
)
566 self
.assertEqual(result
.fract_width
,
567 expected
.fract_width
)
568 self
.assertEqual(result
.bits
, expected
.bits
)
570 self
.assertEqual(result
, expected
)
573 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
+ rhs
)
576 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
- rhs
)
579 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
& rhs
)
582 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs | rhs
)
585 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs ^ rhs
)
588 def adjust_bits_op(bits
, out_fract_width
, in_fract_width
):
590 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
* rhs
,
592 lambda l_width
, r_width
: l_width
+ r_width
,
602 self
.binary_op_test_helper(cmp, False)
605 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
< rhs
, False)
608 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
<= rhs
, False)
611 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
== rhs
, False)
614 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
!= rhs
, False)
617 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
> rhs
, False)
620 self
.binary_op_test_helper(lambda lhs
, rhs
: lhs
>= rhs
, False)
623 for v
in self
.get_test_values(6, False):
624 with self
.subTest(v
=repr(v
)):
625 self
.assertEqual(bool(v
), bool(v
.bits
))
628 self
.assertEqual(str(Fixed
.from_bits(0x1234, 0, 16, False)),
630 self
.assertEqual(str(Fixed
.from_bits(-0x1234, 0, 16, True)),
632 self
.assertEqual(str(Fixed
.from_bits(0x12345, 3, 20, True)),
634 self
.assertEqual(str(Fixed(123.625, 3, 12, True)),
637 self
.assertEqual(str(Fixed
.from_bits(0x1, 0, 20, True)),
639 self
.assertEqual(str(Fixed
.from_bits(0x2, 1, 20, True)),
641 self
.assertEqual(str(Fixed
.from_bits(0x4, 2, 20, True)),
643 self
.assertEqual(str(Fixed
.from_bits(0x9, 3, 20, True)),
645 self
.assertEqual(str(Fixed
.from_bits(0x12, 4, 20, True)),
647 self
.assertEqual(str(Fixed
.from_bits(0x24, 5, 20, True)),
649 self
.assertEqual(str(Fixed
.from_bits(0x48, 6, 20, True)),
651 self
.assertEqual(str(Fixed
.from_bits(0x91, 7, 20, True)),
653 self
.assertEqual(str(Fixed
.from_bits(0x123, 8, 20, True)),
655 self
.assertEqual(str(Fixed
.from_bits(0x246, 9, 20, True)),
657 self
.assertEqual(str(Fixed
.from_bits(0x48d, 10, 20, True)),
659 self
.assertEqual(str(Fixed
.from_bits(0x91a, 11, 20, True)),
661 self
.assertEqual(str(Fixed
.from_bits(0x1234, 12, 20, True)),
663 self
.assertEqual(str(Fixed
.from_bits(0x2468, 13, 20, True)),
665 self
.assertEqual(str(Fixed
.from_bits(0x48d1, 14, 20, True)),
667 self
.assertEqual(str(Fixed
.from_bits(0x91a2, 15, 20, True)),
669 self
.assertEqual(str(Fixed
.from_bits(0x12345, 16, 20, True)),
671 self
.assertEqual(str(Fixed
.from_bits(0x2468a, 17, 20, True)),
673 self
.assertEqual(str(Fixed
.from_bits(0x48d14, 18, 20, True)),
675 self
.assertEqual(str(Fixed
.from_bits(0x91a28, 19, 20, True)),
677 self
.assertEqual(str(Fixed
.from_bits(0x91a28, 19, 20, False)),
681 # FIXME: add tests for fract_sqrt, FractSqrt, fract_rsqrt, and FractRSqrt