diff --git a/contracts/payment/PullPayment.sol b/contracts/payment/PullPayment.sol index 130a7b0d6..694b68837 100644 --- a/contracts/payment/PullPayment.sol +++ b/contracts/payment/PullPayment.sol @@ -11,17 +11,19 @@ import '../SafeMath.sol'; */ contract PullPayment is SafeMath { mapping(address => uint) public payments; + uint public totalPayments; // store sent amount as credit to be pulled, called by payer function asyncSend(address dest, uint amount) internal { payments[dest] = safeAdd(payments[dest], amount); + totalPayments = safeAdd(totalPayments, amount); } // withdraw accumulated balance, called by payee function withdrawPayments() { address payee = msg.sender; uint payment = payments[payee]; - + if (payment == 0) { throw; } @@ -30,6 +32,7 @@ contract PullPayment is SafeMath { throw; } + totalPayments = safeSub(totalPayments, payment); payments[payee] = 0; if (!payee.send(payment)) { diff --git a/test/PullPayment.js b/test/PullPayment.js index 1953d892b..f8536d732 100644 --- a/test/PullPayment.js +++ b/test/PullPayment.js @@ -12,7 +12,9 @@ contract('PullPayment', function(accounts) { let ppce = await PullPaymentMock.new(); let callSend = await ppce.callSend(accounts[0], AMOUNT); let paymentsToAccount0 = await ppce.payments(accounts[0]); + let totalPayments = await ppce.totalPayments(); + assert.equal(totalPayments, AMOUNT); assert.equal(paymentsToAccount0, AMOUNT); }); @@ -21,7 +23,9 @@ contract('PullPayment', function(accounts) { let call1 = await ppce.callSend(accounts[0], 200); let call2 = await ppce.callSend(accounts[0], 300); let paymentsToAccount0 = await ppce.payments(accounts[0]); + let totalPayments = await ppce.totalPayments(); + assert.equal(totalPayments, 500); assert.equal(paymentsToAccount0, 500); }); @@ -35,6 +39,9 @@ contract('PullPayment', function(accounts) { let paymentsToAccount1 = await ppce.payments(accounts[1]); assert.equal(paymentsToAccount1, 300); + + let totalPayments = await ppce.totalPayments(); + assert.equal(totalPayments, 500); }); it("can withdraw payment", async function() { @@ -48,10 +55,16 @@ contract('PullPayment', function(accounts) { let payment1 = await ppce.payments(payee); assert.equal(payment1, AMOUNT); + let totalPayments = await ppce.totalPayments(); + assert.equal(totalPayments, AMOUNT); + let withdraw = await ppce.withdrawPayments({from: payee}); let payment2 = await ppce.payments(payee); assert.equal(payment2, 0); + totalPayments = await ppce.totalPayments(); + assert.equal(totalPayments, 0); + let balance = web3.eth.getBalance(payee); assert(Math.abs(balance-initialBalance-AMOUNT) < 1e16); });