Flatten the output of RecordObject.ports()
[nmutil.git] / src / nmutil / iocontrol.py
1 """ IO Control API
2
3 Associated development bugs:
4 * http://bugs.libre-riscv.org/show_bug.cgi?id=148
5 * http://bugs.libre-riscv.org/show_bug.cgi?id=64
6 * http://bugs.libre-riscv.org/show_bug.cgi?id=57
7
8 Important: see Stage API (stageapi.py) in combination with below
9
10 Main classes: PrevControl and NextControl.
11
12 These classes manage the data and the synchronisation state
13 to the previous and next stage, respectively. ready/valid
14 signals are used by the Pipeline classes to tell if data
15 may be safely passed from stage to stage.
16
17 The connection from one stage to the next is carried out with
18 NextControl.connect_to_next. It is *not* necessary to have
19 a PrevControl.connect_to_prev because it is functionally
20 directly equivalent to prev->next->connect_to_next.
21 """
22
23 from nmigen import Signal, Cat, Const, Module, Value, Elaboratable
24 from nmigen.cli import verilog, rtlil
25 from nmigen.hdl.rec import Record
26
27 from collections.abc import Sequence, Iterable
28 from collections import OrderedDict
29
30 from nmutil import nmoperator
31
32
33 class Object:
34 def __init__(self):
35 self.fields = OrderedDict()
36
37 def __setattr__(self, k, v):
38 print ("kv", k, v)
39 if (k.startswith('_') or k in ["fields", "name", "src_loc"] or
40 k in dir(Object) or "fields" not in self.__dict__):
41 return object.__setattr__(self, k, v)
42 self.fields[k] = v
43
44 def __getattr__(self, k):
45 if k in self.__dict__:
46 return object.__getattr__(self, k)
47 try:
48 return self.fields[k]
49 except KeyError as e:
50 raise AttributeError(e)
51
52 def __iter__(self):
53 for x in self.fields.values(): # OrderedDict so order is preserved
54 if isinstance(x, Iterable):
55 yield from x
56 else:
57 yield x
58
59 def eq(self, inp):
60 res = []
61 for (k, o) in self.fields.items():
62 i = getattr(inp, k)
63 print ("eq", o, i)
64 rres = o.eq(i)
65 if isinstance(rres, Sequence):
66 res += rres
67 else:
68 res.append(rres)
69 print (res)
70 return res
71
72 def ports(self): # being called "keys" would be much better
73 return list(self)
74
75
76 class RecordObject(Record):
77 def __init__(self, layout=None, name=None):
78 Record.__init__(self, layout=layout or [], name=name)
79
80 def __setattr__(self, k, v):
81 #print (dir(Record))
82 if (k.startswith('_') or k in ["fields", "name", "src_loc"] or
83 k in dir(Record) or "fields" not in self.__dict__):
84 return object.__setattr__(self, k, v)
85 self.fields[k] = v
86 #print ("RecordObject setattr", k, v)
87 if isinstance(v, Record):
88 newlayout = {k: (k, v.layout)}
89 elif isinstance(v, Value):
90 newlayout = {k: (k, v.shape())}
91 else:
92 newlayout = {k: (k, nmoperator.shape(v))}
93 self.layout.fields.update(newlayout)
94
95 def __iter__(self):
96 for x in self.fields.values(): # remember: fields is an OrderedDict
97 if isinstance(x, Record):
98 for f in x.fields.values():
99 yield f
100 elif isinstance(x, Iterable):
101 yield from x # a bit like flatten (nmigen.tools)
102 else:
103 yield x
104
105 def ports(self): # would be better being called "keys"
106 results = []
107 # If the record itself contains records, flatten them
108 for item in list(self):
109 ports_fun = getattr(item, "ports", None)
110 if callable(ports_fun):
111 results.extend(ports_fun())
112 else:
113 results.append(item)
114 return results
115
116
117 class PrevControl(Elaboratable):
118 """ contains signals that come *from* the previous stage (both in and out)
119 * valid_i: previous stage indicating all incoming data is valid.
120 may be a multi-bit signal, where all bits are required
121 to be asserted to indicate "valid".
122 * ready_o: output to next stage indicating readiness to accept data
123 * data_i : an input - MUST be added by the USER of this class
124 """
125
126 def __init__(self, i_width=1, stage_ctl=False, maskwid=0, offs=0):
127 self.stage_ctl = stage_ctl
128 self.maskwid = maskwid
129 if maskwid:
130 self.mask_i = Signal(maskwid) # prev >>in self
131 self.stop_i = Signal(maskwid) # prev >>in self
132 self.valid_i = Signal(i_width, name="p_valid_i") # prev >>in self
133 self._ready_o = Signal(name="p_ready_o") # prev <<out self
134 self.data_i = None # XXX MUST BE ADDED BY USER
135 if stage_ctl:
136 self.s_ready_o = Signal(name="p_s_o_rdy") # prev <<out self
137 self.trigger = Signal(reset_less=True)
138
139 @property
140 def ready_o(self):
141 """ public-facing API: indicates (externally) that stage is ready
142 """
143 if self.stage_ctl:
144 return self.s_ready_o # set dynamically by stage
145 return self._ready_o # return this when not under dynamic control
146
147 def _connect_in(self, prev, direct=False, fn=None,
148 do_data=True, do_stop=True):
149 """ internal helper function to connect stage to an input source.
150 do not use to connect stage-to-stage!
151 """
152 valid_i = prev.valid_i if direct else prev.valid_i_test
153 res = [self.valid_i.eq(valid_i),
154 prev.ready_o.eq(self.ready_o)]
155 if self.maskwid:
156 res.append(self.mask_i.eq(prev.mask_i))
157 if do_stop:
158 res.append(self.stop_i.eq(prev.stop_i))
159 if do_data is False:
160 return res
161 data_i = fn(prev.data_i) if fn is not None else prev.data_i
162 return res + [nmoperator.eq(self.data_i, data_i)]
163
164 @property
165 def valid_i_test(self):
166 vlen = len(self.valid_i)
167 if vlen > 1:
168 # multi-bit case: valid only when valid_i is all 1s
169 all1s = Const(-1, (len(self.valid_i), False))
170 valid_i = (self.valid_i == all1s)
171 else:
172 # single-bit valid_i case
173 valid_i = self.valid_i
174
175 # when stage indicates not ready, incoming data
176 # must "appear" to be not ready too
177 if self.stage_ctl:
178 valid_i = valid_i & self.s_ready_o
179
180 return valid_i
181
182 def elaborate(self, platform):
183 m = Module()
184 m.d.comb += self.trigger.eq(self.valid_i_test & self.ready_o)
185 return m
186
187 def eq(self, i):
188 res = [nmoperator.eq(self.data_i, i.data_i),
189 self.ready_o.eq(i.ready_o),
190 self.valid_i.eq(i.valid_i)]
191 if self.maskwid:
192 res.append(self.mask_i.eq(i.mask_i))
193 return res
194
195 def __iter__(self):
196 yield self.valid_i
197 yield self.ready_o
198 if self.maskwid:
199 yield self.mask_i
200 yield self.stop_i
201 if hasattr(self.data_i, "ports"):
202 yield from self.data_i.ports()
203 elif (isinstance(self.data_i, Sequence) or
204 isinstance(self.data_i, Iterable)):
205 yield from self.data_i
206 else:
207 yield self.data_i
208
209 def ports(self):
210 return list(self)
211
212
213 class NextControl(Elaboratable):
214 """ contains the signals that go *to* the next stage (both in and out)
215 * valid_o: output indicating to next stage that data is valid
216 * ready_i: input from next stage indicating that it can accept data
217 * data_o : an output - MUST be added by the USER of this class
218 """
219 def __init__(self, stage_ctl=False, maskwid=0):
220 self.stage_ctl = stage_ctl
221 self.maskwid = maskwid
222 if maskwid:
223 self.mask_o = Signal(maskwid) # self out>> next
224 self.stop_o = Signal(maskwid) # self out>> next
225 self.valid_o = Signal(name="n_valid_o") # self out>> next
226 self.ready_i = Signal(name="n_ready_i") # self <<in next
227 self.data_o = None # XXX MUST BE ADDED BY USER
228 #if self.stage_ctl:
229 self.d_valid = Signal(reset=1) # INTERNAL (data valid)
230 self.trigger = Signal(reset_less=True)
231
232 @property
233 def ready_i_test(self):
234 if self.stage_ctl:
235 return self.ready_i & self.d_valid
236 return self.ready_i
237
238 def connect_to_next(self, nxt, do_data=True, do_stop=True):
239 """ helper function to connect to the next stage data/valid/ready.
240 data/valid is passed *TO* nxt, and ready comes *IN* from nxt.
241 use this when connecting stage-to-stage
242
243 note: a "connect_from_prev" is completely unnecessary: it's
244 just nxt.connect_to_next(self)
245 """
246 res = [nxt.valid_i.eq(self.valid_o),
247 self.ready_i.eq(nxt.ready_o)]
248 if self.maskwid:
249 res.append(nxt.mask_i.eq(self.mask_o))
250 if do_stop:
251 res.append(nxt.stop_i.eq(self.stop_o))
252 if do_data:
253 res.append(nmoperator.eq(nxt.data_i, self.data_o))
254 print ("connect to next", self, self.maskwid, nxt.data_i, do_data, do_stop)
255 return res
256
257 def _connect_out(self, nxt, direct=False, fn=None,
258 do_data=True, do_stop=True):
259 """ internal helper function to connect stage to an output source.
260 do not use to connect stage-to-stage!
261 """
262 ready_i = nxt.ready_i if direct else nxt.ready_i_test
263 res = [nxt.valid_o.eq(self.valid_o),
264 self.ready_i.eq(ready_i)]
265 if self.maskwid:
266 res.append(nxt.mask_o.eq(self.mask_o))
267 if do_stop:
268 res.append(nxt.stop_o.eq(self.stop_o))
269 if not do_data:
270 return res
271 data_o = fn(nxt.data_o) if fn is not None else nxt.data_o
272 return res + [nmoperator.eq(data_o, self.data_o)]
273
274 def elaborate(self, platform):
275 m = Module()
276 m.d.comb += self.trigger.eq(self.ready_i_test & self.valid_o)
277 return m
278
279 def __iter__(self):
280 yield self.ready_i
281 yield self.valid_o
282 if self.maskwid:
283 yield self.mask_o
284 yield self.stop_o
285 if hasattr(self.data_o, "ports"):
286 yield from self.data_o.ports()
287 elif (isinstance(self.data_o, Sequence) or
288 isinstance(self.data_o, Iterable)):
289 yield from self.data_o
290 else:
291 yield self.data_o
292
293 def ports(self):
294 return list(self)
295