speed up ==, hash, <, >, <=, and >= for plain_data
[nmutil.git] / src / nmutil / nmoperator.py
index 8d0aafff9e32467f95c3ee7f670c178de51fbe18..5163bff740245a55e96eabaf5eb8c0f8e56e1b33 100644 (file)
@@ -1,3 +1,4 @@
+# SPDX-License-Identifier: LGPL-3-or-later
 """ nmigen operator functions / utils
 
     This work is funded through NLnet under Grant 2019-02-012
@@ -41,6 +42,7 @@ class Visitor2:
         python object, enumerate them, find out the list of Signals that way,
         and assign them.
     """
+
     def iterator2(self, o, i):
         if isinstance(o, dict):
             yield from self.dict_iter2(o, i)
@@ -48,22 +50,30 @@ class Visitor2:
         if not isinstance(o, Sequence):
             o, i = [o], [i]
         for (ao, ai) in zip(o, i):
-            #print ("visit", fn, ao, ai)
+            # print ("visit", ao, ai)
+            # print ("    isinstance Record(ao)", isinstance(ao, Record))
+            # print ("    isinstance ArrayProxy(ao)",
+            #            isinstance(ao, ArrayProxy))
+            # print ("    isinstance Value(ai)",
+            #            isinstance(ai, Value))
             if isinstance(ao, Record):
                 yield from self.record_iter2(ao, ai)
             elif isinstance(ao, ArrayProxy) and not isinstance(ai, Value):
                 yield from self.arrayproxy_iter2(ao, ai)
+            elif isinstance(ai, ArrayProxy) and not isinstance(ao, Value):
+                assert False, "whoops, input ArrayProxy not supported yet"
+                yield from self.arrayproxy_iter3(ao, ai)
             else:
                 yield (ao, ai)
 
     def dict_iter2(self, o, i):
         for (k, v) in o.items():
-            print ("d-iter", v, i[k])
+            print ("d-iter", v, i[k])
             yield (v, i[k])
         return res
 
     def _not_quite_working_with_all_unit_tests_record_iter2(self, ao, ai):
-        print ("record_iter2", ao, ai, type(ao), type(ai))
+        print ("record_iter2", ao, ai, type(ao), type(ai))
         if isinstance(ai, Value):
             if isinstance(ao, Sequence):
                 ao, ai = [ao], [ai]
@@ -75,10 +85,10 @@ class Visitor2:
                 val = ai.fields
             else:
                 val = ai
-            if hasattr(val, field_name): # check for attribute
+            if hasattr(val, field_name):  # check for attribute
                 val = getattr(val, field_name)
             else:
-                val = val[field_name] # dictionary-style specification
+                val = val[field_name]  # dictionary-style specification
             yield from self.iterator2(ao.fields[field_name], val)
 
     def record_iter2(self, ao, ai):
@@ -87,16 +97,23 @@ class Visitor2:
                 val = ai.fields
             else:
                 val = ai
-            if hasattr(val, field_name): # check for attribute
+            if hasattr(val, field_name):  # check for attribute
                 val = getattr(val, field_name)
             else:
-                val = val[field_name] # dictionary-style specification
+                val = val[field_name]  # dictionary-style specification
             yield from self.iterator2(ao.fields[field_name], val)
 
     def arrayproxy_iter2(self, ao, ai):
-        #print ("arrayproxy_iter2", ai.ports(), ai, ao)
+        # print ("arrayproxy_iter2", ai.ports(), ai, ao)
         for p in ai.ports():
-            #print ("arrayproxy - p", p, p.name, ao)
+            # print ("arrayproxy - p", p, p.name, ao)
+            op = getattr(ao, p.name)
+            yield from self.iterator2(op, p)
+
+    def arrayproxy_iter3(self, ao, ai):
+        # print ("arrayproxy_iter3", ao.ports(), ai, ao)
+        for p in ao.ports():
+            # print ("arrayproxy - p", p, p.name, ao)
             op = getattr(ao, p.name)
             yield from self.iterator2(op, p)
 
@@ -105,6 +122,7 @@ class Visitor:
     """ a helper class for iterating single-argument compound data structures.
         similar to Visitor2.
     """
+
     def iterate(self, i):
         """ iterate a compound structure recursively using yield
         """
@@ -126,10 +144,10 @@ class Visitor:
                 val = ai.fields
             else:
                 val = ai
-            if hasattr(val, field_name): # check for attribute
+            if hasattr(val, field_name):  # check for attribute
                 val = getattr(val, field_name)
             else:
-                val = val[field_name] # dictionary-style specification
+                val = val[field_name]  # dictionary-style specification
             #print ("recidx", idx, field_name, field_shape, val)
             yield from self.iterate(val)
 
@@ -166,8 +184,6 @@ def cat(i):
     """ flattens a compound structure recursively using Cat
     """
     from nmigen._utils import flatten
-    #res = list(flatten(i)) # works (as of nmigen commit f22106e5) HOWEVER...
-    res = list(Visitor().iterate(i)) # needed because input may be a sequence
+    # res = list(flatten(i)) # works (as of nmigen commit f22106e5) HOWEVER...
+    res = list(Visitor().iterate(i))  # needed because input may be a sequence
     return Cat(*res)
-
-