Add memory loads and stores to simulator
[soc.git] / src / soc / scoreboard / addr_match.py
1 """ Load / Store partial address matcher
2
3 Related bugreports:
4 * http://bugs.libre-riscv.org/show_bug.cgi?id=216
5
6 Loads and Stores do not need a full match (CAM), they need "good enough"
7 avoidance. Around 11 bits on a 64-bit address is "good enough".
8
9 The simplest way to use this module is to ignore not only the top bits,
10 but also the bottom bits as well: in this case (this RV64 processor),
11 enough to cover a DWORD (64-bit). that means ignore the bottom 4 bits,
12 due to the possibility of 64-bit LD/ST being misaligned.
13
14 To reiterate: the use of this module is an *optimisation*. All it has
15 to do is cover the cases that are *definitely* matches (by checking 11
16 bits or so), and if a few opportunities for parallel LD/STs are missed
17 because the top (or bottom) bits weren't checked, so what: all that
18 happens is: the mis-matched addresses are LD/STd on single-cycles. Big Deal.
19
20 However, if we wanted to enhance this algorithm (without using a CAM and
21 without using expensive comparators) probably the best way to do so would
22 be to turn the last 16 bits into a byte-level bitmap. LD/ST on a byte
23 would have 1 of the 16 bits set. LD/ST on a DWORD would have 8 of the 16
24 bits set (offset if the LD/ST was misaligned). TODO.
25
26 Notes:
27
28 > I have used bits <11:6> as they are not translated (4KB pages)
29 > and larger than a cache line (64 bytes).
30 > I have used bits <11:4> when the L1 cache was QuadW sized and
31 > the L2 cache was Line sized.
32 """
33
34 from nmigen.compat.sim import run_simulation
35 from nmigen.cli import verilog, rtlil
36 from nmigen import Module, Signal, Const, Array, Cat, Elaboratable
37 from nmigen.lib.coding import Decoder
38
39 from nmutil.latch import latchregister, SRLatch
40
41
42 class PartialAddrMatch(Elaboratable):
43 """A partial address matcher
44 """
45 def __init__(self, n_adr, bitwid):
46 self.n_adr = n_adr
47 self.bitwid = bitwid
48 # inputs
49 self.addrs_i = Array(Signal(bitwid, name="addr") for i in range(n_adr))
50 #self.addr_we_i = Signal(n_adr, reset_less=True) # write-enable
51 self.addr_en_i = Signal(n_adr, reset_less=True) # address latched in
52 self.addr_rs_i = Signal(n_adr, reset_less=True) # address deactivated
53
54 # output: a nomatch for each address plus individual nomatch signals
55 self.addr_nomatch_o = Signal(n_adr, name="nomatch_o", reset_less=True)
56 self.addr_nomatch_a_o = Array(Signal(n_adr, reset_less=True,
57 name="nomatch_array_o") \
58 for i in range(n_adr))
59
60 def elaborate(self, platform):
61 m = Module()
62 return self._elaborate(m, platform)
63
64 def _elaborate(self, m, platform):
65 comb = m.d.comb
66 sync = m.d.sync
67
68 # array of address-latches
69 m.submodules.l = self.l = l = SRLatch(llen=self.n_adr, sync=False)
70 self.adrs_r = adrs_r = Array(Signal(self.bitwid, reset_less=True,
71 name="a_r") \
72 for i in range(self.n_adr))
73
74 # latch set/reset
75 comb += l.s.eq(self.addr_en_i)
76 comb += l.r.eq(self.addr_rs_i)
77
78 # copy in addresses (and "enable" signals)
79 for i in range(self.n_adr):
80 latchregister(m, self.addrs_i[i], adrs_r[i], l.q[i])
81
82 # is there a clash, yes/no
83 matchgrp = []
84 for i in range(self.n_adr):
85 match = []
86 for j in range(self.n_adr):
87 match.append(self.is_match(i, j))
88 comb += self.addr_nomatch_a_o[i].eq(~Cat(*match))
89 matchgrp.append((self.addr_nomatch_a_o[i] & l.q) == l.q)
90 comb += self.addr_nomatch_o.eq(Cat(*matchgrp) & l.q)
91
92 return m
93
94 def is_match(self, i, j):
95 if i == j:
96 return Const(0) # don't match against self!
97 return self.adrs_r[i] == self.adrs_r[j]
98
99 def __iter__(self):
100 yield from self.addrs_i
101 #yield self.addr_we_i
102 yield self.addr_en_i
103 yield from self.addr_nomatch_a_o
104 yield self.addr_nomatch_o
105
106 def ports(self):
107 return list(self)
108
109
110 class LenExpand(Elaboratable):
111 """LenExpand: expands binary length (and LSBs of an address) into unary
112
113 this basically produces a bitmap of which *bytes* are to be read (written)
114 in memory. examples:
115
116 (bit_len=4) len=4, addr=0b0011 => 0b1111 << addr
117 => 0b1111000
118 (bit_len=4) len=8, addr=0b0101 => 0b11111111 << addr
119 => 0b1111111100000
120 """
121
122 def __init__(self, bit_len):
123 self.bit_len = bit_len
124 self.len_i = Signal(bit_len, reset_less=True)
125 self.addr_i = Signal(bit_len, reset_less=True)
126 self.lexp_o = Signal(1<<(bit_len+1), reset_less=True)
127
128 def elaborate(self, platform):
129 m = Module()
130 comb = m.d.comb
131
132 # temp
133 binlen = Signal((1<<self.bit_len)+1, reset_less=True)
134 comb += binlen.eq((Const(1, self.bit_len+1) << (self.len_i)) - 1)
135 comb += self.lexp_o.eq(binlen << self.addr_i)
136
137 return m
138
139 def ports(self):
140 return [self.len_i, self.addr_i, self.lexp_o,]
141
142
143 class TwinPartialAddrBitmap(PartialAddrMatch):
144 """TwinPartialAddrBitMap
145
146 designed to be connected to via LDSTSplitter, which generates
147 *pairs* of addresses and covers the misalignment across cache
148 line boundaries *in the splitter*. Also LDSTSplitter takes
149 care of expanding the LSBs of each address into a bitmap, itself.
150
151 the key difference between this and PartialAddrMap is that the
152 knowledge (fact) that pairs of addresses from the same LDSTSplitter
153 are 1 apart is *guaranteed* to be a miss for those two addresses.
154 therefore is_match specially takes that into account.
155 """
156 def __init__(self, n_adr, lsbwid, bitlen):
157 self.lsbwid = lsbwid # number of bits to turn into unary
158 self.midlen = bitlen-lsbwid
159 PartialAddrMatch.__init__(self, n_adr, self.midlen)
160
161 # input: length of the LOAD/STORE
162 expwid = 1+self.lsbwid # XXX assume LD/ST no greater than 8
163 self.lexp_i = Array(Signal(1<<expwid, reset_less=True,
164 name="len") for i in range(n_adr))
165 # input: full address
166 self.faddrs_i = Array(Signal(bitlen, reset_less=True,
167 name="fadr") for i in range(n_adr))
168
169 # registers for expanded len
170 self.len_r = Array(Signal(expwid, reset_less=True, name="l_r") \
171 for i in range(self.n_adr))
172
173 def elaborate(self, platform):
174 m = PartialAddrMatch.elaborate(self, platform)
175 comb = m.d.comb
176
177 # intermediaries
178 adrs_r, l = self.adrs_r, self.l
179 expwid = 1+self.lsbwid
180
181 for i in range(self.n_adr):
182 # copy the top lsbwid..(lsbwid-bit_len) of addresses to compare
183 comb += self.addrs_i[i].eq(self.faddrs_i[i][self.lsbwid:])
184
185 # copy in expanded-lengths and latch them
186 latchregister(m, self.lexp_i[i], self.len_r[i], l.q[i])
187
188 return m
189
190 # TODO make this a module. too much.
191 def is_match(self, i, j):
192 if i == j:
193 return Const(0) # don't match against self!
194 # we know that pairs have addr and addr+1 therefore it is
195 # guaranteed that they will not match.
196 if (i // 2) == (j // 2):
197 return Const(0) # don't match against twin, either.
198
199 # the bitmask contains data for *two* cache lines (16 bytes).
200 # however len==8 only covers *half* a cache line so we only
201 # need to compare half the bits
202 expwid = 1<<self.lsbwid
203 #if i % 2 == 1 or j % 2 == 1: # XXX hmmm...
204 # expwid >>= 1
205
206 # straight compare: binary top bits of addr, *unary* compare on bottom
207 straight_eq = (self.adrs_r[i] == self.adrs_r[j]) & \
208 (self.len_r[i][:expwid] & self.len_r[j][:expwid]).bool()
209 return straight_eq
210
211 def __iter__(self):
212 yield from self.faddrs_i
213 yield from self.lexp_i
214 yield self.addr_en_i
215 yield from self.addr_nomatch_a_o
216 yield self.addr_nomatch_o
217
218 def ports(self):
219 return list(self)
220
221
222 class PartialAddrBitmap(PartialAddrMatch):
223 """PartialAddrBitMap
224
225 makes two comparisons for each address, with each (addr,len)
226 being extended to an unary byte-map.
227
228 two comparisons are needed because when an address is misaligned,
229 the byte-map is split into two halves. example:
230
231 address = 0b1011011, len=8 => 0b101 and shift of 11 (0b1011)
232 len in unary is 0b0000 0000 1111 1111
233 when shifted becomes TWO addresses:
234
235 * 0b101 and a byte-map of 0b1111 1000 0000 0000 (len-mask shifted by 11)
236 * 0b101+1 and a byte-map of 0b0000 0000 0000 0111 (overlaps onto next 16)
237
238 therefore, because this now covers two addresses, we need *two*
239 comparisons per address *not* one.
240 """
241 def __init__(self, n_adr, lsbwid, bitlen):
242 self.lsbwid = lsbwid # number of bits to turn into unary
243 self.midlen = bitlen-lsbwid
244 PartialAddrMatch.__init__(self, n_adr, self.midlen)
245
246 # input: length of the LOAD/STORE
247 self.len_i = Array(Signal(lsbwid, reset_less=True,
248 name="len") for i in range(n_adr))
249 # input: full address
250 self.faddrs_i = Array(Signal(bitlen, reset_less=True,
251 name="fadr") for i in range(n_adr))
252
253 # intermediary: address + 1
254 self.addr1s = Array(Signal(self.midlen, reset_less=True,
255 name="adr1") \
256 for i in range(n_adr))
257
258 # expanded lengths, needed in match
259 expwid = 1+self.lsbwid # XXX assume LD/ST no greater than 8
260 self.lexp = Array(Signal(1<<expwid, reset_less=True,
261 name="a_l") \
262 for i in range(self.n_adr))
263
264 def elaborate(self, platform):
265 m = PartialAddrMatch.elaborate(self, platform)
266 comb = m.d.comb
267
268 # intermediaries
269 adrs_r, l = self.adrs_r, self.l
270 len_r = Array(Signal(self.lsbwid, reset_less=True,
271 name="l_r") \
272 for i in range(self.n_adr))
273
274 for i in range(self.n_adr):
275 # create a bit-expander for each address
276 be = LenExpand(self.lsbwid)
277 setattr(m.submodules, "le%d" % i, be)
278 # copy the top lsbwid..(lsbwid-bit_len) of addresses to compare
279 comb += self.addrs_i[i].eq(self.faddrs_i[i][self.lsbwid:])
280
281 # copy in lengths and latch them
282 latchregister(m, self.len_i[i], len_r[i], l.q[i])
283
284 # add one to intermediate addresses
285 comb += self.addr1s[i].eq(self.adrs_r[i]+1)
286
287 # put the bottom bits of each address into each LenExpander.
288 comb += be.len_i.eq(len_r[i])
289 comb += be.addr_i.eq(self.faddrs_i[i][:self.lsbwid])
290 # connect expander output
291 comb += self.lexp[i].eq(be.lexp_o)
292
293 return m
294
295 # TODO make this a module. too much.
296 def is_match(self, i, j):
297 if i == j:
298 return Const(0) # don't match against self!
299 # the bitmask contains data for *two* cache lines (16 bytes).
300 # however len==8 only covers *half* a cache line so we only
301 # need to compare half the bits
302 expwid = 1<<self.lsbwid
303 hexp = expwid >> 1
304 expwid2 = expwid + hexp
305 print (self.lsbwid, expwid)
306 # straight compare: binary top bits of addr, *unary* compare on bottom
307 straight_eq = (self.adrs_r[i] == self.adrs_r[j]) & \
308 (self.lexp[i][:expwid] & self.lexp[j][:expwid]).bool()
309 # compare i (addr+1) to j (addr), but top unary against bottom unary
310 i1_eq_j = (self.addr1s[i] == self.adrs_r[j]) & \
311 (self.lexp[i][expwid:expwid2] & self.lexp[j][:hexp]).bool()
312 # compare i (addr) to j (addr+1), but bottom unary against top unary
313 i_eq_j1 = (self.adrs_r[i] == self.addr1s[j]) & \
314 (self.lexp[i][:hexp] & self.lexp[j][expwid:expwid2]).bool()
315 return straight_eq | i1_eq_j | i_eq_j1
316
317 def __iter__(self):
318 yield from self.faddrs_i
319 yield from self.len_i
320 #yield self.addr_we_i
321 yield self.addr_en_i
322 yield from self.addr_nomatch_a_o
323 yield self.addr_nomatch_o
324
325 def ports(self):
326 return list(self)
327
328
329 def part_addr_sim(dut):
330 yield dut.dest_i.eq(1)
331 yield dut.issue_i.eq(1)
332 yield
333 yield dut.issue_i.eq(0)
334 yield
335 yield dut.src1_i.eq(1)
336 yield dut.issue_i.eq(1)
337 yield
338 yield dut.issue_i.eq(0)
339 yield
340 yield dut.go_rd_i.eq(1)
341 yield
342 yield dut.go_rd_i.eq(0)
343 yield
344 yield dut.go_wr_i.eq(1)
345 yield
346 yield dut.go_wr_i.eq(0)
347 yield
348
349 def part_addr_bit(dut):
350 # 0b110 | 0b101 |
351 # 0b101 1011 / 8 ==> 0b0000 0000 0000 0111 | 1111 1000 0000 0000 |
352 yield dut.len_i[0].eq(8)
353 yield dut.faddrs_i[0].eq(0b1011011)
354 yield dut.addr_en_i[0].eq(1)
355 yield
356 yield dut.addr_en_i[0].eq(0)
357 yield
358 # 0b110 | 0b101 |
359 # 0b110 0010 / 2 ==> 0b0000 0000 0000 1100 | 0000 0000 0000 0000 |
360 yield dut.len_i[1].eq(2)
361 yield dut.faddrs_i[1].eq(0b1100010)
362 yield dut.addr_en_i[1].eq(1)
363 yield
364 yield dut.addr_en_i[1].eq(0)
365 yield
366 # 0b110 | 0b101 |
367 # 0b101 1010 / 2 ==> 0b0000 0000 0000 0000 | 0000 1100 0000 0000 |
368 yield dut.len_i[2].eq(2)
369 yield dut.faddrs_i[2].eq(0b1011010)
370 yield dut.addr_en_i[2].eq(1)
371 yield
372 yield dut.addr_en_i[2].eq(0)
373 yield
374 # 0b110 | 0b101 |
375 # 0b101 1001 / 2 ==> 0b0000 0000 0000 0000 | 0000 0110 0000 0000 |
376 yield dut.len_i[2].eq(2)
377 yield dut.faddrs_i[2].eq(0b1011001)
378 yield dut.addr_en_i[2].eq(1)
379 yield
380 yield dut.addr_en_i[2].eq(0)
381 yield
382 yield dut.addr_rs_i[1].eq(1)
383 yield
384 yield dut.addr_rs_i[1].eq(0)
385 yield
386
387 def test_part_addr():
388 dut = LenExpand(4)
389 vl = rtlil.convert(dut, ports=dut.ports())
390 with open("test_len_expand.il", "w") as f:
391 f.write(vl)
392
393 dut = TwinPartialAddrBitmap(3, 4, 10)
394 vl = rtlil.convert(dut, ports=dut.ports())
395 with open("test_twin_part_bit.il", "w") as f:
396 f.write(vl)
397
398 dut = PartialAddrBitmap(3, 4, 10)
399 vl = rtlil.convert(dut, ports=dut.ports())
400 with open("test_part_bit.il", "w") as f:
401 f.write(vl)
402
403 run_simulation(dut, part_addr_bit(dut), vcd_name='test_part_bit.vcd')
404
405 dut = PartialAddrMatch(3, 10)
406 vl = rtlil.convert(dut, ports=dut.ports())
407 with open("test_part_addr.il", "w") as f:
408 f.write(vl)
409
410 run_simulation(dut, part_addr_sim(dut), vcd_name='test_part_addr.vcd')
411
412 if __name__ == '__main__':
413 test_part_addr()