diff --git a/core/block_processor.go b/core/block_processor.go index 60f0258c4..5172636dd 100644 --- a/core/block_processor.go +++ b/core/block_processor.go @@ -195,14 +195,16 @@ func (sm *BlockProcessor) Process(block *types.Block) (logs vm.Logs, receipts ty defer sm.mutex.Unlock() if sm.bc.HasBlock(block.Hash()) { - return nil, nil, &KnownBlockError{block.Number(), block.Hash()} + if _, err := state.New(block.Root(), sm.chainDb); err == nil { + return nil, nil, &KnownBlockError{block.Number(), block.Hash()} + } } - - if !sm.bc.HasBlock(block.ParentHash()) { - return nil, nil, ParentError(block.ParentHash()) + if parent := sm.bc.GetBlock(block.ParentHash()); parent != nil { + if _, err := state.New(parent.Root(), sm.chainDb); err == nil { + return sm.processWithParent(block, parent) + } } - parent := sm.bc.GetBlock(block.ParentHash()) - return sm.processWithParent(block, parent) + return nil, nil, ParentError(block.ParentHash()) } func (sm *BlockProcessor) processWithParent(block, parent *types.Block) (logs vm.Logs, receipts types.Receipts, err error) { diff --git a/core/blockchain.go b/core/blockchain.go index 490552ea0..f14ff363c 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -18,11 +18,13 @@ package core import ( + crand "crypto/rand" "errors" "fmt" "io" + "math" "math/big" - "math/rand" + mrand "math/rand" "runtime" "sync" "sync/atomic" @@ -89,7 +91,8 @@ type BlockChain struct { procInterrupt int32 // interrupt signaler for block processing wg sync.WaitGroup - pow pow.PoW + pow pow.PoW + rand *mrand.Rand } func NewBlockChain(chainDb ethdb.Database, pow pow.PoW, mux *event.TypeMux) (*BlockChain, error) { @@ -112,6 +115,12 @@ func NewBlockChain(chainDb ethdb.Database, pow pow.PoW, mux *event.TypeMux) (*Bl futureBlocks: futureBlocks, pow: pow, } + // Seed a fast but crypto originating random generator + seed, err := crand.Int(crand.Reader, big.NewInt(math.MaxInt64)) + if err != nil { + return nil, err + } + bc.rand = mrand.New(mrand.NewSource(seed.Int64())) bc.genesisBlock = bc.GetBlockByNumber(0) if bc.genesisBlock == nil { @@ -178,21 +187,21 @@ func (self *BlockChain) loadLastState() error { fastTd := self.GetTd(self.currentFastBlock.Hash()) glog.V(logger.Info).Infof("Last header: #%d [%x…] TD=%v", self.currentHeader.Number, self.currentHeader.Hash().Bytes()[:4], headerTd) - glog.V(logger.Info).Infof("Fast block: #%d [%x…] TD=%v", self.currentFastBlock.Number(), self.currentFastBlock.Hash().Bytes()[:4], fastTd) glog.V(logger.Info).Infof("Last block: #%d [%x…] TD=%v", self.currentBlock.Number(), self.currentBlock.Hash().Bytes()[:4], blockTd) + glog.V(logger.Info).Infof("Fast block: #%d [%x…] TD=%v", self.currentFastBlock.Number(), self.currentFastBlock.Hash().Bytes()[:4], fastTd) return nil } -// SetHead rewind the local chain to a new head entity. In the case of headers, -// everything above the new head will be deleted and the new one set. In the case -// of blocks though, the head may be further rewound if block bodies are missing -// (non-archive nodes after a fast sync). +// SetHead rewinds the local chain to a new head. In the case of headers, everything +// above the new head will be deleted and the new one set. In the case of blocks +// though, the head may be further rewound if block bodies are missing (non-archive +// nodes after a fast sync). func (bc *BlockChain) SetHead(head uint64) { bc.mu.Lock() defer bc.mu.Unlock() - // Figure out the highest known canonical assignment + // Figure out the highest known canonical headers and/or blocks height := uint64(0) if bc.currentHeader != nil { if hh := bc.currentHeader.Number.Uint64(); hh > height { @@ -266,7 +275,7 @@ func (bc *BlockChain) SetHead(head uint64) { // FastSyncCommitHead sets the current head block to the one defined by the hash // irrelevant what the chain contents were prior. func (self *BlockChain) FastSyncCommitHead(hash common.Hash) error { - // Make sure that both the block as well at it's state trie exists + // Make sure that both the block as well at its state trie exists block := self.GetBlock(hash) if block == nil { return fmt.Errorf("non existent block [%x…]", hash[:4]) @@ -298,7 +307,7 @@ func (self *BlockChain) LastBlockHash() common.Hash { } // CurrentHeader retrieves the current head header of the canonical chain. The -// header is retrieved from the chain manager's internal cache. +// header is retrieved from the blockchain's internal cache. func (self *BlockChain) CurrentHeader() *types.Header { self.mu.RLock() defer self.mu.RUnlock() @@ -307,7 +316,7 @@ func (self *BlockChain) CurrentHeader() *types.Header { } // CurrentBlock retrieves the current head block of the canonical chain. The -// block is retrieved from the chain manager's internal cache. +// block is retrieved from the blockchain's internal cache. func (self *BlockChain) CurrentBlock() *types.Block { self.mu.RLock() defer self.mu.RUnlock() @@ -316,7 +325,7 @@ func (self *BlockChain) CurrentBlock() *types.Block { } // CurrentFastBlock retrieves the current fast-sync head block of the canonical -// chain. The block is retrieved from the chain manager's internal cache. +// chain. The block is retrieved from the blockchain's internal cache. func (self *BlockChain) CurrentFastBlock() *types.Block { self.mu.RLock() defer self.mu.RUnlock() @@ -353,7 +362,7 @@ func (bc *BlockChain) ResetWithGenesisBlock(genesis *types.Block) { bc.mu.Lock() defer bc.mu.Unlock() - // Prepare the genesis block and reinitialize the chain + // Prepare the genesis block and reinitialise the chain if err := WriteTd(bc.chainDb, genesis.Hash(), genesis.Difficulty()); err != nil { glog.Fatalf("failed to write genesis block TD: %v", err) } @@ -403,7 +412,7 @@ func (self *BlockChain) ExportN(w io.Writer, first uint64, last uint64) error { // insert injects a new head block into the current block chain. This method // assumes that the block is indeed a true head. It will also reset the head // header and the head fast sync block to this very same block to prevent them -// from diverging on a different header chain. +// from pointing to a possibly old canonical chain (i.e. side chain by now). // // Note, this function assumes that the `mu` mutex is held! func (bc *BlockChain) insert(block *types.Block) { @@ -625,10 +634,10 @@ const ( // writeHeader writes a header into the local chain, given that its parent is // already known. If the total difficulty of the newly inserted header becomes -// greater than the old known TD, the canonical chain is re-routed. +// greater than the current known TD, the canonical chain is re-routed. // // Note: This method is not concurrent-safe with inserting blocks simultaneously -// into the chain, as side effects caused by reorganizations cannot be emulated +// into the chain, as side effects caused by reorganisations cannot be emulated // without the real blocks. Hence, writing headers directly should only be done // in two scenarios: pure-header mode of operation (light clients), or properly // separated header/block phases (non-archive clients). @@ -678,10 +687,9 @@ func (self *BlockChain) writeHeader(header *types.Header) error { return nil } -// InsertHeaderChain will attempt to insert the given header chain in to the -// local chain, possibly creating a fork. If an error is returned, it will -// return the index number of the failing header as well an error describing -// what went wrong. +// InsertHeaderChain attempts to insert the given header chain in to the local +// chain, possibly creating a reorg. If an error is returned, it will return the +// index number of the failing header as well an error describing what went wrong. // // The verify parameter can be used to fine tune whether nonce verification // should be done or not. The reason behind the optional check is because some @@ -702,7 +710,7 @@ func (self *BlockChain) InsertHeaderChain(chain []*types.Header, checkFreq int) // Generate the list of headers that should be POW verified verify := make([]bool, len(chain)) for i := 0; i < len(verify)/checkFreq; i++ { - index := i*checkFreq + rand.Intn(checkFreq) + index := i*checkFreq + self.rand.Intn(checkFreq) if index >= len(verify) { index = len(verify) - 1 } @@ -766,10 +774,6 @@ func (self *BlockChain) InsertHeaderChain(chain []*types.Header, checkFreq int) pending.Wait() // If anything failed, report - if atomic.LoadInt32(&self.procInterrupt) == 1 { - glog.V(logger.Debug).Infoln("premature abort during receipt chain processing") - return 0, nil - } if failed > 0 { for i, err := range errs { if err != nil { @@ -807,6 +811,9 @@ func (self *BlockChain) InsertHeaderChain(chain []*types.Header, checkFreq int) // Rollback is designed to remove a chain of links from the database that aren't // certain enough to be valid. func (self *BlockChain) Rollback(chain []common.Hash) { + self.mu.Lock() + defer self.mu.Unlock() + for i := len(chain) - 1; i >= 0; i-- { hash := chain[i] @@ -905,6 +912,12 @@ func (self *BlockChain) InsertReceiptChain(blockChain types.Blocks, receiptChain glog.Fatal(errs[index]) return } + if err := WriteMipmapBloom(self.chainDb, block.NumberU64(), receipts); err != nil { + errs[index] = fmt.Errorf("failed to write log blooms: %v", err) + atomic.AddInt32(&failed, 1) + glog.Fatal(errs[index]) + return + } atomic.AddInt32(&stats.processed, 1) } } @@ -920,10 +933,6 @@ func (self *BlockChain) InsertReceiptChain(blockChain types.Blocks, receiptChain pending.Wait() // If anything failed, report - if atomic.LoadInt32(&self.procInterrupt) == 1 { - glog.V(logger.Debug).Infoln("premature abort during receipt chain processing") - return 0, nil - } if failed > 0 { for i, err := range errs { if err != nil { @@ -931,6 +940,10 @@ func (self *BlockChain) InsertReceiptChain(blockChain types.Blocks, receiptChain } } } + if atomic.LoadInt32(&self.procInterrupt) == 1 { + glog.V(logger.Debug).Infoln("premature abort during receipt chain processing") + return 0, nil + } // Update the head fast sync block if better self.mu.Lock() head := blockChain[len(errs)-1] diff --git a/core/blockchain_test.go b/core/blockchain_test.go index 01667c21e..8ddc5032b 100644 --- a/core/blockchain_test.go +++ b/core/blockchain_test.go @@ -452,7 +452,7 @@ func makeBlockChainWithDiff(genesis *types.Block, d []int, seed byte) []*types.B func chm(genesis *types.Block, db ethdb.Database) *BlockChain { var eventMux event.TypeMux - bc := &BlockChain{chainDb: db, genesisBlock: genesis, eventMux: &eventMux, pow: FakePow{}} + bc := &BlockChain{chainDb: db, genesisBlock: genesis, eventMux: &eventMux, pow: FakePow{}, rand: rand.New(rand.NewSource(0))} bc.headerCache, _ = lru.New(100) bc.bodyCache, _ = lru.New(100) bc.bodyRLPCache, _ = lru.New(100) diff --git a/core/chain_util.go b/core/chain_util.go index 907e6668c..ddff381a1 100644 --- a/core/chain_util.go +++ b/core/chain_util.go @@ -394,7 +394,7 @@ func WriteMipmapBloom(db ethdb.Database, number uint64, receipts types.Receipts) bloomDat, _ := db.Get(key) bloom := types.BytesToBloom(bloomDat) for _, receipt := range receipts { - for _, log := range receipt.Logs() { + for _, log := range receipt.Logs { bloom.Add(log.Address.Big()) } } diff --git a/core/chain_util_test.go b/core/chain_util_test.go index bc5aa9776..0bbcbbe53 100644 --- a/core/chain_util_test.go +++ b/core/chain_util_test.go @@ -345,15 +345,15 @@ func TestMipmapBloom(t *testing.T) { db, _ := ethdb.NewMemDatabase() receipt1 := new(types.Receipt) - receipt1.SetLogs(vm.Logs{ + receipt1.Logs = vm.Logs{ &vm.Log{Address: common.BytesToAddress([]byte("test"))}, &vm.Log{Address: common.BytesToAddress([]byte("address"))}, - }) + } receipt2 := new(types.Receipt) - receipt2.SetLogs(vm.Logs{ + receipt2.Logs = vm.Logs{ &vm.Log{Address: common.BytesToAddress([]byte("test"))}, &vm.Log{Address: common.BytesToAddress([]byte("address1"))}, - }) + } WriteMipmapBloom(db, 1, types.Receipts{receipt1}) WriteMipmapBloom(db, 2, types.Receipts{receipt2}) @@ -368,15 +368,15 @@ func TestMipmapBloom(t *testing.T) { // reset db, _ = ethdb.NewMemDatabase() receipt := new(types.Receipt) - receipt.SetLogs(vm.Logs{ + receipt.Logs = vm.Logs{ &vm.Log{Address: common.BytesToAddress([]byte("test"))}, - }) + } WriteMipmapBloom(db, 999, types.Receipts{receipt1}) receipt = new(types.Receipt) - receipt.SetLogs(vm.Logs{ + receipt.Logs = vm.Logs{ &vm.Log{Address: common.BytesToAddress([]byte("test 1"))}, - }) + } WriteMipmapBloom(db, 1000, types.Receipts{receipt}) bloom := GetMipmapBloom(db, 1000, 1000) @@ -403,22 +403,22 @@ func TestMipmapChain(t *testing.T) { defer db.Close() genesis := WriteGenesisBlockForTesting(db, GenesisAccount{addr, big.NewInt(1000000)}) - chain := GenerateChain(genesis, db, 1010, func(i int, gen *BlockGen) { + chain, receipts := GenerateChain(genesis, db, 1010, func(i int, gen *BlockGen) { var receipts types.Receipts switch i { case 1: receipt := types.NewReceipt(nil, new(big.Int)) - receipt.SetLogs(vm.Logs{ + receipt.Logs = vm.Logs{ &vm.Log{ Address: addr, Topics: []common.Hash{hash1}, }, - }) + } gen.AddUncheckedReceipt(receipt) receipts = types.Receipts{receipt} case 1000: receipt := types.NewReceipt(nil, new(big.Int)) - receipt.SetLogs(vm.Logs{&vm.Log{Address: addr2}}) + receipt.Logs = vm.Logs{&vm.Log{Address: addr2}} gen.AddUncheckedReceipt(receipt) receipts = types.Receipts{receipt} @@ -431,7 +431,7 @@ func TestMipmapChain(t *testing.T) { } WriteMipmapBloom(db, uint64(i+1), receipts) }) - for _, block := range chain { + for i, block := range chain { WriteBlock(db, block) if err := WriteCanonicalHash(db, block.Hash(), block.NumberU64()); err != nil { t.Fatalf("failed to insert block number: %v", err) @@ -439,7 +439,7 @@ func TestMipmapChain(t *testing.T) { if err := WriteHeadBlockHash(db, block.Hash()); err != nil { t.Fatalf("failed to insert block number: %v", err) } - if err := PutBlockReceipts(db, block, block.Receipts()); err != nil { + if err := PutBlockReceipts(db, block.Hash(), receipts[i]); err != nil { t.Fatal("error writing block receipts:", err) } } diff --git a/core/state/sync.go b/core/state/sync.go index 5a388886c..ef2b4b84c 100644 --- a/core/state/sync.go +++ b/core/state/sync.go @@ -26,14 +26,13 @@ import ( "github.com/ethereum/go-ethereum/trie" ) -// StateSync is the main state synchronisation scheduler, which provides yet the +// StateSync is the main state synchronisation scheduler, which provides yet the // unknown state hashes to retrieve, accepts node data associated with said hashes // and reconstructs the state database step by step until all is done. type StateSync trie.TrieSync // NewStateSync create a new state trie download scheduler. func NewStateSync(root common.Hash, database ethdb.Database) *StateSync { - // Pre-declare the result syncer t var syncer *trie.TrieSync callback := func(leaf []byte, parent common.Hash) error { diff --git a/core/state/sync_test.go b/core/state/sync_test.go index f0376d484..0dab372ba 100644 --- a/core/state/sync_test.go +++ b/core/state/sync_test.go @@ -38,7 +38,7 @@ type testAccount struct { func makeTestState() (ethdb.Database, common.Hash, []*testAccount) { // Create an empty state db, _ := ethdb.NewMemDatabase() - state := New(common.Hash{}, db) + state, _ := New(common.Hash{}, db) // Fill it with some arbitrary data accounts := []*testAccount{} @@ -68,7 +68,7 @@ func makeTestState() (ethdb.Database, common.Hash, []*testAccount) { // checkStateAccounts cross references a reconstructed state with an expected // account array. func checkStateAccounts(t *testing.T, db ethdb.Database, root common.Hash, accounts []*testAccount) { - state := New(root, db) + state, _ := New(root, db) for i, acc := range accounts { if balance := state.GetBalance(acc.address); balance.Cmp(acc.balance) != 0 { diff --git a/core/types/receipt.go b/core/types/receipt.go index aea5b3e91..e7d5203a3 100644 --- a/core/types/receipt.go +++ b/core/types/receipt.go @@ -67,7 +67,7 @@ func (r *Receipt) DecodeRLP(s *rlp.Stream) error { return nil } -// RlpEncode implements common.RlpEncode required for SHA derivation. +// RlpEncode implements common.RlpEncode required for SHA3 derivation. func (r *Receipt) RlpEncode() []byte { bytes, err := rlp.EncodeToBytes(r) if err != nil { @@ -82,7 +82,7 @@ func (r *Receipt) String() string { } // ReceiptForStorage is a wrapper around a Receipt that flattens and parses the -// entire content of a receipt, opposed to only the consensus fields originally. +// entire content of a receipt, as opposed to only the consensus fields originally. type ReceiptForStorage Receipt // EncodeRLP implements rlp.Encoder, and flattens all content fields of a receipt @@ -95,8 +95,8 @@ func (r *ReceiptForStorage) EncodeRLP(w io.Writer) error { return rlp.Encode(w, []interface{}{r.PostState, r.CumulativeGasUsed, r.Bloom, r.TxHash, r.ContractAddress, logs, r.GasUsed}) } -// DecodeRLP implements rlp.Decoder, and loads the consensus fields of a receipt -// from an RLP stream. +// DecodeRLP implements rlp.Decoder, and loads both consensus and implementation +// fields of a receipt from an RLP stream. func (r *ReceiptForStorage) DecodeRLP(s *rlp.Stream) error { var receipt struct { PostState []byte @@ -125,7 +125,7 @@ func (r *ReceiptForStorage) DecodeRLP(s *rlp.Stream) error { // Receipts is a wrapper around a Receipt array to implement types.DerivableList. type Receipts []*Receipt -// RlpEncode implements common.RlpEncode required for SHA derivation. +// RlpEncode implements common.RlpEncode required for SHA3 derivation. func (r Receipts) RlpEncode() []byte { bytes, err := rlp.EncodeToBytes(r) if err != nil { diff --git a/core/vm/log.go b/core/vm/log.go index 526221e43..191e3a253 100644 --- a/core/vm/log.go +++ b/core/vm/log.go @@ -66,6 +66,6 @@ func (l *Log) String() string { type Logs []*Log // LogForStorage is a wrapper around a Log that flattens and parses the entire -// content of a log, opposed to only the consensus fields originally (by hiding +// content of a log, as opposed to only the consensus fields originally (by hiding // the rlp interface methods). type LogForStorage Log diff --git a/eth/backend.go b/eth/backend.go index 0a3791783..a4f656ecd 100644 --- a/eth/backend.go +++ b/eth/backend.go @@ -391,7 +391,6 @@ func New(config *Config) (*Ethereum, error) { if err == core.ErrNoGenesis { return nil, fmt.Errorf(`Genesis block not found. Please supply a genesis block with the "--genesis /path/to/file" argument`) } - return nil, err } newPool := core.NewTxPool(eth.EventMux(), eth.blockchain.State, eth.blockchain.GasLimit) diff --git a/eth/backend_test.go b/eth/backend_test.go index 220426c17..0379fc843 100644 --- a/eth/backend_test.go +++ b/eth/backend_test.go @@ -16,17 +16,17 @@ func TestMipmapUpgrade(t *testing.T) { addr := common.BytesToAddress([]byte("jeff")) genesis := core.WriteGenesisBlockForTesting(db) - chain := core.GenerateChain(genesis, db, 10, func(i int, gen *core.BlockGen) { + chain, receipts := core.GenerateChain(genesis, db, 10, func(i int, gen *core.BlockGen) { var receipts types.Receipts switch i { case 1: receipt := types.NewReceipt(nil, new(big.Int)) - receipt.SetLogs(vm.Logs{&vm.Log{Address: addr}}) + receipt.Logs = vm.Logs{&vm.Log{Address: addr}} gen.AddUncheckedReceipt(receipt) receipts = types.Receipts{receipt} case 2: receipt := types.NewReceipt(nil, new(big.Int)) - receipt.SetLogs(vm.Logs{&vm.Log{Address: addr}}) + receipt.Logs = vm.Logs{&vm.Log{Address: addr}} gen.AddUncheckedReceipt(receipt) receipts = types.Receipts{receipt} } @@ -37,7 +37,7 @@ func TestMipmapUpgrade(t *testing.T) { t.Fatal(err) } }) - for _, block := range chain { + for i, block := range chain { core.WriteBlock(db, block) if err := core.WriteCanonicalHash(db, block.Hash(), block.NumberU64()); err != nil { t.Fatalf("failed to insert block number: %v", err) @@ -45,7 +45,7 @@ func TestMipmapUpgrade(t *testing.T) { if err := core.WriteHeadBlockHash(db, block.Hash()); err != nil { t.Fatalf("failed to insert block number: %v", err) } - if err := core.PutBlockReceipts(db, block, block.Receipts()); err != nil { + if err := core.PutBlockReceipts(db, block.Hash(), receipts[i]); err != nil { t.Fatal("error writing block receipts:", err) } } diff --git a/eth/downloader/downloader.go b/eth/downloader/downloader.go index 0298dfa0b..4bcbd8557 100644 --- a/eth/downloader/downloader.go +++ b/eth/downloader/downloader.go @@ -18,7 +18,9 @@ package downloader import ( + "crypto/rand" "errors" + "fmt" "math" "math/big" "strings" @@ -59,9 +61,11 @@ var ( maxQueuedStates = 256 * 1024 // [eth/63] Maximum number of state requests to queue (DOS protection) maxResultsProcess = 256 // Number of download results to import at once into the chain - headerCheckFrequency = 100 // Verification frequency of the downloaded headers during fast sync - minCheckedHeaders = 2048 // Number of headers to verify fully when approaching the chain head - minFullBlocks = 1024 // Number of blocks to retrieve fully even in fast sync + fsHeaderCheckFrequency = 100 // Verification frequency of the downloaded headers during fast sync + fsHeaderSafetyNet = 2048 // Number of headers to discard in case a chain violation is detected + fsHeaderForceVerify = 24 // Number of headers to verify before and after the pivot to accept it + fsPivotInterval = 512 // Number of headers out of which to randomize the pivot point + fsMinFullBlocks = 1024 // Number of blocks to retrieve fully even in fast sync ) var ( @@ -85,12 +89,14 @@ var ( errCancelHeaderFetch = errors.New("block header download canceled (requested)") errCancelBodyFetch = errors.New("block body download canceled (requested)") errCancelReceiptFetch = errors.New("receipt download canceled (requested)") + errCancelStateFetch = errors.New("state data download canceled (requested)") errNoSyncActive = errors.New("no sync active") ) type Downloader struct { - mode SyncMode // Synchronisation mode defining the strategies used - mux *event.TypeMux // Event multiplexer to announce sync operation events + mode SyncMode // Synchronisation mode defining the strategy used (per sync cycle) + noFast bool // Flag to disable fast syncing in case of a security error + mux *event.TypeMux // Event multiplexer to announce sync operation events queue *queue // Scheduler for selecting the hashes to download peers *peerSet // Set of active peers from which download can proceed @@ -150,13 +156,13 @@ type Downloader struct { } // New creates a new downloader to fetch hashes and blocks from remote peers. -func New(mode SyncMode, stateDb ethdb.Database, mux *event.TypeMux, hasHeader headerCheckFn, hasBlock blockCheckFn, getHeader headerRetrievalFn, +func New(stateDb ethdb.Database, mux *event.TypeMux, hasHeader headerCheckFn, hasBlock blockCheckFn, getHeader headerRetrievalFn, getBlock blockRetrievalFn, headHeader headHeaderRetrievalFn, headBlock headBlockRetrievalFn, headFastBlock headFastBlockRetrievalFn, commitHeadBlock headBlockCommitterFn, getTd tdRetrievalFn, insertHeaders headerChainInsertFn, insertBlocks blockChainInsertFn, insertReceipts receiptChainInsertFn, rollback chainRollbackFn, dropPeer peerDropFn) *Downloader { return &Downloader{ - mode: mode, + mode: FullSync, mux: mux, queue: newQueue(stateDb), peers: newPeerSet(), @@ -188,19 +194,28 @@ func New(mode SyncMode, stateDb ethdb.Database, mux *event.TypeMux, hasHeader he } } -// Boundaries retrieves the synchronisation boundaries, specifically the origin -// block where synchronisation started at (may have failed/suspended) and the -// latest known block which the synchonisation targets. -func (d *Downloader) Boundaries() (uint64, uint64) { +// Progress retrieves the synchronisation boundaries, specifically the origin +// block where synchronisation started at (may have failed/suspended); the block +// or header sync is currently at; and the latest known block which the sync targets. +func (d *Downloader) Progress() (uint64, uint64, uint64) { d.syncStatsLock.RLock() defer d.syncStatsLock.RUnlock() - return d.syncStatsChainOrigin, d.syncStatsChainHeight + current := uint64(0) + switch d.mode { + case FullSync: + current = d.headBlock().NumberU64() + case FastSync: + current = d.headFastBlock().NumberU64() + case LightSync: + current = d.headHeader().Number.Uint64() + } + return d.syncStatsChainOrigin, current, d.syncStatsChainHeight } // Synchronising returns whether the downloader is currently retrieving blocks. func (d *Downloader) Synchronising() bool { - return atomic.LoadInt32(&d.synchronising) > 0 + return atomic.LoadInt32(&d.synchronising) > 0 || atomic.LoadInt32(&d.processing) > 0 } // RegisterPeer injects a new download peer into the set of block source to be @@ -233,10 +248,10 @@ func (d *Downloader) UnregisterPeer(id string) error { // Synchronise tries to sync up our local block chain with a remote peer, both // adding various sanity checks as well as wrapping it with various log entries. -func (d *Downloader) Synchronise(id string, head common.Hash, td *big.Int) { +func (d *Downloader) Synchronise(id string, head common.Hash, td *big.Int, mode SyncMode) { glog.V(logger.Detail).Infof("Attempting synchronisation: %v, head [%x…], TD %v", id, head[:4], td) - switch err := d.synchronise(id, head, td); err { + switch err := d.synchronise(id, head, td, mode); err { case nil: glog.V(logger.Detail).Infof("Synchronisation completed") @@ -258,7 +273,7 @@ func (d *Downloader) Synchronise(id string, head common.Hash, td *big.Int) { // synchronise will select the peer and use it for synchronising. If an empty string is given // it will use the best peer possible and synchronize if it's TD is higher than our own. If any of the // checks fail an error will be returned. This method is synchronous -func (d *Downloader) synchronise(id string, hash common.Hash, td *big.Int) error { +func (d *Downloader) synchronise(id string, hash common.Hash, td *big.Int, mode SyncMode) error { // Mock out the synchonisation if testing if d.synchroniseMock != nil { return d.synchroniseMock(id, hash) @@ -298,6 +313,11 @@ func (d *Downloader) synchronise(id string, hash common.Hash, td *big.Int) error d.cancelCh = make(chan struct{}) d.cancelLock.Unlock() + // Set the requested sync mode, unless it's forbidden + d.mode = mode + if d.mode == FastSync && d.noFast { + d.mode = FullSync + } // Retrieve the origin peer and initiate the downloading process p := d.peers.Peer(id) if p == nil { @@ -306,13 +326,6 @@ func (d *Downloader) synchronise(id string, hash common.Hash, td *big.Int) error return d.syncWithPeer(p, hash, td) } -/* -// Has checks if the downloader knows about a particular hash, meaning that its -// either already downloaded of pending retrieval. -func (d *Downloader) Has(hash common.Hash) bool { - return d.queue.Has(hash) -} -*/ // syncWithPeer starts a block synchronization based on the hash chain from the // specified peer and head hash. func (d *Downloader) syncWithPeer(p *peer, hash common.Hash, td *big.Int) (err error) { @@ -387,8 +400,28 @@ func (d *Downloader) syncWithPeer(p *peer, hash common.Hash, td *big.Int) (err e // Initiate the sync using a concurrent header and content retrieval algorithm pivot := uint64(0) - if latest > uint64(minFullBlocks) { - pivot = latest - uint64(minFullBlocks) + switch d.mode { + case LightSync: + pivot = latest + + case FastSync: + // Calculate the new fast/slow sync pivot point + pivotOffset, err := rand.Int(rand.Reader, big.NewInt(int64(fsPivotInterval))) + if err != nil { + panic(fmt.Sprintf("Failed to access crypto random source: %v", err)) + } + if latest > uint64(fsMinFullBlocks)+pivotOffset.Uint64() { + pivot = latest - uint64(fsMinFullBlocks) - pivotOffset.Uint64() + } + // If the point is below the origin, move origin back to ensure state download + if pivot < origin { + if pivot > 0 { + origin = pivot - 1 + } else { + origin = 0 + } + } + glog.V(logger.Debug).Infof("Fast syncing until pivot block #%d", pivot) } d.queue.Prepare(origin+1, d.mode, pivot) @@ -396,10 +429,10 @@ func (d *Downloader) syncWithPeer(p *peer, hash common.Hash, td *big.Int) (err e d.syncInitHook(origin, latest) } errc := make(chan error, 4) - go func() { errc <- d.fetchHeaders(p, td, origin+1, latest) }() // Headers are always retrieved - go func() { errc <- d.fetchBodies(origin + 1) }() // Bodies are retrieved during normal and fast sync - go func() { errc <- d.fetchReceipts(origin + 1) }() // Receipts are retrieved during fast sync - go func() { errc <- d.fetchNodeData() }() // Node state data is retrieved during fast sync + go func() { errc <- d.fetchHeaders(p, td, origin+1) }() // Headers are always retrieved + go func() { errc <- d.fetchBodies(origin + 1) }() // Bodies are retrieved during normal and fast sync + go func() { errc <- d.fetchReceipts(origin + 1) }() // Receipts are retrieved during fast sync + go func() { errc <- d.fetchNodeData() }() // Node state data is retrieved during fast sync // If any fetcher fails, cancel the others var fail error @@ -844,7 +877,7 @@ func (d *Downloader) fetchBlocks61(from uint64) error { for _, peer := range idles { // Short circuit if throttling activated - if d.queue.ThrottleBlocks() { + if d.queue.ShouldThrottleBlocks() { throttled = true break } @@ -860,8 +893,13 @@ func (d *Downloader) fetchBlocks61(from uint64) error { } // Fetch the chunk and make sure any errors return the hashes to the queue if err := peer.Fetch61(request); err != nil { - glog.V(logger.Error).Infof("%v: fetch failed, rescheduling", peer) - d.queue.CancelBlocks(request) + // Although we could try and make an attempt to fix this, this error really + // means that we've double allocated a fetch task to a peer. If that is the + // case, the internal state of the downloader and the queue is very wrong so + // better hard crash and note the error instead of silently accumulating into + // a much bigger issue. + panic(fmt.Sprintf("%v: fetch assignment failed, hard panic", peer)) + d.queue.CancelBlocks(request) // noop for now } } // Make sure that we have peers available for fetching. If all peers have been tried @@ -1051,28 +1089,34 @@ func (d *Downloader) findAncestor(p *peer) (uint64, error) { // // The queue parameter can be used to switch between queuing headers for block // body download too, or directly import as pure header chains. -func (d *Downloader) fetchHeaders(p *peer, td *big.Int, from, latest uint64) error { +func (d *Downloader) fetchHeaders(p *peer, td *big.Int, from uint64) error { glog.V(logger.Debug).Infof("%v: downloading headers from #%d", p, from) defer glog.V(logger.Debug).Infof("%v: header download terminated", p) + // Calculate the pivoting point for switching from fast to slow sync + pivot := d.queue.FastSyncPivot() + // Keep a count of uncertain headers to roll back rollback := []*types.Header{} defer func() { if len(rollback) > 0 { + // Flatten the headers and roll them back hashes := make([]common.Hash, len(rollback)) for i, header := range rollback { hashes[i] = header.Hash() } + lh, lfb, lb := d.headHeader().Number, d.headFastBlock().Number(), d.headBlock().Number() d.rollback(hashes) + glog.V(logger.Warn).Infof("Rolled back %d headers (LH: %d->%d, FB: %d->%d, LB: %d->%d)", + len(hashes), lh, d.headHeader().Number, lfb, d.headFastBlock().Number(), lb, d.headBlock().Number()) + + // If we're already past the pivot point, this could be an attack, disable fast sync + if rollback[len(rollback)-1].Number.Uint64() > pivot { + d.noFast = true + } } }() - // Calculate the pivoting point for switching from fast to slow sync - pivot := uint64(0) - if d.mode == FastSync && latest > uint64(minFullBlocks) { - pivot = latest - uint64(minFullBlocks) - } else if d.mode == LightSync { - pivot = latest - } + // Create a timeout timer, and the associated hash fetcher request := time.Now() // time of the last fetch request timeout := time.NewTimer(0) // timer to dump a non-responsive active peer @@ -1135,6 +1179,19 @@ func (d *Downloader) fetchHeaders(p *peer, td *big.Int, from, latest uint64) err if !gotHeaders && td.Cmp(d.getTd(d.headBlock().Hash())) > 0 { return errStallingPeer } + // If fast or light syncing, ensure promised headers are indeed delivered. This is + // needed to detect scenarios where an attacker feeds a bad pivot and then bails out + // of delivering the post-pivot blocks that would flag the invalid content. + // + // This check cannot be executed "as is" for full imports, since blocks may still be + // queued for processing when the header download completes. However, as long as the + // peer gave us something useful, we're already happy/progressed (above check). + if d.mode == FastSync || d.mode == LightSync { + if td.Cmp(d.getTd(d.headHeader().Hash())) > 0 { + return errStallingPeer + } + } + rollback = nil return nil } gotHeaders = true @@ -1152,8 +1209,8 @@ func (d *Downloader) fetchHeaders(p *peer, td *big.Int, from, latest uint64) err } } // If we're importing pure headers, verify based on their recentness - frequency := headerCheckFrequency - if headers[len(headers)-1].Number.Uint64()+uint64(minCheckedHeaders) > pivot { + frequency := fsHeaderCheckFrequency + if headers[len(headers)-1].Number.Uint64()+uint64(fsHeaderForceVerify) > pivot { frequency = 1 } if n, err := d.insertHeaders(headers, frequency); err != nil { @@ -1162,11 +1219,8 @@ func (d *Downloader) fetchHeaders(p *peer, td *big.Int, from, latest uint64) err } // All verifications passed, store newly found uncertain headers rollback = append(rollback, unknown...) - if len(rollback) > minCheckedHeaders { - rollback = append(rollback[:0], rollback[len(rollback)-minCheckedHeaders:]...) - } - if headers[len(headers)-1].Number.Uint64() >= pivot { - rollback = rollback[:0] + if len(rollback) > fsHeaderSafetyNet { + rollback = append(rollback[:0], rollback[len(rollback)-fsHeaderSafetyNet:]...) } } if d.mode == FullSync || d.mode == FastSync { @@ -1230,12 +1284,11 @@ func (d *Downloader) fetchBodies(from uint64) error { expire = func() []string { return d.queue.ExpireBodies(bodyHardTTL) } fetch = func(p *peer, req *fetchRequest) error { return p.FetchBodies(req) } capacity = func(p *peer) int { return p.BlockCapacity() } - getIdles = func() ([]*peer, int) { return d.peers.BodyIdlePeers() } - setIdle = func(p *peer) { p.SetBlocksIdle() } + setIdle = func(p *peer) { p.SetBodiesIdle() } ) err := d.fetchParts(errCancelBodyFetch, d.bodyCh, deliver, d.bodyWakeCh, expire, - d.queue.PendingBlocks, d.queue.InFlightBlocks, d.queue.ThrottleBlocks, d.queue.ReserveBodies, - d.bodyFetchHook, fetch, d.queue.CancelBodies, capacity, getIdles, setIdle, "Body") + d.queue.PendingBlocks, d.queue.InFlightBlocks, d.queue.ShouldThrottleBlocks, d.queue.ReserveBodies, + d.bodyFetchHook, fetch, d.queue.CancelBodies, capacity, d.peers.BodyIdlePeers, setIdle, "Body") glog.V(logger.Debug).Infof("Block body download terminated: %v", err) return err @@ -1252,13 +1305,13 @@ func (d *Downloader) fetchReceipts(from uint64) error { pack := packet.(*receiptPack) return d.queue.DeliverReceipts(pack.peerId, pack.receipts) } - expire = func() []string { return d.queue.ExpireReceipts(bodyHardTTL) } + expire = func() []string { return d.queue.ExpireReceipts(receiptHardTTL) } fetch = func(p *peer, req *fetchRequest) error { return p.FetchReceipts(req) } capacity = func(p *peer) int { return p.ReceiptCapacity() } setIdle = func(p *peer) { p.SetReceiptsIdle() } ) err := d.fetchParts(errCancelReceiptFetch, d.receiptCh, deliver, d.receiptWakeCh, expire, - d.queue.PendingReceipts, d.queue.InFlightReceipts, d.queue.ThrottleReceipts, d.queue.ReserveReceipts, + d.queue.PendingReceipts, d.queue.InFlightReceipts, d.queue.ShouldThrottleReceipts, d.queue.ReserveReceipts, d.receiptFetchHook, fetch, d.queue.CancelReceipts, capacity, d.peers.ReceiptIdlePeers, setIdle, "Receipt") glog.V(logger.Debug).Infof("Receipt download terminated: %v", err) @@ -1307,9 +1360,9 @@ func (d *Downloader) fetchNodeData() error { capacity = func(p *peer) int { return p.NodeDataCapacity() } setIdle = func(p *peer) { p.SetNodeDataIdle() } ) - err := d.fetchParts(errCancelReceiptFetch, d.stateCh, deliver, d.stateWakeCh, expire, + err := d.fetchParts(errCancelStateFetch, d.stateCh, deliver, d.stateWakeCh, expire, d.queue.PendingNodeData, d.queue.InFlightNodeData, throttle, reserve, nil, fetch, - d.queue.CancelNodeData, capacity, d.peers.ReceiptIdlePeers, setIdle, "State") + d.queue.CancelNodeData, capacity, d.peers.NodeDataIdlePeers, setIdle, "State") glog.V(logger.Debug).Infof("Node state data download terminated: %v", err) return err @@ -1323,7 +1376,7 @@ func (d *Downloader) fetchParts(errCancel error, deliveryCh chan dataPack, deliv fetchHook func([]*types.Header), fetch func(*peer, *fetchRequest) error, cancel func(*fetchRequest), capacity func(*peer) int, idle func() ([]*peer, int), setIdle func(*peer), kind string) error { - // Create a ticker to detect expired retreival tasks + // Create a ticker to detect expired retrieval tasks ticker := time.NewTicker(100 * time.Millisecond) defer ticker.Stop() @@ -1366,11 +1419,6 @@ func (d *Downloader) fetchParts(errCancel error, deliveryCh chan dataPack, deliv // The hash chain is invalid (blocks are not ordered properly), abort return err - case errInvalidBody, errInvalidReceipt: - // The peer delivered something very bad, drop immediately - glog.V(logger.Error).Infof("%s: delivered invalid %s, dropping", peer, strings.ToLower(kind)) - d.dropPeer(peer.id) - case errNoFetchesPending: // Peer probably timed out with its delivery but came through // in the end, demote, but allow to to pull from this peer. @@ -1475,8 +1523,13 @@ func (d *Downloader) fetchParts(errCancel error, deliveryCh chan dataPack, deliv fetchHook(request.Headers) } if err := fetch(peer, request); err != nil { - glog.V(logger.Error).Infof("%v: %s fetch failed, rescheduling", peer, strings.ToLower(kind)) - cancel(request) + // Although we could try and make an attempt to fix this, this error really + // means that we've double allocated a fetch task to a peer. If that is the + // case, the internal state of the downloader and the queue is very wrong so + // better hard crash and note the error instead of silently accumulating into + // a much bigger issue. + panic(fmt.Sprintf("%v: %s fetch assignment failed, hard panic", peer, strings.ToLower(kind))) + cancel(request) // noop for now } running = true } @@ -1526,6 +1579,7 @@ func (d *Downloader) process() { // Repeat the processing as long as there are results to process for { // Fetch the next batch of results + pivot := d.queue.FastSyncPivot() // Fetch pivot before results to prevent reset race results := d.queue.TakeResults() if len(results) == 0 { return @@ -1545,7 +1599,6 @@ func (d *Downloader) process() { } // Retrieve the a batch of results to import var ( - headers = make([]*types.Header, 0, maxResultsProcess) blocks = make([]*types.Block, 0, maxResultsProcess) receipts = make([]types.Receipts, 0, maxResultsProcess) ) @@ -1556,11 +1609,9 @@ func (d *Downloader) process() { blocks = append(blocks, types.NewBlockWithHeader(result.Header).WithBody(result.Transactions, result.Uncles)) case d.mode == FastSync: blocks = append(blocks, types.NewBlockWithHeader(result.Header).WithBody(result.Transactions, result.Uncles)) - if result.Header.Number.Uint64() <= d.queue.fastSyncPivot { + if result.Header.Number.Uint64() <= pivot { receipts = append(receipts, result.Receipts) } - case d.mode == LightSync: - headers = append(headers, result.Header) } } // Try to process the results, aborting if there's an error @@ -1569,12 +1620,10 @@ func (d *Downloader) process() { index int ) switch { - case len(headers) > 0: - index, err = d.insertHeaders(headers, headerCheckFrequency) - case len(receipts) > 0: index, err = d.insertReceipts(blocks, receipts) - if err == nil && blocks[len(blocks)-1].NumberU64() == d.queue.fastSyncPivot { + if err == nil && blocks[len(blocks)-1].NumberU64() == pivot { + glog.V(logger.Debug).Infof("Committing block #%d [%x…] as the new head", blocks[len(blocks)-1].Number(), blocks[len(blocks)-1].Hash().Bytes()[:4]) index, err = len(blocks)-1, d.commitHeadBlock(blocks[len(blocks)-1].Hash()) } default: diff --git a/eth/downloader/downloader_test.go b/eth/downloader/downloader_test.go index f01650ebd..ef6f74a6b 100644 --- a/eth/downloader/downloader_test.go +++ b/eth/downloader/downloader_test.go @@ -136,7 +136,7 @@ type downloadTester struct { } // newTester creates a new downloader test mocker. -func newTester(mode SyncMode) *downloadTester { +func newTester() *downloadTester { tester := &downloadTester{ ownHashes: []common.Hash{genesis.Hash()}, ownHeaders: map[common.Hash]*types.Header{genesis.Hash(): genesis.Header()}, @@ -150,7 +150,7 @@ func newTester(mode SyncMode) *downloadTester { peerChainTds: make(map[string]map[common.Hash]*big.Int), } tester.stateDb, _ = ethdb.NewMemDatabase() - tester.downloader = New(mode, tester.stateDb, new(event.TypeMux), tester.hasHeader, tester.hasBlock, tester.getHeader, + tester.downloader = New(tester.stateDb, new(event.TypeMux), tester.hasHeader, tester.hasBlock, tester.getHeader, tester.getBlock, tester.headHeader, tester.headBlock, tester.headFastBlock, tester.commitHeadBlock, tester.getTd, tester.insertHeaders, tester.insertBlocks, tester.insertReceipts, tester.rollback, tester.dropPeer) @@ -158,7 +158,7 @@ func newTester(mode SyncMode) *downloadTester { } // sync starts synchronizing with a remote peer, blocking until it completes. -func (dl *downloadTester) sync(id string, td *big.Int) error { +func (dl *downloadTester) sync(id string, td *big.Int, mode SyncMode) error { dl.lock.RLock() hash := dl.peerHashes[id][0] // If no particular TD was requested, load from the peer's blockchain @@ -170,7 +170,7 @@ func (dl *downloadTester) sync(id string, td *big.Int) error { } dl.lock.RUnlock() - err := dl.downloader.synchronise(id, hash, td) + err := dl.downloader.synchronise(id, hash, td, mode) for { // If the queue is empty and processing stopped, break if dl.downloader.queue.Idle() && atomic.LoadInt32(&dl.downloader.processing) == 0 { @@ -214,7 +214,7 @@ func (dl *downloadTester) headHeader() *types.Header { defer dl.lock.RUnlock() for i := len(dl.ownHashes) - 1; i >= 0; i-- { - if header := dl.getHeader(dl.ownHashes[i]); header != nil { + if header := dl.ownHeaders[dl.ownHashes[i]]; header != nil { return header } } @@ -227,7 +227,7 @@ func (dl *downloadTester) headBlock() *types.Block { defer dl.lock.RUnlock() for i := len(dl.ownHashes) - 1; i >= 0; i-- { - if block := dl.getBlock(dl.ownHashes[i]); block != nil { + if block := dl.ownBlocks[dl.ownHashes[i]]; block != nil { if _, err := dl.stateDb.Get(block.Root().Bytes()); err == nil { return block } @@ -242,7 +242,7 @@ func (dl *downloadTester) headFastBlock() *types.Block { defer dl.lock.RUnlock() for i := len(dl.ownHashes) - 1; i >= 0; i-- { - if block := dl.getBlock(dl.ownHashes[i]); block != nil { + if block := dl.ownBlocks[dl.ownHashes[i]]; block != nil { return block } } @@ -291,7 +291,7 @@ func (dl *downloadTester) insertHeaders(headers []*types.Header, checkFreq int) } dl.ownHashes = append(dl.ownHashes, header.Hash()) dl.ownHeaders[header.Hash()] = header - dl.ownChainTd[header.Hash()] = dl.ownChainTd[header.ParentHash] + dl.ownChainTd[header.Hash()] = new(big.Int).Add(dl.ownChainTd[header.ParentHash], header.Difficulty) } return len(headers), nil } @@ -305,11 +305,13 @@ func (dl *downloadTester) insertBlocks(blocks types.Blocks) (int, error) { if _, ok := dl.ownBlocks[block.ParentHash()]; !ok { return i, errors.New("unknown parent") } - dl.ownHashes = append(dl.ownHashes, block.Hash()) - dl.ownHeaders[block.Hash()] = block.Header() + if _, ok := dl.ownHeaders[block.Hash()]; !ok { + dl.ownHashes = append(dl.ownHashes, block.Hash()) + dl.ownHeaders[block.Hash()] = block.Header() + } dl.ownBlocks[block.Hash()] = block - dl.stateDb.Put(block.Root().Bytes(), []byte{}) - dl.ownChainTd[block.Hash()] = dl.ownChainTd[block.ParentHash()] + dl.stateDb.Put(block.Root().Bytes(), []byte{0x00}) + dl.ownChainTd[block.Hash()] = new(big.Int).Add(dl.ownChainTd[block.ParentHash()], block.Difficulty()) } return len(blocks), nil } @@ -381,7 +383,19 @@ func (dl *downloadTester) newSlowPeer(id string, version int, hashes []common.Ha dl.peerReceipts[id] = make(map[common.Hash]types.Receipts) dl.peerChainTds[id] = make(map[common.Hash]*big.Int) - for _, hash := range hashes { + genesis := hashes[len(hashes)-1] + if header := headers[genesis]; header != nil { + dl.peerHeaders[id][genesis] = header + dl.peerChainTds[id][genesis] = header.Difficulty + } + if block := blocks[genesis]; block != nil { + dl.peerBlocks[id][genesis] = block + dl.peerChainTds[id][genesis] = block.Difficulty() + } + + for i := len(hashes) - 2; i >= 0; i-- { + hash := hashes[i] + if header, ok := headers[hash]; ok { dl.peerHeaders[id][hash] = header if _, ok := dl.peerHeaders[id][header.ParentHash]; ok { @@ -627,21 +641,28 @@ func assertOwnChain(t *testing.T, tester *downloadTester, length int) { // number of items of the various chain components. func assertOwnForkedChain(t *testing.T, tester *downloadTester, common int, lengths []int) { // Initialize the counters for the first fork - headers, blocks, receipts := lengths[0], lengths[0], lengths[0]-minFullBlocks - if receipts < 0 { - receipts = 1 + headers, blocks := lengths[0], lengths[0] + + minReceipts, maxReceipts := lengths[0]-fsMinFullBlocks-fsPivotInterval, lengths[0]-fsMinFullBlocks + if minReceipts < 0 { + minReceipts = 1 + } + if maxReceipts < 0 { + maxReceipts = 1 } // Update the counters for each subsequent fork for _, length := range lengths[1:] { headers += length - common blocks += length - common - receipts += length - common - minFullBlocks + + minReceipts += length - common - fsMinFullBlocks - fsPivotInterval + maxReceipts += length - common - fsMinFullBlocks } switch tester.downloader.mode { case FullSync: - receipts = 1 + minReceipts, maxReceipts = 1, 1 case LightSync: - blocks, receipts = 1, 1 + blocks, minReceipts, maxReceipts = 1, 1, 1 } if hs := len(tester.ownHeaders); hs != headers { t.Fatalf("synchronised headers mismatch: have %v, want %v", hs, headers) @@ -649,14 +670,20 @@ func assertOwnForkedChain(t *testing.T, tester *downloadTester, common int, leng if bs := len(tester.ownBlocks); bs != blocks { t.Fatalf("synchronised blocks mismatch: have %v, want %v", bs, blocks) } - if rs := len(tester.ownReceipts); rs != receipts { - t.Fatalf("synchronised receipts mismatch: have %v, want %v", rs, receipts) + if rs := len(tester.ownReceipts); rs < minReceipts || rs > maxReceipts { + t.Fatalf("synchronised receipts mismatch: have %v, want between [%v, %v]", rs, minReceipts, maxReceipts) } // Verify the state trie too for fast syncs if tester.downloader.mode == FastSync { - if index := lengths[len(lengths)-1] - minFullBlocks - 1; index > 0 { - if statedb := state.New(tester.ownHeaders[tester.ownHashes[index]].Root, tester.stateDb); statedb == nil { - t.Fatalf("state reconstruction failed") + index := 0 + if pivot := int(tester.downloader.queue.fastSyncPivot); pivot < common { + index = pivot + } else { + index = len(tester.ownHashes) - lengths[len(lengths)-1] + int(tester.downloader.queue.fastSyncPivot) + } + if index > 0 { + if statedb, err := state.New(tester.ownHeaders[tester.ownHashes[index]].Root, tester.stateDb); statedb == nil || err != nil { + t.Fatalf("state reconstruction failed: %v", err) } } } @@ -678,11 +705,11 @@ func testCanonicalSynchronisation(t *testing.T, protocol int, mode SyncMode) { targetBlocks := blockCacheLimit - 15 hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil) - tester := newTester(mode) + tester := newTester() tester.newPeer("peer", protocol, hashes, headers, blocks, receipts) // Synchronise with the peer and make sure all relevant data was retrieved - if err := tester.sync("peer", nil); err != nil { + if err := tester.sync("peer", nil, mode); err != nil { t.Fatalf("failed to synchronise blocks: %v", err) } assertOwnChain(t, tester, targetBlocks+1) @@ -702,7 +729,7 @@ func testThrottling(t *testing.T, protocol int, mode SyncMode) { targetBlocks := 8 * blockCacheLimit hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil) - tester := newTester(mode) + tester := newTester() tester.newPeer("peer", protocol, hashes, headers, blocks, receipts) // Wrap the importer to allow stepping @@ -714,7 +741,7 @@ func testThrottling(t *testing.T, protocol int, mode SyncMode) { // Start a synchronisation concurrently errc := make(chan error) go func() { - errc <- tester.sync("peer", nil) + errc <- tester.sync("peer", nil, mode) }() // Iteratively take some blocks, always checking the retrieval count for { @@ -726,10 +753,11 @@ func testThrottling(t *testing.T, protocol int, mode SyncMode) { break } // Wait a bit for sync to throttle itself - var cached int + var cached, frozen int for start := time.Now(); time.Since(start) < time.Second; { time.Sleep(25 * time.Millisecond) + tester.lock.RLock() tester.downloader.queue.lock.RLock() cached = len(tester.downloader.queue.blockDonePool) if mode == FastSync { @@ -739,16 +767,23 @@ func testThrottling(t *testing.T, protocol int, mode SyncMode) { } } } + frozen = int(atomic.LoadUint32(&blocked)) + retrieved = len(tester.ownBlocks) tester.downloader.queue.lock.RUnlock() + tester.lock.RUnlock() - if cached == blockCacheLimit || len(tester.ownBlocks)+cached+int(atomic.LoadUint32(&blocked)) == targetBlocks+1 { + if cached == blockCacheLimit || retrieved+cached+frozen == targetBlocks+1 { break } } // Make sure we filled up the cache, then exhaust it time.Sleep(25 * time.Millisecond) // give it a chance to screw up - if cached != blockCacheLimit && len(tester.ownBlocks)+cached+int(atomic.LoadUint32(&blocked)) != targetBlocks+1 { - t.Fatalf("block count mismatch: have %v, want %v (owned %v, target %v)", cached, blockCacheLimit, len(tester.ownBlocks), targetBlocks+1) + + tester.lock.RLock() + retrieved = len(tester.ownBlocks) + tester.lock.RUnlock() + if cached != blockCacheLimit && retrieved+cached+frozen != targetBlocks+1 { + t.Fatalf("block count mismatch: have %v, want %v (owned %v, blocked %v, target %v)", cached, blockCacheLimit, retrieved, frozen, targetBlocks+1) } // Permit the blocked blocks to import if atomic.LoadUint32(&blocked) > 0 { @@ -779,18 +814,18 @@ func testForkedSynchronisation(t *testing.T, protocol int, mode SyncMode) { common, fork := MaxHashFetch, 2*MaxHashFetch hashesA, hashesB, headersA, headersB, blocksA, blocksB, receiptsA, receiptsB := makeChainFork(common+fork, fork, genesis, nil) - tester := newTester(mode) + tester := newTester() tester.newPeer("fork A", protocol, hashesA, headersA, blocksA, receiptsA) tester.newPeer("fork B", protocol, hashesB, headersB, blocksB, receiptsB) // Synchronise with the peer and make sure all blocks were retrieved - if err := tester.sync("fork A", nil); err != nil { + if err := tester.sync("fork A", nil, mode); err != nil { t.Fatalf("failed to synchronise blocks: %v", err) } assertOwnChain(t, tester, common+fork+1) // Synchronise with the second peer and make sure that fork is pulled too - if err := tester.sync("fork B", nil); err != nil { + if err := tester.sync("fork B", nil, mode); err != nil { t.Fatalf("failed to synchronise blocks: %v", err) } assertOwnForkedChain(t, tester, common+1, []int{common + fork + 1, common + fork + 1}) @@ -798,7 +833,7 @@ func testForkedSynchronisation(t *testing.T, protocol int, mode SyncMode) { // Tests that an inactive downloader will not accept incoming hashes and blocks. func TestInactiveDownloader61(t *testing.T) { - tester := newTester(FullSync) + tester := newTester() // Check that neither hashes nor blocks are accepted if err := tester.downloader.DeliverHashes("bad peer", []common.Hash{}); err != errNoSyncActive { @@ -812,7 +847,7 @@ func TestInactiveDownloader61(t *testing.T) { // Tests that an inactive downloader will not accept incoming block headers and // bodies. func TestInactiveDownloader62(t *testing.T) { - tester := newTester(FullSync) + tester := newTester() // Check that neither block headers nor bodies are accepted if err := tester.downloader.DeliverHeaders("bad peer", []*types.Header{}); err != errNoSyncActive { @@ -826,7 +861,7 @@ func TestInactiveDownloader62(t *testing.T) { // Tests that an inactive downloader will not accept incoming block headers, // bodies and receipts. func TestInactiveDownloader63(t *testing.T) { - tester := newTester(FullSync) + tester := newTester() // Check that neither block headers nor bodies are accepted if err := tester.downloader.DeliverHeaders("bad peer", []*types.Header{}); err != errNoSyncActive { @@ -860,7 +895,7 @@ func testCancel(t *testing.T, protocol int, mode SyncMode) { } hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil) - tester := newTester(mode) + tester := newTester() tester.newPeer("peer", protocol, hashes, headers, blocks, receipts) // Make sure canceling works with a pristine downloader @@ -869,7 +904,7 @@ func testCancel(t *testing.T, protocol int, mode SyncMode) { t.Errorf("download queue not idle") } // Synchronise with the peer, but cancel afterwards - if err := tester.sync("peer", nil); err != nil { + if err := tester.sync("peer", nil, mode); err != nil { t.Fatalf("failed to synchronise blocks: %v", err) } tester.downloader.cancel() @@ -893,12 +928,12 @@ func testMultiSynchronisation(t *testing.T, protocol int, mode SyncMode) { targetBlocks := targetPeers*blockCacheLimit - 15 hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil) - tester := newTester(mode) + tester := newTester() for i := 0; i < targetPeers; i++ { id := fmt.Sprintf("peer #%d", i) tester.newPeer(id, protocol, hashes[i*blockCacheLimit:], headers, blocks, receipts) } - if err := tester.sync("peer #0", nil); err != nil { + if err := tester.sync("peer #0", nil, mode); err != nil { t.Fatalf("failed to synchronise blocks: %v", err) } assertOwnChain(t, tester, targetBlocks+1) @@ -920,14 +955,14 @@ func testMultiProtoSync(t *testing.T, protocol int, mode SyncMode) { hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil) // Create peers of every type - tester := newTester(mode) - tester.newPeer("peer 61", 61, hashes, headers, blocks, receipts) - tester.newPeer("peer 62", 62, hashes, headers, blocks, receipts) + tester := newTester() + tester.newPeer("peer 61", 61, hashes, nil, blocks, nil) + tester.newPeer("peer 62", 62, hashes, headers, blocks, nil) tester.newPeer("peer 63", 63, hashes, headers, blocks, receipts) tester.newPeer("peer 64", 64, hashes, headers, blocks, receipts) - // Synchronise with the requestd peer and make sure all blocks were retrieved - if err := tester.sync(fmt.Sprintf("peer %d", protocol), nil); err != nil { + // Synchronise with the requested peer and make sure all blocks were retrieved + if err := tester.sync(fmt.Sprintf("peer %d", protocol), nil, mode); err != nil { t.Fatalf("failed to synchronise blocks: %v", err) } assertOwnChain(t, tester, targetBlocks+1) @@ -955,7 +990,7 @@ func testEmptyShortCircuit(t *testing.T, protocol int, mode SyncMode) { targetBlocks := 2*blockCacheLimit - 15 hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil) - tester := newTester(mode) + tester := newTester() tester.newPeer("peer", protocol, hashes, headers, blocks, receipts) // Instrument the downloader to signal body requests @@ -967,7 +1002,7 @@ func testEmptyShortCircuit(t *testing.T, protocol int, mode SyncMode) { atomic.AddInt32(&receiptsHave, int32(len(headers))) } // Synchronise with the peer and make sure all blocks were retrieved - if err := tester.sync("peer", nil); err != nil { + if err := tester.sync("peer", nil, mode); err != nil { t.Fatalf("failed to synchronise blocks: %v", err) } assertOwnChain(t, tester, targetBlocks+1) @@ -980,7 +1015,7 @@ func testEmptyShortCircuit(t *testing.T, protocol int, mode SyncMode) { } } for hash, receipt := range receipts { - if mode == FastSync && len(receipt) > 0 && headers[hash].Number.Uint64() <= uint64(targetBlocks-minFullBlocks) { + if mode == FastSync && len(receipt) > 0 && headers[hash].Number.Uint64() <= tester.downloader.queue.fastSyncPivot { receiptsNeeded++ } } @@ -1006,19 +1041,19 @@ func testMissingHeaderAttack(t *testing.T, protocol int, mode SyncMode) { targetBlocks := blockCacheLimit - 15 hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil) - tester := newTester(mode) + tester := newTester() // Attempt a full sync with an attacker feeding gapped headers tester.newPeer("attack", protocol, hashes, headers, blocks, receipts) missing := targetBlocks / 2 delete(tester.peerHeaders["attack"], hashes[missing]) - if err := tester.sync("attack", nil); err == nil { + if err := tester.sync("attack", nil, mode); err == nil { t.Fatalf("succeeded attacker synchronisation") } // Synchronise with the valid peer and make sure sync succeeds tester.newPeer("valid", protocol, hashes, headers, blocks, receipts) - if err := tester.sync("valid", nil); err != nil { + if err := tester.sync("valid", nil, mode); err != nil { t.Fatalf("failed to synchronise blocks: %v", err) } assertOwnChain(t, tester, targetBlocks+1) @@ -1038,7 +1073,7 @@ func testShiftedHeaderAttack(t *testing.T, protocol int, mode SyncMode) { targetBlocks := blockCacheLimit - 15 hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil) - tester := newTester(mode) + tester := newTester() // Attempt a full sync with an attacker feeding shifted headers tester.newPeer("attack", protocol, hashes, headers, blocks, receipts) @@ -1046,12 +1081,12 @@ func testShiftedHeaderAttack(t *testing.T, protocol int, mode SyncMode) { delete(tester.peerBlocks["attack"], hashes[len(hashes)-2]) delete(tester.peerReceipts["attack"], hashes[len(hashes)-2]) - if err := tester.sync("attack", nil); err == nil { + if err := tester.sync("attack", nil, mode); err == nil { t.Fatalf("succeeded attacker synchronisation") } // Synchronise with the valid peer and make sure sync succeeds tester.newPeer("valid", protocol, hashes, headers, blocks, receipts) - if err := tester.sync("valid", nil); err != nil { + if err := tester.sync("valid", nil, mode); err != nil { t.Fatalf("failed to synchronise blocks: %v", err) } assertOwnChain(t, tester, targetBlocks+1) @@ -1064,92 +1099,81 @@ func TestInvalidHeaderRollback64Light(t *testing.T) { testInvalidHeaderRollback( func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) { // Create a small enough block chain to download - targetBlocks := 3*minCheckedHeaders + minFullBlocks + targetBlocks := 3*fsHeaderSafetyNet + fsMinFullBlocks hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil) - tester := newTester(mode) + tester := newTester() - // Attempt to sync with an attacker that feeds junk during the fast sync phase + // Attempt to sync with an attacker that feeds junk during the fast sync phase. + // This should result in the last fsHeaderSafetyNet headers being rolled back. tester.newPeer("fast-attack", protocol, hashes, headers, blocks, receipts) - missing := minCheckedHeaders + MaxHeaderFetch + 1 + missing := fsHeaderSafetyNet + MaxHeaderFetch + 1 delete(tester.peerHeaders["fast-attack"], hashes[len(hashes)-missing]) - if err := tester.sync("fast-attack", nil); err == nil { + if err := tester.sync("fast-attack", nil, mode); err == nil { t.Fatalf("succeeded fast attacker synchronisation") } if head := tester.headHeader().Number.Int64(); int(head) > MaxHeaderFetch { - t.Fatalf("rollback head mismatch: have %v, want at most %v", head, MaxHeaderFetch) + t.Errorf("rollback head mismatch: have %v, want at most %v", head, MaxHeaderFetch) } - // Attempt to sync with an attacker that feeds junk during the block import phase + // Attempt to sync with an attacker that feeds junk during the block import phase. + // This should result in both the last fsHeaderSafetyNet number of headers being + // rolled back, and also the pivot point being reverted to a non-block status. tester.newPeer("block-attack", protocol, hashes, headers, blocks, receipts) - missing = 3*minCheckedHeaders + MaxHeaderFetch + 1 + missing = 3*fsHeaderSafetyNet + MaxHeaderFetch + 1 delete(tester.peerHeaders["block-attack"], hashes[len(hashes)-missing]) - if err := tester.sync("block-attack", nil); err == nil { + if err := tester.sync("block-attack", nil, mode); err == nil { t.Fatalf("succeeded block attacker synchronisation") } + if head := tester.headHeader().Number.Int64(); int(head) > 2*fsHeaderSafetyNet+MaxHeaderFetch { + t.Errorf("rollback head mismatch: have %v, want at most %v", head, 2*fsHeaderSafetyNet+MaxHeaderFetch) + } if mode == FastSync { - // Fast sync should not discard anything below the verified pivot point - if head := tester.headHeader().Number.Int64(); int(head) < 3*minCheckedHeaders { - t.Fatalf("rollback head mismatch: have %v, want at least %v", head, 3*minCheckedHeaders) + if head := tester.headBlock().NumberU64(); head != 0 { + t.Errorf("fast sync pivot block #%d not rolled back", head) } - } else if mode == LightSync { - // Light sync should still discard data as before - if head := tester.headHeader().Number.Int64(); int(head) > 3*minCheckedHeaders { - t.Fatalf("rollback head mismatch: have %v, want at most %v", head, 3*minCheckedHeaders) - } - } - // Synchronise with the valid peer and make sure sync succeeds - tester.newPeer("valid", protocol, hashes, headers, blocks, receipts) - if err := tester.sync("valid", nil); err != nil { - t.Fatalf("failed to synchronise blocks: %v", err) } - assertOwnChain(t, tester, targetBlocks+1) -} + // Attempt to sync with an attacker that withholds promised blocks after the + // fast sync pivot point. This could be a trial to leave the node with a bad + // but already imported pivot block. + tester.newPeer("withhold-attack", protocol, hashes, headers, blocks, receipts) + missing = 3*fsHeaderSafetyNet + MaxHeaderFetch + 1 -// Tests that if a peer sends an invalid block piece (body or receipt) for a -// requested block, it gets dropped immediately by the downloader. -func TestInvalidContentAttack62(t *testing.T) { testInvalidContentAttack(t, 62, FullSync) } -func TestInvalidContentAttack63Full(t *testing.T) { testInvalidContentAttack(t, 63, FullSync) } -func TestInvalidContentAttack63Fast(t *testing.T) { testInvalidContentAttack(t, 63, FastSync) } -func TestInvalidContentAttack64Full(t *testing.T) { testInvalidContentAttack(t, 64, FullSync) } -func TestInvalidContentAttack64Fast(t *testing.T) { testInvalidContentAttack(t, 64, FastSync) } -func TestInvalidContentAttack64Light(t *testing.T) { testInvalidContentAttack(t, 64, LightSync) } - -func testInvalidContentAttack(t *testing.T, protocol int, mode SyncMode) { - // Create two peers, one feeding invalid block bodies - targetBlocks := 4*blockCacheLimit - 15 - hashes, headers, validBlocks, validReceipts := makeChain(targetBlocks, 0, genesis, nil) - - invalidBlocks := make(map[common.Hash]*types.Block) - for hash, block := range validBlocks { - invalidBlocks[hash] = types.NewBlockWithHeader(block.Header()) - } - invalidReceipts := make(map[common.Hash]types.Receipts) - for hash, _ := range validReceipts { - invalidReceipts[hash] = types.Receipts{&types.Receipt{}} + tester.downloader.noFast = false + tester.downloader.syncInitHook = func(uint64, uint64) { + for i := missing; i <= len(hashes); i++ { + delete(tester.peerHeaders["withhold-attack"], hashes[len(hashes)-i]) + } + tester.downloader.syncInitHook = nil } - tester := newTester(mode) - tester.newPeer("valid", protocol, hashes, headers, validBlocks, validReceipts) - if mode != LightSync { - tester.newPeer("body attack", protocol, hashes, headers, invalidBlocks, validReceipts) + if err := tester.sync("withhold-attack", nil, mode); err == nil { + t.Fatalf("succeeded withholding attacker synchronisation") + } + if head := tester.headHeader().Number.Int64(); int(head) > 2*fsHeaderSafetyNet+MaxHeaderFetch { + t.Errorf("rollback head mismatch: have %v, want at most %v", head, 2*fsHeaderSafetyNet+MaxHeaderFetch) } if mode == FastSync { - tester.newPeer("receipt attack", protocol, hashes, headers, validBlocks, invalidReceipts) + if head := tester.headBlock().NumberU64(); head != 0 { + t.Errorf("fast sync pivot block #%d not rolled back", head) + } } - // Synchronise with the valid peer (will pull contents from the attacker too) - if err := tester.sync("valid", nil); err != nil { + // Synchronise with the valid peer and make sure sync succeeds. Since the last + // rollback should also disable fast syncing for this process, verify that we + // did a fresh full sync. Note, we can't assert anything about the receipts + // since we won't purge the database of them, hence we can't use asserOwnChain. + tester.newPeer("valid", protocol, hashes, headers, blocks, receipts) + if err := tester.sync("valid", nil, mode); err != nil { t.Fatalf("failed to synchronise blocks: %v", err) } - assertOwnChain(t, tester, targetBlocks+1) - - // Make sure the attacker was detected and dropped in the mean time - if _, ok := tester.peerHashes["body attack"]; ok { - t.Fatalf("block body attacker not detected/dropped") + if hs := len(tester.ownHeaders); hs != len(headers) { + t.Fatalf("synchronised headers mismatch: have %v, want %v", hs, len(headers)) } - if _, ok := tester.peerHashes["receipt attack"]; ok { - t.Fatalf("receipt attacker not detected/dropped") + if mode != LightSync { + if bs := len(tester.ownBlocks); bs != len(blocks) { + t.Fatalf("synchronised blocks mismatch: have %v, want %v", bs, len(blocks)) + } } } @@ -1164,11 +1188,11 @@ func TestHighTDStarvationAttack64Fast(t *testing.T) { testHighTDStarvationAttac func TestHighTDStarvationAttack64Light(t *testing.T) { testHighTDStarvationAttack(t, 64, LightSync) } func testHighTDStarvationAttack(t *testing.T, protocol int, mode SyncMode) { - tester := newTester(mode) + tester := newTester() hashes, headers, blocks, receipts := makeChain(0, 0, genesis, nil) tester.newPeer("attack", protocol, []common.Hash{hashes[0]}, headers, blocks, receipts) - if err := tester.sync("attack", big.NewInt(1000000)); err != errStallingPeer { + if err := tester.sync("attack", big.NewInt(1000000), mode); err != errStallingPeer { t.Fatalf("synchronisation error mismatch: have %v, want %v", err, errStallingPeer) } } @@ -1206,7 +1230,7 @@ func testBlockHeaderAttackerDropping(t *testing.T, protocol int) { {errCancelBodyFetch, false}, // Synchronisation was canceled, origin may be innocent, don't drop } // Run the tests and check disconnection status - tester := newTester(FullSync) + tester := newTester() for i, tt := range tests { // Register a new peer and ensure it's presence id := fmt.Sprintf("test %d", i) @@ -1219,120 +1243,125 @@ func testBlockHeaderAttackerDropping(t *testing.T, protocol int) { // Simulate a synchronisation and check the required result tester.downloader.synchroniseMock = func(string, common.Hash) error { return tt.result } - tester.downloader.Synchronise(id, genesis.Hash(), big.NewInt(1000)) + tester.downloader.Synchronise(id, genesis.Hash(), big.NewInt(1000), FullSync) if _, ok := tester.peerHashes[id]; !ok != tt.drop { t.Errorf("test %d: peer drop mismatch for %v: have %v, want %v", i, tt.result, !ok, tt.drop) } } } -// Tests that synchronisation boundaries (origin block number and highest block -// number) is tracked and updated correctly. -func TestSyncBoundaries61(t *testing.T) { testSyncBoundaries(t, 61, FullSync) } -func TestSyncBoundaries62(t *testing.T) { testSyncBoundaries(t, 62, FullSync) } -func TestSyncBoundaries63Full(t *testing.T) { testSyncBoundaries(t, 63, FullSync) } -func TestSyncBoundaries63Fast(t *testing.T) { testSyncBoundaries(t, 63, FastSync) } -func TestSyncBoundaries64Full(t *testing.T) { testSyncBoundaries(t, 64, FullSync) } -func TestSyncBoundaries64Fast(t *testing.T) { testSyncBoundaries(t, 64, FastSync) } -func TestSyncBoundaries64Light(t *testing.T) { testSyncBoundaries(t, 64, LightSync) } - -func testSyncBoundaries(t *testing.T, protocol int, mode SyncMode) { +// Tests that synchronisation progress (origin block number, current block number +// and highest block number) is tracked and updated correctly. +func TestSyncProgress61(t *testing.T) { testSyncProgress(t, 61, FullSync) } +func TestSyncProgress62(t *testing.T) { testSyncProgress(t, 62, FullSync) } +func TestSyncProgress63Full(t *testing.T) { testSyncProgress(t, 63, FullSync) } +func TestSyncProgress63Fast(t *testing.T) { testSyncProgress(t, 63, FastSync) } +func TestSyncProgress64Full(t *testing.T) { testSyncProgress(t, 64, FullSync) } +func TestSyncProgress64Fast(t *testing.T) { testSyncProgress(t, 64, FastSync) } +func TestSyncProgress64Light(t *testing.T) { testSyncProgress(t, 64, LightSync) } + +func testSyncProgress(t *testing.T, protocol int, mode SyncMode) { // Create a small enough block chain to download targetBlocks := blockCacheLimit - 15 hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil) - // Set a sync init hook to catch boundary changes + // Set a sync init hook to catch progress changes starting := make(chan struct{}) progress := make(chan struct{}) - tester := newTester(mode) + tester := newTester() tester.downloader.syncInitHook = func(origin, latest uint64) { starting <- struct{}{} <-progress } - // Retrieve the sync boundaries and ensure they are zero (pristine sync) - if origin, latest := tester.downloader.Boundaries(); origin != 0 || latest != 0 { - t.Fatalf("Pristine boundary mismatch: have %v/%v, want %v/%v", origin, latest, 0, 0) + // Retrieve the sync progress and ensure they are zero (pristine sync) + if origin, current, latest := tester.downloader.Progress(); origin != 0 || current != 0 || latest != 0 { + t.Fatalf("Pristine progress mismatch: have %v/%v/%v, want %v/%v/%v", origin, current, latest, 0, 0, 0) } - // Synchronise half the blocks and check initial boundaries + // Synchronise half the blocks and check initial progress tester.newPeer("peer-half", protocol, hashes[targetBlocks/2:], headers, blocks, receipts) pending := new(sync.WaitGroup) pending.Add(1) go func() { defer pending.Done() - if err := tester.sync("peer-half", nil); err != nil { + if err := tester.sync("peer-half", nil, mode); err != nil { t.Fatalf("failed to synchronise blocks: %v", err) } }() <-starting - if origin, latest := tester.downloader.Boundaries(); origin != 0 || latest != uint64(targetBlocks/2+1) { - t.Fatalf("Initial boundary mismatch: have %v/%v, want %v/%v", origin, latest, 0, targetBlocks/2+1) + if origin, current, latest := tester.downloader.Progress(); origin != 0 || current != 0 || latest != uint64(targetBlocks/2+1) { + t.Fatalf("Initial progress mismatch: have %v/%v/%v, want %v/%v/%v", origin, current, latest, 0, 0, targetBlocks/2+1) } progress <- struct{}{} pending.Wait() - // Synchronise all the blocks and check continuation boundaries + // Synchronise all the blocks and check continuation progress tester.newPeer("peer-full", protocol, hashes, headers, blocks, receipts) pending.Add(1) go func() { defer pending.Done() - if err := tester.sync("peer-full", nil); err != nil { + if err := tester.sync("peer-full", nil, mode); err != nil { t.Fatalf("failed to synchronise blocks: %v", err) } }() <-starting - if origin, latest := tester.downloader.Boundaries(); origin != uint64(targetBlocks/2+1) || latest != uint64(targetBlocks) { - t.Fatalf("Completing boundary mismatch: have %v/%v, want %v/%v", origin, latest, targetBlocks/2+1, targetBlocks) + if origin, current, latest := tester.downloader.Progress(); origin != uint64(targetBlocks/2+1) || current != uint64(targetBlocks/2+1) || latest != uint64(targetBlocks) { + t.Fatalf("Completing progress mismatch: have %v/%v/%v, want %v/%v/%v", origin, current, latest, targetBlocks/2+1, targetBlocks/2+1, targetBlocks) } progress <- struct{}{} pending.Wait() + + // Check final progress after successful sync + if origin, current, latest := tester.downloader.Progress(); origin != uint64(targetBlocks/2+1) || current != uint64(targetBlocks) || latest != uint64(targetBlocks) { + t.Fatalf("Final progress mismatch: have %v/%v/%v, want %v/%v/%v", origin, current, latest, targetBlocks/2+1, targetBlocks, targetBlocks) + } } -// Tests that synchronisation boundaries (origin block number and highest block +// Tests that synchronisation progress (origin block number and highest block // number) is tracked and updated correctly in case of a fork (or manual head // revertal). -func TestForkedSyncBoundaries61(t *testing.T) { testForkedSyncBoundaries(t, 61, FullSync) } -func TestForkedSyncBoundaries62(t *testing.T) { testForkedSyncBoundaries(t, 62, FullSync) } -func TestForkedSyncBoundaries63Full(t *testing.T) { testForkedSyncBoundaries(t, 63, FullSync) } -func TestForkedSyncBoundaries63Fast(t *testing.T) { testForkedSyncBoundaries(t, 63, FastSync) } -func TestForkedSyncBoundaries64Full(t *testing.T) { testForkedSyncBoundaries(t, 64, FullSync) } -func TestForkedSyncBoundaries64Fast(t *testing.T) { testForkedSyncBoundaries(t, 64, FastSync) } -func TestForkedSyncBoundaries64Light(t *testing.T) { testForkedSyncBoundaries(t, 64, LightSync) } - -func testForkedSyncBoundaries(t *testing.T, protocol int, mode SyncMode) { +func TestForkedSyncProgress61(t *testing.T) { testForkedSyncProgress(t, 61, FullSync) } +func TestForkedSyncProgress62(t *testing.T) { testForkedSyncProgress(t, 62, FullSync) } +func TestForkedSyncProgress63Full(t *testing.T) { testForkedSyncProgress(t, 63, FullSync) } +func TestForkedSyncProgress63Fast(t *testing.T) { testForkedSyncProgress(t, 63, FastSync) } +func TestForkedSyncProgress64Full(t *testing.T) { testForkedSyncProgress(t, 64, FullSync) } +func TestForkedSyncProgress64Fast(t *testing.T) { testForkedSyncProgress(t, 64, FastSync) } +func TestForkedSyncProgress64Light(t *testing.T) { testForkedSyncProgress(t, 64, LightSync) } + +func testForkedSyncProgress(t *testing.T, protocol int, mode SyncMode) { // Create a forked chain to simulate origin revertal common, fork := MaxHashFetch, 2*MaxHashFetch hashesA, hashesB, headersA, headersB, blocksA, blocksB, receiptsA, receiptsB := makeChainFork(common+fork, fork, genesis, nil) - // Set a sync init hook to catch boundary changes + // Set a sync init hook to catch progress changes starting := make(chan struct{}) progress := make(chan struct{}) - tester := newTester(mode) + tester := newTester() tester.downloader.syncInitHook = func(origin, latest uint64) { starting <- struct{}{} <-progress } - // Retrieve the sync boundaries and ensure they are zero (pristine sync) - if origin, latest := tester.downloader.Boundaries(); origin != 0 || latest != 0 { - t.Fatalf("Pristine boundary mismatch: have %v/%v, want %v/%v", origin, latest, 0, 0) + // Retrieve the sync progress and ensure they are zero (pristine sync) + if origin, current, latest := tester.downloader.Progress(); origin != 0 || current != 0 || latest != 0 { + t.Fatalf("Pristine progress mismatch: have %v/%v/%v, want %v/%v/%v", origin, current, latest, 0, 0, 0) } - // Synchronise with one of the forks and check boundaries + // Synchronise with one of the forks and check progress tester.newPeer("fork A", protocol, hashesA, headersA, blocksA, receiptsA) pending := new(sync.WaitGroup) pending.Add(1) go func() { defer pending.Done() - if err := tester.sync("fork A", nil); err != nil { + if err := tester.sync("fork A", nil, mode); err != nil { t.Fatalf("failed to synchronise blocks: %v", err) } }() <-starting - if origin, latest := tester.downloader.Boundaries(); origin != 0 || latest != uint64(len(hashesA)-1) { - t.Fatalf("Initial boundary mismatch: have %v/%v, want %v/%v", origin, latest, 0, len(hashesA)-1) + if origin, current, latest := tester.downloader.Progress(); origin != 0 || current != 0 || latest != uint64(len(hashesA)-1) { + t.Fatalf("Initial progress mismatch: have %v/%v/%v, want %v/%v/%v", origin, current, latest, 0, 0, len(hashesA)-1) } progress <- struct{}{} pending.Wait() @@ -1340,52 +1369,57 @@ func testForkedSyncBoundaries(t *testing.T, protocol int, mode SyncMode) { // Simulate a successful sync above the fork tester.downloader.syncStatsChainOrigin = tester.downloader.syncStatsChainHeight - // Synchronise with the second fork and check boundary resets + // Synchronise with the second fork and check progress resets tester.newPeer("fork B", protocol, hashesB, headersB, blocksB, receiptsB) pending.Add(1) go func() { defer pending.Done() - if err := tester.sync("fork B", nil); err != nil { + if err := tester.sync("fork B", nil, mode); err != nil { t.Fatalf("failed to synchronise blocks: %v", err) } }() <-starting - if origin, latest := tester.downloader.Boundaries(); origin != uint64(common) || latest != uint64(len(hashesB)-1) { - t.Fatalf("Forking boundary mismatch: have %v/%v, want %v/%v", origin, latest, common, len(hashesB)-1) + if origin, current, latest := tester.downloader.Progress(); origin != uint64(common) || current != uint64(len(hashesA)-1) || latest != uint64(len(hashesB)-1) { + t.Fatalf("Forking progress mismatch: have %v/%v/%v, want %v/%v/%v", origin, current, latest, common, len(hashesA)-1, len(hashesB)-1) } progress <- struct{}{} pending.Wait() + + // Check final progress after successful sync + if origin, current, latest := tester.downloader.Progress(); origin != uint64(common) || current != uint64(len(hashesB)-1) || latest != uint64(len(hashesB)-1) { + t.Fatalf("Final progress mismatch: have %v/%v/%v, want %v/%v/%v", origin, current, latest, common, len(hashesB)-1, len(hashesB)-1) + } } -// Tests that if synchronisation is aborted due to some failure, then the boundary +// Tests that if synchronisation is aborted due to some failure, then the progress // origin is not updated in the next sync cycle, as it should be considered the // continuation of the previous sync and not a new instance. -func TestFailedSyncBoundaries61(t *testing.T) { testFailedSyncBoundaries(t, 61, FullSync) } -func TestFailedSyncBoundaries62(t *testing.T) { testFailedSyncBoundaries(t, 62, FullSync) } -func TestFailedSyncBoundaries63Full(t *testing.T) { testFailedSyncBoundaries(t, 63, FullSync) } -func TestFailedSyncBoundaries63Fast(t *testing.T) { testFailedSyncBoundaries(t, 63, FastSync) } -func TestFailedSyncBoundaries64Full(t *testing.T) { testFailedSyncBoundaries(t, 64, FullSync) } -func TestFailedSyncBoundaries64Fast(t *testing.T) { testFailedSyncBoundaries(t, 64, FastSync) } -func TestFailedSyncBoundaries64Light(t *testing.T) { testFailedSyncBoundaries(t, 64, LightSync) } - -func testFailedSyncBoundaries(t *testing.T, protocol int, mode SyncMode) { +func TestFailedSyncProgress61(t *testing.T) { testFailedSyncProgress(t, 61, FullSync) } +func TestFailedSyncProgress62(t *testing.T) { testFailedSyncProgress(t, 62, FullSync) } +func TestFailedSyncProgress63Full(t *testing.T) { testFailedSyncProgress(t, 63, FullSync) } +func TestFailedSyncProgress63Fast(t *testing.T) { testFailedSyncProgress(t, 63, FastSync) } +func TestFailedSyncProgress64Full(t *testing.T) { testFailedSyncProgress(t, 64, FullSync) } +func TestFailedSyncProgress64Fast(t *testing.T) { testFailedSyncProgress(t, 64, FastSync) } +func TestFailedSyncProgress64Light(t *testing.T) { testFailedSyncProgress(t, 64, LightSync) } + +func testFailedSyncProgress(t *testing.T, protocol int, mode SyncMode) { // Create a small enough block chain to download targetBlocks := blockCacheLimit - 15 hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil) - // Set a sync init hook to catch boundary changes + // Set a sync init hook to catch progress changes starting := make(chan struct{}) progress := make(chan struct{}) - tester := newTester(mode) + tester := newTester() tester.downloader.syncInitHook = func(origin, latest uint64) { starting <- struct{}{} <-progress } - // Retrieve the sync boundaries and ensure they are zero (pristine sync) - if origin, latest := tester.downloader.Boundaries(); origin != 0 || latest != 0 { - t.Fatalf("Pristine boundary mismatch: have %v/%v, want %v/%v", origin, latest, 0, 0) + // Retrieve the sync progress and ensure they are zero (pristine sync) + if origin, current, latest := tester.downloader.Progress(); origin != 0 || current != 0 || latest != 0 { + t.Fatalf("Pristine progress mismatch: have %v/%v/%v, want %v/%v/%v", origin, current, latest, 0, 0, 0) } // Attempt a full sync with a faulty peer tester.newPeer("faulty", protocol, hashes, headers, blocks, receipts) @@ -1399,62 +1433,67 @@ func testFailedSyncBoundaries(t *testing.T, protocol int, mode SyncMode) { go func() { defer pending.Done() - if err := tester.sync("faulty", nil); err == nil { + if err := tester.sync("faulty", nil, mode); err == nil { t.Fatalf("succeeded faulty synchronisation") } }() <-starting - if origin, latest := tester.downloader.Boundaries(); origin != 0 || latest != uint64(targetBlocks) { - t.Fatalf("Initial boundary mismatch: have %v/%v, want %v/%v", origin, latest, 0, targetBlocks) + if origin, current, latest := tester.downloader.Progress(); origin != 0 || current != 0 || latest != uint64(targetBlocks) { + t.Fatalf("Initial progress mismatch: have %v/%v/%v, want %v/%v/%v", origin, current, latest, 0, 0, targetBlocks) } progress <- struct{}{} pending.Wait() - // Synchronise with a good peer and check that the boundary origin remind the same after a failure + // Synchronise with a good peer and check that the progress origin remind the same after a failure tester.newPeer("valid", protocol, hashes, headers, blocks, receipts) pending.Add(1) go func() { defer pending.Done() - if err := tester.sync("valid", nil); err != nil { + if err := tester.sync("valid", nil, mode); err != nil { t.Fatalf("failed to synchronise blocks: %v", err) } }() <-starting - if origin, latest := tester.downloader.Boundaries(); origin != 0 || latest != uint64(targetBlocks) { - t.Fatalf("Completing boundary mismatch: have %v/%v, want %v/%v", origin, latest, 0, targetBlocks) + if origin, current, latest := tester.downloader.Progress(); origin != 0 || current > uint64(targetBlocks/2) || latest != uint64(targetBlocks) { + t.Fatalf("Completing progress mismatch: have %v/%v/%v, want %v/0-%v/%v", origin, current, latest, 0, targetBlocks/2, targetBlocks) } progress <- struct{}{} pending.Wait() + + // Check final progress after successful sync + if origin, current, latest := tester.downloader.Progress(); origin > uint64(targetBlocks/2) || current != uint64(targetBlocks) || latest != uint64(targetBlocks) { + t.Fatalf("Final progress mismatch: have %v/%v/%v, want 0-%v/%v/%v", origin, current, latest, targetBlocks/2, targetBlocks, targetBlocks) + } } // Tests that if an attacker fakes a chain height, after the attack is detected, -// the boundary height is successfully reduced at the next sync invocation. -func TestFakedSyncBoundaries61(t *testing.T) { testFakedSyncBoundaries(t, 61, FullSync) } -func TestFakedSyncBoundaries62(t *testing.T) { testFakedSyncBoundaries(t, 62, FullSync) } -func TestFakedSyncBoundaries63Full(t *testing.T) { testFakedSyncBoundaries(t, 63, FullSync) } -func TestFakedSyncBoundaries63Fast(t *testing.T) { testFakedSyncBoundaries(t, 63, FastSync) } -func TestFakedSyncBoundaries64Full(t *testing.T) { testFakedSyncBoundaries(t, 64, FullSync) } -func TestFakedSyncBoundaries64Fast(t *testing.T) { testFakedSyncBoundaries(t, 64, FastSync) } -func TestFakedSyncBoundaries64Light(t *testing.T) { testFakedSyncBoundaries(t, 64, LightSync) } - -func testFakedSyncBoundaries(t *testing.T, protocol int, mode SyncMode) { +// the progress height is successfully reduced at the next sync invocation. +func TestFakedSyncProgress61(t *testing.T) { testFakedSyncProgress(t, 61, FullSync) } +func TestFakedSyncProgress62(t *testing.T) { testFakedSyncProgress(t, 62, FullSync) } +func TestFakedSyncProgress63Full(t *testing.T) { testFakedSyncProgress(t, 63, FullSync) } +func TestFakedSyncProgress63Fast(t *testing.T) { testFakedSyncProgress(t, 63, FastSync) } +func TestFakedSyncProgress64Full(t *testing.T) { testFakedSyncProgress(t, 64, FullSync) } +func TestFakedSyncProgress64Fast(t *testing.T) { testFakedSyncProgress(t, 64, FastSync) } +func TestFakedSyncProgress64Light(t *testing.T) { testFakedSyncProgress(t, 64, LightSync) } + +func testFakedSyncProgress(t *testing.T, protocol int, mode SyncMode) { // Create a small block chain targetBlocks := blockCacheLimit - 15 hashes, headers, blocks, receipts := makeChain(targetBlocks+3, 0, genesis, nil) - // Set a sync init hook to catch boundary changes + // Set a sync init hook to catch progress changes starting := make(chan struct{}) progress := make(chan struct{}) - tester := newTester(mode) + tester := newTester() tester.downloader.syncInitHook = func(origin, latest uint64) { starting <- struct{}{} <-progress } - // Retrieve the sync boundaries and ensure they are zero (pristine sync) - if origin, latest := tester.downloader.Boundaries(); origin != 0 || latest != 0 { - t.Fatalf("Pristine boundary mismatch: have %v/%v, want %v/%v", origin, latest, 0, 0) + // Retrieve the sync progress and ensure they are zero (pristine sync) + if origin, current, latest := tester.downloader.Progress(); origin != 0 || current != 0 || latest != 0 { + t.Fatalf("Pristine progress mismatch: have %v/%v/%v, want %v/%v/%v", origin, current, latest, 0, 0, 0) } // Create and sync with an attacker that promises a higher chain than available tester.newPeer("attack", protocol, hashes, headers, blocks, receipts) @@ -1469,31 +1508,36 @@ func testFakedSyncBoundaries(t *testing.T, protocol int, mode SyncMode) { go func() { defer pending.Done() - if err := tester.sync("attack", nil); err == nil { + if err := tester.sync("attack", nil, mode); err == nil { t.Fatalf("succeeded attacker synchronisation") } }() <-starting - if origin, latest := tester.downloader.Boundaries(); origin != 0 || latest != uint64(targetBlocks+3) { - t.Fatalf("Initial boundary mismatch: have %v/%v, want %v/%v", origin, latest, 0, targetBlocks+3) + if origin, current, latest := tester.downloader.Progress(); origin != 0 || current != 0 || latest != uint64(targetBlocks+3) { + t.Fatalf("Initial progress mismatch: have %v/%v/%v, want %v/%v/%v", origin, current, latest, 0, 0, targetBlocks+3) } progress <- struct{}{} pending.Wait() - // Synchronise with a good peer and check that the boundary height has been reduced to the true value + // Synchronise with a good peer and check that the progress height has been reduced to the true value tester.newPeer("valid", protocol, hashes[3:], headers, blocks, receipts) pending.Add(1) go func() { defer pending.Done() - if err := tester.sync("valid", nil); err != nil { + if err := tester.sync("valid", nil, mode); err != nil { t.Fatalf("failed to synchronise blocks: %v", err) } }() <-starting - if origin, latest := tester.downloader.Boundaries(); origin != 0 || latest != uint64(targetBlocks) { - t.Fatalf("Initial boundary mismatch: have %v/%v, want %v/%v", origin, latest, 0, targetBlocks) + if origin, current, latest := tester.downloader.Progress(); origin != 0 || current > uint64(targetBlocks) || latest != uint64(targetBlocks) { + t.Fatalf("Completing progress mismatch: have %v/%v/%v, want %v/0-%v/%v", origin, current, latest, 0, targetBlocks, targetBlocks) } progress <- struct{}{} pending.Wait() + + // Check final progress after successful sync + if origin, current, latest := tester.downloader.Progress(); origin > uint64(targetBlocks) || current != uint64(targetBlocks) || latest != uint64(targetBlocks) { + t.Fatalf("Final progress mismatch: have %v/%v/%v, want 0-%v/%v/%v", origin, current, latest, targetBlocks, targetBlocks, targetBlocks) + } } diff --git a/eth/downloader/modes.go b/eth/downloader/modes.go index 8916dbb79..ec339c074 100644 --- a/eth/downloader/modes.go +++ b/eth/downloader/modes.go @@ -20,7 +20,7 @@ package downloader type SyncMode int const ( - FullSync SyncMode = iota // Synchronise the entire block-chain history from full blocks - FastSync // Quikcly download the headers, full sync only at the chain head + FullSync SyncMode = iota // Synchronise the entire blockchain history from full blocks + FastSync // Quickly download the headers, full sync only at the chain head LightSync // Download only the headers and terminate afterwards ) diff --git a/eth/downloader/peer.go b/eth/downloader/peer.go index 5011d5d46..1f457cb15 100644 --- a/eth/downloader/peer.go +++ b/eth/downloader/peer.go @@ -124,6 +124,10 @@ func (p *peer) Reset() { // Fetch61 sends a block retrieval request to the remote peer. func (p *peer) Fetch61(request *fetchRequest) error { + // Sanity check the protocol version + if p.version != 61 { + panic(fmt.Sprintf("block fetch [eth/61] requested on eth/%d", p.version)) + } // Short circuit if the peer is already fetching if !atomic.CompareAndSwapInt32(&p.blockIdle, 0, 1) { return errAlreadyFetching @@ -142,6 +146,10 @@ func (p *peer) Fetch61(request *fetchRequest) error { // FetchBodies sends a block body retrieval request to the remote peer. func (p *peer) FetchBodies(request *fetchRequest) error { + // Sanity check the protocol version + if p.version < 62 { + panic(fmt.Sprintf("body fetch [eth/62+] requested on eth/%d", p.version)) + } // Short circuit if the peer is already fetching if !atomic.CompareAndSwapInt32(&p.blockIdle, 0, 1) { return errAlreadyFetching @@ -160,6 +168,10 @@ func (p *peer) FetchBodies(request *fetchRequest) error { // FetchReceipts sends a receipt retrieval request to the remote peer. func (p *peer) FetchReceipts(request *fetchRequest) error { + // Sanity check the protocol version + if p.version < 63 { + panic(fmt.Sprintf("body fetch [eth/63+] requested on eth/%d", p.version)) + } // Short circuit if the peer is already fetching if !atomic.CompareAndSwapInt32(&p.receiptIdle, 0, 1) { return errAlreadyFetching @@ -178,6 +190,10 @@ func (p *peer) FetchReceipts(request *fetchRequest) error { // FetchNodeData sends a node state data retrieval request to the remote peer. func (p *peer) FetchNodeData(request *fetchRequest) error { + // Sanity check the protocol version + if p.version < 63 { + panic(fmt.Sprintf("node data fetch [eth/63+] requested on eth/%d", p.version)) + } // Short circuit if the peer is already fetching if !atomic.CompareAndSwapInt32(&p.stateIdle, 0, 1) { return errAlreadyFetching @@ -196,35 +212,35 @@ func (p *peer) FetchNodeData(request *fetchRequest) error { // SetBlocksIdle sets the peer to idle, allowing it to execute new retrieval requests. // Its block retrieval allowance will also be updated either up- or downwards, -// depending on whether the previous fetch completed in time or not. +// depending on whether the previous fetch completed in time. func (p *peer) SetBlocksIdle() { p.setIdle(p.blockStarted, blockSoftTTL, blockHardTTL, MaxBlockFetch, &p.blockCapacity, &p.blockIdle) } // SetBodiesIdle sets the peer to idle, allowing it to execute new retrieval requests. // Its block body retrieval allowance will also be updated either up- or downwards, -// depending on whether the previous fetch completed in time or not. +// depending on whether the previous fetch completed in time. func (p *peer) SetBodiesIdle() { - p.setIdle(p.blockStarted, bodySoftTTL, bodyHardTTL, MaxBlockFetch, &p.blockCapacity, &p.blockIdle) + p.setIdle(p.blockStarted, bodySoftTTL, bodyHardTTL, MaxBodyFetch, &p.blockCapacity, &p.blockIdle) } // SetReceiptsIdle sets the peer to idle, allowing it to execute new retrieval requests. // Its receipt retrieval allowance will also be updated either up- or downwards, -// depending on whether the previous fetch completed in time or not. +// depending on whether the previous fetch completed in time. func (p *peer) SetReceiptsIdle() { p.setIdle(p.receiptStarted, receiptSoftTTL, receiptHardTTL, MaxReceiptFetch, &p.receiptCapacity, &p.receiptIdle) } // SetNodeDataIdle sets the peer to idle, allowing it to execute new retrieval // requests. Its node data retrieval allowance will also be updated either up- or -// downwards, depending on whether the previous fetch completed in time or not. +// downwards, depending on whether the previous fetch completed in time. func (p *peer) SetNodeDataIdle() { p.setIdle(p.stateStarted, stateSoftTTL, stateSoftTTL, MaxStateFetch, &p.stateCapacity, &p.stateIdle) } // setIdle sets the peer to idle, allowing it to execute new retrieval requests. // Its data retrieval allowance will also be updated either up- or downwards, -// depending on whether the previous fetch completed in time or not. +// depending on whether the previous fetch completed in time. func (p *peer) setIdle(started time.Time, softTTL, hardTTL time.Duration, maxFetch int, capacity, idle *int32) { // Update the peer's download allowance based on previous performance scale := 2.0 diff --git a/eth/downloader/queue.go b/eth/downloader/queue.go index 17fbb1c7f..56b46e285 100644 --- a/eth/downloader/queue.go +++ b/eth/downloader/queue.go @@ -56,9 +56,8 @@ type fetchRequest struct { Time time.Time // Time when the request was made } -// fetchResult is the assembly collecting partial results from potentially more -// than one fetcher routines, until all outstanding retrievals complete and the -// result as a whole can be processed. +// fetchResult is a struct collecting partial results from data fetchers until +// all outstanding pieces complete and the result as a whole can be processed. type fetchResult struct { Pending int // Number of data fetches still pending @@ -89,7 +88,7 @@ type queue struct { receiptPendPool map[string]*fetchRequest // [eth/63] Currently pending receipt retrieval operations receiptDonePool map[common.Hash]struct{} // [eth/63] Set of the completed receipt fetches - stateTaskIndex int // [eth/63] Counter indexing the added hashes to ensure prioritized retrieval order + stateTaskIndex int // [eth/63] Counter indexing the added hashes to ensure prioritised retrieval order stateTaskPool map[common.Hash]int // [eth/63] Pending node data retrieval tasks, mapping to their priority stateTaskQueue *prque.Prque // [eth/63] Priority queue of the hashes to fetch the node data for statePendPool map[string]*fetchRequest // [eth/63] Currently pending node data retrieval operations @@ -97,10 +96,10 @@ type queue struct { stateDatabase ethdb.Database // [eth/63] Trie database to populate during state reassembly stateScheduler *state.StateSync // [eth/63] State trie synchronisation scheduler and integrator stateProcessors int32 // [eth/63] Number of currently running state processors - stateSchedLock sync.RWMutex // [eth/63] Lock serializing access to the state scheduler + stateSchedLock sync.RWMutex // [eth/63] Lock serialising access to the state scheduler resultCache []*fetchResult // Downloaded but not yet delivered fetch results - resultOffset uint64 // Offset of the first cached fetch result in the block-chain + resultOffset uint64 // Offset of the first cached fetch result in the block chain lock sync.RWMutex } @@ -131,6 +130,9 @@ func (q *queue) Reset() { q.lock.Lock() defer q.lock.Unlock() + q.stateSchedLock.Lock() + defer q.stateSchedLock.Unlock() + q.mode = FullSync q.fastSyncPivot = 0 @@ -233,9 +235,17 @@ func (q *queue) Idle() bool { return (queued + pending + cached) == 0 } -// ThrottleBlocks checks if the download should be throttled (active block (body) +// FastSyncPivot retrieves the currently used fast sync pivot point. +func (q *queue) FastSyncPivot() uint64 { + q.lock.RLock() + defer q.lock.RUnlock() + + return q.fastSyncPivot +} + +// ShouldThrottleBlocks checks if the download should be throttled (active block (body) // fetches exceed block cache). -func (q *queue) ThrottleBlocks() bool { +func (q *queue) ShouldThrottleBlocks() bool { q.lock.RLock() defer q.lock.RUnlock() @@ -248,9 +258,9 @@ func (q *queue) ThrottleBlocks() bool { return pending >= len(q.resultCache)-len(q.blockDonePool) } -// ThrottleReceipts checks if the download should be throttled (active receipt +// ShouldThrottleReceipts checks if the download should be throttled (active receipt // fetches exceed block cache). -func (q *queue) ThrottleReceipts() bool { +func (q *queue) ShouldThrottleReceipts() bool { q.lock.RLock() defer q.lock.RUnlock() @@ -269,7 +279,7 @@ func (q *queue) Schedule61(hashes []common.Hash, fifo bool) []common.Hash { q.lock.Lock() defer q.lock.Unlock() - // Insert all the hashes prioritized in the arrival order + // Insert all the hashes prioritised in the arrival order inserts := make([]common.Hash, 0, len(hashes)) for _, hash := range hashes { // Skip anything we already have @@ -297,10 +307,10 @@ func (q *queue) Schedule(headers []*types.Header, from uint64) []*types.Header { q.lock.Lock() defer q.lock.Unlock() - // Insert all the headers prioritized by the contained block number + // Insert all the headers prioritised by the contained block number inserts := make([]*types.Header, 0, len(headers)) for _, header := range headers { - // Make sure chain order is honored and preserved throughout + // Make sure chain order is honoured and preserved throughout hash := header.Hash() if header.Number == nil || header.Number.Uint64() != from { glog.V(logger.Warn).Infof("Header #%v [%x] broke chain ordering, expected %d", header.Number, hash[:4], from) @@ -347,19 +357,29 @@ func (q *queue) GetHeadResult() *fetchResult { q.lock.RLock() defer q.lock.RUnlock() + // If there are no results pending, return nil if len(q.resultCache) == 0 || q.resultCache[0] == nil { return nil } + // If the next result is still incomplete, return nil if q.resultCache[0].Pending > 0 { return nil } + // If the next result is the fast sync pivot... if q.mode == FastSync && q.resultCache[0].Header.Number.Uint64() == q.fastSyncPivot { + // If the pivot state trie is still being pulled, return nil if len(q.stateTaskPool) > 0 { return nil } if q.PendingNodeData() > 0 { return nil } + // If the state is done, but not enough post-pivot headers were verified, stall... + for i := 0; i < fsHeaderForceVerify; i++ { + if i+1 >= len(q.resultCache) || q.resultCache[i+1] == nil { + return nil + } + } } return q.resultCache[0] } @@ -372,7 +392,7 @@ func (q *queue) TakeResults() []*fetchResult { // Accumulate all available results results := []*fetchResult{} - for _, result := range q.resultCache { + for i, result := range q.resultCache { // Stop if no more results are ready if result == nil || result.Pending > 0 { break @@ -385,6 +405,16 @@ func (q *queue) TakeResults() []*fetchResult { if q.PendingNodeData() > 0 { break } + // Even is state fetch is done, ensure post-pivot headers passed verifications + safe := true + for j := 0; j < fsHeaderForceVerify; j++ { + if i+j+1 >= len(q.resultCache) || q.resultCache[i+j+1] == nil { + safe = false + } + } + if !safe { + break + } } // If we've just inserted the fast sync pivot, stop as the following batch needs different insertion if q.mode == FastSync && result.Header.Number.Uint64() == q.fastSyncPivot+1 && len(results) > 0 { @@ -411,6 +441,9 @@ func (q *queue) TakeResults() []*fetchResult { // ReserveBlocks reserves a set of block hashes for the given peer, skipping any // previously failed download. func (q *queue) ReserveBlocks(p *peer, count int) *fetchRequest { + q.lock.Lock() + defer q.lock.Unlock() + return q.reserveHashes(p, count, q.hashQueue, nil, q.blockPendPool, len(q.resultCache)-len(q.blockDonePool)) } @@ -430,17 +463,21 @@ func (q *queue) ReserveNodeData(p *peer, count int) *fetchRequest { } } } + q.lock.Lock() + defer q.lock.Unlock() + return q.reserveHashes(p, count, q.stateTaskQueue, generator, q.statePendPool, count) } // reserveHashes reserves a set of hashes for the given peer, skipping previously // failed ones. +// +// Note, this method expects the queue lock to be already held for writing. The +// reason the lock is not obtained in here is because the parameters already need +// to access the queue, so they already need a lock anyway. func (q *queue) reserveHashes(p *peer, count int, taskQueue *prque.Prque, taskGen func(int), pendPool map[string]*fetchRequest, maxPending int) *fetchRequest { - q.lock.Lock() - defer q.lock.Unlock() - - // Short circuit if the peer's already downloading something (sanity check not - // to corrupt state) + // Short circuit if the peer's already downloading something (sanity check to + // not corrupt state) if _, ok := pendPool[p.id]; ok { return nil } @@ -492,30 +529,37 @@ func (q *queue) reserveHashes(p *peer, count int, taskQueue *prque.Prque, taskGe // previously failed downloads. Beside the next batch of needed fetches, it also // returns a flag whether empty blocks were queued requiring processing. func (q *queue) ReserveBodies(p *peer, count int) (*fetchRequest, bool, error) { - noop := func(header *types.Header) bool { + isNoop := func(header *types.Header) bool { return header.TxHash == types.EmptyRootHash && header.UncleHash == types.EmptyUncleHash } - return q.reserveHeaders(p, count, q.blockTaskPool, q.blockTaskQueue, q.blockPendPool, q.blockDonePool, noop) + q.lock.Lock() + defer q.lock.Unlock() + + return q.reserveHeaders(p, count, q.blockTaskPool, q.blockTaskQueue, q.blockPendPool, q.blockDonePool, isNoop) } // ReserveReceipts reserves a set of receipt fetches for the given peer, skipping // any previously failed downloads. Beside the next batch of needed fetches, it // also returns a flag whether empty receipts were queued requiring importing. func (q *queue) ReserveReceipts(p *peer, count int) (*fetchRequest, bool, error) { - noop := func(header *types.Header) bool { + isNoop := func(header *types.Header) bool { return header.ReceiptHash == types.EmptyRootHash } - return q.reserveHeaders(p, count, q.receiptTaskPool, q.receiptTaskQueue, q.receiptPendPool, q.receiptDonePool, noop) + q.lock.Lock() + defer q.lock.Unlock() + + return q.reserveHeaders(p, count, q.receiptTaskPool, q.receiptTaskQueue, q.receiptPendPool, q.receiptDonePool, isNoop) } // reserveHeaders reserves a set of data download operations for a given peer, // skipping any previously failed ones. This method is a generic version used // by the individual special reservation functions. +// +// Note, this method expects the queue lock to be already held for writing. The +// reason the lock is not obtained in here is because the parameters already need +// to access the queue, so they already need a lock anyway. func (q *queue) reserveHeaders(p *peer, count int, taskPool map[common.Hash]*types.Header, taskQueue *prque.Prque, - pendPool map[string]*fetchRequest, donePool map[common.Hash]struct{}, noop func(*types.Header) bool) (*fetchRequest, bool, error) { - q.lock.Lock() - defer q.lock.Unlock() - + pendPool map[string]*fetchRequest, donePool map[common.Hash]struct{}, isNoop func(*types.Header) bool) (*fetchRequest, bool, error) { // Short circuit if the pool has been depleted, or if the peer's already // downloading something (sanity check not to corrupt state) if taskQueue.Empty() { @@ -537,7 +581,7 @@ func (q *queue) reserveHeaders(p *peer, count int, taskPool map[common.Hash]*typ for proc := 0; proc < space && len(send) < count && !taskQueue.Empty(); proc++ { header := taskQueue.PopItem().(*types.Header) - // If we're the first to request this task, initialize the result container + // If we're the first to request this task, initialise the result container index := int(header.Number.Int64() - int64(q.resultOffset)) if index >= len(q.resultCache) || index < 0 { return nil, false, errInvalidChain @@ -553,7 +597,7 @@ func (q *queue) reserveHeaders(p *peer, count int, taskPool map[common.Hash]*typ } } // If this fetch task is a noop, skip this fetch operation - if noop(header) { + if isNoop(header) { donePool[header.Hash()] = struct{}{} delete(taskPool, header.Hash()) @@ -562,7 +606,7 @@ func (q *queue) reserveHeaders(p *peer, count int, taskPool map[common.Hash]*typ progress = true continue } - // Otherwise if not a known unknown block, add to the retrieve list + // Otherwise unless the peer is known not to have the data, add to the retrieve list if p.ignored.Has(header.Hash()) { skip = append(skip, header) } else { @@ -655,35 +699,48 @@ func (q *queue) Revoke(peerId string) { } // ExpireBlocks checks for in flight requests that exceeded a timeout allowance, -// canceling them and returning the responsible peers for penalization. +// canceling them and returning the responsible peers for penalisation. func (q *queue) ExpireBlocks(timeout time.Duration) []string { + q.lock.Lock() + defer q.lock.Unlock() + return q.expire(timeout, q.blockPendPool, q.hashQueue, blockTimeoutMeter) } // ExpireBodies checks for in flight block body requests that exceeded a timeout -// allowance, canceling them and returning the responsible peers for penalization. +// allowance, canceling them and returning the responsible peers for penalisation. func (q *queue) ExpireBodies(timeout time.Duration) []string { + q.lock.Lock() + defer q.lock.Unlock() + return q.expire(timeout, q.blockPendPool, q.blockTaskQueue, bodyTimeoutMeter) } // ExpireReceipts checks for in flight receipt requests that exceeded a timeout -// allowance, canceling them and returning the responsible peers for penalization. +// allowance, canceling them and returning the responsible peers for penalisation. func (q *queue) ExpireReceipts(timeout time.Duration) []string { + q.lock.Lock() + defer q.lock.Unlock() + return q.expire(timeout, q.receiptPendPool, q.receiptTaskQueue, receiptTimeoutMeter) } // ExpireNodeData checks for in flight node data requests that exceeded a timeout -// allowance, canceling them and returning the responsible peers for penalization. +// allowance, canceling them and returning the responsible peers for penalisation. func (q *queue) ExpireNodeData(timeout time.Duration) []string { + q.lock.Lock() + defer q.lock.Unlock() + return q.expire(timeout, q.statePendPool, q.stateTaskQueue, stateTimeoutMeter) } // expire is the generic check that move expired tasks from a pending pool back // into a task pool, returning all entities caught with expired tasks. +// +// Note, this method expects the queue lock to be already held for writing. The +// reason the lock is not obtained in here is because the parameters already need +// to access the queue, so they already need a lock anyway. func (q *queue) expire(timeout time.Duration, pendPool map[string]*fetchRequest, taskQueue *prque.Prque, timeoutMeter metrics.Meter) []string { - q.lock.Lock() - defer q.lock.Unlock() - // Iterate over the expired requests and return each to the queue peers := []string{} for id, request := range pendPool { @@ -764,7 +821,7 @@ func (q *queue) DeliverBlocks(id string, blocks []*types.Block) error { case len(errs) == 1 && (errs[0] == errInvalidChain || errs[0] == errInvalidBlock): return errs[0] - case len(errs) == len(request.Headers): + case len(errs) == len(blocks): return errStaleDelivery default: @@ -774,6 +831,9 @@ func (q *queue) DeliverBlocks(id string, blocks []*types.Block) error { // DeliverBodies injects a block body retrieval response into the results queue. func (q *queue) DeliverBodies(id string, txLists [][]*types.Transaction, uncleLists [][]*types.Header) error { + q.lock.Lock() + defer q.lock.Unlock() + reconstruct := func(header *types.Header, index int, result *fetchResult) error { if types.DeriveSha(types.Transactions(txLists[index])) != header.TxHash || types.CalcUncleHash(uncleLists[index]) != header.UncleHash { return errInvalidBody @@ -787,6 +847,9 @@ func (q *queue) DeliverBodies(id string, txLists [][]*types.Transaction, uncleLi // DeliverReceipts injects a receipt retrieval response into the results queue. func (q *queue) DeliverReceipts(id string, receiptList [][]*types.Receipt) error { + q.lock.Lock() + defer q.lock.Unlock() + reconstruct := func(header *types.Header, index int, result *fetchResult) error { if types.DeriveSha(types.Receipts(receiptList[index])) != header.ReceiptHash { return errInvalidReceipt @@ -798,11 +861,12 @@ func (q *queue) DeliverReceipts(id string, receiptList [][]*types.Receipt) error } // deliver injects a data retrieval response into the results queue. +// +// Note, this method expects the queue lock to be already held for writing. The +// reason the lock is not obtained in here is because the parameters already need +// to access the queue, so they already need a lock anyway. func (q *queue) deliver(id string, taskPool map[common.Hash]*types.Header, taskQueue *prque.Prque, pendPool map[string]*fetchRequest, donePool map[common.Hash]struct{}, reqTimer metrics.Timer, results int, reconstruct func(header *types.Header, index int, result *fetchResult) error) error { - q.lock.Lock() - defer q.lock.Unlock() - // Short circuit if the data was never requested request := pendPool[id] if request == nil { @@ -818,7 +882,10 @@ func (q *queue) deliver(id string, taskPool map[common.Hash]*types.Header, taskQ } } // Assemble each of the results with their headers and retrieved data parts - errs := make([]error, 0) + var ( + failure error + useful bool + ) for i, header := range request.Headers { // Short circuit assembly if no more fetch results are found if i >= results { @@ -827,15 +894,16 @@ func (q *queue) deliver(id string, taskPool map[common.Hash]*types.Header, taskQ // Reconstruct the next result if contents match up index := int(header.Number.Int64() - int64(q.resultOffset)) if index >= len(q.resultCache) || index < 0 || q.resultCache[index] == nil { - errs = []error{errInvalidChain} + failure = errInvalidChain break } if err := reconstruct(header, i, q.resultCache[index]); err != nil { - errs = []error{err} + failure = err break } donePool[header.Hash()] = struct{}{} q.resultCache[index].Pending-- + useful = true // Clean up a successful fetch request.Headers[i] = nil @@ -847,19 +915,16 @@ func (q *queue) deliver(id string, taskPool map[common.Hash]*types.Header, taskQ taskQueue.Push(header, -float32(header.Number.Uint64())) } } - // If none of the blocks were good, it's a stale delivery + // If none of the data was good, it's a stale delivery switch { - case len(errs) == 0: - return nil - - case len(errs) == 1 && (errs[0] == errInvalidChain || errs[0] == errInvalidBody || errs[0] == errInvalidReceipt): - return errs[0] + case failure == nil || failure == errInvalidChain: + return failure - case len(errs) == len(request.Headers): - return errStaleDelivery + case useful: + return fmt.Errorf("partial failure: %v", failure) default: - return fmt.Errorf("multiple failures: %v", errs) + return errStaleDelivery } } @@ -876,7 +941,7 @@ func (q *queue) DeliverNodeData(id string, data [][]byte, callback func(error, i stateReqTimer.UpdateSince(request.Time) delete(q.statePendPool, id) - // If no data was retrieved, mark them as unavailable for the origin peer + // If no data was retrieved, mark their hashes as unavailable for the origin peer if len(data) == 0 { for hash, _ := range request.Hashes { request.Peer.ignored.Add(hash) @@ -955,9 +1020,6 @@ func (q *queue) Prepare(offset uint64, mode SyncMode, pivot uint64) { if q.resultOffset < offset { q.resultOffset = offset } - q.fastSyncPivot = 0 - if mode == FastSync { - q.fastSyncPivot = pivot - } + q.fastSyncPivot = pivot q.mode = mode } diff --git a/eth/fetcher/fetcher.go b/eth/fetcher/fetcher.go index b8ec1fc55..d88d91982 100644 --- a/eth/fetcher/fetcher.go +++ b/eth/fetcher/fetcher.go @@ -142,9 +142,11 @@ type Fetcher struct { dropPeer peerDropFn // Drops a peer for misbehaving // Testing hooks - fetchingHook func([]common.Hash) // Method to call upon starting a block (eth/61) or header (eth/62) fetch - completingHook func([]common.Hash) // Method to call upon starting a block body fetch (eth/62) - importedHook func(*types.Block) // Method to call upon successful block import (both eth/61 and eth/62) + announceChangeHook func(common.Hash, bool) // Method to call upon adding or deleting a hash from the announce list + queueChangeHook func(common.Hash, bool) // Method to call upon adding or deleting a block from the import queue + fetchingHook func([]common.Hash) // Method to call upon starting a block (eth/61) or header (eth/62) fetch + completingHook func([]common.Hash) // Method to call upon starting a block body fetch (eth/62) + importedHook func(*types.Block) // Method to call upon successful block import (both eth/61 and eth/62) } // New creates a block fetcher to retrieve blocks based on hash announcements. @@ -324,11 +326,16 @@ func (f *Fetcher) loop() { height := f.chainHeight() for !f.queue.Empty() { op := f.queue.PopItem().(*inject) - + if f.queueChangeHook != nil { + f.queueChangeHook(op.block.Hash(), false) + } // If too high up the chain or phase, continue later number := op.block.NumberU64() if number > height+1 { f.queue.Push(op, -float32(op.block.NumberU64())) + if f.queueChangeHook != nil { + f.queueChangeHook(op.block.Hash(), true) + } break } // Otherwise if fresh and still unknown, try and import @@ -372,6 +379,9 @@ func (f *Fetcher) loop() { } f.announces[notification.origin] = count f.announced[notification.hash] = append(f.announced[notification.hash], notification) + if f.announceChangeHook != nil && len(f.announced[notification.hash]) == 1 { + f.announceChangeHook(notification.hash, true) + } if len(f.announced) == 1 { f.rescheduleFetch(fetchTimer) } @@ -714,7 +724,9 @@ func (f *Fetcher) enqueue(peer string, block *types.Block) { f.queues[peer] = count f.queued[hash] = op f.queue.Push(op, -float32(block.NumberU64())) - + if f.queueChangeHook != nil { + f.queueChangeHook(op.block.Hash(), true) + } if glog.V(logger.Debug) { glog.Infof("Peer %s: queued block #%d [%x…], total %v", peer, block.NumberU64(), hash.Bytes()[:4], f.queue.Size()) } @@ -781,7 +793,9 @@ func (f *Fetcher) forgetHash(hash common.Hash) { } } delete(f.announced, hash) - + if f.announceChangeHook != nil { + f.announceChangeHook(hash, false) + } // Remove any pending fetches and decrement the DOS counters if announce := f.fetching[hash]; announce != nil { f.announces[announce.origin]-- diff --git a/eth/fetcher/fetcher_test.go b/eth/fetcher/fetcher_test.go index 170a80aba..2404c8cfa 100644 --- a/eth/fetcher/fetcher_test.go +++ b/eth/fetcher/fetcher_test.go @@ -145,6 +145,9 @@ func (f *fetcherTester) insertChain(blocks types.Blocks) (int, error) { // dropPeer is an emulator for the peer removal, simply accumulating the various // peers dropped by the fetcher. func (f *fetcherTester) dropPeer(peer string) { + f.lock.Lock() + defer f.lock.Unlock() + f.drops[peer] = true } @@ -608,8 +611,11 @@ func TestDistantPropagationDiscarding(t *testing.T) { // Create a tester and simulate a head block being the middle of the above chain tester := newTester() + + tester.lock.Lock() tester.hashes = []common.Hash{head} tester.blocks = map[common.Hash]*types.Block{head: blocks[head]} + tester.lock.Unlock() // Ensure that a block with a lower number than the threshold is discarded tester.fetcher.Enqueue("lower", blocks[hashes[low]]) @@ -641,8 +647,11 @@ func testDistantAnnouncementDiscarding(t *testing.T, protocol int) { // Create a tester and simulate a head block being the middle of the above chain tester := newTester() + + tester.lock.Lock() tester.hashes = []common.Hash{head} tester.blocks = map[common.Hash]*types.Block{head: blocks[head]} + tester.lock.Unlock() headerFetcher := tester.makeHeaderFetcher(blocks, -gatherSlack) bodyFetcher := tester.makeBodyFetcher(blocks, 0) @@ -687,14 +696,22 @@ func testInvalidNumberAnnouncement(t *testing.T, protocol int) { tester.fetcher.Notify("bad", hashes[0], 2, time.Now().Add(-arriveTimeout), nil, headerFetcher, bodyFetcher) verifyImportEvent(t, imported, false) - if !tester.drops["bad"] { + tester.lock.RLock() + dropped := tester.drops["bad"] + tester.lock.RUnlock() + + if !dropped { t.Fatalf("peer with invalid numbered announcement not dropped") } // Make sure a good announcement passes without a drop tester.fetcher.Notify("good", hashes[0], 1, time.Now().Add(-arriveTimeout), nil, headerFetcher, bodyFetcher) verifyImportEvent(t, imported, true) - if tester.drops["good"] { + tester.lock.RLock() + dropped = tester.drops["good"] + tester.lock.RUnlock() + + if dropped { t.Fatalf("peer with valid numbered announcement dropped") } verifyImportDone(t, imported) @@ -752,9 +769,15 @@ func testHashMemoryExhaustionAttack(t *testing.T, protocol int) { // Create a tester with instrumented import hooks tester := newTester() - imported := make(chan *types.Block) + imported, announces := make(chan *types.Block), int32(0) tester.fetcher.importedHook = func(block *types.Block) { imported <- block } - + tester.fetcher.announceChangeHook = func(hash common.Hash, added bool) { + if added { + atomic.AddInt32(&announces, 1) + } else { + atomic.AddInt32(&announces, -1) + } + } // Create a valid chain and an infinite junk chain targetBlocks := hashLimit + 2*maxQueueDist hashes, blocks := makeChain(targetBlocks, 0, genesis) @@ -782,8 +805,8 @@ func testHashMemoryExhaustionAttack(t *testing.T, protocol int) { tester.fetcher.Notify("attacker", attack[i], 1 /* don't distance drop */, time.Now(), nil, attackerHeaderFetcher, attackerBodyFetcher) } } - if len(tester.fetcher.announced) != hashLimit+maxQueueDist { - t.Fatalf("queued announce count mismatch: have %d, want %d", len(tester.fetcher.announced), hashLimit+maxQueueDist) + if count := atomic.LoadInt32(&announces); count != hashLimit+maxQueueDist { + t.Fatalf("queued announce count mismatch: have %d, want %d", count, hashLimit+maxQueueDist) } // Wait for fetches to complete verifyImportCount(t, imported, maxQueueDist) @@ -807,9 +830,15 @@ func TestBlockMemoryExhaustionAttack(t *testing.T) { // Create a tester with instrumented import hooks tester := newTester() - imported := make(chan *types.Block) + imported, enqueued := make(chan *types.Block), int32(0) tester.fetcher.importedHook = func(block *types.Block) { imported <- block } - + tester.fetcher.queueChangeHook = func(hash common.Hash, added bool) { + if added { + atomic.AddInt32(&enqueued, 1) + } else { + atomic.AddInt32(&enqueued, -1) + } + } // Create a valid chain and a batch of dangling (but in range) blocks targetBlocks := hashLimit + 2*maxQueueDist hashes, blocks := makeChain(targetBlocks, 0, genesis) @@ -825,7 +854,7 @@ func TestBlockMemoryExhaustionAttack(t *testing.T) { tester.fetcher.Enqueue("attacker", block) } time.Sleep(200 * time.Millisecond) - if queued := tester.fetcher.queue.Size(); queued != blockLimit { + if queued := atomic.LoadInt32(&enqueued); queued != blockLimit { t.Fatalf("queued block count mismatch: have %d, want %d", queued, blockLimit) } // Queue up a batch of valid blocks, and check that a new peer is allowed to do so @@ -833,7 +862,7 @@ func TestBlockMemoryExhaustionAttack(t *testing.T) { tester.fetcher.Enqueue("valid", blocks[hashes[len(hashes)-3-i]]) } time.Sleep(100 * time.Millisecond) - if queued := tester.fetcher.queue.Size(); queued != blockLimit+maxQueueDist-1 { + if queued := atomic.LoadInt32(&enqueued); queued != blockLimit+maxQueueDist-1 { t.Fatalf("queued block count mismatch: have %d, want %d", queued, blockLimit+maxQueueDist-1) } // Insert the missing piece (and sanity check the import) diff --git a/eth/filters/filter_test.go b/eth/filters/filter_test.go index 9e7538fac..a5418e2e7 100644 --- a/eth/filters/filter_test.go +++ b/eth/filters/filter_test.go @@ -16,9 +16,9 @@ import ( func makeReceipt(addr common.Address) *types.Receipt { receipt := types.NewReceipt(nil, new(big.Int)) - receipt.SetLogs(vm.Logs{ + receipt.Logs = vm.Logs{ &vm.Log{Address: addr}, - }) + } receipt.Bloom = types.CreateBloom(types.Receipts{receipt}) return receipt } @@ -41,7 +41,7 @@ func BenchmarkMipmaps(b *testing.B) { defer db.Close() genesis := core.WriteGenesisBlockForTesting(db, core.GenesisAccount{addr1, big.NewInt(1000000)}) - chain := core.GenerateChain(genesis, db, 100010, func(i int, gen *core.BlockGen) { + chain, receipts := core.GenerateChain(genesis, db, 100010, func(i int, gen *core.BlockGen) { var receipts types.Receipts switch i { case 2403: @@ -70,7 +70,7 @@ func BenchmarkMipmaps(b *testing.B) { } core.WriteMipmapBloom(db, uint64(i+1), receipts) }) - for _, block := range chain { + for i, block := range chain { core.WriteBlock(db, block) if err := core.WriteCanonicalHash(db, block.Hash(), block.NumberU64()); err != nil { b.Fatalf("failed to insert block number: %v", err) @@ -78,11 +78,10 @@ func BenchmarkMipmaps(b *testing.B) { if err := core.WriteHeadBlockHash(db, block.Hash()); err != nil { b.Fatalf("failed to insert block number: %v", err) } - if err := core.PutBlockReceipts(db, block, block.Receipts()); err != nil { + if err := core.PutBlockReceipts(db, block.Hash(), receipts[i]); err != nil { b.Fatal("error writing block receipts:", err) } } - b.ResetTimer() filter := New(db) @@ -118,47 +117,47 @@ func TestFilters(t *testing.T) { defer db.Close() genesis := core.WriteGenesisBlockForTesting(db, core.GenesisAccount{addr, big.NewInt(1000000)}) - chain := core.GenerateChain(genesis, db, 1000, func(i int, gen *core.BlockGen) { + chain, receipts := core.GenerateChain(genesis, db, 1000, func(i int, gen *core.BlockGen) { var receipts types.Receipts switch i { case 1: receipt := types.NewReceipt(nil, new(big.Int)) - receipt.SetLogs(vm.Logs{ + receipt.Logs = vm.Logs{ &vm.Log{ Address: addr, Topics: []common.Hash{hash1}, }, - }) + } gen.AddUncheckedReceipt(receipt) receipts = types.Receipts{receipt} case 2: receipt := types.NewReceipt(nil, new(big.Int)) - receipt.SetLogs(vm.Logs{ + receipt.Logs = vm.Logs{ &vm.Log{ Address: addr, Topics: []common.Hash{hash2}, }, - }) + } gen.AddUncheckedReceipt(receipt) receipts = types.Receipts{receipt} case 998: receipt := types.NewReceipt(nil, new(big.Int)) - receipt.SetLogs(vm.Logs{ + receipt.Logs = vm.Logs{ &vm.Log{ Address: addr, Topics: []common.Hash{hash3}, }, - }) + } gen.AddUncheckedReceipt(receipt) receipts = types.Receipts{receipt} case 999: receipt := types.NewReceipt(nil, new(big.Int)) - receipt.SetLogs(vm.Logs{ + receipt.Logs = vm.Logs{ &vm.Log{ Address: addr, Topics: []common.Hash{hash4}, }, - }) + } gen.AddUncheckedReceipt(receipt) receipts = types.Receipts{receipt} } @@ -173,7 +172,7 @@ func TestFilters(t *testing.T) { // by one core.WriteMipmapBloom(db, uint64(i+1), receipts) }) - for _, block := range chain { + for i, block := range chain { core.WriteBlock(db, block) if err := core.WriteCanonicalHash(db, block.Hash(), block.NumberU64()); err != nil { t.Fatalf("failed to insert block number: %v", err) @@ -181,7 +180,7 @@ func TestFilters(t *testing.T) { if err := core.WriteHeadBlockHash(db, block.Hash()); err != nil { t.Fatalf("failed to insert block number: %v", err) } - if err := core.PutBlockReceipts(db, block, block.Receipts()); err != nil { + if err := core.PutBlockReceipts(db, block.Hash(), receipts[i]); err != nil { t.Fatal("error writing block receipts:", err) } } diff --git a/eth/handler.go b/eth/handler.go index 725178035..7dc7de80e 100644 --- a/eth/handler.go +++ b/eth/handler.go @@ -84,6 +84,11 @@ type ProtocolManager struct { // NewProtocolManager returns a new ethereum sub protocol manager. The Ethereum sub protocol manages peers capable // with the ethereum network. func NewProtocolManager(fastSync bool, networkId int, mux *event.TypeMux, txpool txPool, pow pow.PoW, blockchain *core.BlockChain, chaindb ethdb.Database) (*ProtocolManager, error) { + // Figure out whether to allow fast sync or not + if fastSync && blockchain.CurrentBlock().NumberU64() > 0 { + glog.V(logger.Info).Infof("blockchain not empty, fast sync disabled") + fastSync = false + } // Create the protocol manager with the base fields manager := &ProtocolManager{ fastSync: fastSync, @@ -103,7 +108,7 @@ func NewProtocolManager(fastSync bool, networkId int, mux *event.TypeMux, txpool if fastSync && version < eth63 { continue } - // Compatible, initialize the sub-protocol + // Compatible; initialise the sub-protocol version := version // Closure for the run manager.SubProtocols = append(manager.SubProtocols, p2p.Protocol{ Name: "eth", @@ -120,13 +125,9 @@ func NewProtocolManager(fastSync bool, networkId int, mux *event.TypeMux, txpool return nil, errIncompatibleConfig } // Construct the different synchronisation mechanisms - syncMode := downloader.FullSync - if fastSync { - syncMode = downloader.FastSync - } - manager.downloader = downloader.New(syncMode, chaindb, manager.eventMux, blockchain.HasHeader, blockchain.HasBlock, blockchain.GetHeader, - blockchain.GetBlock, blockchain.CurrentHeader, blockchain.CurrentBlock, blockchain.CurrentFastBlock, blockchain.FastSyncCommitHead, - blockchain.GetTd, blockchain.InsertHeaderChain, blockchain.InsertChain, blockchain.InsertReceiptChain, blockchain.Rollback, manager.removePeer) + manager.downloader = downloader.New(chaindb, manager.eventMux, blockchain.HasHeader, blockchain.HasBlock, blockchain.GetHeader, blockchain.GetBlock, + blockchain.CurrentHeader, blockchain.CurrentBlock, blockchain.CurrentFastBlock, blockchain.FastSyncCommitHead, blockchain.GetTd, + blockchain.InsertHeaderChain, blockchain.InsertChain, blockchain.InsertReceiptChain, blockchain.Rollback, manager.removePeer) validator := func(block *types.Block, parent *types.Block) error { return core.ValidateHeader(pow, block.Header(), parent.Header(), true, false) diff --git a/eth/handler_test.go b/eth/handler_test.go index 843b02fd4..ab2ce54b1 100644 --- a/eth/handler_test.go +++ b/eth/handler_test.go @@ -443,7 +443,9 @@ func testGetNodeData(t *testing.T, protocol int) { // Fetch for now the entire chain db hashes := []common.Hash{} for _, key := range pm.chaindb.(*ethdb.MemDatabase).Keys() { - hashes = append(hashes, common.BytesToHash(key)) + if len(key) == len(common.Hash{}) { + hashes = append(hashes, common.BytesToHash(key)) + } } p2p.Send(peer.app, 0x0d, hashes) msg, err := peer.app.ReadMsg() diff --git a/eth/metrics.go b/eth/metrics.go index cfab3bcb3..8231a06ff 100644 --- a/eth/metrics.go +++ b/eth/metrics.go @@ -101,7 +101,7 @@ func (rw *meteredMsgReadWriter) ReadMsg() (p2p.Msg, error) { packets, traffic = reqBlockInPacketsMeter, reqBlockInTrafficMeter case rw.version >= eth62 && msg.Code == BlockHeadersMsg: - packets, traffic = reqBlockInPacketsMeter, reqBlockInTrafficMeter + packets, traffic = reqHeaderInPacketsMeter, reqHeaderInTrafficMeter case rw.version >= eth62 && msg.Code == BlockBodiesMsg: packets, traffic = reqBodyInPacketsMeter, reqBodyInTrafficMeter diff --git a/eth/sync.go b/eth/sync.go index 6295083e2..b69a24556 100644 --- a/eth/sync.go +++ b/eth/sync.go @@ -22,6 +22,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/eth/downloader" "github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger/glog" "github.com/ethereum/go-ethereum/p2p/discover" @@ -165,5 +166,20 @@ func (pm *ProtocolManager) synchronise(peer *peer) { return } // Otherwise try to sync with the downloader - pm.downloader.Synchronise(peer.id, peer.Head(), peer.Td()) + mode := downloader.FullSync + if pm.fastSync { + mode = downloader.FastSync + } + pm.downloader.Synchronise(peer.id, peer.Head(), peer.Td(), mode) + + // If fast sync was enabled, and we synced up, disable it + if pm.fastSync { + for pm.downloader.Synchronising() { + time.Sleep(100 * time.Millisecond) + } + if pm.blockchain.CurrentBlock().NumberU64() > 0 { + glog.V(logger.Info).Infof("fast sync complete, auto disabling") + pm.fastSync = false + } + } } diff --git a/eth/sync_test.go b/eth/sync_test.go new file mode 100644 index 000000000..f3a6718ab --- /dev/null +++ b/eth/sync_test.go @@ -0,0 +1,53 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package eth + +import ( + "testing" + "time" + + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/discover" +) + +// Tests that fast sync gets disabled as soon as a real block is successfully +// imported into the blockchain. +func TestFastSyncDisabling(t *testing.T) { + // Create a pristine protocol manager, check that fast sync is left enabled + pmEmpty := newTestProtocolManagerMust(t, true, 0, nil, nil) + if !pmEmpty.fastSync { + t.Fatalf("fast sync disabled on pristine blockchain") + } + // Create a full protocol manager, check that fast sync gets disabled + pmFull := newTestProtocolManagerMust(t, true, 1024, nil, nil) + if pmFull.fastSync { + t.Fatalf("fast sync not disabled on non-empty blockchain") + } + // Sync up the two peers + io1, io2 := p2p.MsgPipe() + + go pmFull.handle(pmFull.newPeer(63, NetworkId, p2p.NewPeer(discover.NodeID{}, "empty", nil), io2)) + go pmEmpty.handle(pmEmpty.newPeer(63, NetworkId, p2p.NewPeer(discover.NodeID{}, "full", nil), io1)) + + time.Sleep(250 * time.Millisecond) + pmEmpty.synchronise(pmEmpty.peers.BestPeer()) + + // Check that fast sync was disabled + if pmEmpty.fastSync { + t.Fatalf("fast sync not disabled after successful synchronisation") + } +} diff --git a/ethdb/memory_database.go b/ethdb/memory_database.go index 330834fa4..01273b9db 100644 --- a/ethdb/memory_database.go +++ b/ethdb/memory_database.go @@ -17,6 +17,7 @@ package ethdb import ( + "errors" "fmt" "sync" @@ -56,7 +57,10 @@ func (db *MemDatabase) Get(key []byte) ([]byte, error) { db.lock.RLock() defer db.lock.RUnlock() - return db.db[string(key)], nil + if entry, ok := db.db[string(key)]; ok { + return entry, nil + } + return nil, errors.New("not found") } func (db *MemDatabase) Keys() [][]byte { @@ -132,8 +136,8 @@ func (b *memBatch) Write() error { b.lock.RLock() defer b.lock.RUnlock() - b.db.lock.RLock() - defer b.db.lock.RUnlock() + b.db.lock.Lock() + defer b.db.lock.Unlock() for _, kv := range b.writes { b.db.db[string(kv.k)] = kv.v diff --git a/rpc/api/eth.go b/rpc/api/eth.go index 6db006a46..4722682ff 100644 --- a/rpc/api/eth.go +++ b/rpc/api/eth.go @@ -168,9 +168,7 @@ func (self *ethApi) IsMining(req *shared.Request) (interface{}, error) { } func (self *ethApi) IsSyncing(req *shared.Request) (interface{}, error) { - current := self.ethereum.BlockChain().CurrentBlock().NumberU64() - origin, height := self.ethereum.Downloader().Boundaries() - + origin, current, height := self.ethereum.Downloader().Progress() if current < height { return map[string]interface{}{ "startingBlock": newHexNum(big.NewInt(int64(origin)).Bytes()), diff --git a/trie/sync.go b/trie/sync.go index bb112fb62..d55399d06 100644 --- a/trie/sync.go +++ b/trie/sync.go @@ -31,7 +31,7 @@ type request struct { object *node // Target node to populate with retrieved data (hashnode originally) parents []*request // Parent state nodes referencing this entry (notify all upon completion) - depth int // Depth level within the trie the node is located to prioritize DFS + depth int // Depth level within the trie the node is located to prioritise DFS deps int // Number of dependencies before allowed to commit this node callback TrieSyncLeafCallback // Callback to invoke if a leaf node it reached on this branch