split out most Mem methods into MemCommon base class
authorJacob Lifshay <programmerjake@gmail.com>
Fri, 22 Sep 2023 22:40:30 +0000 (15:40 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Sat, 23 Sep 2023 01:22:05 +0000 (18:22 -0700)
src/openpower/decoder/isa/mem.py
src/openpower/test/state.py

index 6640829232b98841fd62edc9437c7f225fffc6d5..0f4639b6b0331262332abfdbf9a6bf7d208da601 100644 (file)
@@ -16,6 +16,8 @@ from collections import defaultdict
 from openpower.decoder.selectable_int import SelectableInt
 from openpower.util import log, LogKind
 import math
+import enum
+from cached_property import cached_property
 
 
 def swap_order(x, nbytes):
@@ -50,10 +52,22 @@ def process_mem(initial_mem, row_bytes=8):
     return res
 
 
-class Mem:
+@enum.unique
+class _ReadReason(enum.Enum):
+    Read = enum.auto()
+    SubWordWrite = enum.auto()
+    Dump = enum.auto()
+    Execute = enum.auto()
 
-    def __init__(self, row_bytes=8, initial_mem=None, misaligned_ok=False):
-        self.mem = {}
+    @cached_property
+    def read_default(self):
+        if self in (self.SubWordWrite, self.Dump):
+            return 0
+        return None
+
+
+class MemCommon:
+    def __init__(self, row_bytes, initial_mem, misaligned_ok):
         self.bytes_per_word = row_bytes
         self.word_log2 = math.ceil(math.log2(row_bytes))
         self.last_ld_addr = None
@@ -67,6 +81,16 @@ class Mem:
             # val = swap_order(val, width)
             self.st(addr, val, width, swap=False)
 
+    def _read_word(self, word_idx, reason):
+        raise NotImplementedError
+
+    def _write_word(self, word_idx, value):
+        raise NotImplementedError
+
+    def word_idxs(self):
+        raise NotImplementedError
+        yield 0
+
     def _get_shifter_mask(self, wid, remainder):
         shifter = ((self.bytes_per_word - wid) - remainder) * \
             8  # bits per byte
@@ -79,7 +103,7 @@ class Mem:
 
     # TODO: Implement ld/st of lesser width
     def ld(self, address, width=8, swap=True, check_in_mem=False,
-           instr_fetch=False):
+           instr_fetch=False, reason=None):
         log("ld from addr 0x%x width %d" % (address, width),
             swap, check_in_mem, instr_fetch)
         self.last_ld_addr = address  # record last load
@@ -92,12 +116,14 @@ class Mem:
                                (remainder, width))
             exc.dar = ldaddr
             raise exc
-        if address in self.mem:
-            val = self.mem[address]
-        elif check_in_mem:
-            return None
-        else:
-            val = 0
+        if reason is None:
+            reason = _ReadReason.Execute if instr_fetch else _ReadReason.Read
+        val = self._read_word(address, reason)
+        if val is None:
+            if check_in_mem:
+                return None
+            else:
+                val = 0
         log("ld mem @ 0x%x rem %d : 0x%x" % (ldaddr, remainder, val))
 
         if width != self.bytes_per_word:
@@ -125,17 +151,15 @@ class Mem:
         if swap:
             v = swap_order(v, width)
         if width != self.bytes_per_word:
-            if addr in self.mem:
-                val = self.mem[addr]
-            else:
-                val = 0
+            val = self._read_word(addr, _ReadReason.SubWordWrite)
             shifter, mask = self._get_shifter_mask(width, remainder)
             val &= ~(mask << shifter)
             val |= v << shifter
-            self.mem[addr] = val
+            self._write_word(addr, val)
         else:
-            self.mem[addr] = v
-        log("mem @ 0x%x: 0x%x" % (staddr, self.mem[addr]))
+            val = v
+            self._write_word(addr, v)
+        log("mem @ 0x%x: 0x%x" % (staddr, val))
 
     def st(self, st_addr, v, width=8, swap=True):
         self.last_st_addr = st_addr  # record last store
@@ -171,22 +195,22 @@ class Mem:
         self.st(addr.value, val.value, sz, swap=False)
 
     def dump(self, printout=True, asciidump=False):
-        keys = list(self.mem.keys())
+        keys = list(self.word_idxs())
         keys.sort()
         res = []
         for k in keys:
-            res.append(((k*8), self.mem[k]))
+            v = self._read_word(k, _ReadReason.Dump)
+            res.append((k*8, v))
             if not printout:
                 continue
             s = ""
             if asciidump:
                 for i in range(8):
-                    c = chr(self.mem[k] >> (i*8) & 0xff)
+                    c = chr(v >> (i*8) & 0xff)
                     if not c.isprintable():
                         c = "."
                     s += c
-            print("%016x: %016x" % ((k*8) & 0xffffffffffffffff,
-                                    self.mem[k]), s)
+            print("%016x: %016x" % ((k*8) & 0xffffffffffffffff, v), s)
         return res
 
     def log_fancy(self, *, kind=LogKind.Default, name="Memory",
@@ -199,10 +223,10 @@ class Mem:
             return bytearray(line_size)
         mem_lines = defaultdict(make_line)
         subword_range = range(1 << self.word_log2)
-        for k in self.mem.keys():
+        for k in self.word_idxs():
             addr = k << self.word_log2
             for _ in subword_range:
-                v = self.ld(addr, width=1)
+                v = self.ld(addr, width=1, reason=_ReadReason.Dump)
                 mem_lines[addr >> log2_line_size][addr & subline_mask] = v
                 addr += 1
 
@@ -231,3 +255,18 @@ class Mem:
             lines.append(line_str)
         lines = "\n".join(lines)
         log(f"\n{name}:\n{lines}\n", kind=kind)
+
+
+class Mem(MemCommon):
+    def __init__(self, row_bytes=8, initial_mem=None, misaligned_ok=False):
+        self.mem = {}
+        super().__init__(row_bytes, initial_mem, misaligned_ok)
+
+    def _read_word(self, word_idx, reason):
+        return self.mem.get(word_idx, reason.read_default)
+
+    def _write_word(self, word_idx, value):
+        self.mem[word_idx] = value
+
+    def word_idxs(self):
+        return self.mem.keys()
index b2eb250ad9a4ce90e192e646fae748f91a89094c..d32f4b3efdd00338b1b224698346422310aa3a36 100644 (file)
@@ -423,7 +423,7 @@ class SimState(State):
         mem = self.sim.mem
         if isinstance(mem, RADIX):
             mem = mem.mem
-        keys = list(mem.mem.keys())
+        keys = list(mem.word_idxs())
         self.mem = {}
         # from each address in the underlying mem-simulated dictionary
         # issue a 64-bit LD (with no byte-swapping)