move ashift out to common area
[soc.git] / src / soc / scoreboard / addr_split.py
1 # LDST Address Splitter. For misaligned address crossing cache line boundary
2
3 from nmigen import Elaboratable, Module, Signal, Record, Array, Const
4 from nmutil.latch import SRLatch, latchregister
5 from nmigen.back.pysim import Simulator, Delay
6 from nmigen.cli import verilog, rtlil
7
8 from soc.scoreboard.addr_match import LenExpand
9 #from nmutil.queue import Queue
10
11 class LDData(Record):
12 def __init__(self, dwidth, name=None):
13 Record.__init__(self, (('err', 1), ('data', dwidth)), name=name)
14
15
16 class LDLatch(Elaboratable):
17
18 def __init__(self, dwidth, awidth, mlen):
19 self.addr_i = Signal(awidth, reset_less=True)
20 self.mask_i = Signal(mlen, reset_less=True)
21 self.valid_i = Signal(reset_less=True)
22 self.ld_i = LDData(dwidth, "ld_i")
23 self.ld_o = LDData(dwidth, "ld_o")
24 self.valid_o = Signal(reset_less=True)
25
26 def elaborate(self, platform):
27 m = Module()
28 comb = m.d.comb
29 m.submodules.in_l = in_l = SRLatch(sync=False, name="in_l")
30
31 comb += in_l.s.eq(self.valid_i)
32 comb += self.valid_o.eq(in_l.q & self.valid_i)
33 latchregister(m, self.ld_i, self.ld_o, in_l.q & self.valid_o, "ld_i_r")
34
35 return m
36
37 def __iter__(self):
38 yield self.addr_i
39 yield self.mask_i
40 yield self.ld_i.err
41 yield self.ld_i.data
42 yield self.ld_o.err
43 yield self.ld_o.data
44 yield self.valid_i
45 yield self.valid_o
46
47 def ports(self):
48 return list(self)
49
50 class LDSTSplitter(Elaboratable):
51
52 def __init__(self, dwidth, awidth, dlen):
53 self.dwidth, self.awidth, self.dlen = dwidth, awidth, dlen
54 #cline_wid = 8<<dlen # cache line width: bytes (8) times (2^^dlen)
55 cline_wid = dwidth # TODO: make this bytes not bits
56 self.addr_i = Signal(awidth, reset_less=True)
57 self.len_i = Signal(dlen, reset_less=True)
58 self.valid_i = Signal(reset_less=True)
59 self.valid_o = Signal(reset_less=True)
60
61 self.is_ld_i = Signal(reset_less=True)
62 self.is_st_i = Signal(reset_less=True)
63
64 self.ld_data_o = LDData(dwidth, "ld_data_o")
65 self.st_data_i = LDData(dwidth, "st_data_i")
66
67 self.sld_valid_o = Signal(2, reset_less=True)
68 self.sld_valid_i = Signal(2, reset_less=True)
69 self.sld_data_i = Array((LDData(cline_wid, "ld_data_i1"),
70 LDData(cline_wid, "ld_data_i2")))
71
72 self.sst_valid_o = Signal(2, reset_less=True)
73 self.sst_valid_i = Signal(2, reset_less=True)
74 self.sst_data_o = Array((LDData(cline_wid, "st_data_i1"),
75 LDData(cline_wid, "st_data_i2")))
76
77 def elaborate(self, platform):
78 m = Module()
79 comb = m.d.comb
80 dlen = self.dlen
81 mlen = 1 << dlen
82 m.submodules.ld1 = ld1 = LDLatch(self.dwidth, self.awidth-dlen, mlen)
83 m.submodules.ld2 = ld2 = LDLatch(self.dwidth, self.awidth-dlen, mlen)
84 m.submodules.lenexp = lenexp = LenExpand(self.dlen)
85
86 # set up len-expander, len to mask. ld1 gets first bit, ld2 gets rest
87 comb += lenexp.addr_i.eq(self.addr_i)
88 comb += lenexp.len_i.eq(self.len_i)
89 mask1 = Signal(mlen, reset_less=True)
90 mask2 = Signal(mlen, reset_less=True)
91 comb += mask1.eq(lenexp.lexp_o[0:mlen]) # Lo bits of expanded len-mask
92 comb += mask2.eq(lenexp.lexp_o[mlen:]) # Hi bits of expanded len-mask
93
94 # set up new address records: addr1 is "as-is", addr2 is +1
95 comb += ld1.addr_i.eq(self.addr_i[dlen:])
96 comb += ld2.addr_i.eq(self.addr_i[dlen:] + 1) # TODO exception if rolls
97
98 # data needs recombining / splitting via shifting.
99 ashift1 = Signal(self.dlen, reset_less=True)
100 ashift2 = Signal(self.dlen, reset_less=True)
101 comb += ashift1.eq(self.addr_i[:self.dlen])
102 comb += ashift2.eq((1<<dlen)-ashift1)
103
104 with m.If(self.is_ld_i):
105 # set up connections to LD-split. note: not active if mask is zero
106 mzero = Const(0, mlen)
107 for i, (ld, mask) in enumerate(((ld1, mask1),
108 (ld2, mask2))):
109 ld_valid = Signal(name="ldvalid_i%d" % i, reset_less=True)
110 comb += ld_valid.eq(self.valid_i & self.sld_valid_i[i])
111 comb += ld.valid_i.eq(ld_valid & (mask != mzero))
112 comb += ld.ld_i.eq(self.sld_data_i[i])
113 comb += self.sld_valid_o[i].eq(ld.valid_o)
114
115 # sort out valid: mask2 zero we ignore 2nd LD
116 with m.If(mask2 == mzero):
117 comb += self.valid_o.eq(self.sld_valid_o[0])
118 with m.Else():
119 comb += self.valid_o.eq(self.sld_valid_o.all())
120
121 # all bits valid (including when data error occurs!) decode ld1/ld2
122 with m.If(self.valid_o):
123 # errors cause error condition
124 comb += self.ld_data_o.err.eq(ld1.ld_o.err | ld2.ld_o.err)
125
126 # note that data from LD1 will be in *cache-line* byte position
127 # likewise from LD2 but we *know* it is at the start of the line
128 comb += self.ld_data_o.data.eq((ld1.ld_o.data >> ashift1) |
129 (ld2.ld_o.data << ashift2))
130
131 with m.If(self.is_st_i):
132 mzero = Const(0, mlen)
133 for i, (ld, mask) in enumerate(((ld1, mask1),
134 (ld2, mask2))):
135 valid = Signal(name="stvalid_i%d" % i, reset_less=True)
136 comb += valid.eq(self.valid_i & self.sst_valid_i[i])
137 comb += ld.valid_i.eq(valid & (mask != mzero))
138 comb += self.sld_valid_o[i].eq(ld.valid_o)
139
140 comb += ld1.ld_i.eq((self.st_data_i & mask1) << ashift1)
141 comb += ld2.ld_i.eq((self.st_data_i & mask2) >> ashift2)
142
143 return m
144
145 def __iter__(self):
146 yield self.addr_i
147 yield self.len_i
148 yield self.is_ld_i
149 yield self.ld_data_o.err
150 yield self.ld_data_o.data
151 yield self.valid_i
152 yield self.valid_o
153 yield self.sld_valid_i
154 for i in range(2):
155 yield self.sld_data_i[i].err
156 yield self.sld_data_i[i].data
157
158 def ports(self):
159 return list(self)
160
161 def sim(dut):
162
163 sim = Simulator(dut)
164 sim.add_clock(1e-6)
165 data = 0b11010011
166 dlen = 4 # 4 bits
167 addr = 0b1100
168 ld_len = 8
169 ldm = ((1<<ld_len)-1)
170 dlm = ((1<<dlen)-1)
171 data = data & ldm # truncate data to be tested, mask to within ld len
172 print ("ldm", ldm, bin(data&ldm))
173 print ("dlm", dlm, bin(addr&dlm))
174 dmask = ldm << (addr & dlm)
175 print ("dmask", bin(dmask))
176 dmask1 = dmask >> (1<<dlen)
177 print ("dmask1", bin(dmask1))
178 dmask = dmask & ((1<<(1<<dlen))-1)
179 print ("dmask", bin(dmask))
180
181 def send_in():
182 print ("send_in")
183 yield dut.is_ld_i.eq(1)
184 yield dut.len_i.eq(ld_len)
185 yield dut.addr_i.eq(addr)
186 yield dut.valid_i.eq(1)
187 print ("waiting")
188 while True:
189 valid_o = yield dut.valid_o
190 if valid_o:
191 break
192 yield
193 ld_data_o = yield dut.ld_data_o.data
194 yield dut.is_ld_i.eq(0)
195 yield
196
197 print (bin(ld_data_o), bin(data))
198 assert ld_data_o == data
199
200 def lds():
201 print ("lds")
202 while True:
203 valid_i = yield dut.valid_i
204 if valid_i:
205 break
206 yield
207
208 shf = addr & dlm
209 shfdata = (data << shf)
210 data1 = shfdata & dmask
211 print ("ld data1", bin(data), bin(data1), shf, bin(dmask))
212
213 data2 = (shfdata >> 16) & dmask1
214 print ("ld data2", 1<<dlen, bin(data >> (1<<dlen)), bin(data2))
215 yield dut.sld_data_i[0].data.eq(data1)
216 yield dut.sld_valid_i[0].eq(1)
217 yield
218 yield dut.sld_data_i[1].data.eq(data2)
219 yield dut.sld_valid_i[1].eq(1)
220 yield
221
222 sim.add_sync_process(lds)
223 sim.add_sync_process(send_in)
224
225 prefix = "ldst_splitter"
226 with sim.write_vcd("%s.vcd" % prefix, traces=dut.ports()):
227 sim.run()
228
229
230 if __name__ == '__main__':
231 dut = LDSTSplitter(32, 48, 4)
232 vl = rtlil.convert(dut, ports=dut.ports())
233 with open("ldst_splitter.il", "w") as f:
234 f.write(vl)
235
236 sim(dut)