bug 1151: work on python version of curve25519_mul "reduce" version
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sat, 24 Feb 2024 18:48:54 +0000 (18:48 +0000)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sat, 24 Feb 2024 18:48:54 +0000 (18:48 +0000)
src/openpower/decoder/isa/ed25519/curve25519_mul.py
src/openpower/decoder/isa/ed25519/ed25519util.py [new file with mode: 0644]

index 1b6d639074fb89207c980ce4e5ef6535d15954c9..9da698f1ed772e464fdf7394f1402969feb673f2 100644 (file)
@@ -34,6 +34,7 @@ r1 +=   c;
 """
 
 import random
+from ed25519util import add128_64, lo128, shr128, reduce_mask_51, MASK64
 
 def curve25519_mul(r, s):
 
@@ -43,8 +44,8 @@ def curve25519_mul(r, s):
         print("t%d += " % i, end='')
         for j in range(i+1):
             sidx = i-j
-            print("r%d*s%d + " % (i, sidx), end='')
-            t[i] += r[i] * s[sidx]
+            print("r%d*s%d + " % (j, sidx), end='')
+            t[i] += (r[j] * s[sidx]) & MASK64
         print()
 
     for i in range(1,5):
@@ -57,7 +58,7 @@ def curve25519_mul(r, s):
         for j in range(i):
             jidx, sidx = 4-j, 5-(i-j)
             print("r%d*s%d + " % (jidx, sidx), end='')
-            t[tidx] += r[jidx] * s[sidx]
+            t[tidx] += (r[jidx] * s[sidx]) & MASK64
         print()
 
     # this is the one where i *think* it possible to do some sort
@@ -65,20 +66,48 @@ def curve25519_mul(r, s):
 
     c = 0
     for i in range(5):
-        add128_64(t[i], c);
-        r[i] = lo128(t[i]) & reduce_mask_51;
-        shr128(c, t[i], 51);
+        t[i] = add128_64(t[i], c)
+        r[i] = lo128(t[i]) & reduce_mask_51
+        c = shr128(t[i], 51);
 
     r[0] +=   c * 19; c = r[0] >> 51; r[0] = r[0] & reduce_mask_51;
     r[1] +=   c;
 
+    return r
+
+
+def contract(a): # put array back to a bignum
+    res = 0
+    for i, x in enumerate(a):
+        res += x << (i*51)
+    return res
+
+
+def expand(a): # put bignum into an array
+    res = []
+    for i in range(5):
+        res.append(a & reduce_mask_51)
+        a >>= 51
+    return res
+
+
 if __name__ == '__main__':
     random.seed(2) # set the same seed (consistent test)
-    r, s = [], []
+    r, s = [0]*5, [0]*5
+    # dummy/obvious test
+    r[0] = (2<<53)-1
+    s[0] = 2<<60
     for j in range(5):
-        r.append(random.randint(0, 2^50))
-        s.append(random.randint(0, 2^50))
-    print ("r", r)
-    print ("s", s)
+        #r[j] = random.randint(0, 1<<50)
+        #s[j] = random.randint(0, 1<<50)
+        pass
+    rb, sb = contract(r), contract(s)
+    print ("r", list(map(hex,r)), hex(rb))
+    print ("s", list(map(hex,s)), hex(sb))
     t = curve25519_mul(r, s)
-    print ("t", t)
+    tb = contract(t)
+    print ("t", list(map(hex,t)))
+    print ("     ", hex(tb))
+
+    check = rb * sb % ((1<<255)-19)
+    print ("check", hex(check))
diff --git a/src/openpower/decoder/isa/ed25519/ed25519util.py b/src/openpower/decoder/isa/ed25519/ed25519util.py
new file mode 100644 (file)
index 0000000..d328758
--- /dev/null
@@ -0,0 +1,28 @@
+"""
+#define mul64x64_128(out,a,b) out = (uint128_t)a * b;
+#define shr128_pair(out,hi,lo,shift) 
+        out = (uint64_t)((((uint128_t)hi << 64) | lo) >> (shift));
+#define shl128_pair(out,hi,lo,shift) 
+        out = (uint64_t)(((((uint128_t)hi << 64) | lo) << (shift)) >> 64);
+#define shr128(out,in,shift) out = (uint64_t)(in >> (shift));
+#define shl128(out,in,shift) out = (uint64_t)((in << shift) >> 64);
+#define add128(a,b) a += b;
+#define add128_64(a,b) a += (uint64_t)b;
+#define lo128(a) ((uint64_t)a)
+#define hi128(a) ((uint64_t)(a >> 64))
+"""
+
+MASK64 = (1<<64)-1
+MASK128 = (1<<128)-1
+reduce_mask_51 = (1<<51)-1
+reduce_mask_40 = (1<<40)-1
+reduce_mask_56 = (1<<56)-1
+def mul64x64_128(a,b): return (a * b) & MASK128
+def shr128_pair(hi,lo,shift): return shr128((hi<<64)|lo, shift)
+def shl128_pair(hi,lo,shift): return shl128((hi<<64)|lo, shift)
+def shr128(a,shift): return lo128(a>>shift)
+def shl128(a,shift): return lo128((a<<shift)>>64)
+def add128(a,b): return (a + b) & MASK128
+def add128_64(a,b): return a + lo128(b)
+def lo128(a): return a & MASK64
+def hi128(a): return lo128(a>>64)