insndb/core: support class walking
authorDmitry Selyutin <ghostmansd@gmail.com>
Fri, 9 Jun 2023 19:20:38 +0000 (22:20 +0300)
committerDmitry Selyutin <ghostmansd@gmail.com>
Fri, 9 Jun 2023 19:59:06 +0000 (22:59 +0300)
src/openpower/insndb/core.py

index 6a2054be930b215e98dbc1539fa3510bb8f81a2d..05042d2d2ed39163444c0533ac63c2acc4c12973 100644 (file)
@@ -4,6 +4,7 @@ import csv as _csv
 import dataclasses as _dataclasses
 import enum as _enum
 import functools as _functools
+import inspect as _inspect
 import os as _os
 import operator as _operator
 import pathlib as _pathlib
@@ -56,8 +57,21 @@ from openpower.decoder.power_fields import (
 from openpower.decoder.pseudo.pagereader import ISA as _ISA
 
 
+class walkmethod:
+    def __init__(self, walk):
+        self.__walk = walk
+        return super().__init__()
+
+    def __get__(self, instance, owner):
+        entity = instance
+        if instance is None:
+            entity = owner
+        return _functools.partial(self.__walk, entity)
+
+
 class Node:
-    def walk(self, match=None):
+    @walkmethod
+    def walk(clsself, match=None):
         return ()
 
 
@@ -76,14 +90,20 @@ class DataclassMeta(type):
 
 
 class Dataclass(metaclass=DataclassMeta):
-    def walk(self, match=None):
+    @walkmethod
+    def walk(clsself, match=None):
         if match is None:
             match = lambda subnode: True
 
-        def subnode(field):
-            return getattr(self, field.name)
+        def field_type(field):
+            return field.type
 
-        yield from filter(match, map(subnode, _dataclasses.fields(self)))
+        def field_value(field):
+            return getattr(clsself, field.name)
+
+        field = (field_type if isinstance(clsself, type) else field_value)
+
+        yield from filter(match, map(field, _dataclasses.fields(clsself)))
 
 
 class Visitor:
@@ -3724,13 +3744,17 @@ class Records(tuple):
     def __new__(cls, records):
         return super().__new__(cls, sorted(records))
 
-    def walk(self, match=None):
+    @walkmethod
+    def walk(clsself, match=None):
         if match is None:
             match = lambda subnode: True
 
-        for record in self:
-            if match(record):
-                yield record
+        if isinstance(clsself, type):
+            yield Record
+        else:
+            for record in clsself:
+                if match(record):
+                    yield record
 
 
 class Database(Node):
@@ -3767,12 +3791,16 @@ class Database(Node):
 
         return super().__init__()
 
-    def walk(self, match=None):
+    @walkmethod
+    def walk(clsself, match=None):
         if match is None:
             match = lambda subnode: True
 
-        if match(self.__db):
-            yield self.__db
+        if isinstance(clsself, type):
+            yield Records
+        else:
+            if match(clsself.__db):
+                yield clsself.__db
 
     def __repr__(self):
         return repr(self.__db)