oppc/code: convert everything into calls
authorDmitry Selyutin <ghostmansd@gmail.com>
Sat, 13 Jan 2024 12:33:48 +0000 (15:33 +0300)
committerDmitry Selyutin <ghostmansd@gmail.com>
Tue, 16 Jan 2024 19:10:07 +0000 (22:10 +0300)
src/openpower/oppc/pc_code.py

index bf16efaa184311ad73851e1d50b08511dc972117..4675d268a825febda514ebe2dff52f2f7bd11b54 100644 (file)
@@ -1,7 +1,9 @@
 import collections
+import contextlib
 
 import openpower.oppc.pc_ast as pc_ast
 import openpower.oppc.pc_util as pc_util
+import openpower.oppc.pc_pseudocode as pc_pseudocode
 
 
 class CodeVisitor(pc_util.Visitor):
@@ -13,15 +15,16 @@ class CodeVisitor(pc_util.Visitor):
         self.__decls = collections.defaultdict(list)
         self.__regfetch = collections.defaultdict(list)
         self.__regstore = collections.defaultdict(list)
+        self.__pseudocode = pc_pseudocode.PseudocodeVisitor(root=root)
 
         super().__init__(root=root)
 
-        self.__code[self.__header].emit("void")
-        self.__code[self.__header].emit(f"oppc_{name}(void) {{")
+        self.__code[self.__header].emit(stmt="void")
+        self.__code[self.__header].emit(stmt=f"oppc_{name}(void) {{")
         with self.__code[self.__header]:
             for decl in self.__decls:
-                self.__code[self.__header].emit(f"uint64_t {decl};")
-        self.__code[self.__footer].emit(f"}}")
+                self.__code[self.__header].emit(stmt=f"struct oppc_int {decl};")
+        self.__code[self.__footer].emit(stmt=f"}}")
 
     def __iter__(self):
         yield from self.__code[self.__header]
@@ -31,6 +34,46 @@ class CodeVisitor(pc_util.Visitor):
     def __getitem__(self, node):
         return self.__code[node]
 
+    @contextlib.contextmanager
+    def pseudocode(self, node):
+        for (level, stmt) in self.__pseudocode[node]:
+            self[node].emit(stmt=f"/* {stmt} */", level=level)
+        yield
+
+    def transient(self, node):
+        return [
+            (0, "/* transient */"),
+            (0, "&(struct oppc_int){"),
+            (1, ".count = ctx->XLEN,"),
+            (1, ".array = {},"),
+            (0, "}"),
+        ]
+
+    def integer(self, node, bits, value):
+        with self.pseudocode(node=node):
+            self[node].emit(stmt="&(struct oppc_int){")
+            with self[node]:
+                self[node].emit(stmt=f".count = {bits},",)
+                self[node].emit(stmt=f".array = {{{value}}},",)
+            self[node].emit(stmt="}")
+
+    def call(self, node, code, prefix="", suffix=""):
+        with self.pseudocode(node=node):
+            self[node].emit(stmt=f"{prefix}(")
+            with self[node]:
+                for chunk in code[:-1]:
+                    for (level, stmt) in chunk:
+                        if not (not stmt or
+                                stmt.startswith("/*") or
+                                stmt.endswith((",", "(", "{", "*/"))):
+                            stmt = (stmt + ",")
+                        self[node].emit(stmt=stmt, level=level)
+                if len(code) > 0:
+                    for (level, stmt) in code[-1]:
+                        if stmt:
+                            self[node].emit(stmt=stmt, level=level)
+            self[node].emit(stmt=f"){suffix}")
+
     @pc_util.Hook(pc_ast.Scope)
     def Scope(self, node):
         yield node
@@ -47,43 +90,34 @@ class CodeVisitor(pc_util.Visitor):
         if isinstance(node.rvalue, (pc_ast.GPR, pc_ast.FPR)):
             self.__regfetch[str(node.rvalue)].append(node.rvalue)
 
-        rvalue = str(self[node.rvalue])
+        rvalue = self[node.rvalue]
         if isinstance(node.rvalue, pc_ast.IfExpr):
-            rvalue = " ".join([
+            rvalue = [(0, " ".join([
                 str(self[node.rvalue.test]),
                 "?",
                 str(self[node.rvalue.body[0]]),
                 ":",
                 str(self[node.rvalue.orelse[0]]),
-            ])
+            ]))]
 
         if isinstance(node.lvalue, pc_ast.SubscriptExpr):
