oppc/code: introduce transient
authorDmitry Selyutin <ghostmansd@gmail.com>
Sun, 14 Jan 2024 11:25:17 +0000 (14:25 +0300)
committerDmitry Selyutin <ghostmansd@gmail.com>
Tue, 16 Jan 2024 19:10:07 +0000 (22:10 +0300)
src/openpower/oppc/pc_code.py

index 4675d268a825febda514ebe2dff52f2f7bd11b54..7b37a2944a13ed3fdeace1e3571225c9510684f2 100644 (file)
@@ -6,6 +6,17 @@ import openpower.oppc.pc_util as pc_util
 import openpower.oppc.pc_pseudocode as pc_pseudocode
 
 
+class Transient(pc_ast.Node):
+    def __init__(self, value="UINT64_C(0)", bits="(uint8_t)ctx->XLEN"):
+        self.__value = value
+        self.__bits = bits
+
+        return super().__init__()
+
+    def __str__(self):
+        return f"oppc_transient(&(struct oppc_int){{}}, {self.__value}, {self.__bits})"
+
+
 class CodeVisitor(pc_util.Visitor):
     def __init__(self, name, root):
         self.__root = root
@@ -34,29 +45,18 @@ class CodeVisitor(pc_util.Visitor):
     def __getitem__(self, node):
         return self.__code[node]
 
+    def transient(self, node, value="UINT64_C(0)", bits="(uint8_t)ctx->XLEN"):
+        transient = Transient(value=value, bits=bits)
+        with self.pseudocode(node=node):
+            self.traverse(root=transient)
+        return transient
+
     @contextlib.contextmanager
     def pseudocode(self, node):
         for (level, stmt) in self.__pseudocode[node]:
             self[node].emit(stmt=f"/* {stmt} */", level=level)
         yield
 
-    def transient(self, node):
-        return [
-            (0, "/* transient */"),
-            (0, "&(struct oppc_int){"),
-            (1, ".count = ctx->XLEN,"),
-            (1, ".array = {},"),
-            (0, "}"),
-        ]
-
-    def integer(self, node, bits, value):
-        with self.pseudocode(node=node):
-            self[node].emit(stmt="&(struct oppc_int){")
-            with self[node]:
-                self[node].emit(stmt=f".count = {bits},",)
-                self[node].emit(stmt=f".array = {{{value}}},",)
-            self[node].emit(stmt="}")
-
     def call(self, node, code, prefix="", suffix=""):
         with self.pseudocode(node=node):
             self[node].emit(stmt=f"{prefix}(")
@@ -138,8 +138,9 @@ class CodeVisitor(pc_util.Visitor):
                 self[node.right],
             ])
         else:
+            transient = self.transient(node=node)
             self.call(prefix=str(self[node.op]), node=node, code=[
-                self.transient(node=node),
+                self[transient],
                 self[node.left],
                 self[node.right],
             ])
@@ -189,10 +190,10 @@ class CodeVisitor(pc_util.Visitor):
         fmt = hex
         value = str(node)
         if isinstance(node, pc_ast.BinLiteral):
-            bits = str(len(value[2:]))
+            bits = f"UINT8_C({str(len(value[2:]))})"
             value = int(value, 2)
         elif isinstance(node, pc_ast.HexLiteral):
-            bits = str(len(value[2:]) * 4)
+            bits = f"UINT8_C({str(len(value[2:]) * 4)})"
             value = int(value, 16)
         else:
             bits = "ctx->XLEN"
@@ -200,7 +201,15 @@ class CodeVisitor(pc_util.Visitor):
             fmt = str
         if (value > ((2**64) - 1)):
             raise NotImplementedError()
-        self.integer(node=node, bits=bits, value=fmt(value))
+        value = f"UINT64_C({fmt(value)})"
+        transient = self.transient(node=node, value=value, bits=bits)
+        for (level, stmt) in self[transient]:
+            self[node].emit(stmt=stmt, level=level)
+
+    @pc_util.Hook(Transient)
+    def Transient(self, node):
+        yield node
+        self[node].emit(stmt=str(node))
 
     @pc_util.Hook(pc_ast.GPR)
     def GPR(self, node):
@@ -217,8 +226,9 @@ class CodeVisitor(pc_util.Visitor):
     @pc_util.Hook(pc_ast.RepeatExpr)
     def RepeatExpr(self, node):
         yield node
+        transient = self.transient(node=node)
         self.call(prefix="oppc_repeat", node=node, code=[
-            self.transient(node=node),
+            self[transient],
             self[node.subject],
             self[node.times],
         ])
@@ -226,7 +236,10 @@ class CodeVisitor(pc_util.Visitor):
     @pc_util.Hook(pc_ast.XLEN)
     def XLEN(self, node):
         yield node
-        self.integer(node=node, bits="ctx->XLEN", value="ctx->XLEN")
+        (value, bits) = ("ctx->XLEN", "(uint8_t)ctx->XLEN")
+        transient = self.transient(node=node, value=value, bits=bits)
+        for (level, stmt) in self[transient]:
+            self[node].emit(stmt=stmt, level=level)
 
     @pc_util.Hook(pc_ast.SubscriptExpr)
     def SubscriptExpr(self, node):