fhdl.ir: automatically flatten hierarchy to resolve driver conflicts.
authorwhitequark <whitequark@whitequark.org>
Fri, 14 Dec 2018 22:47:58 +0000 (22:47 +0000)
committerwhitequark <whitequark@whitequark.org>
Fri, 14 Dec 2018 22:48:17 +0000 (22:48 +0000)
Fixes #5.

.coveragerc
nmigen/fhdl/ir.py
nmigen/test/test_fhdl_ir.py
nmigen/test/tools.py

index e70df5002a16597198ff5ea864f12f068d43d1dd..6435aa3cc52fed9b1fba433582b70b07ebe19928 100644 (file)
@@ -9,3 +9,5 @@ omit =
 [report]
 exclude_lines =
        :nocov:
+partial_branches =
+  :nobr:
index dbac8179a506d488ab8ab35607690d9425f7712a..844cc7b8abee82154bfcca87ec621860bb269196 100644 (file)
@@ -1,3 +1,4 @@
+import warnings
 from collections import defaultdict, OrderedDict
 
 from ..tools import *
@@ -5,7 +6,11 @@ from .ast import *
 from .cd import *
 
 
-__all__ = ["Fragment"]
+__all__ = ["Fragment", "DriverConflict"]
+
+
+class DriverConflict(UserWarning):
+    pass
 
 
 class Fragment:
@@ -73,6 +78,77 @@ class Fragment:
         assert isinstance(subfragment, Fragment)
         self.subfragments.append((subfragment, name))
 
+    def _resolve_driver_conflicts(self, hierarchy=("top",), mode="warn"):
+        assert mode in ("silent", "warn", "error")
+
+        driver_subfrags = ValueDict()
+
+        # For each signal driven by this fragment and/or its subfragments, determine which
+        # subfragments also drive it.
+        for domain, signal in self.iter_drivers():
+            if signal not in driver_subfrags:
+                driver_subfrags[signal] = set()
+            driver_subfrags[signal].add((None, hierarchy))
+
+        for i, (subfrag, name) in enumerate(self.subfragments):
+            # First, recurse into subfragments and let them detect driver conflicts as well.
+            if name is None:
+                name = "<unnamed #{}>".format(i)
+            subfrag_hierarchy = hierarchy + (name,)
+            subfrag_drivers = subfrag._resolve_driver_conflicts(subfrag_hierarchy, mode)
+
+            # Second, classify subfragments by domains they define.
+            for signal in subfrag_drivers:
+                if signal not in driver_subfrags:
+                    driver_subfrags[signal] = set()
+                driver_subfrags[signal].add((subfrag, subfrag_hierarchy))
+
+        # Find out the set of subfragments that needs to be flattened into this fragment
+        # to resolve driver-driver conflicts.
+        flatten_subfrags = set()
+        for signal, subfrags in driver_subfrags.items():
+            if len(subfrags) > 1:
+                flatten_subfrags.update((f, h) for f, h in subfrags if f is not None)
+
+                # While we're at it, show a message.
+                subfrag_names = ", ".join(sorted(".".join(h) for f, h in subfrags))
+                message = ("Signal '{}' is driven from multiple fragments: {}"
+                           .format(signal, subfrag_names))
+                if mode == "error":
+                    raise DriverConflict(message)
+                elif mode == "warn":
+                    message += "; hierarchy will be flattened"
+                    warnings.warn_explicit(message, DriverConflict, *signal.src_loc)
+
+        for subfrag, subfrag_hierarchy in sorted(flatten_subfrags, key=lambda x: x[1]):
+            # Merge subfragment's everything except clock domains into this fragment.
+            # Flattening is done after clock domain propagation, so we can assume the domains
+            # are already the same in every involved fragment in the first place.
+            self.ports.update(subfrag.ports)
+            for domain, signal in subfrag.iter_drivers():
+                self.add_driver(signal, domain)
+            self.statements += subfrag.statements
+            self.subfragments += subfrag.subfragments
+
+            # Remove the merged subfragment.
+            for i, (check_subfrag, check_name) in enumerate(self.subfragments): # :nobr:
+                if subfrag == check_subfrag:
+                    del self.subfragments[i]
+                    break
+
+        # If we flattened anything, we might be in a situation where we have a driver conflict
+        # again, e.g. if we had a tree of fragments like A --- B --- C where only fragments
+        # A and C were driving a signal S. In that case, since B is not driving S itself,
+        # processing B will not result in any flattening, but since B is transitively driving S,
+        # processing A will flatten B into it. Afterwards, we have a tree like AB --- C, which
+        # has another conflict.
+        if any(flatten_subfrags):
+            # Try flattening again.
+            return self._resolve_driver_conflicts(hierarchy, mode)
+
+        # Nothing was flattened, we're done!
+        return ValueSet(driver_subfrags.keys())
+
     def _propagate_domains_up(self, hierarchy=("top",)):
         from .xfrm import DomainRenamer
 
@@ -193,6 +269,7 @@ class Fragment:
 
         fragment = FragmentTransformer()(self)
         fragment._propagate_domains(ensure_sync_exists)
+        fragment._resolve_driver_conflicts()
         fragment = fragment._insert_domain_resets()
         fragment = fragment._lower_domain_signals()
         fragment._propagate_ports(ports)