-            index = str(self[node.lvalue.index])
-            subject = str(self[node.lvalue.subject])
-            args = ", ".join([
-                f"&{subject}",
-                index,
+            self.call(prefix="oppc_subscript_assign", suffix=";", node=node, code=[
+                self[node.lvalue.subject],
+                self[node.lvalue.index],
                 rvalue,
             ])
-            self[node].emit(f"oppc_subscript_assign({args});")
         elif isinstance(node.lvalue, pc_ast.RangeSubscriptExpr):
-            start = str(self[node.lvalue.start])
-            end = str(self[node.lvalue.end])
-            subject = str(self[node.lvalue.subject])
-            args = ", ".join([
-                f"&{subject}",
-                start,
-                end,
+            self.call(prefix="oppc_range_subscript_assign", suffix=";", node=node, code=[
+                self[node.lvalue.subject],
+                self[node.lvalue.start],
+                self[node.lvalue.end],
                 rvalue,
             ])
-            self[node].emit(f"oppc_range_subscript_assign({args});")
         else:
-            stmt = " ".join([
-                str(self[node.lvalue]),
-                "=",
+            self.call(prefix="oppc_assign", suffix=";", node=node, code=[
+                self[node.lvalue],
                 rvalue,
             ])
-            self[node].emit(stmt=f"{stmt};")
 
     @pc_util.Hook(pc_ast.BinaryExpr)
     def BinaryExpr(self, node):
@@ -92,30 +126,30 @@ class CodeVisitor(pc_util.Visitor):
             self.__regfetch[str(node.left)].append(node.left)
         if isinstance(node.right, (pc_ast.GPR, pc_ast.FPR)):
             self.__regfetch[str(node.right)].append(node.left)
-        special = (
-            pc_ast.MulS,
-            pc_ast.MulU,
-            pc_ast.DivT,
-            pc_ast.Sqrt,
-            pc_ast.BitConcat
+
+        comparison = (
+            pc_ast.Lt, pc_ast.Le,
+            pc_ast.Eq, pc_ast.NotEq,
+            pc_ast.Ge, pc_ast.Gt,
         )
-        if isinstance(node.op, special):
-            raise NotImplementedError(node)
-        stmt = " ".join([
-            str(self[node.left]),
-            str(self[node.op]),
-            str(self[node.right]),
-        ])
-        self[node].emit(stmt=f"({stmt})")
+        if isinstance(node.op, comparison):
+            self.call(prefix=str(self[node.op]), node=node, code=[
+                self[node.left],
+                self[node.right],
+            ])
+        else:
+            self.call(prefix=str(self[node.op]), node=node, code=[
+                self.transient(node=node),
+                self[node.left],
+                self[node.right],
+            ])
 
     @pc_util.Hook(pc_ast.UnaryExpr)
     def UnaryExpr(self, node):
         yield node
-        stmt = "".join([
-            str(self[node.op]),
-            f"({str(self[node.value])})",
+        self.call(prefix=str(self[node.op]), node=node, code=[
+            self[node.value],
         ])
-        self[node].emit(stmt=stmt)
 
     @pc_util.Hook(
             pc_ast.Not, pc_ast.Add, pc_ast.Sub,
@@ -129,103 +163,166 @@ class CodeVisitor(pc_util.Visitor):
     def Op(self, node):
         yield node
         op = {
-            pc_ast.Not: "~",
-            pc_ast.Add: "+",
-            pc_ast.Sub: "-",
-            pc_ast.Mul: "*",
-            pc_ast.Div: "/",
-            pc_ast.Mod: "%",
-            pc_ast.Lt: "<",
-            pc_ast.Le: "<=",
-            pc_ast.Eq: "==",
-            pc_ast.NotEq: "!=",
-            pc_ast.Ge: ">=",
-            pc_ast.Gt: ">",
-            pc_ast.LShift: "<<",
-            pc_ast.RShift: "<<",
-            pc_ast.BitAnd: "&",
-            pc_ast.BitOr: "|",
-            pc_ast.BitXor: "^",
+            pc_ast.Not: "oppc_not",
+            pc_ast.Add: "oppc_add",
+            pc_ast.Sub: "oppc_sub",
+            pc_ast.Mul: "oppc_mul",
+            pc_ast.Div: "oppc_div",
+            pc_ast.Mod: "oppc_mod",
+            pc_ast.Lt: "oppc_lt",
+            pc_ast.Le: "oppc_le",
+            pc_ast.Eq: "oppc_eq",
+            pc_ast.NotEq: "oppc_noteq",
+            pc_ast.Ge: "oppc_ge",
+            pc_ast.Gt: "oppc_gt",
+            pc_ast.LShift: "oppc_lshift",
+            pc_ast.RShift: "oppc_rshift",
+            pc_ast.BitAnd: "oppc_and",
+            pc_ast.BitOr: "oppc_or",
+            pc_ast.BitXor: "oppc_xor",
         }[node.__class__]
         self[node].emit(stmt=op)
 
     @pc_util.Hook(pc_ast.BinLiteral, pc_ast.DecLiteral, pc_ast.HexLiteral)
     def Integer(self, node):
         yield node
+        fmt = hex
+        value = str(node)
         if isinstance(node, pc_ast.BinLiteral):
-            base = 2
-        elif isinstance(node, pc_ast.DecLiteral):
-            base = 10
+            bits = str(len(value[2:]))
+            value = int(value, 2)
         elif isinstance(node, pc_ast.HexLiteral):
-            base = 16
+            bits = str(len(value[2:]) * 4)
+            value = int(value, 16)
         else:
-            raise ValueError(node)
-        self[node].emit(stmt=f"UINT64_C({hex(int(node, base))})")
+            bits = "ctx->XLEN"
+            value = int(value)
+            fmt = str
+        if (value > ((2**64) - 1)):
+            raise NotImplementedError()
+        self.integer(node=node, bits=bits, value=fmt(value))
 
     @pc_util.Hook(pc_ast.GPR)
     def GPR(self, node):
         yield node
-        self[node].emit(stmt=f"ctx->gpr[OPPC_GPR_{str(node)}]")
+        with self.pseudocode(node=node):
+            self[node].emit(stmt=f"&ctx->gpr[OPPC_GPR_{str(node)}]")
 
     @pc_util.Hook(pc_ast.FPR)
     def FPR(self, node):
         yield node
-        self[node].emit(stmt=f"ctx->fpr[OPPC_FPR_{str(node)}]")
+        with self.pseudocode(node=node):
+            self[node].emit(stmt=f"&ctx->fpr[OPPC_FPR_{str(node)}]")
 
     @pc_util.Hook(pc_ast.RepeatExpr)
     def RepeatExpr(self, node):
         yield node
-        subject = str(self[node.subject])
-        times = str(self[node.times])
-        self[node].emit(f"oppc_repeat({subject}, {times})")
+        self.call(prefix="oppc_repeat", node=node, code=[
+            self.transient(node=node),
+            self[node.subject],
+            self[node.times],
+        ])
 
     @pc_util.Hook(pc_ast.XLEN)
     def XLEN(self, node):
         yield node
-        self[node].emit(f"ctx->XLEN")
+        self.integer(node=node, bits="ctx->XLEN", value="ctx->XLEN")
 
     @pc_util.Hook(pc_ast.SubscriptExpr)
     def SubscriptExpr(self, node):
         yield node
-        index = str(self[node.index])
-        subject = str(self[node.subject])
-        self[node].emit(f"oppc_subscript({subject}, {index})")
+        self.call(prefix="oppc_subscript", node=node, code=[
+            self[node.subject],
+            self[node.index],
+        ])
 
     @pc_util.Hook(pc_ast.RangeSubscriptExpr)
     def RangeSubscriptExpr(self, node):
         yield node
-        start = str(self[node.start])
-        end = str(self[node.end])
-        subject = str(self[node.subject])
-        self[node].emit(f"oppc_range_subscript({subject}, {start}, {end})")
+        self.call(prefix="oppc_subscript", node=node, code=[
+            self[node.subject],
+            self[node.start],
+            self[node.end],
+        ])
 
     @pc_util.Hook(pc_ast.ForExpr)
     def ForExpr(self, node):
         yield node
-        subject = str(self[node.subject])
-        start = str(self[node.start])
-        end = str(self[node.end])
-        self[node].emit(f"for ({subject} = {start}; {subject} != ({end} + 1); ++{subject}) {{")
+
+        enter = pc_ast.AssignExpr(
+            lvalue=node.subject.clone(),
+            rvalue=node.start.clone(),
+        )
+        match = pc_ast.BinaryExpr(
+            left=node.subject.clone(),
+            op=pc_ast.Le("<="),
+            right=node.end.clone(),
+        )
+        leave = pc_ast.AssignExpr(
+            lvalue=node.subject.clone(),
+            rvalue=pc_ast.BinaryExpr(
+                left=node.subject.clone(),
+                op=pc_ast.Add("+"),
+                right=node.end.clone(),
+            ),
+        )
+        with self.pseudocode(node=node):
+            (level, stmt) = self[node][0]
+        self[node].clear()
+        self[node].emit(stmt=stmt, level=level)
+        self[node].emit(stmt="for (")
+        with self[node]:
+            with self[node]:
+                for subnode in (enter, match, leave):
+                    self.__pseudocode.traverse(root=subnode)
+                    self.traverse(root=subnode)
+                    for (level, stmt) in self[subnode]:
+                        self[node].emit(stmt=stmt, level=level)
+                    (level, stmt) = self[node][-1]
+                    if subnode is match:
+                        stmt = f"{stmt};"
+                    elif subnode is leave:
+                        stmt = stmt[:-1]
+                    self[node][-1] = (level, stmt)
+        (level, stmt) = self[node][0]
+        self[node].emit(stmt=stmt, level=level)
+        self[node].emit(stmt=") {")
+        for (level, stmt) in self[node.body]:
+            self[node].emit(stmt=stmt, level=level)
+        self[node].emit(stmt="}")
+
+    @pc_util.Hook(pc_ast.WhileExpr)
+    def WhileExpr(self, node):
+        yield node
+        self[node].emit(stmt="while (")
+        with self[node]:
+            with self[node]:
+                for (level, stmt) in self[node.test]:
+                    self[node].emit(stmt=stmt, level=level)
+        self[node].emit(") {")
         for (level, stmt) in self[node.body]:
             self[node].emit(stmt=stmt, level=level)
-        self[node].emit(f"}}")
+        if node.orelse:
+            self[node].emit(stmt="} else {")
+            for (level, stmt) in self[node.orelse]:
+                self[node].emit(stmt=stmt, level=level)
+        self[node].emit(stmt="}")
 
     @pc_util.Hook(pc_ast.IfExpr)
     def IfExpr(self, node):
         yield node
-        stmt = " ".join([
-            "if",
-            str(self[node.test]),
-            "{",
-        ])
-        self[node].emit(stmt=stmt)
+        self[node].emit(stmt="if (")
+        with self[node]:
+            for (level, stmt) in self[node.test]:
+                self[node].emit(stmt=stmt, level=level)
+        self[node].emit(stmt=") {")
         for (level, stmt) in self[node.body]:
             self[node].emit(stmt=stmt, level=level)
         if node.orelse:
-            self[node].emit("} else {")
+            self[node].emit(stmt="} else {")
             for (level, stmt) in self[node.orelse]:
                 self[node].emit(stmt=stmt, level=level)
-        self[node].emit("}")
+        self[node].emit(stmt="}")
 
     @pc_util.Hook(pc_ast.Call.Name)
     def CallName(self, node):
@@ -238,22 +335,18 @@ class CodeVisitor(pc_util.Visitor):
         for subnode in node:
             if isinstance(subnode, (pc_ast.GPR, pc_ast.FPR)):
                 self.__regfetch[str(subnode)].append(subnode)
-        stmt = ", ".join(map(lambda subnode: str(self[subnode]), node))
-        self[node].emit(stmt=stmt)
 
     @pc_util.Hook(pc_ast.Call)
     def Call(self, node):
         yield node
-        name = str(self[node.name])
-        args = str(self[node.args])
-        stmt = f"{name}({args})"
-        self[node].emit(stmt=stmt)
+        code = tuple(map(lambda arg: self[arg], node.args))
+        self.call(prefix=str(node.name), node=node, code=code)
 
     @pc_util.Hook(pc_ast.Symbol)
     def Symbol(self, node):
         yield node
         self.__decls[str(node)].append(node)
-        self[node].emit(stmt=str(node))
+        self[node].emit(stmt=f"&{str(node)}")
 
     @pc_util.Hook(pc_ast.Node)
     def Node(self, node):