https://bugs.libre-soc.org/show_bug.cgi?id=446
[soc.git] / src / soc / scoreboard / addr_split.py
1 # LDST Address Splitter. For misaligned address crossing cache line boundary
2 """
3 Links:
4 * https://libre-riscv.org/3d_gpu/architecture/6600scoreboard/
5 * http://bugs.libre-riscv.org/show_bug.cgi?id=257
6 * http://bugs.libre-riscv.org/show_bug.cgi?id=216
7 """
8
9 from soc.experiment.pimem import PortInterface
10
11 from nmigen import Elaboratable, Module, Signal, Record, Array, Const
12 from nmutil.latch import SRLatch, latchregister
13 from nmigen.back.pysim import Simulator, Delay
14 from nmigen.cli import verilog, rtlil
15
16 from soc.scoreboard.addr_match import LenExpand
17 #from nmutil.queue import Queue
18
19
20 class LDData(Record):
21 def __init__(self, dwidth, name=None):
22 Record.__init__(self, (('err', 1), ('data', dwidth)), name=name)
23
24
25 class LDLatch(Elaboratable):
26
27 def __init__(self, dwidth, awidth, mlen):
28 self.addr_i = Signal(awidth, reset_less=True)
29 self.mask_i = Signal(mlen, reset_less=True)
30 self.valid_i = Signal(reset_less=True)
31 self.ld_i = LDData(dwidth, "ld_i")
32 self.ld_o = LDData(dwidth, "ld_o")
33 self.valid_o = Signal(reset_less=True)
34
35 def elaborate(self, platform):
36 m = Module()
37 comb = m.d.comb
38 m.submodules.in_l = in_l = SRLatch(sync=False, name="in_l")
39
40 comb += in_l.s.eq(self.valid_i)
41 comb += self.valid_o.eq(in_l.q & self.valid_i)
42 latchregister(m, self.ld_i, self.ld_o, in_l.q & self.valid_o, "ld_i_r")
43
44 return m
45
46 def __iter__(self):
47 yield self.addr_i
48 yield self.mask_i
49 yield self.ld_i.err
50 yield self.ld_i.data
51 yield self.ld_o.err
52 yield self.ld_o.data
53 yield self.valid_i
54 yield self.valid_o
55
56 def ports(self):
57 return list(self)
58
59
60 class LDSTSplitter(Elaboratable):
61
62 def __init__(self, dwidth, awidth, dlen):
63 self.dwidth, self.awidth, self.dlen = dwidth, awidth, dlen
64 # cline_wid = 8<<dlen # cache line width: bytes (8) times (2^^dlen)
65 cline_wid = dwidth # TODO: make this bytes not bits
66
67 self.pi = PortInterface()
68
69 self.addr_i = self.pi.addr.data #Signal(awidth, reset_less=True)
70 # no match in PortInterface
71 self.len_i = Signal(dlen, reset_less=True)
72 self.valid_i = Signal(reset_less=True)
73 self.valid_o = Signal(reset_less=True)
74
75 self.is_ld_i = self.pi.is_ld_i #Signal(reset_less=True)
76 self.is_st_i = self.pi.is_st_i #Signal(reset_less=True)
77
78 self.ld_data_o = LDData(dwidth, "ld_data_o") #port.ld
79 self.st_data_i = LDData(dwidth, "st_data_i") #port.st
80
81 self.exc = Signal(reset_less=True) # pi.exc TODO
82
83 # TODO : create/connect two outgoing port interfaces
84
85 self.sld_valid_o = Signal(2, reset_less=True)
86 self.sld_valid_i = Signal(2, reset_less=True)
87 self.sld_data_i = Array((LDData(cline_wid, "ld_data_i1"),
88 LDData(cline_wid, "ld_data_i2")))
89
90 self.sst_valid_o = Signal(2, reset_less=True)
91 self.sst_valid_i = Signal(2, reset_less=True)
92 self.sst_data_o = Array((LDData(cline_wid, "st_data_i1"),
93 LDData(cline_wid, "st_data_i2")))
94
95 def elaborate(self, platform):
96 m = Module()
97 comb = m.d.comb
98 dlen = self.dlen
99 mlen = 1 << dlen
100 mzero = Const(0, mlen)
101 m.submodules.ld1 = ld1 = LDLatch(self.dwidth, self.awidth-dlen, mlen)
102 m.submodules.ld2 = ld2 = LDLatch(self.dwidth, self.awidth-dlen, mlen)
103 m.submodules.lenexp = lenexp = LenExpand(self.dlen)
104
105 # set up len-expander, len to mask. ld1 gets first bit, ld2 gets rest
106 comb += lenexp.addr_i.eq(self.addr_i)
107 comb += lenexp.len_i.eq(self.len_i)
108 mask1 = Signal(mlen, reset_less=True)
109 mask2 = Signal(mlen, reset_less=True)
110 comb += mask1.eq(lenexp.lexp_o[0:mlen]) # Lo bits of expanded len-mask
111 comb += mask2.eq(lenexp.lexp_o[mlen:]) # Hi bits of expanded len-mask
112
113 # set up new address records: addr1 is "as-is", addr2 is +1
114 comb += ld1.addr_i.eq(self.addr_i[dlen:])
115 ld2_value = self.addr_i[dlen:] + 1
116 comb += ld2.addr_i.eq(ld2_value)
117 # exception if rolls
118 with m.If(ld2_value[self.awidth-dlen]):
119 comb += self.exc.eq(1)
120
121 # data needs recombining / splitting via shifting.
122 ashift1 = Signal(self.dlen, reset_less=True)
123 ashift2 = Signal(self.dlen, reset_less=True)
124 comb += ashift1.eq(self.addr_i[:self.dlen])
125 comb += ashift2.eq((1 << dlen)-ashift1)
126
127 with m.If(self.is_ld_i):
128 # set up connections to LD-split. note: not active if mask is zero
129 for i, (ld, mask) in enumerate(((ld1, mask1),
130 (ld2, mask2))):
131 ld_valid = Signal(name="ldvalid_i%d" % i, reset_less=True)
132 comb += ld_valid.eq(self.valid_i & self.sld_valid_i[i])
133 comb += ld.valid_i.eq(ld_valid & (mask != mzero))
134 comb += ld.ld_i.eq(self.sld_data_i[i])
135 comb += self.sld_valid_o[i].eq(ld.valid_o)
136
137 # sort out valid: mask2 zero we ignore 2nd LD
138 with m.If(mask2 == mzero):
139 comb += self.valid_o.eq(self.sld_valid_o[0])
140 with m.Else():
141 comb += self.valid_o.eq(self.sld_valid_o.all())
142
143 # all bits valid (including when data error occurs!) decode ld1/ld2
144 with m.If(self.valid_o):
145 # errors cause error condition
146 comb += self.ld_data_o.err.eq(ld1.ld_o.err | ld2.ld_o.err)
147
148 # note that data from LD1 will be in *cache-line* byte position
149 # likewise from LD2 but we *know* it is at the start of the line
150 comb += self.ld_data_o.data.eq((ld1.ld_o.data >> ashift1) |
151 (ld2.ld_o.data << ashift2))
152
153 with m.If(self.is_st_i):
154 for i, (ld, mask) in enumerate(((ld1, mask1),
155 (ld2, mask2))):
156 valid = Signal(name="stvalid_i%d" % i, reset_less=True)
157 comb += valid.eq(self.valid_i & self.sst_valid_i[i])
158 comb += ld.valid_i.eq(valid & (mask != mzero))
159 comb += self.sld_valid_o[i].eq(ld.valid_o)
160 comb += self.sst_data_o[i].data.eq(ld.ld_o.data)
161
162 comb += ld1.ld_i.eq((self.st_data_i << ashift1) & mask1)
163 comb += ld2.ld_i.eq((self.st_data_i >> ashift2) & mask2)
164
165 # sort out valid: mask2 zero we ignore 2nd LD
166 with m.If(mask2 == mzero):
167 comb += self.valid_o.eq(self.sst_valid_o[0])
168 with m.Else():
169 comb += self.valid_o.eq(self.sst_valid_o.all())
170
171 # all bits valid (including when data error occurs!) decode ld1/ld2
172 with m.If(self.valid_o):
173 # errors cause error condition
174 comb += self.st_data_i.err.eq(ld1.ld_o.err | ld2.ld_o.err)
175
176 return m
177
178 def __iter__(self):
179 yield self.addr_i
180 yield self.len_i
181 yield self.is_ld_i
182 yield self.ld_data_o.err
183 yield self.ld_data_o.data
184 yield self.valid_i
185 yield self.valid_o
186 yield self.sld_valid_i
187 for i in range(2):
188 yield self.sld_data_i[i].err
189 yield self.sld_data_i[i].data
190
191 def ports(self):
192 return list(self)
193
194
195 def sim(dut):
196
197 sim = Simulator(dut)
198 sim.add_clock(1e-6)
199 data = 0b11010011
200 dlen = 4 # 4 bits
201 addr = 0b1100
202 ld_len = 8
203 ldm = ((1 << ld_len)-1)
204 dlm = ((1 << dlen)-1)
205 data = data & ldm # truncate data to be tested, mask to within ld len
206 print("ldm", ldm, bin(data & ldm))
207 print("dlm", dlm, bin(addr & dlm))
208 dmask = ldm << (addr & dlm)
209 print("dmask", bin(dmask))
210 dmask1 = dmask >> (1 << dlen)
211 print("dmask1", bin(dmask1))
212 dmask = dmask & ((1 << (1 << dlen))-1)
213 print("dmask", bin(dmask))
214
215 def send_ld():
216 print("send_ld")
217 yield dut.is_ld_i.eq(1)
218 yield dut.len_i.eq(ld_len)
219 yield dut.addr_i.eq(addr)
220 yield dut.valid_i.eq(1)
221 print("waiting")
222 while True:
223 valid_o = yield dut.valid_o
224 if valid_o:
225 break
226 yield
227 ld_data_o = yield dut.ld_data_o.data
228 yield dut.is_ld_i.eq(0)
229 yield
230
231 print(bin(ld_data_o), bin(data))
232 assert ld_data_o == data
233
234 def lds():
235 print("lds")
236 while True:
237 valid_i = yield dut.valid_i
238 if valid_i:
239 break
240 yield
241
242 shf = addr & dlm
243 shfdata = (data << shf)
244 data1 = shfdata & dmask
245 print("ld data1", bin(data), bin(data1), shf, bin(dmask))
246
247 data2 = (shfdata >> 16) & dmask1
248 print("ld data2", 1 << dlen, bin(data >> (1 << dlen)), bin(data2))
249 yield dut.sld_data_i[0].data.eq(data1)
250 yield dut.sld_valid_i[0].eq(1)
251 yield
252 yield dut.sld_data_i[1].data.eq(data2)
253 yield dut.sld_valid_i[1].eq(1)
254 yield
255
256 sim.add_sync_process(lds)
257 sim.add_sync_process(send_ld)
258
259 prefix = "ldst_splitter"
260 with sim.write_vcd("%s.vcd" % prefix, traces=dut.ports()):
261 sim.run()
262
263
264 if __name__ == '__main__':
265 dut = LDSTSplitter(32, 48, 4)
266 vl = rtlil.convert(dut, ports=dut.ports())
267 with open("ldst_splitter.il", "w") as f:
268 f.write(vl)
269
270 sim(dut)