add routing of store data through splitters
[soc.git] / src / soc / scoreboard / addr_split.py
1 # LDST Address Splitter. For misaligned address crossing cache line boundary
2
3 from nmigen import Elaboratable, Module, Signal, Record, Array, Const
4 from nmutil.latch import SRLatch, latchregister
5 from nmigen.back.pysim import Simulator, Delay
6 from nmigen.cli import verilog, rtlil
7
8 from soc.scoreboard.addr_match import LenExpand
9 #from nmutil.queue import Queue
10
11 class LDData(Record):
12 def __init__(self, dwidth, name=None):
13 Record.__init__(self, (('err', 1), ('data', dwidth)), name=name)
14
15
16 class LDLatch(Elaboratable):
17
18 def __init__(self, dwidth, awidth, mlen):
19 self.addr_i = Signal(awidth, reset_less=True)
20 self.mask_i = Signal(mlen, reset_less=True)
21 self.valid_i = Signal(reset_less=True)
22 self.ld_i = LDData(dwidth, "ld_i")
23 self.ld_o = LDData(dwidth, "ld_o")
24 self.valid_o = Signal(reset_less=True)
25
26 def elaborate(self, platform):
27 m = Module()
28 comb = m.d.comb
29 m.submodules.in_l = in_l = SRLatch(sync=False, name="in_l")
30
31 comb += in_l.s.eq(self.valid_i)
32 comb += self.valid_o.eq(in_l.q & self.valid_i)
33 latchregister(m, self.ld_i, self.ld_o, in_l.q & self.valid_o, "ld_i_r")
34
35 return m
36
37 def __iter__(self):
38 yield self.addr_i
39 yield self.mask_i
40 yield self.ld_i.err
41 yield self.ld_i.data
42 yield self.ld_o.err
43 yield self.ld_o.data
44 yield self.valid_i
45 yield self.valid_o
46
47 def ports(self):
48 return list(self)
49
50 class LDSTSplitter(Elaboratable):
51
52 def __init__(self, dwidth, awidth, dlen):
53 self.dwidth, self.awidth, self.dlen = dwidth, awidth, dlen
54 #cline_wid = 8<<dlen # cache line width: bytes (8) times (2^^dlen)
55 cline_wid = dwidth # TODO: make this bytes not bits
56 self.addr_i = Signal(awidth, reset_less=True)
57 self.len_i = Signal(dlen, reset_less=True)
58 self.valid_i = Signal(reset_less=True)
59 self.valid_o = Signal(reset_less=True)
60
61 self.is_ld_i = Signal(reset_less=True)
62 self.is_st_i = Signal(reset_less=True)
63
64 self.ld_data_o = LDData(dwidth, "ld_data_o")
65 self.st_data_i = LDData(dwidth, "st_data_i")
66
67 self.sld_valid_o = Signal(2, reset_less=True)
68 self.sld_valid_i = Signal(2, reset_less=True)
69 self.sld_data_i = Array((LDData(cline_wid, "ld_data_i1"),
70 LDData(cline_wid, "ld_data_i2")))
71
72 self.sst_valid_o = Signal(2, reset_less=True)
73 self.sst_valid_i = Signal(2, reset_less=True)
74 self.sst_data_o = Array((LDData(cline_wid, "st_data_i1"),
75 LDData(cline_wid, "st_data_i2")))
76
77 def elaborate(self, platform):
78 m = Module()
79 comb = m.d.comb
80 dlen = self.dlen
81 mlen = 1 << dlen
82 mzero = Const(0, mlen)
83 m.submodules.ld1 = ld1 = LDLatch(self.dwidth, self.awidth-dlen, mlen)
84 m.submodules.ld2 = ld2 = LDLatch(self.dwidth, self.awidth-dlen, mlen)
85 m.submodules.lenexp = lenexp = LenExpand(self.dlen)
86
87 # set up len-expander, len to mask. ld1 gets first bit, ld2 gets rest
88 comb += lenexp.addr_i.eq(self.addr_i)
89 comb += lenexp.len_i.eq(self.len_i)
90 mask1 = Signal(mlen, reset_less=True)
91 mask2 = Signal(mlen, reset_less=True)
92 comb += mask1.eq(lenexp.lexp_o[0:mlen]) # Lo bits of expanded len-mask
93 comb += mask2.eq(lenexp.lexp_o[mlen:]) # Hi bits of expanded len-mask
94
95 # set up new address records: addr1 is "as-is", addr2 is +1
96 comb += ld1.addr_i.eq(self.addr_i[dlen:])
97 comb += ld2.addr_i.eq(self.addr_i[dlen:] + 1) # TODO exception if rolls
98
99 # data needs recombining / splitting via shifting.
100 ashift1 = Signal(self.dlen, reset_less=True)
101 ashift2 = Signal(self.dlen, reset_less=True)
102 comb += ashift1.eq(self.addr_i[:self.dlen])
103 comb += ashift2.eq((1<<dlen)-ashift1)
104
105 with m.If(self.is_ld_i):
106 # set up connections to LD-split. note: not active if mask is zero
107 for i, (ld, mask) in enumerate(((ld1, mask1),
108 (ld2, mask2))):
109 ld_valid = Signal(name="ldvalid_i%d" % i, reset_less=True)
110 comb += ld_valid.eq(self.valid_i & self.sld_valid_i[i])
111 comb += ld.valid_i.eq(ld_valid & (mask != mzero))
112 comb += ld.ld_i.eq(self.sld_data_i[i])
113 comb += self.sld_valid_o[i].eq(ld.valid_o)
114
115 # sort out valid: mask2 zero we ignore 2nd LD
116 with m.If(mask2 == mzero):
117 comb += self.valid_o.eq(self.sld_valid_o[0])
118 with m.Else():
119 comb += self.valid_o.eq(self.sld_valid_o.all())
120
121 # all bits valid (including when data error occurs!) decode ld1/ld2
122 with m.If(self.valid_o):
123 # errors cause error condition
124 comb += self.ld_data_o.err.eq(ld1.ld_o.err | ld2.ld_o.err)
125
126 # note that data from LD1 will be in *cache-line* byte position
127 # likewise from LD2 but we *know* it is at the start of the line
128 comb += self.ld_data_o.data.eq((ld1.ld_o.data >> ashift1) |
129 (ld2.ld_o.data << ashift2))
130
131 with m.If(self.is_st_i):
132 for i, (ld, mask) in enumerate(((ld1, mask1),
133 (ld2, mask2))):
134 valid = Signal(name="stvalid_i%d" % i, reset_less=True)
135 comb += valid.eq(self.valid_i & self.sst_valid_i[i])
136 comb += ld.valid_i.eq(valid & (mask != mzero))
137 comb += self.sld_valid_o[i].eq(ld.valid_o)
138 comb += self.sst_data_o[i].data.eq(ld.ld_o.data)
139
140 comb += ld1.ld_i.eq((self.st_data_i << ashift1) & mask1)
141 comb += ld2.ld_i.eq((self.st_data_i >> ashift2) & mask2)
142
143 # sort out valid: mask2 zero we ignore 2nd LD
144 with m.If(mask2 == mzero):
145 comb += self.valid_o.eq(self.sst_valid_o[0])
146 with m.Else():
147 comb += self.valid_o.eq(self.sst_valid_o.all())
148
149 # all bits valid (including when data error occurs!) decode ld1/ld2
150 with m.If(self.valid_o):
151 # errors cause error condition
152 comb += self.st_data_i.err.eq(ld1.ld_o.err | ld2.ld_o.err)
153
154 return m
155
156 def __iter__(self):
157 yield self.addr_i
158 yield self.len_i
159 yield self.is_ld_i
160 yield self.ld_data_o.err
161 yield self.ld_data_o.data
162 yield self.valid_i
163 yield self.valid_o
164 yield self.sld_valid_i
165 for i in range(2):
166 yield self.sld_data_i[i].err
167 yield self.sld_data_i[i].data
168
169 def ports(self):
170 return list(self)
171
172 def sim(dut):
173
174 sim = Simulator(dut)
175 sim.add_clock(1e-6)
176 data = 0b11010011
177 dlen = 4 # 4 bits
178 addr = 0b1100
179 ld_len = 8
180 ldm = ((1<<ld_len)-1)
181 dlm = ((1<<dlen)-1)
182 data = data & ldm # truncate data to be tested, mask to within ld len
183 print ("ldm", ldm, bin(data&ldm))
184 print ("dlm", dlm, bin(addr&dlm))
185 dmask = ldm << (addr & dlm)
186 print ("dmask", bin(dmask))
187 dmask1 = dmask >> (1<<dlen)
188 print ("dmask1", bin(dmask1))
189 dmask = dmask & ((1<<(1<<dlen))-1)
190 print ("dmask", bin(dmask))
191
192 def send_in():
193 print ("send_in")
194 yield dut.is_ld_i.eq(1)
195 yield dut.len_i.eq(ld_len)
196 yield dut.addr_i.eq(addr)
197 yield dut.valid_i.eq(1)
198 print ("waiting")
199 while True:
200 valid_o = yield dut.valid_o
201 if valid_o:
202 break
203 yield
204 ld_data_o = yield dut.ld_data_o.data
205 yield dut.is_ld_i.eq(0)
206 yield
207
208 print (bin(ld_data_o), bin(data))
209 assert ld_data_o == data
210
211 def lds():
212 print ("lds")
213 while True:
214 valid_i = yield dut.valid_i
215 if valid_i:
216 break
217 yield
218
219 shf = addr & dlm
220 shfdata = (data << shf)
221 data1 = shfdata & dmask
222 print ("ld data1", bin(data), bin(data1), shf, bin(dmask))
223
224 data2 = (shfdata >> 16) & dmask1
225 print ("ld data2", 1<<dlen, bin(data >> (1<<dlen)), bin(data2))
226 yield dut.sld_data_i[0].data.eq(data1)
227 yield dut.sld_valid_i[0].eq(1)
228 yield
229 yield dut.sld_data_i[1].data.eq(data2)
230 yield dut.sld_valid_i[1].eq(1)
231 yield
232
233 sim.add_sync_process(lds)
234 sim.add_sync_process(send_in)
235
236 prefix = "ldst_splitter"
237 with sim.write_vcd("%s.vcd" % prefix, traces=dut.ports()):
238 sim.run()
239
240
241 if __name__ == '__main__':
242 dut = LDSTSplitter(32, 48, 4)
243 vl = rtlil.convert(dut, ports=dut.ports())
244 with open("ldst_splitter.il", "w") as f:
245 f.write(vl)
246
247 sim(dut)