back.pysim: Reuse clock simulation commands
[nmigen.git] / nmigen / back / pysim.py
1 import os
2 import tempfile
3 import warnings
4 import inspect
5 from contextlib import contextmanager
6 import itertools
7 from vcd import VCDWriter
8 from vcd.gtkw import GTKWSave
9
10 from .._utils import deprecated
11 from ..hdl.ast import *
12 from ..hdl.cd import *
13 from ..hdl.ir import *
14 from ..hdl.xfrm import ValueVisitor, StatementVisitor, LHSGroupFilter
15
16
17 class Command:
18 pass
19
20
21 class Settle(Command):
22 def __repr__(self):
23 return "(settle)"
24
25
26 class Delay(Command):
27 def __init__(self, interval=None):
28 self.interval = None if interval is None else float(interval)
29
30 def __repr__(self):
31 if self.interval is None:
32 return "(delay ε)"
33 else:
34 return "(delay {:.3}us)".format(self.interval * 1e6)
35
36
37 class Tick(Command):
38 def __init__(self, domain="sync"):
39 if not isinstance(domain, (str, ClockDomain)):
40 raise TypeError("Domain must be a string or a ClockDomain instance, not {!r}"
41 .format(domain))
42 assert domain != "comb"
43 self.domain = domain
44
45 def __repr__(self):
46 return "(tick {})".format(self.domain)
47
48
49 class Passive(Command):
50 def __repr__(self):
51 return "(passive)"
52
53
54 class Active(Command):
55 def __repr__(self):
56 return "(active)"
57
58
59 class _WaveformWriter:
60 def update(self, timestamp, signal, value):
61 raise NotImplementedError # :nocov:
62
63 def close(self, timestamp):
64 raise NotImplementedError # :nocov:
65
66
67 class _VCDWaveformWriter(_WaveformWriter):
68 @staticmethod
69 def timestamp_to_vcd(timestamp):
70 return timestamp * (10 ** 10) # 1/(100 ps)
71
72 @staticmethod
73 def decode_to_vcd(signal, value):
74 return signal.decoder(value).expandtabs().replace(" ", "_")
75
76 def __init__(self, signal_names, *, vcd_file, gtkw_file=None, traces=()):
77 if isinstance(vcd_file, str):
78 vcd_file = open(vcd_file, "wt")
79 if isinstance(gtkw_file, str):
80 gtkw_file = open(gtkw_file, "wt")
81
82 self.vcd_vars = SignalDict()
83 self.vcd_file = vcd_file
84 self.vcd_writer = vcd_file and VCDWriter(self.vcd_file,
85 timescale="100 ps", comment="Generated by nMigen")
86
87 self.gtkw_names = SignalDict()
88 self.gtkw_file = gtkw_file
89 self.gtkw_save = gtkw_file and GTKWSave(self.gtkw_file)
90
91 self.traces = []
92
93 trace_names = SignalDict()
94 for trace in traces:
95 if trace not in signal_names:
96 trace_names[trace] = {("top", trace.name)}
97 self.traces.append(trace)
98
99 if self.vcd_writer is None:
100 return
101
102 for signal, names in itertools.chain(signal_names.items(), trace_names.items()):
103 if signal.decoder:
104 var_type = "string"
105 var_size = 1
106 var_init = self.decode_to_vcd(signal, signal.reset)
107 else:
108 var_type = "wire"
109 var_size = signal.width
110 var_init = signal.reset
111
112 for (*var_scope, var_name) in names:
113 suffix = None
114 while True:
115 try:
116 if suffix is None:
117 var_name_suffix = var_name
118 else:
119 var_name_suffix = "{}${}".format(var_name, suffix)
120 vcd_var = self.vcd_writer.register_var(
121 scope=var_scope, name=var_name_suffix,
122 var_type=var_type, size=var_size, init=var_init)
123 break
124 except KeyError:
125 suffix = (suffix or 0) + 1
126
127 if signal not in self.vcd_vars:
128 self.vcd_vars[signal] = set()
129 self.vcd_vars[signal].add(vcd_var)
130
131 if signal not in self.gtkw_names:
132 self.gtkw_names[signal] = (*var_scope, var_name_suffix)
133
134 def update(self, timestamp, signal, value):
135 if signal not in self.vcd_vars:
136 return
137
138 vcd_timestamp = self.timestamp_to_vcd(timestamp)
139 if signal.decoder:
140 var_value = self.decode_to_vcd(signal, value)
141 else:
142 var_value = value
143 for vcd_var in self.vcd_vars[signal]:
144 self.vcd_writer.change(vcd_var, vcd_timestamp, var_value)
145
146 def close(self, timestamp):
147 if self.vcd_writer is not None:
148 self.vcd_writer.close(self.timestamp_to_vcd(timestamp))
149
150 if self.gtkw_save is not None:
151 self.gtkw_save.dumpfile(self.vcd_file.name)
152 self.gtkw_save.dumpfile_size(self.vcd_file.tell())
153
154 self.gtkw_save.treeopen("top")
155 for signal in self.traces:
156 if len(signal) > 1 and not signal.decoder:
157 suffix = "[{}:0]".format(len(signal) - 1)
158 else:
159 suffix = ""
160 self.gtkw_save.trace(".".join(self.gtkw_names[signal]) + suffix)
161
162 if self.vcd_file is not None:
163 self.vcd_file.close()
164 if self.gtkw_file is not None:
165 self.gtkw_file.close()
166
167
168 class _Process:
169 __slots__ = ("runnable", "passive")
170
171 def reset(self):
172 raise NotImplementedError # :nocov:
173
174 def run(self):
175 raise NotImplementedError # :nocov:
176
177 @property
178 def name(self):
179 raise NotImplementedError # :nocov:
180
181
182 class _SignalState:
183 __slots__ = ("signal", "curr", "next", "waiters", "pending")
184
185 def __init__(self, signal, pending):
186 self.signal = signal
187 self.pending = pending
188 self.waiters = dict()
189 self.reset()
190
191 def reset(self):
192 self.curr = self.next = self.signal.reset
193
194 def set(self, value):
195 if self.next == value:
196 return
197 self.next = value
198 self.pending.add(self)
199
200 def wait(self, task, *, trigger=None):
201 assert task not in self.waiters
202 self.waiters[task] = trigger
203
204 def commit(self):
205 if self.curr == self.next:
206 return False
207 self.curr = self.next
208 return True
209
210 def wakeup(self):
211 awoken_any = False
212 for process, trigger in self.waiters.items():
213 if trigger is None or trigger == self.curr:
214 process.runnable = awoken_any = True
215 return awoken_any
216
217
218 class _SimulatorState:
219 def __init__(self):
220 self.signals = SignalDict()
221 self.pending = set()
222
223 self.timestamp = 0.0
224 self.deadlines = dict()
225
226 self.waveform_writer = None
227
228 def reset(self):
229 for signal_state in self.signals.values():
230 signal_state.reset()
231 self.pending.clear()
232
233 self.timestamp = 0.0
234 self.deadlines.clear()
235
236 def for_signal(self, signal):
237 try:
238 return self.signals[signal]
239 except KeyError:
240 signal_state = _SignalState(signal, self.pending)
241 self.signals[signal] = signal_state
242 return signal_state
243
244 def commit(self):
245 awoken_any = False
246 for signal_state in self.pending:
247 if signal_state.commit():
248 if signal_state.wakeup():
249 awoken_any = True
250 if self.waveform_writer is not None:
251 self.waveform_writer.update(self.timestamp,
252 signal_state.signal, signal_state.curr)
253 return awoken_any
254
255 def advance(self):
256 nearest_processes = set()
257 nearest_deadline = None
258 for process, deadline in self.deadlines.items():
259 if deadline is None:
260 if nearest_deadline is not None:
261 nearest_processes.clear()
262 nearest_processes.add(process)
263 nearest_deadline = self.timestamp
264 break
265 elif nearest_deadline is None or deadline <= nearest_deadline:
266 assert deadline >= self.timestamp
267 if nearest_deadline is not None and deadline < nearest_deadline:
268 nearest_processes.clear()
269 nearest_processes.add(process)
270 nearest_deadline = deadline
271
272 if not nearest_processes:
273 return False
274
275 for process in nearest_processes:
276 process.runnable = True
277 del self.deadlines[process]
278 self.timestamp = nearest_deadline
279
280 return True
281
282 def start_waveform(self, waveform_writer):
283 if self.timestamp != 0.0:
284 raise ValueError("Cannot start writing waveforms after advancing simulation time")
285 if self.waveform_writer is not None:
286 raise ValueError("Already writing waveforms to {!r}"
287 .format(self.waveform_writer))
288 self.waveform_writer = waveform_writer
289
290 def finish_waveform(self):
291 if self.waveform_writer is None:
292 return
293 self.waveform_writer.close(self.timestamp)
294 self.waveform_writer = None
295
296
297 class _EvalContext:
298 __slots__ = ("state", "indexes", "slots")
299
300 def __init__(self, state):
301 self.state = state
302 self.indexes = SignalDict()
303 self.slots = []
304
305 def get_signal(self, signal):
306 try:
307 return self.indexes[signal]
308 except KeyError:
309 index = len(self.slots)
310 self.slots.append(self.state.for_signal(signal))
311 self.indexes[signal] = index
312 return index
313
314 def get_in_signal(self, signal, *, trigger=None):
315 index = self.get_signal(signal)
316 self.slots[index].waiters[self] = trigger
317 return index
318
319 def get_out_signal(self, signal):
320 return self.get_signal(signal)
321
322
323 class _Emitter:
324 def __init__(self):
325 self._buffer = []
326 self._suffix = 0
327 self._level = 0
328
329 def append(self, code):
330 self._buffer.append(" " * self._level)
331 self._buffer.append(code)
332 self._buffer.append("\n")
333
334 @contextmanager
335 def indent(self):
336 self._level += 1
337 yield
338 self._level -= 1
339
340 def flush(self, indent=""):
341 code = "".join(self._buffer)
342 self._buffer.clear()
343 return code
344
345 def gen_var(self, prefix):
346 name = f"{prefix}_{self._suffix}"
347 self._suffix += 1
348 return name
349
350 def def_var(self, prefix, value):
351 name = self.gen_var(prefix)
352 self.append(f"{name} = {value}")
353 return name
354
355
356 class _Compiler:
357 def __init__(self, context, emitter):
358 self.context = context
359 self.emitter = emitter
360
361
362 class _ValueCompiler(ValueVisitor, _Compiler):
363 helpers = {
364 "sign": lambda value, sign: value | sign if value & sign else value,
365 "zdiv": lambda lhs, rhs: 0 if rhs == 0 else lhs // rhs,
366 "zmod": lambda lhs, rhs: 0 if rhs == 0 else lhs % rhs,
367 }
368
369 def on_ClockSignal(self, value):
370 raise NotImplementedError # :nocov:
371
372 def on_ResetSignal(self, value):
373 raise NotImplementedError # :nocov:
374
375 def on_Record(self, value):
376 return self(Cat(value.fields.values()))
377
378 def on_AnyConst(self, value):
379 raise NotImplementedError # :nocov:
380
381 def on_AnySeq(self, value):
382 raise NotImplementedError # :nocov:
383
384 def on_Sample(self, value):
385 raise NotImplementedError # :nocov:
386
387 def on_Initial(self, value):
388 raise NotImplementedError # :nocov:
389
390
391 class _RHSValueCompiler(_ValueCompiler):
392 def __init__(self, context, emitter, *, mode, inputs=None):
393 super().__init__(context, emitter)
394 assert mode in ("curr", "next")
395 self.mode = mode
396 # If not None, `inputs` gets populated with RHS signals.
397 self.inputs = inputs
398
399 def on_Const(self, value):
400 return f"{value.value}"
401
402 def on_Signal(self, value):
403 if self.inputs is not None:
404 self.inputs.add(value)
405
406 if self.mode == "curr":
407 return f"slots[{self.context.get_signal(value)}].{self.mode}"
408 else:
409 return f"next_{self.context.get_signal(value)}"
410
411 def on_Operator(self, value):
412 def mask(value):
413 value_mask = (1 << len(value)) - 1
414 return f"({self(value)} & {value_mask})"
415
416 def sign(value):
417 if value.shape().signed:
418 return f"sign({mask(value)}, {-1 << (len(value) - 1)})"
419 else: # unsigned
420 return mask(value)
421
422 if len(value.operands) == 1:
423 arg, = value.operands
424 if value.operator == "~":
425 return f"(~{self(arg)})"
426 if value.operator == "-":
427 return f"(-{self(arg)})"
428 if value.operator == "b":
429 return f"bool({mask(arg)})"
430 if value.operator == "r|":
431 return f"({mask(arg)} != 0)"
432 if value.operator == "r&":
433 return f"({mask(arg)} == {(1 << len(arg)) - 1})"
434 if value.operator == "r^":
435 # Believe it or not, this is the fastest way to compute a sideways XOR in Python.
436 return f"(format({mask(arg)}, 'b').count('1') % 2)"
437 if value.operator in ("u", "s"):
438 # These operators don't change the bit pattern, only its interpretation.
439 return self(arg)
440 elif len(value.operands) == 2:
441 lhs, rhs = value.operands
442 lhs_mask = (1 << len(lhs)) - 1
443 rhs_mask = (1 << len(rhs)) - 1
444 if value.operator == "+":
445 return f"({sign(lhs)} + {sign(rhs)})"
446 if value.operator == "-":
447 return f"({sign(lhs)} - {sign(rhs)})"
448 if value.operator == "*":
449 return f"({sign(lhs)} * {sign(rhs)})"
450 if value.operator == "//":
451 return f"zdiv({sign(lhs)}, {sign(rhs)})"
452 if value.operator == "%":
453 return f"zmod({sign(lhs)}, {sign(rhs)})"
454 if value.operator == "&":
455 return f"({self(lhs)} & {self(rhs)})"
456 if value.operator == "|":
457 return f"({self(lhs)} | {self(rhs)})"
458 if value.operator == "^":
459 return f"({self(lhs)} ^ {self(rhs)})"
460 if value.operator == "<<":
461 return f"({sign(lhs)} << {sign(rhs)})"
462 if value.operator == ">>":
463 return f"({sign(lhs)} >> {sign(rhs)})"
464 if value.operator == "==":
465 return f"({sign(lhs)} == {sign(rhs)})"
466 if value.operator == "!=":
467 return f"({sign(lhs)} != {sign(rhs)})"
468 if value.operator == "<":
469 return f"({sign(lhs)} < {sign(rhs)})"
470 if value.operator == "<=":
471 return f"({sign(lhs)} <= {sign(rhs)})"
472 if value.operator == ">":
473 return f"({sign(lhs)} > {sign(rhs)})"
474 if value.operator == ">=":
475 return f"({sign(lhs)} >= {sign(rhs)})"
476 elif len(value.operands) == 3:
477 if value.operator == "m":
478 sel, val1, val0 = value.operands
479 return f"({self(val1)} if {self(sel)} else {self(val0)})"
480 raise NotImplementedError("Operator '{}' not implemented".format(value.operator)) # :nocov:
481
482 def on_Slice(self, value):
483 return f"(({self(value.value)} >> {value.start}) & {(1 << len(value)) - 1})"
484
485 def on_Part(self, value):
486 offset_mask = (1 << len(value.offset)) - 1
487 offset = f"(({self(value.offset)} & {offset_mask}) * {value.stride})"
488 return f"({self(value.value)} >> {offset} & " \
489 f"{(1 << value.width) - 1})"
490
491 def on_Cat(self, value):
492 gen_parts = []
493 offset = 0
494 for part in value.parts:
495 part_mask = (1 << len(part)) - 1
496 gen_parts.append(f"(({self(part)} & {part_mask}) << {offset})")
497 offset += len(part)
498 if gen_parts:
499 return f"({' | '.join(gen_parts)})"
500 return f"0"
501
502 def on_Repl(self, value):
503 part_mask = (1 << len(value.value)) - 1
504 gen_part = self.emitter.def_var("repl", f"{self(value.value)} & {part_mask}")
505 gen_parts = []
506 offset = 0
507 for _ in range(value.count):
508 gen_parts.append(f"({gen_part} << {offset})")
509 offset += len(value.value)
510 if gen_parts:
511 return f"({' | '.join(gen_parts)})"
512 return f"0"
513
514 def on_ArrayProxy(self, value):
515 index_mask = (1 << len(value.index)) - 1
516 gen_index = self.emitter.def_var("rhs_index", f"{self(value.index)} & {index_mask}")
517 gen_value = self.emitter.gen_var("rhs_proxy")
518 if value.elems:
519 gen_elems = []
520 for index, elem in enumerate(value.elems):
521 if index == 0:
522 self.emitter.append(f"if {gen_index} == {index}:")
523 else:
524 self.emitter.append(f"elif {gen_index} == {index}:")
525 with self.emitter.indent():
526 self.emitter.append(f"{gen_value} = {self(elem)}")
527 self.emitter.append(f"else:")
528 with self.emitter.indent():
529 self.emitter.append(f"{gen_value} = {self(value.elems[-1])}")
530 return gen_value
531 else:
532 return f"0"
533
534 @classmethod
535 def compile(cls, context, value, *, mode, inputs=None):
536 emitter = _Emitter()
537 compiler = cls(context, emitter, mode=mode, inputs=inputs)
538 emitter.append(f"result = {compiler(value)}")
539 return emitter.flush()
540
541
542 class _LHSValueCompiler(_ValueCompiler):
543 def __init__(self, context, emitter, *, rhs, outputs=None):
544 super().__init__(context, emitter)
545 # `rrhs` is used to translate rvalues that are syntactically a part of an lvalue, e.g.
546 # the offset of a Part.
547 self.rrhs = rhs
548 # `lrhs` is used to translate the read part of a read-modify-write cycle during partial
549 # update of an lvalue.
550 self.lrhs = _RHSValueCompiler(context, emitter, mode="next", inputs=None)
551 # If not None, `outputs` gets populated with signals on LHS.
552 self.outputs = outputs
553
554 def on_Const(self, value):
555 raise TypeError # :nocov:
556
557 def on_Signal(self, value):
558 if self.outputs is not None:
559 self.outputs.add(value)
560
561 def gen(arg):
562 value_mask = (1 << len(value)) - 1
563 if value.shape().signed:
564 value_sign = f"sign({arg} & {value_mask}, {-1 << (len(value) - 1)})"
565 else: # unsigned
566 value_sign = f"{arg} & {value_mask}"
567 self.emitter.append(f"next_{self.context.get_out_signal(value)} = {value_sign}")
568 return gen
569
570 def on_Operator(self, value):
571 raise TypeError # :nocov:
572
573 def on_Slice(self, value):
574 def gen(arg):
575 width_mask = (1 << (value.stop - value.start)) - 1
576 self(value.value)(f"({self.lrhs(value.value)} & " \
577 f"{~(width_mask << value.start)} | " \
578 f"(({arg} & {width_mask}) << {value.start}))")
579 return gen
580
581 def on_Part(self, value):
582 def gen(arg):
583 width_mask = (1 << value.width) - 1
584 offset_mask = (1 << len(value.offset)) - 1
585 offset = f"(({self.rrhs(value.offset)} & {offset_mask}) * {value.stride})"
586 self(value.value)(f"({self.lrhs(value.value)} & " \
587 f"~({width_mask} << {offset}) | " \
588 f"(({arg} & {width_mask}) << {offset}))")
589 return gen
590
591 def on_Cat(self, value):
592 def gen(arg):
593 gen_arg = self.emitter.def_var("cat", arg)
594 gen_parts = []
595 offset = 0
596 for part in value.parts:
597 part_mask = (1 << len(part)) - 1
598 self(part)(f"(({gen_arg} >> {offset}) & {part_mask})")
599 offset += len(part)
600 return gen
601
602 def on_Repl(self, value):
603 raise TypeError # :nocov:
604
605 def on_ArrayProxy(self, value):
606 def gen(arg):
607 index_mask = (1 << len(value.index)) - 1
608 gen_index = self.emitter.def_var("index", f"{self.rrhs(value.index)} & {index_mask}")
609 if value.elems:
610 gen_elems = []
611 for index, elem in enumerate(value.elems):
612 if index == 0:
613 self.emitter.append(f"if {gen_index} == {index}:")
614 else:
615 self.emitter.append(f"elif {gen_index} == {index}:")
616 with self.emitter.indent():
617 self(elem)(arg)
618 self.emitter.append(f"else:")
619 with self.emitter.indent():
620 self(value.elems[-1])(arg)
621 else:
622 self.emitter.append(f"pass")
623 return gen
624
625 @classmethod
626 def compile(cls, context, stmt, *, inputs=None, outputs=None):
627 emitter = _Emitter()
628 compiler = cls(context, emitter, inputs=inputs, outputs=outputs)
629 compiler(stmt)
630 return emitter.flush()
631
632
633 class _StatementCompiler(StatementVisitor, _Compiler):
634 def __init__(self, context, emitter, *, inputs=None, outputs=None):
635 super().__init__(context, emitter)
636 self.rhs = _RHSValueCompiler(context, emitter, mode="curr", inputs=inputs)
637 self.lhs = _LHSValueCompiler(context, emitter, rhs=self.rhs, outputs=outputs)
638
639 def on_statements(self, stmts):
640 for stmt in stmts:
641 self(stmt)
642 if not stmts:
643 self.emitter.append("pass")
644
645 def on_Assign(self, stmt):
646 return self.lhs(stmt.lhs)(self.rhs(stmt.rhs))
647
648 def on_Switch(self, stmt):
649 gen_test = self.emitter.def_var("test",
650 f"{self.rhs(stmt.test)} & {(1 << len(stmt.test)) - 1}")
651 for index, (patterns, stmts) in enumerate(stmt.cases.items()):
652 gen_checks = []
653 if not patterns:
654 gen_checks.append(f"True")
655 else:
656 for pattern in patterns:
657 if "-" in pattern:
658 mask = int("".join("0" if b == "-" else "1" for b in pattern), 2)
659 value = int("".join("0" if b == "-" else b for b in pattern), 2)
660 gen_checks.append(f"({gen_test} & {mask}) == {value}")
661 else:
662 value = int(pattern, 2)
663 gen_checks.append(f"{gen_test} == {value}")
664 if index == 0:
665 self.emitter.append(f"if {' or '.join(gen_checks)}:")
666 else:
667 self.emitter.append(f"elif {' or '.join(gen_checks)}:")
668 with self.emitter.indent():
669 self(stmts)
670
671 def on_Assert(self, stmt):
672 raise NotImplementedError # :nocov:
673
674 def on_Assume(self, stmt):
675 raise NotImplementedError # :nocov:
676
677 def on_Cover(self, stmt):
678 raise NotImplementedError # :nocov:
679
680 @classmethod
681 def compile(cls, context, stmt, *, inputs=None, outputs=None):
682 output_indexes = [context.get_signal(signal) for signal in stmt._lhs_signals()]
683 emitter = _Emitter()
684 for signal_index in output_indexes:
685 emitter.append(f"next_{signal_index} = slots[{signal_index}].next")
686 compiler = cls(context, emitter, inputs=inputs, outputs=outputs)
687 compiler(stmt)
688 for signal_index in output_indexes:
689 emitter.append(f"slots[{signal_index}].set(next_{signal_index})")
690 return emitter.flush()
691
692
693 class _CompiledProcess(_Process):
694 __slots__ = ("context", "comb", "name", "run")
695
696 def __init__(self, state, *, comb, name):
697 self.context = _EvalContext(state)
698 self.comb = comb
699 self.name = name
700 self.run = None # set by _FragmentCompiler
701 self.reset()
702
703 def reset(self):
704 self.runnable = self.comb
705 self.passive = True
706
707
708 class _FragmentCompiler:
709 def __init__(self, state, signal_names):
710 self.state = state
711 self.signal_names = signal_names
712
713 def __call__(self, fragment, *, hierarchy=("top",)):
714 processes = set()
715
716 def add_signal_name(signal):
717 hierarchical_signal_name = (*hierarchy, signal.name)
718 if signal not in self.signal_names:
719 self.signal_names[signal] = {hierarchical_signal_name}
720 else:
721 self.signal_names[signal].add(hierarchical_signal_name)
722
723 for domain_name, domain_signals in fragment.drivers.items():
724 domain_stmts = LHSGroupFilter(domain_signals)(fragment.statements)
725 domain_process = _CompiledProcess(self.state, comb=domain_name is None,
726 name=".".join((*hierarchy, "<{}>".format(domain_name or "comb"))))
727
728 emitter = _Emitter()
729 emitter.append(f"def run():")
730 emitter._level += 1
731
732 if domain_name is None:
733 for signal in domain_signals:
734 signal_index = domain_process.context.get_signal(signal)
735 emitter.append(f"next_{signal_index} = {signal.reset}")
736
737 inputs = SignalSet()
738 _StatementCompiler(domain_process.context, emitter, inputs=inputs)(domain_stmts)
739
740 for input in inputs:
741 self.state.for_signal(input).wait(domain_process)
742
743 else:
744 domain = fragment.domains[domain_name]
745 add_signal_name(domain.clk)
746 if domain.rst is not None:
747 add_signal_name(domain.rst)
748
749 clk_trigger = 1 if domain.clk_edge == "pos" else 0
750 self.state.for_signal(domain.clk).wait(domain_process, trigger=clk_trigger)
751 if domain.rst is not None and domain.async_reset:
752 rst_trigger = 1
753 self.state.for_signal(domain.rst).wait(domain_process, trigger=rst_trigger)
754
755 gen_asserts = []
756 clk_index = domain_process.context.get_signal(domain.clk)
757 gen_asserts.append(f"slots[{clk_index}].curr == {clk_trigger}")
758 if domain.rst is not None and domain.async_reset:
759 rst_index = domain_process.context.get_signal(domain.rst)
760 gen_asserts.append(f"slots[{rst_index}].curr == {rst_trigger}")
761 emitter.append(f"assert {' or '.join(gen_asserts)}")
762
763 for signal in domain_signals:
764 signal_index = domain_process.context.get_signal(signal)
765 emitter.append(f"next_{signal_index} = slots[{signal_index}].next")
766
767 _StatementCompiler(domain_process.context, emitter)(domain_stmts)
768
769 for signal in domain_signals:
770 signal_index = domain_process.context.get_signal(signal)
771 emitter.append(f"slots[{signal_index}].set(next_{signal_index})")
772
773 # There shouldn't be any exceptions raised by the generated code, but if there are
774 # (almost certainly due to a bug in the code generator), use this environment variable
775 # to make backtraces useful.
776 code = emitter.flush()
777 if os.getenv("NMIGEN_pysim_dump"):
778 file = tempfile.NamedTemporaryFile("w", prefix="nmigen_pysim_", delete=False)
779 file.write(code)
780 filename = file.name
781 else:
782 filename = "<string>"
783
784 exec_locals = {"slots": domain_process.context.slots, **_ValueCompiler.helpers}
785 exec(compile(code, filename, "exec"), exec_locals)
786 domain_process.run = exec_locals["run"]
787
788 processes.add(domain_process)
789
790 for used_signal in domain_process.context.indexes:
791 add_signal_name(used_signal)
792
793 for subfragment_index, (subfragment, subfragment_name) in enumerate(fragment.subfragments):
794 if subfragment_name is None:
795 subfragment_name = "U${}".format(subfragment_index)
796 processes.update(self(subfragment, hierarchy=(*hierarchy, subfragment_name)))
797
798 return processes
799
800
801 class _CoroutineProcess(_Process):
802 def __init__(self, state, domains, constructor, *, default_cmd=None):
803 self.state = state
804 self.domains = domains
805 self.constructor = constructor
806 self.default_cmd = default_cmd
807 self.reset()
808
809 def reset(self):
810 self.runnable = True
811 self.passive = False
812 self.coroutine = self.constructor()
813 self.eval_context = _EvalContext(self.state)
814 self.exec_locals = {
815 "slots": self.eval_context.slots,
816 "result": None,
817 **_ValueCompiler.helpers
818 }
819 self.waits_on = set()
820
821 @property
822 def name(self):
823 coroutine = self.coroutine
824 while coroutine.gi_yieldfrom is not None:
825 coroutine = coroutine.gi_yieldfrom
826 if inspect.isgenerator(coroutine):
827 frame = coroutine.gi_frame
828 if inspect.iscoroutine(coroutine):
829 frame = coroutine.cr_frame
830 return "{}:{}".format(inspect.getfile(frame), inspect.getlineno(frame))
831
832 def get_in_signal(self, signal, *, trigger=None):
833 signal_state = self.state.for_signal(signal)
834 assert self not in signal_state.waiters
835 signal_state.waiters[self] = trigger
836 self.waits_on.add(signal_state)
837 return signal_state
838
839 def run(self):
840 if self.coroutine is None:
841 return
842
843 if self.waits_on:
844 for signal_state in self.waits_on:
845 del signal_state.waiters[self]
846 self.waits_on.clear()
847
848 response = None
849 while True:
850 try:
851 command = self.coroutine.send(response)
852 if command is None:
853 command = self.default_cmd
854 response = None
855
856 if isinstance(command, Value):
857 exec(_RHSValueCompiler.compile(self.eval_context, command, mode="curr"),
858 self.exec_locals)
859 response = Const.normalize(self.exec_locals["result"], command.shape())
860
861 elif isinstance(command, Statement):
862 exec(_StatementCompiler.compile(self.eval_context, command),
863 self.exec_locals)
864
865 elif type(command) is Tick:
866 domain = command.domain
867 if isinstance(domain, ClockDomain):
868 pass
869 elif domain in self.domains:
870 domain = self.domains[domain]
871 else:
872 raise NameError("Received command {!r} that refers to a nonexistent "
873 "domain {!r} from process {!r}"
874 .format(command, command.domain, self.name))
875 self.get_in_signal(domain.clk, trigger=1 if domain.clk_edge == "pos" else 0)
876 if domain.rst is not None and domain.async_reset:
877 self.get_in_signal(domain.rst, trigger=1)
878 return
879
880 elif type(command) is Settle:
881 self.state.deadlines[self] = None
882 return
883
884 elif type(command) is Delay:
885 if command.interval is None:
886 self.state.deadlines[self] = None
887 else:
888 self.state.deadlines[self] = self.state.timestamp + command.interval
889 return
890
891 elif type(command) is Passive:
892 self.passive = True
893
894 elif type(command) is Active:
895 self.passive = False
896
897 elif command is None: # only possible if self.default_cmd is None
898 raise TypeError("Received default command from process {!r} that was added "
899 "with add_process(); did you mean to add this process with "
900 "add_sync_process() instead?"
901 .format(self.name))
902
903 else:
904 raise TypeError("Received unsupported command {!r} from process {!r}"
905 .format(command, self.name))
906
907 except StopIteration:
908 self.passive = True
909 self.coroutine = None
910 return
911
912 except Exception as exn:
913 self.coroutine.throw(exn)
914
915
916 class _WaveformContextManager:
917 def __init__(self, state, waveform_writer):
918 self._state = state
919 self._waveform_writer = waveform_writer
920
921 def __enter__(self):
922 try:
923 self._state.start_waveform(self._waveform_writer)
924 except:
925 self._waveform_writer.close(0)
926 raise
927
928 def __exit__(self, *args):
929 self._state.finish_waveform()
930
931
932 class Simulator:
933 def __init__(self, fragment, **kwargs):
934 self._state = _SimulatorState()
935 self._signal_names = SignalDict()
936 self._fragment = Fragment.get(fragment, platform=None).prepare()
937 self._processes = _FragmentCompiler(self._state, self._signal_names)(self._fragment)
938 if kwargs: # :nocov:
939 # TODO(nmigen-0.3): remove
940 self._state.start_waveform(_VCDWaveformWriter(self._signal_names, **kwargs))
941 self._clocked = set()
942
943 def _check_process(self, process):
944 if not (inspect.isgeneratorfunction(process) or inspect.iscoroutinefunction(process)):
945 if inspect.isgenerator(process) or inspect.iscoroutine(process):
946 warnings.warn("instead of generators, use generator functions as processes; "
947 "this allows the simulator to be repeatedly reset",
948 DeprecationWarning, stacklevel=3)
949 def wrapper():
950 yield from process
951 return wrapper
952 else:
953 raise TypeError("Cannot add a process {!r} because it is not a generator function"
954 .format(process))
955 return process
956
957 def _add_coroutine_process(self, process, *, default_cmd):
958 self._processes.add(_CoroutineProcess(self._state, self._fragment.domains, process,
959 default_cmd=default_cmd))
960
961 def add_process(self, process):
962 process = self._check_process(process)
963 def wrapper():
964 # Only start a bench process after comb settling, so that the reset values are correct.
965 yield Settle()
966 yield from process()
967 self._add_coroutine_process(wrapper, default_cmd=None)
968
969 def add_sync_process(self, process, *, domain="sync"):
970 process = self._check_process(process)
971 def wrapper():
972 # Only start a sync process after the first clock edge (or reset edge, if the domain
973 # uses an asynchronous reset). This matches the behavior of synchronous FFs.
974 yield Tick(domain)
975 yield from process()
976 return self._add_coroutine_process(wrapper, default_cmd=Tick(domain))
977
978 def add_clock(self, period, *, phase=None, domain="sync", if_exists=False):
979 """Add a clock process.
980
981 Adds a process that drives the clock signal of ``domain`` at a 50% duty cycle.
982
983 Arguments
984 ---------
985 period : float
986 Clock period. The process will toggle the ``domain`` clock signal every ``period / 2``
987 seconds.
988 phase : None or float
989 Clock phase. The process will wait ``phase`` seconds before the first clock transition.
990 If not specified, defaults to ``period / 2``.
991 domain : str or ClockDomain
992 Driven clock domain. If specified as a string, the domain with that name is looked up
993 in the root fragment of the simulation.
994 if_exists : bool
995 If ``False`` (the default), raise an error if the driven domain is specified as
996 a string and the root fragment does not have such a domain. If ``True``, do nothing
997 in this case.
998 """
999 if isinstance(domain, ClockDomain):
1000 pass
1001 elif domain in self._fragment.domains:
1002 domain = self._fragment.domains[domain]
1003 elif if_exists:
1004 return
1005 else:
1006 raise ValueError("Domain {!r} is not present in simulation"
1007 .format(domain))
1008 if domain in self._clocked:
1009 raise ValueError("Domain {!r} already has a clock driving it"
1010 .format(domain.name))
1011
1012 half_period = period / 2
1013 if phase is None:
1014 # By default, delay the first edge by half period. This causes any synchronous activity
1015 # to happen at a non-zero time, distinguishing it from the reset values in the waveform
1016 # viewer.
1017 phase = half_period
1018 def clk_process():
1019 yield Passive()
1020 yield Delay(phase)
1021 # Behave correctly if the process is added after the clock signal is manipulated, or if
1022 # its reset state is high.
1023 initial = (yield domain.clk)
1024 steps = (
1025 domain.clk.eq(~initial),
1026 Delay(half_period),
1027 domain.clk.eq(initial),
1028 Delay(half_period),
1029 )
1030 while True:
1031 yield from iter(steps)
1032 self._add_coroutine_process(clk_process, default_cmd=None)
1033 self._clocked.add(domain)
1034
1035 def reset(self):
1036 """Reset the simulation.
1037
1038 Assign the reset value to every signal in the simulation, and restart every user process.
1039 """
1040 self._state.reset()
1041 for process in self._processes:
1042 process.reset()
1043
1044 def _delta(self):
1045 """Perform a delta cycle.
1046
1047 Performs the two phases of a delta cycle:
1048 1. run and suspend every non-waiting process once, queueing signal changes;
1049 2. commit every queued signal change, waking up any waiting process.
1050 """
1051 for process in self._processes:
1052 if process.runnable:
1053 process.runnable = False
1054 process.run()
1055
1056 return self._state.commit()
1057
1058 def _settle(self):
1059 """Settle the simulation.
1060
1061 Run every process and commit changes until a fixed point is reached. If there is
1062 an unstable combinatorial loop, this function will never return.
1063 """
1064 while self._delta():
1065 pass
1066
1067 def step(self):
1068 """Step the simulation.
1069
1070 Run every process and commit changes until a fixed point is reached, then advance time
1071 to the closest deadline (if any). If there is an unstable combinatorial loop,
1072 this function will never return.
1073
1074 Returns ``True`` if there are any active processes, ``False`` otherwise.
1075 """
1076 self._settle()
1077 self._state.advance()
1078 return any(not process.passive for process in self._processes)
1079
1080 def run(self):
1081 """Run the simulation while any processes are active.
1082
1083 Processes added with :meth:`add_process` and :meth:`add_sync_process` are initially active,
1084 and may change their status using the ``yield Passive()`` and ``yield Active()`` commands.
1085 Processes compiled from HDL and added with :meth:`add_clock` are always passive.
1086 """
1087 while self.step():
1088 pass
1089
1090 def run_until(self, deadline, *, run_passive=False):
1091 """Run the simulation until it advances to ``deadline``.
1092
1093 If ``run_passive`` is ``False``, the simulation also stops when there are no active
1094 processes, similar to :meth:`run`. Otherwise, the simulation will stop only after it
1095 advances to or past ``deadline``.
1096
1097 If the simulation stops advancing, this function will never return.
1098 """
1099 assert self._state.timestamp <= deadline
1100 while (self.step() or run_passive) and self._state.timestamp < deadline:
1101 pass
1102
1103 def write_vcd(self, vcd_file, gtkw_file=None, *, traces=()):
1104 """Write waveforms to a Value Change Dump file, optionally populating a GTKWave save file.
1105
1106 This method returns a context manager. It can be used as: ::
1107
1108 sim = Simulator(frag)
1109 sim.add_clock(1e-6)
1110 with sim.write_vcd("dump.vcd", "dump.gtkw"):
1111 sim.run_until(1e-3)
1112
1113 Arguments
1114 ---------
1115 vcd_file : str or file-like object
1116 Verilog Value Change Dump file or filename.
1117 gtkw_file : str or file-like object
1118 GTKWave save file or filename.
1119 traces : iterable of Signal
1120 Signals to display traces for.
1121 """
1122 waveform_writer = _VCDWaveformWriter(self._signal_names,
1123 vcd_file=vcd_file, gtkw_file=gtkw_file, traces=traces)
1124 return _WaveformContextManager(self._state, waveform_writer)
1125
1126 # TODO(nmigen-0.3): remove
1127 @deprecated("instead of `with Simulator(fragment, ...) as sim:`, use "
1128 "`sim = Simulator(fragment); with sim.write_vcd(...):`")
1129 def __enter__(self): # :nocov:
1130 return self
1131
1132 # TODO(nmigen-0.3): remove
1133 def __exit__(self, *args): # :nocov:
1134 self._state.finish_waveform()