add pre_ra_insert_copies
[bigint-presentation-code.git] / src / bigint_presentation_code / compiler_ir2.py
index e6ffe4e606076577b57bb64a208b8b70a42c7954..382d66c457f5dacbe45561f262e464b1b9fa84f6 100644 (file)
@@ -56,6 +56,52 @@ class Fn:
         for op in self.ops:
             op.pre_ra_sim(state)
 
+    def pre_ra_insert_copies(self):
+        # type: () -> None
+        orig_ops = list(self.ops)
+        copied_outputs = {}  # type: dict[SSAVal, SSAVal]
+        self.ops.clear()
+        for op in orig_ops:
+            for i in range(len(op.inputs)):
+                inp = copied_outputs[op.inputs[i]]
+                if inp.ty.base_ty is BaseTy.I64:
+                    maxvl = inp.ty.reg_len
+                    if inp.ty.reg_len != 1:
+                        setvl = self.append_new_op(OpKind.SetVLI,
+                                                   immediates=[maxvl])
+                        vl = setvl.outputs[0]
+                        mv = self.append_new_op(OpKind.VecCopyToReg,
+                                                inputs=[inp, vl], maxvl=maxvl)
+                    else:
+                        mv = self.append_new_op(OpKind.CopyToReg, inputs=[inp])
+                    op.inputs[i] = mv.outputs[0]
+                elif inp.ty.base_ty is BaseTy.CA \
+                        or inp.ty.base_ty is BaseTy.VL_MAXVL:
+                    # all copies would be no-ops, so we don't need to copy
+                    op.inputs[i] = inp
+                else:
+                    assert_never(inp.ty.base_ty)
+            self.ops.append(op)
+            for out in op.outputs:
+                if out.ty.base_ty is BaseTy.I64:
+                    maxvl = out.ty.reg_len
+                    if out.ty.reg_len != 1:
+                        setvl = self.append_new_op(OpKind.SetVLI,
+                                                   immediates=[maxvl])
+                        vl = setvl.outputs[0]
+                        mv = self.append_new_op(OpKind.VecCopyFromReg,
+                                                inputs=[out, vl], maxvl=maxvl)
+                    else:
+                        mv = self.append_new_op(OpKind.CopyFromReg,
+                                                inputs=[out])
+                    copied_outputs[out] = mv.outputs[0]
+                elif out.ty.base_ty is BaseTy.CA \
+                        or out.ty.base_ty is BaseTy.VL_MAXVL:
+                    # all copies would be no-ops, so we don't need to copy
+                    copied_outputs[out] = out
+                else:
+                    assert_never(out.ty.base_ty)
+
 
 @unique
 @final