test_core.py doesn't crash anymore
authorJacob Lifshay <programmerjake@gmail.com>
Wed, 10 Jul 2019 06:48:19 +0000 (23:48 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Wed, 10 Jul 2019 06:48:19 +0000 (23:48 -0700)
.gitignore
src/ieee754/div_rem_sqrt_rsqrt/core.py
src/ieee754/div_rem_sqrt_rsqrt/test_core.py

index 31fdbbe5574abf7346eea73e2ff0f0d1e0033800..9c4f3cb0749a103fbe36a369bab7f8701ea94069 100644 (file)
@@ -7,3 +7,4 @@ __pycache__
 *.il
 .eggs
 *.egg-info
 *.il
 .eggs
 *.egg-info
+*.gtkw
index a1a8edefc343b21e1c48db9bedd13e3f8c126137..141deb7365456d727952588c6e8c848c91652625 100644 (file)
@@ -49,7 +49,7 @@ class DivPipeCoreConfig:
         return (self.bit_width + self.log2_radix - 1) // self.log2_radix
 
 
         return (self.bit_width + self.log2_radix - 1) // self.log2_radix
 
 
-class DivPipeCoreOperation(enum.IntEnum):
+class DivPipeCoreOperation(enum.Enum):
     """ Operation for ``DivPipeCore``.
 
     :attribute UDivRem: unsigned divide/remainder.
     """ Operation for ``DivPipeCore``.
 
     :attribute UDivRem: unsigned divide/remainder.
@@ -61,13 +61,17 @@ class DivPipeCoreOperation(enum.IntEnum):
     SqrtRem = 1
     RSqrtRem = 2
 
     SqrtRem = 1
     RSqrtRem = 2
 
+    def __int__(self):
+        """ Convert to int. """
+        return self.value
+
     @classmethod
     def create_signal(cls, *, src_loc_at=0, **kwargs):
         """ Create a signal that can contain a ``DivPipeCoreOperation``. """
     @classmethod
     def create_signal(cls, *, src_loc_at=0, **kwargs):
         """ Create a signal that can contain a ``DivPipeCoreOperation``. """
-        return Signal(min=int(min(cls)),
-                      max=int(max(cls)),
+        return Signal(min=min(map(int, cls)),
+                      max=max(map(int, cls)),
                       src_loc_at=(src_loc_at + 1),
                       src_loc_at=(src_loc_at + 1),
-                      decoder=cls,
+                      decoder=lambda v: str(cls(v)),
                       **kwargs)
 
 
                       **kwargs)
 
 
@@ -239,10 +243,10 @@ class DivPipeCoreSetupStage(Elaboratable):
         m.d.comb += self.o.quotient_root.eq(0)
         m.d.comb += self.o.root_times_radicand.eq(0)
 
         m.d.comb += self.o.quotient_root.eq(0)
         m.d.comb += self.o.root_times_radicand.eq(0)
 
-        with m.If(self.i.operation == DivPipeCoreOperation.UDivRem):
+        with m.If(self.i.operation == int(DivPipeCoreOperation.UDivRem)):
             m.d.comb += self.o.compare_lhs.eq(self.i.dividend
                                               << self.core_config.fract_width)
             m.d.comb += self.o.compare_lhs.eq(self.i.dividend
                                               << self.core_config.fract_width)
-        with m.Elif(self.i.operation == DivPipeCoreOperation.SqrtRem):
+        with m.Elif(self.i.operation == int(DivPipeCoreOperation.SqrtRem)):
             m.d.comb += self.o.compare_lhs.eq(
                 self.i.divisor_radicand << (self.core_config.fract_width * 2))
         with m.Else():  # DivPipeCoreOperation.RSqrtRem
             m.d.comb += self.o.compare_lhs.eq(
                 self.i.divisor_radicand << (self.core_config.fract_width * 2))
         with m.Else():  # DivPipeCoreOperation.RSqrtRem
@@ -331,9 +335,9 @@ class DivPipeCoreCalculateStage(Elaboratable):
                 self.o.compare_rhs, name=f"trial_compare_rhs_{trial_bits}",
                 reset_less=True)
 
                 self.o.compare_rhs, name=f"trial_compare_rhs_{trial_bits}",
                 reset_less=True)
 
