switch to exact version of cython
[ieee754fpu.git] / src / ieee754 / partitioned_signal_tester.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # See Notices.txt for copyright information
3 from enum import Enum
4 import shutil
5 from nmigen.hdl.ast import (AnyConst, Assert, Signal, Value, ValueCastable)
6 from nmigen.hdl.dsl import Module
7 from nmigen.hdl.ir import Elaboratable, Fragment
8 from nmigen.sim import Simulator, Delay
9 from ieee754.part.partsig import SimdSignal, PartitionPoints
10 import unittest
11 import textwrap
12 import subprocess
13 from hashlib import sha256
14 from nmigen.back import rtlil
15 from nmutil.get_test_path import get_test_path
16 from collections.abc import Sequence
17
18
19 def formal(test_case, hdl, *, base_path="formal_test_temp"):
20 hdl = Fragment.get(hdl, platform="formal")
21 path = get_test_path(test_case, base_path)
22 shutil.rmtree(path, ignore_errors=True)
23 path.mkdir(parents=True)
24 sby_name = "config.sby"
25 sby_file = path / sby_name
26
27 sby_file.write_text(textwrap.dedent(f"""\
28 [options]
29 mode prove
30 depth 1
31 wait on
32
33 [engines]
34 smtbmc
35
36 [script]
37 read_rtlil top.il
38 prep
39
40 [file top.il]
41 {rtlil.convert(hdl)}
42 """), encoding="utf-8")
43 sby = shutil.which('sby')
44 assert sby is not None
45 with subprocess.Popen(
46 [sby, sby_name],
47 cwd=path, text=True, encoding="utf-8",
48 stdin=subprocess.DEVNULL, stdout=subprocess.PIPE
49 ) as p:
50 stdout, stderr = p.communicate()
51 if p.returncode != 0:
52 test_case.fail(f"Formal failed:\n{stdout}")
53
54
55 class Layout:
56 @staticmethod
57 def cast(layout, width=None):
58 if isinstance(layout, Layout):
59 return layout
60 return Layout(layout, width)
61
62 def __init__(self, part_indexes, width=None):
63 part_indexes = set(part_indexes)
64 for p in part_indexes:
65 assert isinstance(p, int)
66 assert 0 <= p
67 if width is not None:
68 width = Layout.get_width(width)
69 for p in part_indexes:
70 assert p <= width
71 part_indexes.add(width)
72 part_indexes.add(0)
73 part_indexes = list(part_indexes)
74 part_indexes.sort()
75 self.part_indexes = tuple(part_indexes)
76 """bit indexes of partition points in sorted order, always
77 includes `0` and `self.width`"""
78
79 sizes = []
80 for start_index in range(len(self.part_indexes)):
81 start = self.part_indexes[start_index]
82 for end in self.part_indexes[start_index + 1:]:
83 sizes.append(end - start)
84 sizes.sort()
85 # build in sorted order
86 self.__lane_starts_for_sizes = {size: {} for size in sizes}
87 """keys are in sorted order"""
88
89 for start_index in range(len(self.part_indexes)):
90 start = self.part_indexes[start_index]
91 for end in self.part_indexes[start_index + 1:]:
92 self.__lane_starts_for_sizes[end - start][start] = None
93
94 @property
95 def width(self):
96 return self.part_indexes[-1]
97
98 @property
99 def part_signal_count(self):
100 return max(len(self.part_indexes) - 2, 0)
101
102 @staticmethod
103 def get_width(width):
104 if isinstance(width, Layout):
105 width = width.width
106 assert isinstance(width, int)
107 assert width >= 0
108 return width
109
110 def partition_points_signals(self, name=None,
111 src_loc_at=0):
112 if name is None:
113 name = Signal(src_loc_at=1 + src_loc_at).name
114 pps = {}
115 for i in self.part_indexes[1:-1]:
116 pps[i] = Signal(name=f"{name}_{i}", src_loc_at=1 + src_loc_at)
117 return PartitionPoints(pps)
118
119 def __repr__(self):
120 return f"Layout({self.part_indexes}, width={self.width})"
121
122 def __eq__(self, o):
123 if isinstance(o, Layout):
124 return self.part_indexes == o.part_indexes
125 return NotImplemented
126
127 def __hash__(self):
128 return hash(self.part_indexes)
129
130 def is_lane_valid(self, start, size):
131 return start in self.__lane_starts_for_sizes.get(size, ())
132
133 def lane_sizes(self):
134 return self.__lane_starts_for_sizes.keys()
135
136 def lane_starts_for_size(self, size):
137 return self.__lane_starts_for_sizes[size].keys()
138
139 def lanes_for_size(self, size):
140 for start in self.lane_starts_for_size(size):
141 yield Lane(start, size, self)
142
143 def lanes(self):
144 for size in self.lane_sizes():
145 yield from self.lanes_for_size(size)
146
147 def is_compatible(self, other):
148 other = Layout.cast(other)
149 return len(self.part_indexes) == len(other.part_indexes)
150
151 def translate_lane_to(self, lane, target_layout):
152 assert lane.layout == self
153 target_layout = Layout.cast(target_layout)
154 assert self.is_compatible(target_layout)
155 start_index = self.part_indexes.index(lane.start)
156 end_index = self.part_indexes.index(lane.end)
157 target_start = target_layout.part_indexes[start_index]
158 target_end = target_layout.part_indexes[end_index]
159 return Lane(target_start, target_end - target_start, target_layout)
160
161
162 class Lane:
163 def __init__(self, start, size, layout):
164 self.layout = Layout.cast(layout)
165 assert self.layout.is_lane_valid(start, size)
166 self.start = start
167 self.size = size
168
169 def __repr__(self):
170 return (f"Lane(start={self.start}, size={self.size}, "
171 f"layout={self.layout})")
172
173 def __eq__(self, o):
174 if isinstance(o, Lane):
175 return self.start == o.start and self.size == o.size \
176 and self.layout == o.layout
177 return NotImplemented
178
179 def __hash__(self):
180 return hash((self.start, self.size, self.layout))
181
182 def as_slice(self):
183 return slice(self.start, self.end)
184
185 @property
186 def end(self):
187 return self.start + self.size
188
189 def translate_to(self, target_layout):
190 return self.layout.translate_lane_to(self, target_layout)
191
192 def is_active(self, partition_points):
193 def get_partition_point(index, invert):
194 if index == 0 or index == len(self.layout.part_indexes) - 1:
195 return True
196 if isinstance(partition_points, Sequence):
197 retval = partition_points[index]
198 else:
199 retval = partition_points[self.layout.part_indexes[index]]
200 if isinstance(retval, bool):
201 if invert:
202 return not retval
203 return retval
204 retval = Value.cast(retval)
205 if invert:
206 return ~retval
207 return retval
208
209 start_index = self.layout.part_indexes.index(self.start)
210 end_index = self.layout.part_indexes.index(self.end)
211 retval = get_partition_point(start_index, False) \
212 & get_partition_point(end_index, False)
213 for i in range(start_index + 1, end_index):
214 retval &= get_partition_point(i, True)
215
216 return retval
217
218
219 class SimdSignalTester:
220
221 def __init__(self, m, operation, reference, *layouts,
222 src_loc_at=0, additional_case_count=30,
223 special_cases=(), seed=""):
224 self.m = m
225 self.operation = operation
226 self.reference = reference
227 self.layouts = []
228 self.inputs = []
229 for layout in layouts:
230 layout = Layout.cast(layout)
231 if len(self.layouts) > 0:
232 assert self.layouts[0].is_compatible(layout)
233 self.layouts.append(layout)
234 name = f"input_{len(self.inputs)}"
235 ps = SimdSignal(
236 layout.partition_points_signals(name=name,
237 src_loc_at=1 + src_loc_at),
238 layout.width,
239 name=name)
240 ps.set_module(m)
241 self.inputs.append(ps)
242 assert len(self.layouts) != 0, "must have at least one input layout"
243 for i in range(1, len(self.inputs)):
244 for j in range(1, len(self.layouts[0].part_indexes) - 1):
245 lhs_part_point = self.layouts[i].part_indexes[j]
246 rhs_part_point = self.layouts[0].part_indexes[j]
247 lhs = self.inputs[i].partpoints[lhs_part_point]
248 rhs = self.inputs[0].partpoints[rhs_part_point]
249 m.d.comb += lhs.eq(rhs)
250 self.special_cases = list(special_cases)
251 self.case_count = additional_case_count + len(self.special_cases)
252 self.seed = seed
253 self.case_number = Signal(64)
254 self.test_output = operation(tuple(self.inputs))
255 assert isinstance(self.test_output, SimdSignal)
256 self.test_output_layout = Layout(
257 self.test_output.partpoints, self.test_output.sig.width)
258 assert self.test_output_layout.is_compatible(self.layouts[0])
259 self.reference_output_values = {}
260 for lane in self.layouts[0].lanes():
261 in_t = []
262 for inp, layout in zip(self.inputs, self.layouts):
263 in_t.append(inp.sig[lane.translate_to(layout).as_slice()])
264 v = Value.cast(reference(lane, tuple(in_t)))
265 self.reference_output_values[lane] = v
266 self.reference_outputs = {}
267 for lane, value in self.reference_output_values.items():
268 s = Signal(value.shape(),
269 name=f"reference_output_{lane.start}_{lane.size}")
270 self.reference_outputs[lane] = s
271 m.d.comb += s.eq(value)
272
273 def __hash_256(self, v):
274 return int.from_bytes(
275 sha256(bytes(self.seed + v, encoding='utf-8')).digest(),
276 byteorder='little'
277 )
278
279 def __hash(self, v, bits):
280 retval = 0
281 for i in range(0, bits, 256):
282 retval <<= 256
283 retval |= self.__hash_256(f" {v} {i}")
284 return retval & ((1 << bits) - 1)
285
286 def __get_case(self, case_number):
287 if case_number < len(self.special_cases):
288 return self.special_cases[case_number]
289 trial = 0
290 bits = self.__hash(f"{case_number} trial {trial}",
291 self.layouts[0].part_signal_count)
292 bits |= 1 | (1 << len(self.layouts[0].part_indexes)) | (bits << 1)
293 part_starts = []
294 for i in range(len(self.layouts[0].part_indexes)):
295 part_starts.append((bits & (1 << i)) != 0)
296 inputs = []
297 for i in range(len(self.layouts)):
298 inputs.append(self.__hash(f"{case_number} input {i}",
299 self.layouts[i].width))
300 return tuple(part_starts), tuple(inputs)
301
302 def __format_case(self, case):
303 part_starts, inputs = case
304 str_inputs = [hex(i) for i in inputs]
305 return f"part_starts={part_starts}, inputs={str_inputs}"
306
307 def __setup_case(self, case_number, case=None):
308 if case is None:
309 case = self.__get_case(case_number)
310 yield self.case_number.eq(case_number)
311 part_starts, inputs = case
312 part_indexes = self.layouts[0].part_indexes
313 assert len(part_starts) == len(part_indexes)
314 for i in range(1, len(part_starts) - 1):
315 yield self.inputs[0].partpoints[part_indexes[i]].eq(part_starts[i])
316 for i in range(len(self.inputs)):
317 yield self.inputs[i].sig.eq(inputs[i])
318
319 def run_sim(self, test_case, *, engine=None, base_path="sim_test_out"):
320 if engine is None:
321 sim = Simulator(self.m)
322 else:
323 sim = Simulator(self.m, engine=engine)
324
325 def check_active_lane(lane):
326 reference = yield self.reference_outputs[lane]
327 output = yield self.test_output.sig[
328 lane.translate_to(self.test_output_layout).as_slice()]
329 test_case.assertEqual(hex(reference), hex(output))
330
331 def check_case(case):
332 part_starts, inputs = case
333 for i in range(1, len(self.layouts[0].part_indexes) - 1):
334 part_point = yield self.test_output.partpoints[
335 self.test_output_layout.part_indexes[i]]
336 test_case.assertEqual(part_point, part_starts[i])
337 for lane in self.layouts[0].lanes():
338 with test_case.subTest(lane=lane):
339 active = lane.is_active(part_starts)
340 if active:
341 yield from check_active_lane(lane)
342
343 def process():
344 for case_number in range(self.case_count):
345 with test_case.subTest(case_number=str(case_number)):
346 case = self.__get_case(case_number)
347 with test_case.subTest(case=self.__format_case(case)):
348 yield from self.__setup_case(case_number, case)
349 yield Delay(1e-6)
350 yield from check_case(case)
351 sim.add_process(process)
352 path = get_test_path(test_case, base_path)
353 path.parent.mkdir(parents=True, exist_ok=True)
354 vcd_path = path.with_suffix(".vcd")
355 gtkw_path = path.with_suffix(".gtkw")
356 traces = [self.case_number]
357 for i in self.layouts[0].part_indexes[1:-1]:
358 traces.append(self.inputs[0].partpoints[i])
359 for inp in self.inputs:
360 traces.append(inp.sig)
361 traces.extend(self.reference_outputs.values())
362 traces.append(self.test_output.sig)
363 with sim.write_vcd(vcd_path.open("wt", encoding="utf-8"),
364 gtkw_path.open("wt", encoding="utf-8"),
365 traces=traces):
366 sim.run()
367
368 def run_formal(self, test_case, **kwargs):
369 for part_point in self.inputs[0].partpoints.values():
370 self.m.d.comb += part_point.eq(AnyConst(1))
371 for i in range(len(self.inputs)):
372 s = self.inputs[i].sig
373 self.m.d.comb += s.eq(AnyConst(s.shape()))
374 for i in range(1, len(self.layouts[0].part_indexes) - 1):
375 in_part_point = self.inputs[0].partpoints[
376 self.layouts[0].part_indexes[i]]
377 out_part_point = self.test_output.partpoints[
378 self.test_output_layout.part_indexes[i]]
379 self.m.d.comb += Assert(in_part_point == out_part_point)
380
381 def check_active_lane(lane):
382 reference = self.reference_outputs[lane]
383 output = self.test_output.sig[
384 lane.translate_to(self.test_output_layout).as_slice()]
385 yield Assert(reference == output)
386
387 for lane in self.layouts[0].lanes():
388 with test_case.subTest(lane=lane):
389 a = check_active_lane(lane)
390 with self.m.If(lane.is_active(self.inputs[0].partpoints)):
391 self.m.d.comb += a
392 formal(test_case, self.m, **kwargs)