add PartitionedSignalTester
[ieee754fpu.git] / src / ieee754 / partitioned_signal_tester.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # See Notices.txt for copyright information
3 from enum import Enum
4 from typing import (Any, Callable, Dict, Generator, Iterable, List, Mapping,
5 Optional, Sequence, Tuple, Union, final, overload)
6 import shutil
7 from nmigen.hdl.ast import (AnyConst, Assert, Signal, Value, ValueCastable)
8 from nmigen.hdl.dsl import Module
9 from nmigen.hdl.ir import Elaboratable, Fragment
10 from nmigen.sim import Simulator, Delay
11 from ieee754.part.partsig import PartitionedSignal, PartitionPoints
12 import unittest
13 import textwrap
14 import subprocess
15 from hashlib import sha256
16 from nmigen.back import rtlil
17 from nmutil.get_test_path import get_test_path, _StrPath
18
19
20 _PartitionedSignalTestable = Callable[[Tuple[PartitionedSignal, ...]],
21 PartitionedSignal]
22
23 _WidthCastable = Union["Layout", int]
24 _LayoutCastable = Union["Layout", Mapping[int, Any], Iterable[int]]
25 _ValueCastableType = Union[Value, int, Enum, ValueCastable]
26 _FragmentLike = Union[Elaboratable, Fragment]
27
28
29 def formal(test_case: unittest.TestCase, hdl: _FragmentLike, *,
30 base_path: _StrPath = "formal_test_temp"):
31 hdl = Fragment.get(hdl, platform="formal")
32 path = get_test_path(test_case, base_path)
33 shutil.rmtree(path, ignore_errors=True)
34 path.mkdir(parents=True)
35 sby_name = "config.sby"
36 sby_file = path / sby_name
37
38 sby_file.write_text(textwrap.dedent(f"""\
39 [options]
40 mode prove
41 depth 1
42 wait on
43
44 [engines]
45 smtbmc
46
47 [script]
48 read_rtlil top.il
49 prep
50
51 [file top.il]
52 {rtlil.convert(hdl)}
53 """), encoding="utf-8")
54 sby = shutil.which('sby')
55 assert sby is not None
56 with subprocess.Popen(
57 [sby, sby_name],
58 cwd=path, text=True, encoding="utf-8",
59 stdin=subprocess.DEVNULL, stdout=subprocess.PIPE
60 ) as p:
61 stdout, stderr = p.communicate()
62 if p.returncode != 0:
63 test_case.fail(f"Formal failed:\n{stdout}")
64
65
66 @final
67 class Layout:
68 __lane_starts_for_sizes: Dict[int, Dict[int, None]]
69 """keys are in sorted order"""
70
71 part_indexes: Tuple[int, ...]
72 """bit indexes of partition points in sorted order, always includes
73 `0` and `self.width`"""
74
75 @staticmethod
76 def cast(layout: _LayoutCastable,
77 width: Optional[_WidthCastable] = None) -> "Layout":
78 if isinstance(layout, Layout):
79 return layout
80 return Layout(layout, width)
81
82 def __init__(self,
83 part_indexes: Union[Mapping[int, Any], Iterable[int]],
84 width: Optional[_WidthCastable] = None):
85 part_indexes = set(part_indexes)
86 for p in part_indexes:
87 assert isinstance(p, int)
88 assert 0 <= p
89 if width is not None:
90 width = Layout.get_width(width)
91 for p in part_indexes:
92 assert p <= width
93 part_indexes.add(width)
94 part_indexes.add(0)
95 part_indexes = list(part_indexes)
96 part_indexes.sort()
97 self.part_indexes = tuple(part_indexes)
98 sizes: List[int] = []
99 for start_index in range(len(self.part_indexes)):
100 start = self.part_indexes[start_index]
101 for end in self.part_indexes[start_index + 1:]:
102 sizes.append(end - start)
103 sizes.sort()
104 # build in sorted order
105 self.__lane_starts_for_sizes = {size: {} for size in sizes}
106 for start_index in range(len(self.part_indexes)):
107 start = self.part_indexes[start_index]
108 for end in self.part_indexes[start_index + 1:]:
109 self.__lane_starts_for_sizes[end - start][start] = None
110
111 @property
112 def width(self) -> int:
113 return self.part_indexes[-1]
114
115 @property
116 def part_signal_count(self) -> int:
117 return max(len(self.part_indexes) - 2, 0)
118
119 @staticmethod
120 def get_width(width: _WidthCastable) -> int:
121 if isinstance(width, Layout):
122 width = width.width
123 assert isinstance(width, int)
124 assert width >= 0
125 return width
126
127 def partition_points_signals(self, name: Optional[str] = None,
128 src_loc_at: int = 0) -> PartitionPoints:
129 if name is None:
130 name = Signal(src_loc_at=1 + src_loc_at).name
131 return PartitionPoints({
132 i: Signal(name=f"{name}_{i}", src_loc_at=1 + src_loc_at)
133 for i in self.part_indexes[1:-1]
134 })
135
136 def __repr__(self) -> str:
137 return f"Layout({self.part_indexes}, width={self.width})"
138
139 def __eq__(self, o: object) -> bool:
140 if isinstance(o, Layout):
141 return self.part_indexes == o.part_indexes
142 return NotImplemented
143
144 def __hash__(self) -> int:
145 return hash(self.part_indexes)
146
147 def is_lane_valid(self, start: int, size: int) -> bool:
148 return start in self.__lane_starts_for_sizes.get(size, ())
149
150 def lane_sizes(self) -> Iterable[int]:
151 return self.__lane_starts_for_sizes.keys()
152
153 def lane_starts_for_size(self, size: int) -> Iterable[int]:
154 return self.__lane_starts_for_sizes[size].keys()
155
156 def lanes_for_size(self, size: int) -> Iterable["Lane"]:
157 for start in self.lane_starts_for_size(size):
158 yield Lane(start, size, self)
159
160 def lanes(self) -> Iterable["Lane"]:
161 for size in self.lane_sizes():
162 yield from self.lanes_for_size(size)
163
164 def is_compatible(self, other: _LayoutCastable) -> bool:
165 other = Layout.cast(other)
166 return len(self.part_indexes) == len(other.part_indexes)
167
168 def translate_lane_to(self, lane: "Lane",
169 target_layout: _LayoutCastable) -> "Lane":
170 assert lane.layout == self
171 target_layout = Layout.cast(target_layout)
172 assert self.is_compatible(target_layout)
173 start_index = self.part_indexes.index(lane.start)
174 end_index = self.part_indexes.index(lane.end)
175 target_start = target_layout.part_indexes[start_index]
176 target_end = target_layout.part_indexes[end_index]
177 return Lane(target_start, target_end - target_start, target_layout)
178
179
180 @final
181 class Lane:
182 def __init__(self, start: int, size: int, layout: _LayoutCastable):
183 self.layout = Layout.cast(layout)
184 assert self.layout.is_lane_valid(start, size)
185 self.start = start
186 self.size = size
187
188 def __repr__(self) -> str:
189 return (f"Lane(start={self.start}, size={self.size}, "
190 f"layout={self.layout})")
191
192 def __eq__(self, o: object) -> bool:
193 if isinstance(o, Lane):
194 return self.start == o.start and self.size == o.size \
195 and self.layout == o.layout
196 return NotImplemented
197
198 def __hash__(self) -> int:
199 return hash((self.start, self.size, self.layout))
200
201 def as_slice(self) -> slice:
202 return slice(self.start, self.end)
203
204 @property
205 def end(self) -> int:
206 return self.start + self.size
207
208 def translate_to(self, target_layout: _LayoutCastable) -> "Lane":
209 return self.layout.translate_lane_to(self, target_layout)
210
211 @overload
212 def is_active(self, partition_points: Sequence[bool]) -> bool: ...
213
214 @overload
215 def is_active(self, partition_points: Sequence[_ValueCastableType]
216 ) -> Union[Value, bool]: ...
217
218 @overload
219 def is_active(self, partition_points: Mapping[int, bool]) -> bool: ...
220
221 @overload
222 def is_active(self, partition_points: Mapping[int, _ValueCastableType]
223 ) -> Union[Value, bool]: ...
224
225 def is_active(self, partition_points):
226 def get_partition_point(index: int, invert: bool):
227 if index == 0 or index == len(self.layout.part_indexes) - 1:
228 return True
229 if isinstance(partition_points, Sequence):
230 retval = partition_points[index]
231 else:
232 retval = partition_points[self.layout.part_indexes[index]]
233 if isinstance(retval, bool):
234 if invert:
235 return not retval
236 return retval
237 retval = Value.cast(retval)
238 if invert:
239 return ~retval
240 return retval
241 start_index = self.layout.part_indexes.index(self.start)
242 end_index = self.layout.part_indexes.index(self.end)
243 retval = get_partition_point(start_index, False) \
244 & get_partition_point(end_index, False)
245 for i in range(start_index + 1, end_index):
246 retval &= get_partition_point(i, True)
247 return retval
248
249
250 _PartitionedSignalTestReference = Callable[[Lane, Tuple[Value, ...]],
251 _ValueCastableType]
252
253 _PartitionedSignalTestCasePartMode = Tuple[bool, ...]
254 _PartitionedSignalTestCaseInputs = Tuple[int, ...]
255 _PartitionedSignalTestCase = Tuple[_PartitionedSignalTestCasePartMode,
256 _PartitionedSignalTestCaseInputs]
257
258
259 class PartitionedSignalTester:
260 layouts: List[Layout]
261 inputs: List[PartitionedSignal]
262
263 def __init__(self,
264 m: Module,
265 operation: _PartitionedSignalTestable,
266 reference: _PartitionedSignalTestReference,
267 *layouts: _LayoutCastable,
268 src_loc_at: int = 0,
269 additional_case_count: int = 30,
270 special_cases: Iterable[_PartitionedSignalTestCase] = (),
271 seed: str = ""):
272 self.m = m
273 self.operation = operation
274 self.reference = reference
275 self.layouts = []
276 self.inputs = []
277 for layout in layouts:
278 layout = Layout.cast(layout)
279 if len(self.layouts) > 0:
280 assert self.layouts[0].is_compatible(layout)
281 self.layouts.append(layout)
282 name = f"input_{len(self.inputs)}"
283 ps = PartitionedSignal(
284 layout.partition_points_signals(name=name,
285 src_loc_at=1 + src_loc_at),
286 layout.width,
287 name=name)
288 ps.set_module(m)
289 self.inputs.append(ps)
290 assert len(self.layouts) != 0, "must have at least one input layout"
291 for i in range(1, len(self.inputs)):
292 for j in range(1, len(self.layouts[0].part_indexes) - 1):
293 lhs_part_point = self.layouts[i].part_indexes[j]
294 rhs_part_point = self.layouts[0].part_indexes[j]
295 lhs = self.inputs[i].partpoints[lhs_part_point]
296 rhs = self.inputs[0].partpoints[rhs_part_point]
297 m.d.comb += lhs.eq(rhs)
298 self.special_cases = list(special_cases)
299 self.case_count = additional_case_count + len(self.special_cases)
300 self.seed = seed
301 self.case_number = Signal(64)
302 self.test_output = operation(tuple(self.inputs))
303 assert isinstance(self.test_output, PartitionedSignal)
304 self.test_output_layout = Layout(
305 self.test_output.partpoints, self.test_output.sig.width)
306 assert self.test_output_layout.is_compatible(self.layouts[0])
307 self.reference_output_values = {
308 lane: Value.cast(reference(lane, tuple(
309 inp.sig[lane.translate_to(layout).as_slice()]
310 for inp, layout in zip(self.inputs, self.layouts))))
311 for lane in self.layouts[0].lanes()
312 }
313 self.reference_outputs = {
314 lane: Signal(value.shape(),
315 name=f"reference_output_{lane.start}_{lane.size}")
316 for lane, value in self.reference_output_values.items()
317 }
318 for lane, value in self.reference_output_values.items():
319 m.d.comb += self.reference_outputs[lane].eq(value)
320
321 def __hash_256(self, v: str) -> int:
322 return int.from_bytes(
323 sha256(bytes(self.seed + v, encoding='utf-8')).digest(),
324 byteorder='little'
325 )
326
327 def __hash(self, v: str, bits: int) -> int:
328 retval = 0
329 for i in range(0, bits, 256):
330 retval <<= 256
331 retval |= self.__hash_256(f" {v} {i}")
332 return retval & ((1 << bits) - 1)
333
334 def __get_case(self, case_number: int) -> _PartitionedSignalTestCase:
335 if case_number < len(self.special_cases):
336 return self.special_cases[case_number]
337 trial = 0
338 bits = self.__hash(f"{case_number} trial {trial}",
339 self.layouts[0].part_signal_count)
340 bits |= 1 | (1 << len(self.layouts[0].part_indexes)) | (bits << 1)
341 part_starts = tuple(
342 (bits & (1 << i)) != 0
343 for i in range(len(self.layouts[0].part_indexes)))
344 inputs = tuple(self.__hash(f"{case_number} input {i}",
345 self.layouts[i].width)
346 for i in range(len(self.layouts)))
347 return part_starts, inputs
348
349 def __format_case(self, case: _PartitionedSignalTestCase) -> str:
350 part_starts, inputs = case
351 str_inputs = [hex(i) for i in inputs]
352 return f"part_starts={part_starts}, inputs={str_inputs}"
353
354 def __setup_case(self, case_number: int,
355 case: Optional[_PartitionedSignalTestCase] = None
356 ) -> Generator[Any, int, None]:
357 if case is None:
358 case = self.__get_case(case_number)
359 yield self.case_number.eq(case_number)
360 part_starts, inputs = case
361 part_indexes = self.layouts[0].part_indexes
362 assert len(part_starts) == len(part_indexes)
363 for i in range(1, len(part_starts) - 1):
364 yield self.inputs[0].partpoints[part_indexes[i]].eq(part_starts[i])
365 for i in range(len(self.inputs)):
366 yield self.inputs[i].sig.eq(inputs[i])
367
368 def run_sim(self, test_case: unittest.TestCase, *,
369 engine: Optional[str] = None,
370 base_path: _StrPath = "sim_test_out"):
371 if engine is None:
372 sim = Simulator(self.m)
373 else:
374 sim = Simulator(self.m, engine=engine)
375
376 def check_active_lane(lane: Lane):
377 reference = yield self.reference_outputs[lane]
378 output = yield self.test_output.sig[
379 lane.translate_to(self.test_output_layout).as_slice()]
380 test_case.assertEqual(hex(reference), hex(output))
381
382 def check_case(case: _PartitionedSignalTestCase):
383 part_starts, inputs = case
384 for i in range(1, len(self.layouts[0].part_indexes) - 1):
385 part_point = yield self.test_output.partpoints[
386 self.test_output_layout.part_indexes[i]]
387 test_case.assertEqual(part_point, part_starts[i])
388 for lane in self.layouts[0].lanes():
389 with test_case.subTest(lane=lane):
390 active = lane.is_active(part_starts)
391 if active:
392 yield from check_active_lane(lane)
393
394 def process():
395 for case_number in range(self.case_count):
396 with test_case.subTest(case_number=str(case_number)):
397 case = self.__get_case(case_number)
398 with test_case.subTest(case=self.__format_case(case)):
399 yield from self.__setup_case(case_number, case)
400 yield Delay(1e-6)
401 yield from check_case(case)
402 sim.add_process(process)
403 path = get_test_path(test_case, base_path)
404 path.parent.mkdir(parents=True, exist_ok=True)
405 vcd_path = path.with_suffix(".vcd")
406 gtkw_path = path.with_suffix(".gtkw")
407 traces = [self.case_number]
408 for i in self.layouts[0].part_indexes[1:-1]:
409 traces.append(self.inputs[0].partpoints[i])
410 for inp in self.inputs:
411 traces.append(inp.sig)
412 traces.extend(self.reference_outputs.values())
413 traces.append(self.test_output.sig)
414 with sim.write_vcd(vcd_path.open("wt", encoding="utf-8"),
415 gtkw_path.open("wt", encoding="utf-8"),
416 traces=traces):
417 sim.run()
418
419 def run_formal(self, test_case: unittest.TestCase, **kwargs):
420 for part_point in self.inputs[0].partpoints.values():
421 self.m.d.comb += part_point.eq(AnyConst(1))
422 for i in range(len(self.inputs)):
423 s = self.inputs[i].sig
424 self.m.d.comb += s.eq(AnyConst(s.shape()))
425 for i in range(1, len(self.layouts[0].part_indexes) - 1):
426 in_part_point = self.inputs[0].partpoints[
427 self.layouts[0].part_indexes[i]]
428 out_part_point = self.test_output.partpoints[
429 self.test_output_layout.part_indexes[i]]
430 self.m.d.comb += Assert(in_part_point == out_part_point)
431
432 def check_active_lane(lane: Lane) -> Assert:
433 reference = self.reference_outputs[lane]
434 output = self.test_output.sig[
435 lane.translate_to(self.test_output_layout).as_slice()]
436 yield Assert(reference == output)
437
438 for lane in self.layouts[0].lanes():
439 with test_case.subTest(lane=lane):
440 a = check_active_lane(lane)
441 with self.m.If(lane.is_active(self.inputs[0].partpoints)):
442 self.m.d.comb += a
443 formal(test_case, self.m, **kwargs)