dc427dc0f02b54f907443935427aa92283fb7fb5
[soc.git] / src / soc / experiment / mmu.py
1 """MMU
2
3 based on Anton Blanchard microwatt mmu.vhdl
4
5 """
6 from enum import Enum, unique
7 from nmigen import (C, Module, Signal, Elaboratable, Mux, Cat, Repl, Signal)
8 from nmigen.cli import main
9 from nmigen.cli import rtlil
10 from nmutil.iocontrol import RecordObject
11 from nmutil.byterev import byte_reverse
12 from nmutil.mask import Mask
13 from nmutil.util import Display
14
15 if True:
16 from nmigen.back.pysim import Simulator, Delay, Settle
17 else:
18 from nmigen.sim.cxxsim import Simulator, Delay, Settle
19 from nmutil.util import wrap
20
21 from soc.experiment.mem_types import (LoadStore1ToMMUType,
22 MMUToLoadStore1Type,
23 MMUToDCacheType,
24 DCacheToMMUType,
25 MMUToICacheType)
26
27
28 @unique
29 class State(Enum):
30 IDLE = 0 # zero is default on reset for r.state
31 DO_TLBIE = 1
32 TLB_WAIT = 2
33 PROC_TBL_READ = 3
34 PROC_TBL_WAIT = 4
35 SEGMENT_CHECK = 5
36 RADIX_LOOKUP = 6
37 RADIX_READ_WAIT = 7
38 RADIX_LOAD_TLB = 8
39 RADIX_FINISH = 9
40
41
42 class RegStage(RecordObject):
43 def __init__(self, name=None):
44 super().__init__(name=name)
45 # latched request from loadstore1
46 self.valid = Signal()
47 self.iside = Signal()
48 self.store = Signal()
49 self.priv = Signal()
50 self.addr = Signal(64)
51 self.inval_all = Signal()
52 # config SPRs
53 self.prtbl = Signal(64)
54 self.pid = Signal(32)
55 # internal state
56 self.state = Signal(State) # resets to IDLE
57 self.done = Signal()
58 self.err = Signal()
59 self.pgtbl0 = Signal(64)
60 self.pt0_valid = Signal()
61 self.pgtbl3 = Signal(64)
62 self.pt3_valid = Signal()
63 self.shift = Signal(6)
64 self.mask_size = Signal(5)
65 self.pgbase = Signal(56)
66 self.pde = Signal(64)
67 self.invalid = Signal()
68 self.badtree = Signal()
69 self.segerror = Signal()
70 self.perm_err = Signal()
71 self.rc_error = Signal()
72
73
74 class MMU(Elaboratable):
75 """Radix MMU
76
77 Supports 4-level trees as in arch 3.0B, but not the
78 two-step translation for guests under a hypervisor
79 (i.e. there is no gRA -> hRA translation).
80 """
81 def __init__(self):
82 self.l_in = LoadStore1ToMMUType()
83 self.l_out = MMUToLoadStore1Type()
84 self.d_out = MMUToDCacheType()
85 self.d_in = DCacheToMMUType()
86 self.i_out = MMUToICacheType()
87
88 def radix_tree_idle(self, m, l_in, r, v):
89 comb = m.d.comb
90 pt_valid = Signal()
91 pgtbl = Signal(64)
92 with m.If(~l_in.addr[63]):
93 comb += pgtbl.eq(r.pgtbl0)
94 comb += pt_valid.eq(r.pt0_valid)
95 with m.Else():
96 comb += pgtbl.eq(r.pt3_valid)
97 comb += pt_valid.eq(r.pt3_valid)
98
99 # rts == radix tree size, number of address bits
100 # being translated
101 rts = Signal(6)
102 comb += rts.eq(Cat(pgtbl[5:8], pgtbl[61:63]))
103
104 # mbits == number of address bits to index top
105 # level of tree
106 mbits = Signal(6)
107 comb += mbits.eq(pgtbl[0:5])
108
109 # set v.shift to rts so that we can use finalmask
110 # for the segment check
111 comb += v.shift.eq(rts)
112 comb += v.mask_size.eq(mbits[0:5])
113 comb += v.pgbase.eq(Cat(C(0, 8), pgtbl[8:56]))
114
115 with m.If(l_in.valid):
116 comb += v.addr.eq(l_in.addr)
117 comb += v.iside.eq(l_in.iside)
118 comb += v.store.eq(~(l_in.load | l_in.iside))
119
120 with m.If(l_in.tlbie):
121 # Invalidate all iTLB/dTLB entries for
122 # tlbie with RB[IS] != 0 or RB[AP] != 0,
123 # or for slbia
124 comb += v.inval_all.eq(l_in.slbia
125 | l_in.addr[11]
126 | l_in.addr[10]
127 | l_in.addr[7]
128 | l_in.addr[6]
129 | l_in.addr[5]
130 )
131 # The RIC field of the tlbie instruction
132 # comes across on the sprn bus as bits 2--3.
133 # RIC=2 flushes process table caches.
134 with m.If(l_in.sprn[3]):
135 comb += v.pt0_valid.eq(0)
136 comb += v.pt3_valid.eq(0)
137 comb += v.state.eq(State.DO_TLBIE)
138 with m.Else():
139 comb += v.valid.eq(1)
140 with m.If(~pt_valid):
141 # need to fetch process table entry
142 # set v.shift so we can use finalmask
143 # for generating the process table
144 # entry address
145 comb += v.shift.eq(r.prtbl[0:5])
146 comb += v.state.eq(State.PROC_TBL_READ)
147
148 with m.If(~mbits):
149 # Use RPDS = 0 to disable radix tree walks
150 comb += v.state.eq(State.RADIX_FINISH)
151 comb += v.invalid.eq(1)
152 with m.Else():
153 comb += v.state.eq(State.SEGMENT_CHECK)
154
155 with m.If(l_in.mtspr):
156 # Move to PID needs to invalidate L1 TLBs
157 # and cached pgtbl0 value. Move to PRTBL
158 # does that plus invalidating the cached
159 # pgtbl3 value as well.
160 with m.If(~l_in.sprn[9]):
161 comb += v.pid.eq(l_in.rs[0:32])
162 with m.Else():
163 comb += v.prtbl.eq(l_in.rs)
164 comb += v.pt3_valid.eq(0)
165
166 comb += v.pt0_valid.eq(0)
167 comb += v.inval_all.eq(1)
168 comb += v.state.eq(State.DO_TLBIE)
169
170 def proc_tbl_wait(self, m, v, r, data):
171 comb = m.d.comb
172 with m.If(r.addr[63]):
173 comb += v.pgtbl3.eq(data)
174 comb += v.pt3_valid.eq(1)
175 with m.Else():
176 comb += v.pgtbl0.eq(data)
177 comb += v.pt0_valid.eq(1)
178 # rts == radix tree size, # address bits being translated
179 rts = Signal(6)
180 comb += rts.eq(Cat(data[5:8], data[61:63]))
181
182 # mbits == # address bits to index top level of tree
183 mbits = Signal(6)
184 comb += mbits.eq(data[0:5])
185 # set v.shift to rts so that we can use
186 # finalmask for the segment check
187 comb += v.shift.eq(rts)
188 comb += v.mask_size.eq(mbits[0:5])
189 comb += v.pgbase.eq(Cat(C(0, 8), data[8:56]))
190
191 with m.If(~mbits):
192 comb += v.state.eq(State.RADIX_FINISH)
193 comb += v.invalid.eq(1)
194 comb += v.state.eq(State.SEGMENT_CHECK)
195
196 def radix_read_wait(self, m, v, r, d_in, data):
197 comb = m.d.comb
198 comb += v.pde.eq(data)
199 # test valid bit
200 with m.If(data[63]):
201 with m.If(data[62]):
202 # check permissions and RC bits
203 perm_ok = Signal()
204 comb += perm_ok.eq(0)
205 with m.If(r.priv | ~data[3]):
206 with m.If(~r.iside):
207 comb += perm_ok.eq(
208 (data[1] | data[2])
209 & (~r.store)
210 )
211 with m.Else():
212 # no IAMR, so no KUEP support
213 # for now deny execute
214 # permission if cache inhibited
215 comb += perm_ok.eq(data[0] & ~data[5])
216
217 rc_ok = Signal()
218 comb += rc_ok.eq(data[8] & (data[7] | (~r.store)))
219 with m.If(perm_ok & rc_ok):
220 comb += v.state.eq(State.RADIX_LOAD_TLB)
221 with m.Else():
222 comb += v.state.eq(State.RADIX_FINISH)
223 comb += v.perm_err.eq(~perm_ok)
224 # permission error takes precedence
225 # over RC error
226 comb += v.rc_error.eq(perm_ok)
227 with m.Else():
228 mbits = Signal(6)
229 comb += mbits.eq(data[0:5])
230 with m.If((mbits < 5) | (mbits > 16) | (mbits > r.shift)):
231 comb += v.state.eq(State.RADIX_FINISH)
232 comb += v.badtree.eq(1)
233 with m.Else():
234 comb += v.shift.eq(v.shift - mbits)
235 comb += v.mask_size.eq(mbits[0:5])
236 comb += v.pgbase.eq(Cat(C(0, 8), data[8:56]))
237 comb += v.state.eq(State.RADIX_LOOKUP)
238
239 def segment_check(self, m, v, r, data, finalmask):
240 comb = m.d.comb
241 mbits = Signal(6)
242 nonzero = Signal()
243 comb += mbits.eq(r.mask_size)
244 comb += v.shift.eq(r.shift + (31 - 12) - mbits)
245 comb += nonzero.eq((r.addr[31:62] & ~finalmask[0:31]).bool())
246 with m.If((r.addr[63] ^ r.addr[62]) | nonzero):
247 comb += v.state.eq(State.RADIX_FINISH)
248 comb += v.segerror.eq(1)
249 with m.Elif((mbits < 5) | (mbits > 16) |
250 (mbits > (r.shift + (31-12)))):
251 comb += v.state.eq(State.RADIX_FINISH)
252 comb += v.badtree.eq(1)
253 with m.Else():
254 comb += v.state.eq(State.RADIX_LOOKUP)
255
256 def elaborate(self, platform):
257 m = Module()
258
259 comb = m.d.comb
260 sync = m.d.sync
261
262 addrsh = Signal(16)
263 mask = Signal(16)
264 finalmask = Signal(44)
265
266 r = RegStage("r")
267 rin = RegStage("r_in")
268
269 l_in = self.l_in
270 l_out = self.l_out
271 d_out = self.d_out
272 d_in = self.d_in
273 i_out = self.i_out
274
275 # Multiplex internal SPR values back to loadstore1,
276 # selected by l_in.sprn.
277 with m.If(l_in.sprn[9]):
278 comb += l_out.sprval.eq(r.prtbl)
279 with m.Else():
280 comb += l_out.sprval.eq(r.pid)
281
282 with m.If(rin.valid):
283 sync += Display("MMU got tlb miss for %x", rin.addr)
284
285 with m.If(l_out.done):
286 sync += Display("MMU completing op without error")
287
288 with m.If(l_out.err):
289 sync += Display("MMU completing op with err invalid"
290 "%d badtree=%d", l_out.invalid, l_out.badtree)
291
292 with m.If(rin.state == State.RADIX_LOOKUP):
293 sync += Display ("radix lookup shift=%d msize=%d",
294 rin.shift, rin.mask_size)
295
296 with m.If(r.state == State.RADIX_LOOKUP):
297 sync += Display(f"send load addr=%x addrsh=%d mask=%d",
298 d_out.addr, addrsh, mask)
299 sync += r.eq(rin)
300
301 v = RegStage()
302 dcreq = Signal()
303 tlb_load = Signal()
304 itlb_load = Signal()
305 tlbie_req = Signal()
306 prtbl_rd = Signal()
307 effpid = Signal(32)
308 prtable_addr = Signal(64)
309 pgtable_addr = Signal(64)
310 pte = Signal(64)
311 tlb_data = Signal(64)
312 addr = Signal(64)
313
314 comb += v.eq(r)
315 comb += v.valid.eq(0)
316 comb += dcreq.eq(0)
317 comb += v.done.eq(0)
318 comb += v.err.eq(0)
319 comb += v.invalid.eq(0)
320 comb += v.badtree.eq(0)
321 comb += v.segerror.eq(0)
322 comb += v.perm_err.eq(0)
323 comb += v.rc_error.eq(0)
324 comb += tlb_load.eq(0)
325 comb += itlb_load.eq(0)
326 comb += tlbie_req.eq(0)
327 comb += v.inval_all.eq(0)
328 comb += prtbl_rd.eq(0)
329
330 # Radix tree data structures in memory are
331 # big-endian, so we need to byte-swap them
332 data = byte_reverse(m, "data", d_in.data, 8)
333
334 # generate mask for extracting address fields for PTE addr generation
335 m.submodules.pte_mask = pte_mask = Mask(16-5)
336 comb += pte_mask.shift.eq(r.mask_size - 5)
337 comb += mask.eq(Cat(C(0x1f,5), pte_mask.mask))
338
339 # generate mask for extracting address bits to go in
340 # TLB entry in order to support pages > 4kB
341 m.submodules.tlb_mask = tlb_mask = Mask(44)
342 comb += tlb_mask.shift.eq(r.shift)
343 comb += finalmask.eq(tlb_mask.mask)
344
345 with m.Switch(r.state):
346 with m.Case(State.IDLE):
347 self.radix_tree_idle(m, l_in, r, v)
348
349 with m.Case(State.DO_TLBIE):
350 comb += dcreq.eq(1)
351 comb += tlbie_req.eq(1)
352 comb += v.state.eq(State.TLB_WAIT)
353
354 with m.Case(State.TLB_WAIT):
355 with m.If(d_in.done):
356 comb += v.state.eq(State.RADIX_FINISH)
357
358 with m.Case(State.PROC_TBL_READ):
359 comb += dcreq.eq(1)
360 comb += prtbl_rd.eq(1)
361 comb += v.state.eq(State.PROC_TBL_WAIT)
362
363 with m.Case(State.PROC_TBL_WAIT):
364 with m.If(d_in.done):
365 self.proc_tbl_wait(m, v, r, data)
366
367 with m.If(d_in.err):
368 comb += v.state.eq(State.RADIX_FINISH)
369 comb += v.badtree.eq(1)
370
371 with m.Case(State.SEGMENT_CHECK):
372 self.segment_check(m, v, r, data, finalmask)
373
374 with m.Case(State.RADIX_LOOKUP):
375 comb += dcreq.eq(1)
376 comb += v.state.eq(State.RADIX_READ_WAIT)
377
378 with m.Case(State.RADIX_READ_WAIT):
379 with m.If(d_in.done):
380 self.radix_read_wait(m, v, r, d_in, data)
381 with m.Else():
382 # non-present PTE, generate a DSI
383 comb += v.state.eq(State.RADIX_FINISH)
384 comb += v.invalid.eq(1)
385
386 with m.If(d_in.err):
387 comb += v.state.eq(State.RADIX_FINISH)
388 comb += v.badtree.eq(1)
389
390 with m.Case(State.RADIX_LOAD_TLB):
391 comb += tlb_load.eq(1)
392 with m.If(~r.iside):
393 comb += dcreq.eq(1)
394 comb += v.state.eq(State.TLB_WAIT)
395 with m.Else():
396 comb += itlb_load.eq(1)
397 comb += v.state.eq(State.IDLE)
398
399 with m.Case(State.RADIX_FINISH):
400 comb += v.state.eq(State.IDLE)
401
402 with m.If((v.state == State.RADIX_FINISH) |
403 ((v.state == State.RADIX_LOAD_TLB) & r.iside)):
404 comb += v.err.eq(v.invalid | v.badtree | v.segerror
405 | v.perm_err | v.rc_error)
406 comb += v.done.eq(~v.err)
407
408 with m.If(~r.addr[63]):
409 comb += effpid.eq(r.pid)
410
411 comb += prtable_addr.eq(Cat(
412 C(0b0000, 4),
413 effpid[0:8],
414 (r.prtbl[12:36] & ~finalmask[0:24]) |
415 (effpid[8:32] & finalmask[0:24]),
416 r.prtbl[36:56]
417 ))
418
419 comb += pgtable_addr.eq(Cat(
420 C(0b000, 3),
421 (r.pgbase[3:19] & ~mask) |
422 (addrsh & mask),
423 r.pgbase[19:56]
424 ))
425
426 comb += pte.eq(Cat(
427 r.pde[0:12],
428 (r.pde[12:56] & ~finalmask) |
429 (r.addr[12:56] & finalmask),
430 ))
431
432 # update registers
433 rin.eq(v)
434
435 # drive outputs
436 with m.If(tlbie_req):
437 comb += addr.eq(r.addr)
438 with m.Elif(tlb_load):
439 comb += addr.eq(Cat(C(0, 12), r.addr[12:64]))
440 comb += tlb_data.eq(pte)
441 with m.Elif(prtbl_rd):
442 comb += addr.eq(prtable_addr)
443 with m.Else():
444 comb += addr.eq(pgtable_addr)
445
446 comb += l_out.done.eq(r.done)
447 comb += l_out.err.eq(r.err)
448 comb += l_out.invalid.eq(r.invalid)
449 comb += l_out.badtree.eq(r.badtree)
450 comb += l_out.segerr.eq(r.segerror)
451 comb += l_out.perm_error.eq(r.perm_err)
452 comb += l_out.rc_error.eq(r.rc_error)
453
454 comb += d_out.valid.eq(dcreq)
455 comb += d_out.tlbie.eq(tlbie_req)
456 comb += d_out.doall.eq(r.inval_all)
457 comb += d_out.tlbld.eq(tlb_load)
458 comb += d_out.addr.eq(addr)
459 comb += d_out.pte.eq(tlb_data)
460
461 comb += i_out.tlbld.eq(itlb_load)
462 comb += i_out.tlbie.eq(tlbie_req)
463 comb += i_out.doall.eq(r.inval_all)
464 comb += i_out.addr.eq(addr)
465 comb += i_out.pte.eq(tlb_data)
466
467 return m
468
469
470 def mmu_sim(dut):
471 yield wp.waddr.eq(1)
472 yield wp.data_i.eq(2)
473 yield wp.wen.eq(1)
474 yield
475 yield wp.wen.eq(0)
476 yield rp.ren.eq(1)
477 yield rp.raddr.eq(1)
478 yield Settle()
479 data = yield rp.data_o
480 print(data)
481 assert data == 2
482 yield
483
484 yield wp.waddr.eq(5)
485 yield rp.raddr.eq(5)
486 yield rp.ren.eq(1)
487 yield wp.wen.eq(1)
488 yield wp.data_i.eq(6)
489 yield Settle()
490 data = yield rp.data_o
491 print(data)
492 assert data == 6
493 yield
494 yield wp.wen.eq(0)
495 yield rp.ren.eq(0)
496 yield Settle()
497 data = yield rp.data_o
498 print(data)
499 assert data == 0
500 yield
501 data = yield rp.data_o
502 print(data)
503
504
505 def test_mmu():
506 dut = MMU()
507 vl = rtlil.convert(dut, ports=[])#dut.ports())
508 with open("test_mmu.il", "w") as f:
509 f.write(vl)
510
511 m = Module()
512 m.submodules.mmu = dut
513
514 # nmigen Simulation
515 sim = Simulator(m)
516 sim.add_clock(1e-6)
517
518 sim.add_sync_process(wrap(mmu_sim(dut)))
519 with sim.write_vcd('test_mmu.vcd'):
520 sim.run()
521
522 if __name__ == '__main__':
523 test_mmu()