hdl.ir, back.rtlil: allow specifying attributes on instances.
[nmigen.git] / nmigen / hdl / ir.py
1 from abc import ABCMeta, abstractmethod
2 from collections import defaultdict, OrderedDict
3 from functools import reduce
4 import warnings
5 import traceback
6 import sys
7
8 from ..tools import *
9 from .ast import *
10 from .cd import *
11
12
13 __all__ = ["Elaboratable", "DriverConflict", "Fragment", "Instance"]
14
15
16 class Elaboratable(metaclass=ABCMeta):
17 _Elaboratable__silence = False
18
19 def __new__(cls, *args, **kwargs):
20 self = super().__new__(cls)
21 self._Elaboratable__traceback = traceback.extract_stack()[:-1]
22 self._Elaboratable__used = False
23 return self
24
25 def __del__(self):
26 if self._Elaboratable__silence:
27 return
28 if hasattr(self, "_Elaboratable__used") and not self._Elaboratable__used:
29 print("Warning: elaboratable created but never used\n",
30 "Constructor traceback (most recent call last):\n",
31 *traceback.format_list(self._Elaboratable__traceback),
32 file=sys.stderr, sep="")
33
34
35 _old_excepthook = sys.excepthook
36 def _silence_elaboratable(type, value, traceback):
37 # Don't show anything if the interpreter crashed; that'd just obscure the exception
38 # traceback instead of helping.
39 Elaboratable._Elaboratable__silence = True
40 _old_excepthook(type, value, traceback)
41 sys.excepthook = _silence_elaboratable
42
43
44 class DriverConflict(UserWarning):
45 pass
46
47
48 class Fragment:
49 @staticmethod
50 def get(obj, platform):
51 while True:
52 if isinstance(obj, Fragment):
53 return obj
54 elif isinstance(obj, Elaboratable):
55 obj._Elaboratable__used = True
56 obj = obj.elaborate(platform)
57 elif hasattr(obj, "elaborate"):
58 warnings.warn(
59 message="Class {!r} is an elaboratable that does not explicitly inherit from "
60 "Elaboratable; doing so would improve diagnostics"
61 .format(type(obj)),
62 category=RuntimeWarning,
63 stacklevel=2)
64 obj = obj.elaborate(platform)
65 else:
66 raise AttributeError("Object '{!r}' cannot be elaborated".format(obj))
67
68 def __init__(self):
69 self.ports = SignalDict()
70 self.drivers = OrderedDict()
71 self.statements = []
72 self.domains = OrderedDict()
73 self.subfragments = []
74 self.attrs = OrderedDict()
75 self.generated = OrderedDict()
76 self.flatten = False
77
78 def add_ports(self, *ports, dir):
79 assert dir in ("i", "o", "io")
80 for port in flatten(ports):
81 self.ports[port] = dir
82
83 def iter_ports(self, dir=None):
84 if dir is None:
85 yield from self.ports
86 else:
87 for port, port_dir in self.ports.items():
88 if port_dir == dir:
89 yield port
90
91 def add_driver(self, signal, domain=None):
92 if domain not in self.drivers:
93 self.drivers[domain] = SignalSet()
94 self.drivers[domain].add(signal)
95
96 def iter_drivers(self):
97 for domain, signals in self.drivers.items():
98 for signal in signals:
99 yield domain, signal
100
101 def iter_comb(self):
102 if None in self.drivers:
103 yield from self.drivers[None]
104
105 def iter_sync(self):
106 for domain, signals in self.drivers.items():
107 if domain is None:
108 continue
109 for signal in signals:
110 yield domain, signal
111
112 def iter_signals(self):
113 signals = SignalSet()
114 signals |= self.ports.keys()
115 for domain, domain_signals in self.drivers.items():
116 if domain is not None:
117 cd = self.domains[domain]
118 signals.add(cd.clk)
119 if cd.rst is not None:
120 signals.add(cd.rst)
121 signals |= domain_signals
122 return signals
123
124 def add_domains(self, *domains):
125 for domain in flatten(domains):
126 assert isinstance(domain, ClockDomain)
127 assert domain.name not in self.domains
128 self.domains[domain.name] = domain
129
130 def iter_domains(self):
131 yield from self.domains
132
133 def add_statements(self, *stmts):
134 self.statements += Statement.wrap(stmts)
135
136 def add_subfragment(self, subfragment, name=None):
137 assert isinstance(subfragment, Fragment)
138 self.subfragments.append((subfragment, name))
139
140 def find_subfragment(self, name_or_index):
141 if isinstance(name_or_index, int):
142 if name_or_index < len(self.subfragments):
143 subfragment, name = self.subfragments[name_or_index]
144 return subfragment
145 raise NameError("No subfragment at index #{}".format(name_or_index))
146 else:
147 for subfragment, name in self.subfragments:
148 if name == name_or_index:
149 return subfragment
150 raise NameError("No subfragment with name '{}'".format(name_or_index))
151
152 def find_generated(self, *path):
153 if len(path) > 1:
154 path_component, *path = path
155 return self.find_subfragment(path_component).find_generated(*path)
156 else:
157 item, = path
158 return self.generated[item]
159
160 def elaborate(self, platform):
161 return self
162
163 def _merge_subfragment(self, subfragment):
164 # Merge subfragment's everything except clock domains into this fragment.
165 # Flattening is done after clock domain propagation, so we can assume the domains
166 # are already the same in every involved fragment in the first place.
167 self.ports.update(subfragment.ports)
168 for domain, signal in subfragment.iter_drivers():
169 self.add_driver(signal, domain)
170 self.statements += subfragment.statements
171 self.subfragments += subfragment.subfragments
172
173 # Remove the merged subfragment.
174 found = False
175 for i, (check_subfrag, check_name) in enumerate(self.subfragments): # :nobr:
176 if subfragment == check_subfrag:
177 del self.subfragments[i]
178 found = True
179 break
180 assert found
181
182 def _resolve_hierarchy_conflicts(self, hierarchy=("top",), mode="warn"):
183 assert mode in ("silent", "warn", "error")
184
185 driver_subfrags = SignalDict()
186 memory_subfrags = OrderedDict()
187 def add_subfrag(registry, entity, entry):
188 if entity not in registry:
189 registry[entity] = set()
190 registry[entity].add(entry)
191
192 # For each signal driven by this fragment and/or its subfragments, determine which
193 # subfragments also drive it.
194 for domain, signal in self.iter_drivers():
195 add_subfrag(driver_subfrags, signal, (None, hierarchy))
196
197 flatten_subfrags = set()
198 for i, (subfrag, name) in enumerate(self.subfragments):
199 if name is None:
200 name = "<unnamed #{}>".format(i)
201 subfrag_hierarchy = hierarchy + (name,)
202
203 if subfrag.flatten:
204 # Always flatten subfragments that explicitly request it.
205 flatten_subfrags.add((subfrag, subfrag_hierarchy))
206
207 if isinstance(subfrag, Instance):
208 # For memories (which are subfragments, but semantically a part of superfragment),
209 # record that this fragment is driving it.
210 if subfrag.type in ("$memrd", "$memwr"):
211 memory = subfrag.parameters["MEMID"]
212 add_subfrag(memory_subfrags, memory, (None, hierarchy))
213
214 # Never flatten instances.
215 continue
216
217 # First, recurse into subfragments and let them detect driver conflicts as well.
218 subfrag_drivers, subfrag_memories = \
219 subfrag._resolve_hierarchy_conflicts(subfrag_hierarchy, mode)
220
221 # Second, classify subfragments by signals they drive and memories they use.
222 for signal in subfrag_drivers:
223 add_subfrag(driver_subfrags, signal, (subfrag, subfrag_hierarchy))
224 for memory in subfrag_memories:
225 add_subfrag(memory_subfrags, memory, (subfrag, subfrag_hierarchy))
226
227 # Find out the set of subfragments that needs to be flattened into this fragment
228 # to resolve driver-driver conflicts.
229 def flatten_subfrags_if_needed(subfrags):
230 if len(subfrags) == 1:
231 return []
232 flatten_subfrags.update((f, h) for f, h in subfrags if f is not None)
233 return list(sorted(".".join(h) for f, h in subfrags))
234
235 for signal, subfrags in driver_subfrags.items():
236 subfrag_names = flatten_subfrags_if_needed(subfrags)
237 if not subfrag_names:
238 continue
239
240 # While we're at it, show a message.
241 message = ("Signal '{}' is driven from multiple fragments: {}"
242 .format(signal, ", ".join(subfrag_names)))
243 if mode == "error":
244 raise DriverConflict(message)
245 elif mode == "warn":
246 message += "; hierarchy will be flattened"
247 warnings.warn_explicit(message, DriverConflict, *signal.src_loc)
248
249 for memory, subfrags in memory_subfrags.items():
250 subfrag_names = flatten_subfrags_if_needed(subfrags)
251 if not subfrag_names:
252 continue
253
254 # While we're at it, show a message.
255 message = ("Memory '{}' is accessed from multiple fragments: {}"
256 .format(memory.name, ", ".join(subfrag_names)))
257 if mode == "error":
258 raise DriverConflict(message)
259 elif mode == "warn":
260 message += "; hierarchy will be flattened"
261 warnings.warn_explicit(message, DriverConflict, *memory.src_loc)
262
263 # Flatten hierarchy.
264 for subfrag, subfrag_hierarchy in sorted(flatten_subfrags, key=lambda x: x[1]):
265 self._merge_subfragment(subfrag)
266
267 # If we flattened anything, we might be in a situation where we have a driver conflict
268 # again, e.g. if we had a tree of fragments like A --- B --- C where only fragments
269 # A and C were driving a signal S. In that case, since B is not driving S itself,
270 # processing B will not result in any flattening, but since B is transitively driving S,
271 # processing A will flatten B into it. Afterwards, we have a tree like AB --- C, which
272 # has another conflict.
273 if any(flatten_subfrags):
274 # Try flattening again.
275 return self._resolve_hierarchy_conflicts(hierarchy, mode)
276
277 # Nothing was flattened, we're done!
278 return (SignalSet(driver_subfrags.keys()),
279 set(memory_subfrags.keys()))
280
281 def _propagate_domains_up(self, hierarchy=("top",)):
282 from .xfrm import DomainRenamer
283
284 domain_subfrags = defaultdict(lambda: set())
285
286 # For each domain defined by a subfragment, determine which subfragments define it.
287 for i, (subfrag, name) in enumerate(self.subfragments):
288 # First, recurse into subfragments and let them propagate domains up as well.
289 hier_name = name
290 if hier_name is None:
291 hier_name = "<unnamed #{}>".format(i)
292 subfrag._propagate_domains_up(hierarchy + (hier_name,))
293
294 # Second, classify subfragments by domains they define.
295 for domain in subfrag.iter_domains():
296 domain_subfrags[domain].add((subfrag, name, i))
297
298 # For each domain defined by more than one subfragment, rename the domain in each
299 # of the subfragments such that they no longer conflict.
300 for domain, subfrags in domain_subfrags.items():
301 if len(subfrags) == 1:
302 continue
303
304 names = [n for f, n, i in subfrags]
305 if not all(names):
306 names = sorted("<unnamed #{}>".format(i) if n is None else "'{}'".format(n)
307 for f, n, i in subfrags)
308 raise DomainError("Domain '{}' is defined by subfragments {} of fragment '{}'; "
309 "it is necessary to either rename subfragment domains "
310 "explicitly, or give names to subfragments"
311 .format(domain, ", ".join(names), ".".join(hierarchy)))
312
313 if len(names) != len(set(names)):
314 names = sorted("#{}".format(i) for f, n, i in subfrags)
315 raise DomainError("Domain '{}' is defined by subfragments {} of fragment '{}', "
316 "some of which have identical names; it is necessary to either "
317 "rename subfragment domains explicitly, or give distinct names "
318 "to subfragments"
319 .format(domain, ", ".join(names), ".".join(hierarchy)))
320
321 for subfrag, name, i in subfrags:
322 self.subfragments[i] = \
323 (DomainRenamer({domain: "{}_{}".format(name, domain)})(subfrag), name)
324
325 # Finally, collect the (now unique) subfragment domains, and merge them into our domains.
326 for subfrag, name in self.subfragments:
327 for domain in subfrag.iter_domains():
328 self.add_domains(subfrag.domains[domain])
329
330 def _propagate_domains_down(self):
331 # For each domain defined in this fragment, ensure it also exists in all subfragments.
332 for subfrag, name in self.subfragments:
333 for domain in self.iter_domains():
334 if domain in subfrag.domains:
335 assert self.domains[domain] is subfrag.domains[domain]
336 else:
337 subfrag.add_domains(self.domains[domain])
338
339 subfrag._propagate_domains_down()
340
341 def _propagate_domains(self, ensure_sync_exists):
342 self._propagate_domains_up()
343 if ensure_sync_exists and not self.domains:
344 cd_sync = ClockDomain()
345 self.add_domains(cd_sync)
346 new_domains = (cd_sync,)
347 else:
348 new_domains = ()
349 self._propagate_domains_down()
350 return new_domains
351
352 def _insert_domain_resets(self):
353 from .xfrm import ResetInserter
354
355 resets = {cd.name: cd.rst for cd in self.domains.values() if cd.rst is not None}
356 return ResetInserter(resets)(self)
357
358 def _lower_domain_signals(self):
359 from .xfrm import DomainLowerer
360
361 return DomainLowerer(self.domains)(self)
362
363 def _prepare_use_def_graph(self, parent, level, uses, defs, ios, top):
364 def add_uses(*sigs, self=self):
365 for sig in flatten(sigs):
366 if sig not in uses:
367 uses[sig] = set()
368 uses[sig].add(self)
369
370 def add_defs(*sigs):
371 for sig in flatten(sigs):
372 if sig not in defs:
373 defs[sig] = self
374 else:
375 assert defs[sig] is self
376
377 def add_io(*sigs):
378 for sig in flatten(sigs):
379 if sig not in ios:
380 ios[sig] = self
381 else:
382 assert ios[sig] is self
383
384 # Collect all signals we're driving (on LHS of statements), and signals we're using
385 # (on RHS of statements, or in clock domains).
386 for stmt in self.statements:
387 add_uses(stmt._rhs_signals())
388 add_defs(stmt._lhs_signals())
389
390 for domain, _ in self.iter_sync():
391 cd = self.domains[domain]
392 add_uses(cd.clk)
393 if cd.rst is not None:
394 add_uses(cd.rst)
395
396 # Repeat for subfragments.
397 for subfrag, name in self.subfragments:
398 if isinstance(subfrag, Instance):
399 for port_name, (value, dir) in subfrag.named_ports.items():
400 if dir == "i":
401 subfrag.add_ports(value._rhs_signals(), dir=dir)
402 add_uses(value._rhs_signals())
403 if dir == "o":
404 subfrag.add_ports(value._lhs_signals(), dir=dir)
405 add_defs(value._lhs_signals())
406 if dir == "io":
407 subfrag.add_ports(value._lhs_signals(), dir=dir)
408 add_io(value._lhs_signals())
409 else:
410 parent[subfrag] = self
411 level [subfrag] = level[self] + 1
412
413 subfrag._prepare_use_def_graph(parent, level, uses, defs, ios, top)
414
415 def _propagate_ports(self, ports, all_undef_as_ports):
416 # Take this fragment graph:
417 #
418 # __ B (def: q, use: p r)
419 # /
420 # A (def: p, use: q r)
421 # \
422 # \_ C (def: r, use: p q)
423 #
424 # We need to consider three cases.
425 # 1. Signal p requires an input port in B;
426 # 2. Signal r requires an output port in C;
427 # 3. Signal r requires an output port in C and an input port in B.
428 #
429 # Adding these ports can be in general done in three steps for each signal:
430 # 1. Find the least common ancestor of all uses and defs.
431 # 2. Going upwards from the single def, add output ports.
432 # 3. Going upwards from all uses, add input ports.
433
434 parent = {self: None}
435 level = {self: 0}
436 uses = SignalDict()
437 defs = SignalDict()
438 ios = SignalDict()
439 self._prepare_use_def_graph(parent, level, uses, defs, ios, self)
440
441 ports = SignalSet(ports)
442 if all_undef_as_ports:
443 for sig in uses:
444 if sig in defs:
445 continue
446 ports.add(sig)
447 for sig in ports:
448 if sig not in uses:
449 uses[sig] = set()
450 uses[sig].add(self)
451
452 @memoize
453 def lca_of(fragu, fragv):
454 # Normalize fragu to be deeper than fragv.
455 if level[fragu] < level[fragv]:
456 fragu, fragv = fragv, fragu
457 # Find ancestor of fragu on the same level as fragv.
458 for _ in range(level[fragu] - level[fragv]):
459 fragu = parent[fragu]
460 # If fragv was the ancestor of fragv, we're done.
461 if fragu == fragv:
462 return fragu
463 # Otherwise, they are at the same level but in different branches. Step both fragu
464 # and fragv until we find the common ancestor.
465 while parent[fragu] != parent[fragv]:
466 fragu = parent[fragu]
467 fragv = parent[fragv]
468 return parent[fragu]
469
470 for sig in uses:
471 if sig in defs:
472 lca = reduce(lca_of, uses[sig], defs[sig])
473 else:
474 lca = reduce(lca_of, uses[sig])
475
476 for frag in uses[sig]:
477 if sig in defs and frag is defs[sig]:
478 continue
479 while frag != lca:
480 frag.add_ports(sig, dir="i")
481 frag = parent[frag]
482
483 if sig in defs:
484 frag = defs[sig]
485 while frag != lca:
486 frag.add_ports(sig, dir="o")
487 frag = parent[frag]
488
489 for sig in ios:
490 frag = ios[sig]
491 while frag is not None:
492 frag.add_ports(sig, dir="io")
493 frag = parent[frag]
494
495 for sig in ports:
496 if sig in ios:
497 continue
498 if sig in defs:
499 self.add_ports(sig, dir="o")
500 else:
501 self.add_ports(sig, dir="i")
502
503 def prepare(self, ports=None, ensure_sync_exists=True):
504 from .xfrm import SampleLowerer
505
506 fragment = SampleLowerer()(self)
507 new_domains = fragment._propagate_domains(ensure_sync_exists)
508 fragment._resolve_hierarchy_conflicts()
509 fragment = fragment._insert_domain_resets()
510 fragment = fragment._lower_domain_signals()
511 if ports is None:
512 fragment._propagate_ports(ports=(), all_undef_as_ports=True)
513 else:
514 new_ports = []
515 for cd in new_domains:
516 new_ports.append(cd.clk)
517 if cd.rst is not None:
518 new_ports.append(cd.rst)
519 fragment._propagate_ports(ports=(*ports, *new_ports), all_undef_as_ports=False)
520 return fragment
521
522
523 class Instance(Fragment):
524 def __init__(self, type, *args, **kwargs):
525 super().__init__()
526
527 self.type = type
528 self.parameters = OrderedDict()
529 self.named_ports = OrderedDict()
530
531 for (kind, name, value) in args:
532 if kind == "a":
533 self.attrs[name] = value
534 elif kind == "p":
535 self.parameters[name] = value
536 elif kind in ("i", "o", "io"):
537 self.named_ports[name] = (value, kind)
538 else:
539 raise NameError("Instance argument {!r} should be a tuple (kind, name, value) "
540 "where kind is one of \"p\", \"i\", \"o\", or \"io\""
541 .format((kind, name, value)))
542
543 for kw, arg in kwargs.items():
544 if kw.startswith("a_"):
545 self.attrs[kw[2:]] = arg
546 elif kw.startswith("p_"):
547 self.parameters[kw[2:]] = arg
548 elif kw.startswith("i_"):
549 self.named_ports[kw[2:]] = (arg, "i")
550 elif kw.startswith("o_"):
551 self.named_ports[kw[2:]] = (arg, "o")
552 elif kw.startswith("io_"):
553 self.named_ports[kw[3:]] = (arg, "io")
554 else:
555 raise NameError("Instance keyword argument {}={!r} does not start with one of "
556 "\"p_\", \"i_\", \"o_\", or \"io_\""
557 .format(kw, arg))