test_core.py doesn't crash anymore
[ieee754fpu.git] / src / ieee754 / div_rem_sqrt_rsqrt / test_core.py
index 74b10b087632e696321cb384bdaeec930902f16c..59beaddebb10e37688137d132ad25b24726219de 100755 (executable)
@@ -9,10 +9,11 @@ from .core import (DivPipeCoreConfig, DivPipeCoreSetupStage,
 from .algorithm import (FixedUDivRemSqrtRSqrt, Fixed, Operation, div_rem,
                         fixed_sqrt, fixed_rsqrt)
 import unittest
-from nmigen import Module, Elaboratable
+from nmigen import Module, Elaboratable, Signal
 from nmigen.hdl.ir import Fragment
 from nmigen.back import rtlil
 from nmigen.back.pysim import Simulator, Delay, Tick
+from itertools import chain
 
 
 def show_fixed(bits, fract_width, bit_width):
@@ -51,18 +52,18 @@ class TestCaseData:
     def __str__(self):
         bit_width = self.core_config.bit_width
         fract_width = self.core_config.fract_width
-        dividend_str = show_fixed(dividend,
+        dividend_str = show_fixed(self.dividend,
                                   fract_width * 2,
                                   bit_width + fract_width)
-        divisor_radicand_str = show_fixed(divisor_radicand,
+        divisor_radicand_str = show_fixed(self.divisor_radicand,
                                           fract_width,
                                           bit_width)
-        quotient_root_str = self.show_fixed(quotient_root,
-                                            fract_width,
-                                            bit_width)
-        remainder_str = self.show_fixed(remainder,
-                                        fract_width * 3,
-                                        bit_width * 3)
+        quotient_root_str = show_fixed(self.quotient_root,
+                                       fract_width,
+                                       bit_width)
+        remainder_str = show_fixed(self.remainder,
+                                   fract_width * 3,
+                                   bit_width * 3)
         return f"{{dividend={dividend_str}, " \
             + f"divisor_radicand={divisor_radicand_str}, " \
             + f"op={self.alg_op.name}, " \
@@ -110,35 +111,41 @@ def generate_test_case(core_config, dividend, divisor_radicand, alg_op):
 
 
 def get_test_cases(core_config,
-                   dividend_range=None,
-                   divisor_range=None,
-                   radicand_range=None):
-    if dividend_range is None:
-        dividend_range = range(1 << (core_config.bit_width
-                                     + core_config.fract_width))
-    if divisor_range is None:
-        divisor_range = range(1 << core_config.bit_width)
-    if radicand_range is None:
-        radicand_range = range(1 << core_config.bit_width)
+                   dividends=None,
+                   divisors=None,
+                   radicands=None):
+    if dividends is None:
+        dividends = range(1 << (core_config.bit_width
+                                + core_config.fract_width))
+    else:
+        assert isinstance(dividends, list)
+    if divisors is None:
+        divisors = range(1 << core_config.bit_width)
+    else:
+        assert isinstance(divisors, list)
+    if radicands is None:
+        radicands = range(1 << core_config.bit_width)
+    else:
+        assert isinstance(radicands, list)
 
     for alg_op in Operation:
         if alg_op is Operation.UDivRem:
-            for dividend in dividend_range:
-                for divisor in divisor_range:
+            for dividend in dividends:
+                for divisor in divisors:
                     yield from generate_test_case(core_config,
                                                   dividend,
                                                   divisor,
                                                   alg_op)
         else:
-            for radicand in radicand_range:
+            for radicand in radicands:
                 yield from generate_test_case(core_config,
-                                              dividend,
+                                              0,
                                               radicand,
                                               alg_op)
 
 
 class DivPipeCoreTestPipeline(Elaboratable):
-    def __init__(self, core_config):
+    def __init__(self, core_config, sync=True):
         self.setup_stage = DivPipeCoreSetupStage(core_config)
         self.calculate_stages = [
             DivPipeCoreCalculateStage(core_config, stage_index)
@@ -149,6 +156,7 @@ class DivPipeCoreTestPipeline(Elaboratable):
             for i in range(core_config.num_calculate_stages + 1)]
         self.i = DivPipeCoreInputData(core_config, reset_less=True)
         self.o = DivPipeCoreOutputData(core_config, reset_less=True)
+        self.sync = sync
 
     def elaborate(self, platform):
         m = Module()
@@ -157,28 +165,39 @@ class DivPipeCoreTestPipeline(Elaboratable):
         stage_outputs = [*self.interstage_signals, self.o]
         for stage, input, output in zip(stages, stage_inputs, stage_outputs):
             stage.setup(m, input)
-            m.d.sync += output.eq(stage.process(input))
-
+            assignments = output.eq(stage.process(input))
+            if self.sync:
+                m.d.sync += assignments
+            else:
+                m.d.comb += assignments
         return m
 
     def traces(self):
         yield from self.i
-        for interstage_signal in self.interstage_signals:
-            yield from interstage_signal
+        for interstage_signal in self.interstage_signals:
+            yield from interstage_signal
         yield from self.o
 
 
 class TestDivPipeCore(unittest.TestCase):
     def handle_case(self,
                     core_config,
-                    dividend_range=None,
-                    divisor_range=None,
-                    radicand_range=None):
+                    dividends=None,
+                    divisors=None,
+                    radicands=None,
+                    sync=True):
+        if dividends is not None:
+            dividends = list(dividends)
+        if divisors is not None:
+            divisors = list(divisors)
+        if radicands is not None:
+            radicands = list(radicands)
+
         def gen_test_cases():
             yield from get_test_cases(core_config,
-                                      dividend_range,
-                                      divisor_range,
-                                      radicand_range)
+                                      dividends,
+                                      divisors,
+                                      radicands)
         base_name = f"div_pipe_core_bit_width_{core_config.bit_width}"
         base_name += f"_fract_width_{core_config.fract_width}"
         base_name += f"_radix_{1 << core_config.log2_radix}"
@@ -189,23 +208,24 @@ class TestDivPipeCore(unittest.TestCase):
                 f.write(vl)
         dut = DivPipeCoreTestPipeline(core_config)
         with Simulator(dut,
-                       vcd_file=f"{base_name}.vcd",
-                       gtkw_file=f"{base_name}.gtkw",
+                       vcd_file=open(f"{base_name}.vcd", "w"),
+                       gtkw_file=open(f"{base_name}.gtkw", "w"),
                        traces=[*dut.traces()]) as sim:
             def generate_process():
                 for test_case in gen_test_cases():
                     yield dut.i.dividend.eq(test_case.dividend)
                     yield dut.i.divisor_radicand.eq(test_case.divisor_radicand)
-                    yield dut.i.operation.eq(test_case.core_op)
+                    yield dut.i.operation.eq(int(test_case.core_op))
                     yield Delay(1e-6)
                     yield Tick()
 
             def check_process():
                 # sync with generator
-                yield
-                for _ in core_config.num_calculate_stages:
+                if sync:
+                    yield
+                    for _ in range(core_config.num_calculate_stages):
+                        yield
                     yield
-                yield
 
                 # now synched with generator
                 for test_case in gen_test_cases():
@@ -225,7 +245,10 @@ class TestDivPipeCore(unittest.TestCase):
     def test_bit_width_8_fract_width_4_radix_2(self):
         self.handle_case(DivPipeCoreConfig(bit_width=8,
                                            fract_width=4,
-                                           log2_radix=1))
+                                           log2_radix=1),
+                         dividends=[*range(1 << 8),
+                                    *range(1 << 8, 1 << 12, 1 << 4)],
+                         sync=False)
 
     # FIXME: add more test_* functions