From e3cfe1c5ddbea74cca0494bb28361c8096e0160d Mon Sep 17 00:00:00 2001
From: cairo <cairoeth@protonmail.com>
Date: Mon, 30 Sep 2024 09:05:44 -0700
Subject: [PATCH] Fix P256 corner cases (#5218)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Co-authored-by: Hadrien Croubois <hadrien.croubois@gmail.com>
Co-authored-by: Ernesto García <ernestognw@gmail.com>
---
 .solcover.js                          |  8 +++
 contracts/utils/cryptography/P256.sol | 94 ++++++++++++++++++++-------
 test/helpers/iterate.js               |  4 +-
 test/utils/cryptography/P256.t.sol    | 12 ++--
 4 files changed, 88 insertions(+), 30 deletions(-)

diff --git a/.solcover.js b/.solcover.js
index e0dea5e2c..f079998cf 100644
--- a/.solcover.js
+++ b/.solcover.js
@@ -10,4 +10,12 @@ module.exports = {
     fgrep: '[skip-on-coverage]',
     invert: true,
   },
+  // Work around stack too deep for coverage
+  configureYulOptimizer: true,
+  solcOptimizerDetails: {
+    yul: true,
+    yulDetails: {
+      optimizerSteps: '',
+    },
+  },
 };
