switch to exact version of cython
[ieee754fpu.git] / src / ieee754 / partitioned_signal_tester.py
index 27db86925f9c90d82564bec50ae3359e1a5dd9bc..2986ee440e334cf2dd700db050fd76bbca323674 100644 (file)
@@ -6,13 +6,14 @@ from nmigen.hdl.ast import (AnyConst, Assert, Signal, Value, ValueCastable)
 from nmigen.hdl.dsl import Module
 from nmigen.hdl.ir import Elaboratable, Fragment
 from nmigen.sim import Simulator, Delay
-from ieee754.part.partsig import PartitionedSignal, PartitionPoints
+from ieee754.part.partsig import SimdSignal, PartitionPoints
 import unittest
 import textwrap
 import subprocess
 from hashlib import sha256
 from nmigen.back import rtlil
-from nmutil.get_test_path import get_test_path, _StrPath
+from nmutil.get_test_path import get_test_path
+from collections.abc import Sequence
 
 
 def formal(test_case, hdl, *, base_path="formal_test_temp"):
@@ -51,15 +52,7 @@ def formal(test_case, hdl, *, base_path="formal_test_temp"):
             test_case.fail(f"Formal failed:\n{stdout}")
 
 
-@final
 class Layout:
-    __lane_starts_for_sizes
-    """keys are in sorted order"""
-
-    part_indexes
-    """bit indexes of partition points in sorted order, always includes
-    `0` and `self.width`"""
-
     @staticmethod
     def cast(layout, width=None):
         if isinstance(layout, Layout):
@@ -80,7 +73,10 @@ class Layout:
         part_indexes = list(part_indexes)
         part_indexes.sort()
         self.part_indexes = tuple(part_indexes)
-        sizes: List[int] = []
+        """bit indexes of partition points in sorted order, always
+        includes `0` and `self.width`"""
+
+        sizes = []
         for start_index in range(len(self.part_indexes)):
             start = self.part_indexes[start_index]
             for end in self.part_indexes[start_index + 1:]:
@@ -88,6 +84,8 @@ class Layout:
         sizes.sort()
         # build in sorted order
         self.__lane_starts_for_sizes = {size: {} for size in sizes}
+        """keys are in sorted order"""
+
         for start_index in range(len(self.part_indexes)):
             start = self.part_indexes[start_index]
             for end in self.part_indexes[start_index + 1:]:
@@ -109,11 +107,14 @@ class Layout:
         assert width >= 0
         return width
 
-    def partition_points_signals(self, nameNone,
+    def partition_points_signals(self, name=None,
                                  src_loc_at=0):
         if name is None:
             name = Signal(src_loc_at=1 + src_loc_at).name
-        return PartitionPoints({ i for i in self.part_indexes[1:-1] })
+        pps = {}
+        for i in self.part_indexes[1:-1]:
+            pps[i] = Signal(name=f"{name}_{i}", src_loc_at=1 + src_loc_at)
+        return PartitionPoints(pps)
 
     def __repr__(self):
         return f"Layout({self.part_indexes}, width={self.width})"
@@ -158,7 +159,6 @@ class Layout:
         return Lane(target_start, target_end - target_start, target_layout)
 
 
-@final
 class Lane:
     def __init__(self, start, size, layout):
         self.layout = Layout.cast(layout)
@@ -189,18 +189,6 @@ class Lane:
     def translate_to(self, target_layout):
         return self.layout.translate_lane_to(self, target_layout)
 
-    @overload
-    def is_active(self, partition_points): ...
-
-    @overload
-    def is_active(self, partition_points): ...
-
-    @overload
-    def is_active(self, partition_points): ...
-
-    @overload
-    def is_active(self, partition_points): ...
-
     def is_active(self, partition_points):
         def get_partition_point(index, invert):
             if index == 0 or index == len(self.layout.part_indexes) - 1:
@@ -228,7 +216,7 @@ class Lane:
         return retval
 
 
-class PartitionedSignalTester:
+class SimdSignalTester:
 
     def __init__(self, m, operation, reference, *layouts,
                  src_loc_at=0, additional_case_count=30,
@@ -244,7 +232,7 @@ class PartitionedSignalTester:
                 assert self.layouts[0].is_compatible(layout)
             self.layouts.append(layout)
             name = f"input_{len(self.inputs)}"
-            ps = PartitionedSignal(
+            ps = SimdSignal(
                 layout.partition_points_signals(name=name,
                                                 src_loc_at=1 + src_loc_at),
                 layout.width,
@@ -264,22 +252,23 @@ class PartitionedSignalTester:
         self.seed = seed
         self.case_number = Signal(64)
         self.test_output = operation(tuple(self.inputs))
-        assert isinstance(self.test_output, PartitionedSignal)
+        assert isinstance(self.test_output, SimdSignal)
         self.test_output_layout = Layout(
             self.test_output.partpoints, self.test_output.sig.width)
         assert self.test_output_layout.is_compatible(self.layouts[0])
-        self.reference_output_values = {
-            lane, tuple(
-                inp.sig[lane.translate_to(layout).as_slice()]
-                for inp, layout in zip(self.inputs, self.layouts))
-            for lane in self.layouts[0].lanes()
-        }
-        self.reference_outputs = {
-            lane, name=f"reference_output_{lane.start}_{lane.size}")
-            for lane, value in self.reference_output_values.items()
-        }
+        self.reference_output_values = {}
+        for lane in self.layouts[0].lanes():
+            in_t = []
+            for inp, layout in zip(self.inputs, self.layouts):
+                in_t.append(inp.sig[lane.translate_to(layout).as_slice()])
+            v = Value.cast(reference(lane, tuple(in_t)))
+            self.reference_output_values[lane] = v
+        self.reference_outputs = {}
         for lane, value in self.reference_output_values.items():
-            m.d.comb += self.reference_outputs[lane].eq(value)
+            s = Signal(value.shape(),
+                       name=f"reference_output_{lane.start}_{lane.size}")
+            self.reference_outputs[lane] = s
+            m.d.comb += s.eq(value)
 
     def __hash_256(self, v):
         return int.from_bytes(
@@ -301,13 +290,14 @@ class PartitionedSignalTester:
         bits = self.__hash(f"{case_number} trial {trial}",
                            self.layouts[0].part_signal_count)
         bits |= 1 | (1 << len(self.layouts[0].part_indexes)) | (bits << 1)
-        part_starts = tuple(
-            (bits & (1 << i)) != 0
-            for i in range(len(self.layouts[0].part_indexes)))
-        inputs = tuple(self.__hash(f"{case_number} input {i}",
-                                   self.layouts[i].width)
-                       for i in range(len(self.layouts)))
-        return part_starts, inputs
+        part_starts = []
+        for i in range(len(self.layouts[0].part_indexes)):
+            part_starts.append((bits & (1 << i)) != 0)
+        inputs = []
+        for i in range(len(self.layouts)):
+            inputs.append(self.__hash(f"{case_number} input {i}",
+                                   self.layouts[i].width))
+        return tuple(part_starts), tuple(inputs)
 
     def __format_case(self, case):
         part_starts, inputs = case
@@ -326,9 +316,7 @@ class PartitionedSignalTester:
         for i in range(len(self.inputs)):
             yield self.inputs[i].sig.eq(inputs[i])
 
-    def run_sim(self, test_case, *,
-                engine: Optional[str] = None,
-                base_path: _StrPath = "sim_test_out"):
+    def run_sim(self, test_case, *, engine=None, base_path="sim_test_out"):
         if engine is None:
             sim = Simulator(self.m)
         else: