hdl.ast: add Past, Stable, Rose, Fell.
authorwhitequark <whitequark@whitequark.org>
Thu, 17 Jan 2019 04:31:27 +0000 (04:31 +0000)
committerwhitequark <whitequark@whitequark.org>
Thu, 17 Jan 2019 04:31:27 +0000 (04:31 +0000)
nmigen/hdl/ast.py
nmigen/hdl/dsl.py
nmigen/hdl/ir.py
nmigen/hdl/xfrm.py
nmigen/test/test_hdl_dsl.py
nmigen/test/test_sim.py

index bee4a1dfa6ba7219bc5c4acd74c9814662e8a2a8..4b413811000dc953410b118103b5dbe510b2bf6d 100644 (file)
@@ -10,7 +10,8 @@ from ..tools import *
 
 __all__ = [
     "Value", "Const", "C", "AnyConst", "AnySeq", "Operator", "Mux", "Part", "Slice", "Cat", "Repl",
-    "Array", "ArrayProxy", "Sample",
+    "Array", "ArrayProxy",
+    "Sample", "Past", "Stable", "Rose", "Fell",
     "Signal", "ClockSignal", "ResetSignal",
     "Statement", "Assign", "Assert", "Assume", "Switch", "Delay", "Tick",
     "Passive", "ValueKey", "ValueDict", "ValueSet", "SignalKey", "SignalDict",
@@ -471,10 +472,10 @@ class Cat(Value):
         return sum(len(part) for part in self.parts), False
 
     def _lhs_signals(self):
-        return union(part._lhs_signals() for part in self.parts)
+        return union((part._lhs_signals() for part in self.parts), start=ValueSet())
 
     def _rhs_signals(self):
-        return union(part._rhs_signals() for part in self.parts)
+        return union((part._rhs_signals() for part in self.parts), start=ValueSet())
 
     def _as_const(self):
         value = 0
@@ -832,9 +833,15 @@ class ArrayProxy(Value):
 
 
 class Sample(Value):
-    def __init__(self, value, clocks, domain):
+    """Value from the past.
+
+    A ``Sample`` of an expression is equal to the value of the expression ``clocks`` clock edges
+    of the ``domain`` clock back. If that moment is before the beginning of time, it is equal
+    to the value of the expression calculated as if each signal had its reset value.
+    """
+    def __init__(self, expr, clocks, domain):
         super().__init__(src_loc_at=1)
-        self.value  = Value.wrap(value)
+        self.value  = Value.wrap(expr)
         self.clocks = int(clocks)
         self.domain = domain
         if not isinstance(self.value, (Const, Signal)):
@@ -855,6 +862,22 @@ class Sample(Value):
             self.value, "<default>" if self.domain is None else self.domain, self.clocks)
 
 
+def Past(expr, clocks=1, domain=None):
+    return Sample(expr, clocks, domain)
+
+
+def Stable(expr, clocks=0, domain=None):
+    return Sample(expr, clocks + 1, domain) == Sample(expr, clocks, domain)
+
+
+def Rose(expr, clocks=0, domain=None):
+    return ~Sample(expr, clocks + 1, domain) & Sample(expr, clocks, domain)
+
+
+def Fell(expr, clocks=0, domain=None):
+    return Sample(expr, clocks + 1, domain) & ~Sample(expr, clocks, domain)
+
+
 class _StatementList(list):
     def __repr__(self):
         return "({})".format(" ".join(map(repr, self)))
index 918e8e351589a15bc3584dcd5416b635280b487c..23655e0228d368bade498b6a4465229fd749a5ce 100644 (file)
@@ -353,6 +353,7 @@ class Module(_ModuleBuilderRoot):
                     "Only assignments, asserts, and assumes may be appended to d.{}"
                     .format(domain_name(domain)))
 
+            assign = SampleDomainInjector(domain)(assign)
             for signal in assign._lhs_signals():
                 if signal not in self._driving:
                     self._driving[signal] = domain
@@ -384,7 +385,8 @@ class Module(_ModuleBuilderRoot):
         fragment = Fragment()
         for submodule, name in self._submodules:
             fragment.add_subfragment(submodule.get_fragment(platform), name)
-        fragment.add_statements(self._statements)
+        statements = SampleDomainInjector("sync")(self._statements)
+        fragment.add_statements(statements)
         for signal, domain in self._driving.items():
             fragment.add_driver(signal, domain)
         fragment.add_domains(self._domains)
index bf66135a6d898b11341e194fed41ac066ba375d4..7f20b19f9ec933e03ee4c006673a8deb0f2fe165 100644 (file)
@@ -363,9 +363,9 @@ class Fragment:
                 SignalSet(self.iter_ports("io")))
 
     def prepare(self, ports=(), ensure_sync_exists=True):
-        from .xfrm import FragmentTransformer
+        from .xfrm import SampleLowerer
 
-        fragment = FragmentTransformer()(self)
+        fragment = SampleLowerer()(self)
         fragment._propagate_domains(ensure_sync_exists)
         fragment._resolve_hierarchy_conflicts()
         fragment = fragment._insert_domain_resets()
