better fuzzing tests

pull/5035/head
Hadrien Croubois 1 week ago
parent d1a07f01f3
commit d0326d9793
No known key found for this signature in database
GPG Key ID: B53810561A746A06
  1. 56
      test/utils/math/Math.t.sol

@ -14,6 +14,8 @@ contract MathTest is Test {
// ADD512 & MUL512 // ADD512 & MUL512
function testAdd512(uint256 a, uint256 b) public pure { function testAdd512(uint256 a, uint256 b) public pure {
(uint256 high, uint256 low) = Math.add512(a, b); (uint256 high, uint256 low) = Math.add512(a, b);
// test against tryAdd
(bool success, uint256 result) = Math.tryAdd(a, b); (bool success, uint256 result) = Math.tryAdd(a, b);
if (success) { if (success) {
assertEq(high, 0); assertEq(high, 0);
@ -21,10 +23,17 @@ contract MathTest is Test {
} else { } else {
assertEq(high, 1); assertEq(high, 1);
} }
// test against unchecked
unchecked {
assertEq(low, a + b); // unchecked allow overflow
}
} }
function testMul512(uint256 a, uint256 b) public pure { function testMul512(uint256 a, uint256 b) public pure {
(uint256 high, uint256 low) = Math.mul512(a, b); (uint256 high, uint256 low) = Math.mul512(a, b);
// test against tryMul
(bool success, uint256 result) = Math.tryMul(a, b); (bool success, uint256 result) = Math.tryMul(a, b);
if (success) { if (success) {
assertEq(high, 0); assertEq(high, 0);
@ -32,6 +41,16 @@ contract MathTest is Test {
} else { } else {
assertGt(high, 0); assertGt(high, 0);
} }
// test against unchecked
unchecked {
assertEq(low, a * b); // unchecked allow overflow
}
// test against alternative method
(uint256 _high, uint256 _low) = _mulKaratsuba(a, b);
assertEq(high, _high);
assertEq(low, _low);
} }
// MIN & MAX // MIN & MAX
@ -207,7 +226,7 @@ contract MathTest is Test {
// MULDIV // MULDIV
function testMulDiv(uint256 x, uint256 y, uint256 d) public pure { function testMulDiv(uint256 x, uint256 y, uint256 d) public pure {
// Full precision for x * y // Full precision for x * y
(uint256 xyHi, uint256 xyLo) = _mulHighLow(x, y); (uint256 xyHi, uint256 xyLo) = Math.mul512(x, y);
// Assume result won't overflow (see {testMulDivDomain}) // Assume result won't overflow (see {testMulDivDomain})
// This also checks that `d` is positive // This also checks that `d` is positive
@ -217,9 +236,9 @@ contract MathTest is Test {
uint256 q = Math.mulDiv(x, y, d); uint256 q = Math.mulDiv(x, y, d);
// Full precision for q * d // Full precision for q * d
(uint256 qdHi, uint256 qdLo) = _mulHighLow(q, d); (uint256 qdHi, uint256 qdLo) = Math.mul512(q, d);
// Add remainder of x * y / d (computed as rem = (x * y % d)) // Add remainder of x * y / d (computed as rem = (x * y % d))
(uint256 qdRemLo, uint256 c) = _addCarry(qdLo, mulmod(x, y, d)); (uint256 c, uint256 qdRemLo) = Math.add512(qdLo, mulmod(x, y, d));
uint256 qdRemHi = qdHi + c; uint256 qdRemHi = qdHi + c;
// Full precision check that x * y = q * d + rem // Full precision check that x * y = q * d + rem
@ -228,7 +247,7 @@ contract MathTest is Test {
} }
function testMulDivDomain(uint256 x, uint256 y, uint256 d) public { function testMulDivDomain(uint256 x, uint256 y, uint256 d) public {
(uint256 xyHi, ) = _mulHighLow(x, y); (uint256 xyHi, ) = Math.mul512(x, y);
// Violate {testMulDiv} assumption (covers d is 0 and result overflow) // Violate {testMulDiv} assumption (covers d is 0 and result overflow)
vm.assume(xyHi >= d); vm.assume(xyHi >= d);
@ -286,26 +305,13 @@ contract MathTest is Test {
} }
} }
function _nativeModExp(uint256 b, uint256 e, uint256 m) private pure returns (uint256) {
if (m == 1) return 0;
uint256 r = 1;
while (e > 0) {
if (e % 2 > 0) {
r = mulmod(r, b, m);
}
b = mulmod(b, b, m);
e >>= 1;
}
return r;
}
// Helpers // Helpers
function _asRounding(uint8 r) private pure returns (Math.Rounding) { function _asRounding(uint8 r) private pure returns (Math.Rounding) {
vm.assume(r < uint8(type(Math.Rounding).max)); vm.assume(r < uint8(type(Math.Rounding).max));
return Math.Rounding(r); return Math.Rounding(r);
} }
function _mulHighLow(uint256 x, uint256 y) private pure returns (uint256 high, uint256 low) { function _mulKaratsuba(uint256 x, uint256 y) private pure returns (uint256 high, uint256 low) {
(uint256 x0, uint256 x1) = (x & type(uint128).max, x >> 128); (uint256 x0, uint256 x1) = (x & type(uint128).max, x >> 128);
(uint256 y0, uint256 y1) = (y & type(uint128).max, y >> 128); (uint256 y0, uint256 y1) = (y & type(uint128).max, y >> 128);
@ -325,10 +331,16 @@ contract MathTest is Test {
} }
} }
function _addCarry(uint256 x, uint256 y) private pure returns (uint256 res, uint256 carry) { function _nativeModExp(uint256 b, uint256 e, uint256 m) private pure returns (uint256) {
unchecked { if (m == 1) return 0;
res = x + y; uint256 r = 1;
while (e > 0) {
if (e % 2 > 0) {
r = mulmod(r, b, m);
}
b = mulmod(b, b, m);
e >>= 1;
} }
carry = res < x ? 1 : 0; return r;
} }
} }

Loading…
Cancel
Save