7d8ead9e2d3e3b673b6e41a7074ffaab29f324f3
[soc.git] / src / regfile / regfile.py
1 from nmigen.compat.sim import run_simulation
2 from nmigen.cli import verilog, rtlil
3
4 from nmigen import Cat, Const, Array, Signal, Elaboratable, Module
5 from nmutil.iocontrol import RecordObject
6
7 from math import log
8 from functools import reduce
9 import operator
10
11
12 class Register(Elaboratable):
13 def __init__(self, width):
14 self.width = width
15 self._rdports = []
16 self._wrports = []
17
18 def read_port(self, name=None):
19 port = RecordObject([("ren", 1),
20 ("data_o", self.width)],
21 name=name)
22 self._rdports.append(port)
23 return port
24
25 def write_port(self, name=None):
26 port = RecordObject([("wen", 1),
27 ("data_i", self.width)],
28 name=name)
29 self._wrports.append(port)
30 return port
31
32 def elaborate(self, platform):
33 m = Module()
34 self.reg = reg = Signal(self.width, name="reg")
35
36 # read ports. has write-through detection (returns data written)
37 for rp in self._rdports:
38 wr_detect = Signal(reset_less=False)
39 with m.If(rp.ren):
40 m.d.comb += wr_detect.eq(0)
41 for wp in self._wrports:
42 with m.If(wp.wen):
43 m.d.comb += rp.data_o.eq(wp.data_i)
44 m.d.comb += wr_detect.eq(1)
45 with m.If(~wr_detect):
46 m.d.comb += rp.data_o.eq(reg)
47
48 # write ports, don't allow write to address 0 (ignore it)
49 for wp in self._wrports:
50 with m.If(wp.wen):
51 m.d.sync += reg.eq(wp.data_i)
52
53 return m
54
55 def __iter__(self):
56 for p in self._rdports:
57 yield from p
58 for p in self._wrports:
59 yield from p
60
61 def ports(self):
62 res = list(self)
63
64 def treereduce(tree):
65 #print ("treereduce", tree)
66 if not isinstance(tree, list):
67 return tree
68 if len(tree) == 1:
69 return tree[0].data_o
70 if len(tree) == 2:
71 return tree[0].data_o | tree[1].data_o
72 splitpoint = len(tree) // 2
73 return treereduce(tree[:splitpoint]) | treereduce(tree[splitpoint:])
74
75
76 class RegFileArray(Elaboratable):
77 """ an array-based register file (register having write-through capability)
78 that has no "address" decoder, instead it has individual write-en
79 and read-en signals (per port).
80 """
81 def __init__(self, width, depth):
82 self.width = width
83 self.depth = depth
84 self.regs = Array(Register(width) for _ in range(self.depth))
85 self._rdports = []
86 self._wrports = []
87
88 def read_port(self, name=None):
89 regs = []
90 for i in range(self.depth):
91 port = self.regs[i].read_port(name)
92 regs.append(port)
93 regs = Array(regs)
94 port = RecordObject([("ren", self.depth),
95 ("data_o", self.width)], name)
96 self._rdports.append((regs, port))
97 return port
98
99 def write_port(self, name=None):
100 regs = []
101 for i in range(self.depth):
102 port = self.regs[i].write_port(name)
103 regs.append(port)
104 regs = Array(regs)
105 port = RecordObject([("wen", self.depth),
106 ("data_i", self.width)])
107 self._wrports.append((regs, port))
108 return port
109
110 def _get_en_sig(self, port, typ):
111 wen = []
112 for p in port:
113 wen.append(p[typ])
114 return Cat(*wen)
115
116 def elaborate(self, platform):
117 m = Module()
118 for i, reg in enumerate(self.regs):
119 setattr(m.submodules, "reg_%d" % i, reg)
120
121 for (regs, p) in self._rdports:
122 #print (p)
123 m.d.comb += self._get_en_sig(regs, 'ren').eq(p.ren)
124 ror = treereduce(list(regs))
125 m.d.comb += p.data_o.eq(ror)
126 for (regs, p) in self._wrports:
127 m.d.comb += self._get_en_sig(regs, 'wen').eq(p.wen)
128 for r in regs:
129 m.d.comb += r.data_i.eq(p.data_i)
130
131 return m
132
133 def __iter__(self):
134 for r in self.regs:
135 yield from r
136
137 def ports(self):
138 return list(self)
139
140
141 class RegFile(Elaboratable):
142 def __init__(self, width, depth):
143 self.width = width
144 self.depth = depth
145 self._rdports = []
146 self._wrports = []
147
148 def read_port(self):
149 bsz = int(log(self.width) / log(2))
150 port = RecordObject([("raddr", bsz),
151 ("ren", 1),
152 ("data_o", self.width)])
153 self._rdports.append(port)
154 return port
155
156 def write_port(self):
157 bsz = int(log(self.width) / log(2))
158 port = RecordObject([("waddr", bsz),
159 ("wen", 1),
160 ("data_i", self.width)])
161 self._wrports.append(port)
162 return port
163
164 def elaborate(self, platform):
165 m = Module()
166 bsz = int(log(self.width) / log(2))
167 regs = Array(Signal(self.width, name="reg") for _ in range(self.depth))
168
169 # read ports. has write-through detection (returns data written)
170 for rp in self._rdports:
171 wr_detect = Signal(reset_less=False)
172 with m.If(rp.ren):
173 m.d.comb += wr_detect.eq(0)
174 for wp in self._wrports:
175 addrmatch = Signal(reset_less=False)
176 m.d.comb += addrmatch.eq(wp.waddr == rp.raddr)
177 with m.If(wp.wen & addrmatch):
178 m.d.comb += rp.data_o.eq(wp.data_i)
179 m.d.comb += wr_detect.eq(1)
180 with m.If(~wr_detect):
181 m.d.comb += rp.data_o.eq(regs[rp.raddr])
182
183 # write ports, don't allow write to address 0 (ignore it)
184 for wp in self._wrports:
185 with m.If(wp.wen & (wp.waddr != Const(0, bsz))):
186 m.d.sync += regs[wp.waddr].eq(wp.data_i)
187
188 return m
189
190 def __iter__(self):
191 yield from self._rdports
192 yield from self._wrports
193
194 def ports(self):
195 res = list(self)
196 for r in res:
197 if isinstance(r, RecordObject):
198 yield from r
199 else:
200 yield r
201
202 def regfile_sim(dut, rp, wp):
203 yield wp.waddr.eq(1)
204 yield wp.data_i.eq(2)
205 yield wp.wen.eq(1)
206 yield
207 yield wp.wen.eq(0)
208 yield rp.ren.eq(1)
209 yield rp.raddr.eq(1)
210 yield
211 data = yield rp.data_o
212 print (data)
213 assert data == 2
214
215 yield wp.waddr.eq(5)
216 yield rp.raddr.eq(5)
217 yield rp.ren.eq(1)
218 yield wp.wen.eq(1)
219 yield wp.data_i.eq(6)
220 data = yield rp.data_o
221 print (data)
222 yield
223 yield wp.wen.eq(0)
224 yield rp.ren.eq(0)
225 data = yield rp.data_o
226 print (data)
227 assert data == 6
228 yield
229 data = yield rp.data_o
230 print (data)
231
232 def regfile_array_sim(dut, rp1, rp2, wp):
233 yield wp.data_i.eq(2)
234 yield wp.wen.eq(1<<1)
235 yield
236 yield wp.wen.eq(0)
237 yield rp1.ren.eq(1<<1)
238 yield
239 data = yield rp1.data_o
240 print (data)
241 assert data == 2
242
243 yield rp1.ren.eq(1<<5)
244 yield rp2.ren.eq(1<<1)
245 yield wp.wen.eq(1<<5)
246 yield wp.data_i.eq(6)
247 data = yield rp1.data_o
248 print (data)
249 yield
250 yield wp.wen.eq(0)
251 yield rp1.ren.eq(0)
252 yield rp2.ren.eq(0)
253 data1 = yield rp1.data_o
254 print (data1)
255 data2 = yield rp2.data_o
256 print (data2)
257 assert data1 == 6
258 yield
259 data = yield rp1.data_o
260 print (data)
261
262 def test_regfile():
263 dut = RegFile(32, 8)
264 rp = dut.read_port()
265 wp = dut.write_port()
266 vl = rtlil.convert(dut, ports=dut.ports())
267 with open("test_regfile.il", "w") as f:
268 f.write(vl)
269
270 run_simulation(dut, regfile_sim(dut, rp, wp), vcd_name='test_regfile.vcd')
271
272 dut = RegFileArray(32, 8)
273 rp1 = dut.read_port("read1")
274 rp2 = dut.read_port("read2")
275 wp = dut.write_port("write")
276 ports=dut.ports()
277 print ("ports", ports)
278 vl = rtlil.convert(dut, ports=ports)
279 with open("test_regfile_array.il", "w") as f:
280 f.write(vl)
281
282 run_simulation(dut, regfile_array_sim(dut, rp1, rp2, wp),
283 vcd_name='test_regfile_array.vcd')
284
285 if __name__ == '__main__':
286 test_regfile()