insndb: refactor visitors
authorDmitry Selyutin <ghostmansd@gmail.com>
Wed, 7 Jun 2023 08:58:36 +0000 (11:58 +0300)
committerDmitry Selyutin <ghostmansd@gmail.com>
Wed, 7 Jun 2023 09:36:23 +0000 (12:36 +0300)
src/openpower/insndb/core.py
src/openpower/insndb/db.py

index 511d695ac150c3754d9c50647f9bc5c490b3faab..042b938d7dd3705dce783f049f1a3933235fd849 100644 (file)
@@ -56,18 +56,34 @@ from openpower.decoder.power_fields import (
 from openpower.decoder.pseudo.pagereader import ISA as _ISA
 
 
+class Node:
+    @property
+    def subnodes(self):
+        yield from ()
+
+
 class Visitor:
     @_contextlib.contextmanager
-    def db(self, db):
-        yield db
+    def Node(self, node, depth):
+        yield node
+        for subnode in node.subnodes:
+            manager = subnode.__class__.__name__
+            manager = getattr(self, manager, self.Node)
+            with manager(node=subnode, depth=(depth + 1)):
+                pass
 
-    @_contextlib.contextmanager
-    def record(self, record):
-        yield record
+    def __getattr__(self, attr):
+        return self.Node
+
+    def __call__(self, node, depth):
+        manager = node.__class__.__name__
+        manager = getattr(self, manager, self.Node)
+        return manager(node=node, depth=depth)
 
-    @_contextlib.contextmanager
-    def extra(self, extra):
-        yield extra
+
+def visit(visitor, node):
+    with visitor(node=node, depth=0):
+        pass
 
 
 @_functools.total_ordering
@@ -824,7 +840,7 @@ class MarkdownRecord:
 
 
 @_dataclasses.dataclass(eq=True, frozen=True)
-class Extra:
+class Extra(Node):
     name: str
     sel: _typing.Union[
         _In1Sel, _In2Sel, _In3Sel, _CRInSel, _CRIn2Sel,
@@ -839,10 +855,9 @@ class Extra:
             pass
 
 
-
 @_functools.total_ordering
 @_dataclasses.dataclass(eq=True, frozen=True)
-class Record:
+class Record(Node):
     name: str
     section: Section
     ppc: PPCRecord
@@ -850,11 +865,10 @@ class Record:
     mdwn: MarkdownRecord
     svp64: SVP64Record = None
 
-    def visit(self, visitor):
-        with visitor.record(record=self) as record:
-            for (name, fields) in record.extras.items():
-                extra = Extra(name=name, **fields)
-                extra.visit(visitor=visitor)
+    @property
+    def subnodes(self):
+        for (name, fields) in self.extras.items():
+            yield Extra(name=name, **fields)
 
     @property
     def extras(self):
@@ -3703,7 +3717,7 @@ class SVP64Database:
         return None
 
 
-class Database:
+class Database(Node):
     def __init__(self, root):
         root = _pathlib.Path(root)
         mdwndb = MarkdownDatabase()
@@ -3737,10 +3751,10 @@ class Database:
 
         return super().__init__()
 
-    def visit(self, visitor):
-        with visitor.db(db=self) as db:
-            for record in self.__db:
-                record.visit(visitor=visitor)
+    @property
+    def subnodes(self):
+        for record in self.__db:
+            yield record
 
     def __repr__(self):
         return repr(self.__db)
index aadd3b91b95ffba99d319a69f9c35c013f1cb6a5..73f4abe2a278a78622fa988119ea66338927c755 100644 (file)
@@ -9,6 +9,7 @@ from openpower.decoder.power_enums import (
 from openpower.insndb.core import (
     Database,
     Visitor,
+    visit,
 )
 
 
@@ -38,55 +39,27 @@ class SVP64Instruction(Instruction):
 class BaseVisitor(Visitor):
     def __init__(self, **arguments):
         self.__arguments = types.MappingProxyType(arguments)
-        self.__current_db = None
-        self.__current_record = None
-        self.__current_extra = None
         return super().__init__()
 
-    @property
-    def arguments(self):
-        return self.__arguments
-
-    @property
-    def current_db(self):
-        return self.__current_db
-
-    @property
-    def current_record(self):
-        return self.__current_record
-
-    @property
-    def current_extra(self):
-        return self.__current_extra
-
-    @contextlib.contextmanager
-    def db(self, db):
-        self.__current_db = db
-        yield db
-        self.__current_db = None
-
-    @contextlib.contextmanager
-    def record(self, record):
-        self.__current_record = record
-        yield record
-        self.__current_record = None
-
-    @contextlib.contextmanager
-    def extra(self, extra):
-        self.__current_extra = extra
-        yield extra
-        self.__current_extra = None
+    def __getitem__(self, argument):
+        return self.__arguments[argument]
 
 
 class ListVisitor(BaseVisitor):
     @contextlib.contextmanager
-    def record(self, record):
-        print(record.name)
-        yield record
+    def Record(self, node, depth):
+        print(node.name)
+        yield node
 
 
 class InstructionVisitor(BaseVisitor):
-    pass
+    @contextlib.contextmanager
+    def Database(self, node, depth):
+        yield node
+        for subnode in node.subnodes:
+            if subnode.name == self["insn"]:
+                with self(node=subnode, depth=(depth + 1)):
+                    pass
 
 
 class SVP64InstructionVisitor(InstructionVisitor):
@@ -95,48 +68,41 @@ class SVP64InstructionVisitor(InstructionVisitor):
 
 class OpcodesVisitor(InstructionVisitor):
     @contextlib.contextmanager
-    def record(self, record):
-        for opcode in record.opcodes:
+    def Record(self, node, depth):
+        for opcode in node.opcodes:
             print(opcode)
+        yield node
 
 
 class OperandsVisitor(InstructionVisitor):
     @contextlib.contextmanager
-    def record(self, record):
-        with super().record(record=record):
-            if self.current_record.name == self.arguments["insn"]:
-                for operand in record.dynamic_operands:
-                    print(operand.name, ",".join(map(str, operand.span)))
-                for operand in record.static_operands:
-                    if operand.name not in ("PO", "XO"):
-                        desc = f"{operand.name}={operand.value}"
-                        print(desc, ",".join(map(str, operand.span)))
-
-        yield record
+    def Record(self, node, depth):
+        for operand in node.dynamic_operands:
+            print(operand.name, ",".join(map(str, operand.span)))
+        for operand in node.static_operands:
+            if operand.name not in ("PO", "XO"):
+                desc = f"{operand.name}={operand.value}"
+                print(desc, ",".join(map(str, operand.span)))
+        yield node
 
 
 class PCodeVisitor(InstructionVisitor):
     @contextlib.contextmanager
-    def record(self, record):
-        with super().record(record=record):
-            if self.current_record.name == self.arguments["insn"]:
-                for line in record.pcode:
-                    print(line)
+    def Record(self, node, depth):
+        for line in node.pcode:
+            print(line)
+        yield node
 
 
 class ExtrasVisitor(SVP64InstructionVisitor):
     @contextlib.contextmanager
-    def extra(self, extra):
-        with super().extra(extra=extra) as extra:
-            if self.current_record.name == self.arguments["insn"]:
-                print(extra.name)
-                print("    sel", extra.sel)
-                print("    reg", extra.reg)
-                print("    seltype", extra.seltype)
-                print("    idx", extra.idx)
-                pass
-
-        yield extra
+    def Extra(self, node, depth):
+        print(node.name)
+        print("    sel", node.sel)
+        print("    reg", node.reg)
+        print("    seltype", node.seltype)
+        print("    idx", node.idx)
+        yield node
 
 
 def main():
@@ -188,7 +154,7 @@ def main():
     visitor = commands[command][0](**args)
 
     db = Database(find_wiki_dir())
-    db.visit(visitor=visitor)
+    visit(visitor=visitor, node=db)
 
 
 if __name__ == "__main__":