diff --git a/core/tx_pool.go b/core/tx_pool.go index f41fbe069a..0ad7651795 100644 --- a/core/tx_pool.go +++ b/core/tx_pool.go @@ -19,6 +19,7 @@ package core import ( "errors" "fmt" + "math" "math/big" "sort" "sync" @@ -105,11 +106,11 @@ var ( // blockChain provides the state of blockchain and current gas limit to do // some pre checks in tx pool and event subscribers. type blockChain interface { - CurrentHeader() *types.Header - SubscribeChainHeadEvent(ch chan<- ChainHeadEvent) event.Subscription - + CurrentBlock() *types.Block GetBlock(hash common.Hash, number uint64) *types.Block StateAt(root common.Hash) (*state.StateDB, error) + + SubscribeChainHeadEvent(ch chan<- ChainHeadEvent) event.Subscription } // TxPoolConfig are the configuration parameters of the transaction pool. @@ -223,7 +224,7 @@ func NewTxPool(config TxPoolConfig, chainconfig *params.ChainConfig, chain block } pool.locals = newAccountSet(pool.signer) pool.priced = newTxPricedList(&pool.all) - pool.reset(nil, chain.CurrentHeader()) + pool.reset(nil, chain.CurrentBlock().Header()) // If local transactions and journaling is enabled, load from disk if !config.NoLocals && config.Journal != "" { @@ -265,7 +266,7 @@ func (pool *TxPool) loop() { defer journal.Stop() // Track the previous head headers for transaction reorgs - head := pool.chain.CurrentHeader() + head := pool.chain.CurrentBlock() // Keep waiting for and reacting to the various events for { @@ -277,8 +278,8 @@ func (pool *TxPool) loop() { if pool.chainconfig.IsHomestead(ev.Block.Number()) { pool.homestead = true } - pool.reset(head, ev.Block.Header()) - head = ev.Block.Header() + pool.reset(head.Header(), ev.Block.Header()) + head = ev.Block pool.mu.Unlock() } @@ -344,43 +345,52 @@ func (pool *TxPool) reset(oldHead, newHead *types.Header) { var reinject types.Transactions if oldHead != nil && oldHead.Hash() != newHead.ParentHash { - var discarded, included types.Transactions - - var ( - rem = pool.chain.GetBlock(oldHead.Hash(), oldHead.Number.Uint64()) - add = pool.chain.GetBlock(newHead.Hash(), newHead.Number.Uint64()) - ) - for rem.NumberU64() > add.NumberU64() { - discarded = append(discarded, rem.Transactions()...) - if rem = pool.chain.GetBlock(rem.ParentHash(), rem.NumberU64()-1); rem == nil { - log.Error("Unrooted old chain seen by tx pool", "block", oldHead.Number, "hash", oldHead.Hash()) - return - } - } - for add.NumberU64() > rem.NumberU64() { - included = append(included, add.Transactions()...) - if add = pool.chain.GetBlock(add.ParentHash(), add.NumberU64()-1); add == nil { - log.Error("Unrooted new chain seen by tx pool", "block", newHead.Number, "hash", newHead.Hash()) - return + // If the reorg is too deep, avoid doing it (will happen during fast sync) + oldNum := oldHead.Number.Uint64() + newNum := newHead.Number.Uint64() + + if depth := uint64(math.Abs(float64(oldNum) - float64(newNum))); depth > 64 { + log.Warn("Skipping deep transaction reorg", "depth", depth) + } else { + // Reorg seems shallow enough to pull in all transactions into memory + var discarded, included types.Transactions + + var ( + rem = pool.chain.GetBlock(oldHead.Hash(), oldHead.Number.Uint64()) + add = pool.chain.GetBlock(newHead.Hash(), newHead.Number.Uint64()) + ) + for rem.NumberU64() > add.NumberU64() { + discarded = append(discarded, rem.Transactions()...) + if rem = pool.chain.GetBlock(rem.ParentHash(), rem.NumberU64()-1); rem == nil { + log.Error("Unrooted old chain seen by tx pool", "block", oldHead.Number, "hash", oldHead.Hash()) + return + } } - } - for rem.Hash() != add.Hash() { - discarded = append(discarded, rem.Transactions()...) - if rem = pool.chain.GetBlock(rem.ParentHash(), rem.NumberU64()-1); rem == nil { - log.Error("Unrooted old chain seen by tx pool", "block", oldHead.Number, "hash", oldHead.Hash()) - return + for add.NumberU64() > rem.NumberU64() { + included = append(included, add.Transactions()...) + if add = pool.chain.GetBlock(add.ParentHash(), add.NumberU64()-1); add == nil { + log.Error("Unrooted new chain seen by tx pool", "block", newHead.Number, "hash", newHead.Hash()) + return + } } - included = append(included, add.Transactions()...) - if add = pool.chain.GetBlock(add.ParentHash(), add.NumberU64()-1); add == nil { - log.Error("Unrooted new chain seen by tx pool", "block", newHead.Number, "hash", newHead.Hash()) - return + for rem.Hash() != add.Hash() { + discarded = append(discarded, rem.Transactions()...) + if rem = pool.chain.GetBlock(rem.ParentHash(), rem.NumberU64()-1); rem == nil { + log.Error("Unrooted old chain seen by tx pool", "block", oldHead.Number, "hash", oldHead.Hash()) + return + } + included = append(included, add.Transactions()...) + if add = pool.chain.GetBlock(add.ParentHash(), add.NumberU64()-1); add == nil { + log.Error("Unrooted new chain seen by tx pool", "block", newHead.Number, "hash", newHead.Hash()) + return + } } + reinject = types.TxDifference(discarded, included) } - reinject = types.TxDifference(discarded, included) } // Initialize the internal state to the current head if newHead == nil { - newHead = pool.chain.CurrentHeader() // Special case during testing + newHead = pool.chain.CurrentBlock().Header() // Special case during testing } statedb, err := pool.chain.StateAt(newHead.Root) if err != nil { diff --git a/core/tx_pool_test.go b/core/tx_pool_test.go index cdd45b4b1d..17d7368774 100644 --- a/core/tx_pool_test.go +++ b/core/tx_pool_test.go @@ -50,24 +50,24 @@ type testBlockChain struct { chainHeadFeed *event.Feed } -func (bc *testBlockChain) CurrentHeader() *types.Header { - return &types.Header{ +func (bc *testBlockChain) CurrentBlock() *types.Block { + return types.NewBlock(&types.Header{ GasLimit: bc.gasLimit, - } -} - -func (bc *testBlockChain) SubscribeChainHeadEvent(ch chan<- ChainHeadEvent) event.Subscription { - return bc.chainHeadFeed.Subscribe(ch) + }, nil, nil, nil) } func (bc *testBlockChain) GetBlock(hash common.Hash, number uint64) *types.Block { - return types.NewBlock(bc.CurrentHeader(), nil, nil, nil) + return bc.CurrentBlock() } func (bc *testBlockChain) StateAt(common.Hash) (*state.StateDB, error) { return bc.statedb, nil } +func (bc *testBlockChain) SubscribeChainHeadEvent(ch chan<- ChainHeadEvent) event.Subscription { + return bc.chainHeadFeed.Subscribe(ch) +} + func transaction(nonce uint64, gaslimit *big.Int, key *ecdsa.PrivateKey) *types.Transaction { return pricedTransaction(nonce, gaslimit, big.NewInt(1), key) }