hmm start adding st in (half done)
[soc.git] / src / soc / scoreboard / addr_split.py
1 # LDST Address Splitter. For misaligned address crossing cache line boundary
2
3 from nmigen import Elaboratable, Module, Signal, Record, Array, Const
4 from nmutil.latch import SRLatch, latchregister
5 from nmigen.back.pysim import Simulator, Delay
6 from nmigen.cli import verilog, rtlil
7
8 from soc.scoreboard.addr_match import LenExpand
9 #from nmutil.queue import Queue
10
11 class LDData(Record):
12 def __init__(self, dwidth, name=None):
13 Record.__init__(self, (('err', 1), ('data', dwidth)), name=name)
14
15
16 class LDLatch(Elaboratable):
17
18 def __init__(self, dwidth, awidth, mlen):
19 self.addr_i = Signal(awidth, reset_less=True)
20 self.mask_i = Signal(mlen, reset_less=True)
21 self.valid_i = Signal(reset_less=True)
22 self.ld_i = LDData(dwidth, "ld_i")
23 self.ld_o = LDData(dwidth, "ld_o")
24 self.valid_o = Signal(reset_less=True)
25
26 def elaborate(self, platform):
27 m = Module()
28 comb = m.d.comb
29 m.submodules.in_l = in_l = SRLatch(sync=False, name="in_l")
30
31 comb += in_l.s.eq(self.valid_i)
32 comb += self.valid_o.eq(in_l.q & self.valid_i)
33 latchregister(m, self.ld_i, self.ld_o, in_l.q & self.valid_o, "ld_i_r")
34
35 return m
36
37 def __iter__(self):
38 yield self.addr_i
39 yield self.mask_i
40 yield self.ld_i.err
41 yield self.ld_i.data
42 yield self.ld_o.err
43 yield self.ld_o.data
44 yield self.valid_i
45 yield self.valid_o
46
47 def ports(self):
48 return list(self)
49
50 class LDSTSplitter(Elaboratable):
51
52 def __init__(self, dwidth, awidth, dlen):
53 self.dwidth, self.awidth, self.dlen = dwidth, awidth, dlen
54 self.addr_i = Signal(awidth, reset_less=True)
55 self.len_i = Signal(dlen, reset_less=True)
56 self.valid_i = Signal(reset_less=True)
57 self.valid_o = Signal(reset_less=True)
58
59 self.is_ld_i = Signal(reset_less=True)
60 self.is_st_i = Signal(reset_less=True)
61
62 self.ld_data_o = LDData(dwidth, "ld_data_o")
63 self.st_data_i = LDData(dwidth, "st_data_i")
64
65 self.sld_valid_o = Signal(2, reset_less=True)
66 self.sld_valid_i = Signal(2, reset_less=True)
67 self.sld_data_i = Array((LDData(dwidth, "ld_data_i1"),
68 LDData(dwidth, "ld_data_i2")))
69
70 self.sst_valid_o = Signal(2, reset_less=True)
71 self.sst_valid_i = Signal(2, reset_less=True)
72 self.sst_data_o = Array((LDData(dwidth, "st_data_i1"),
73 LDData(dwidth, "st_data_i2")))
74
75 def elaborate(self, platform):
76 m = Module()
77 comb = m.d.comb
78 dlen = self.dlen
79 mlen = 1 << dlen
80 m.submodules.ld1 = ld1 = LDLatch(self.dwidth, self.awidth-dlen, mlen)
81 m.submodules.ld2 = ld2 = LDLatch(self.dwidth, self.awidth-dlen, mlen)
82 m.submodules.lenexp = lenexp = LenExpand(self.dlen)
83
84 # set up len-expander, len to mask. ld1 gets first bit, ld2 gets rest
85 comb += lenexp.addr_i.eq(self.addr_i)
86 comb += lenexp.len_i.eq(self.len_i)
87 mask1 = Signal(mlen, reset_less=True)
88 mask2 = Signal(mlen, reset_less=True)
89 comb += mask1.eq(lenexp.lexp_o[0:mlen]) # Lo bits of expanded len-mask
90 comb += mask2.eq(lenexp.lexp_o[mlen:]) # Hi bits of expanded len-mask
91
92 # set up new address records: addr1 is "as-is", addr2 is +1
93 comb += ld1.addr_i.eq(self.addr_i[dlen:])
94 comb += ld2.addr_i.eq(self.addr_i[dlen:] + 1) # TODO exception if rolls
95
96 with m.If(self.is_ld_i):
97 # set up connections to LD-split. note: not active if mask is zero
98 mzero = Const(0, mlen)
99 for i, (ld, mask) in enumerate(((ld1, mask1),
100 (ld2, mask2))):
101 ld_valid = Signal(name="ldvalid_i%d" % i, reset_less=True)
102 comb += ld_valid.eq(self.valid_i & self.sld_valid_i[i])
103 comb += ld.valid_i.eq(ld_valid & (mask != mzero))
104 comb += ld.ld_i.eq(self.sld_data_i[i])
105 comb += self.sld_valid_o[i].eq(ld.valid_o)
106
107 # sort out valid: mask2 zero we ignore 2nd LD
108 with m.If(mask2 == mzero):
109 comb += self.valid_o.eq(self.sld_valid_o[0])
110 with m.Else():
111 comb += self.valid_o.eq(self.sld_valid_o.all())
112
113 # all bits valid (including when data error occurs!) decode ld1/ld2
114 with m.If(self.valid_o):
115 # errors cause error condition
116 comb += self.ld_data_o.err.eq(ld1.ld_o.err | ld2.ld_o.err)
117 # data needs recombining via shifting.
118 ashift1 = Signal(self.dlen)
119 ashift2 = Signal(self.dlen)
120 comb += ashift1.eq(self.addr_i[:self.dlen])
121 comb += ashift2.eq((1<<dlen)-ashift1)
122 # note that data from LD1 will be in *cache-line* byte position
123 # likewise from LD2 but we *know* it is at the start of the line
124 comb += self.ld_data_o.data.eq((ld1.ld_o.data >> ashift1) |
125 (ld2.ld_o.data << ashift2))
126
127 with m.If(self.is_st_i):
128 mzero = Const(0, mlen)
129 for i, (ld, mask) in enumerate(((ld1, mask1),
130 (ld2, mask2))):
131 valid = Signal(name="stvalid_i%d" % i, reset_less=True)
132 comb += valid.eq(self.valid_i & self.sst_valid_i[i])
133 comb += ld.valid_i.eq(valid & (mask != mzero))
134 comb += self.sld_valid_o[i].eq(ld.valid_o)
135
136 comb += ld1.ld_i.eq((self.st_data_i & mask1) << ashift1)
137 comb += ld2.ld_i.eq((self.st_data_i & mask2) >> ashift2)
138
139 return m
140
141 def __iter__(self):
142 yield self.addr_i
143 yield self.len_i
144 yield self.is_ld_i
145 yield self.ld_data_o.err
146 yield self.ld_data_o.data
147 yield self.valid_i
148 yield self.valid_o
149 yield self.sld_valid_i
150 for i in range(2):
151 yield self.sld_data_i[i].err
152 yield self.sld_data_i[i].data
153
154 def ports(self):
155 return list(self)
156
157 def sim(dut):
158
159 sim = Simulator(dut)
160 sim.add_clock(1e-6)
161 data = 0b11010011
162 dlen = 4 # 4 bits
163 addr = 0b1100
164 ld_len = 8
165 ldm = ((1<<ld_len)-1)
166 dlm = ((1<<dlen)-1)
167 data = data & ldm # truncate data to be tested, mask to within ld len
168 print ("ldm", ldm, bin(data&ldm))
169 print ("dlm", dlm, bin(addr&dlm))
170 dmask = ldm << (addr & dlm)
171 print ("dmask", bin(dmask))
172 dmask1 = dmask >> (1<<dlen)
173 print ("dmask1", bin(dmask1))
174 dmask = dmask & ((1<<(1<<dlen))-1)
175 print ("dmask", bin(dmask))
176
177 def send_in():
178 print ("send_in")
179 yield dut.is_ld_i.eq(1)
180 yield dut.len_i.eq(ld_len)
181 yield dut.addr_i.eq(addr)
182 yield dut.valid_i.eq(1)
183 print ("waiting")
184 while True:
185 valid_o = yield dut.valid_o
186 if valid_o:
187 break
188 yield
189 ld_data_o = yield dut.ld_data_o.data
190 yield dut.is_ld_i.eq(0)
191 yield
192
193 print (bin(ld_data_o), bin(data))
194 assert ld_data_o == data
195
196 def lds():
197 print ("lds")
198 while True:
199 valid_i = yield dut.valid_i
200 if valid_i:
201 break
202 yield
203
204 shf = addr & dlm
205 shfdata = (data << shf)
206 data1 = shfdata & dmask
207 print ("ld data1", bin(data), bin(data1), shf, bin(dmask))
208
209 data2 = (shfdata >> 16) & dmask1
210 print ("ld data2", 1<<dlen, bin(data >> (1<<dlen)), bin(data2))
211 yield dut.sld_data_i[0].data.eq(data1)
212 yield dut.sld_valid_i[0].eq(1)
213 yield
214 yield dut.sld_data_i[1].data.eq(data2)
215 yield dut.sld_valid_i[1].eq(1)
216 yield
217
218 sim.add_sync_process(lds)
219 sim.add_sync_process(send_in)
220
221 prefix = "ldst_splitter"
222 with sim.write_vcd("%s.vcd" % prefix, traces=dut.ports()):
223 sim.run()
224
225
226 if __name__ == '__main__':
227 dut = LDSTSplitter(32, 48, 4)
228 vl = rtlil.convert(dut, ports=dut.ports())
229 with open("ldst_splitter.il", "w") as f:
230 f.write(vl)
231
232 sim(dut)