1fa565244d8f6b177de6d9f6f744690b6eb856d0
[soc.git] / src / soc / scoreboard / addr_match.py
1 """ Load / Store partial address matcher
2
3 Related bugreports:
4 * http://bugs.libre-riscv.org/show_bug.cgi?id=216
5
6 Loads and Stores do not need a full match (CAM), they need "good enough"
7 avoidance. Around 11 bits on a 64-bit address is "good enough".
8
9 The simplest way to use this module is to ignore not only the top bits,
10 but also the bottom bits as well: in this case (this RV64 processor),
11 enough to cover a DWORD (64-bit). that means ignore the bottom 4 bits,
12 due to the possibility of 64-bit LD/ST being misaligned.
13
14 To reiterate: the use of this module is an *optimisation*. All it has
15 to do is cover the cases that are *definitely* matches (by checking 11
16 bits or so), and if a few opportunities for parallel LD/STs are missed
17 because the top (or bottom) bits weren't checked, so what: all that
18 happens is: the mis-matched addresses are LD/STd on single-cycles. Big Deal.
19
20 However, if we wanted to enhance this algorithm (without using a CAM and
21 without using expensive comparators) probably the best way to do so would
22 be to turn the last 16 bits into a byte-level bitmap. LD/ST on a byte
23 would have 1 of the 16 bits set. LD/ST on a DWORD would have 8 of the 16
24 bits set (offset if the LD/ST was misaligned). TODO.
25
26 Notes:
27
28 > I have used bits <11:6> as they are not translated (4KB pages)
29 > and larger than a cache line (64 bytes).
30 > I have used bits <11:4> when the L1 cache was QuadW sized and
31 > the L2 cache was Line sized.
32 """
33
34 from nmigen.compat.sim import run_simulation, Settle
35 from nmigen.cli import verilog, rtlil
36 from nmigen import Module, Signal, Const, Array, Cat, Elaboratable, Repl
37 from nmigen.lib.coding import Decoder
38 from nmigen.utils import log2_int
39
40 from nmutil.latch import latchregister, SRLatch
41
42
43 class PartialAddrMatch(Elaboratable):
44 """A partial address matcher
45 """
46 def __init__(self, n_adr, bitwid):
47 self.n_adr = n_adr
48 self.bitwid = bitwid
49 # inputs
50 self.addrs_i = Array(Signal(bitwid, name="addr") for i in range(n_adr))
51 #self.addr_we_i = Signal(n_adr, reset_less=True) # write-enable
52 self.addr_en_i = Signal(n_adr, reset_less=True) # address latched in
53 self.addr_rs_i = Signal(n_adr, reset_less=True) # address deactivated
54
55 # output: a nomatch for each address plus individual nomatch signals
56 self.addr_nomatch_o = Signal(n_adr, name="nomatch_o", reset_less=True)
57 self.addr_nomatch_a_o = Array(Signal(n_adr, reset_less=True,
58 name="nomatch_array_o") \
59 for i in range(n_adr))
60
61 def elaborate(self, platform):
62 m = Module()
63 return self._elaborate(m, platform)
64
65 def _elaborate(self, m, platform):
66 comb = m.d.comb
67 sync = m.d.sync
68
69 # array of address-latches
70 m.submodules.l = self.l = l = SRLatch(llen=self.n_adr, sync=False)
71 self.adrs_r = adrs_r = Array(Signal(self.bitwid, reset_less=True,
72 name="a_r") \
73 for i in range(self.n_adr))
74
75 # latch set/reset
76 comb += l.s.eq(self.addr_en_i)
77 comb += l.r.eq(self.addr_rs_i)
78
79 # copy in addresses (and "enable" signals)
80 for i in range(self.n_adr):
81 latchregister(m, self.addrs_i[i], adrs_r[i], l.q[i])
82
83 # is there a clash, yes/no
84 matchgrp = []
85 for i in range(self.n_adr):
86 match = []
87 for j in range(self.n_adr):
88 match.append(self.is_match(i, j))
89 comb += self.addr_nomatch_a_o[i].eq(~Cat(*match))
90 matchgrp.append((self.addr_nomatch_a_o[i] & l.q) == l.q)
91 comb += self.addr_nomatch_o.eq(Cat(*matchgrp) & l.q)
92
93 return m
94
95 def is_match(self, i, j):
96 if i == j:
97 return Const(0) # don't match against self!
98 return self.adrs_r[i] == self.adrs_r[j]
99
100 def __iter__(self):
101 yield from self.addrs_i
102 #yield self.addr_we_i
103 yield self.addr_en_i
104 yield from self.addr_nomatch_a_o
105 yield self.addr_nomatch_o
106
107 def ports(self):
108 return list(self)
109
110
111 class LenExpand(Elaboratable):
112 """LenExpand: expands binary length (and LSBs of an address) into unary
113
114 this basically produces a bitmap of which *bytes* are to be read (written)
115 in memory. examples:
116
117 (bit_len=4) len=4, addr=0b0011 => 0b1111 << addr
118 => 0b1111000
119 (bit_len=4) len=8, addr=0b0101 => 0b11111111 << addr
120 => 0b1111111100000
121
122 note: by setting cover=8 this can also be used as a shift-mask. the
123 bit-mask is replicated (expanded out), each bit expanded to "cover" bits.
124 """
125
126 def __init__(self, bit_len, cover=1):
127 self.bit_len = bit_len
128 self.cover = cover
129 self.len_i = Signal(bit_len, reset_less=True)
130 self.addr_i = Signal(bit_len, reset_less=True)
131 self.lexp_o = Signal(self.llen(1), reset_less=True)
132 if cover > 1:
133 self.rexp_o = Signal(self.llen(cover), reset_less=True)
134 print ("LenExpand", bit_len, cover, self.lexp_o.shape())
135
136 def llen(self, cover):
137 cl = log2_int(self.cover)
138 return (cover<<(self.bit_len))+(cl<<self.bit_len)
139
140 def elaborate(self, platform):
141 m = Module()
142 comb = m.d.comb
143
144 # covers N bits
145 llen = self.llen(1)
146 # temp
147 binlen = Signal((1<<self.bit_len)+1, reset_less=True)
148 lexp_o = Signal(llen, reset_less=True)
149 comb += binlen.eq((Const(1, self.bit_len+1) << (self.len_i)) - 1)
150 comb += self.lexp_o.eq(binlen << self.addr_i)
151 if self.cover == 1:
152 return m
153 l = []
154 print ("llen", llen)
155 for i in range(llen):
156 l.append(Repl(self.lexp_o[i], self.cover))
157 comb += self.rexp_o.eq(Cat(*l))
158 return m
159
160 def ports(self):
161 return [self.len_i, self.addr_i, self.lexp_o,]
162
163
164 class TwinPartialAddrBitmap(PartialAddrMatch):
165 """TwinPartialAddrBitMap
166
167 designed to be connected to via LDSTSplitter, which generates
168 *pairs* of addresses and covers the misalignment across cache
169 line boundaries *in the splitter*. Also LDSTSplitter takes
170 care of expanding the LSBs of each address into a bitmap, itself.
171
172 the key difference between this and PartialAddrMap is that the
173 knowledge (fact) that pairs of addresses from the same LDSTSplitter
174 are 1 apart is *guaranteed* to be a miss for those two addresses.
175 therefore is_match specially takes that into account.
176 """
177 def __init__(self, n_adr, lsbwid, bitlen):
178 self.lsbwid = lsbwid # number of bits to turn into unary
179 self.midlen = bitlen-lsbwid
180 PartialAddrMatch.__init__(self, n_adr, self.midlen)
181
182 # input: length of the LOAD/STORE
183 expwid = 1+self.lsbwid # XXX assume LD/ST no greater than 8
184 self.lexp_i = Array(Signal(1<<expwid, reset_less=True,
185 name="len") for i in range(n_adr))
186 # input: full address
187 self.faddrs_i = Array(Signal(bitlen, reset_less=True,
188 name="fadr") for i in range(n_adr))
189
190 # registers for expanded len
191 self.len_r = Array(Signal(expwid, reset_less=True, name="l_r") \
192 for i in range(self.n_adr))
193
194 def elaborate(self, platform):
195 m = PartialAddrMatch.elaborate(self, platform)
196 comb = m.d.comb
197
198 # intermediaries
199 adrs_r, l = self.adrs_r, self.l
200 expwid = 1+self.lsbwid
201
202 for i in range(self.n_adr):
203 # copy the top lsbwid..(lsbwid-bit_len) of addresses to compare
204 comb += self.addrs_i[i].eq(self.faddrs_i[i][self.lsbwid:])
205
206 # copy in expanded-lengths and latch them
207 latchregister(m, self.lexp_i[i], self.len_r[i], l.q[i])
208
209 return m
210
211 # TODO make this a module. too much.
212 def is_match(self, i, j):
213 if i == j:
214 return Const(0) # don't match against self!
215 # we know that pairs have addr and addr+1 therefore it is
216 # guaranteed that they will not match.
217 if (i // 2) == (j // 2):
218 return Const(0) # don't match against twin, either.
219
220 # the bitmask contains data for *two* cache lines (16 bytes).
221 # however len==8 only covers *half* a cache line so we only
222 # need to compare half the bits
223 expwid = 1<<self.lsbwid
224 #if i % 2 == 1 or j % 2 == 1: # XXX hmmm...
225 # expwid >>= 1
226
227 # straight compare: binary top bits of addr, *unary* compare on bottom
228 straight_eq = (self.adrs_r[i] == self.adrs_r[j]) & \
229 (self.len_r[i][:expwid] & self.len_r[j][:expwid]).bool()
230 return straight_eq
231
232 def __iter__(self):
233 yield from self.faddrs_i
234 yield from self.lexp_i
235 yield self.addr_en_i
236 yield from self.addr_nomatch_a_o
237 yield self.addr_nomatch_o
238
239 def ports(self):
240 return list(self)
241
242
243 class PartialAddrBitmap(PartialAddrMatch):
244 """PartialAddrBitMap
245
246 makes two comparisons for each address, with each (addr,len)
247 being extended to an unary byte-map.
248
249 two comparisons are needed because when an address is misaligned,
250 the byte-map is split into two halves. example:
251
252 address = 0b1011011, len=8 => 0b101 and shift of 11 (0b1011)
253 len in unary is 0b0000 0000 1111 1111
254 when shifted becomes TWO addresses:
255
256 * 0b101 and a byte-map of 0b1111 1000 0000 0000 (len-mask shifted by 11)
257 * 0b101+1 and a byte-map of 0b0000 0000 0000 0111 (overlaps onto next 16)
258
259 therefore, because this now covers two addresses, we need *two*
260 comparisons per address *not* one.
261 """
262 def __init__(self, n_adr, lsbwid, bitlen):
263 self.lsbwid = lsbwid # number of bits to turn into unary
264 self.midlen = bitlen-lsbwid
265 PartialAddrMatch.__init__(self, n_adr, self.midlen)
266
267 # input: length of the LOAD/STORE
268 self.len_i = Array(Signal(lsbwid, reset_less=True,
269 name="len") for i in range(n_adr))
270 # input: full address
271 self.faddrs_i = Array(Signal(bitlen, reset_less=True,
272 name="fadr") for i in range(n_adr))
273
274 # intermediary: address + 1
275 self.addr1s = Array(Signal(self.midlen, reset_less=True,
276 name="adr1") \
277 for i in range(n_adr))
278
279 # expanded lengths, needed in match
280 expwid = 1+self.lsbwid # XXX assume LD/ST no greater than 8
281 self.lexp = Array(Signal(1<<expwid, reset_less=True,
282 name="a_l") \
283 for i in range(self.n_adr))
284
285 def elaborate(self, platform):
286 m = PartialAddrMatch.elaborate(self, platform)
287 comb = m.d.comb
288
289 # intermediaries
290 adrs_r, l = self.adrs_r, self.l
291 len_r = Array(Signal(self.lsbwid, reset_less=True,
292 name="l_r") \
293 for i in range(self.n_adr))
294
295 for i in range(self.n_adr):
296 # create a bit-expander for each address
297 be = LenExpand(self.lsbwid)
298 setattr(m.submodules, "le%d" % i, be)
299 # copy the top lsbwid..(lsbwid-bit_len) of addresses to compare
300 comb += self.addrs_i[i].eq(self.faddrs_i[i][self.lsbwid:])
301
302 # copy in lengths and latch them
303 latchregister(m, self.len_i[i], len_r[i], l.q[i])
304
305 # add one to intermediate addresses
306 comb += self.addr1s[i].eq(self.adrs_r[i]+1)
307
308 # put the bottom bits of each address into each LenExpander.
309 comb += be.len_i.eq(len_r[i])
310 comb += be.addr_i.eq(self.faddrs_i[i][:self.lsbwid])
311 # connect expander output
312 comb += self.lexp[i].eq(be.lexp_o)
313
314 return m
315
316 # TODO make this a module. too much.
317 def is_match(self, i, j):
318 if i == j:
319 return Const(0) # don't match against self!
320 # the bitmask contains data for *two* cache lines (16 bytes).
321 # however len==8 only covers *half* a cache line so we only
322 # need to compare half the bits
323 expwid = 1<<self.lsbwid
324 hexp = expwid >> 1
325 expwid2 = expwid + hexp
326 print (self.lsbwid, expwid)
327 # straight compare: binary top bits of addr, *unary* compare on bottom
328 straight_eq = (self.adrs_r[i] == self.adrs_r[j]) & \
329 (self.lexp[i][:expwid] & self.lexp[j][:expwid]).bool()
330 # compare i (addr+1) to j (addr), but top unary against bottom unary
331 i1_eq_j = (self.addr1s[i] == self.adrs_r[j]) & \
332 (self.lexp[i][expwid:expwid2] & self.lexp[j][:hexp]).bool()
333 # compare i (addr) to j (addr+1), but bottom unary against top unary
334 i_eq_j1 = (self.adrs_r[i] == self.addr1s[j]) & \
335 (self.lexp[i][:hexp] & self.lexp[j][expwid:expwid2]).bool()
336 return straight_eq | i1_eq_j | i_eq_j1
337
338 def __iter__(self):
339 yield from self.faddrs_i
340 yield from self.len_i
341 #yield self.addr_we_i
342 yield self.addr_en_i
343 yield from self.addr_nomatch_a_o
344 yield self.addr_nomatch_o
345
346 def ports(self):
347 return list(self)
348
349
350 def part_addr_sim(dut):
351 return
352 yield dut.dest_i.eq(1)
353 yield dut.issue_i.eq(1)
354 yield
355 yield dut.issue_i.eq(0)
356 yield
357 yield dut.src1_i.eq(1)
358 yield dut.issue_i.eq(1)
359 yield
360 yield dut.issue_i.eq(0)
361 yield
362 yield dut.go_rd_i.eq(1)
363 yield
364 yield dut.go_rd_i.eq(0)
365 yield
366 yield dut.go_wr_i.eq(1)
367 yield
368 yield dut.go_wr_i.eq(0)
369 yield
370
371 def part_addr_bit(dut):
372 # 0b110 | 0b101 |
373 # 0b101 1011 / 8 ==> 0b0000 0000 0000 0111 | 1111 1000 0000 0000 |
374 yield dut.len_i[0].eq(8)
375 yield dut.faddrs_i[0].eq(0b1011011)
376 yield dut.addr_en_i[0].eq(1)
377 yield
378 yield dut.addr_en_i[0].eq(0)
379 yield
380 # 0b110 | 0b101 |
381 # 0b110 0010 / 2 ==> 0b0000 0000 0000 1100 | 0000 0000 0000 0000 |
382 yield dut.len_i[1].eq(2)
383 yield dut.faddrs_i[1].eq(0b1100010)
384 yield dut.addr_en_i[1].eq(1)
385 yield
386 yield dut.addr_en_i[1].eq(0)
387 yield
388 # 0b110 | 0b101 |
389 # 0b101 1010 / 2 ==> 0b0000 0000 0000 0000 | 0000 1100 0000 0000 |
390 yield dut.len_i[2].eq(2)
391 yield dut.faddrs_i[2].eq(0b1011010)
392 yield dut.addr_en_i[2].eq(1)
393 yield
394 yield dut.addr_en_i[2].eq(0)
395 yield
396 # 0b110 | 0b101 |
397 # 0b101 1001 / 2 ==> 0b0000 0000 0000 0000 | 0000 0110 0000 0000 |
398 yield dut.len_i[2].eq(2)
399 yield dut.faddrs_i[2].eq(0b1011001)
400 yield dut.addr_en_i[2].eq(1)
401 yield
402 yield dut.addr_en_i[2].eq(0)
403 yield
404 yield dut.addr_rs_i[1].eq(1)
405 yield
406 yield dut.addr_rs_i[1].eq(0)
407 yield
408
409
410 def part_addr_byte(dut):
411 for l in range(8):
412 for a in range(1<<dut.bit_len):
413 maskbit = (1<<(l))-1
414 mask = (1<<(l*8))-1
415 yield dut.len_i.eq(l)
416 yield dut.addr_i.eq(a)
417 yield Settle()
418 lexp = yield dut.lexp_o
419 exp = yield dut.rexp_o
420 print ("pa", l, a, bin(lexp), hex(exp))
421 assert exp == (mask << (a*8))
422 assert lexp == (maskbit << (a))
423
424
425 def test_lenexpand_byte():
426 dut = LenExpand(4, 8)
427 vl = rtlil.convert(dut, ports=dut.ports())
428 with open("test_len_expand_byte.il", "w") as f:
429 f.write(vl)
430 run_simulation(dut, part_addr_byte(dut), vcd_name='test_part_byte.vcd')
431
432
433 def test_part_addr():
434 dut = LenExpand(4)
435 vl = rtlil.convert(dut, ports=dut.ports())
436 with open("test_len_expand.il", "w") as f:
437 f.write(vl)
438
439 dut = TwinPartialAddrBitmap(3, 4, 10)
440 vl = rtlil.convert(dut, ports=dut.ports())
441 with open("test_twin_part_bit.il", "w") as f:
442 f.write(vl)
443
444 dut = PartialAddrBitmap(3, 4, 10)
445 vl = rtlil.convert(dut, ports=dut.ports())
446 with open("test_part_bit.il", "w") as f:
447 f.write(vl)
448
449 run_simulation(dut, part_addr_bit(dut), vcd_name='test_part_bit.vcd')
450
451 dut = PartialAddrMatch(3, 10)
452 vl = rtlil.convert(dut, ports=dut.ports())
453 with open("test_part_addr.il", "w") as f:
454 f.write(vl)
455
456 run_simulation(dut, part_addr_sim(dut), vcd_name='test_part_addr.vcd')
457
458 if __name__ == '__main__':
459 test_part_addr()
460 test_lenexpand_byte()