change name of test function in addr_split
[soc.git] / src / soc / scoreboard / addr_split.py
1 # LDST Address Splitter. For misaligned address crossing cache line boundary
2 """
3 Links:
4 * https://libre-riscv.org/3d_gpu/architecture/6600scoreboard/
5 * http://bugs.libre-riscv.org/show_bug.cgi?id=257
6 * http://bugs.libre-riscv.org/show_bug.cgi?id=216
7 """
8
9 from nmigen import Elaboratable, Module, Signal, Record, Array, Const
10 from nmutil.latch import SRLatch, latchregister
11 from nmigen.back.pysim import Simulator, Delay
12 from nmigen.cli import verilog, rtlil
13
14 from soc.scoreboard.addr_match import LenExpand
15 #from nmutil.queue import Queue
16
17 class LDData(Record):
18 def __init__(self, dwidth, name=None):
19 Record.__init__(self, (('err', 1), ('data', dwidth)), name=name)
20
21
22 class LDLatch(Elaboratable):
23
24 def __init__(self, dwidth, awidth, mlen):
25 self.addr_i = Signal(awidth, reset_less=True)
26 self.mask_i = Signal(mlen, reset_less=True)
27 self.valid_i = Signal(reset_less=True)
28 self.ld_i = LDData(dwidth, "ld_i")
29 self.ld_o = LDData(dwidth, "ld_o")
30 self.valid_o = Signal(reset_less=True)
31
32 def elaborate(self, platform):
33 m = Module()
34 comb = m.d.comb
35 m.submodules.in_l = in_l = SRLatch(sync=False, name="in_l")
36
37 comb += in_l.s.eq(self.valid_i)
38 comb += self.valid_o.eq(in_l.q & self.valid_i)
39 latchregister(m, self.ld_i, self.ld_o, in_l.q & self.valid_o, "ld_i_r")
40
41 return m
42
43 def __iter__(self):
44 yield self.addr_i
45 yield self.mask_i
46 yield self.ld_i.err
47 yield self.ld_i.data
48 yield self.ld_o.err
49 yield self.ld_o.data
50 yield self.valid_i
51 yield self.valid_o
52
53 def ports(self):
54 return list(self)
55
56 class LDSTSplitter(Elaboratable):
57
58 def __init__(self, dwidth, awidth, dlen):
59 self.dwidth, self.awidth, self.dlen = dwidth, awidth, dlen
60 #cline_wid = 8<<dlen # cache line width: bytes (8) times (2^^dlen)
61 cline_wid = dwidth # TODO: make this bytes not bits
62 self.addr_i = Signal(awidth, reset_less=True)
63 self.len_i = Signal(dlen, reset_less=True)
64 self.valid_i = Signal(reset_less=True)
65 self.valid_o = Signal(reset_less=True)
66
67 self.is_ld_i = Signal(reset_less=True)
68 self.is_st_i = Signal(reset_less=True)
69
70 self.ld_data_o = LDData(dwidth, "ld_data_o")
71 self.st_data_i = LDData(dwidth, "st_data_i")
72
73 self.sld_valid_o = Signal(2, reset_less=True)
74 self.sld_valid_i = Signal(2, reset_less=True)
75 self.sld_data_i = Array((LDData(cline_wid, "ld_data_i1"),
76 LDData(cline_wid, "ld_data_i2")))
77
78 self.sst_valid_o = Signal(2, reset_less=True)
79 self.sst_valid_i = Signal(2, reset_less=True)
80 self.sst_data_o = Array((LDData(cline_wid, "st_data_i1"),
81 LDData(cline_wid, "st_data_i2")))
82
83 def elaborate(self, platform):
84 m = Module()
85 comb = m.d.comb
86 dlen = self.dlen
87 mlen = 1 << dlen
88 mzero = Const(0, mlen)
89 m.submodules.ld1 = ld1 = LDLatch(self.dwidth, self.awidth-dlen, mlen)
90 m.submodules.ld2 = ld2 = LDLatch(self.dwidth, self.awidth-dlen, mlen)
91 m.submodules.lenexp = lenexp = LenExpand(self.dlen)
92
93 # set up len-expander, len to mask. ld1 gets first bit, ld2 gets rest
94 comb += lenexp.addr_i.eq(self.addr_i)
95 comb += lenexp.len_i.eq(self.len_i)
96 mask1 = Signal(mlen, reset_less=True)
97 mask2 = Signal(mlen, reset_less=True)
98 comb += mask1.eq(lenexp.lexp_o[0:mlen]) # Lo bits of expanded len-mask
99 comb += mask2.eq(lenexp.lexp_o[mlen:]) # Hi bits of expanded len-mask
100
101 # set up new address records: addr1 is "as-is", addr2 is +1
102 comb += ld1.addr_i.eq(self.addr_i[dlen:])
103 comb += ld2.addr_i.eq(self.addr_i[dlen:] + 1) # TODO exception if rolls
104
105 # data needs recombining / splitting via shifting.
106 ashift1 = Signal(self.dlen, reset_less=True)
107 ashift2 = Signal(self.dlen, reset_less=True)
108 comb += ashift1.eq(self.addr_i[:self.dlen])
109 comb += ashift2.eq((1<<dlen)-ashift1)
110
111 with m.If(self.is_ld_i):
112 # set up connections to LD-split. note: not active if mask is zero
113 for i, (ld, mask) in enumerate(((ld1, mask1),
114 (ld2, mask2))):
115 ld_valid = Signal(name="ldvalid_i%d" % i, reset_less=True)
116 comb += ld_valid.eq(self.valid_i & self.sld_valid_i[i])
117 comb += ld.valid_i.eq(ld_valid & (mask != mzero))
118 comb += ld.ld_i.eq(self.sld_data_i[i])
119 comb += self.sld_valid_o[i].eq(ld.valid_o)
120
121 # sort out valid: mask2 zero we ignore 2nd LD
122 with m.If(mask2 == mzero):
123 comb += self.valid_o.eq(self.sld_valid_o[0])
124 with m.Else():
125 comb += self.valid_o.eq(self.sld_valid_o.all())
126
127 # all bits valid (including when data error occurs!) decode ld1/ld2
128 with m.If(self.valid_o):
129 # errors cause error condition
130 comb += self.ld_data_o.err.eq(ld1.ld_o.err | ld2.ld_o.err)
131
132 # note that data from LD1 will be in *cache-line* byte position
133 # likewise from LD2 but we *know* it is at the start of the line
134 comb += self.ld_data_o.data.eq((ld1.ld_o.data >> ashift1) |
135 (ld2.ld_o.data << ashift2))
136
137 with m.If(self.is_st_i):
138 for i, (ld, mask) in enumerate(((ld1, mask1),
139 (ld2, mask2))):
140 valid = Signal(name="stvalid_i%d" % i, reset_less=True)
141 comb += valid.eq(self.valid_i & self.sst_valid_i[i])
142 comb += ld.valid_i.eq(valid & (mask != mzero))
143 comb += self.sld_valid_o[i].eq(ld.valid_o)
144 comb += self.sst_data_o[i].data.eq(ld.ld_o.data)
145
146 comb += ld1.ld_i.eq((self.st_data_i << ashift1) & mask1)
147 comb += ld2.ld_i.eq((self.st_data_i >> ashift2) & mask2)
148
149 # sort out valid: mask2 zero we ignore 2nd LD
150 with m.If(mask2 == mzero):
151 comb += self.valid_o.eq(self.sst_valid_o[0])
152 with m.Else():
153 comb += self.valid_o.eq(self.sst_valid_o.all())
154
155 # all bits valid (including when data error occurs!) decode ld1/ld2
156 with m.If(self.valid_o):
157 # errors cause error condition
158 comb += self.st_data_i.err.eq(ld1.ld_o.err | ld2.ld_o.err)
159
160 return m
161
162 def __iter__(self):
163 yield self.addr_i
164 yield self.len_i
165 yield self.is_ld_i
166 yield self.ld_data_o.err
167 yield self.ld_data_o.data
168 yield self.valid_i
169 yield self.valid_o
170 yield self.sld_valid_i
171 for i in range(2):
172 yield self.sld_data_i[i].err
173 yield self.sld_data_i[i].data
174
175 def ports(self):
176 return list(self)
177
178 def sim(dut):
179
180 sim = Simulator(dut)
181 sim.add_clock(1e-6)
182 data = 0b11010011
183 dlen = 4 # 4 bits
184 addr = 0b1100
185 ld_len = 8
186 ldm = ((1<<ld_len)-1)
187 dlm = ((1<<dlen)-1)
188 data = data & ldm # truncate data to be tested, mask to within ld len
189 print ("ldm", ldm, bin(data&ldm))
190 print ("dlm", dlm, bin(addr&dlm))
191 dmask = ldm << (addr & dlm)
192 print ("dmask", bin(dmask))
193 dmask1 = dmask >> (1<<dlen)
194 print ("dmask1", bin(dmask1))
195 dmask = dmask & ((1<<(1<<dlen))-1)
196 print ("dmask", bin(dmask))
197
198 def send_ld():
199 print ("send_ld")
200 yield dut.is_ld_i.eq(1)
201 yield dut.len_i.eq(ld_len)
202 yield dut.addr_i.eq(addr)
203 yield dut.valid_i.eq(1)
204 print ("waiting")
205 while True:
206 valid_o = yield dut.valid_o
207 if valid_o:
208 break
209 yield
210 ld_data_o = yield dut.ld_data_o.data
211 yield dut.is_ld_i.eq(0)
212 yield
213
214 print (bin(ld_data_o), bin(data))
215 assert ld_data_o == data
216
217 def lds():
218 print ("lds")
219 while True:
220 valid_i = yield dut.valid_i
221 if valid_i:
222 break
223 yield
224
225 shf = addr & dlm
226 shfdata = (data << shf)
227 data1 = shfdata & dmask
228 print ("ld data1", bin(data), bin(data1), shf, bin(dmask))
229
230 data2 = (shfdata >> 16) & dmask1
231 print ("ld data2", 1<<dlen, bin(data >> (1<<dlen)), bin(data2))
232 yield dut.sld_data_i[0].data.eq(data1)
233 yield dut.sld_valid_i[0].eq(1)
234 yield
235 yield dut.sld_data_i[1].data.eq(data2)
236 yield dut.sld_valid_i[1].eq(1)
237 yield
238
239 sim.add_sync_process(lds)
240 sim.add_sync_process(send_ld)
241
242 prefix = "ldst_splitter"
243 with sim.write_vcd("%s.vcd" % prefix, traces=dut.ports()):
244 sim.run()
245
246
247 if __name__ == '__main__':
248 dut = LDSTSplitter(32, 48, 4)
249 vl = rtlil.convert(dut, ports=dut.ports())
250 with open("ldst_splitter.il", "w") as f:
251 f.write(vl)
252
253 sim(dut)