move code to mmu_0
[soc.git] / src / soc / experiment / mmu.py
1 """MMU
2
3 based on Anton Blanchard microwatt mmu.vhdl
4
5 """
6 from enum import Enum, unique
7 from nmigen import (C, Module, Signal, Elaboratable, Mux, Cat, Repl, Signal)
8 from nmigen.cli import main
9 from nmigen.cli import rtlil
10 from nmutil.iocontrol import RecordObject
11 from nmutil.byterev import byte_reverse
12 from nmutil.mask import Mask
13 from nmutil.util import Display
14
15 if True:
16 from nmigen.back.pysim import Simulator, Delay, Settle
17 else:
18 from nmigen.sim.cxxsim import Simulator, Delay, Settle
19 from nmutil.util import wrap
20
21 from soc.experiment.mem_types import (LoadStore1ToMMUType,
22 MMUToLoadStore1Type,
23 MMUToDCacheType,
24 DCacheToMMUType,
25 MMUToICacheType)
26
27
28 @unique
29 class State(Enum):
30 IDLE = 0 # zero is default on reset for r.state
31 DO_TLBIE = 1
32 TLB_WAIT = 2
33 PROC_TBL_READ = 3
34 PROC_TBL_WAIT = 4
35 SEGMENT_CHECK = 5
36 RADIX_LOOKUP = 6
37 RADIX_READ_WAIT = 7
38 RADIX_LOAD_TLB = 8
39 RADIX_FINISH = 9
40
41
42 class RegStage(RecordObject):
43 def __init__(self, name=None):
44 super().__init__(name=name)
45 # latched request from loadstore1
46 self.valid = Signal()
47 self.iside = Signal()
48 self.store = Signal()
49 self.priv = Signal()
50 self.addr = Signal(64)
51 self.inval_all = Signal()
52 # config SPRs
53 self.prtbl = Signal(64)
54 self.pid = Signal(32)
55 # internal state
56 self.state = Signal(State) # resets to IDLE
57 self.done = Signal()
58 self.err = Signal()
59 self.pgtbl0 = Signal(64)
60 self.pt0_valid = Signal()
61 self.pgtbl3 = Signal(64)
62 self.pt3_valid = Signal()
63 self.shift = Signal(6)
64 self.mask_size = Signal(5)
65 self.pgbase = Signal(56)
66 self.pde = Signal(64)
67 self.invalid = Signal()
68 self.badtree = Signal()
69 self.segerror = Signal()
70 self.perm_err = Signal()
71 self.rc_error = Signal()
72
73
74 class MMU(Elaboratable):
75 """Radix MMU
76
77 Supports 4-level trees as in arch 3.0B, but not the
78 two-step translation for guests under a hypervisor
79 (i.e. there is no gRA -> hRA translation).
80 """
81 def __init__(self):
82 self.l_in = LoadStore1ToMMUType()
83 self.l_out = MMUToLoadStore1Type()
84 self.d_out = MMUToDCacheType()
85 self.d_in = DCacheToMMUType()
86 self.i_out = MMUToICacheType()
87
88 def radix_tree_idle(self, m, l_in, r, v):
89 comb = m.d.comb
90 pt_valid = Signal()
91 pgtbl = Signal(64)
92 with m.If(~l_in.addr[63]):
93 comb += pgtbl.eq(r.pgtbl0)
94 comb += pt_valid.eq(r.pt0_valid)
95 with m.Else():
96 comb += pgtbl.eq(r.pt3_valid)
97 comb += pt_valid.eq(r.pt3_valid)
98
99 # rts == radix tree size, number of address bits
100 # being translated
101 rts = Signal(6)
102 comb += rts.eq(Cat(pgtbl[5:8], pgtbl[61:63]))
103
104 # mbits == number of address bits to index top
105 # level of tree
106 mbits = Signal(6)
107 comb += mbits.eq(pgtbl[0:5])
108
109 # set v.shift to rts so that we can use finalmask
110 # for the segment check
111 comb += v.shift.eq(rts)
112 comb += v.mask_size.eq(mbits[0:5])
113 comb += v.pgbase.eq(Cat(C(0, 8), pgtbl[8:56]))
114
115 with m.If(l_in.valid):
116 comb += v.addr.eq(l_in.addr)
117 comb += v.iside.eq(l_in.iside)
118 comb += v.store.eq(~(l_in.load | l_in.iside))
119
120 with m.If(l_in.tlbie):
121 # Invalidate all iTLB/dTLB entries for
122 # tlbie with RB[IS] != 0 or RB[AP] != 0,
123 # or for slbia
124 comb += v.inval_all.eq(l_in.slbia
125 | l_in.addr[11]
126 | l_in.addr[10]
127 | l_in.addr[7]
128 | l_in.addr[6]
129 | l_in.addr[5]
130 )
131 # The RIC field of the tlbie instruction
132 # comes across on the sprn bus as bits 2--3.
133 # RIC=2 flushes process table caches.
134 with m.If(l_in.sprn[3]):
135 comb += v.pt0_valid.eq(0)
136 comb += v.pt3_valid.eq(0)
137 comb += v.state.eq(State.DO_TLBIE)
138 with m.Else():
139 comb += v.valid.eq(1)
140 with m.If(~pt_valid):
141 # need to fetch process table entry
142 # set v.shift so we can use finalmask
143 # for generating the process table
144 # entry address
145 comb += v.shift.eq(r.prtbl[0:5])
146 comb += v.state.eq(State.PROC_TBL_READ)
147
148 with m.If(~mbits):
149 # Use RPDS = 0 to disable radix tree walks
150 comb += v.state.eq(State.RADIX_FINISH)
151 comb += v.invalid.eq(1)
152 with m.Else():
153 comb += v.state.eq(State.SEGMENT_CHECK)
154
155 with m.If(l_in.mtspr):
156 # Move to PID needs to invalidate L1 TLBs
157 # and cached pgtbl0 value. Move to PRTBL
158 # does that plus invalidating the cached
159 # pgtbl3 value as well.
160 with m.If(~l_in.sprn[9]):
161 comb += v.pid.eq(l_in.rs[0:32])
162 with m.Else():
163 comb += v.prtbl.eq(l_in.rs)
164 comb += v.pt3_valid.eq(0)
165
166 comb += v.pt0_valid.eq(0)
167 comb += v.inval_all.eq(1)
168 comb += v.state.eq(State.DO_TLBIE)
169
170 def proc_tbl_wait(self, m, v, r, data):
171 comb = m.d.comb
172 with m.If(r.addr[63]):
173 comb += v.pgtbl3.eq(data)
174 comb += v.pt3_valid.eq(1)
175 with m.Else():
176 comb += v.pgtbl0.eq(data)
177 comb += v.pt0_valid.eq(1)
178 # rts == radix tree size, # address bits being translated
179 rts = Signal(6)
180 comb += rts.eq(Cat(data[5:8], data[61:63]))
181
182 # mbits == # address bits to index top level of tree
183 mbits = Signal(6)
184 comb += mbits.eq(data[0:5])
185 # set v.shift to rts so that we can use
186 # finalmask for the segment check
187 comb += v.shift.eq(rts)
188 comb += v.mask_size.eq(mbits[0:5])
189 comb += v.pgbase.eq(Cat(C(0, 8), data[8:56]))
190
191 with m.If(~mbits):
192 comb += v.state.eq(State.RADIX_FINISH)
193 comb += v.invalid.eq(1)
194 comb += v.state.eq(State.SEGMENT_CHECK)
195
196 def radix_read_wait(self, m, v, r, d_in, data):
197 comb = m.d.comb
198 comb += v.pde.eq(data)
199 # test valid bit
200 with m.If(data[63]):
201 with m.If(data[62]):
202 # check permissions and RC bits
203 perm_ok = Signal()
204 comb += perm_ok.eq(0)
205 with m.If(r.priv | ~data[3]):
206 with m.If(~r.iside):
207 comb += perm_ok.eq(
208 (data[1] | data[2])
209 & (~r.store)
210 )
211 with m.Else():
212 # no IAMR, so no KUEP support
213 # for now deny execute
214 # permission if cache inhibited
215 comb += perm_ok.eq(data[0] & ~data[5])
216
217 rc_ok = Signal()
218 comb += rc_ok.eq(data[8] & (data[7] | (~r.store)))
219 with m.If(perm_ok & rc_ok):
220 comb += v.state.eq(State.RADIX_LOAD_TLB)
221 with m.Else():
222 comb += v.state.eq(State.RADIX_FINISH)
223 comb += v.perm_err.eq(~perm_ok)
224 # permission error takes precedence
225 # over RC error
226 comb += v.rc_error.eq(perm_ok)
227 with m.Else():
228 mbits = Signal(6)
229 comb += mbits.eq(data[0:5])
230 with m.If((mbits < 5) | (mbits > 16) | (mbits > r.shift)):
231 comb += v.state.eq(State.RADIX_FINISH)
232 comb += v.badtree.eq(1)
233 with m.Else():
234 comb += v.shift.eq(v.shift - mbits)
235 comb += v.mask_size.eq(mbits[0:5])
236 comb += v.pgbase.eq(Cat(C(0, 8), data[8:56]))
237 comb += v.state.eq(State.RADIX_LOOKUP)
238
239 def segment_check(self, m, v, r, data, finalmask):
240 comb = m.d.comb
241 mbits = Signal(6)
242 nonzero = Signal()
243 comb += mbits.eq(r.mask_size)
244 comb += v.shift.eq(r.shift + (31 - 12) - mbits)
245 comb += nonzero.eq((r.addr[31:62] & ~finalmask[0:31]).bool())
246 with m.If((r.addr[63] ^ r.addr[62]) | nonzero):
247 comb += v.state.eq(State.RADIX_FINISH)
248 comb += v.segerror.eq(1)
249 with m.Elif((mbits < 5) | (mbits > 16) |
250 (mbits > (r.shift + (31-12)))):
251 comb += v.state.eq(State.RADIX_FINISH)
252 comb += v.badtree.eq(1)
253 with m.Else():
254 comb += v.state.eq(State.RADIX_LOOKUP)
255
256 def mmu_0(self, m, r, rin, l_in, l_out, d_out, addrsh, mask):
257 comb = m.d.comb
258 sync = m.d.sync
259
260 # Multiplex internal SPR values back to loadstore1,
261 # selected by l_in.sprn.
262 with m.If(l_in.sprn[9]):
263 comb += l_out.sprval.eq(r.prtbl)
264 with m.Else():
265 comb += l_out.sprval.eq(r.pid)
266
267 with m.If(rin.valid):
268 sync += Display("MMU got tlb miss for %x", rin.addr)
269
270 with m.If(l_out.done):
271 sync += Display("MMU completing op without error")
272
273 with m.If(l_out.err):
274 sync += Display("MMU completing op with err invalid"
275 "%d badtree=%d", l_out.invalid, l_out.badtree)
276
277 with m.If(rin.state == State.RADIX_LOOKUP):
278 sync += Display ("radix lookup shift=%d msize=%d",
279 rin.shift, rin.mask_size)
280
281 with m.If(r.state == State.RADIX_LOOKUP):
282 sync += Display(f"send load addr=%x addrsh=%d mask=%d",
283 d_out.addr, addrsh, mask)
284 sync += r.eq(rin)
285
286 def elaborate(self, platform):
287 m = Module()
288
289 comb = m.d.comb
290 sync = m.d.sync
291
292 addrsh = Signal(16)
293 mask = Signal(16)
294 finalmask = Signal(44)
295
296 r = RegStage("r")
297 rin = RegStage("r_in")
298
299 l_in = self.l_in
300 l_out = self.l_out
301 d_out = self.d_out
302 d_in = self.d_in
303 i_out = self.i_out
304
305 self.mmu_0(m, r, rin, l_in, l_out, d_out, addrsh, mask)
306
307 v = RegStage()
308 dcreq = Signal()
309 tlb_load = Signal()
310 itlb_load = Signal()
311 tlbie_req = Signal()
312 prtbl_rd = Signal()
313 effpid = Signal(32)
314 prtable_addr = Signal(64)
315 pgtable_addr = Signal(64)
316 pte = Signal(64)
317 tlb_data = Signal(64)
318 addr = Signal(64)
319
320 comb += v.eq(r)
321 comb += v.valid.eq(0)
322 comb += dcreq.eq(0)
323 comb += v.done.eq(0)
324 comb += v.err.eq(0)
325 comb += v.invalid.eq(0)
326 comb += v.badtree.eq(0)
327 comb += v.segerror.eq(0)
328 comb += v.perm_err.eq(0)
329 comb += v.rc_error.eq(0)
330 comb += tlb_load.eq(0)
331 comb += itlb_load.eq(0)
332 comb += tlbie_req.eq(0)
333 comb += v.inval_all.eq(0)
334 comb += prtbl_rd.eq(0)
335
336 # Radix tree data structures in memory are
337 # big-endian, so we need to byte-swap them
338 data = byte_reverse(m, "data", d_in.data, 8)
339
340 # generate mask for extracting address fields for PTE addr generation
341 m.submodules.pte_mask = pte_mask = Mask(16-5)
342 comb += pte_mask.shift.eq(r.mask_size - 5)
343 comb += mask.eq(Cat(C(0x1f, 5), pte_mask.mask))
344
345 # generate mask for extracting address bits to go in
346 # TLB entry in order to support pages > 4kB
347 m.submodules.tlb_mask = tlb_mask = Mask(44)
348 comb += tlb_mask.shift.eq(r.shift)
349 comb += finalmask.eq(tlb_mask.mask)
350
351 with m.Switch(r.state):
352 with m.Case(State.IDLE):
353 self.radix_tree_idle(m, l_in, r, v)
354
355 with m.Case(State.DO_TLBIE):
356 comb += dcreq.eq(1)
357 comb += tlbie_req.eq(1)
358 comb += v.state.eq(State.TLB_WAIT)
359
360 with m.Case(State.TLB_WAIT):
361 with m.If(d_in.done):
362 comb += v.state.eq(State.RADIX_FINISH)
363
364 with m.Case(State.PROC_TBL_READ):
365 comb += dcreq.eq(1)
366 comb += prtbl_rd.eq(1)
367 comb += v.state.eq(State.PROC_TBL_WAIT)
368
369 with m.Case(State.PROC_TBL_WAIT):
370 with m.If(d_in.done):
371 self.proc_tbl_wait(m, v, r, data)
372
373 with m.If(d_in.err):
374 comb += v.state.eq(State.RADIX_FINISH)
375 comb += v.badtree.eq(1)
376
377 with m.Case(State.SEGMENT_CHECK):
378 self.segment_check(m, v, r, data, finalmask)
379
380 with m.Case(State.RADIX_LOOKUP):
381 comb += dcreq.eq(1)
382 comb += v.state.eq(State.RADIX_READ_WAIT)
383
384 with m.Case(State.RADIX_READ_WAIT):
385 with m.If(d_in.done):
386 self.radix_read_wait(m, v, r, d_in, data)
387 with m.Else():
388 # non-present PTE, generate a DSI
389 comb += v.state.eq(State.RADIX_FINISH)
390 comb += v.invalid.eq(1)
391
392 with m.If(d_in.err):
393 comb += v.state.eq(State.RADIX_FINISH)
394 comb += v.badtree.eq(1)
395
396 with m.Case(State.RADIX_LOAD_TLB):
397 comb += tlb_load.eq(1)
398 with m.If(~r.iside):
399 comb += dcreq.eq(1)
400 comb += v.state.eq(State.TLB_WAIT)
401 with m.Else():
402 comb += itlb_load.eq(1)
403 comb += v.state.eq(State.IDLE)
404
405 with m.Case(State.RADIX_FINISH):
406 comb += v.state.eq(State.IDLE)
407
408 with m.If((v.state == State.RADIX_FINISH) |
409 ((v.state == State.RADIX_LOAD_TLB) & r.iside)):
410 comb += v.err.eq(v.invalid | v.badtree | v.segerror
411 | v.perm_err | v.rc_error)
412 comb += v.done.eq(~v.err)
413
414 with m.If(~r.addr[63]):
415 comb += effpid.eq(r.pid)
416
417 comb += prtable_addr.eq(Cat(
418 C(0b0000, 4),
419 effpid[0:8],
420 (r.prtbl[12:36] & ~finalmask[0:24]) |
421 (effpid[8:32] & finalmask[0:24]),
422 r.prtbl[36:56]
423 ))
424
425 comb += pgtable_addr.eq(Cat(
426 C(0b000, 3),
427 (r.pgbase[3:19] & ~mask) |
428 (addrsh & mask),
429 r.pgbase[19:56]
430 ))
431
432 comb += pte.eq(Cat(
433 r.pde[0:12],
434 (r.pde[12:56] & ~finalmask) |
435 (r.addr[12:56] & finalmask),
436 ))
437
438 # update registers
439 comb += rin.eq(v)
440
441 # drive outputs
442 with m.If(tlbie_req):
443 comb += addr.eq(r.addr)
444 with m.Elif(tlb_load):
445 comb += addr.eq(Cat(C(0, 12), r.addr[12:64]))
446 comb += tlb_data.eq(pte)
447 with m.Elif(prtbl_rd):
448 comb += addr.eq(prtable_addr)
449 with m.Else():
450 comb += addr.eq(pgtable_addr)
451
452 comb += l_out.done.eq(r.done)
453 comb += l_out.err.eq(r.err)
454 comb += l_out.invalid.eq(r.invalid)
455 comb += l_out.badtree.eq(r.badtree)
456 comb += l_out.segerr.eq(r.segerror)
457 comb += l_out.perm_error.eq(r.perm_err)
458 comb += l_out.rc_error.eq(r.rc_error)
459
460 comb += d_out.valid.eq(dcreq)
461 comb += d_out.tlbie.eq(tlbie_req)
462 comb += d_out.doall.eq(r.inval_all)
463 comb += d_out.tlbld.eq(tlb_load)
464 comb += d_out.addr.eq(addr)
465 comb += d_out.pte.eq(tlb_data)
466
467 comb += i_out.tlbld.eq(itlb_load)
468 comb += i_out.tlbie.eq(tlbie_req)
469 comb += i_out.doall.eq(r.inval_all)
470 comb += i_out.addr.eq(addr)
471 comb += i_out.pte.eq(tlb_data)
472
473 return m
474
475
476 def mmu_sim(dut):
477 yield wp.waddr.eq(1)
478 yield wp.data_i.eq(2)
479 yield wp.wen.eq(1)
480 yield
481 yield wp.wen.eq(0)
482 yield rp.ren.eq(1)
483 yield rp.raddr.eq(1)
484 yield Settle()
485 data = yield rp.data_o
486 print(data)
487 assert data == 2
488 yield
489
490 yield wp.waddr.eq(5)
491 yield rp.raddr.eq(5)
492 yield rp.ren.eq(1)
493 yield wp.wen.eq(1)
494 yield wp.data_i.eq(6)
495 yield Settle()
496 data = yield rp.data_o
497 print(data)
498 assert data == 6
499 yield
500 yield wp.wen.eq(0)
501 yield rp.ren.eq(0)
502 yield Settle()
503 data = yield rp.data_o
504 print(data)
505 assert data == 0
506 yield
507 data = yield rp.data_o
508 print(data)
509
510
511 def test_mmu():
512 dut = MMU()
513 vl = rtlil.convert(dut, ports=[])#dut.ports())
514 with open("test_mmu.il", "w") as f:
515 f.write(vl)
516
517 m = Module()
518 m.submodules.mmu = dut
519
520 # nmigen Simulation
521 sim = Simulator(m)
522 sim.add_clock(1e-6)
523
524 sim.add_sync_process(wrap(mmu_sim(dut)))
525 with sim.write_vcd('test_mmu.vcd'):
526 sim.run()
527
528 if __name__ == '__main__':
529 test_mmu()