c2e9367bcb1d0c36cc3ac8efd7bacb8c4ca00710
[nmigen.git] / nmigen / sim / _pyrtl.py
1 import os
2 import tempfile
3 from contextlib import contextmanager
4
5 from ..hdl import *
6 from ..hdl.ast import SignalSet
7 from ..hdl.xfrm import ValueVisitor, StatementVisitor, LHSGroupFilter
8 from ._core import *
9
10
11 __all__ = ["PyRTLProcess"]
12
13
14 class PyRTLProcess(Process):
15 pass
16
17
18 class _PythonEmitter:
19 def __init__(self):
20 self._buffer = []
21 self._suffix = 0
22 self._level = 0
23
24 def append(self, code):
25 self._buffer.append(" " * self._level)
26 self._buffer.append(code)
27 self._buffer.append("\n")
28
29 @contextmanager
30 def indent(self):
31 self._level += 1
32 yield
33 self._level -= 1
34
35 def flush(self, indent=""):
36 code = "".join(self._buffer)
37 self._buffer.clear()
38 return code
39
40 def gen_var(self, prefix):
41 name = f"{prefix}_{self._suffix}"
42 self._suffix += 1
43 return name
44
45 def def_var(self, prefix, value):
46 name = self.gen_var(prefix)
47 self.append(f"{name} = {value}")
48 return name
49
50
51 class _Compiler:
52 def __init__(self, state, emitter):
53 self.state = state
54 self.emitter = emitter
55
56
57 class _ValueCompiler(ValueVisitor, _Compiler):
58 helpers = {
59 "sign": lambda value, sign: value | sign if value & sign else value,
60 "zdiv": lambda lhs, rhs: 0 if rhs == 0 else lhs // rhs,
61 "zmod": lambda lhs, rhs: 0 if rhs == 0 else lhs % rhs,
62 }
63
64 def on_ClockSignal(self, value):
65 raise NotImplementedError # :nocov:
66
67 def on_ResetSignal(self, value):
68 raise NotImplementedError # :nocov:
69
70 def on_AnyConst(self, value):
71 raise NotImplementedError # :nocov:
72
73 def on_AnySeq(self, value):
74 raise NotImplementedError # :nocov:
75
76 def on_Sample(self, value):
77 raise NotImplementedError # :nocov:
78
79 def on_Initial(self, value):
80 raise NotImplementedError # :nocov:
81
82
83 class _RHSValueCompiler(_ValueCompiler):
84 def __init__(self, state, emitter, *, mode, inputs=None):
85 super().__init__(state, emitter)
86 assert mode in ("curr", "next")
87 self.mode = mode
88 # If not None, `inputs` gets populated with RHS signals.
89 self.inputs = inputs
90
91 def on_Const(self, value):
92 return f"{value.value}"
93
94 def on_Signal(self, value):
95 if self.inputs is not None:
96 self.inputs.add(value)
97
98 if self.mode == "curr":
99 return f"slots[{self.state.get_signal(value)}].{self.mode}"
100 else:
101 return f"next_{self.state.get_signal(value)}"
102
103 def on_Operator(self, value):
104 def mask(value):
105 value_mask = (1 << len(value)) - 1
106 return f"({self(value)} & {value_mask})"
107
108 def sign(value):
109 if value.shape().signed:
110 return f"sign({mask(value)}, {-1 << (len(value) - 1)})"
111 else: # unsigned
112 return mask(value)
113
114 if len(value.operands) == 1:
115 arg, = value.operands
116 if value.operator == "~":
117 return f"(~{self(arg)})"
118 if value.operator == "-":
119 return f"(-{sign(arg)})"
120 if value.operator == "b":
121 return f"bool({mask(arg)})"
122 if value.operator == "r|":
123 return f"({mask(arg)} != 0)"
124 if value.operator == "r&":
125 return f"({mask(arg)} == {(1 << len(arg)) - 1})"
126 if value.operator == "r^":
127 # Believe it or not, this is the fastest way to compute a sideways XOR in Python.
128 return f"(format({mask(arg)}, 'b').count('1') % 2)"
129 if value.operator in ("u", "s"):
130 # These operators don't change the bit pattern, only its interpretation.
131 return self(arg)
132 elif len(value.operands) == 2:
133 lhs, rhs = value.operands
134 lhs_mask = (1 << len(lhs)) - 1
135 rhs_mask = (1 << len(rhs)) - 1
136 if value.operator == "+":
137 return f"({sign(lhs)} + {sign(rhs)})"
138 if value.operator == "-":
139 return f"({sign(lhs)} - {sign(rhs)})"
140 if value.operator == "*":
141 return f"({sign(lhs)} * {sign(rhs)})"
142 if value.operator == "//":
143 return f"zdiv({sign(lhs)}, {sign(rhs)})"
144 if value.operator == "%":
145 return f"zmod({sign(lhs)}, {sign(rhs)})"
146 if value.operator == "&":
147 return f"({self(lhs)} & {self(rhs)})"
148 if value.operator == "|":
149 return f"({self(lhs)} | {self(rhs)})"
150 if value.operator == "^":
151 return f"({self(lhs)} ^ {self(rhs)})"
152 if value.operator == "<<":
153 return f"({sign(lhs)} << {sign(rhs)})"
154 if value.operator == ">>":
155 return f"({sign(lhs)} >> {sign(rhs)})"
156 if value.operator == "==":
157 return f"({sign(lhs)} == {sign(rhs)})"
158 if value.operator == "!=":
159 return f"({sign(lhs)} != {sign(rhs)})"
160 if value.operator == "<":
161 return f"({sign(lhs)} < {sign(rhs)})"
162 if value.operator == "<=":
163 return f"({sign(lhs)} <= {sign(rhs)})"
164 if value.operator == ">":
165 return f"({sign(lhs)} > {sign(rhs)})"
166 if value.operator == ">=":
167 return f"({sign(lhs)} >= {sign(rhs)})"
168 elif len(value.operands) == 3:
169 if value.operator == "m":
170 sel, val1, val0 = value.operands
171 return f"({self(val1)} if {self(sel)} else {self(val0)})"
172 raise NotImplementedError("Operator '{}' not implemented".format(value.operator)) # :nocov:
173
174 def on_Slice(self, value):
175 return f"(({self(value.value)} >> {value.start}) & {(1 << len(value)) - 1})"
176
177 def on_Part(self, value):
178 offset_mask = (1 << len(value.offset)) - 1
179 offset = f"(({self(value.offset)} & {offset_mask}) * {value.stride})"
180 return f"({self(value.value)} >> {offset} & " \
181 f"{(1 << value.width) - 1})"
182
183 def on_Cat(self, value):
184 gen_parts = []
185 offset = 0
186 for part in value.parts:
187 part_mask = (1 << len(part)) - 1
188 gen_parts.append(f"(({self(part)} & {part_mask}) << {offset})")
189 offset += len(part)
190 if gen_parts:
191 return f"({' | '.join(gen_parts)})"
192 return f"0"
193
194 def on_Repl(self, value):
195 part_mask = (1 << len(value.value)) - 1
196 gen_part = self.emitter.def_var("repl", f"{self(value.value)} & {part_mask}")
197 gen_parts = []
198 offset = 0
199 for _ in range(value.count):
200 gen_parts.append(f"({gen_part} << {offset})")
201 offset += len(value.value)
202 if gen_parts:
203 return f"({' | '.join(gen_parts)})"
204 return f"0"
205
206 def on_ArrayProxy(self, value):
207 index_mask = (1 << len(value.index)) - 1
208 gen_index = self.emitter.def_var("rhs_index", f"{self(value.index)} & {index_mask}")
209 gen_value = self.emitter.gen_var("rhs_proxy")
210 if value.elems:
211 gen_elems = []
212 for index, elem in enumerate(value.elems):
213 if index == 0:
214 self.emitter.append(f"if {gen_index} == {index}:")
215 else:
216 self.emitter.append(f"elif {gen_index} == {index}:")
217 with self.emitter.indent():
218 self.emitter.append(f"{gen_value} = {self(elem)}")
219 self.emitter.append(f"else:")
220 with self.emitter.indent():
221 self.emitter.append(f"{gen_value} = {self(value.elems[-1])}")
222 return gen_value
223 else:
224 return f"0"
225
226 @classmethod
227 def compile(cls, state, value, *, mode):
228 emitter = _PythonEmitter()
229 compiler = cls(state, emitter, mode=mode)
230 emitter.append(f"result = {compiler(value)}")
231 return emitter.flush()
232
233
234 class _LHSValueCompiler(_ValueCompiler):
235 def __init__(self, state, emitter, *, rhs, outputs=None):
236 super().__init__(state, emitter)
237 # `rrhs` is used to translate rvalues that are syntactically a part of an lvalue, e.g.
238 # the offset of a Part.
239 self.rrhs = rhs
240 # `lrhs` is used to translate the read part of a read-modify-write cycle during partial
241 # update of an lvalue.
242 self.lrhs = _RHSValueCompiler(state, emitter, mode="next", inputs=None)
243 # If not None, `outputs` gets populated with signals on LHS.
244 self.outputs = outputs
245
246 def on_Const(self, value):
247 raise TypeError # :nocov:
248
249 def on_Signal(self, value):
250 if self.outputs is not None:
251 self.outputs.add(value)
252
253 def gen(arg):
254 value_mask = (1 << len(value)) - 1
255 if value.shape().signed:
256 value_sign = f"sign({arg} & {value_mask}, {-1 << (len(value) - 1)})"
257 else: # unsigned
258 value_sign = f"{arg} & {value_mask}"
259 self.emitter.append(f"next_{self.state.get_signal(value)} = {value_sign}")
260 return gen
261
262 def on_Operator(self, value):
263 raise TypeError # :nocov:
264
265 def on_Slice(self, value):
266 def gen(arg):
267 width_mask = (1 << (value.stop - value.start)) - 1
268 self(value.value)(f"({self.lrhs(value.value)} & " \
269 f"{~(width_mask << value.start)} | " \
270 f"(({arg} & {width_mask}) << {value.start}))")
271 return gen
272
273 def on_Part(self, value):
274 def gen(arg):
275 width_mask = (1 << value.width) - 1
276 offset_mask = (1 << len(value.offset)) - 1
277 offset = f"(({self.rrhs(value.offset)} & {offset_mask}) * {value.stride})"
278 self(value.value)(f"({self.lrhs(value.value)} & " \
279 f"~({width_mask} << {offset}) | " \
280 f"(({arg} & {width_mask}) << {offset}))")
281 return gen
282
283 def on_Cat(self, value):
284 def gen(arg):
285 gen_arg = self.emitter.def_var("cat", arg)
286 gen_parts = []
287 offset = 0
288 for part in value.parts:
289 part_mask = (1 << len(part)) - 1
290 self(part)(f"(({gen_arg} >> {offset}) & {part_mask})")
291 offset += len(part)
292 return gen
293
294 def on_Repl(self, value):
295 raise TypeError # :nocov:
296
297 def on_ArrayProxy(self, value):
298 def gen(arg):
299 index_mask = (1 << len(value.index)) - 1
300 gen_index = self.emitter.def_var("index", f"{self.rrhs(value.index)} & {index_mask}")
301 if value.elems:
302 gen_elems = []
303 for index, elem in enumerate(value.elems):
304 if index == 0:
305 self.emitter.append(f"if {gen_index} == {index}:")
306 else:
307 self.emitter.append(f"elif {gen_index} == {index}:")
308 with self.emitter.indent():
309 self(elem)(arg)
310 self.emitter.append(f"else:")
311 with self.emitter.indent():
312 self(value.elems[-1])(arg)
313 else:
314 self.emitter.append(f"pass")
315 return gen
316
317
318 class _StatementCompiler(StatementVisitor, _Compiler):
319 def __init__(self, state, emitter, *, inputs=None, outputs=None):
320 super().__init__(state, emitter)
321 self.rhs = _RHSValueCompiler(state, emitter, mode="curr", inputs=inputs)
322 self.lhs = _LHSValueCompiler(state, emitter, rhs=self.rhs, outputs=outputs)
323
324 def on_statements(self, stmts):
325 for stmt in stmts:
326 self(stmt)
327 if not stmts:
328 self.emitter.append("pass")
329
330 def on_Assign(self, stmt):
331 return self.lhs(stmt.lhs)(self.rhs(stmt.rhs))
332
333 def on_Switch(self, stmt):
334 gen_test = self.emitter.def_var("test",
335 f"{self.rhs(stmt.test)} & {(1 << len(stmt.test)) - 1}")
336 for index, (patterns, stmts) in enumerate(stmt.cases.items()):
337 gen_checks = []
338 if not patterns:
339 gen_checks.append(f"True")
340 else:
341 for pattern in patterns:
342 if "-" in pattern:
343 mask = int("".join("0" if b == "-" else "1" for b in pattern), 2)
344 value = int("".join("0" if b == "-" else b for b in pattern), 2)
345 gen_checks.append(f"({gen_test} & {mask}) == {value}")
346 else:
347 value = int(pattern, 2)
348 gen_checks.append(f"{gen_test} == {value}")
349 if index == 0:
350 self.emitter.append(f"if {' or '.join(gen_checks)}:")
351 else:
352 self.emitter.append(f"elif {' or '.join(gen_checks)}:")
353 with self.emitter.indent():
354 self(stmts)
355
356 def on_Assert(self, stmt):
357 raise NotImplementedError # :nocov:
358
359 def on_Assume(self, stmt):
360 raise NotImplementedError # :nocov:
361
362 def on_Cover(self, stmt):
363 raise NotImplementedError # :nocov:
364
365 @classmethod
366 def compile(cls, state, stmt):
367 output_indexes = [state.get_signal(signal) for signal in stmt._lhs_signals()]
368 emitter = _PythonEmitter()
369 for signal_index in output_indexes:
370 emitter.append(f"next_{signal_index} = slots[{signal_index}].next")
371 compiler = cls(state, emitter)
372 compiler(stmt)
373 for signal_index in output_indexes:
374 emitter.append(f"slots[{signal_index}].set(next_{signal_index})")
375 return emitter.flush()
376
377
378 class _FragmentCompiler:
379 def __init__(self, state):
380 self.state = state
381
382 def __call__(self, fragment):
383 processes = set()
384
385 for domain_name, domain_signals in fragment.drivers.items():
386 domain_stmts = LHSGroupFilter(domain_signals)(fragment.statements)
387 domain_process = PyRTLProcess(is_comb=domain_name is None)
388
389 emitter = _PythonEmitter()
390 emitter.append(f"def run():")
391 emitter._level += 1
392
393 if domain_name is None:
394 for signal in domain_signals:
395 signal_index = self.state.get_signal(signal)
396 emitter.append(f"next_{signal_index} = {signal.reset}")
397
398 inputs = SignalSet()
399 _StatementCompiler(self.state, emitter, inputs=inputs)(domain_stmts)
400
401 for input in inputs:
402 self.state.add_trigger(domain_process, input)
403
404 else:
405 domain = fragment.domains[domain_name]
406 clk_trigger = 1 if domain.clk_edge == "pos" else 0
407 self.state.add_trigger(domain_process, domain.clk, trigger=clk_trigger)
408 if domain.rst is not None and domain.async_reset:
409 rst_trigger = 1
410 self.state.add_trigger(domain_process, domain.rst, trigger=rst_trigger)
411
412 for signal in domain_signals:
413 signal_index = self.state.get_signal(signal)
414 emitter.append(f"next_{signal_index} = slots[{signal_index}].next")
415
416 _StatementCompiler(self.state, emitter)(domain_stmts)
417
418 for signal in domain_signals:
419 signal_index = self.state.get_signal(signal)
420 emitter.append(f"slots[{signal_index}].set(next_{signal_index})")
421
422 # There shouldn't be any exceptions raised by the generated code, but if there are
423 # (almost certainly due to a bug in the code generator), use this environment variable
424 # to make backtraces useful.
425 code = emitter.flush()
426 if os.getenv("NMIGEN_pysim_dump"):
427 file = tempfile.NamedTemporaryFile("w", prefix="nmigen_pysim_", delete=False)
428 file.write(code)
429 filename = file.name
430 else:
431 filename = "<string>"
432
433 exec_locals = {"slots": self.state.slots, **_ValueCompiler.helpers}
434 exec(compile(code, filename, "exec"), exec_locals)
435 domain_process.run = exec_locals["run"]
436
437 processes.add(domain_process)
438
439 for subfragment_index, (subfragment, subfragment_name) in enumerate(fragment.subfragments):
440 if subfragment_name is None:
441 subfragment_name = "U${}".format(subfragment_index)
442 processes.update(self(subfragment))
443
444 return processes