// SPDX-License-Identifier: MIT pragma solidity ^0.8.20; import {Test, stdError} from "forge-std/Test.sol"; import {Math} from "@openzeppelin/contracts/utils/math/Math.sol"; contract MathTest is Test { // CEILDIV function testCeilDiv(uint256 a, uint256 b) public { vm.assume(b > 0); uint256 result = Math.ceilDiv(a, b); if (result == 0) { assertEq(a, 0); } else { uint256 expect = a / b; if (expect * b < a) { expect += 1; } assertEq(result, expect); } } // SQRT function testSqrt(uint256 input, uint8 r) public { Math.Rounding rounding = _asRounding(r); uint256 result = Math.sqrt(input, rounding); // square of result is bigger than input if (_squareBigger(result, input)) { assertTrue(Math.unsignedRoundsUp(rounding)); assertTrue(_squareSmaller(result - 1, input)); } // square of result is smaller than input else if (_squareSmaller(result, input)) { assertFalse(Math.unsignedRoundsUp(rounding)); assertTrue(_squareBigger(result + 1, input)); } // input is perfect square else { assertEq(result * result, input); } } function _squareBigger(uint256 value, uint256 ref) private pure returns (bool) { (bool noOverflow, uint256 square) = Math.tryMul(value, value); return !noOverflow || square > ref; } function _squareSmaller(uint256 value, uint256 ref) private pure returns (bool) { return value * value < ref; } // INV function testInvMod(uint256 value, uint256 p) public { _testInvMod(value, p, true); } function testInvMod2(uint256 seed) public { uint256 p = 2; // prime _testInvMod(bound(seed, 1, p - 1), p, false); } function testInvMod17(uint256 seed) public { uint256 p = 17; // prime _testInvMod(bound(seed, 1, p - 1), p, false); } function testInvMod65537(uint256 seed) public { uint256 p = 65537; // prime _testInvMod(bound(seed, 1, p - 1), p, false); } function testInvModP256(uint256 seed) public { uint256 p = 0xffffffff00000001000000000000000000000000ffffffffffffffffffffffff; // prime _testInvMod(bound(seed, 1, p - 1), p, false); } function _testInvMod(uint256 value, uint256 p, bool allowZero) private { uint256 inverse = Math.invMod(value, p); if (inverse != 0) { assertEq(mulmod(value, inverse, p), 1); assertLt(inverse, p); } else { assertTrue(allowZero); } } // LOG2 function testLog2(uint256 input, uint8 r) public { Math.Rounding rounding = _asRounding(r); uint256 result = Math.log2(input, rounding); if (input == 0) { assertEq(result, 0); } else if (_powerOf2Bigger(result, input)) { assertTrue(Math.unsignedRoundsUp(rounding)); assertTrue(_powerOf2Smaller(result - 1, input)); } else if (_powerOf2Smaller(result, input)) { assertFalse(Math.unsignedRoundsUp(rounding)); assertTrue(_powerOf2Bigger(result + 1, input)); } else { assertEq(2 ** result, input); } } function _powerOf2Bigger(uint256 value, uint256 ref) private pure returns (bool) { return value >= 256 || 2 ** value > ref; // 2**256 overflows uint256 } function _powerOf2Smaller(uint256 value, uint256 ref) private pure returns (bool) { return 2 ** value < ref; } // LOG10 function testLog10(uint256 input, uint8 r) public { Math.Rounding rounding = _asRounding(r); uint256 result = Math.log10(input, rounding); if (input == 0) { assertEq(result, 0); } else if (_powerOf10Bigger(result, input)) { assertTrue(Math.unsignedRoundsUp(rounding)); assertTrue(_powerOf10Smaller(result - 1, input)); } else if (_powerOf10Smaller(result, input)) { assertFalse(Math.unsignedRoundsUp(rounding)); assertTrue(_powerOf10Bigger(result + 1, input)); } else { assertEq(10 ** result, input); } } function _powerOf10Bigger(uint256 value, uint256 ref) private pure returns (bool) { return value >= 78 || 10 ** value > ref; // 10**78 overflows uint256 } function _powerOf10Smaller(uint256 value, uint256 ref) private pure returns (bool) { return 10 ** value < ref; } // LOG256 function testLog256(uint256 input, uint8 r) public { Math.Rounding rounding = _asRounding(r); uint256 result = Math.log256(input, rounding); if (input == 0) { assertEq(result, 0); } else if (_powerOf256Bigger(result, input)) { assertTrue(Math.unsignedRoundsUp(rounding)); assertTrue(_powerOf256Smaller(result - 1, input)); } else if (_powerOf256Smaller(result, input)) { assertFalse(Math.unsignedRoundsUp(rounding)); assertTrue(_powerOf256Bigger(result + 1, input)); } else { assertEq(256 ** result, input); } } function _powerOf256Bigger(uint256 value, uint256 ref) private pure returns (bool) { return value >= 32 || 256 ** value > ref; // 256**32 overflows uint256 } function _powerOf256Smaller(uint256 value, uint256 ref) private pure returns (bool) { return 256 ** value < ref; } // MULDIV function testMulDiv(uint256 x, uint256 y, uint256 d) public { // Full precision for x * y (uint256 xyHi, uint256 xyLo) = _mulHighLow(x, y); // Assume result won't overflow (see {testMulDivDomain}) // This also checks that `d` is positive vm.assume(xyHi < d); // Perform muldiv uint256 q = Math.mulDiv(x, y, d); // Full precision for q * d (uint256 qdHi, uint256 qdLo) = _mulHighLow(q, d); // Add remainder of x * y / d (computed as rem = (x * y % d)) (uint256 qdRemLo, uint256 c) = _addCarry(qdLo, mulmod(x, y, d)); uint256 qdRemHi = qdHi + c; // Full precision check that x * y = q * d + rem assertEq(xyHi, qdRemHi); assertEq(xyLo, qdRemLo); } function testMulDivDomain(uint256 x, uint256 y, uint256 d) public { (uint256 xyHi, ) = _mulHighLow(x, y); // Violate {testMulDiv} assumption (covers d is 0 and result overflow) vm.assume(xyHi >= d); // we are outside the scope of {testMulDiv}, we expect muldiv to revert vm.expectRevert(d == 0 ? stdError.divisionError : stdError.arithmeticError); Math.mulDiv(x, y, d); } // MOD EXP function testModExp(uint256 b, uint256 e, uint256 m) public { if (m == 0) { vm.expectRevert(stdError.divisionError); } uint256 result = Math.modExp(b, e, m); assertLt(result, m); assertEq(result, _nativeModExp(b, e, m)); } function testTryModExp(uint256 b, uint256 e, uint256 m) public { (bool success, uint256 result) = Math.tryModExp(b, e, m); assertEq(success, m != 0); if (success) { assertLt(result, m); assertEq(result, _nativeModExp(b, e, m)); } else { assertEq(result, 0); } } function testModExpMemory(uint256 b, uint256 e, uint256 m) public { if (m == 0) { vm.expectRevert(stdError.divisionError); } bytes memory result = Math.modExp(abi.encodePacked(b), abi.encodePacked(e), abi.encodePacked(m)); assertEq(result.length, 0x20); uint256 res = abi.decode(result, (uint256)); assertLt(res, m); assertEq(res, _nativeModExp(b, e, m)); } function testTryModExpMemory(uint256 b, uint256 e, uint256 m) public { (bool success, bytes memory result) = Math.tryModExp( abi.encodePacked(b), abi.encodePacked(e), abi.encodePacked(m) ); if (success) { assertEq(result.length, 0x20); // m is a uint256, so abi.encodePacked(m).length is 0x20 uint256 res = abi.decode(result, (uint256)); assertLt(res, m); assertEq(res, _nativeModExp(b, e, m)); } else { assertEq(result.length, 0); } } function _nativeModExp(uint256 b, uint256 e, uint256 m) private pure returns (uint256) { if (m == 1) return 0; uint256 r = 1; while (e > 0) { if (e % 2 > 0) { r = mulmod(r, b, m); } b = mulmod(b, b, m); e >>= 1; } return r; } // Helpers function _asRounding(uint8 r) private pure returns (Math.Rounding) { vm.assume(r < uint8(type(Math.Rounding).max)); return Math.Rounding(r); } function _mulHighLow(uint256 x, uint256 y) private pure returns (uint256 high, uint256 low) { (uint256 x0, uint256 x1) = (x & type(uint128).max, x >> 128); (uint256 y0, uint256 y1) = (y & type(uint128).max, y >> 128); // Karatsuba algorithm // https://en.wikipedia.org/wiki/Karatsuba_algorithm uint256 z2 = x1 * y1; uint256 z1a = x1 * y0; uint256 z1b = x0 * y1; uint256 z0 = x0 * y0; uint256 carry = ((z1a & type(uint128).max) + (z1b & type(uint128).max) + (z0 >> 128)) >> 128; high = z2 + (z1a >> 128) + (z1b >> 128) + carry; unchecked { low = x * y; } } function _addCarry(uint256 x, uint256 y) private pure returns (uint256 res, uint256 carry) { unchecked { res = x + y; } carry = res < x ? 1 : 0; } }