From a2bd1bb7f6a6d68029e02ec42ad3309fd2ba331b Mon Sep 17 00:00:00 2001 From: Remco Bloemen Date: Fri, 24 Mar 2017 11:01:06 +0000 Subject: [PATCH] Add ReentrancyGuard --- contracts/ReentrancyGuard.sol | 28 +++++++++++++++++++ test/ReentrancyGuard.js | 31 +++++++++++++++++++++ test/helpers/ReentrancyAttack.sol | 11 ++++++++ test/helpers/ReentrancyMock.sol | 46 +++++++++++++++++++++++++++++++ test/helpers/expectThrow.js | 20 ++++++++++++++ 5 files changed, 136 insertions(+) create mode 100644 contracts/ReentrancyGuard.sol create mode 100644 test/ReentrancyGuard.js create mode 100644 test/helpers/ReentrancyAttack.sol create mode 100644 test/helpers/ReentrancyMock.sol create mode 100644 test/helpers/expectThrow.js diff --git a/contracts/ReentrancyGuard.sol b/contracts/ReentrancyGuard.sol new file mode 100644 index 000000000..02c3c49c8 --- /dev/null +++ b/contracts/ReentrancyGuard.sol @@ -0,0 +1,28 @@ +pragma solidity ^0.4.8; + +/// @title Helps contracts guard agains rentrancy attacks. +/// @author Remco Bloemen +/// @notice If you mark a function `nonReentrant`, you should also +/// mark it `external`. +contract ReentrancyGuard { + + /// @dev We use a single lock for the whole contract. + bool private rentrancy_lock = false; + + /// Prevent contract from calling itself, directly or indirectly. + /// @notice If you mark a function `nonReentrant`, you should also + /// mark it `external`. Calling one nonReentrant function from + /// another is not supported. Instead, you can implement a + /// `private` function doing the actual work, and a `external` + /// wrapper marked as `nonReentrant`. + modifier nonReentrant() { + if(rentrancy_lock == false) { + rentrancy_lock = true; + _; + rentrancy_lock = false; + } else { + throw; + } + } + +} diff --git a/test/ReentrancyGuard.js b/test/ReentrancyGuard.js new file mode 100644 index 000000000..b3145df1d --- /dev/null +++ b/test/ReentrancyGuard.js @@ -0,0 +1,31 @@ +'use strict'; +import expectThrow from './helpers/expectThrow'; +const ReentrancyMock = artifacts.require('./helper/ReentrancyMock.sol'); +const ReentrancyAttack = artifacts.require('./helper/ReentrancyAttack.sol'); + +contract('ReentrancyGuard', function(accounts) { + let reentrancyMock; + + beforeEach(async function() { + reentrancyMock = await ReentrancyMock.new(); + let initialCounter = await reentrancyMock.counter(); + assert.equal(initialCounter, 0); + }); + + it('should not allow remote callback', async function() { + let attacker = await ReentrancyAttack.new(); + await expectThrow(reentrancyMock.countAndCall(attacker.address)); + }); + + // The following are more side-effects that intended behaviour: + // I put them here as documentation, and to monitor any changes + // in the side-effects. + + it('should not allow local recursion', async function() { + await expectThrow(reentrancyMock.countLocalRecursive(10)); + }); + + it('should not allow indirect local recursion', async function() { + await expectThrow(reentrancyMock.countThisRecursive(10)); + }); +}); diff --git a/test/helpers/ReentrancyAttack.sol b/test/helpers/ReentrancyAttack.sol new file mode 100644 index 000000000..ce67683a6 --- /dev/null +++ b/test/helpers/ReentrancyAttack.sol @@ -0,0 +1,11 @@ +pragma solidity ^0.4.8; + +contract ReentrancyAttack { + + function callSender(bytes4 data) { + if(!msg.sender.call(data)) { + throw; + } + } + +} diff --git a/test/helpers/ReentrancyMock.sol b/test/helpers/ReentrancyMock.sol new file mode 100644 index 000000000..dbfb41209 --- /dev/null +++ b/test/helpers/ReentrancyMock.sol @@ -0,0 +1,46 @@ +pragma solidity ^0.4.8; + +import '../../contracts/ReentrancyGuard.sol'; +import './ReentrancyAttack.sol'; + +contract ReentrancyMock is ReentrancyGuard { + + uint256 public counter; + + function ReentrancyMock() { + counter = 0; + } + + function count() private { + counter += 1; + } + + function countLocalRecursive(uint n) public nonReentrant { + if(n > 0) { + count(); + countLocalRecursive(n - 1); + } + } + + function countThisRecursive(uint256 n) public nonReentrant { + bytes4 func = bytes4(keccak256("countThisRecursive(uint256)")); + if(n > 0) { + count(); + bool result = this.call(func, n - 1); + if(result != true) { + throw; + } + } + } + + function countAndCall(ReentrancyAttack attacker) public nonReentrant { + count(); + bytes4 func = bytes4(keccak256("callback()")); + attacker.callSender(func); + } + + function callback() external nonReentrant { + count(); + } + +} diff --git a/test/helpers/expectThrow.js b/test/helpers/expectThrow.js new file mode 100644 index 000000000..45bdcfdb0 --- /dev/null +++ b/test/helpers/expectThrow.js @@ -0,0 +1,20 @@ +export default async promise => { + try { + await promise; + } catch (error) { + // TODO: Check jump destination to destinguish between a throw + // and an actual invalid jump. + const invalidJump = error.message.search('invalid JUMP') >= 0; + // TODO: When we contract A calls contract B, and B throws, instead + // of an 'invalid jump', we get an 'out of gas' error. How do + // we distinguish this from an actual out of gas event? (The + // testrpc log actually show an 'invalid jump' event.) + const outOfGas = error.message.search('out of gas') >= 0; + assert( + invalidJump || outOfGas, + "Expected throw, got '" + error + "' instead", + ); + return; + } + assert.fail('Expected throw not received'); +};