hdl.rec: make Record inherit from UserValue.
[nmigen.git] / nmigen / back / pysim.py
1 import os
2 import tempfile
3 import warnings
4 import inspect
5 from contextlib import contextmanager
6 import itertools
7 from vcd import VCDWriter
8 from vcd.gtkw import GTKWSave
9
10 from .._utils import deprecated
11 from ..hdl.ast import *
12 from ..hdl.cd import *
13 from ..hdl.ir import *
14 from ..hdl.xfrm import ValueVisitor, StatementVisitor, LHSGroupFilter
15
16
17 class Command:
18 pass
19
20
21 class Settle(Command):
22 def __repr__(self):
23 return "(settle)"
24
25
26 class Delay(Command):
27 def __init__(self, interval=None):
28 self.interval = None if interval is None else float(interval)
29
30 def __repr__(self):
31 if self.interval is None:
32 return "(delay ε)"
33 else:
34 return "(delay {:.3}us)".format(self.interval * 1e6)
35
36
37 class Tick(Command):
38 def __init__(self, domain="sync"):
39 if not isinstance(domain, (str, ClockDomain)):
40 raise TypeError("Domain must be a string or a ClockDomain instance, not {!r}"
41 .format(domain))
42 assert domain != "comb"
43 self.domain = domain
44
45 def __repr__(self):
46 return "(tick {})".format(self.domain)
47
48
49 class Passive(Command):
50 def __repr__(self):
51 return "(passive)"
52
53
54 class Active(Command):
55 def __repr__(self):
56 return "(active)"
57
58
59 class _WaveformWriter:
60 def update(self, timestamp, signal, value):
61 raise NotImplementedError # :nocov:
62
63 def close(self, timestamp):
64 raise NotImplementedError # :nocov:
65
66
67 class _VCDWaveformWriter(_WaveformWriter):
68 @staticmethod
69 def timestamp_to_vcd(timestamp):
70 return timestamp * (10 ** 10) # 1/(100 ps)
71
72 @staticmethod
73 def decode_to_vcd(signal, value):
74 return signal.decoder(value).expandtabs().replace(" ", "_")
75
76 def __init__(self, signal_names, *, vcd_file, gtkw_file=None, traces=()):
77 if isinstance(vcd_file, str):
78 vcd_file = open(vcd_file, "wt")
79 if isinstance(gtkw_file, str):
80 gtkw_file = open(gtkw_file, "wt")
81
82 self.vcd_vars = SignalDict()
83 self.vcd_file = vcd_file
84 self.vcd_writer = vcd_file and VCDWriter(self.vcd_file,
85 timescale="100 ps", comment="Generated by nMigen")
86
87 self.gtkw_names = SignalDict()
88 self.gtkw_file = gtkw_file
89 self.gtkw_save = gtkw_file and GTKWSave(self.gtkw_file)
90
91 self.traces = []
92
93 trace_names = SignalDict()
94 for trace in traces:
95 if trace not in signal_names:
96 trace_names[trace] = {("top", trace.name)}
97 self.traces.append(trace)
98
99 if self.vcd_writer is None:
100 return
101
102 for signal, names in itertools.chain(signal_names.items(), trace_names.items()):
103 if signal.decoder:
104 var_type = "string"
105 var_size = 1
106 var_init = self.decode_to_vcd(signal, signal.reset)
107 else:
108 var_type = "wire"
109 var_size = signal.width
110 var_init = signal.reset
111
112 for (*var_scope, var_name) in names:
113 suffix = None
114 while True:
115 try:
116 if suffix is None:
117 var_name_suffix = var_name
118 else:
119 var_name_suffix = "{}${}".format(var_name, suffix)
120 vcd_var = self.vcd_writer.register_var(
121 scope=var_scope, name=var_name_suffix,
122 var_type=var_type, size=var_size, init=var_init)
123 break
124 except KeyError:
125 suffix = (suffix or 0) + 1
126
127 if signal not in self.vcd_vars:
128 self.vcd_vars[signal] = set()
129 self.vcd_vars[signal].add(vcd_var)
130
131 if signal not in self.gtkw_names:
132 self.gtkw_names[signal] = (*var_scope, var_name_suffix)
133
134 def update(self, timestamp, signal, value):
135 vcd_vars = self.vcd_vars.get(signal)
136 if vcd_vars is None:
137 return
138
139 vcd_timestamp = self.timestamp_to_vcd(timestamp)
140 if signal.decoder:
141 var_value = self.decode_to_vcd(signal, value)
142 else:
143 var_value = value
144 for vcd_var in vcd_vars:
145 self.vcd_writer.change(vcd_var, vcd_timestamp, var_value)
146
147 def close(self, timestamp):
148 if self.vcd_writer is not None:
149 self.vcd_writer.close(self.timestamp_to_vcd(timestamp))
150
151 if self.gtkw_save is not None:
152 self.gtkw_save.dumpfile(self.vcd_file.name)
153 self.gtkw_save.dumpfile_size(self.vcd_file.tell())
154
155 self.gtkw_save.treeopen("top")
156 for signal in self.traces:
157 if len(signal) > 1 and not signal.decoder:
158 suffix = "[{}:0]".format(len(signal) - 1)
159 else:
160 suffix = ""
161 self.gtkw_save.trace(".".join(self.gtkw_names[signal]) + suffix)
162
163 if self.vcd_file is not None:
164 self.vcd_file.close()
165 if self.gtkw_file is not None:
166 self.gtkw_file.close()
167
168
169 class _Process:
170 __slots__ = ("runnable", "passive")
171
172 def reset(self):
173 raise NotImplementedError # :nocov:
174
175 def run(self):
176 raise NotImplementedError # :nocov:
177
178 @property
179 def name(self):
180 raise NotImplementedError # :nocov:
181
182
183 class _SignalState:
184 __slots__ = ("signal", "curr", "next", "waiters", "pending")
185
186 def __init__(self, signal, pending):
187 self.signal = signal
188 self.pending = pending
189 self.waiters = dict()
190 self.reset()
191
192 def reset(self):
193 self.curr = self.next = self.signal.reset
194
195 def set(self, value):
196 if self.next == value:
197 return
198 self.next = value
199 self.pending.add(self)
200
201 def wait(self, task, *, trigger=None):
202 assert task not in self.waiters
203 self.waiters[task] = trigger
204
205 def commit(self):
206 if self.curr == self.next:
207 return False
208 self.curr = self.next
209 return True
210
211 def wakeup(self):
212 awoken_any = False
213 for process, trigger in self.waiters.items():
214 if trigger is None or trigger == self.curr:
215 process.runnable = awoken_any = True
216 return awoken_any
217
218
219 class _SimulatorState:
220 def __init__(self):
221 self.signals = SignalDict()
222 self.pending = set()
223
224 self.timestamp = 0.0
225 self.deadlines = dict()
226
227 self.waveform_writer = None
228
229 def reset(self):
230 for signal_state in self.signals.values():
231 signal_state.reset()
232 self.pending.clear()
233
234 self.timestamp = 0.0
235 self.deadlines.clear()
236
237 def for_signal(self, signal):
238 try:
239 return self.signals[signal]
240 except KeyError:
241 signal_state = _SignalState(signal, self.pending)
242 self.signals[signal] = signal_state
243 return signal_state
244
245 def commit(self):
246 awoken_any = False
247 for signal_state in self.pending:
248 if signal_state.commit():
249 if signal_state.wakeup():
250 awoken_any = True
251 if self.waveform_writer is not None:
252 self.waveform_writer.update(self.timestamp,
253 signal_state.signal, signal_state.curr)
254 self.pending.clear()
255 return awoken_any
256
257 def advance(self):
258 nearest_processes = set()
259 nearest_deadline = None
260 for process, deadline in self.deadlines.items():
261 if deadline is None:
262 if nearest_deadline is not None:
263 nearest_processes.clear()
264 nearest_processes.add(process)
265 nearest_deadline = self.timestamp
266 break
267 elif nearest_deadline is None or deadline <= nearest_deadline:
268 assert deadline >= self.timestamp
269 if nearest_deadline is not None and deadline < nearest_deadline:
270 nearest_processes.clear()
271 nearest_processes.add(process)
272 nearest_deadline = deadline
273
274 if not nearest_processes:
275 return False
276
277 for process in nearest_processes:
278 process.runnable = True
279 del self.deadlines[process]
280 self.timestamp = nearest_deadline
281
282 return True
283
284 def start_waveform(self, waveform_writer):
285 if self.timestamp != 0.0:
286 raise ValueError("Cannot start writing waveforms after advancing simulation time")
287 if self.waveform_writer is not None:
288 raise ValueError("Already writing waveforms to {!r}"
289 .format(self.waveform_writer))
290 self.waveform_writer = waveform_writer
291
292 def finish_waveform(self):
293 if self.waveform_writer is None:
294 return
295 self.waveform_writer.close(self.timestamp)
296 self.waveform_writer = None
297
298
299 class _EvalContext:
300 __slots__ = ("state", "indexes", "slots")
301
302 def __init__(self, state):
303 self.state = state
304 self.indexes = SignalDict()
305 self.slots = []
306
307 def get_signal(self, signal):
308 try:
309 return self.indexes[signal]
310 except KeyError:
311 index = len(self.slots)
312 self.slots.append(self.state.for_signal(signal))
313 self.indexes[signal] = index
314 return index
315
316 def get_in_signal(self, signal, *, trigger=None):
317 index = self.get_signal(signal)
318 self.slots[index].waiters[self] = trigger
319 return index
320
321 def get_out_signal(self, signal):
322 return self.get_signal(signal)
323
324
325 class _Emitter:
326 def __init__(self):
327 self._buffer = []
328 self._suffix = 0
329 self._level = 0
330
331 def append(self, code):
332 self._buffer.append(" " * self._level)
333 self._buffer.append(code)
334 self._buffer.append("\n")
335
336 @contextmanager
337 def indent(self):
338 self._level += 1
339 yield
340 self._level -= 1
341
342 def flush(self, indent=""):
343 code = "".join(self._buffer)
344 self._buffer.clear()
345 return code
346
347 def gen_var(self, prefix):
348 name = f"{prefix}_{self._suffix}"
349 self._suffix += 1
350 return name
351
352 def def_var(self, prefix, value):
353 name = self.gen_var(prefix)
354 self.append(f"{name} = {value}")
355 return name
356
357
358 class _Compiler:
359 def __init__(self, context, emitter):
360 self.context = context
361 self.emitter = emitter
362
363
364 class _ValueCompiler(ValueVisitor, _Compiler):
365 helpers = {
366 "sign": lambda value, sign: value | sign if value & sign else value,
367 "zdiv": lambda lhs, rhs: 0 if rhs == 0 else lhs // rhs,
368 "zmod": lambda lhs, rhs: 0 if rhs == 0 else lhs % rhs,
369 }
370
371 def on_ClockSignal(self, value):
372 raise NotImplementedError # :nocov:
373
374 def on_ResetSignal(self, value):
375 raise NotImplementedError # :nocov:
376
377 def on_AnyConst(self, value):
378 raise NotImplementedError # :nocov:
379
380 def on_AnySeq(self, value):
381 raise NotImplementedError # :nocov:
382
383 def on_Sample(self, value):
384 raise NotImplementedError # :nocov:
385
386 def on_Initial(self, value):
387 raise NotImplementedError # :nocov:
388
389
390 class _RHSValueCompiler(_ValueCompiler):
391 def __init__(self, context, emitter, *, mode, inputs=None):
392 super().__init__(context, emitter)
393 assert mode in ("curr", "next")
394 self.mode = mode
395 # If not None, `inputs` gets populated with RHS signals.
396 self.inputs = inputs
397
398 def on_Const(self, value):
399 return f"{value.value}"
400
401 def on_Signal(self, value):
402 if self.inputs is not None:
403 self.inputs.add(value)
404
405 if self.mode == "curr":
406 return f"slots[{self.context.get_signal(value)}].{self.mode}"
407 else:
408 return f"next_{self.context.get_signal(value)}"
409
410 def on_Operator(self, value):
411 def mask(value):
412 value_mask = (1 << len(value)) - 1
413 return f"({self(value)} & {value_mask})"
414
415 def sign(value):
416 if value.shape().signed:
417 return f"sign({mask(value)}, {-1 << (len(value) - 1)})"
418 else: # unsigned
419 return mask(value)
420
421 if len(value.operands) == 1:
422 arg, = value.operands
423 if value.operator == "~":
424 return f"(~{self(arg)})"
425 if value.operator == "-":
426 return f"(-{self(arg)})"
427 if value.operator == "b":
428 return f"bool({mask(arg)})"
429 if value.operator == "r|":
430 return f"({mask(arg)} != 0)"
431 if value.operator == "r&":
432 return f"({mask(arg)} == {(1 << len(arg)) - 1})"
433 if value.operator == "r^":
434 # Believe it or not, this is the fastest way to compute a sideways XOR in Python.
435 return f"(format({mask(arg)}, 'b').count('1') % 2)"
436 if value.operator in ("u", "s"):
437 # These operators don't change the bit pattern, only its interpretation.
438 return self(arg)
439 elif len(value.operands) == 2:
440 lhs, rhs = value.operands
441 lhs_mask = (1 << len(lhs)) - 1
442 rhs_mask = (1 << len(rhs)) - 1
443 if value.operator == "+":
444 return f"({sign(lhs)} + {sign(rhs)})"
445 if value.operator == "-":
446 return f"({sign(lhs)} - {sign(rhs)})"
447 if value.operator == "*":
448 return f"({sign(lhs)} * {sign(rhs)})"
449 if value.operator == "//":
450 return f"zdiv({sign(lhs)}, {sign(rhs)})"
451 if value.operator == "%":
452 return f"zmod({sign(lhs)}, {sign(rhs)})"
453 if value.operator == "&":
454 return f"({self(lhs)} & {self(rhs)})"
455 if value.operator == "|":
456 return f"({self(lhs)} | {self(rhs)})"
457 if value.operator == "^":
458 return f"({self(lhs)} ^ {self(rhs)})"
459 if value.operator == "<<":
460 return f"({sign(lhs)} << {sign(rhs)})"
461 if value.operator == ">>":
462 return f"({sign(lhs)} >> {sign(rhs)})"
463 if value.operator == "==":
464 return f"({sign(lhs)} == {sign(rhs)})"
465 if value.operator == "!=":
466 return f"({sign(lhs)} != {sign(rhs)})"
467 if value.operator == "<":
468 return f"({sign(lhs)} < {sign(rhs)})"
469 if value.operator == "<=":
470 return f"({sign(lhs)} <= {sign(rhs)})"
471 if value.operator == ">":
472 return f"({sign(lhs)} > {sign(rhs)})"
473 if value.operator == ">=":
474 return f"({sign(lhs)} >= {sign(rhs)})"
475 elif len(value.operands) == 3:
476 if value.operator == "m":
477 sel, val1, val0 = value.operands
478 return f"({self(val1)} if {self(sel)} else {self(val0)})"
479 raise NotImplementedError("Operator '{}' not implemented".format(value.operator)) # :nocov:
480
481 def on_Slice(self, value):
482 return f"(({self(value.value)} >> {value.start}) & {(1 << len(value)) - 1})"
483
484 def on_Part(self, value):
485 offset_mask = (1 << len(value.offset)) - 1
486 offset = f"(({self(value.offset)} & {offset_mask}) * {value.stride})"
487 return f"({self(value.value)} >> {offset} & " \
488 f"{(1 << value.width) - 1})"
489
490 def on_Cat(self, value):
491 gen_parts = []
492 offset = 0
493 for part in value.parts:
494 part_mask = (1 << len(part)) - 1
495 gen_parts.append(f"(({self(part)} & {part_mask}) << {offset})")
496 offset += len(part)
497 if gen_parts:
498 return f"({' | '.join(gen_parts)})"
499 return f"0"
500
501 def on_Repl(self, value):
502 part_mask = (1 << len(value.value)) - 1
503 gen_part = self.emitter.def_var("repl", f"{self(value.value)} & {part_mask}")
504 gen_parts = []
505 offset = 0
506 for _ in range(value.count):
507 gen_parts.append(f"({gen_part} << {offset})")
508 offset += len(value.value)
509 if gen_parts:
510 return f"({' | '.join(gen_parts)})"
511 return f"0"
512
513 def on_ArrayProxy(self, value):
514 index_mask = (1 << len(value.index)) - 1
515 gen_index = self.emitter.def_var("rhs_index", f"{self(value.index)} & {index_mask}")
516 gen_value = self.emitter.gen_var("rhs_proxy")
517 if value.elems:
518 gen_elems = []
519 for index, elem in enumerate(value.elems):
520 if index == 0:
521 self.emitter.append(f"if {gen_index} == {index}:")
522 else:
523 self.emitter.append(f"elif {gen_index} == {index}:")
524 with self.emitter.indent():
525 self.emitter.append(f"{gen_value} = {self(elem)}")
526 self.emitter.append(f"else:")
527 with self.emitter.indent():
528 self.emitter.append(f"{gen_value} = {self(value.elems[-1])}")
529 return gen_value
530 else:
531 return f"0"
532
533 @classmethod
534 def compile(cls, context, value, *, mode, inputs=None):
535 emitter = _Emitter()
536 compiler = cls(context, emitter, mode=mode, inputs=inputs)
537 emitter.append(f"result = {compiler(value)}")
538 return emitter.flush()
539
540
541 class _LHSValueCompiler(_ValueCompiler):
542 def __init__(self, context, emitter, *, rhs, outputs=None):
543 super().__init__(context, emitter)
544 # `rrhs` is used to translate rvalues that are syntactically a part of an lvalue, e.g.
545 # the offset of a Part.
546 self.rrhs = rhs
547 # `lrhs` is used to translate the read part of a read-modify-write cycle during partial
548 # update of an lvalue.
549 self.lrhs = _RHSValueCompiler(context, emitter, mode="next", inputs=None)
550 # If not None, `outputs` gets populated with signals on LHS.
551 self.outputs = outputs
552
553 def on_Const(self, value):
554 raise TypeError # :nocov:
555
556 def on_Signal(self, value):
557 if self.outputs is not None:
558 self.outputs.add(value)
559
560 def gen(arg):
561 value_mask = (1 << len(value)) - 1
562 if value.shape().signed:
563 value_sign = f"sign({arg} & {value_mask}, {-1 << (len(value) - 1)})"
564 else: # unsigned
565 value_sign = f"{arg} & {value_mask}"
566 self.emitter.append(f"next_{self.context.get_out_signal(value)} = {value_sign}")
567 return gen
568
569 def on_Operator(self, value):
570 raise TypeError # :nocov:
571
572 def on_Slice(self, value):
573 def gen(arg):
574 width_mask = (1 << (value.stop - value.start)) - 1
575 self(value.value)(f"({self.lrhs(value.value)} & " \
576 f"{~(width_mask << value.start)} | " \
577 f"(({arg} & {width_mask}) << {value.start}))")
578 return gen
579
580 def on_Part(self, value):
581 def gen(arg):
582 width_mask = (1 << value.width) - 1
583 offset_mask = (1 << len(value.offset)) - 1
584 offset = f"(({self.rrhs(value.offset)} & {offset_mask}) * {value.stride})"
585 self(value.value)(f"({self.lrhs(value.value)} & " \
586 f"~({width_mask} << {offset}) | " \
587 f"(({arg} & {width_mask}) << {offset}))")
588 return gen
589
590 def on_Cat(self, value):
591 def gen(arg):
592 gen_arg = self.emitter.def_var("cat", arg)
593 gen_parts = []
594 offset = 0
595 for part in value.parts:
596 part_mask = (1 << len(part)) - 1
597 self(part)(f"(({gen_arg} >> {offset}) & {part_mask})")
598 offset += len(part)
599 return gen
600
601 def on_Repl(self, value):
602 raise TypeError # :nocov:
603
604 def on_ArrayProxy(self, value):
605 def gen(arg):
606 index_mask = (1 << len(value.index)) - 1
607 gen_index = self.emitter.def_var("index", f"{self.rrhs(value.index)} & {index_mask}")
608 if value.elems:
609 gen_elems = []
610 for index, elem in enumerate(value.elems):
611 if index == 0:
612 self.emitter.append(f"if {gen_index} == {index}:")
613 else:
614 self.emitter.append(f"elif {gen_index} == {index}:")
615 with self.emitter.indent():
616 self(elem)(arg)
617 self.emitter.append(f"else:")
618 with self.emitter.indent():
619 self(value.elems[-1])(arg)
620 else:
621 self.emitter.append(f"pass")
622 return gen
623
624 @classmethod
625 def compile(cls, context, stmt, *, inputs=None, outputs=None):
626 emitter = _Emitter()
627 compiler = cls(context, emitter, inputs=inputs, outputs=outputs)
628 compiler(stmt)
629 return emitter.flush()
630
631
632 class _StatementCompiler(StatementVisitor, _Compiler):
633 def __init__(self, context, emitter, *, inputs=None, outputs=None):
634 super().__init__(context, emitter)
635 self.rhs = _RHSValueCompiler(context, emitter, mode="curr", inputs=inputs)
636 self.lhs = _LHSValueCompiler(context, emitter, rhs=self.rhs, outputs=outputs)
637
638 def on_statements(self, stmts):
639 for stmt in stmts:
640 self(stmt)
641 if not stmts:
642 self.emitter.append("pass")
643
644 def on_Assign(self, stmt):
645 return self.lhs(stmt.lhs)(self.rhs(stmt.rhs))
646
647 def on_Switch(self, stmt):
648 gen_test = self.emitter.def_var("test",
649 f"{self.rhs(stmt.test)} & {(1 << len(stmt.test)) - 1}")
650 for index, (patterns, stmts) in enumerate(stmt.cases.items()):
651 gen_checks = []
652 if not patterns:
653 gen_checks.append(f"True")
654 else:
655 for pattern in patterns:
656 if "-" in pattern:
657 mask = int("".join("0" if b == "-" else "1" for b in pattern), 2)
658 value = int("".join("0" if b == "-" else b for b in pattern), 2)
659 gen_checks.append(f"({gen_test} & {mask}) == {value}")
660 else:
661 value = int(pattern, 2)
662 gen_checks.append(f"{gen_test} == {value}")
663 if index == 0:
664 self.emitter.append(f"if {' or '.join(gen_checks)}:")
665 else:
666 self.emitter.append(f"elif {' or '.join(gen_checks)}:")
667 with self.emitter.indent():
668 self(stmts)
669
670 def on_Assert(self, stmt):
671 raise NotImplementedError # :nocov:
672
673 def on_Assume(self, stmt):
674 raise NotImplementedError # :nocov:
675
676 def on_Cover(self, stmt):
677 raise NotImplementedError # :nocov:
678
679 @classmethod
680 def compile(cls, context, stmt, *, inputs=None, outputs=None):
681 output_indexes = [context.get_signal(signal) for signal in stmt._lhs_signals()]
682 emitter = _Emitter()
683 for signal_index in output_indexes:
684 emitter.append(f"next_{signal_index} = slots[{signal_index}].next")
685 compiler = cls(context, emitter, inputs=inputs, outputs=outputs)
686 compiler(stmt)
687 for signal_index in output_indexes:
688 emitter.append(f"slots[{signal_index}].set(next_{signal_index})")
689 return emitter.flush()
690
691
692 class _CompiledProcess(_Process):
693 __slots__ = ("context", "comb", "name", "run")
694
695 def __init__(self, state, *, comb, name):
696 self.context = _EvalContext(state)
697 self.comb = comb
698 self.name = name
699 self.run = None # set by _FragmentCompiler
700 self.reset()
701
702 def reset(self):
703 self.runnable = self.comb
704 self.passive = True
705
706
707 class _FragmentCompiler:
708 def __init__(self, state, signal_names):
709 self.state = state
710 self.signal_names = signal_names
711
712 def __call__(self, fragment, *, hierarchy=("top",)):
713 processes = set()
714
715 def add_signal_name(signal):
716 hierarchical_signal_name = (*hierarchy, signal.name)
717 if signal not in self.signal_names:
718 self.signal_names[signal] = {hierarchical_signal_name}
719 else:
720 self.signal_names[signal].add(hierarchical_signal_name)
721
722 for domain_name, domain_signals in fragment.drivers.items():
723 domain_stmts = LHSGroupFilter(domain_signals)(fragment.statements)
724 domain_process = _CompiledProcess(self.state, comb=domain_name is None,
725 name=".".join((*hierarchy, "<{}>".format(domain_name or "comb"))))
726
727 emitter = _Emitter()
728 emitter.append(f"def run():")
729 emitter._level += 1
730
731 if domain_name is None:
732 for signal in domain_signals:
733 signal_index = domain_process.context.get_signal(signal)
734 emitter.append(f"next_{signal_index} = {signal.reset}")
735
736 inputs = SignalSet()
737 _StatementCompiler(domain_process.context, emitter, inputs=inputs)(domain_stmts)
738
739 for input in inputs:
740 self.state.for_signal(input).wait(domain_process)
741
742 else:
743 domain = fragment.domains[domain_name]
744 add_signal_name(domain.clk)
745 if domain.rst is not None:
746 add_signal_name(domain.rst)
747
748 clk_trigger = 1 if domain.clk_edge == "pos" else 0
749 self.state.for_signal(domain.clk).wait(domain_process, trigger=clk_trigger)
750 if domain.rst is not None and domain.async_reset:
751 rst_trigger = 1
752 self.state.for_signal(domain.rst).wait(domain_process, trigger=rst_trigger)
753
754 gen_asserts = []
755 clk_index = domain_process.context.get_signal(domain.clk)
756 gen_asserts.append(f"slots[{clk_index}].curr == {clk_trigger}")
757 if domain.rst is not None and domain.async_reset:
758 rst_index = domain_process.context.get_signal(domain.rst)
759 gen_asserts.append(f"slots[{rst_index}].curr == {rst_trigger}")
760 emitter.append(f"assert {' or '.join(gen_asserts)}")
761
762 for signal in domain_signals:
763 signal_index = domain_process.context.get_signal(signal)
764 emitter.append(f"next_{signal_index} = slots[{signal_index}].next")
765
766 _StatementCompiler(domain_process.context, emitter)(domain_stmts)
767
768 for signal in domain_signals:
769 signal_index = domain_process.context.get_signal(signal)
770 emitter.append(f"slots[{signal_index}].set(next_{signal_index})")
771
772 # There shouldn't be any exceptions raised by the generated code, but if there are
773 # (almost certainly due to a bug in the code generator), use this environment variable
774 # to make backtraces useful.
775 code = emitter.flush()
776 if os.getenv("NMIGEN_pysim_dump"):
777 file = tempfile.NamedTemporaryFile("w", prefix="nmigen_pysim_", delete=False)
778 file.write(code)
779 filename = file.name
780 else:
781 filename = "<string>"
782
783 exec_locals = {"slots": domain_process.context.slots, **_ValueCompiler.helpers}
784 exec(compile(code, filename, "exec"), exec_locals)
785 domain_process.run = exec_locals["run"]
786
787 processes.add(domain_process)
788
789 for used_signal in domain_process.context.indexes:
790 add_signal_name(used_signal)
791
792 for subfragment_index, (subfragment, subfragment_name) in enumerate(fragment.subfragments):
793 if subfragment_name is None:
794 subfragment_name = "U${}".format(subfragment_index)
795 processes.update(self(subfragment, hierarchy=(*hierarchy, subfragment_name)))
796
797 return processes
798
799
800 class _CoroutineProcess(_Process):
801 def __init__(self, state, domains, constructor, *, default_cmd=None):
802 self.state = state
803 self.domains = domains
804 self.constructor = constructor
805 self.default_cmd = default_cmd
806 self.reset()
807
808 def reset(self):
809 self.runnable = True
810 self.passive = False
811 self.coroutine = self.constructor()
812 self.eval_context = _EvalContext(self.state)
813 self.exec_locals = {
814 "slots": self.eval_context.slots,
815 "result": None,
816 **_ValueCompiler.helpers
817 }
818 self.waits_on = set()
819
820 @property
821 def name(self):
822 coroutine = self.coroutine
823 while coroutine.gi_yieldfrom is not None:
824 coroutine = coroutine.gi_yieldfrom
825 if inspect.isgenerator(coroutine):
826 frame = coroutine.gi_frame
827 if inspect.iscoroutine(coroutine):
828 frame = coroutine.cr_frame
829 return "{}:{}".format(inspect.getfile(frame), inspect.getlineno(frame))
830
831 def get_in_signal(self, signal, *, trigger=None):
832 signal_state = self.state.for_signal(signal)
833 assert self not in signal_state.waiters
834 signal_state.waiters[self] = trigger
835 self.waits_on.add(signal_state)
836 return signal_state
837
838 def run(self):
839 if self.coroutine is None:
840 return
841
842 if self.waits_on:
843 for signal_state in self.waits_on:
844 del signal_state.waiters[self]
845 self.waits_on.clear()
846
847 response = None
848 while True:
849 try:
850 command = self.coroutine.send(response)
851 if command is None:
852 command = self.default_cmd
853 response = None
854
855 if isinstance(command, Value):
856 exec(_RHSValueCompiler.compile(self.eval_context, command, mode="curr"),
857 self.exec_locals)
858 response = Const.normalize(self.exec_locals["result"], command.shape())
859
860 elif isinstance(command, Statement):
861 exec(_StatementCompiler.compile(self.eval_context, command),
862 self.exec_locals)
863
864 elif type(command) is Tick:
865 domain = command.domain
866 if isinstance(domain, ClockDomain):
867 pass
868 elif domain in self.domains:
869 domain = self.domains[domain]
870 else:
871 raise NameError("Received command {!r} that refers to a nonexistent "
872 "domain {!r} from process {!r}"
873 .format(command, command.domain, self.name))
874 self.get_in_signal(domain.clk, trigger=1 if domain.clk_edge == "pos" else 0)
875 if domain.rst is not None and domain.async_reset:
876 self.get_in_signal(domain.rst, trigger=1)
877 return
878
879 elif type(command) is Settle:
880 self.state.deadlines[self] = None
881 return
882
883 elif type(command) is Delay:
884 if command.interval is None:
885 self.state.deadlines[self] = None
886 else:
887 self.state.deadlines[self] = self.state.timestamp + command.interval
888 return
889
890 elif type(command) is Passive:
891 self.passive = True
892
893 elif type(command) is Active:
894 self.passive = False
895
896 elif command is None: # only possible if self.default_cmd is None
897 raise TypeError("Received default command from process {!r} that was added "
898 "with add_process(); did you mean to add this process with "
899 "add_sync_process() instead?"
900 .format(self.name))
901
902 else:
903 raise TypeError("Received unsupported command {!r} from process {!r}"
904 .format(command, self.name))
905
906 except StopIteration:
907 self.passive = True
908 self.coroutine = None
909 return
910
911 except Exception as exn:
912 self.coroutine.throw(exn)
913
914
915 class _WaveformContextManager:
916 def __init__(self, state, waveform_writer):
917 self._state = state
918 self._waveform_writer = waveform_writer
919
920 def __enter__(self):
921 try:
922 self._state.start_waveform(self._waveform_writer)
923 except:
924 self._waveform_writer.close(0)
925 raise
926
927 def __exit__(self, *args):
928 self._state.finish_waveform()
929
930
931 class Simulator:
932 def __init__(self, fragment, **kwargs):
933 self._state = _SimulatorState()
934 self._signal_names = SignalDict()
935 self._fragment = Fragment.get(fragment, platform=None).prepare()
936 self._processes = _FragmentCompiler(self._state, self._signal_names)(self._fragment)
937 if kwargs: # :nocov:
938 # TODO(nmigen-0.3): remove
939 self._state.start_waveform(_VCDWaveformWriter(self._signal_names, **kwargs))
940 self._clocked = set()
941
942 def _check_process(self, process):
943 if not (inspect.isgeneratorfunction(process) or inspect.iscoroutinefunction(process)):
944 if inspect.isgenerator(process) or inspect.iscoroutine(process):
945 warnings.warn("instead of generators, use generator functions as processes; "
946 "this allows the simulator to be repeatedly reset",
947 DeprecationWarning, stacklevel=3)
948 def wrapper():
949 yield from process
950 return wrapper
951 else:
952 raise TypeError("Cannot add a process {!r} because it is not a generator function"
953 .format(process))
954 return process
955
956 def _add_coroutine_process(self, process, *, default_cmd):
957 self._processes.add(_CoroutineProcess(self._state, self._fragment.domains, process,
958 default_cmd=default_cmd))
959
960 def add_process(self, process):
961 process = self._check_process(process)
962 def wrapper():
963 # Only start a bench process after comb settling, so that the reset values are correct.
964 yield Settle()
965 yield from process()
966 self._add_coroutine_process(wrapper, default_cmd=None)
967
968 def add_sync_process(self, process, *, domain="sync"):
969 process = self._check_process(process)
970 def wrapper():
971 # Only start a sync process after the first clock edge (or reset edge, if the domain
972 # uses an asynchronous reset). This matches the behavior of synchronous FFs.
973 yield Tick(domain)
974 yield from process()
975 return self._add_coroutine_process(wrapper, default_cmd=Tick(domain))
976
977 def add_clock(self, period, *, phase=None, domain="sync", if_exists=False):
978 """Add a clock process.
979
980 Adds a process that drives the clock signal of ``domain`` at a 50% duty cycle.
981
982 Arguments
983 ---------
984 period : float
985 Clock period. The process will toggle the ``domain`` clock signal every ``period / 2``
986 seconds.
987 phase : None or float
988 Clock phase. The process will wait ``phase`` seconds before the first clock transition.
989 If not specified, defaults to ``period / 2``.
990 domain : str or ClockDomain
991 Driven clock domain. If specified as a string, the domain with that name is looked up
992 in the root fragment of the simulation.
993 if_exists : bool
994 If ``False`` (the default), raise an error if the driven domain is specified as
995 a string and the root fragment does not have such a domain. If ``True``, do nothing
996 in this case.
997 """
998 if isinstance(domain, ClockDomain):
999 pass
1000 elif domain in self._fragment.domains:
1001 domain = self._fragment.domains[domain]
1002 elif if_exists:
1003 return
1004 else:
1005 raise ValueError("Domain {!r} is not present in simulation"
1006 .format(domain))
1007 if domain in self._clocked:
1008 raise ValueError("Domain {!r} already has a clock driving it"
1009 .format(domain.name))
1010
1011 half_period = period / 2
1012 if phase is None:
1013 # By default, delay the first edge by half period. This causes any synchronous activity
1014 # to happen at a non-zero time, distinguishing it from the reset values in the waveform
1015 # viewer.
1016 phase = half_period
1017 def clk_process():
1018 yield Passive()
1019 yield Delay(phase)
1020 # Behave correctly if the process is added after the clock signal is manipulated, or if
1021 # its reset state is high.
1022 initial = (yield domain.clk)
1023 steps = (
1024 domain.clk.eq(~initial),
1025 Delay(half_period),
1026 domain.clk.eq(initial),
1027 Delay(half_period),
1028 )
1029 while True:
1030 yield from iter(steps)
1031 self._add_coroutine_process(clk_process, default_cmd=None)
1032 self._clocked.add(domain)
1033
1034 def reset(self):
1035 """Reset the simulation.
1036
1037 Assign the reset value to every signal in the simulation, and restart every user process.
1038 """
1039 self._state.reset()
1040 for process in self._processes:
1041 process.reset()
1042
1043 def _delta(self):
1044 """Perform a delta cycle.
1045
1046 Performs the two phases of a delta cycle:
1047 1. run and suspend every non-waiting process once, queueing signal changes;
1048 2. commit every queued signal change, waking up any waiting process.
1049 """
1050 for process in self._processes:
1051 if process.runnable:
1052 process.runnable = False
1053 process.run()
1054
1055 return self._state.commit()
1056
1057 def _settle(self):
1058 """Settle the simulation.
1059
1060 Run every process and commit changes until a fixed point is reached. If there is
1061 an unstable combinatorial loop, this function will never return.
1062 """
1063 while self._delta():
1064 pass
1065
1066 def step(self):
1067 """Step the simulation.
1068
1069 Run every process and commit changes until a fixed point is reached, then advance time
1070 to the closest deadline (if any). If there is an unstable combinatorial loop,
1071 this function will never return.
1072
1073 Returns ``True`` if there are any active processes, ``False`` otherwise.
1074 """
1075 self._settle()
1076 self._state.advance()
1077 return any(not process.passive for process in self._processes)
1078
1079 def run(self):
1080 """Run the simulation while any processes are active.
1081
1082 Processes added with :meth:`add_process` and :meth:`add_sync_process` are initially active,
1083 and may change their status using the ``yield Passive()`` and ``yield Active()`` commands.
1084 Processes compiled from HDL and added with :meth:`add_clock` are always passive.
1085 """
1086 while self.step():
1087 pass
1088
1089 def run_until(self, deadline, *, run_passive=False):
1090 """Run the simulation until it advances to ``deadline``.
1091
1092 If ``run_passive`` is ``False``, the simulation also stops when there are no active
1093 processes, similar to :meth:`run`. Otherwise, the simulation will stop only after it
1094 advances to or past ``deadline``.
1095
1096 If the simulation stops advancing, this function will never return.
1097 """
1098 assert self._state.timestamp <= deadline
1099 while (self.step() or run_passive) and self._state.timestamp < deadline:
1100 pass
1101
1102 def write_vcd(self, vcd_file, gtkw_file=None, *, traces=()):
1103 """Write waveforms to a Value Change Dump file, optionally populating a GTKWave save file.
1104
1105 This method returns a context manager. It can be used as: ::
1106
1107 sim = Simulator(frag)
1108 sim.add_clock(1e-6)
1109 with sim.write_vcd("dump.vcd", "dump.gtkw"):
1110 sim.run_until(1e-3)
1111
1112 Arguments
1113 ---------
1114 vcd_file : str or file-like object
1115 Verilog Value Change Dump file or filename.
1116 gtkw_file : str or file-like object
1117 GTKWave save file or filename.
1118 traces : iterable of Signal
1119 Signals to display traces for.
1120 """
1121 waveform_writer = _VCDWaveformWriter(self._signal_names,
1122 vcd_file=vcd_file, gtkw_file=gtkw_file, traces=traces)
1123 return _WaveformContextManager(self._state, waveform_writer)
1124
1125 # TODO(nmigen-0.3): remove
1126 @deprecated("instead of `with Simulator(fragment, ...) as sim:`, use "
1127 "`sim = Simulator(fragment); with sim.write_vcd(...):`")
1128 def __enter__(self): # :nocov:
1129 return self
1130
1131 # TODO(nmigen-0.3): remove
1132 def __exit__(self, *args): # :nocov:
1133 self._state.finish_waveform()