-            with m.If(self.i.operation == DivPipeCoreOperation.UDivRem):
+            with m.If(self.i.operation == int(DivPipeCoreOperation.UDivRem)):
                 m.d.comb += trial_compare_rhs.eq(div_rhs)
                 m.d.comb += trial_compare_rhs.eq(div_rhs)
-            with m.Elif(self.i.operation == DivPipeCoreOperation.SqrtRem):
+            with m.Elif(self.i.operation == int(DivPipeCoreOperation.SqrtRem)):
                 m.d.comb += trial_compare_rhs.eq(sqrt_rhs)
             with m.Else():  # DivPipeCoreOperation.RSqrtRem
                 m.d.comb += trial_compare_rhs.eq(rsqrt_rhs)
                 m.d.comb += trial_compare_rhs.eq(sqrt_rhs)
             with m.Else():  # DivPipeCoreOperation.RSqrtRem
                 m.d.comb += trial_compare_rhs.eq(rsqrt_rhs)
index 74b10b087632e696321cb384bdaeec930902f16c..59beaddebb10e37688137d132ad25b24726219de 100755 (executable)
@@ -9,10 +9,11 @@ from .core import (DivPipeCoreConfig, DivPipeCoreSetupStage,
 from .algorithm import (FixedUDivRemSqrtRSqrt, Fixed, Operation, div_rem,
                         fixed_sqrt, fixed_rsqrt)
 import unittest
 from .algorithm import (FixedUDivRemSqrtRSqrt, Fixed, Operation, div_rem,
                         fixed_sqrt, fixed_rsqrt)
 import unittest
-from nmigen import Module, Elaboratable
+from nmigen import Module, Elaboratable, Signal
 from nmigen.hdl.ir import Fragment
 from nmigen.back import rtlil
 from nmigen.back.pysim import Simulator, Delay, Tick
 from nmigen.hdl.ir import Fragment
 from nmigen.back import rtlil
 from nmigen.back.pysim import Simulator, Delay, Tick
+from itertools import chain
 
 
 def show_fixed(bits, fract_width, bit_width):
 
 
 def show_fixed(bits, fract_width, bit_width):
@@ -51,18 +52,18 @@ class TestCaseData:
     def __str__(self):
         bit_width = self.core_config.bit_width
         fract_width = self.core_config.fract_width
     def __str__(self):
         bit_width = self.core_config.bit_width
         fract_width = self.core_config.fract_width
-        dividend_str = show_fixed(dividend,
+        dividend_str = show_fixed(self.dividend,
                                   fract_width * 2,
                                   bit_width + fract_width)
                                   fract_width * 2,
                                   bit_width + fract_width)
-        divisor_radicand_str = show_fixed(divisor_radicand,
+        divisor_radicand_str = show_fixed(self.divisor_radicand,
                                           fract_width,
                                           bit_width)
                                           fract_width,
                                           bit_width)
-        quotient_root_str = self.show_fixed(quotient_root,
-                                            fract_width,
-                                            bit_width)
-        remainder_str = self.show_fixed(remainder,
-                                        fract_width * 3,
-                                        bit_width * 3)
+        quotient_root_str = show_fixed(self.quotient_root,
+                                       fract_width,
+                                       bit_width)
+        remainder_str = show_fixed(self.remainder,
+                                   fract_width * 3,
+                                   bit_width * 3)
         return f"{{dividend={dividend_str}, " \
             + f"divisor_radicand={divisor_radicand_str}, " \
             + f"op={self.alg_op.name}, " \
         return f"{{dividend={dividend_str}, " \
             + f"divisor_radicand={divisor_radicand_str}, " \
             + f"op={self.alg_op.name}, " \
@@ -110,35 +111,41 @@ def generate_test_case(core_config, dividend, divisor_radicand, alg_op):
 
 
 def get_test_cases(core_config,
 
 
 def get_test_cases(core_config,
-                   dividend_range=None,
-                   divisor_range=None,
-                   radicand_range=None):
-    if dividend_range is None:
-        dividend_range = range(1 << (core_config.bit_width
-                                     + core_config.fract_width))
-    if divisor_range is None:
-        divisor_range = range(1 << core_config.bit_width)
-    if radicand_range is None:
-        radicand_range = range(1 << core_config.bit_width)
+                   dividends=None,
+                   divisors=None,
+                   radicands=None):
+    if dividends is None:
+        dividends = range(1 << (core_config.bit_width
+                                + core_config.fract_width))
+    else:
+        assert isinstance(dividends, list)
+    if divisors is None:
+        divisors = range(1 << core_config.bit_width)
+    else:
+        assert isinstance(divisors, list)
+    if radicands is None:
+        radicands = range(1 << core_config.bit_width)
+    else:
+        assert isinstance(radicands, list)
 
     for alg_op in Operation:
         if alg_op is Operation.UDivRem:
 
     for alg_op in Operation:
         if alg_op is Operation.UDivRem:
-            for dividend in dividend_range:
-                for divisor in divisor_range:
+            for dividend in dividends:
+                for divisor in divisors:
                     yield from generate_test_case(core_config,
                                                   dividend,
                                                   divisor,
                                                   alg_op)
         else:
                     yield from generate_test_case(core_config,
                                                   dividend,
                                                   divisor,
                                                   alg_op)
         else:
-            for radicand in radicand_range:
+            for radicand in radicands:
                 yield from generate_test_case(core_config,
                 yield from generate_test_case(core_config,
-                                              dividend,
+                                              0,
                                               radicand,
                                               alg_op)
 
 
 class DivPipeCoreTestPipeline(Elaboratable):
                                               radicand,
                                               alg_op)
 
 
 class DivPipeCoreTestPipeline(Elaboratable):
-    def __init__(self, core_config):
+    def __init__(self, core_config, sync=True):
         self.setup_stage = DivPipeCoreSetupStage(core_config)
         self.calculate_stages = [
             DivPipeCoreCalculateStage(core_config, stage_index)
         self.setup_stage = DivPipeCoreSetupStage(core_config)
         self.calculate_stages = [
             DivPipeCoreCalculateStage(core_config, stage_index)
@@ -149,6 +156,7 @@ class DivPipeCoreTestPipeline(Elaboratable):
             for i in range(core_config.num_calculate_stages + 1)]
         self.i = DivPipeCoreInputData(core_config, reset_less=True)
         self.o = DivPipeCoreOutputData(core_config, reset_less=True)
             for i in range(core_config.num_calculate_stages + 1)]
         self.i = DivPipeCoreInputData(core_config, reset_less=True)
         self.o = DivPipeCoreOutputData(core_config, reset_less=True)
+        self.sync = sync
 
     def elaborate(self, platform):
         m = Module()
 
     def elaborate(self, platform):
         m = Module()
@@ -157,28 +165,39 @@ class DivPipeCoreTestPipeline(Elaboratable):
         stage_outputs = [*self.interstage_signals, self.o]
         for stage, input, output in zip(stages, stage_inputs, stage_outputs):
             stage.setup(m, input)
         stage_outputs = [*self.interstage_signals, self.o]
         for stage, input, output in zip(stages, stage_inputs, stage_outputs):
             stage.setup(m, input)
-            m.d.sync += output.eq(stage.process(input))
-
+            assignments = output.eq(stage.process(input))
+            if self.sync:
+                m.d.sync += assignments
+            else:
+                m.d.comb += assignments
         return m
 
     def traces(self):
         yield from self.i
         return m
 
     def traces(self):
         yield from self.i
-        for interstage_signal in self.interstage_signals:
-            yield from interstage_signal
+        for interstage_signal in self.interstage_signals:
+            yield from interstage_signal
         yield from self.o
 
 
 class TestDivPipeCore(unittest.TestCase):
     def handle_case(self,
                     core_config,
         yield from self.o
 
 
 class TestDivPipeCore(unittest.TestCase):
     def handle_case(self,
                     core_config,
-                    dividend_range=None,
-                    divisor_range=None,
-                    radicand_range=None):
+                    dividends=None,
+                    divisors=None,
+                    radicands=None,
+                    sync=True):
+        if dividends is not None:
+            dividends = list(dividends)
+        if divisors is not None:
+            divisors = list(divisors)
+        if radicands is not None:
+            radicands = list(radicands)
+
         def gen_test_cases():
             yield from get_test_cases(core_config,
         def gen_test_cases():
             yield from get_test_cases(core_config,
-                                      dividend_range,
-                                      divisor_range,
-                                      radicand_range)
+                                      dividends,
+                                      divisors,
+                                      radicands)
         base_name = f"div_pipe_core_bit_width_{core_config.bit_width}"
         base_name += f"_fract_width_{core_config.fract_width}"
         base_name += f"_radix_{1 << core_config.log2_radix}"
         base_name = f"div_pipe_core_bit_width_{core_config.bit_width}"
         base_name += f"_fract_width_{core_config.fract_width}"
         base_name += f"_radix_{1 << core_config.log2_radix}"
@@ -189,23 +208,24 @@ class TestDivPipeCore(unittest.TestCase):
                 f.write(vl)
         dut = DivPipeCoreTestPipeline(core_config)
         with Simulator(dut,
                 f.write(vl)
         dut = DivPipeCoreTestPipeline(core_config)
         with Simulator(dut,
-                       vcd_file=f"{base_name}.vcd",
-                       gtkw_file=f"{base_name}.gtkw",
+                       vcd_file=open(f"{base_name}.vcd", "w"),
+                       gtkw_file=open(f"{base_name}.gtkw", "w"),
                        traces=[*dut.traces()]) as sim:
             def generate_process():
                 for test_case in gen_test_cases():
                     yield dut.i.dividend.eq(test_case.dividend)
                     yield dut.i.divisor_radicand.eq(test_case.divisor_radicand)
                        traces=[*dut.traces()]) as sim:
             def generate_process():
                 for test_case in gen_test_cases():
                     yield dut.i.dividend.eq(test_case.dividend)
                     yield dut.i.divisor_radicand.eq(test_case.divisor_radicand)
-                    yield dut.i.operation.eq(test_case.core_op)
+                    yield dut.i.operation.eq(int(test_case.core_op))
                     yield Delay(1e-6)
                     yield Tick()
 
             def check_process():
                 # sync with generator
                     yield Delay(1e-6)
                     yield Tick()
 
             def check_process():
                 # sync with generator
-                yield
-                for _ in core_config.num_calculate_stages:
+                if sync:
+                    yield
+                    for _ in range(core_config.num_calculate_stages):
+                        yield
                     yield
                     yield
-                yield
 
                 # now synched with generator
                 for test_case in gen_test_cases():
 
                 # now synched with generator
                 for test_case in gen_test_cases():
@@ -225,7 +245,10 @@ class TestDivPipeCore(unittest.TestCase):
     def test_bit_width_8_fract_width_4_radix_2(self):
         self.handle_case(DivPipeCoreConfig(bit_width=8,
                                            fract_width=4,
     def test_bit_width_8_fract_width_4_radix_2(self):
         self.handle_case(DivPipeCoreConfig(bit_width=8,
                                            fract_width=4,
-                                           log2_radix=1))
+                                           log2_radix=1),
+                         dividends=[*range(1 << 8),
+                                    *range(1 << 8, 1 << 12, 1 << 4)],
+                         sync=False)
 
     # FIXME: add more test_* functions
 
 
     # FIXME: add more test_* functions