diff --git a/contracts/utils/cryptography/P256.sol b/contracts/utils/cryptography/P256.sol
index 1c46e38b0..3028505ba 100644
--- a/contracts/utils/cryptography/P256.sol
+++ b/contracts/utils/cryptography/P256.sol
@@ -185,6 +185,13 @@ library P256 {
     /**
      * @dev Point addition on the jacobian coordinates
      * Reference: https://www.hyperelliptic.org/EFD/g1p/auto-shortw-jacobian.html#addition-add-1998-cmo-2
+     *
+     * Note that:
+     *
+     * - `addition-add-1998-cmo-2` doesn't support identical input points. This version is modified to use
+     * the `h` and `r` values computed by `addition-add-1998-cmo-2` to detect identical inputs, and fallback to
+     * `doubling-dbl-1998-cmo-2` if needed.
+     * - if one of the points is at infinity (i.e. `z=0`), the result is undefined.
      */
     function _jAdd(
         JPoint memory p1,
@@ -197,25 +204,53 @@ library P256 {
             let z1 := mload(add(p1, 0x40))
             let zz1 := mulmod(z1, z1, p) // zz1 = z1²
             let s1 := mulmod(mload(add(p1, 0x20)), mulmod(mulmod(z2, z2, p), z2, p), p) // s1 = y1*z2³
-            let r := addmod(mulmod(y2, mulmod(zz1, z1, p), p), sub(p, s1), p) // r = s2-s1 = y2*z1³-s1
+            let r := addmod(mulmod(y2, mulmod(zz1, z1, p), p), sub(p, s1), p) // r = s2-s1 = y2*z1³-s1 = y2*z1³-y1*z2³
             let u1 := mulmod(mload(p1), mulmod(z2, z2, p), p) // u1 = x1*z2²
-            let h := addmod(mulmod(x2, zz1, p), sub(p, u1), p) // h = u2-u1 = x2*z1²-u1
-            let hh := mulmod(h, h, p) // h²
+            let h := addmod(mulmod(x2, zz1, p), sub(p, u1), p) // h = u2-u1 = x2*z1²-u1 = x2*z1²-x1*z2²
+
+            // detect edge cases where inputs are identical
+            switch and(iszero(r), iszero(h))
+            // case 0: points are different
+            case 0 {
+                let hh := mulmod(h, h, p) // h²
+
+                // x' = r²-h³-2*u1*h²
+                rx := addmod(
+                    addmod(mulmod(r, r, p), sub(p, mulmod(h, hh, p)), p),
+                    sub(p, mulmod(2, mulmod(u1, hh, p), p)),
+                    p
+                )
+                // y' = r*(u1*h²-x')-s1*h³
+                ry := addmod(
+                    mulmod(r, addmod(mulmod(u1, hh, p), sub(p, rx), p), p),
+                    sub(p, mulmod(s1, mulmod(h, hh, p), p)),
+                    p
+                )
+                // z' = h*z1*z2
+                rz := mulmod(h, mulmod(z1, z2, p), p)
+            }
+            // case 1: points are equal
+            case 1 {
+                let x := x2
+                let y := y2
+                let z := z2
+                let yy := mulmod(y, y, p)
+                let zz := mulmod(z, z, p)
+                let m := addmod(mulmod(3, mulmod(x, x, p), p), mulmod(A, mulmod(zz, zz, p), p), p) // m = 3*x²+a*z⁴
+                let s := mulmod(4, mulmod(x, yy, p), p) // s = 4*x*y²
+
+                // x' = t = m²-2*s
+                rx := addmod(mulmod(m, m, p), sub(p, mulmod(2, s, p)), p)
 
-            // x' = r²-h³-2*u1*h²
-            rx := addmod(
-                addmod(mulmod(r, r, p), sub(p, mulmod(h, hh, p)), p),
-                sub(p, mulmod(2, mulmod(u1, hh, p), p)),
-                p
-            )
-            // y' = r*(u1*h²-x')-s1*h³
-            ry := addmod(
-                mulmod(r, addmod(mulmod(u1, hh, p), sub(p, rx), p), p),
-                sub(p, mulmod(s1, mulmod(h, hh, p), p)),
-                p
-            )
-            // z' = h*z1*z2
-            rz := mulmod(h, mulmod(z1, z2, p), p)
+                // y' = m*(s-t)-8*y⁴ = m*(s-x')-8*y⁴
+                // cut the computation to avoid stack too deep
+                let rytmp1 := sub(p, mulmod(8, mulmod(yy, yy, p), p)) // -8*y⁴
+                let rytmp2 := addmod(s, sub(p, rx), p) // s-x'
+                ry := addmod(mulmod(m, rytmp2, p), rytmp1, p) // m*(s-x')-8*y⁴
+
+                // z' = 2*y*z
+                rz := mulmod(2, mulmod(y, z, p), p)
+            }
         }
     }
 
@@ -228,8 +263,8 @@ library P256 {
             let p := P
             let yy := mulmod(y, y, p)
             let zz := mulmod(z, z, p)
-            let s := mulmod(4, mulmod(x, yy, p), p) // s = 4*x*y²
             let m := addmod(mulmod(3, mulmod(x, x, p), p), mulmod(A, mulmod(zz, zz, p), p), p) // m = 3*x²+a*z⁴
+            let s := mulmod(4, mulmod(x, yy, p), p) // s = 4*x*y²
 
             // x' = t = m²-2*s
             rx := addmod(mulmod(m, m, p), sub(p, mulmod(2, s, p)), p)
@@ -244,10 +279,11 @@ library P256 {
      * @dev Compute G·u1 + P·u2 using the precomputed points for G and P (see {_preComputeJacobianPoints}).
      *
      * Uses Strauss Shamir trick for EC multiplication
-     * https://stackoverflow.com/questions/50993471/ec-scalar-multiplication-with-strauss-shamir-method.
-     * We optimise on this a bit to do with 2 bits at a time rather than a single bit.
-     * The individual points for a single pass are precomputed.
-     * Overall this reduces the number of additions while keeping the same number of doublings.
+     * https://stackoverflow.com/questions/50993471/ec-scalar-multiplication-with-strauss-shamir-method
+     *
+     * We optimize this for 2 bits at a time rather than a single bit. The individual points for a single pass are
+     * precomputed. Overall this reduces the number of additions while keeping the same number of
+     * doublings
      */
     function _jMultShamir(
         JPoint[16] memory points,
@@ -263,9 +299,14 @@ library P256 {
                     (x, y, z) = _jDouble(x, y, z);
                     (x, y, z) = _jDouble(x, y, z);
                 }
-                // Read 2 bits of u1, and 2 bits of u2. Combining the two give a lookup index in the table.
+                // Read 2 bits of u1, and 2 bits of u2. Combining the two gives the lookup index in the table.
                 uint256 pos = ((u1 >> 252) & 0xc) | ((u2 >> 254) & 0x3);
-                if (pos > 0) {
+                // Points that have z = 0 are points at infinity. They are the additive 0 of the group
+                // - if the lookup point is a 0, we can skip it
+                // - otherwise:
+                //   - if the current point (x, y, z) is 0, we use the lookup point as our new value (0+P=P)
+                //   - if the current point (x, y, z) is not 0, both points are valid and we can use `_jAdd`
+                if (points[pos].z != 0) {
                     if (z == 0) {
                         (x, y, z) = (points[pos].x, points[pos].y, points[pos].z);
                     } else {
@@ -291,6 +332,11 @@ library P256 {
      * │  8 │ 2g 2g+p 2g+2p 2g+3p │
      * │ 12 │ 3g 3g+p 3g+2p 3g+3p │
      * └────┴─────────────────────┘
+     *
+     * Note that `_jAdd` (and thus `_jAddPoint`) does not handle the case where one of the inputs is a point at
+     * infinity (z = 0). However, we know that since `N ≡ 1 mod 2` and `N ≡ 1 mod 3`, there is no point P such that
+     * 2P = 0 or 3P = 0. This guarantees that g, 2g, 3g, p, 2p, 3p are all non-zero, and that all `_jAddPoint` calls
+     * have valid inputs.
      */
     function _preComputeJacobianPoints(uint256 px, uint256 py) private pure returns (JPoint[16] memory points) {
         points[0x00] = JPoint(0, 0, 0); // 0,0
diff --git a/test/helpers/iterate.js b/test/helpers/iterate.js
index ef4526e13..c7403d523 100644
--- a/test/helpers/iterate.js
+++ b/test/helpers/iterate.js
@@ -13,11 +13,11 @@ module.exports = {
   // Range from start to end in increment
   // Example: range(17,42,7) → [17,24,31,38]
   range: (start, stop = undefined, step = 1) => {
-    if (!stop) {
+    if (stop == undefined) {
       stop = start;
       start = 0;
     }
-    return start < stop ? Array.from({ length: Math.ceil((stop - start) / step) }, (_, i) => start + i * step) : [];
+    return start < stop ? Array.from({ length: (stop - start + step - 1) / step }, (_, i) => start + i * step) : [];
   },
 
   // Unique elements, with an optional getter function
diff --git a/test/utils/cryptography/P256.t.sol b/test/utils/cryptography/P256.t.sol
index 1391afd76..8b95ff225 100644
--- a/test/utils/cryptography/P256.t.sol
+++ b/test/utils/cryptography/P256.t.sol
@@ -9,8 +9,8 @@ import {Math} from "@openzeppelin/contracts/utils/math/Math.sol";
 
 contract P256Test is Test {
     /// forge-config: default.fuzz.runs = 512
-    function testVerify(uint256 seed, bytes32 digest) public {
-        uint256 privateKey = bound(uint256(keccak256(abi.encode(seed))), 1, P256.N - 1);
+    function testVerify(bytes32 digest, uint256 seed) public {
+        uint256 privateKey = _asPrivateKey(seed);
 
         (bytes32 x, bytes32 y) = P256PublicKey.getPublicKey(privateKey);
         (bytes32 r, bytes32 s) = vm.signP256(privateKey, digest);
@@ -20,8 +20,8 @@ contract P256Test is Test {
     }
 
     /// forge-config: default.fuzz.runs = 512
-    function testRecover(uint256 seed, bytes32 digest) public {
-        uint256 privateKey = bound(uint256(keccak256(abi.encode(seed))), 1, P256.N - 1);
+    function testRecover(bytes32 digest, uint256 seed) public {
+        uint256 privateKey = _asPrivateKey(seed);
 
         (bytes32 x, bytes32 y) = P256PublicKey.getPublicKey(privateKey);
         (bytes32 r, bytes32 s) = vm.signP256(privateKey, digest);
@@ -31,6 +31,10 @@ contract P256Test is Test {
         assertTrue((qx0 == x && qy0 == y) || (qx1 == x && qy1 == y));
     }
 
+    function _asPrivateKey(uint256 seed) private pure returns (uint256) {
+        return bound(seed, 1, P256.N - 1);
+    }
+
     function _ensureLowerS(bytes32 s) private pure returns (bytes32) {
         uint256 _s = uint256(s);
         unchecked {