From cc807105a4334824d9f540cc085f586afb2a4c15 Mon Sep 17 00:00:00 2001 From: Hadrien Croubois Date: Thu, 9 May 2024 09:28:23 +0200 Subject: [PATCH] add 512bits add and mult operations --- contracts/utils/math/Math.sol | 39 +++++++++++++++++++++++++++-------- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/contracts/utils/math/Math.sol b/contracts/utils/math/Math.sol index 0a431719d..d1f2776c0 100644 --- a/contracts/utils/math/Math.sol +++ b/contracts/utils/math/Math.sol @@ -17,6 +17,35 @@ library Math { Expand // Away from zero } + /** + * @dev Return the 512-bit addition of two uint256. + * + * The result is stored in two 256 variables such that product = high * 2²⁵⁶ + low. + */ + function addFull(uint256 a, uint256 b) internal pure returns (uint256 high, uint256 low) { + unchecked { + low = a + b; + high = SafeCast.toUint(low < a); + } + } + + /** + * @dev Return the 512-bit multiplication of two uint256. + * + * The result is stored in two 256 variables such that product = high * 2²⁵⁶ + low. + */ + function mulFull(uint256 a, uint256 b) internal pure returns (uint256 high, uint256 low) { + // Compute the product mod 2²⁵⁶ and mod 2²⁵⁶ - 1, then use use the Chinese Remainder Theorem to reconstruct + // the 512 bit result. + unchecked { + low = a * b; + assembly { + let mm := mulmod(a, b, not(0)) + high := sub(sub(mm, low), lt(mm, low)) + } + } + } + /** * @dev Returns the addition of two unsigned integers, with an success flag (no overflow). */ @@ -143,15 +172,7 @@ library Math { */ function mulDiv(uint256 x, uint256 y, uint256 denominator) internal pure returns (uint256 result) { unchecked { - // 512-bit multiply [prod1 prod0] = x * y. Compute the product mod 2²⁵⁶ and mod 2²⁵⁶ - 1, then use - // use the Chinese Remainder Theorem to reconstruct the 512 bit result. The result is stored in two 256 - // variables such that product = prod1 * 2²⁵⁶ + prod0. - uint256 prod0 = x * y; // Least significant 256 bits of the product - uint256 prod1; // Most significant 256 bits of the product - assembly { - let mm := mulmod(x, y, not(0)) - prod1 := sub(sub(mm, prod0), lt(mm, prod0)) - } + (uint256 prod1, uint256 prod0) = mulFull(x, y); // Handle non-overflow cases, 256 by 256 division. if (prod1 == 0) {