|
|
|
@ -14,6 +14,8 @@ contract MathTest is Test { |
|
|
|
|
// ADD512 & MUL512 |
|
|
|
|
function testAdd512(uint256 a, uint256 b) public pure { |
|
|
|
|
(uint256 high, uint256 low) = Math.add512(a, b); |
|
|
|
|
|
|
|
|
|
// test against tryAdd |
|
|
|
|
(bool success, uint256 result) = Math.tryAdd(a, b); |
|
|
|
|
if (success) { |
|
|
|
|
assertEq(high, 0); |
|
|
|
@ -21,10 +23,17 @@ contract MathTest is Test { |
|
|
|
|
} else { |
|
|
|
|
assertEq(high, 1); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// test against unchecked |
|
|
|
|
unchecked { |
|
|
|
|
assertEq(low, a + b); // unchecked allow overflow |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
function testMul512(uint256 a, uint256 b) public pure { |
|
|
|
|
(uint256 high, uint256 low) = Math.mul512(a, b); |
|
|
|
|
|
|
|
|
|
// test against tryMul |
|
|
|
|
(bool success, uint256 result) = Math.tryMul(a, b); |
|
|
|
|
if (success) { |
|
|
|
|
assertEq(high, 0); |
|
|
|
@ -32,6 +41,16 @@ contract MathTest is Test { |
|
|
|
|
} else { |
|
|
|
|
assertGt(high, 0); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// test against unchecked |
|
|
|
|
unchecked { |
|
|
|
|
assertEq(low, a * b); // unchecked allow overflow |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// test against alternative method |
|
|
|
|
(uint256 _high, uint256 _low) = _mulKaratsuba(a, b); |
|
|
|
|
assertEq(high, _high); |
|
|
|
|
assertEq(low, _low); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// MIN & MAX |
|
|
|
@ -207,7 +226,7 @@ contract MathTest is Test { |
|
|
|
|
// MULDIV |
|
|
|
|
function testMulDiv(uint256 x, uint256 y, uint256 d) public pure { |
|
|
|
|
// Full precision for x * y |
|
|
|
|
(uint256 xyHi, uint256 xyLo) = _mulHighLow(x, y); |
|
|
|
|
(uint256 xyHi, uint256 xyLo) = Math.mul512(x, y); |
|
|
|
|
|
|
|
|
|
// Assume result won't overflow (see {testMulDivDomain}) |
|
|
|
|
// This also checks that `d` is positive |
|
|
|
@ -217,9 +236,9 @@ contract MathTest is Test { |
|
|
|
|
uint256 q = Math.mulDiv(x, y, d); |
|
|
|
|
|
|
|
|
|
// Full precision for q * d |
|
|
|
|
(uint256 qdHi, uint256 qdLo) = _mulHighLow(q, d); |
|
|
|
|
(uint256 qdHi, uint256 qdLo) = Math.mul512(q, d); |
|
|
|
|
// Add remainder of x * y / d (computed as rem = (x * y % d)) |
|
|
|
|
(uint256 qdRemLo, uint256 c) = _addCarry(qdLo, mulmod(x, y, d)); |
|
|
|
|
(uint256 c, uint256 qdRemLo) = Math.add512(qdLo, mulmod(x, y, d)); |
|
|
|
|
uint256 qdRemHi = qdHi + c; |
|
|
|
|
|
|
|
|
|
// Full precision check that x * y = q * d + rem |
|
|
|
@ -228,7 +247,7 @@ contract MathTest is Test { |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
function testMulDivDomain(uint256 x, uint256 y, uint256 d) public { |
|
|
|
|
(uint256 xyHi, ) = _mulHighLow(x, y); |
|
|
|
|
(uint256 xyHi, ) = Math.mul512(x, y); |
|
|
|
|
|
|
|
|
|
// Violate {testMulDiv} assumption (covers d is 0 and result overflow) |
|
|
|
|
vm.assume(xyHi >= d); |
|
|
|
@ -286,26 +305,13 @@ contract MathTest is Test { |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
function _nativeModExp(uint256 b, uint256 e, uint256 m) private pure returns (uint256) { |
|
|
|
|
if (m == 1) return 0; |
|
|
|
|
uint256 r = 1; |
|
|
|
|
while (e > 0) { |
|
|
|
|
if (e % 2 > 0) { |
|
|
|
|
r = mulmod(r, b, m); |
|
|
|
|
} |
|
|
|
|
b = mulmod(b, b, m); |
|
|
|
|
e >>= 1; |
|
|
|
|
} |
|
|
|
|
return r; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Helpers |
|
|
|
|
function _asRounding(uint8 r) private pure returns (Math.Rounding) { |
|
|
|
|
vm.assume(r < uint8(type(Math.Rounding).max)); |
|
|
|
|
return Math.Rounding(r); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
function _mulHighLow(uint256 x, uint256 y) private pure returns (uint256 high, uint256 low) { |
|
|
|
|
function _mulKaratsuba(uint256 x, uint256 y) private pure returns (uint256 high, uint256 low) { |
|
|
|
|
(uint256 x0, uint256 x1) = (x & type(uint128).max, x >> 128); |
|
|
|
|
(uint256 y0, uint256 y1) = (y & type(uint128).max, y >> 128); |
|
|
|
|
|
|
|
|
@ -325,10 +331,16 @@ contract MathTest is Test { |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
function _addCarry(uint256 x, uint256 y) private pure returns (uint256 res, uint256 carry) { |
|
|
|
|
unchecked { |
|
|
|
|
res = x + y; |
|
|
|
|
function _nativeModExp(uint256 b, uint256 e, uint256 m) private pure returns (uint256) { |
|
|
|
|
if (m == 1) return 0; |
|
|
|
|
uint256 r = 1; |
|
|
|
|
while (e > 0) { |
|
|
|
|
if (e % 2 > 0) { |
|
|
|
|
r = mulmod(r, b, m); |
|
|
|
|
} |
|
|
|
|
b = mulmod(b, b, m); |
|
|
|
|
e >>= 1; |
|
|
|
|
} |
|
|
|
|
carry = res < x ? 1 : 0; |
|
|
|
|
return r; |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|