c2e9367bcb1d0c36cc3ac8efd7bacb8c4ca00710
3 from contextlib
import contextmanager
6 from ..hdl
.ast
import SignalSet
7 from ..hdl
.xfrm
import ValueVisitor
, StatementVisitor
, LHSGroupFilter
11 __all__
= ["PyRTLProcess"]
14 class PyRTLProcess(Process
):
24 def append(self
, code
):
25 self
._buffer
.append(" " * self
._level
)
26 self
._buffer
.append(code
)
27 self
._buffer
.append("\n")
35 def flush(self
, indent
=""):
36 code
= "".join(self
._buffer
)
40 def gen_var(self
, prefix
):
41 name
= f
"{prefix}_{self._suffix}"
45 def def_var(self
, prefix
, value
):
46 name
= self
.gen_var(prefix
)
47 self
.append(f
"{name} = {value}")
52 def __init__(self
, state
, emitter
):
54 self
.emitter
= emitter
57 class _ValueCompiler(ValueVisitor
, _Compiler
):
59 "sign": lambda value
, sign
: value | sign
if value
& sign
else value
,
60 "zdiv": lambda lhs
, rhs
: 0 if rhs
== 0 else lhs
// rhs
,
61 "zmod": lambda lhs
, rhs
: 0 if rhs
== 0 else lhs
% rhs
,
64 def on_ClockSignal(self
, value
):
65 raise NotImplementedError # :nocov:
67 def on_ResetSignal(self
, value
):
68 raise NotImplementedError # :nocov:
70 def on_AnyConst(self
, value
):
71 raise NotImplementedError # :nocov:
73 def on_AnySeq(self
, value
):
74 raise NotImplementedError # :nocov:
76 def on_Sample(self
, value
):
77 raise NotImplementedError # :nocov:
79 def on_Initial(self
, value
):
80 raise NotImplementedError # :nocov:
83 class _RHSValueCompiler(_ValueCompiler
):
84 def __init__(self
, state
, emitter
, *, mode
, inputs
=None):
85 super().__init
__(state
, emitter
)
86 assert mode
in ("curr", "next")
88 # If not None, `inputs` gets populated with RHS signals.
91 def on_Const(self
, value
):
92 return f
"{value.value}"
94 def on_Signal(self
, value
):
95 if self
.inputs
is not None:
96 self
.inputs
.add(value
)
98 if self
.mode
== "curr":
99 return f
"slots[{self.state.get_signal(value)}].{self.mode}"
101 return f
"next_{self.state.get_signal(value)}"
103 def on_Operator(self
, value
):
105 value_mask
= (1 << len(value
)) - 1
106 return f
"({self(value)} & {value_mask})"
109 if value
.shape().signed
:
110 return f
"sign({mask(value)}, {-1 << (len(value) - 1)})"
114 if len(value
.operands
) == 1:
115 arg
, = value
.operands
116 if value
.operator
== "~":
117 return f
"(~{self(arg)})"
118 if value
.operator
== "-":
119 return f
"(-{sign(arg)})"
120 if value
.operator
== "b":
121 return f
"bool({mask(arg)})"
122 if value
.operator
== "r|":
123 return f
"({mask(arg)} != 0)"
124 if value
.operator
== "r&":
125 return f
"({mask(arg)} == {(1 << len(arg)) - 1})"
126 if value
.operator
== "r^":
127 # Believe it or not, this is the fastest way to compute a sideways XOR in Python.
128 return f
"(format({mask(arg)}, 'b').count('1') % 2)"
129 if value
.operator
in ("u", "s"):
130 # These operators don't change the bit pattern, only its interpretation.
132 elif len(value
.operands
) == 2:
133 lhs
, rhs
= value
.operands
134 lhs_mask
= (1 << len(lhs
)) - 1
135 rhs_mask
= (1 << len(rhs
)) - 1
136 if value
.operator
== "+":
137 return f
"({sign(lhs)} + {sign(rhs)})"
138 if value
.operator
== "-":
139 return f
"({sign(lhs)} - {sign(rhs)})"
140 if value
.operator
== "*":
141 return f
"({sign(lhs)} * {sign(rhs)})"
142 if value
.operator
== "//":
143 return f
"zdiv({sign(lhs)}, {sign(rhs)})"
144 if value
.operator
== "%":
145 return f
"zmod({sign(lhs)}, {sign(rhs)})"
146 if value
.operator
== "&":
147 return f
"({self(lhs)} & {self(rhs)})"
148 if value
.operator
== "|":
149 return f
"({self(lhs)} | {self(rhs)})"
150 if value
.operator
== "^":
151 return f
"({self(lhs)} ^ {self(rhs)})"
152 if value
.operator
== "<<":
153 return f
"({sign(lhs)} << {sign(rhs)})"
154 if value
.operator
== ">>":
155 return f
"({sign(lhs)} >> {sign(rhs)})"
156 if value
.operator
== "==":
157 return f
"({sign(lhs)} == {sign(rhs)})"
158 if value
.operator
== "!=":
159 return f
"({sign(lhs)} != {sign(rhs)})"
160 if value
.operator
== "<":
161 return f
"({sign(lhs)} < {sign(rhs)})"
162 if value
.operator
== "<=":
163 return f
"({sign(lhs)} <= {sign(rhs)})"
164 if value
.operator
== ">":
165 return f
"({sign(lhs)} > {sign(rhs)})"
166 if value
.operator
== ">=":
167 return f
"({sign(lhs)} >= {sign(rhs)})"
168 elif len(value
.operands
) == 3:
169 if value
.operator
== "m":
170 sel
, val1
, val0
= value
.operands
171 return f
"({self(val1)} if {self(sel)} else {self(val0)})"
172 raise NotImplementedError("Operator '{}' not implemented".format(value
.operator
)) # :nocov:
174 def on_Slice(self
, value
):
175 return f
"(({self(value.value)} >> {value.start}) & {(1 << len(value)) - 1})"
177 def on_Part(self
, value
):
178 offset_mask
= (1 << len(value
.offset
)) - 1
179 offset
= f
"(({self(value.offset)} & {offset_mask}) * {value.stride})"
180 return f
"({self(value.value)} >> {offset} & " \
181 f
"{(1 << value.width) - 1})"
183 def on_Cat(self
, value
):
186 for part
in value
.parts
:
187 part_mask
= (1 << len(part
)) - 1
188 gen_parts
.append(f
"(({self(part)} & {part_mask}) << {offset})")
191 return f
"({' | '.join(gen_parts)})"
194 def on_Repl(self
, value
):
195 part_mask
= (1 << len(value
.value
)) - 1
196 gen_part
= self
.emitter
.def_var("repl", f
"{self(value.value)} & {part_mask}")
199 for _
in range(value
.count
):
200 gen_parts
.append(f
"({gen_part} << {offset})")
201 offset
+= len(value
.value
)
203 return f
"({' | '.join(gen_parts)})"
206 def on_ArrayProxy(self
, value
):
207 index_mask
= (1 << len(value
.index
)) - 1
208 gen_index
= self
.emitter
.def_var("rhs_index", f
"{self(value.index)} & {index_mask}")
209 gen_value
= self
.emitter
.gen_var("rhs_proxy")
212 for index
, elem
in enumerate(value
.elems
):
214 self
.emitter
.append(f
"if {gen_index} == {index}:")
216 self
.emitter
.append(f
"elif {gen_index} == {index}:")
217 with self
.emitter
.indent():
218 self
.emitter
.append(f
"{gen_value} = {self(elem)}")
219 self
.emitter
.append(f
"else:")
220 with self
.emitter
.indent():
221 self
.emitter
.append(f
"{gen_value} = {self(value.elems[-1])}")
227 def compile(cls
, state
, value
, *, mode
):
228 emitter
= _PythonEmitter()
229 compiler
= cls(state
, emitter
, mode
=mode
)
230 emitter
.append(f
"result = {compiler(value)}")
231 return emitter
.flush()
234 class _LHSValueCompiler(_ValueCompiler
):
235 def __init__(self
, state
, emitter
, *, rhs
, outputs
=None):
236 super().__init
__(state
, emitter
)
237 # `rrhs` is used to translate rvalues that are syntactically a part of an lvalue, e.g.
238 # the offset of a Part.
240 # `lrhs` is used to translate the read part of a read-modify-write cycle during partial
241 # update of an lvalue.
242 self
.lrhs
= _RHSValueCompiler(state
, emitter
, mode
="next", inputs
=None)
243 # If not None, `outputs` gets populated with signals on LHS.
244 self
.outputs
= outputs
246 def on_Const(self
, value
):
247 raise TypeError # :nocov:
249 def on_Signal(self
, value
):
250 if self
.outputs
is not None:
251 self
.outputs
.add(value
)
254 value_mask
= (1 << len(value
)) - 1
255 if value
.shape().signed
:
256 value_sign
= f
"sign({arg} & {value_mask}, {-1 << (len(value) - 1)})"
258 value_sign
= f
"{arg} & {value_mask}"
259 self
.emitter
.append(f
"next_{self.state.get_signal(value)} = {value_sign}")
262 def on_Operator(self
, value
):
263 raise TypeError # :nocov:
265 def on_Slice(self
, value
):
267 width_mask
= (1 << (value
.stop
- value
.start
)) - 1
268 self(value
.value
)(f
"({self.lrhs(value.value)} & " \
269 f
"{~(width_mask << value.start)} | " \
270 f
"(({arg} & {width_mask}) << {value.start}))")
273 def on_Part(self
, value
):
275 width_mask
= (1 << value
.width
) - 1
276 offset_mask
= (1 << len(value
.offset
)) - 1
277 offset
= f
"(({self.rrhs(value.offset)} & {offset_mask}) * {value.stride})"
278 self(value
.value
)(f
"({self.lrhs(value.value)} & " \
279 f
"~({width_mask} << {offset}) | " \
280 f
"(({arg} & {width_mask}) << {offset}))")
283 def on_Cat(self
, value
):
285 gen_arg
= self
.emitter
.def_var("cat", arg
)
288 for part
in value
.parts
:
289 part_mask
= (1 << len(part
)) - 1
290 self(part
)(f
"(({gen_arg} >> {offset}) & {part_mask})")
294 def on_Repl(self
, value
):
295 raise TypeError # :nocov:
297 def on_ArrayProxy(self
, value
):
299 index_mask
= (1 << len(value
.index
)) - 1
300 gen_index
= self
.emitter
.def_var("index", f
"{self.rrhs(value.index)} & {index_mask}")
303 for index
, elem
in enumerate(value
.elems
):
305 self
.emitter
.append(f
"if {gen_index} == {index}:")
307 self
.emitter
.append(f
"elif {gen_index} == {index}:")
308 with self
.emitter
.indent():
310 self
.emitter
.append(f
"else:")
311 with self
.emitter
.indent():
312 self(value
.elems
[-1])(arg
)
314 self
.emitter
.append(f
"pass")
318 class _StatementCompiler(StatementVisitor
, _Compiler
):
319 def __init__(self
, state
, emitter
, *, inputs
=None, outputs
=None):
320 super().__init
__(state
, emitter
)
321 self
.rhs
= _RHSValueCompiler(state
, emitter
, mode
="curr", inputs
=inputs
)
322 self
.lhs
= _LHSValueCompiler(state
, emitter
, rhs
=self
.rhs
, outputs
=outputs
)
324 def on_statements(self
, stmts
):
328 self
.emitter
.append("pass")
330 def on_Assign(self
, stmt
):
331 return self
.lhs(stmt
.lhs
)(self
.rhs(stmt
.rhs
))
333 def on_Switch(self
, stmt
):
334 gen_test
= self
.emitter
.def_var("test",
335 f
"{self.rhs(stmt.test)} & {(1 << len(stmt.test)) - 1}")
336 for index
, (patterns
, stmts
) in enumerate(stmt
.cases
.items()):
339 gen_checks
.append(f
"True")
341 for pattern
in patterns
:
343 mask
= int("".join("0" if b
== "-" else "1" for b
in pattern
), 2)
344 value
= int("".join("0" if b
== "-" else b
for b
in pattern
), 2)
345 gen_checks
.append(f
"({gen_test} & {mask}) == {value}")
347 value
= int(pattern
, 2)
348 gen_checks
.append(f
"{gen_test} == {value}")
350 self
.emitter
.append(f
"if {' or '.join(gen_checks)}:")
352 self
.emitter
.append(f
"elif {' or '.join(gen_checks)}:")
353 with self
.emitter
.indent():
356 def on_Assert(self
, stmt
):
357 raise NotImplementedError # :nocov:
359 def on_Assume(self
, stmt
):
360 raise NotImplementedError # :nocov:
362 def on_Cover(self
, stmt
):
363 raise NotImplementedError # :nocov:
366 def compile(cls
, state
, stmt
):
367 output_indexes
= [state
.get_signal(signal
) for signal
in stmt
._lhs
_signals
()]
368 emitter
= _PythonEmitter()
369 for signal_index
in output_indexes
:
370 emitter
.append(f
"next_{signal_index} = slots[{signal_index}].next")
371 compiler
= cls(state
, emitter
)
373 for signal_index
in output_indexes
:
374 emitter
.append(f
"slots[{signal_index}].set(next_{signal_index})")
375 return emitter
.flush()
378 class _FragmentCompiler
:
379 def __init__(self
, state
):
382 def __call__(self
, fragment
):
385 for domain_name
, domain_signals
in fragment
.drivers
.items():
386 domain_stmts
= LHSGroupFilter(domain_signals
)(fragment
.statements
)
387 domain_process
= PyRTLProcess(is_comb
=domain_name
is None)
389 emitter
= _PythonEmitter()
390 emitter
.append(f
"def run():")
393 if domain_name
is None:
394 for signal
in domain_signals
:
395 signal_index
= self
.state
.get_signal(signal
)
396 emitter
.append(f
"next_{signal_index} = {signal.reset}")
399 _StatementCompiler(self
.state
, emitter
, inputs
=inputs
)(domain_stmts
)
402 self
.state
.add_trigger(domain_process
, input)
405 domain
= fragment
.domains
[domain_name
]
406 clk_trigger
= 1 if domain
.clk_edge
== "pos" else 0
407 self
.state
.add_trigger(domain_process
, domain
.clk
, trigger
=clk_trigger
)
408 if domain
.rst
is not None and domain
.async_reset
:
410 self
.state
.add_trigger(domain_process
, domain
.rst
, trigger
=rst_trigger
)
412 for signal
in domain_signals
:
413 signal_index
= self
.state
.get_signal(signal
)
414 emitter
.append(f
"next_{signal_index} = slots[{signal_index}].next")
416 _StatementCompiler(self
.state
, emitter
)(domain_stmts
)
418 for signal
in domain_signals
:
419 signal_index
= self
.state
.get_signal(signal
)
420 emitter
.append(f
"slots[{signal_index}].set(next_{signal_index})")
422 # There shouldn't be any exceptions raised by the generated code, but if there are
423 # (almost certainly due to a bug in the code generator), use this environment variable
424 # to make backtraces useful.
425 code
= emitter
.flush()
426 if os
.getenv("NMIGEN_pysim_dump"):
427 file = tempfile
.NamedTemporaryFile("w", prefix
="nmigen_pysim_", delete
=False)
431 filename
= "<string>"
433 exec_locals
= {"slots": self
.state
.slots
, **_ValueCompiler
.helpers
}
434 exec(compile(code
, filename
, "exec"), exec_locals
)
435 domain_process
.run
= exec_locals
["run"]
437 processes
.add(domain_process
)
439 for subfragment_index
, (subfragment
, subfragment_name
) in enumerate(fragment
.subfragments
):
440 if subfragment_name
is None:
441 subfragment_name
= "U${}".format(subfragment_index
)
442 processes
.update(self(subfragment
))