radix tree wait error, investigating
[soc.git] / src / soc / experiment / mmu.py
1 # MMU
2 #
3 # License for original copyright mmu.vhdl by microwatt authors: CC4
4 # License for copyrighted modifications made in mmu.py: LGPLv3+
5 #
6 # This derivative work although includes CC4 licensed material is
7 # covered by the LGPLv3+
8
9 """MMU
10
11 based on Anton Blanchard microwatt mmu.vhdl
12
13 """
14 from enum import Enum, unique
15 from nmigen import (C, Module, Signal, Elaboratable, Mux, Cat, Repl, Signal)
16 from nmigen.cli import main
17 from nmigen.cli import rtlil
18 from nmutil.iocontrol import RecordObject
19 from nmutil.byterev import byte_reverse
20 from nmutil.mask import Mask, masked
21 from nmutil.util import Display
22
23 if True:
24 from nmigen.back.pysim import Simulator, Delay, Settle
25 else:
26 from nmigen.sim.cxxsim import Simulator, Delay, Settle
27 from nmutil.util import wrap
28
29 from soc.experiment.mem_types import (LoadStore1ToMMUType,
30 MMUToLoadStore1Type,
31 MMUToDCacheType,
32 DCacheToMMUType,
33 MMUToICacheType)
34
35
36 @unique
37 class State(Enum):
38 IDLE = 0 # zero is default on reset for r.state
39 DO_TLBIE = 1
40 TLB_WAIT = 2
41 PROC_TBL_READ = 3
42 PROC_TBL_WAIT = 4
43 SEGMENT_CHECK = 5
44 RADIX_LOOKUP = 6
45 RADIX_READ_WAIT = 7
46 RADIX_LOAD_TLB = 8
47 RADIX_FINISH = 9
48
49
50 class RegStage(RecordObject):
51 def __init__(self, name=None):
52 super().__init__(name=name)
53 # latched request from loadstore1
54 self.valid = Signal()
55 self.iside = Signal()
56 self.store = Signal()
57 self.priv = Signal()
58 self.addr = Signal(64)
59 self.inval_all = Signal()
60 # config SPRs
61 self.prtbl = Signal(64)
62 self.pid = Signal(32)
63 # internal state
64 self.state = Signal(State) # resets to IDLE
65 self.done = Signal()
66 self.err = Signal()
67 self.pgtbl0 = Signal(64)
68 self.pt0_valid = Signal()
69 self.pgtbl3 = Signal(64)
70 self.pt3_valid = Signal()
71 self.shift = Signal(6)
72 self.mask_size = Signal(5)
73 self.pgbase = Signal(56)
74 self.pde = Signal(64)
75 self.invalid = Signal()
76 self.badtree = Signal()
77 self.segerror = Signal()
78 self.perm_err = Signal()
79 self.rc_error = Signal()
80
81
82 class MMU(Elaboratable):
83 """Radix MMU
84
85 Supports 4-level trees as in arch 3.0B, but not the
86 two-step translation for guests under a hypervisor
87 (i.e. there is no gRA -> hRA translation).
88 """
89 def __init__(self):
90 self.l_in = LoadStore1ToMMUType()
91 self.l_out = MMUToLoadStore1Type()
92 self.d_out = MMUToDCacheType()
93 self.d_in = DCacheToMMUType()
94 self.i_out = MMUToICacheType()
95
96 def radix_tree_idle(self, m, l_in, r, v):
97 comb = m.d.comb
98 sync = m.d.sync
99
100 pt_valid = Signal()
101 pgtbl = Signal(64)
102 rts = Signal(6)
103 mbits = Signal(6)
104
105 with m.If(~l_in.addr[63]):
106 comb += pgtbl.eq(r.pgtbl0)
107 comb += pt_valid.eq(r.pt0_valid)
108 with m.Else():
109 comb += pgtbl.eq(r.pgtbl3)
110 comb += pt_valid.eq(r.pt3_valid)
111
112 # rts == radix tree size, number of address bits
113 # being translated
114 comb += rts.eq(Cat(pgtbl[5:8], pgtbl[61:63]))
115
116 # mbits == number of address bits to index top
117 # level of tree
118 comb += mbits.eq(pgtbl[0:5])
119
120 # set v.shift to rts so that we can use finalmask
121 # for the segment check
122 comb += v.shift.eq(rts)
123 comb += v.mask_size.eq(mbits[0:5])
124 comb += v.pgbase.eq(Cat(C(0, 8), pgtbl[8:56]))
125
126 with m.If(l_in.valid):
127 comb += v.addr.eq(l_in.addr)
128 comb += v.iside.eq(l_in.iside)
129 comb += v.store.eq(~(l_in.load | l_in.iside))
130 comb += v.priv.eq(l_in.priv)
131
132 comb += Display("l_in.valid addr %x iside %d store %d "
133 "rts %x mbits %x pt_valid %d",
134 v.addr, v.iside, v.store,
135 rts, mbits, pt_valid)
136
137 with m.If(l_in.tlbie):
138 # Invalidate all iTLB/dTLB entries for
139 # tlbie with RB[IS] != 0 or RB[AP] != 0,
140 # or for slbia
141 comb += v.inval_all.eq(l_in.slbia
142 | l_in.addr[11]
143 | l_in.addr[10]
144 | l_in.addr[7]
145 | l_in.addr[6]
146 | l_in.addr[5]
147 )
148 # The RIC field of the tlbie instruction
149 # comes across on the sprn bus as bits 2--3.
150 # RIC=2 flushes process table caches.
151 with m.If(l_in.sprn[3]):
152 comb += v.pt0_valid.eq(0)
153 comb += v.pt3_valid.eq(0)
154 comb += v.state.eq(State.DO_TLBIE)
155 with m.Else():
156 comb += v.valid.eq(1)
157 with m.If(~pt_valid):
158 # need to fetch process table entry
159 # set v.shift so we can use finalmask
160 # for generating the process table
161 # entry address
162 comb += v.shift.eq(r.prtbl[0:5])
163 comb += v.state.eq(State.PROC_TBL_READ)
164
165 with m.Elif(mbits == 0):
166 # Use RPDS = 0 to disable radix tree walks
167 comb += v.state.eq(State.RADIX_FINISH)
168 comb += v.invalid.eq(1)
169 with m.Else():
170 comb += v.state.eq(State.SEGMENT_CHECK)
171
172 with m.If(l_in.mtspr):
173 # Move to PID needs to invalidate L1 TLBs
174 # and cached pgtbl0 value. Move to PRTBL
175 # does that plus invalidating the cached
176 # pgtbl3 value as well.
177 with m.If(~l_in.sprn[9]):
178 comb += v.pid.eq(l_in.rs[0:32])
179 with m.Else():
180 comb += v.prtbl.eq(l_in.rs)
181 comb += v.pt3_valid.eq(0)
182
183 comb += v.pt0_valid.eq(0)
184 comb += v.inval_all.eq(1)
185 comb += v.state.eq(State.DO_TLBIE)
186
187 def proc_tbl_wait(self, m, v, r, data):
188 comb = m.d.comb
189 with m.If(r.addr[63]):
190 comb += v.pgtbl3.eq(data)
191 comb += v.pt3_valid.eq(1)
192 with m.Else():
193 comb += v.pgtbl0.eq(data)
194 comb += v.pt0_valid.eq(1)
195
196 rts = Signal(6)
197 mbits = Signal(6)
198
199 # rts == radix tree size, # address bits being translated
200 comb += rts.eq(Cat(data[5:8], data[61:63]))
201
202 # mbits == # address bits to index top level of tree
203 comb += mbits.eq(data[0:5])
204
205 # set v.shift to rts so that we can use finalmask for the segment check
206 comb += v.shift.eq(rts)
207 comb += v.mask_size.eq(mbits[0:5])
208 comb += v.pgbase.eq(Cat(C(0, 8), data[8:56]))
209
210 with m.If(mbits):
211 comb += v.state.eq(State.SEGMENT_CHECK)
212 with m.Else():
213 comb += v.state.eq(State.RADIX_FINISH)
214 comb += v.invalid.eq(1)
215
216 def radix_read_wait(self, m, v, r, d_in, data):
217 comb = m.d.comb
218
219 perm_ok = Signal()
220 rc_ok = Signal()
221 mbits = Signal(6)
222 valid = Signal(1)
223 leaf = Signal(1)
224
225 with m.If(d_in.done & (r.state == State.RADIX_READ_WAIT)):
226 comb += Display("RADRDWAIT %016x done %d "
227 "perm %d rc %d mbits %d rshift %d "
228 "valid %d leaf %d",
229 data, d_in.done, perm_ok, rc_ok,
230 mbits, r.shift, valid, leaf)
231
232 # set pde
233 comb += v.pde.eq(data)
234
235 # test valid bit
236 comb += valid.eq(data[63]) # valid=data[63]
237 comb += leaf.eq(data[62]) # valid=data[63]
238
239 comb += v.pde.eq(data)
240 # valid & leaf
241 with m.If(valid):
242 with m.If(leaf):
243 # check permissions and RC bits
244 with m.If(r.priv | ~data[3]):
245 with m.If(~r.iside):
246 comb += perm_ok.eq(data[1] | (data[2] & ~r.store))
247 with m.Else():
248 # no IAMR, so no KUEP support for now
249 # deny execute permission if cache inhibited
250 comb += perm_ok.eq(data[0] & ~data[5])
251
252 comb += rc_ok.eq(data[8] & (data[7] | ~r.store))
253 with m.If(perm_ok & rc_ok):
254 comb += v.state.eq(State.RADIX_LOAD_TLB)
255 with m.Else():
256 comb += v.state.eq(State.RADIX_FINISH)
257 comb += v.perm_err.eq(~perm_ok)
258 # permission error takes precedence over RC error
259 comb += v.rc_error.eq(perm_ok)
260
261 # valid & !leaf
262 with m.Else():
263 comb += mbits.eq(data[0:5])
264 with m.If((mbits < 5) | (mbits > 16) | (mbits > r.shift)):
265 comb += v.state.eq(State.RADIX_FINISH)
266 comb += v.badtree.eq(1)
267 with m.Else():
268 comb += v.shift.eq(v.shift - mbits)
269 comb += v.mask_size.eq(mbits[0:5])
270 comb += v.pgbase.eq(Cat(C(0, 8), data[8:56]))
271 comb += v.state.eq(State.RADIX_LOOKUP)
272 with m.Else():
273 # non-present PTE, generate a DSI
274 comb += v.state.eq(State.RADIX_FINISH)
275 comb += v.invalid.eq(1)
276
277 def segment_check(self, m, v, r, data, finalmask):
278 comb = m.d.comb
279
280 mbits = Signal(6)
281 nonzero = Signal()
282 comb += mbits.eq(r.mask_size)
283 comb += v.shift.eq(r.shift + (31 - 12) - mbits)
284 comb += nonzero.eq((r.addr[31:62] & ~finalmask[0:31]).bool())
285 with m.If((r.addr[63] ^ r.addr[62]) | nonzero):
286 comb += v.state.eq(State.RADIX_FINISH)
287 comb += v.segerror.eq(1)
288 with m.Elif((mbits < 5) | (mbits > 16) |
289 (mbits > (r.shift + (31-12)))):
290 comb += v.state.eq(State.RADIX_FINISH)
291 comb += v.badtree.eq(1)
292 with m.Else():
293 comb += v.state.eq(State.RADIX_LOOKUP)
294
295 def mmu_0(self, m, r, rin, l_in, l_out, d_out, addrsh, mask):
296 comb = m.d.comb
297 sync = m.d.sync
298
299 # Multiplex internal SPR values back to loadstore1,
300 # selected by l_in.sprn.
301 with m.If(l_in.sprn[9]):
302 comb += l_out.sprval.eq(r.prtbl)
303 with m.Else():
304 comb += l_out.sprval.eq(r.pid)
305
306 with m.If(rin.valid):
307 sync += Display("MMU got tlb miss for %x", rin.addr)
308
309 with m.If(l_out.done):
310 sync += Display("MMU completing op without error")
311
312 with m.If(l_out.err):
313 sync += Display("MMU completing op with err invalid"
314 "%d badtree=%d", l_out.invalid, l_out.badtree)
315
316 with m.If(rin.state == State.RADIX_LOOKUP):
317 sync += Display ("radix lookup shift=%d msize=%d",
318 rin.shift, rin.mask_size)
319
320 with m.If(r.state == State.RADIX_LOOKUP):
321 sync += Display(f"send load addr=%x addrsh=%d mask=%d",
322 d_out.addr, addrsh, mask)
323 sync += r.eq(rin)
324
325 def elaborate(self, platform):
326 m = Module()
327
328 comb = m.d.comb
329 sync = m.d.sync
330
331 addrsh = Signal(16)
332 mask = Signal(16)
333 finalmask = Signal(44)
334
335 self.rin = rin = RegStage("r_in")
336 r = RegStage("r")
337
338 l_in = self.l_in
339 l_out = self.l_out
340 d_out = self.d_out
341 d_in = self.d_in
342 i_out = self.i_out
343
344 self.mmu_0(m, r, rin, l_in, l_out, d_out, addrsh, mask)
345
346 v = RegStage()
347 dcreq = Signal()
348 tlb_load = Signal()
349 itlb_load = Signal()
350 tlbie_req = Signal()
351 prtbl_rd = Signal()
352 effpid = Signal(32)
353 prtb_adr = Signal(64)
354 pgtb_adr = Signal(64)
355 pte = Signal(64)
356 tlb_data = Signal(64)
357 addr = Signal(64)
358
359 comb += v.eq(r)
360 comb += v.valid.eq(0)
361 comb += dcreq.eq(0)
362 comb += v.done.eq(0)
363 comb += v.err.eq(0)
364 comb += v.invalid.eq(0)
365 comb += v.badtree.eq(0)
366 comb += v.segerror.eq(0)
367 comb += v.perm_err.eq(0)
368 comb += v.rc_error.eq(0)
369 comb += tlb_load.eq(0)
370 comb += itlb_load.eq(0)
371 comb += tlbie_req.eq(0)
372 comb += v.inval_all.eq(0)
373 comb += prtbl_rd.eq(0)
374
375 # Radix tree data structures in memory are
376 # big-endian, so we need to byte-swap them
377 data = byte_reverse(m, "data", d_in.data, 8)
378
379 # generate mask for extracting address fields for PTE addr generation
380 m.submodules.pte_mask = pte_mask = Mask(16-5)
381 comb += pte_mask.shift.eq(r.mask_size - 5)
382 comb += mask.eq(Cat(C(0x1f, 5), pte_mask.mask))
383
384 # generate mask for extracting address bits to go in
385 # TLB entry in order to support pages > 4kB
386 m.submodules.tlb_mask = tlb_mask = Mask(44)
387 comb += tlb_mask.shift.eq(r.shift)
388 comb += finalmask.eq(tlb_mask.mask)
389
390 with m.If(r.state != State.IDLE):
391 sync += Display("MMU state %d %016x", r.state, data)
392
393 with m.Switch(r.state):
394 with m.Case(State.IDLE):
395 self.radix_tree_idle(m, l_in, r, v)
396
397 with m.Case(State.DO_TLBIE):
398 comb += dcreq.eq(1)
399 comb += tlbie_req.eq(1)
400 comb += v.state.eq(State.TLB_WAIT)
401
402 with m.Case(State.TLB_WAIT):
403 with m.If(d_in.done):
404 comb += v.state.eq(State.RADIX_FINISH)
405
406 with m.Case(State.PROC_TBL_READ):
407 sync += Display(" TBL_READ %016x", prtb_adr)
408 comb += dcreq.eq(1)
409 comb += prtbl_rd.eq(1)
410 comb += v.state.eq(State.PROC_TBL_WAIT)
411
412 with m.Case(State.PROC_TBL_WAIT):
413 with m.If(d_in.done):
414 self.proc_tbl_wait(m, v, r, data)
415
416 with m.If(d_in.err):
417 comb += v.state.eq(State.RADIX_FINISH)
418 comb += v.badtree.eq(1)
419
420 with m.Case(State.SEGMENT_CHECK):
421 self.segment_check(m, v, r, data, finalmask)
422
423 with m.Case(State.RADIX_LOOKUP):
424 comb += dcreq.eq(1)
425 comb += v.state.eq(State.RADIX_READ_WAIT)
426
427 with m.Case(State.RADIX_READ_WAIT):
428 with m.If(d_in.done):
429 self.radix_read_wait(m, v, r, d_in, data)
430
431 with m.If(d_in.err):
432 comb += v.state.eq(State.RADIX_FINISH)
433 comb += v.badtree.eq(1)
434
435 with m.Case(State.RADIX_LOAD_TLB):
436 comb += tlb_load.eq(1)
437 with m.If(~r.iside):
438 comb += dcreq.eq(1)
439 comb += v.state.eq(State.TLB_WAIT)
440 with m.Else():
441 comb += itlb_load.eq(1)
442 comb += v.state.eq(State.IDLE)
443
444 with m.Case(State.RADIX_FINISH):
445 comb += v.state.eq(State.IDLE)
446
447 with m.If((v.state == State.RADIX_FINISH) |
448 ((v.state == State.RADIX_LOAD_TLB) & r.iside)):
449 comb += v.err.eq(v.invalid | v.badtree | v.segerror
450 | v.perm_err | v.rc_error)
451 comb += v.done.eq(~v.err)
452
453 with m.If(~r.addr[63]):
454 comb += effpid.eq(r.pid)
455
456 pr24 = Signal(24, reset_less=True)
457 comb += pr24.eq(masked(r.prtbl[12:36], effpid[8:32], finalmask))
458 comb += prtb_adr.eq(Cat(C(0, 4), effpid[0:8], pr24, r.prtbl[36:56]))
459
460 pg16 = Signal(16, reset_less=True)
461 comb += pg16.eq(masked(r.pgbase[3:19], addrsh, mask))
462 comb += pgtb_adr.eq(Cat(C(0, 3), pg16, r.pgbase[19:56]))
463
464 pd44 = Signal(44, reset_less=True)
465 comb += pd44.eq(masked(r.pde[12:56], r.addr[12:56], finalmask))
466 comb += pte.eq(Cat(r.pde[0:12], pd44))
467
468 # update registers
469 comb += rin.eq(v)
470
471 # drive outputs
472 with m.If(tlbie_req):
473 comb += addr.eq(r.addr)
474 with m.Elif(tlb_load):
475 comb += addr.eq(Cat(C(0, 12), r.addr[12:64]))
476 comb += tlb_data.eq(pte)
477 with m.Elif(prtbl_rd):
478 comb += addr.eq(prtb_adr)
479 with m.Else():
480 comb += addr.eq(pgtb_adr)
481
482 comb += l_out.done.eq(r.done)
483 comb += l_out.err.eq(r.err)
484 comb += l_out.invalid.eq(r.invalid)
485 comb += l_out.badtree.eq(r.badtree)
486 comb += l_out.segerr.eq(r.segerror)
487 comb += l_out.perm_error.eq(r.perm_err)
488 comb += l_out.rc_error.eq(r.rc_error)
489
490 comb += d_out.valid.eq(dcreq)
491 comb += d_out.tlbie.eq(tlbie_req)
492 comb += d_out.doall.eq(r.inval_all)
493 comb += d_out.tlbld.eq(tlb_load)
494 comb += d_out.addr.eq(addr)
495 comb += d_out.pte.eq(tlb_data)
496
497 comb += i_out.tlbld.eq(itlb_load)
498 comb += i_out.tlbie.eq(tlbie_req)
499 comb += i_out.doall.eq(r.inval_all)
500 comb += i_out.addr.eq(addr)
501 comb += i_out.pte.eq(tlb_data)
502
503 return m
504
505 stop = False
506
507 def dcache_get(dut):
508 """simulator process for getting memory load requests
509 """
510
511 global stop
512
513 def b(x):
514 return int.from_bytes(x.to_bytes(8, byteorder='little'),
515 byteorder='big', signed=False)
516
517 mem = {0x10000: # PARTITION_TABLE_2
518 # PATB_GR=1 PRTB=0x1000 PRTS=0xb
519 b(0x800000000100000b),
520
521 0x30000: # RADIX_ROOT_PTE
522 # V = 1 L = 0 NLB = 0x400 NLS = 9
523 b(0x8000000000040009),
524
525 0x1000000: # PROCESS_TABLE_3
526 # RTS1 = 0x2 RPDB = 0x300 RTS2 = 0x5 RPDS = 13
527 b(0x40000000000300ad),
528 }
529
530 while not stop:
531 while True: # wait for dc_valid
532 if stop:
533 return
534 dc_valid = yield (dut.d_out.valid)
535 if dc_valid:
536 break
537 yield
538 addr = yield dut.d_out.addr
539 if addr not in mem:
540 print (" DCACHE LOOKUP FAIL %x" % (addr))
541 stop = True
542 return
543
544 data = mem[addr]
545 yield dut.d_in.data.eq(data)
546 print ("dcache get %x data %x" % (addr, data))
547 yield dut.d_in.done.eq(1)
548 yield
549 yield dut.d_in.done.eq(0)
550
551
552 def mmu_sim(dut):
553 yield dut.rin.prtbl.eq(0x1000000) # set process table
554 yield
555
556 yield dut.l_in.load.eq(1)
557 yield dut.l_in.priv.eq(1)
558 yield dut.l_in.addr.eq(0x10000)
559 yield dut.l_in.valid.eq(1)
560 while True: # wait for dc_valid / err
561 l_done = yield (dut.l_out.done)
562 l_err = yield (dut.l_out.err)
563 l_badtree = yield (dut.l_out.badtree)
564 l_permerr = yield (dut.l_out.perm_error)
565 l_rc_err = yield (dut.l_out.rc_error)
566 l_segerr = yield (dut.l_out.segerr)
567 l_invalid = yield (dut.l_out.invalid)
568 if (l_done or l_err or l_badtree or
569 l_permerr or l_rc_err or l_segerr or l_invalid):
570 break
571 yield
572 addr = yield dut.d_out.addr
573 pte = yield dut.d_out.pte
574 print ("translated done %d err %d badtree %d addr %x pte %x" % \
575 (l_done, l_err, l_badtree, addr, pte))
576
577 global stop
578 stop = True
579
580
581 def test_mmu():
582 dut = MMU()
583 vl = rtlil.convert(dut, ports=[])#dut.ports())
584 with open("test_mmu.il", "w") as f:
585 f.write(vl)
586
587 m = Module()
588 m.submodules.mmu = dut
589
590 # nmigen Simulation
591 sim = Simulator(m)
592 sim.add_clock(1e-6)
593
594 sim.add_sync_process(wrap(mmu_sim(dut)))
595 sim.add_sync_process(wrap(dcache_get(dut)))
596 with sim.write_vcd('test_mmu.vcd'):
597 sim.run()
598
599 if __name__ == '__main__':
600 test_mmu()