speed up ==, hash, <, >, <=, and >= for plain_data
[nmutil.git] / src / nmutil / nmoperator.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 """ nmigen operator functions / utils
3
4 This work is funded through NLnet under Grant 2019-02-012
5
6 License: LGPLv3+
7
8
9 eq:
10 --
11
12 a strategically very important function that is identical in function
13 to nmigen's Signal.eq function, except it may take objects, or a list
14 of objects, or a tuple of objects, and where objects may also be
15 Records.
16 """
17
18 from nmigen import Signal, Cat, Value
19 from nmigen.hdl.ast import ArrayProxy
20 from nmigen.hdl.rec import Record, Layout
21
22 from abc import ABCMeta, abstractmethod
23 from collections.abc import Sequence, Iterable
24 import inspect
25
26
27 class Visitor2:
28 """ a helper class for iterating twin-argument compound data structures.
29
30 Record is a special (unusual, recursive) case, where the input may be
31 specified as a dictionary (which may contain further dictionaries,
32 recursively), where the field names of the dictionary must match
33 the Record's field spec. Alternatively, an object with the same
34 member names as the Record may be assigned: it does not have to
35 *be* a Record.
36
37 ArrayProxy is also special-cased, it's a bit messy: whilst ArrayProxy
38 has an eq function, the object being assigned to it (e.g. a python
39 object) might not. despite the *input* having an eq function,
40 that doesn't help us, because it's the *ArrayProxy* that's being
41 assigned to. so.... we cheat. use the ports() function of the
42 python object, enumerate them, find out the list of Signals that way,
43 and assign them.
44 """
45
46 def iterator2(self, o, i):
47 if isinstance(o, dict):
48 yield from self.dict_iter2(o, i)
49
50 if not isinstance(o, Sequence):
51 o, i = [o], [i]
52 for (ao, ai) in zip(o, i):
53 # print ("visit", ao, ai)
54 # print (" isinstance Record(ao)", isinstance(ao, Record))
55 # print (" isinstance ArrayProxy(ao)",
56 # isinstance(ao, ArrayProxy))
57 # print (" isinstance Value(ai)",
58 # isinstance(ai, Value))
59 if isinstance(ao, Record):
60 yield from self.record_iter2(ao, ai)
61 elif isinstance(ao, ArrayProxy) and not isinstance(ai, Value):
62 yield from self.arrayproxy_iter2(ao, ai)
63 elif isinstance(ai, ArrayProxy) and not isinstance(ao, Value):
64 assert False, "whoops, input ArrayProxy not supported yet"
65 yield from self.arrayproxy_iter3(ao, ai)
66 else:
67 yield (ao, ai)
68
69 def dict_iter2(self, o, i):
70 for (k, v) in o.items():
71 # print ("d-iter", v, i[k])
72 yield (v, i[k])
73 return res
74
75 def _not_quite_working_with_all_unit_tests_record_iter2(self, ao, ai):
76 # print ("record_iter2", ao, ai, type(ao), type(ai))
77 if isinstance(ai, Value):
78 if isinstance(ao, Sequence):
79 ao, ai = [ao], [ai]
80 for o, i in zip(ao, ai):
81 yield (o, i)
82 return
83 for idx, (field_name, field_shape, _) in enumerate(ao.layout):
84 if isinstance(field_shape, Layout):
85 val = ai.fields
86 else:
87 val = ai
88 if hasattr(val, field_name): # check for attribute
89 val = getattr(val, field_name)
90 else:
91 val = val[field_name] # dictionary-style specification
92 yield from self.iterator2(ao.fields[field_name], val)
93
94 def record_iter2(self, ao, ai):
95 for idx, (field_name, field_shape, _) in enumerate(ao.layout):
96 if isinstance(field_shape, Layout):
97 val = ai.fields
98 else:
99 val = ai
100 if hasattr(val, field_name): # check for attribute
101 val = getattr(val, field_name)
102 else:
103 val = val[field_name] # dictionary-style specification
104 yield from self.iterator2(ao.fields[field_name], val)
105
106 def arrayproxy_iter2(self, ao, ai):
107 # print ("arrayproxy_iter2", ai.ports(), ai, ao)
108 for p in ai.ports():
109 # print ("arrayproxy - p", p, p.name, ao)
110 op = getattr(ao, p.name)
111 yield from self.iterator2(op, p)
112
113 def arrayproxy_iter3(self, ao, ai):
114 # print ("arrayproxy_iter3", ao.ports(), ai, ao)
115 for p in ao.ports():
116 # print ("arrayproxy - p", p, p.name, ao)
117 op = getattr(ao, p.name)
118 yield from self.iterator2(op, p)
119
120
121 class Visitor:
122 """ a helper class for iterating single-argument compound data structures.
123 similar to Visitor2.
124 """
125
126 def iterate(self, i):
127 """ iterate a compound structure recursively using yield
128 """
129 if not isinstance(i, Sequence):
130 i = [i]
131 for ai in i:
132 #print ("iterate", ai)
133 if isinstance(ai, Record):
134 #print ("record", list(ai.layout))
135 yield from self.record_iter(ai)
136 elif isinstance(ai, ArrayProxy) and not isinstance(ai, Value):
137 yield from self.array_iter(ai)
138 else:
139 yield ai
140
141 def record_iter(self, ai):
142 for idx, (field_name, field_shape, _) in enumerate(ai.layout):
143 if isinstance(field_shape, Layout):
144 val = ai.fields
145 else:
146 val = ai
147 if hasattr(val, field_name): # check for attribute
148 val = getattr(val, field_name)
149 else:
150 val = val[field_name] # dictionary-style specification
151 #print ("recidx", idx, field_name, field_shape, val)
152 yield from self.iterate(val)
153
154 def array_iter(self, ai):
155 for p in ai.ports():
156 yield from self.iterate(p)
157
158
159 def eq(o, i):
160 """ makes signals equal: a helper routine which identifies if it is being
161 passed a list (or tuple) of objects, or signals, or Records, and calls
162 the objects' eq function.
163 """
164 res = []
165 for (ao, ai) in Visitor2().iterator2(o, i):
166 rres = ao.eq(ai)
167 if not isinstance(rres, Sequence):
168 rres = [rres]
169 res += rres
170 return res
171
172
173 def shape(i):
174 #print ("shape", i)
175 r = 0
176 for part in list(i):
177 #print ("shape?", part)
178 s, _ = part.shape()
179 r += s
180 return r, False
181
182
183 def cat(i):
184 """ flattens a compound structure recursively using Cat
185 """
186 from nmigen._utils import flatten
187 # res = list(flatten(i)) # works (as of nmigen commit f22106e5) HOWEVER...
188 res = list(Visitor().iterate(i)) # needed because input may be a sequence
189 return Cat(*res)