move get_fetch_action to separate verilog file
[rv32.git] / cpu.py
diff --git a/cpu.py b/cpu.py
index 2f9d22b1af2b8bb4a83a5cdd0c849bbd45d47668..bdf234d742d833038a1cca4b209f71b5aa5093b3 100644 (file)
--- a/cpu.py
+++ b/cpu.py
@@ -190,69 +190,6 @@ class Fetch:
         self.output_instruction = Signal(32, name="fetch_ouutput_instruction")
         self.output_state = Signal(fetch_output_state,name="fetch_output_state")
 
-    def get_fetch_action(self, dc, load_store_misaligned, mi,
-                         branch_taken, misaligned_jump_target,
-                         csr_op_is_valid):
-        c = {}
-        c["default"] = self.action.eq(FA.default) # XXX should be 32'XXXXXXXX?
-        c[FOS.empty] = self.action.eq(FA.default)
-        c[FOS.trap] = self.action.eq(FA.ack_trap)
-
-        # illegal instruction -> error trap
-        i= If((dc.act & DA.trap_illegal_instruction) != 0,
-                 self.action.eq(FA.error_trap)
-              )
-
-        # ecall / ebreak -> noerror trap
-        i = i.Elif((dc.act & DA.trap_ecall_ebreak) != 0,
-                 self.action.eq(FA.noerror_trap))
-
-        # load/store: check alignment, check wait
-        i = i.Elif((dc.act & (DA.load | DA.store)) != 0,
-                If((load_store_misaligned | ~mi.rw_address_valid),
-                    self.action.eq(FA.error_trap) # misaligned or invalid addr
-                ).Elif(mi.rw_wait,
-                    self.action.eq(FA.wait) # wait
-                ).Else(
-                    self.action.eq(FA.default) # ok
-                )
-              )
-
-        # fence
-        i = i.Elif((dc.act & DA.fence) != 0,
-                 self.action.eq(FA.fence))
-
-        # branch -> misaligned=error, otherwise jump
-        i = i.Elif((dc.act & DA.branch) != 0,
-                If(misaligned_jump_target,
-                    self.action.eq(FA.error_trap)
-                ).Else(
-                    self.action.eq(FA.jump)
-                )
-              )
-
-        # jal/jalr -> misaligned=error, otherwise jump
-        i = i.Elif((dc.act & (DA.jal | DA.jalr)) != 0,
-                If(misaligned_jump_target,
-                    self.action.eq(FA.error_trap)
-                ).Else(
-                    self.action.eq(FA.jump)
-                )
-              )
-
-        # csr -> opvalid=ok, else error trap
-        i = i.Elif((dc.act & DA.csr) != 0,
-                If(csr_op_is_valid,
-                    self.action.eq(FA.default)
-                ).Else(
-                    self.action.eq(FA.error_trap)
-                )
-              )
-
-        c[FOS.valid] = i
-
-        return Case(self.output_state, c)
-
 class CSR:
     def __init__(self, comb, sync, dc, register_rs1):
         self.comb = comb
@@ -337,17 +274,17 @@ class Regs:
         self.comb = comb
         self.sync = sync
 
-        self.ra_en = Signal(reset=1)
-        self.rb_en = Signal(reset=1)
-        self.wen = Signal(name="register_wen")
+        self.ra_en = Signal(reset=1, name="regfile_ra_en") # TODO: ondemand en
+        self.rs1 = Signal(32, name="regfile_rs1")
+        self.rs_a = Signal(5, name="regfile_rs_a")
 
-        self.rs1 = Signal(32, name="register_rs1")
-        self.rs2 = Signal(32, name="register_rs2")
-        self.wval = Signal(32, name="register_wval")
+        self.rb_en = Signal(reset=1, name="regfile_rb_en") # TODO: ondemand en
+        self.rs2 = Signal(32, name="regfile_rs2")
+        self.rs_b = Signal(5, name="regfile_rs_b")
 
-        self.rs_a = Signal(5, name="register_rs_a")
-        self.rs_b = Signal(5, name="register_rs_b")
-        self.rd = Signal(32, name="register_rd")
+        self.w_en = Signal(name="regfile_w_en")
+        self.wval = Signal(32, name="regfile_wval")
+        self.rd = Signal(32, name="regfile_rd")
 
 class CPU(Module):
     """
@@ -443,7 +380,7 @@ class CPU(Module):
     def write_register(self, rd, val):
         return [self.regs.rd.eq(rd),
                 self.regs.wval.eq(val),
-                self.regs.wen.eq(1)
+                self.regs.w_en.eq(1)
                ]
 
     def handle_valid(self, mtvec, mip, minfo, misa, csr, mi, m, mstatus, mie,
@@ -455,7 +392,7 @@ class CPU(Module):
         i = If((ft.action == FA.ack_trap) | (ft.action == FA.noerror_trap),
                 [self.handle_trap(m, mstatus, ft, dc,
                                        load_store_misaligned),
-                 self.regs.wen.eq(0) # no writing to registers
+                 self.regs.w_en.eq(0) # no writing to registers
                 ]
               )
 
@@ -490,7 +427,7 @@ class CPU(Module):
         i = i.Elif((dc.act & (DA.fence | DA.fence_i |
                               DA.store | DA.branch)) != 0,
                 # do nothing
-               self.regs.wen.eq(0) # no writing to registers
+               self.regs.w_en.eq(0) # no writing to registers
               )
 
         return i
@@ -619,7 +556,7 @@ class CPU(Module):
         rf = Instance("RegFile", name="regfile",
            i_ra_en = self.regs.ra_en,
            i_rb_en = self.regs.rb_en,
-           i_w_en = self.regs.wen,
+           i_w_en = self.regs.w_en,
            o_read_a = self.regs.rs1,
            o_read_b = self.regs.rs2,
            i_writeval = self.regs.wval,
@@ -825,9 +762,18 @@ class CPU(Module):
         # CSR decoding
         csr = CSR(self.comb, self.sync, dc, self.regs.rs1)
 
-        self.comb += ft.get_fetch_action(dc, load_store_misaligned, mi,
-                                 branch_taken, misaligned_jump_target,
-                                 csr.op_is_valid)
+        fi = Instance("CPUFetchAction", name="cpu_fetch_action",
+            o_fetch_action = ft.action,
+            i_output_state = ft.output_state,
+            i_dc_act = dc.act,
+            i_load_store_misaligned = load_store_misaligned,
+            i_mi_rw_wait = mi.rw_wait,
+            i_mi_rw_address_valid = mi.rw_address_valid,
+            i_branch_taken = branch_taken,
+            i_misaligned_jump_target = misaligned_jump_target,
+            i_csr_op_is_valid = csr.op_is_valid)
+
+        self.specials += fi
 
         minfo = MInfo(self.comb)