Allow the formal engine to perform a same-cycle result in the ALU
[soc.git] / src / soc / scoreboard / addr_split.py
1 # LDST Address Splitter. For misaligned address crossing cache line boundary
2 """
3 Links:
4 * https://libre-riscv.org/3d_gpu/architecture/6600scoreboard/
5 * http://bugs.libre-riscv.org/show_bug.cgi?id=257
6 * http://bugs.libre-riscv.org/show_bug.cgi?id=216
7 """
8
9 #from soc.experiment.pimem import PortInterface
10
11 from nmigen import Elaboratable, Module, Signal, Record, Const, Cat
12 from nmutil.latch import SRLatch, latchregister
13 from nmigen.back.pysim import Simulator, Delay
14 from nmigen.cli import verilog, rtlil
15
16 from soc.scoreboard.addr_match import LenExpand
17 #from nmutil.queue import Queue
18
19
20 class LDData(Record):
21 def __init__(self, dwidth, name=None):
22 Record.__init__(self, (('err', 1), ('data', dwidth)), name=name)
23
24
25 class LDLatch(Elaboratable):
26
27 def __init__(self, dwidth, awidth, mlen):
28 self.addr_i = Signal(awidth, reset_less=True)
29 self.mask_i = Signal(mlen, reset_less=True)
30 self.i_valid = Signal(reset_less=True)
31 self.ld_i = LDData(dwidth, "ld_i")
32 self.ld_o = LDData(dwidth, "ld_o")
33 self.o_valid = Signal(reset_less=True)
34
35 def elaborate(self, platform):
36 m = Module()
37 comb = m.d.comb
38 m.submodules.in_l = in_l = SRLatch(sync=False, name="in_l")
39
40 comb += in_l.s.eq(self.i_valid)
41 comb += self.o_valid.eq(in_l.q & self.i_valid)
42 latchregister(m, self.ld_i, self.ld_o, in_l.q & self.o_valid, "ld_i_r")
43
44 return m
45
46 def __iter__(self):
47 yield self.addr_i
48 yield self.mask_i
49 yield self.ld_i.err
50 yield self.ld_i.data
51 yield self.ld_o.err
52 yield self.ld_o.data
53 yield self.i_valid
54 yield self.o_valid
55
56 def ports(self):
57 return list(self)
58
59 def byteExpand(signal):
60 if(type(signal)==int):
61 ret = 0
62 shf = 0
63 while(signal>0):
64 bit = signal & 1
65 ret |= (0xFF * bit) << shf
66 signal = signal >> 1
67 shf += 8
68 return ret
69 lst = []
70 for i in range(len(signal)):
71 bit = signal[i]
72 for j in range(8): #TODO this can be optimized
73 lst += [bit]
74 return Cat(*lst)
75
76 class LDSTSplitter(Elaboratable):
77
78 def __init__(self, dwidth, awidth, dlen, pi=None):
79 self.dwidth, self.awidth, self.dlen = dwidth, awidth, dlen
80 # cline_wid = 8<<dlen # cache line width: bytes (8) times (2^^dlen)
81 cline_wid = dwidth*8 # convert bytes to bits
82
83 self.addr_i = Signal(awidth, reset_less=True)
84 # no match in PortInterface
85 self.len_i = Signal(dlen, reset_less=True)
86 self.i_valid = Signal(reset_less=True)
87 self.o_valid = Signal(reset_less=True)
88
89 self.is_ld_i = Signal(reset_less=True)
90 self.is_st_i = Signal(reset_less=True)
91
92 self.ld_data_o = LDData(dwidth*8, "ld_data_o") #port.ld
93 self.st_data_i = LDData(dwidth*8, "st_data_i") #port.st
94
95 self.exc = Signal(reset_less=True) # pi.exc TODO
96 # TODO : create/connect two outgoing port interfaces
97
98 self.sld_o_valid = Signal(2, reset_less=True)
99 self.sld_i_valid = Signal(2, reset_less=True)
100 self.sld_data_i = tuple((LDData(cline_wid, "ld_data_i1"),
101 LDData(cline_wid, "ld_data_i2")))
102
103 self.sst_o_valid = Signal(2, reset_less=True)
104 self.sst_i_valid = Signal(2, reset_less=True)
105 self.sst_data_o = tuple((LDData(cline_wid, "st_data_i1"),
106 LDData(cline_wid, "st_data_i2")))
107
108 def elaborate(self, platform):
109 m = Module()
110 comb = m.d.comb
111 dlen = self.dlen
112 mlen = 1 << dlen
113 mzero = Const(0, mlen)
114 m.submodules.ld1 = ld1 = LDLatch(self.dwidth*8, self.awidth-dlen, mlen)
115 m.submodules.ld2 = ld2 = LDLatch(self.dwidth*8, self.awidth-dlen, mlen)
116 m.submodules.lenexp = lenexp = LenExpand(self.dlen)
117
118 #comb += self.pi.addr_ok_o.eq(self.addr_i < 65536) #FIXME 64k limit
119 #comb += self.pi.busy_o.eq(busy)
120
121
122 # FIXME bytes not bits
123 # set up len-expander, len to mask. ld1 gets first bit, ld2 gets rest
124 comb += lenexp.addr_i.eq(self.addr_i)
125 comb += lenexp.len_i.eq(self.len_i)
126 mask1 = Signal(mlen, reset_less=True)
127 mask2 = Signal(mlen, reset_less=True)
128 comb += mask1.eq(lenexp.lexp_o[0:mlen]) # Lo bits of expanded len-mask
129 comb += mask2.eq(lenexp.lexp_o[mlen:]) # Hi bits of expanded len-mask
130
131 # set up new address records: addr1 is "as-is", addr2 is +1
132 comb += ld1.addr_i.eq(self.addr_i[dlen:])
133 ld2_value = self.addr_i[dlen:] + 1
134 comb += ld2.addr_i.eq(ld2_value)
135 # exception if rolls
136 with m.If(ld2_value[self.awidth-dlen]):
137 comb += self.exc.eq(1)
138
139 # data needs recombining / splitting via shifting.
140 ashift1 = Signal(self.dlen, reset_less=True)
141 ashift2 = Signal(self.dlen, reset_less=True)
142 comb += ashift1.eq(self.addr_i[:self.dlen])
143 comb += ashift2.eq((1 << dlen)-ashift1)
144
145 #expand masks
146 mask1 = byteExpand(mask1)
147 mask2 = byteExpand(mask2)
148 mzero = byteExpand(mzero)
149
150 with m.If(self.is_ld_i):
151 # set up connections to LD-split. note: not active if mask is zero
152 for i, (ld, mask) in enumerate(((ld1, mask1),
153 (ld2, mask2))):
154 ld_valid = Signal(name="ldi_valid%d" % i, reset_less=True)
155 comb += ld_valid.eq(self.i_valid & self.sld_i_valid[i])
156 comb += ld.i_valid.eq(ld_valid & (mask != mzero))
157 comb += ld.ld_i.eq(self.sld_data_i[i])
158 comb += self.sld_o_valid[i].eq(ld.o_valid)
159
160 # sort out valid: mask2 zero we ignore 2nd LD
161 with m.If(mask2 == mzero):
162 comb += self.o_valid.eq(self.sld_o_valid[0])
163 with m.Else():
164 comb += self.o_valid.eq(self.sld_o_valid.all())
165 ## debug output -- output mask2 and mzero
166 ## guess second port is invalid
167
168 # all bits valid (including when data error occurs!) decode ld1/ld2
169 with m.If(self.o_valid):
170 # errors cause error condition
171 comb += self.ld_data_o.err.eq(ld1.ld_o.err | ld2.ld_o.err)
172
173 # note that data from LD1 will be in *cache-line* byte position
174 # likewise from LD2 but we *know* it is at the start of the line
175 comb += self.ld_data_o.data.eq((ld1.ld_o.data >> (ashift1*8)) |
176 (ld2.ld_o.data << (ashift2*8)))
177
178 with m.If(self.is_st_i):
179 # set busy flag -- required for unit test
180 for i, (ld, mask) in enumerate(((ld1, mask1),
181 (ld2, mask2))):
182 valid = Signal(name="sti_valid%d" % i, reset_less=True)
183 comb += valid.eq(self.i_valid & self.sst_i_valid[i])
184 comb += ld.i_valid.eq(valid & (mask != mzero))
185 comb += self.sld_o_valid[i].eq(ld.o_valid)
186 comb += self.sst_data_o[i].data.eq(ld.ld_o.data)
187
188 comb += ld1.ld_i.eq((self.st_data_i << (ashift1*8)) & mask1)
189 comb += ld2.ld_i.eq((self.st_data_i >> (ashift2*8)) & mask2)
190
191 # sort out valid: mask2 zero we ignore 2nd LD
192 with m.If(mask2 == mzero):
193 comb += self.o_valid.eq(self.sst_o_valid[0])
194 with m.Else():
195 comb += self.o_valid.eq(self.sst_o_valid.all())
196
197 # all bits valid (including when data error occurs!) decode ld1/ld2
198 with m.If(self.o_valid):
199 # errors cause error condition
200 comb += self.st_data_i.err.eq(ld1.ld_o.err | ld2.ld_o.err)
201
202 return m
203
204 def __iter__(self):
205 yield self.addr_i
206 yield self.len_i
207 yield self.is_ld_i
208 yield self.ld_data_o.err
209 yield self.ld_data_o.data
210 yield self.i_valid
211 yield self.o_valid
212 yield self.sld_i_valid
213 for i in range(2):
214 yield self.sld_data_i[i].err
215 yield self.sld_data_i[i].data
216
217 def ports(self):
218 return list(self)
219
220
221 def sim(dut):
222
223 sim = Simulator(dut)
224 sim.add_clock(1e-6)
225 data = 0x0102030405060708A1A2A3A4A5A6A7A8
226 dlen = 16 # data length in bytes
227 addr = 0b1110
228 ld_len = 8
229 ldm = ((1 << ld_len)-1)
230 ldme = byteExpand(ldm)
231 dlm = ((1 << dlen)-1)
232 data = data & ldme # truncate data to be tested, mask to within ld len
233 print("ldm", ldm, hex(data & ldme))
234 print("dlm", dlm, bin(addr & dlm))
235
236 dmask = ldm << (addr & dlm)
237 print("dmask", bin(dmask))
238 dmask1 = dmask >> (1 << dlen)
239 print("dmask1", bin(dmask1))
240 dmask = dmask & ((1 << (1 << dlen))-1)
241 print("dmask", bin(dmask))
242 dmask1 = byteExpand(dmask1)
243 dmask = byteExpand(dmask)
244
245 def send_ld():
246 print("send_ld")
247 yield dut.is_ld_i.eq(1)
248 yield dut.len_i.eq(ld_len)
249 yield dut.addr_i.eq(addr)
250 yield dut.i_valid.eq(1)
251 print("waiting")
252 while True:
253 o_valid = yield dut.o_valid
254 if o_valid:
255 break
256 yield
257 exc = yield dut.exc
258 ld_data_o = yield dut.ld_data_o.data
259 yield dut.is_ld_i.eq(0)
260 yield
261
262 print(exc)
263 assert exc==0
264 print(hex(ld_data_o), hex(data))
265 assert ld_data_o == data
266
267 def lds():
268 print("lds")
269 while True:
270 i_valid = yield dut.i_valid
271 if i_valid:
272 break
273 yield
274
275 shf = (addr & dlm)*8 #shift bytes not bits
276 print("shf",shf/8.0)
277 shfdata = (data << shf)
278 data1 = shfdata & dmask
279 print("ld data1", hex(data), hex(data1), shf,shf/8.0, hex(dmask))
280
281 data2 = (shfdata >> 128) & dmask1
282 print("ld data2", 1 << dlen, hex(data >> (1 << dlen)), hex(data2))
283 yield dut.sld_data_i[0].data.eq(data1)
284 yield dut.sld_i_valid[0].eq(1)
285 yield
286 yield dut.sld_data_i[1].data.eq(data2)
287 yield dut.sld_i_valid[1].eq(1)
288 yield
289
290 sim.add_sync_process(lds)
291 sim.add_sync_process(send_ld)
292
293 prefix = "ldst_splitter"
294 with sim.write_vcd("%s.vcd" % prefix, traces=dut.ports()):
295 sim.run()
296
297
298 if __name__ == '__main__':
299 dut = LDSTSplitter(32, 48, 4)
300 vl = rtlil.convert(dut, ports=dut.ports())
301 with open("ldst_splitter.il", "w") as f:
302 f.write(vl)
303
304 sim(dut)