index 3068af4e0d206e372a0caa506d335cdc92a1e28f..ace2e129f82de5563f8c32da42f5b69bbe1f4180 100644 (file)
@@ -245,3 +245,123 @@ class FragmentDomainsTestCase(FHDLTestCase):
         self.assertEqual(f1.domains.keys(), {"sync"})
         self.assertEqual(f2.domains.keys(), {"sync"})
         self.assertEqual(f1.domains["sync"], f2.domains["sync"])
+
+
+class FragmentDriverConflictTestCase(FHDLTestCase):
+    def setUp_self_sub(self):
+        self.s1 = Signal()
+        self.c1 = Signal()
+        self.c2 = Signal()
+
+        self.f1 = Fragment()
+        self.f1.add_statements(self.c1.eq(0))
+        self.f1.add_driver(self.s1)
+        self.f1.add_driver(self.c1, "sync")
+
+        self.f1a = Fragment()
+        self.f1.add_subfragment(self.f1a, "f1a")
+
+        self.f2 = Fragment()
+        self.f2.add_statements(self.c2.eq(1))
+        self.f2.add_driver(self.s1)
+        self.f2.add_driver(self.c2, "sync")
+        self.f1.add_subfragment(self.f2)
+
+        self.f1b = Fragment()
+        self.f1.add_subfragment(self.f1b, "f1b")
+
+        self.f2a = Fragment()
+        self.f2.add_subfragment(self.f2a, "f2a")
+
+    def test_conflict_self_sub(self):
+        self.setUp_self_sub()
+
+        self.f1._resolve_driver_conflicts(mode="silent")
+        self.assertEqual(self.f1.subfragments, [
+            (self.f1a, "f1a"),
+            (self.f1b, "f1b"),
+            (self.f2a, "f2a"),
+        ])
+        self.assertRepr(self.f1.statements, """
+        (
+            (eq (sig c1) (const 1'd0))
+            (eq (sig c2) (const 1'd1))
+        )
+        """)
+        self.assertEqual(self.f1.drivers, {
+            None:   ValueSet((self.s1,)),
+            "sync": ValueSet((self.c1, self.c2)),
+        })
+
+    def test_conflict_self_sub_error(self):
+        self.setUp_self_sub()
+
+        with self.assertRaises(DriverConflict,
+                msg="Signal '(sig s1)' is driven from multiple fragments: top, top.<unnamed #1>"):
+            self.f1._resolve_driver_conflicts(mode="error")
+
+    def test_conflict_self_sub_warning(self):
+        self.setUp_self_sub()
+
+        with self.assertWarns(DriverConflict,
+                msg="Signal '(sig s1)' is driven from multiple fragments: top, top.<unnamed #1>; "
+                    "hierarchy will be flattened"):
+            self.f1._resolve_driver_conflicts(mode="warn")
+
+    def setUp_sub_sub(self):
+        self.s1 = Signal()
+        self.c1 = Signal()
+        self.c2 = Signal()
+
+        self.f1 = Fragment()
+
+        self.f2 = Fragment()
+        self.f2.add_driver(self.s1)
+        self.f2.add_statements(self.c1.eq(0))
+        self.f1.add_subfragment(self.f2)
+
+        self.f3 = Fragment()
+        self.f3.add_driver(self.s1)
+        self.f3.add_statements(self.c2.eq(1))
+        self.f1.add_subfragment(self.f3)
+
+    def test_conflict_sub_sub(self):
+        self.setUp_sub_sub()
+
+        self.f1._resolve_driver_conflicts(mode="silent")
+        self.assertEqual(self.f1.subfragments, [])
+        self.assertRepr(self.f1.statements, """
+        (
+            (eq (sig c1) (const 1'd0))
+            (eq (sig c2) (const 1'd1))
+        )
+        """)
+
+    def setUp_self_subsub(self):
+        self.s1 = Signal()
+        self.c1 = Signal()
+        self.c2 = Signal()
+
+        self.f1 = Fragment()
+        self.f1.add_driver(self.s1)
+
+        self.f2 = Fragment()
+        self.f2.add_statements(self.c1.eq(0))
+        self.f1.add_subfragment(self.f2)
+
+        self.f3 = Fragment()
+        self.f3.add_driver(self.s1)
+        self.f3.add_statements(self.c2.eq(1))
+        self.f2.add_subfragment(self.f3)
+
+    def test_conflict_self_subsub(self):
+        self.setUp_self_subsub()
+
+        self.f1._resolve_driver_conflicts(mode="silent")
+        self.assertEqual(self.f1.subfragments, [])
+        self.assertRepr(self.f1.statements, """
+        (
+            (eq (sig c1) (const 1'd0))
+            (eq (sig c2) (const 1'd1))
+        )
+        """)
index 65cf0ff797ba235cb0f5a8066e0370f8fc32dcb3..297e7f97399c66fba537c658a21f99d4ea88dc7b 100644 (file)
@@ -1,5 +1,6 @@
 import re
 import unittest
+import warnings
 from contextlib import contextmanager
 
 from ..fhdl.ast import *
@@ -23,3 +24,12 @@ class FHDLTestCase(unittest.TestCase):
         if msg is not None:
             # WTF? unittest.assertRaises is completely broken.
             self.assertEqual(str(cm.exception), msg)
+
+    @contextmanager
+    def assertWarns(self, category, msg=None):
+        with warnings.catch_warnings(record=True) as warns:
+            yield
+        self.assertEqual(len(warns), 1)
+        self.assertEqual(warns[0].category, category)
+        if msg is not None:
+            self.assertEqual(str(warns[0].message), msg)