1 # SPDX-License-Identifier: LGPL-3-or-later
2 # See Notices.txt for copyright information
4 from typing
import (Any
, Callable
, Dict
, Generator
, Iterable
, List
, Mapping
,
5 Optional
, Sequence
, Tuple
, Union
, final
, overload
)
7 from nmigen
.hdl
.ast
import (AnyConst
, Assert
, Signal
, Value
, ValueCastable
)
8 from nmigen
.hdl
.dsl
import Module
9 from nmigen
.hdl
.ir
import Elaboratable
, Fragment
10 from nmigen
.sim
import Simulator
, Delay
11 from ieee754
.part
.partsig
import PartitionedSignal
, PartitionPoints
15 from hashlib
import sha256
16 from nmigen
.back
import rtlil
17 from nmutil
.get_test_path
import get_test_path
, _StrPath
20 _PartitionedSignalTestable
= Callable
[[Tuple
[PartitionedSignal
, ...]],
23 _WidthCastable
= Union
["Layout", int]
24 _LayoutCastable
= Union
["Layout", Mapping
[int, Any
], Iterable
[int]]
25 _ValueCastableType
= Union
[Value
, int, Enum
, ValueCastable
]
26 _FragmentLike
= Union
[Elaboratable
, Fragment
]
29 def formal(test_case
: unittest
.TestCase
, hdl
: _FragmentLike
, *,
30 base_path
: _StrPath
= "formal_test_temp"):
31 hdl
= Fragment
.get(hdl
, platform
="formal")
32 path
= get_test_path(test_case
, base_path
)
33 shutil
.rmtree(path
, ignore_errors
=True)
34 path
.mkdir(parents
=True)
35 sby_name
= "config.sby"
36 sby_file
= path
/ sby_name
38 sby_file
.write_text(textwrap
.dedent(f
"""\
53 """), encoding
="utf-8")
54 sby
= shutil
.which('sby')
55 assert sby
is not None
56 with subprocess
.Popen(
58 cwd
=path
, text
=True, encoding
="utf-8",
59 stdin
=subprocess
.DEVNULL
, stdout
=subprocess
.PIPE
61 stdout
, stderr
= p
.communicate()
63 test_case
.fail(f
"Formal failed:\n{stdout}")
68 __lane_starts_for_sizes
: Dict
[int, Dict
[int, None]]
69 """keys are in sorted order"""
71 part_indexes
: Tuple
[int, ...]
72 """bit indexes of partition points in sorted order, always includes
73 `0` and `self.width`"""
76 def cast(layout
: _LayoutCastable
,
77 width
: Optional
[_WidthCastable
] = None) -> "Layout":
78 if isinstance(layout
, Layout
):
80 return Layout(layout
, width
)
83 part_indexes
: Union
[Mapping
[int, Any
], Iterable
[int]],
84 width
: Optional
[_WidthCastable
] = None):
85 part_indexes
= set(part_indexes
)
86 for p
in part_indexes
:
87 assert isinstance(p
, int)
90 width
= Layout
.get_width(width
)
91 for p
in part_indexes
:
93 part_indexes
.add(width
)
95 part_indexes
= list(part_indexes
)
97 self
.part_indexes
= tuple(part_indexes
)
99 for start_index
in range(len(self
.part_indexes
)):
100 start
= self
.part_indexes
[start_index
]
101 for end
in self
.part_indexes
[start_index
+ 1:]:
102 sizes
.append(end
- start
)
104 # build in sorted order
105 self
.__lane
_starts
_for
_sizes
= {size
: {} for size
in sizes
}
106 for start_index
in range(len(self
.part_indexes
)):
107 start
= self
.part_indexes
[start_index
]
108 for end
in self
.part_indexes
[start_index
+ 1:]:
109 self
.__lane
_starts
_for
_sizes
[end
- start
][start
] = None
112 def width(self
) -> int:
113 return self
.part_indexes
[-1]
116 def part_signal_count(self
) -> int:
117 return max(len(self
.part_indexes
) - 2, 0)
120 def get_width(width
: _WidthCastable
) -> int:
121 if isinstance(width
, Layout
):
123 assert isinstance(width
, int)
127 def partition_points_signals(self
, name
: Optional
[str] = None,
128 src_loc_at
: int = 0) -> PartitionPoints
:
130 name
= Signal(src_loc_at
=1 + src_loc_at
).name
131 return PartitionPoints({
132 i
: Signal(name
=f
"{name}_{i}", src_loc_at
=1 + src_loc_at
)
133 for i
in self
.part_indexes
[1:-1]
136 def __repr__(self
) -> str:
137 return f
"Layout({self.part_indexes}, width={self.width})"
139 def __eq__(self
, o
: object) -> bool:
140 if isinstance(o
, Layout
):
141 return self
.part_indexes
== o
.part_indexes
142 return NotImplemented
144 def __hash__(self
) -> int:
145 return hash(self
.part_indexes
)
147 def is_lane_valid(self
, start
: int, size
: int) -> bool:
148 return start
in self
.__lane
_starts
_for
_sizes
.get(size
, ())
150 def lane_sizes(self
) -> Iterable
[int]:
151 return self
.__lane
_starts
_for
_sizes
.keys()
153 def lane_starts_for_size(self
, size
: int) -> Iterable
[int]:
154 return self
.__lane
_starts
_for
_sizes
[size
].keys()
156 def lanes_for_size(self
, size
: int) -> Iterable
["Lane"]:
157 for start
in self
.lane_starts_for_size(size
):
158 yield Lane(start
, size
, self
)
160 def lanes(self
) -> Iterable
["Lane"]:
161 for size
in self
.lane_sizes():
162 yield from self
.lanes_for_size(size
)
164 def is_compatible(self
, other
: _LayoutCastable
) -> bool:
165 other
= Layout
.cast(other
)
166 return len(self
.part_indexes
) == len(other
.part_indexes
)
168 def translate_lane_to(self
, lane
: "Lane",
169 target_layout
: _LayoutCastable
) -> "Lane":
170 assert lane
.layout
== self
171 target_layout
= Layout
.cast(target_layout
)
172 assert self
.is_compatible(target_layout
)
173 start_index
= self
.part_indexes
.index(lane
.start
)
174 end_index
= self
.part_indexes
.index(lane
.end
)
175 target_start
= target_layout
.part_indexes
[start_index
]
176 target_end
= target_layout
.part_indexes
[end_index
]
177 return Lane(target_start
, target_end
- target_start
, target_layout
)
182 def __init__(self
, start
: int, size
: int, layout
: _LayoutCastable
):
183 self
.layout
= Layout
.cast(layout
)
184 assert self
.layout
.is_lane_valid(start
, size
)
188 def __repr__(self
) -> str:
189 return (f
"Lane(start={self.start}, size={self.size}, "
190 f
"layout={self.layout})")
192 def __eq__(self
, o
: object) -> bool:
193 if isinstance(o
, Lane
):
194 return self
.start
== o
.start
and self
.size
== o
.size \
195 and self
.layout
== o
.layout
196 return NotImplemented
198 def __hash__(self
) -> int:
199 return hash((self
.start
, self
.size
, self
.layout
))
201 def as_slice(self
) -> slice:
202 return slice(self
.start
, self
.end
)
205 def end(self
) -> int:
206 return self
.start
+ self
.size
208 def translate_to(self
, target_layout
: _LayoutCastable
) -> "Lane":
209 return self
.layout
.translate_lane_to(self
, target_layout
)
212 def is_active(self
, partition_points
: Sequence
[bool]) -> bool: ...
215 def is_active(self
, partition_points
: Sequence
[_ValueCastableType
]
216 ) -> Union
[Value
, bool]: ...
219 def is_active(self
, partition_points
: Mapping
[int, bool]) -> bool: ...
222 def is_active(self
, partition_points
: Mapping
[int, _ValueCastableType
]
223 ) -> Union
[Value
, bool]: ...
225 def is_active(self
, partition_points
):
226 def get_partition_point(index
: int, invert
: bool):
227 if index
== 0 or index
== len(self
.layout
.part_indexes
) - 1:
229 if isinstance(partition_points
, Sequence
):
230 retval
= partition_points
[index
]
232 retval
= partition_points
[self
.layout
.part_indexes
[index
]]
233 if isinstance(retval
, bool):
237 retval
= Value
.cast(retval
)
241 start_index
= self
.layout
.part_indexes
.index(self
.start
)
242 end_index
= self
.layout
.part_indexes
.index(self
.end
)
243 retval
= get_partition_point(start_index
, False) \
244 & get_partition_point(end_index
, False)
245 for i
in range(start_index
+ 1, end_index
):
246 retval
&= get_partition_point(i
, True)
250 _PartitionedSignalTestReference
= Callable
[[Lane
, Tuple
[Value
, ...]],
253 _PartitionedSignalTestCasePartMode
= Tuple
[bool, ...]
254 _PartitionedSignalTestCaseInputs
= Tuple
[int, ...]
255 _PartitionedSignalTestCase
= Tuple
[_PartitionedSignalTestCasePartMode
,
256 _PartitionedSignalTestCaseInputs
]
259 class PartitionedSignalTester
:
260 layouts
: List
[Layout
]
261 inputs
: List
[PartitionedSignal
]
265 operation
: _PartitionedSignalTestable
,
266 reference
: _PartitionedSignalTestReference
,
267 *layouts
: _LayoutCastable
,
269 additional_case_count
: int = 30,
270 special_cases
: Iterable
[_PartitionedSignalTestCase
] = (),
273 self
.operation
= operation
274 self
.reference
= reference
277 for layout
in layouts
:
278 layout
= Layout
.cast(layout
)
279 if len(self
.layouts
) > 0:
280 assert self
.layouts
[0].is_compatible(layout
)
281 self
.layouts
.append(layout
)
282 name
= f
"input_{len(self.inputs)}"
283 ps
= PartitionedSignal(
284 layout
.partition_points_signals(name
=name
,
285 src_loc_at
=1 + src_loc_at
),
289 self
.inputs
.append(ps
)
290 assert len(self
.layouts
) != 0, "must have at least one input layout"
291 for i
in range(1, len(self
.inputs
)):
292 for j
in range(1, len(self
.layouts
[0].part_indexes
) - 1):
293 lhs_part_point
= self
.layouts
[i
].part_indexes
[j
]
294 rhs_part_point
= self
.layouts
[0].part_indexes
[j
]
295 lhs
= self
.inputs
[i
].partpoints
[lhs_part_point
]
296 rhs
= self
.inputs
[0].partpoints
[rhs_part_point
]
297 m
.d
.comb
+= lhs
.eq(rhs
)
298 self
.special_cases
= list(special_cases
)
299 self
.case_count
= additional_case_count
+ len(self
.special_cases
)
301 self
.case_number
= Signal(64)
302 self
.test_output
= operation(tuple(self
.inputs
))
303 assert isinstance(self
.test_output
, PartitionedSignal
)
304 self
.test_output_layout
= Layout(
305 self
.test_output
.partpoints
, self
.test_output
.sig
.width
)
306 assert self
.test_output_layout
.is_compatible(self
.layouts
[0])
307 self
.reference_output_values
= {
308 lane
: Value
.cast(reference(lane
, tuple(
309 inp
.sig
[lane
.translate_to(layout
).as_slice()]
310 for inp
, layout
in zip(self
.inputs
, self
.layouts
))))
311 for lane
in self
.layouts
[0].lanes()
313 self
.reference_outputs
= {
314 lane
: Signal(value
.shape(),
315 name
=f
"reference_output_{lane.start}_{lane.size}")
316 for lane
, value
in self
.reference_output_values
.items()
318 for lane
, value
in self
.reference_output_values
.items():
319 m
.d
.comb
+= self
.reference_outputs
[lane
].eq(value
)
321 def __hash_256(self
, v
: str) -> int:
322 return int.from_bytes(
323 sha256(bytes(self
.seed
+ v
, encoding
='utf-8')).digest(),
327 def __hash(self
, v
: str, bits
: int) -> int:
329 for i
in range(0, bits
, 256):
331 retval |
= self
.__hash
_256(f
" {v} {i}")
332 return retval
& ((1 << bits
) - 1)
334 def __get_case(self
, case_number
: int) -> _PartitionedSignalTestCase
:
335 if case_number
< len(self
.special_cases
):
336 return self
.special_cases
[case_number
]
338 bits
= self
.__hash
(f
"{case_number} trial {trial}",
339 self
.layouts
[0].part_signal_count
)
340 bits |
= 1 |
(1 << len(self
.layouts
[0].part_indexes
)) |
(bits
<< 1)
342 (bits
& (1 << i
)) != 0
343 for i
in range(len(self
.layouts
[0].part_indexes
)))
344 inputs
= tuple(self
.__hash
(f
"{case_number} input {i}",
345 self
.layouts
[i
].width
)
346 for i
in range(len(self
.layouts
)))
347 return part_starts
, inputs
349 def __format_case(self
, case
: _PartitionedSignalTestCase
) -> str:
350 part_starts
, inputs
= case
351 str_inputs
= [hex(i
) for i
in inputs
]
352 return f
"part_starts={part_starts}, inputs={str_inputs}"
354 def __setup_case(self
, case_number
: int,
355 case
: Optional
[_PartitionedSignalTestCase
] = None
356 ) -> Generator
[Any
, int, None]:
358 case
= self
.__get
_case
(case_number
)
359 yield self
.case_number
.eq(case_number
)
360 part_starts
, inputs
= case
361 part_indexes
= self
.layouts
[0].part_indexes
362 assert len(part_starts
) == len(part_indexes
)
363 for i
in range(1, len(part_starts
) - 1):
364 yield self
.inputs
[0].partpoints
[part_indexes
[i
]].eq(part_starts
[i
])
365 for i
in range(len(self
.inputs
)):
366 yield self
.inputs
[i
].sig
.eq(inputs
[i
])
368 def run_sim(self
, test_case
: unittest
.TestCase
, *,
369 engine
: Optional
[str] = None,
370 base_path
: _StrPath
= "sim_test_out"):
372 sim
= Simulator(self
.m
)
374 sim
= Simulator(self
.m
, engine
=engine
)
376 def check_active_lane(lane
: Lane
):
377 reference
= yield self
.reference_outputs
[lane
]
378 output
= yield self
.test_output
.sig
[
379 lane
.translate_to(self
.test_output_layout
).as_slice()]
380 test_case
.assertEqual(hex(reference
), hex(output
))
382 def check_case(case
: _PartitionedSignalTestCase
):
383 part_starts
, inputs
= case
384 for i
in range(1, len(self
.layouts
[0].part_indexes
) - 1):
385 part_point
= yield self
.test_output
.partpoints
[
386 self
.test_output_layout
.part_indexes
[i
]]
387 test_case
.assertEqual(part_point
, part_starts
[i
])
388 for lane
in self
.layouts
[0].lanes():
389 with test_case
.subTest(lane
=lane
):
390 active
= lane
.is_active(part_starts
)
392 yield from check_active_lane(lane
)
395 for case_number
in range(self
.case_count
):
396 with test_case
.subTest(case_number
=str(case_number
)):
397 case
= self
.__get
_case
(case_number
)
398 with test_case
.subTest(case
=self
.__format
_case
(case
)):
399 yield from self
.__setup
_case
(case_number
, case
)
401 yield from check_case(case
)
402 sim
.add_process(process
)
403 path
= get_test_path(test_case
, base_path
)
404 path
.parent
.mkdir(parents
=True, exist_ok
=True)
405 vcd_path
= path
.with_suffix(".vcd")
406 gtkw_path
= path
.with_suffix(".gtkw")
407 traces
= [self
.case_number
]
408 for i
in self
.layouts
[0].part_indexes
[1:-1]:
409 traces
.append(self
.inputs
[0].partpoints
[i
])
410 for inp
in self
.inputs
:
411 traces
.append(inp
.sig
)
412 traces
.extend(self
.reference_outputs
.values())
413 traces
.append(self
.test_output
.sig
)
414 with sim
.write_vcd(vcd_path
.open("wt", encoding
="utf-8"),
415 gtkw_path
.open("wt", encoding
="utf-8"),
419 def run_formal(self
, test_case
: unittest
.TestCase
, **kwargs
):
420 for part_point
in self
.inputs
[0].partpoints
.values():
421 self
.m
.d
.comb
+= part_point
.eq(AnyConst(1))
422 for i
in range(len(self
.inputs
)):
423 s
= self
.inputs
[i
].sig
424 self
.m
.d
.comb
+= s
.eq(AnyConst(s
.shape()))
425 for i
in range(1, len(self
.layouts
[0].part_indexes
) - 1):
426 in_part_point
= self
.inputs
[0].partpoints
[
427 self
.layouts
[0].part_indexes
[i
]]
428 out_part_point
= self
.test_output
.partpoints
[
429 self
.test_output_layout
.part_indexes
[i
]]
430 self
.m
.d
.comb
+= Assert(in_part_point
== out_part_point
)
432 def check_active_lane(lane
: Lane
) -> Assert
:
433 reference
= self
.reference_outputs
[lane
]
434 output
= self
.test_output
.sig
[
435 lane
.translate_to(self
.test_output_layout
).as_slice()]
436 yield Assert(reference
== output
)
438 for lane
in self
.layouts
[0].lanes():
439 with test_case
.subTest(lane
=lane
):
440 a
= check_active_lane(lane
)
441 with self
.m
.If(lane
.is_active(self
.inputs
[0].partpoints
)):
443 formal(test_case
, self
.m
, **kwargs
)