diff --git a/core/tx_noncer.go b/core/tx_noncer.go index 98a78e087e..aa87c643ae 100644 --- a/core/tx_noncer.go +++ b/core/tx_noncer.go @@ -17,6 +17,8 @@ package core import ( + "sync" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/state" ) @@ -27,6 +29,7 @@ import ( type txNoncer struct { fallback *state.StateDB nonces map[common.Address]uint64 + lock sync.Mutex } // newTxNoncer creates a new virtual state database to track the pool nonces. @@ -40,6 +43,11 @@ func newTxNoncer(statedb *state.StateDB) *txNoncer { // get returns the current nonce of an account, falling back to a real state // database if the account is unknown. func (txn *txNoncer) get(addr common.Address) uint64 { + // We use mutex for get operation is the underlying + // state will mutate db even for read access. + txn.lock.Lock() + defer txn.lock.Unlock() + if _, ok := txn.nonces[addr]; !ok { txn.nonces[addr] = txn.fallback.GetNonce(addr) } @@ -49,5 +57,23 @@ func (txn *txNoncer) get(addr common.Address) uint64 { // set inserts a new virtual nonce into the virtual state database to be returned // whenever the pool requests it instead of reaching into the real state database. func (txn *txNoncer) set(addr common.Address, nonce uint64) { + txn.lock.Lock() + defer txn.lock.Unlock() + + txn.nonces[addr] = nonce +} + +// setIfLower updates a new virtual nonce into the virtual state database if the +// the new one is lower. +func (txn *txNoncer) setIfLower(addr common.Address, nonce uint64) { + txn.lock.Lock() + defer txn.lock.Unlock() + + if _, ok := txn.nonces[addr]; !ok { + txn.nonces[addr] = txn.fallback.GetNonce(addr) + } + if txn.nonces[addr] <= nonce { + return + } txn.nonces[addr] = nonce } diff --git a/core/tx_pool.go b/core/tx_pool.go index 43caf16b18..c41d3fbd4a 100644 --- a/core/tx_pool.go +++ b/core/tx_pool.go @@ -854,9 +854,7 @@ func (pool *TxPool) removeTx(hash common.Hash, outofbound bool) { pool.enqueueTx(tx.Hash(), tx) } // Update the account nonce if needed - if nonce := tx.Nonce(); pool.pendingNonces.get(addr) > nonce { - pool.pendingNonces.set(addr, nonce) - } + pool.pendingNonces.setIfLower(addr, tx.Nonce()) // Reduce the pending counter pendingCounter.Dec(int64(1 + len(invalids))) return @@ -1232,9 +1230,7 @@ func (pool *TxPool) truncatePending() { pool.all.Remove(hash) // Update the account nonce to the dropped transaction - if nonce := tx.Nonce(); pool.pendingNonces.get(offenders[i]) > nonce { - pool.pendingNonces.set(offenders[i], nonce) - } + pool.pendingNonces.setIfLower(offenders[i], tx.Nonce()) log.Trace("Removed fairness-exceeding pending transaction", "hash", hash) } pool.priced.Removed(len(caps)) @@ -1261,9 +1257,7 @@ func (pool *TxPool) truncatePending() { pool.all.Remove(hash) // Update the account nonce to the dropped transaction - if nonce := tx.Nonce(); pool.pendingNonces.get(addr) > nonce { - pool.pendingNonces.set(addr, nonce) - } + pool.pendingNonces.setIfLower(addr, tx.Nonce()) log.Trace("Removed fairness-exceeding pending transaction", "hash", hash) } pool.priced.Removed(len(caps))