remove unneeded imports
[nmutil.git] / src / nmutil / nmoperator.py
1 """ nmigen operator functions / utils
2
3 eq:
4 --
5
6 a strategically very important function that is identical in function
7 to nmigen's Signal.eq function, except it may take objects, or a list
8 of objects, or a tuple of objects, and where objects may also be
9 Records.
10 """
11
12 from nmigen import Signal, Cat, Value
13 from nmigen.hdl.ast import ArrayProxy
14 from nmigen.hdl.rec import Record, Layout
15
16 from abc import ABCMeta, abstractmethod
17 from collections.abc import Sequence, Iterable
18 import inspect
19
20
21 class Visitor2:
22 """ a helper class for iterating twin-argument compound data structures.
23
24 Record is a special (unusual, recursive) case, where the input may be
25 specified as a dictionary (which may contain further dictionaries,
26 recursively), where the field names of the dictionary must match
27 the Record's field spec. Alternatively, an object with the same
28 member names as the Record may be assigned: it does not have to
29 *be* a Record.
30
31 ArrayProxy is also special-cased, it's a bit messy: whilst ArrayProxy
32 has an eq function, the object being assigned to it (e.g. a python
33 object) might not. despite the *input* having an eq function,
34 that doesn't help us, because it's the *ArrayProxy* that's being
35 assigned to. so.... we cheat. use the ports() function of the
36 python object, enumerate them, find out the list of Signals that way,
37 and assign them.
38 """
39 def iterator2(self, o, i):
40 if isinstance(o, dict):
41 yield from self.dict_iter2(o, i)
42
43 if not isinstance(o, Sequence):
44 o, i = [o], [i]
45 for (ao, ai) in zip(o, i):
46 #print ("visit", fn, ao, ai)
47 if isinstance(ao, Record):
48 yield from self.record_iter2(ao, ai)
49 elif isinstance(ao, ArrayProxy) and not isinstance(ai, Value):
50 yield from self.arrayproxy_iter2(ao, ai)
51 else:
52 yield (ao, ai)
53
54 def dict_iter2(self, o, i):
55 for (k, v) in o.items():
56 print ("d-iter", v, i[k])
57 yield (v, i[k])
58 return res
59
60 def _not_quite_working_with_all_unit_tests_record_iter2(self, ao, ai):
61 print ("record_iter2", ao, ai, type(ao), type(ai))
62 if isinstance(ai, Value):
63 if isinstance(ao, Sequence):
64 ao, ai = [ao], [ai]
65 for o, i in zip(ao, ai):
66 yield (o, i)
67 return
68 for idx, (field_name, field_shape, _) in enumerate(ao.layout):
69 if isinstance(field_shape, Layout):
70 val = ai.fields
71 else:
72 val = ai
73 if hasattr(val, field_name): # check for attribute
74 val = getattr(val, field_name)
75 else:
76 val = val[field_name] # dictionary-style specification
77 yield from self.iterator2(ao.fields[field_name], val)
78
79 def record_iter2(self, ao, ai):
80 for idx, (field_name, field_shape, _) in enumerate(ao.layout):
81 if isinstance(field_shape, Layout):
82 val = ai.fields
83 else:
84 val = ai
85 if hasattr(val, field_name): # check for attribute
86 val = getattr(val, field_name)
87 else:
88 val = val[field_name] # dictionary-style specification
89 yield from self.iterator2(ao.fields[field_name], val)
90
91 def arrayproxy_iter2(self, ao, ai):
92 #print ("arrayproxy_iter2", ai.ports(), ai, ao)
93 for p in ai.ports():
94 #print ("arrayproxy - p", p, p.name, ao)
95 op = getattr(ao, p.name)
96 yield from self.iterator2(op, p)
97
98
99 class Visitor:
100 """ a helper class for iterating single-argument compound data structures.
101 similar to Visitor2.
102 """
103 def iterate(self, i):
104 """ iterate a compound structure recursively using yield
105 """
106 if not isinstance(i, Sequence):
107 i = [i]
108 for ai in i:
109 #print ("iterate", ai)
110 if isinstance(ai, Record):
111 #print ("record", list(ai.layout))
112 yield from self.record_iter(ai)
113 elif isinstance(ai, ArrayProxy) and not isinstance(ai, Value):
114 yield from self.array_iter(ai)
115 else:
116 yield ai
117
118 def record_iter(self, ai):
119 for idx, (field_name, field_shape, _) in enumerate(ai.layout):
120 if isinstance(field_shape, Layout):
121 val = ai.fields
122 else:
123 val = ai
124 if hasattr(val, field_name): # check for attribute
125 val = getattr(val, field_name)
126 else:
127 val = val[field_name] # dictionary-style specification
128 #print ("recidx", idx, field_name, field_shape, val)
129 yield from self.iterate(val)
130
131 def array_iter(self, ai):
132 for p in ai.ports():
133 yield from self.iterate(p)
134
135
136 def eq(o, i):
137 """ makes signals equal: a helper routine which identifies if it is being
138 passed a list (or tuple) of objects, or signals, or Records, and calls
139 the objects' eq function.
140 """
141 res = []
142 for (ao, ai) in Visitor2().iterator2(o, i):
143 rres = ao.eq(ai)
144 if not isinstance(rres, Sequence):
145 rres = [rres]
146 res += rres
147 return res
148
149
150 def shape(i):
151 #print ("shape", i)
152 r = 0
153 for part in list(i):
154 #print ("shape?", part)
155 s, _ = part.shape()
156 r += s
157 return r, False
158
159
160 def cat(i):
161 """ flattens a compound structure recursively using Cat
162 """
163 from nmigen._utils import flatten
164 #res = list(flatten(i)) # works (as of nmigen commit f22106e5) HOWEVER...
165 res = list(Visitor().iterate(i)) # needed because input may be a sequence
166 return Cat(*res)
167
168