From 3ac4add548178708f5401c26280b952beb244c1e Mon Sep 17 00:00:00 2001 From: jjz Date: Tue, 7 Jun 2022 14:26:45 +0800 Subject: [PATCH] Add sqrt for math (#3242) --- CHANGELOG.md | 1 + contracts/mocks/MathMock.sol | 4 ++ contracts/utils/math/Math.sol | 74 +++++++++++++++++++++++++++++++++++ test/utils/math/Math.test.js | 34 ++++++++++++++++ 4 files changed, 113 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f5aa30d8c..f1649e013 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ * `ERC20FlashMint`: Add customizable flash fee receiver. ([#3327](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3327)) * `ERC20TokenizedVault`: add an extension of `ERC20` that implements the ERC4626 Tokenized Vault Standard. ([#3171](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3171)) * `Math`: add a `mulDiv` function that can round the result either up or down. ([#3171](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3171)) + * `Math`: Add a `sqrt` function to compute square roots of integers, rounding either up or down. ([#3242](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3242)) * `Strings`: add a new overloaded function `toHexString` that converts an `address` with fixed length of 20 bytes to its not checksummed ASCII `string` hexadecimal representation. ([#3403](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3403)) * `EnumerableMap`: add new `UintToUintMap` map type. ([#3338](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3338)) * `EnumerableMap`: add new `Bytes32ToUintMap` map type. ([#3416](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3416)) diff --git a/contracts/mocks/MathMock.sol b/contracts/mocks/MathMock.sol index 3fac5e769..a9022aa4c 100644 --- a/contracts/mocks/MathMock.sol +++ b/contracts/mocks/MathMock.sol @@ -29,4 +29,8 @@ contract MathMock { ) public pure returns (uint256) { return Math.mulDiv(a, b, denominator, direction); } + + function sqrt(uint256 a, Math.Rounding direction) public pure returns (uint256) { + return Math.sqrt(a, direction); + } } diff --git a/contracts/utils/math/Math.sol b/contracts/utils/math/Math.sol index 150138f76..470aa1f1d 100644 --- a/contracts/utils/math/Math.sol +++ b/contracts/utils/math/Math.sol @@ -149,4 +149,78 @@ library Math { } return result; } + + /** + * @dev Returns the square root of a number. It the number is not a perfect square, the value is rounded down. + * + * Inspired by Henry S. Warren, Jr.'s "Hacker's Delight" (Chapter 11). + */ + function sqrt(uint256 a) internal pure returns (uint256) { + if (a == 0) { + return 0; + } + + // For our first guess, we get the biggest power of 2 which is smaller than the square root of the target. + // We know that the "msb" (most significant bit) of our target number `a` is a power of 2 such that we have + // `msb(a) <= a < 2*msb(a)`. + // We also know that `k`, the position of the most significant bit, is such that `msb(a) = 2**k`. + // This gives `2**k < a <= 2**(k+1)` → `2**(k/2) <= sqrt(a) < 2 ** (k/2+1)`. + // Using an algorithm similar to the msb conmputation, we are able to compute `result = 2**(k/2)` which is a + // good first aproximation of `sqrt(a)` with at least 1 correct bit. + uint256 result = 1; + uint256 x = a; + if (x >> 128 > 0) { + x >>= 128; + result <<= 64; + } + if (x >> 64 > 0) { + x >>= 64; + result <<= 32; + } + if (x >> 32 > 0) { + x >>= 32; + result <<= 16; + } + if (x >> 16 > 0) { + x >>= 16; + result <<= 8; + } + if (x >> 8 > 0) { + x >>= 8; + result <<= 4; + } + if (x >> 4 > 0) { + x >>= 4; + result <<= 2; + } + if (x >> 2 > 0) { + result <<= 1; + } + + // At this point `result` is an estimation with one bit of precision. We know the true value is a uint128, + // since it is the square root of a uint256. Newton's method converges quadratically (precision doubles at + // every iteration). We thus need at most 7 iteration to turn our partial result with one bit of precision + // into the expected uint128 result. + unchecked { + result = (result + a / result) >> 1; + result = (result + a / result) >> 1; + result = (result + a / result) >> 1; + result = (result + a / result) >> 1; + result = (result + a / result) >> 1; + result = (result + a / result) >> 1; + result = (result + a / result) >> 1; + return min(result, a / result); + } + } + + /** + * @notice Calculates sqrt(a), following the selected rounding direction. + */ + function sqrt(uint256 a, Rounding rounding) internal pure returns (uint256) { + uint256 result = sqrt(a); + if (rounding == Rounding.Up && result * result < a) { + result += 1; + } + return result; + } } diff --git a/test/utils/math/Math.test.js b/test/utils/math/Math.test.js index c93165804..a71deb50d 100644 --- a/test/utils/math/Math.test.js +++ b/test/utils/math/Math.test.js @@ -182,4 +182,38 @@ contract('Math', function (accounts) { }); }); }); + + describe('sqrt', function () { + it('rounds down', async function () { + expect(await this.math.sqrt(new BN('0'), Rounding.Down)).to.be.bignumber.equal('0'); + expect(await this.math.sqrt(new BN('1'), Rounding.Down)).to.be.bignumber.equal('1'); + expect(await this.math.sqrt(new BN('2'), Rounding.Down)).to.be.bignumber.equal('1'); + expect(await this.math.sqrt(new BN('3'), Rounding.Down)).to.be.bignumber.equal('1'); + expect(await this.math.sqrt(new BN('4'), Rounding.Down)).to.be.bignumber.equal('2'); + expect(await this.math.sqrt(new BN('144'), Rounding.Down)).to.be.bignumber.equal('12'); + expect(await this.math.sqrt(new BN('999999'), Rounding.Down)).to.be.bignumber.equal('999'); + expect(await this.math.sqrt(new BN('1000000'), Rounding.Down)).to.be.bignumber.equal('1000'); + expect(await this.math.sqrt(new BN('1000001'), Rounding.Down)).to.be.bignumber.equal('1000'); + expect(await this.math.sqrt(new BN('1002000'), Rounding.Down)).to.be.bignumber.equal('1000'); + expect(await this.math.sqrt(new BN('1002001'), Rounding.Down)).to.be.bignumber.equal('1001'); + expect(await this.math.sqrt(MAX_UINT256, Rounding.Down)) + .to.be.bignumber.equal('340282366920938463463374607431768211455'); + }); + + it('rounds up', async function () { + expect(await this.math.sqrt(new BN('0'), Rounding.Up)).to.be.bignumber.equal('0'); + expect(await this.math.sqrt(new BN('1'), Rounding.Up)).to.be.bignumber.equal('1'); + expect(await this.math.sqrt(new BN('2'), Rounding.Up)).to.be.bignumber.equal('2'); + expect(await this.math.sqrt(new BN('3'), Rounding.Up)).to.be.bignumber.equal('2'); + expect(await this.math.sqrt(new BN('4'), Rounding.Up)).to.be.bignumber.equal('2'); + expect(await this.math.sqrt(new BN('144'), Rounding.Up)).to.be.bignumber.equal('12'); + expect(await this.math.sqrt(new BN('999999'), Rounding.Up)).to.be.bignumber.equal('1000'); + expect(await this.math.sqrt(new BN('1000000'), Rounding.Up)).to.be.bignumber.equal('1000'); + expect(await this.math.sqrt(new BN('1000001'), Rounding.Up)).to.be.bignumber.equal('1001'); + expect(await this.math.sqrt(new BN('1002000'), Rounding.Up)).to.be.bignumber.equal('1001'); + expect(await this.math.sqrt(new BN('1002001'), Rounding.Up)).to.be.bignumber.equal('1001'); + expect(await this.math.sqrt(MAX_UINT256, Rounding.Up)) + .to.be.bignumber.equal('340282366920938463463374607431768211456'); + }); + }); });