aa22a39678d586db01898a4fcc507816bf45e335
[gram.git] / gram / test / utils.py
1 import os
2 import re
3 import shutil
4 import subprocess
5 import textwrap
6 import traceback
7 import unittest
8 import warnings
9 from contextlib import contextmanager
10
11 from nmigen import *
12 from nmigen.sim.pysim import *
13 from nmigen.hdl.ir import Fragment
14 from nmigen.back import rtlil
15 from nmigen._toolchain import require_tool
16
17
18 __all__ = ["FHDLTestCase", "runSimulation", "wb_read", "wb_write", "PulseCounter", "Delay"]
19
20 def runSimulation(module, process, vcd_filename="anonymous.vcd", clock=1e-6):
21 sim = Simulator(module)
22 with sim.write_vcd(vcd_filename):
23 sim.add_clock(clock)
24 sim.add_sync_process(process)
25 sim.run()
26
27 class FHDLTestCase(unittest.TestCase):
28 def assertRepr(self, obj, repr_str):
29 if isinstance(obj, list):
30 obj = Statement.cast(obj)
31 def prepare_repr(repr_str):
32 repr_str = re.sub(r"\s+", " ", repr_str)
33 repr_str = re.sub(r"\( (?=\()", "(", repr_str)
34 repr_str = re.sub(r"\) (?=\))", ")", repr_str)
35 return repr_str.strip()
36 self.assertEqual(prepare_repr(repr(obj)), prepare_repr(repr_str))
37
38 @contextmanager
39 def assertRaises(self, exception, msg=None):
40 with super().assertRaises(exception) as cm:
41 yield
42 if msg is not None:
43 # WTF? unittest.assertRaises is completely broken.
44 self.assertEqual(str(cm.exception), msg)
45
46 @contextmanager
47 def assertRaisesRegex(self, exception, regex=None):
48 with super().assertRaises(exception) as cm:
49 yield
50 if regex is not None:
51 # unittest.assertRaisesRegex also seems broken...
52 self.assertRegex(str(cm.exception), regex)
53
54 @contextmanager
55 def assertWarns(self, category, msg=None):
56 with warnings.catch_warnings(record=True) as warns:
57 yield
58 self.assertEqual(len(warns), 1)
59 self.assertEqual(warns[0].category, category)
60 if msg is not None:
61 self.assertEqual(str(warns[0].message), msg)
62
63 def assertFormal(self, spec, mode="bmc", depth=1):
64 caller, *_ = traceback.extract_stack(limit=2)
65 spec_root, _ = os.path.splitext(caller.filename)
66 spec_dir = os.path.dirname(spec_root)
67 spec_name = "{}_{}".format(
68 os.path.basename(spec_root).replace("test_", "spec_"),
69 caller.name.replace("test_", "")
70 )
71
72 # The sby -f switch seems not fully functional when sby is
73 # reading from stdin.
74 if os.path.exists(os.path.join(spec_dir, spec_name)):
75 shutil.rmtree(os.path.join(spec_dir, spec_name))
76
77 config = textwrap.dedent("""\
78 [options]
79 mode {mode}
80 depth {depth}
81 wait on
82
83 [engines]
84 smtbmc
85
86 [script]
87 read_ilang top.il
88 prep
89
90 [file top.il]
91 {rtlil}
92 """).format(
93 mode=mode,
94 depth=depth,
95 rtlil=rtlil.convert(Fragment.get(spec, platform="formal"))
96 )
97 with subprocess.Popen([require_tool("sby"), "-f", "-d", spec_name],
98 cwd=spec_dir,
99 universal_newlines=True,
100 stdin=subprocess.PIPE,
101 stdout=subprocess.PIPE) as proc:
102 stdout, stderr = proc.communicate(config)
103 if proc.returncode != 0:
104 self.fail("Formal verification failed:\n" + stdout)
105
106 def wb_read(bus, addr, sel, timeout=32):
107 yield bus.cyc.eq(1)
108 yield bus.stb.eq(1)
109 yield bus.adr.eq(addr)
110 yield bus.sel.eq(sel)
111 yield
112 cycles = 0
113 while not (yield bus.ack):
114 yield
115 if cycles >= timeout:
116 raise RuntimeError("Wishbone transaction timed out")
117 cycles += 1
118 data = (yield bus.dat_r)
119 yield bus.cyc.eq(0)
120 yield bus.stb.eq(0)
121 return data
122
123 def wb_write(bus, addr, data, sel, timeout=32):
124 yield bus.cyc.eq(1)
125 yield bus.stb.eq(1)
126 yield bus.adr.eq(addr)
127 yield bus.we.eq(1)
128 yield bus.sel.eq(sel)
129 yield bus.dat_w.eq(data)
130 yield
131 cycles = 0
132 while not (yield bus.ack):
133 yield
134 if cycles >= timeout:
135 raise RuntimeError("Wishbone transaction timed out")
136 cycles += 1
137 yield bus.cyc.eq(0)
138 yield bus.stb.eq(0)
139 yield bus.we.eq(0)
140
141 class PulseCounter(Elaboratable):
142 def __init__(self, max=16):
143 self.i = Signal()
144 self.rst = Signal()
145 self.cnt = Signal(range(max))
146
147 def elaborate(self, platform):
148 m = Module()
149
150 with m.If(self.rst):
151 m.d.sync += self.cnt.eq(0)
152 with m.Elif(self.i):
153 m.d.sync += self.cnt.eq(self.cnt+1)
154
155 return m