add WIP powmod_256 -- asm test is currently disabled since divmod is too slow
authorJacob Lifshay <programmerjake@gmail.com>
Fri, 6 Oct 2023 02:57:29 +0000 (19:57 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Thu, 30 Nov 2023 07:55:28 +0000 (23:55 -0800)
src/openpower/decoder/isa/test_caller_svp64_powmod.py
src/openpower/test/bigint/powmod.py

index c9675ee24c37db7b46b465c756c83e8de1fbae43..055b40d428c1257aa345fde68b4f423805e00587 100644 (file)
@@ -14,7 +14,8 @@ related bugs:
 import unittest
 from functools import lru_cache
 import os
-from openpower.test.bigint.powmod import PowModCases, python_divmod_algorithm
+from openpower.test.bigint.powmod import (
+    PowModCases, python_divmod_algorithm, python_powmod_256_algorithm)
 from openpower.test.runner import TestRunnerBase
 
 
@@ -31,6 +32,15 @@ class TestPythonAlgorithms(unittest.TestCase):
                     self.assertEqual(out_q, q)
                     self.assertEqual(out_r, r)
 
+    def test_python_powmod_algorithm(self):
+        for base, exp, mod in PowModCases.powmod_256_test_inputs():
+            expected = pow(base, exp, mod)
+            with self.subTest(base=f"{base:#_x}", exp=f"{exp:#_x}",
+                              mod=f"{mod:#_x}", expected=f"{expected:#_x}"):
+                out = python_powmod_256_algorithm(base, exp, mod)
+                with self.subTest(out=f"{out:#_x}"):
+                    self.assertEqual(expected, out)
+
 
 # writing the test_caller invocation this way makes it work with pytest
 
index 8cdeeabdbf38d58a2a0c667bb987eabaa373e491..4f2f2d936451cd6b6caba4691a7e06ea69d85b2b 100644 (file)
@@ -267,6 +267,88 @@ def python_divmod_algorithm(n, d, width=256, log_regex=False):
     return q, r
 
 
+POWMOD_256_ASM = (
+    # base is in r4-7, exp is in r8-11, mod is in r32-35
+    "powmod_256:",
+    "mfspr 0, 8 # mflr 0",
+    "std 0, 16(1)",  # save return address
+    "setvl 0, 0, 18, 0, 1, 1",  # set VL to 18
+    "sv.std *14, -144(1)",  # save all callee-save registers
+    "stdu 1, -176(1)",  # create stack frame as required by ABI
+
+    "setvl 0, 0, 4, 0, 1, 1",  # set VL to 4
+    "sv.or *16, *4, *4",  # move base to r16-19
+    "sv.or *20, *8, *8",  # move exp to r20-23
+    "sv.or *24, *32, *32",  # move mod to r24-27
+    "sv.addi *28, 0, 0",  # retval in r28-31
+    "addi 28, 0, 1",  # retval = 1
+
+    "addi 14, 0, 256",  # ctr in r14
+
+    "powmod_256_loop:",
+    "setvl 0, 0, 4, 0, 1, 1",  # set VL to 4
+    "addi 3, 0, 1 # li 3, 1",  # shift amount
+    "addi 0, 0, 0 # li 0, 0",  # dsrd carry
+    "sv.dsrd/mrr *20, *20, 3, 0",  # exp >>= 1; shifted out bit in r0
+    "cmpli 0, 1, 0, 0 # cmpldi 0, 0",
+    "bc 12, 2, powmod_256_else # beq powmod_256_else",  # if lsb:
+
+    "sv.or *4, *28, *28",  # copy retval to r4-7
+    "sv.or *8, *16, *16",  # copy base to r8-11
+    "bl mul_256_to_512",  # prod = retval * base
+    # prod in r4-11
+
+    "setvl 0, 0, 4, 0, 1, 1",  # set VL to 4
+    "sv.or *32, *24, *24",  # copy mod to r32-35
+
+    "bl divmod_512_by_256",  # prod % mod
+    "setvl 0, 0, 4, 0, 1, 1",  # set VL to 4
+    "sv.or *28, *8, *8",  # retval = prod % mod
+
+    "powmod_256_else:",
+
+    "sv.or *4, *16, *16",  # copy base to r4-7
+    "sv.or *8, *16, *16",  # copy base to r8-11
+    "bl mul_256_to_512",  # prod = base * base
+    # prod in r4-11
+
+    "setvl 0, 0, 4, 0, 1, 1",  # set VL to 4
+    "sv.or *32, *24, *24",  # copy mod to r32-35
+
+    "bl divmod_512_by_256",  # prod % mod
+    "setvl 0, 0, 4, 0, 1, 1",  # set VL to 4
+    "sv.or *16, *8, *8",  # base = prod % mod
+
+    "addic. 14, 14, -1",  # decrement ctr and compare against zero
+    "bc 4, 2, powmod_256_loop # bne powmod_256_loop",
+
+    "setvl 0, 0, 4, 0, 1, 1",  # set VL to 4
+    "sv.or *4, *28, *28",  # move retval to r4-7
+
+    "addi 1, 1, 176",  # teardown stack frame
+    "ld 0, 16(1)",
+    "mtspr 8, 0 # mtlr 0",  # restore return address
+    "setvl 0, 0, 18, 0, 1, 1",  # set VL to 18
+    "sv.ld *14, -144(1)",  # restore all callee-save registers
+    "bclr 20, 0, 0 # blr",
+    *MUL_256_X_256_TO_512_ASM,
+    *DIVMOD_512x256_TO_256x256_ASM,
+)
+
+
+def python_powmod_256_algorithm(base, exp, mod):
+    retval = 1
+    for _ in range(256):
+        lsb = bool(exp & 1)  # rshift and retrieve lsb
+        exp >>= 1
+        if lsb:
+            prod = retval * base
+            retval = prod % mod
+        prod = base * base
+        base = prod % mod
+    return retval
+
+
 class PowModCases(TestAccumulatorBase):
     def call_case(self, instructions, expected, initial_regs, src_loc_at=0):
         stop_at_pc = 0x10000000
@@ -354,8 +436,47 @@ class PowModCases(TestAccumulatorBase):
 
                 self.call_case(DIVMOD_512x256_TO_256x256_ASM, e, initial_regs)
 
-    # TODO: add 256-bit modular exponentiation
+    @staticmethod
+    def powmod_256_test_inputs():
+        for i in range(3):
+            base = hash_256(f"powmod256 input base {i}")
+            exp = hash_256(f"powmod256 input exp {i}")
+            mod = hash_256(f"powmod256 input mod {i}")
+            if i == 0:
+                base = 2
+                exp = 2 ** 256 - 1
+                mod = 2 ** 256 - 189  # largest prime less than 2 ** 256
+            if mod == 0:
+                mod = 1
+            base %= mod
+            yield (base, exp, mod)
+
+    @skip_case("FIXME: divmod is too slow to test powmod")
+    def case_powmod_256(self):
+        for base, exp, mod in PowModCases.powmod_256_test_inputs():
+            expected = pow(base, exp, mod)
+            with self.subTest(base=f"{base:#_x}", exp=f"{exp:#_x}",
+                              mod=f"{mod:#_x}", expected=f"{expected:#_x}"):
+                # registers start filled with junk
+                initial_regs = [0xABCDEF] * 128
+                for i in range(4):
+                    # write n in LE order to regs 4-7
+                    initial_regs[4 + i] = (base >> (64 * i)) % 2**64
+                for i in range(4):
+                    # write n in LE order to regs 8-11
+                    initial_regs[8 + i] = (exp >> (64 * i)) % 2**64
+                for i in range(4):
+                    # write d in LE order to regs 32-35
+                    initial_regs[32 + i] = (mod >> (64 * i)) % 2**64
+                # only check regs up to r7 since that's where the output is.
+                # don't check CR
+                e = ExpectedState(int_regs=initial_regs[:8], crregs=0)
+                e.ca = None  # ignored
+                for i in range(4):
+                    # write output in LE order to regs 4-7
+                    e.intregs[4 + i] = (expected >> (64 * i)) % 2**64
 
+                self.call_case(POWMOD_256_ASM, e, initial_regs)
 
 # for running "quick" simple investigations
 if __name__ == "__main__":