index b335d0ceb3f173b5a208fdda4c468e9145dddee6..40ce767ef7a5ba805b26d32e9847814712f5cfda 100644 (file)
@@ -13,7 +13,8 @@ from .rec import *
 __all__ = ["ValueVisitor", "ValueTransformer",
            "StatementVisitor", "StatementTransformer",
            "FragmentTransformer",
-           "DomainRenamer", "DomainLowerer", "SampleLowerer",
+           "DomainRenamer", "DomainLowerer",
+           "SampleDomainInjector", "SampleLowerer",
            "SwitchCleaner", "LHSGroupAnalyzer", "LHSGroupFilter",
            "ResetInserter", "CEInserter"]
 
@@ -340,6 +341,19 @@ class DomainLowerer(FragmentTransformer, ValueTransformer, StatementTransformer)
         return cd.rst
 
 
+class SampleDomainInjector(ValueTransformer, StatementTransformer):
+    def __init__(self, domain):
+        self.domain = domain
+
+    def on_Sample(self, value):
+        if value.domain is not None:
+            return value
+        return Sample(value.value, value.clocks, self.domain)
+
+    def __call__(self, stmts):
+        return self.on_statement(stmts)
+
+
 class SampleLowerer(FragmentTransformer, ValueTransformer, StatementTransformer):
     def __init__(self):
         self.sample_cache = ValueDict()
index bbcae25426ca66729a6f7594e8e9b7dcdf00e2ef..6f89ac826ba963a162afa2a0b63b0a465fda18e5 100644 (file)
@@ -113,6 +113,24 @@ class DSLTestCase(FHDLTestCase):
         )
         """)
 
+    def test_sample_domain(self):
+        m = Module()
+        i = Signal()
+        o1 = Signal()
+        o2 = Signal()
+        o3 = Signal()
+        m.d.sync += o1.eq(Past(i))
+        m.d.pix  += o2.eq(Past(i))
+        m.d.pix  += o3.eq(Past(i, domain="sync"))
+        f = m.lower(platform=None)
+        self.assertRepr(f.statements, """
+        (
+            (eq (sig o1) (sample (sig i) @ sync[1]))
+            (eq (sig o2) (sample (sig i) @ pix[1]))
+            (eq (sig o3) (sample (sig i) @ sync[1]))
+        )
+        """)
+
     def test_If(self):
         m = Module()
         with m.If(self.s1):
index a8782258896e96fde5c2aa6bb0c972b05994fef7..f063246dcf6667d867c6292daa557b699372fe05 100644 (file)
@@ -545,6 +545,52 @@ class SimulatorIntegrationTestCase(FHDLTestCase):
             sim.add_clock(1e-6)
             sim.add_sync_process(process)
 
+    def test_sample_helpers(self):
+        m = Module()
+        s = Signal(2)
+        def mk(x):
+            y = Signal.like(x)
+            m.d.comb += y.eq(x)
+            return y
+        p0, r0, f0, s0 = mk(Past(s, 0)), mk(Rose(s)),    mk(Fell(s)),    mk(Stable(s))
+        p1, r1, f1, s1 = mk(Past(s)),    mk(Rose(s, 1)), mk(Fell(s, 1)), mk(Stable(s, 1))
+        p2, r2, f2, s2 = mk(Past(s, 2)), mk(Rose(s, 2)), mk(Fell(s, 2)), mk(Stable(s, 2))
+        p3, r3, f3, s3 = mk(Past(s, 3)), mk(Rose(s, 3)), mk(Fell(s, 3)), mk(Stable(s, 3))
+        with self.assertSimulation(m) as sim:
+            def process_gen():
+                yield s.eq(0b10)
+                yield
+                yield
+                yield s.eq(0b01)
+                yield
+            def process_check():
+                yield
+                yield
+                yield
+
+                self.assertEqual((yield p0), 0b01)
+                self.assertEqual((yield p1), 0b10)
+                self.assertEqual((yield p2), 0b10)
+                self.assertEqual((yield p3), 0b00)
+
+                self.assertEqual((yield s0), 0b0)
+                self.assertEqual((yield s1), 0b1)
+                self.assertEqual((yield s2), 0b0)
+                self.assertEqual((yield s3), 0b1)
+
+                self.assertEqual((yield r0), 0b01)
+                self.assertEqual((yield r1), 0b00)
+                self.assertEqual((yield r2), 0b10)
+                self.assertEqual((yield r3), 0b00)
+
+                self.assertEqual((yield f0), 0b10)
+                self.assertEqual((yield f1), 0b00)
+                self.assertEqual((yield f2), 0b00)
+                self.assertEqual((yield f3), 0b00)
+            sim.add_clock(1e-6)
+            sim.add_sync_process(process_gen)
+            sim.add_sync_process(process_check)
+
     def test_wrong_not_run(self):
         with self.assertWarns(UserWarning,
                 msg="Simulation created, but not run"):