core, eth, trie: fix data races and merge/review issues

pull/1889/head
Péter Szilágyi 9 years ago
parent aa0538db0b
commit 5b0ee8ec30
  1. 14
      core/block_processor.go
  2. 71
      core/blockchain.go
  3. 2
      core/blockchain_test.go
  4. 2
      core/chain_util.go
  5. 28
      core/chain_util_test.go
  6. 3
      core/state/sync.go
  7. 4
      core/state/sync_test.go
  8. 10
      core/types/receipt.go
  9. 2
      core/vm/log.go
  10. 1
      eth/backend.go
  11. 10
      eth/backend_test.go
  12. 191
      eth/downloader/downloader.go
  13. 464
      eth/downloader/downloader_test.go
  14. 4
      eth/downloader/modes.go
  15. 28
      eth/downloader/peer.go
  16. 178
      eth/downloader/queue.go
  17. 26
      eth/fetcher/fetcher.go
  18. 49
      eth/fetcher/fetcher_test.go
  19. 33
      eth/filters/filter_test.go
  20. 17
      eth/handler.go
  21. 4
      eth/handler_test.go
  22. 2
      eth/metrics.go
  23. 18
      eth/sync.go
  24. 53
      eth/sync_test.go
  25. 10
      ethdb/memory_database.go
  26. 4
      rpc/api/eth.go
  27. 2
      trie/sync.go

@ -195,14 +195,16 @@ func (sm *BlockProcessor) Process(block *types.Block) (logs vm.Logs, receipts ty
defer sm.mutex.Unlock() defer sm.mutex.Unlock()
if sm.bc.HasBlock(block.Hash()) { if sm.bc.HasBlock(block.Hash()) {
return nil, nil, &KnownBlockError{block.Number(), block.Hash()} if _, err := state.New(block.Root(), sm.chainDb); err == nil {
return nil, nil, &KnownBlockError{block.Number(), block.Hash()}
}
} }
if parent := sm.bc.GetBlock(block.ParentHash()); parent != nil {
if !sm.bc.HasBlock(block.ParentHash()) { if _, err := state.New(parent.Root(), sm.chainDb); err == nil {
return nil, nil, ParentError(block.ParentHash()) return sm.processWithParent(block, parent)
}
} }
parent := sm.bc.GetBlock(block.ParentHash()) return nil, nil, ParentError(block.ParentHash())
return sm.processWithParent(block, parent)
} }
func (sm *BlockProcessor) processWithParent(block, parent *types.Block) (logs vm.Logs, receipts types.Receipts, err error) { func (sm *BlockProcessor) processWithParent(block, parent *types.Block) (logs vm.Logs, receipts types.Receipts, err error) {

@ -18,11 +18,13 @@
package core package core
import ( import (
crand "crypto/rand"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"math"
"math/big" "math/big"
"math/rand" mrand "math/rand"
"runtime" "runtime"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -89,7 +91,8 @@ type BlockChain struct {
procInterrupt int32 // interrupt signaler for block processing procInterrupt int32 // interrupt signaler for block processing
wg sync.WaitGroup wg sync.WaitGroup
pow pow.PoW pow pow.PoW
rand *mrand.Rand
} }
func NewBlockChain(chainDb ethdb.Database, pow pow.PoW, mux *event.TypeMux) (*BlockChain, error) { func NewBlockChain(chainDb ethdb.Database, pow pow.PoW, mux *event.TypeMux) (*BlockChain, error) {
@ -112,6 +115,12 @@ func NewBlockChain(chainDb ethdb.Database, pow pow.PoW, mux *event.TypeMux) (*Bl
futureBlocks: futureBlocks, futureBlocks: futureBlocks,
pow: pow, pow: pow,
} }
// Seed a fast but crypto originating random generator
seed, err := crand.Int(crand.Reader, big.NewInt(math.MaxInt64))
if err != nil {
return nil, err
}
bc.rand = mrand.New(mrand.NewSource(seed.Int64()))
bc.genesisBlock = bc.GetBlockByNumber(0) bc.genesisBlock = bc.GetBlockByNumber(0)
if bc.genesisBlock == nil { if bc.genesisBlock == nil {
@ -178,21 +187,21 @@ func (self *BlockChain) loadLastState() error {
fastTd := self.GetTd(self.currentFastBlock.Hash()) fastTd := self.GetTd(self.currentFastBlock.Hash())
glog.V(logger.Info).Infof("Last header: #%d [%x…] TD=%v", self.currentHeader.Number, self.currentHeader.Hash().Bytes()[:4], headerTd) glog.V(logger.Info).Infof("Last header: #%d [%x…] TD=%v", self.currentHeader.Number, self.currentHeader.Hash().Bytes()[:4], headerTd)
glog.V(logger.Info).Infof("Fast block: #%d [%x…] TD=%v", self.currentFastBlock.Number(), self.currentFastBlock.Hash().Bytes()[:4], fastTd)
glog.V(logger.Info).Infof("Last block: #%d [%x…] TD=%v", self.currentBlock.Number(), self.currentBlock.Hash().Bytes()[:4], blockTd) glog.V(logger.Info).Infof("Last block: #%d [%x…] TD=%v", self.currentBlock.Number(), self.currentBlock.Hash().Bytes()[:4], blockTd)
glog.V(logger.Info).Infof("Fast block: #%d [%x…] TD=%v", self.currentFastBlock.Number(), self.currentFastBlock.Hash().Bytes()[:4], fastTd)
return nil return nil
} }
// SetHead rewind the local chain to a new head entity. In the case of headers, // SetHead rewinds the local chain to a new head. In the case of headers, everything
// everything above the new head will be deleted and the new one set. In the case // above the new head will be deleted and the new one set. In the case of blocks
// of blocks though, the head may be further rewound if block bodies are missing // though, the head may be further rewound if block bodies are missing (non-archive
// (non-archive nodes after a fast sync). // nodes after a fast sync).
func (bc *BlockChain) SetHead(head uint64) { func (bc *BlockChain) SetHead(head uint64) {
bc.mu.Lock() bc.mu.Lock()
defer bc.mu.Unlock() defer bc.mu.Unlock()
// Figure out the highest known canonical assignment // Figure out the highest known canonical headers and/or blocks
height := uint64(0) height := uint64(0)
if bc.currentHeader != nil { if bc.currentHeader != nil {
if hh := bc.currentHeader.Number.Uint64(); hh > height { if hh := bc.currentHeader.Number.Uint64(); hh > height {
@ -266,7 +275,7 @@ func (bc *BlockChain) SetHead(head uint64) {
// FastSyncCommitHead sets the current head block to the one defined by the hash // FastSyncCommitHead sets the current head block to the one defined by the hash
// irrelevant what the chain contents were prior. // irrelevant what the chain contents were prior.
func (self *BlockChain) FastSyncCommitHead(hash common.Hash) error { func (self *BlockChain) FastSyncCommitHead(hash common.Hash) error {
// Make sure that both the block as well at it's state trie exists // Make sure that both the block as well at its state trie exists
block := self.GetBlock(hash) block := self.GetBlock(hash)
if block == nil { if block == nil {
return fmt.Errorf("non existent block [%x…]", hash[:4]) return fmt.Errorf("non existent block [%x…]", hash[:4])
@ -298,7 +307,7 @@ func (self *BlockChain) LastBlockHash() common.Hash {
} }
// CurrentHeader retrieves the current head header of the canonical chain. The // CurrentHeader retrieves the current head header of the canonical chain. The
// header is retrieved from the chain manager's internal cache. // header is retrieved from the blockchain's internal cache.
func (self *BlockChain) CurrentHeader() *types.Header { func (self *BlockChain) CurrentHeader() *types.Header {
self.mu.RLock() self.mu.RLock()
defer self.mu.RUnlock() defer self.mu.RUnlock()
@ -307,7 +316,7 @@ func (self *BlockChain) CurrentHeader() *types.Header {
} }
// CurrentBlock retrieves the current head block of the canonical chain. The // CurrentBlock retrieves the current head block of the canonical chain. The
// block is retrieved from the chain manager's internal cache. // block is retrieved from the blockchain's internal cache.
func (self *BlockChain) CurrentBlock() *types.Block { func (self *BlockChain) CurrentBlock() *types.Block {
self.mu.RLock() self.mu.RLock()
defer self.mu.RUnlock() defer self.mu.RUnlock()
@ -316,7 +325,7 @@ func (self *BlockChain) CurrentBlock() *types.Block {
} }
// CurrentFastBlock retrieves the current fast-sync head block of the canonical // CurrentFastBlock retrieves the current fast-sync head block of the canonical
// chain. The block is retrieved from the chain manager's internal cache. // chain. The block is retrieved from the blockchain's internal cache.
func (self *BlockChain) CurrentFastBlock() *types.Block { func (self *BlockChain) CurrentFastBlock() *types.Block {
self.mu.RLock() self.mu.RLock()
defer self.mu.RUnlock() defer self.mu.RUnlock()
@ -353,7 +362,7 @@ func (bc *BlockChain) ResetWithGenesisBlock(genesis *types.Block) {
bc.mu.Lock() bc.mu.Lock()
defer bc.mu.Unlock() defer bc.mu.Unlock()
// Prepare the genesis block and reinitialize the chain // Prepare the genesis block and reinitialise the chain
if err := WriteTd(bc.chainDb, genesis.Hash(), genesis.Difficulty()); err != nil { if err := WriteTd(bc.chainDb, genesis.Hash(), genesis.Difficulty()); err != nil {
glog.Fatalf("failed to write genesis block TD: %v", err) glog.Fatalf("failed to write genesis block TD: %v", err)
} }
@ -403,7 +412,7 @@ func (self *BlockChain) ExportN(w io.Writer, first uint64, last uint64) error {
// insert injects a new head block into the current block chain. This method // insert injects a new head block into the current block chain. This method
// assumes that the block is indeed a true head. It will also reset the head // assumes that the block is indeed a true head. It will also reset the head
// header and the head fast sync block to this very same block to prevent them // header and the head fast sync block to this very same block to prevent them
// from diverging on a different header chain. // from pointing to a possibly old canonical chain (i.e. side chain by now).
// //
// Note, this function assumes that the `mu` mutex is held! // Note, this function assumes that the `mu` mutex is held!
func (bc *BlockChain) insert(block *types.Block) { func (bc *BlockChain) insert(block *types.Block) {
@ -625,10 +634,10 @@ const (
// writeHeader writes a header into the local chain, given that its parent is // writeHeader writes a header into the local chain, given that its parent is
// already known. If the total difficulty of the newly inserted header becomes // already known. If the total difficulty of the newly inserted header becomes
// greater than the old known TD, the canonical chain is re-routed. // greater than the current known TD, the canonical chain is re-routed.
// //
// Note: This method is not concurrent-safe with inserting blocks simultaneously // Note: This method is not concurrent-safe with inserting blocks simultaneously
// into the chain, as side effects caused by reorganizations cannot be emulated // into the chain, as side effects caused by reorganisations cannot be emulated
// without the real blocks. Hence, writing headers directly should only be done // without the real blocks. Hence, writing headers directly should only be done
// in two scenarios: pure-header mode of operation (light clients), or properly // in two scenarios: pure-header mode of operation (light clients), or properly
// separated header/block phases (non-archive clients). // separated header/block phases (non-archive clients).
@ -678,10 +687,9 @@ func (self *BlockChain) writeHeader(header *types.Header) error {
return nil return nil
} }
// InsertHeaderChain will attempt to insert the given header chain in to the // InsertHeaderChain attempts to insert the given header chain in to the local
// local chain, possibly creating a fork. If an error is returned, it will // chain, possibly creating a reorg. If an error is returned, it will return the
// return the index number of the failing header as well an error describing // index number of the failing header as well an error describing what went wrong.
// what went wrong.
// //
// The verify parameter can be used to fine tune whether nonce verification // The verify parameter can be used to fine tune whether nonce verification
// should be done or not. The reason behind the optional check is because some // should be done or not. The reason behind the optional check is because some
@ -702,7 +710,7 @@ func (self *BlockChain) InsertHeaderChain(chain []*types.Header, checkFreq int)
// Generate the list of headers that should be POW verified // Generate the list of headers that should be POW verified
verify := make([]bool, len(chain)) verify := make([]bool, len(chain))
for i := 0; i < len(verify)/checkFreq; i++ { for i := 0; i < len(verify)/checkFreq; i++ {
index := i*checkFreq + rand.Intn(checkFreq) index := i*checkFreq + self.rand.Intn(checkFreq)
if index >= len(verify) { if index >= len(verify) {
index = len(verify) - 1 index = len(verify) - 1
} }
@ -766,10 +774,6 @@ func (self *BlockChain) InsertHeaderChain(chain []*types.Header, checkFreq int)
pending.Wait() pending.Wait()
// If anything failed, report // If anything failed, report
if atomic.LoadInt32(&self.procInterrupt) == 1 {
glog.V(logger.Debug).Infoln("premature abort during receipt chain processing")
return 0, nil
}
if failed > 0 { if failed > 0 {
for i, err := range errs { for i, err := range errs {
if err != nil { if err != nil {
@ -807,6 +811,9 @@ func (self *BlockChain) InsertHeaderChain(chain []*types.Header, checkFreq int)
// Rollback is designed to remove a chain of links from the database that aren't // Rollback is designed to remove a chain of links from the database that aren't
// certain enough to be valid. // certain enough to be valid.
func (self *BlockChain) Rollback(chain []common.Hash) { func (self *BlockChain) Rollback(chain []common.Hash) {
self.mu.Lock()
defer self.mu.Unlock()
for i := len(chain) - 1; i >= 0; i-- { for i := len(chain) - 1; i >= 0; i-- {
hash := chain[i] hash := chain[i]
@ -905,6 +912,12 @@ func (self *BlockChain) InsertReceiptChain(blockChain types.Blocks, receiptChain
glog.Fatal(errs[index]) glog.Fatal(errs[index])
return return
} }
if err := WriteMipmapBloom(self.chainDb, block.NumberU64(), receipts); err != nil {
errs[index] = fmt.Errorf("failed to write log blooms: %v", err)
atomic.AddInt32(&failed, 1)
glog.Fatal(errs[index])
return
}
atomic.AddInt32(&stats.processed, 1) atomic.AddInt32(&stats.processed, 1)
} }
} }
@ -920,10 +933,6 @@ func (self *BlockChain) InsertReceiptChain(blockChain types.Blocks, receiptChain
pending.Wait() pending.Wait()
// If anything failed, report // If anything failed, report
if atomic.LoadInt32(&self.procInterrupt) == 1 {
glog.V(logger.Debug).Infoln("premature abort during receipt chain processing")
return 0, nil
}
if failed > 0 { if failed > 0 {
for i, err := range errs { for i, err := range errs {
if err != nil { if err != nil {
@ -931,6 +940,10 @@ func (self *BlockChain) InsertReceiptChain(blockChain types.Blocks, receiptChain
} }
} }
} }
if atomic.LoadInt32(&self.procInterrupt) == 1 {
glog.V(logger.Debug).Infoln("premature abort during receipt chain processing")
return 0, nil
}
// Update the head fast sync block if better // Update the head fast sync block if better
self.mu.Lock() self.mu.Lock()
head := blockChain[len(errs)-1] head := blockChain[len(errs)-1]

@ -452,7 +452,7 @@ func makeBlockChainWithDiff(genesis *types.Block, d []int, seed byte) []*types.B
func chm(genesis *types.Block, db ethdb.Database) *BlockChain { func chm(genesis *types.Block, db ethdb.Database) *BlockChain {
var eventMux event.TypeMux var eventMux event.TypeMux
bc := &BlockChain{chainDb: db, genesisBlock: genesis, eventMux: &eventMux, pow: FakePow{}} bc := &BlockChain{chainDb: db, genesisBlock: genesis, eventMux: &eventMux, pow: FakePow{}, rand: rand.New(rand.NewSource(0))}
bc.headerCache, _ = lru.New(100) bc.headerCache, _ = lru.New(100)
bc.bodyCache, _ = lru.New(100) bc.bodyCache, _ = lru.New(100)
bc.bodyRLPCache, _ = lru.New(100) bc.bodyRLPCache, _ = lru.New(100)

@ -394,7 +394,7 @@ func WriteMipmapBloom(db ethdb.Database, number uint64, receipts types.Receipts)
bloomDat, _ := db.Get(key) bloomDat, _ := db.Get(key)
bloom := types.BytesToBloom(bloomDat) bloom := types.BytesToBloom(bloomDat)
for _, receipt := range receipts { for _, receipt := range receipts {
for _, log := range receipt.Logs() { for _, log := range receipt.Logs {
bloom.Add(log.Address.Big()) bloom.Add(log.Address.Big())
} }
} }

@ -345,15 +345,15 @@ func TestMipmapBloom(t *testing.T) {
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
receipt1 := new(types.Receipt) receipt1 := new(types.Receipt)
receipt1.SetLogs(vm.Logs{ receipt1.Logs = vm.Logs{
&vm.Log{Address: common.BytesToAddress([]byte("test"))}, &vm.Log{Address: common.BytesToAddress([]byte("test"))},
&vm.Log{Address: common.BytesToAddress([]byte("address"))}, &vm.Log{Address: common.BytesToAddress([]byte("address"))},
}) }
receipt2 := new(types.Receipt) receipt2 := new(types.Receipt)
receipt2.SetLogs(vm.Logs{ receipt2.Logs = vm.Logs{
&vm.Log{Address: common.BytesToAddress([]byte("test"))}, &vm.Log{Address: common.BytesToAddress([]byte("test"))},
&vm.Log{Address: common.BytesToAddress([]byte("address1"))}, &vm.Log{Address: common.BytesToAddress([]byte("address1"))},
}) }
WriteMipmapBloom(db, 1, types.Receipts{receipt1}) WriteMipmapBloom(db, 1, types.Receipts{receipt1})
WriteMipmapBloom(db, 2, types.Receipts{receipt2}) WriteMipmapBloom(db, 2, types.Receipts{receipt2})
@ -368,15 +368,15 @@ func TestMipmapBloom(t *testing.T) {
// reset // reset
db, _ = ethdb.NewMemDatabase() db, _ = ethdb.NewMemDatabase()
receipt := new(types.Receipt) receipt := new(types.Receipt)
receipt.SetLogs(vm.Logs{ receipt.Logs = vm.Logs{
&vm.Log{Address: common.BytesToAddress([]byte("test"))}, &vm.Log{Address: common.BytesToAddress([]byte("test"))},
}) }
WriteMipmapBloom(db, 999, types.Receipts{receipt1}) WriteMipmapBloom(db, 999, types.Receipts{receipt1})
receipt = new(types.Receipt) receipt = new(types.Receipt)
receipt.SetLogs(vm.Logs{ receipt.Logs = vm.Logs{
&vm.Log{Address: common.BytesToAddress([]byte("test 1"))}, &vm.Log{Address: common.BytesToAddress([]byte("test 1"))},
}) }
WriteMipmapBloom(db, 1000, types.Receipts{receipt}) WriteMipmapBloom(db, 1000, types.Receipts{receipt})
bloom := GetMipmapBloom(db, 1000, 1000) bloom := GetMipmapBloom(db, 1000, 1000)
@ -403,22 +403,22 @@ func TestMipmapChain(t *testing.T) {
defer db.Close() defer db.Close()
genesis := WriteGenesisBlockForTesting(db, GenesisAccount{addr, big.NewInt(1000000)}) genesis := WriteGenesisBlockForTesting(db, GenesisAccount{addr, big.NewInt(1000000)})
chain := GenerateChain(genesis, db, 1010, func(i int, gen *BlockGen) { chain, receipts := GenerateChain(genesis, db, 1010, func(i int, gen *BlockGen) {
var receipts types.Receipts var receipts types.Receipts
switch i { switch i {
case 1: case 1:
receipt := types.NewReceipt(nil, new(big.Int)) receipt := types.NewReceipt(nil, new(big.Int))
receipt.SetLogs(vm.Logs{ receipt.Logs = vm.Logs{
&vm.Log{ &vm.Log{
Address: addr, Address: addr,
Topics: []common.Hash{hash1}, Topics: []common.Hash{hash1},
}, },
}) }
gen.AddUncheckedReceipt(receipt) gen.AddUncheckedReceipt(receipt)
receipts = types.Receipts{receipt} receipts = types.Receipts{receipt}
case 1000: case 1000:
receipt := types.NewReceipt(nil, new(big.Int)) receipt := types.NewReceipt(nil, new(big.Int))
receipt.SetLogs(vm.Logs{&vm.Log{Address: addr2}}) receipt.Logs = vm.Logs{&vm.Log{Address: addr2}}
gen.AddUncheckedReceipt(receipt) gen.AddUncheckedReceipt(receipt)
receipts = types.Receipts{receipt} receipts = types.Receipts{receipt}
@ -431,7 +431,7 @@ func TestMipmapChain(t *testing.T) {
} }
WriteMipmapBloom(db, uint64(i+1), receipts) WriteMipmapBloom(db, uint64(i+1), receipts)
}) })
for _, block := range chain { for i, block := range chain {
WriteBlock(db, block) WriteBlock(db, block)
if err := WriteCanonicalHash(db, block.Hash(), block.NumberU64()); err != nil { if err := WriteCanonicalHash(db, block.Hash(), block.NumberU64()); err != nil {
t.Fatalf("failed to insert block number: %v", err) t.Fatalf("failed to insert block number: %v", err)
@ -439,7 +439,7 @@ func TestMipmapChain(t *testing.T) {
if err := WriteHeadBlockHash(db, block.Hash()); err != nil { if err := WriteHeadBlockHash(db, block.Hash()); err != nil {
t.Fatalf("failed to insert block number: %v", err) t.Fatalf("failed to insert block number: %v", err)
} }
if err := PutBlockReceipts(db, block, block.Receipts()); err != nil { if err := PutBlockReceipts(db, block.Hash(), receipts[i]); err != nil {
t.Fatal("error writing block receipts:", err) t.Fatal("error writing block receipts:", err)
} }
} }

@ -26,14 +26,13 @@ import (
"github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/trie"
) )
// StateSync is the main state synchronisation scheduler, which provides yet the // StateSync is the main state synchronisation scheduler, which provides yet the
// unknown state hashes to retrieve, accepts node data associated with said hashes // unknown state hashes to retrieve, accepts node data associated with said hashes
// and reconstructs the state database step by step until all is done. // and reconstructs the state database step by step until all is done.
type StateSync trie.TrieSync type StateSync trie.TrieSync
// NewStateSync create a new state trie download scheduler. // NewStateSync create a new state trie download scheduler.
func NewStateSync(root common.Hash, database ethdb.Database) *StateSync { func NewStateSync(root common.Hash, database ethdb.Database) *StateSync {
// Pre-declare the result syncer t
var syncer *trie.TrieSync var syncer *trie.TrieSync
callback := func(leaf []byte, parent common.Hash) error { callback := func(leaf []byte, parent common.Hash) error {

@ -38,7 +38,7 @@ type testAccount struct {
func makeTestState() (ethdb.Database, common.Hash, []*testAccount) { func makeTestState() (ethdb.Database, common.Hash, []*testAccount) {
// Create an empty state // Create an empty state
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
state := New(common.Hash{}, db) state, _ := New(common.Hash{}, db)
// Fill it with some arbitrary data // Fill it with some arbitrary data
accounts := []*testAccount{} accounts := []*testAccount{}
@ -68,7 +68,7 @@ func makeTestState() (ethdb.Database, common.Hash, []*testAccount) {
// checkStateAccounts cross references a reconstructed state with an expected // checkStateAccounts cross references a reconstructed state with an expected
// account array. // account array.
func checkStateAccounts(t *testing.T, db ethdb.Database, root common.Hash, accounts []*testAccount) { func checkStateAccounts(t *testing.T, db ethdb.Database, root common.Hash, accounts []*testAccount) {
state := New(root, db) state, _ := New(root, db)
for i, acc := range accounts { for i, acc := range accounts {
if balance := state.GetBalance(acc.address); balance.Cmp(acc.balance) != 0 { if balance := state.GetBalance(acc.address); balance.Cmp(acc.balance) != 0 {

@ -67,7 +67,7 @@ func (r *Receipt) DecodeRLP(s *rlp.Stream) error {
return nil return nil
} }
// RlpEncode implements common.RlpEncode required for SHA derivation. // RlpEncode implements common.RlpEncode required for SHA3 derivation.
func (r *Receipt) RlpEncode() []byte { func (r *Receipt) RlpEncode() []byte {
bytes, err := rlp.EncodeToBytes(r) bytes, err := rlp.EncodeToBytes(r)
if err != nil { if err != nil {
@ -82,7 +82,7 @@ func (r *Receipt) String() string {
} }
// ReceiptForStorage is a wrapper around a Receipt that flattens and parses the // ReceiptForStorage is a wrapper around a Receipt that flattens and parses the
// entire content of a receipt, opposed to only the consensus fields originally. // entire content of a receipt, as opposed to only the consensus fields originally.
type ReceiptForStorage Receipt type ReceiptForStorage Receipt
// EncodeRLP implements rlp.Encoder, and flattens all content fields of a receipt // EncodeRLP implements rlp.Encoder, and flattens all content fields of a receipt
@ -95,8 +95,8 @@ func (r *ReceiptForStorage) EncodeRLP(w io.Writer) error {
return rlp.Encode(w, []interface{}{r.PostState, r.CumulativeGasUsed, r.Bloom, r.TxHash, r.ContractAddress, logs, r.GasUsed}) return rlp.Encode(w, []interface{}{r.PostState, r.CumulativeGasUsed, r.Bloom, r.TxHash, r.ContractAddress, logs, r.GasUsed})
} }
// DecodeRLP implements rlp.Decoder, and loads the consensus fields of a receipt // DecodeRLP implements rlp.Decoder, and loads both consensus and implementation
// from an RLP stream. // fields of a receipt from an RLP stream.
func (r *ReceiptForStorage) DecodeRLP(s *rlp.Stream) error { func (r *ReceiptForStorage) DecodeRLP(s *rlp.Stream) error {
var receipt struct { var receipt struct {
PostState []byte PostState []byte
@ -125,7 +125,7 @@ func (r *ReceiptForStorage) DecodeRLP(s *rlp.Stream) error {
// Receipts is a wrapper around a Receipt array to implement types.DerivableList. // Receipts is a wrapper around a Receipt array to implement types.DerivableList.
type Receipts []*Receipt type Receipts []*Receipt
// RlpEncode implements common.RlpEncode required for SHA derivation. // RlpEncode implements common.RlpEncode required for SHA3 derivation.
func (r Receipts) RlpEncode() []byte { func (r Receipts) RlpEncode() []byte {
bytes, err := rlp.EncodeToBytes(r) bytes, err := rlp.EncodeToBytes(r)
if err != nil { if err != nil {

@ -66,6 +66,6 @@ func (l *Log) String() string {
type Logs []*Log type Logs []*Log
// LogForStorage is a wrapper around a Log that flattens and parses the entire // LogForStorage is a wrapper around a Log that flattens and parses the entire
// content of a log, opposed to only the consensus fields originally (by hiding // content of a log, as opposed to only the consensus fields originally (by hiding
// the rlp interface methods). // the rlp interface methods).
type LogForStorage Log type LogForStorage Log

@ -391,7 +391,6 @@ func New(config *Config) (*Ethereum, error) {
if err == core.ErrNoGenesis { if err == core.ErrNoGenesis {
return nil, fmt.Errorf(`Genesis block not found. Please supply a genesis block with the "--genesis /path/to/file" argument`) return nil, fmt.Errorf(`Genesis block not found. Please supply a genesis block with the "--genesis /path/to/file" argument`)
} }
return nil, err return nil, err
} }
newPool := core.NewTxPool(eth.EventMux(), eth.blockchain.State, eth.blockchain.GasLimit) newPool := core.NewTxPool(eth.EventMux(), eth.blockchain.State, eth.blockchain.GasLimit)

@ -16,17 +16,17 @@ func TestMipmapUpgrade(t *testing.T) {
addr := common.BytesToAddress([]byte("jeff")) addr := common.BytesToAddress([]byte("jeff"))
genesis := core.WriteGenesisBlockForTesting(db) genesis := core.WriteGenesisBlockForTesting(db)
chain := core.GenerateChain(genesis, db, 10, func(i int, gen *core.BlockGen) { chain, receipts := core.GenerateChain(genesis, db, 10, func(i int, gen *core.BlockGen) {
var receipts types.Receipts var receipts types.Receipts
switch i { switch i {
case 1: case 1:
receipt := types.NewReceipt(nil, new(big.Int)) receipt := types.NewReceipt(nil, new(big.Int))
receipt.SetLogs(vm.Logs{&vm.Log{Address: addr}}) receipt.Logs = vm.Logs{&vm.Log{Address: addr}}
gen.AddUncheckedReceipt(receipt) gen.AddUncheckedReceipt(receipt)
receipts = types.Receipts{receipt} receipts = types.Receipts{receipt}
case 2: case 2:
receipt := types.NewReceipt(nil, new(big.Int)) receipt := types.NewReceipt(nil, new(big.Int))
receipt.SetLogs(vm.Logs{&vm.Log{Address: addr}}) receipt.Logs = vm.Logs{&vm.Log{Address: addr}}
gen.AddUncheckedReceipt(receipt) gen.AddUncheckedReceipt(receipt)
receipts = types.Receipts{receipt} receipts = types.Receipts{receipt}
} }
@ -37,7 +37,7 @@ func TestMipmapUpgrade(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
}) })
for _, block := range chain { for i, block := range chain {
core.WriteBlock(db, block) core.WriteBlock(db, block)
if err := core.WriteCanonicalHash(db, block.Hash(), block.NumberU64()); err != nil { if err := core.WriteCanonicalHash(db, block.Hash(), block.NumberU64()); err != nil {
t.Fatalf("failed to insert block number: %v", err) t.Fatalf("failed to insert block number: %v", err)
@ -45,7 +45,7 @@ func TestMipmapUpgrade(t *testing.T) {
if err := core.WriteHeadBlockHash(db, block.Hash()); err != nil { if err := core.WriteHeadBlockHash(db, block.Hash()); err != nil {
t.Fatalf("failed to insert block number: %v", err) t.Fatalf("failed to insert block number: %v", err)
} }
if err := core.PutBlockReceipts(db, block, block.Receipts()); err != nil { if err := core.PutBlockReceipts(db, block.Hash(), receipts[i]); err != nil {
t.Fatal("error writing block receipts:", err) t.Fatal("error writing block receipts:", err)
} }
} }

@ -18,7 +18,9 @@
package downloader package downloader
import ( import (
"crypto/rand"
"errors" "errors"
"fmt"
"math" "math"
"math/big" "math/big"
"strings" "strings"
@ -59,9 +61,11 @@ var (
maxQueuedStates = 256 * 1024 // [eth/63] Maximum number of state requests to queue (DOS protection) maxQueuedStates = 256 * 1024 // [eth/63] Maximum number of state requests to queue (DOS protection)
maxResultsProcess = 256 // Number of download results to import at once into the chain maxResultsProcess = 256 // Number of download results to import at once into the chain
headerCheckFrequency = 100 // Verification frequency of the downloaded headers during fast sync fsHeaderCheckFrequency = 100 // Verification frequency of the downloaded headers during fast sync
minCheckedHeaders = 2048 // Number of headers to verify fully when approaching the chain head fsHeaderSafetyNet = 2048 // Number of headers to discard in case a chain violation is detected
minFullBlocks = 1024 // Number of blocks to retrieve fully even in fast sync fsHeaderForceVerify = 24 // Number of headers to verify before and after the pivot to accept it
fsPivotInterval = 512 // Number of headers out of which to randomize the pivot point
fsMinFullBlocks = 1024 // Number of blocks to retrieve fully even in fast sync
) )
var ( var (
@ -85,12 +89,14 @@ var (
errCancelHeaderFetch = errors.New("block header download canceled (requested)") errCancelHeaderFetch = errors.New("block header download canceled (requested)")
errCancelBodyFetch = errors.New("block body download canceled (requested)") errCancelBodyFetch = errors.New("block body download canceled (requested)")
errCancelReceiptFetch = errors.New("receipt download canceled (requested)") errCancelReceiptFetch = errors.New("receipt download canceled (requested)")
errCancelStateFetch = errors.New("state data download canceled (requested)")
errNoSyncActive = errors.New("no sync active") errNoSyncActive = errors.New("no sync active")
) )
type Downloader struct { type Downloader struct {
mode SyncMode // Synchronisation mode defining the strategies used mode SyncMode // Synchronisation mode defining the strategy used (per sync cycle)
mux *event.TypeMux // Event multiplexer to announce sync operation events noFast bool // Flag to disable fast syncing in case of a security error
mux *event.TypeMux // Event multiplexer to announce sync operation events
queue *queue // Scheduler for selecting the hashes to download queue *queue // Scheduler for selecting the hashes to download
peers *peerSet // Set of active peers from which download can proceed peers *peerSet // Set of active peers from which download can proceed
@ -150,13 +156,13 @@ type Downloader struct {
} }
// New creates a new downloader to fetch hashes and blocks from remote peers. // New creates a new downloader to fetch hashes and blocks from remote peers.
func New(mode SyncMode, stateDb ethdb.Database, mux *event.TypeMux, hasHeader headerCheckFn, hasBlock blockCheckFn, getHeader headerRetrievalFn, func New(stateDb ethdb.Database, mux *event.TypeMux, hasHeader headerCheckFn, hasBlock blockCheckFn, getHeader headerRetrievalFn,
getBlock blockRetrievalFn, headHeader headHeaderRetrievalFn, headBlock headBlockRetrievalFn, headFastBlock headFastBlockRetrievalFn, getBlock blockRetrievalFn, headHeader headHeaderRetrievalFn, headBlock headBlockRetrievalFn, headFastBlock headFastBlockRetrievalFn,
commitHeadBlock headBlockCommitterFn, getTd tdRetrievalFn, insertHeaders headerChainInsertFn, insertBlocks blockChainInsertFn, commitHeadBlock headBlockCommitterFn, getTd tdRetrievalFn, insertHeaders headerChainInsertFn, insertBlocks blockChainInsertFn,
insertReceipts receiptChainInsertFn, rollback chainRollbackFn, dropPeer peerDropFn) *Downloader { insertReceipts receiptChainInsertFn, rollback chainRollbackFn, dropPeer peerDropFn) *Downloader {
return &Downloader{ return &Downloader{
mode: mode, mode: FullSync,
mux: mux, mux: mux,
queue: newQueue(stateDb), queue: newQueue(stateDb),
peers: newPeerSet(), peers: newPeerSet(),
@ -188,19 +194,28 @@ func New(mode SyncMode, stateDb ethdb.Database, mux *event.TypeMux, hasHeader he
} }
} }
// Boundaries retrieves the synchronisation boundaries, specifically the origin // Progress retrieves the synchronisation boundaries, specifically the origin
// block where synchronisation started at (may have failed/suspended) and the // block where synchronisation started at (may have failed/suspended); the block
// latest known block which the synchonisation targets. // or header sync is currently at; and the latest known block which the sync targets.
func (d *Downloader) Boundaries() (uint64, uint64) { func (d *Downloader) Progress() (uint64, uint64, uint64) {
d.syncStatsLock.RLock() d.syncStatsLock.RLock()
defer d.syncStatsLock.RUnlock() defer d.syncStatsLock.RUnlock()
return d.syncStatsChainOrigin, d.syncStatsChainHeight current := uint64(0)
switch d.mode {
case FullSync:
current = d.headBlock().NumberU64()
case FastSync:
current = d.headFastBlock().NumberU64()
case LightSync:
current = d.headHeader().Number.Uint64()
}
return d.syncStatsChainOrigin, current, d.syncStatsChainHeight
} }
// Synchronising returns whether the downloader is currently retrieving blocks. // Synchronising returns whether the downloader is currently retrieving blocks.
func (d *Downloader) Synchronising() bool { func (d *Downloader) Synchronising() bool {
return atomic.LoadInt32(&d.synchronising) > 0 return atomic.LoadInt32(&d.synchronising) > 0 || atomic.LoadInt32(&d.processing) > 0
} }
// RegisterPeer injects a new download peer into the set of block source to be // RegisterPeer injects a new download peer into the set of block source to be
@ -233,10 +248,10 @@ func (d *Downloader) UnregisterPeer(id string) error {
// Synchronise tries to sync up our local block chain with a remote peer, both // Synchronise tries to sync up our local block chain with a remote peer, both
// adding various sanity checks as well as wrapping it with various log entries. // adding various sanity checks as well as wrapping it with various log entries.
func (d *Downloader) Synchronise(id string, head common.Hash, td *big.Int) { func (d *Downloader) Synchronise(id string, head common.Hash, td *big.Int, mode SyncMode) {
glog.V(logger.Detail).Infof("Attempting synchronisation: %v, head [%x…], TD %v", id, head[:4], td) glog.V(logger.Detail).Infof("Attempting synchronisation: %v, head [%x…], TD %v", id, head[:4], td)
switch err := d.synchronise(id, head, td); err { switch err := d.synchronise(id, head, td, mode); err {
case nil: case nil:
glog.V(logger.Detail).Infof("Synchronisation completed") glog.V(logger.Detail).Infof("Synchronisation completed")
@ -258,7 +273,7 @@ func (d *Downloader) Synchronise(id string, head common.Hash, td *big.Int) {
// synchronise will select the peer and use it for synchronising. If an empty string is given // synchronise will select the peer and use it for synchronising. If an empty string is given
// it will use the best peer possible and synchronize if it's TD is higher than our own. If any of the // it will use the best peer possible and synchronize if it's TD is higher than our own. If any of the
// checks fail an error will be returned. This method is synchronous // checks fail an error will be returned. This method is synchronous
func (d *Downloader) synchronise(id string, hash common.Hash, td *big.Int) error { func (d *Downloader) synchronise(id string, hash common.Hash, td *big.Int, mode SyncMode) error {
// Mock out the synchonisation if testing // Mock out the synchonisation if testing
if d.synchroniseMock != nil { if d.synchroniseMock != nil {
return d.synchroniseMock(id, hash) return d.synchroniseMock(id, hash)
@ -298,6 +313,11 @@ func (d *Downloader) synchronise(id string, hash common.Hash, td *big.Int) error
d.cancelCh = make(chan struct{}) d.cancelCh = make(chan struct{})
d.cancelLock.Unlock() d.cancelLock.Unlock()
// Set the requested sync mode, unless it's forbidden
d.mode = mode
if d.mode == FastSync && d.noFast {
d.mode = FullSync
}
// Retrieve the origin peer and initiate the downloading process // Retrieve the origin peer and initiate the downloading process
p := d.peers.Peer(id) p := d.peers.Peer(id)
if p == nil { if p == nil {
@ -306,13 +326,6 @@ func (d *Downloader) synchronise(id string, hash common.Hash, td *big.Int) error
return d.syncWithPeer(p, hash, td) return d.syncWithPeer(p, hash, td)
} }
/*
// Has checks if the downloader knows about a particular hash, meaning that its
// either already downloaded of pending retrieval.
func (d *Downloader) Has(hash common.Hash) bool {
return d.queue.Has(hash)
}
*/
// syncWithPeer starts a block synchronization based on the hash chain from the // syncWithPeer starts a block synchronization based on the hash chain from the
// specified peer and head hash. // specified peer and head hash.
func (d *Downloader) syncWithPeer(p *peer, hash common.Hash, td *big.Int) (err error) { func (d *Downloader) syncWithPeer(p *peer, hash common.Hash, td *big.Int) (err error) {
@ -387,8 +400,28 @@ func (d *Downloader) syncWithPeer(p *peer, hash common.Hash, td *big.Int) (err e
// Initiate the sync using a concurrent header and content retrieval algorithm // Initiate the sync using a concurrent header and content retrieval algorithm
pivot := uint64(0) pivot := uint64(0)
if latest > uint64(minFullBlocks) { switch d.mode {
pivot = latest - uint64(minFullBlocks) case LightSync:
pivot = latest
case FastSync:
// Calculate the new fast/slow sync pivot point
pivotOffset, err := rand.Int(rand.Reader, big.NewInt(int64(fsPivotInterval)))
if err != nil {
panic(fmt.Sprintf("Failed to access crypto random source: %v", err))
}
if latest > uint64(fsMinFullBlocks)+pivotOffset.Uint64() {
pivot = latest - uint64(fsMinFullBlocks) - pivotOffset.Uint64()
}
// If the point is below the origin, move origin back to ensure state download
if pivot < origin {
if pivot > 0 {
origin = pivot - 1
} else {
origin = 0
}
}
glog.V(logger.Debug).Infof("Fast syncing until pivot block #%d", pivot)
} }
d.queue.Prepare(origin+1, d.mode, pivot) d.queue.Prepare(origin+1, d.mode, pivot)
@ -396,10 +429,10 @@ func (d *Downloader) syncWithPeer(p *peer, hash common.Hash, td *big.Int) (err e
d.syncInitHook(origin, latest) d.syncInitHook(origin, latest)
} }
errc := make(chan error, 4) errc := make(chan error, 4)
go func() { errc <- d.fetchHeaders(p, td, origin+1, latest) }() // Headers are always retrieved go func() { errc <- d.fetchHeaders(p, td, origin+1) }() // Headers are always retrieved
go func() { errc <- d.fetchBodies(origin + 1) }() // Bodies are retrieved during normal and fast sync go func() { errc <- d.fetchBodies(origin + 1) }() // Bodies are retrieved during normal and fast sync
go func() { errc <- d.fetchReceipts(origin + 1) }() // Receipts are retrieved during fast sync go func() { errc <- d.fetchReceipts(origin + 1) }() // Receipts are retrieved during fast sync
go func() { errc <- d.fetchNodeData() }() // Node state data is retrieved during fast sync go func() { errc <- d.fetchNodeData() }() // Node state data is retrieved during fast sync
// If any fetcher fails, cancel the others // If any fetcher fails, cancel the others
var fail error var fail error
@ -844,7 +877,7 @@ func (d *Downloader) fetchBlocks61(from uint64) error {
for _, peer := range idles { for _, peer := range idles {
// Short circuit if throttling activated // Short circuit if throttling activated
if d.queue.ThrottleBlocks() { if d.queue.ShouldThrottleBlocks() {
throttled = true throttled = true
break break
} }
@ -860,8 +893,13 @@ func (d *Downloader) fetchBlocks61(from uint64) error {
} }
// Fetch the chunk and make sure any errors return the hashes to the queue // Fetch the chunk and make sure any errors return the hashes to the queue
if err := peer.Fetch61(request); err != nil { if err := peer.Fetch61(request); err != nil {
glog.V(logger.Error).Infof("%v: fetch failed, rescheduling", peer) // Although we could try and make an attempt to fix this, this error really
d.queue.CancelBlocks(request) // means that we've double allocated a fetch task to a peer. If that is the
// case, the internal state of the downloader and the queue is very wrong so
// better hard crash and note the error instead of silently accumulating into
// a much bigger issue.
panic(fmt.Sprintf("%v: fetch assignment failed, hard panic", peer))
d.queue.CancelBlocks(request) // noop for now
} }
} }
// Make sure that we have peers available for fetching. If all peers have been tried // Make sure that we have peers available for fetching. If all peers have been tried
@ -1051,28 +1089,34 @@ func (d *Downloader) findAncestor(p *peer) (uint64, error) {
// //
// The queue parameter can be used to switch between queuing headers for block // The queue parameter can be used to switch between queuing headers for block
// body download too, or directly import as pure header chains. // body download too, or directly import as pure header chains.
func (d *Downloader) fetchHeaders(p *peer, td *big.Int, from, latest uint64) error { func (d *Downloader) fetchHeaders(p *peer, td *big.Int, from uint64) error {
glog.V(logger.Debug).Infof("%v: downloading headers from #%d", p, from) glog.V(logger.Debug).Infof("%v: downloading headers from #%d", p, from)
defer glog.V(logger.Debug).Infof("%v: header download terminated", p) defer glog.V(logger.Debug).Infof("%v: header download terminated", p)
// Calculate the pivoting point for switching from fast to slow sync
pivot := d.queue.FastSyncPivot()
// Keep a count of uncertain headers to roll back // Keep a count of uncertain headers to roll back
rollback := []*types.Header{} rollback := []*types.Header{}
defer func() { defer func() {
if len(rollback) > 0 { if len(rollback) > 0 {
// Flatten the headers and roll them back
hashes := make([]common.Hash, len(rollback)) hashes := make([]common.Hash, len(rollback))
for i, header := range rollback { for i, header := range rollback {
hashes[i] = header.Hash() hashes[i] = header.Hash()
} }
lh, lfb, lb := d.headHeader().Number, d.headFastBlock().Number(), d.headBlock().Number()
d.rollback(hashes) d.rollback(hashes)
glog.V(logger.Warn).Infof("Rolled back %d headers (LH: %d->%d, FB: %d->%d, LB: %d->%d)",
len(hashes), lh, d.headHeader().Number, lfb, d.headFastBlock().Number(), lb, d.headBlock().Number())
// If we're already past the pivot point, this could be an attack, disable fast sync
if rollback[len(rollback)-1].Number.Uint64() > pivot {
d.noFast = true
}
} }
}() }()
// Calculate the pivoting point for switching from fast to slow sync
pivot := uint64(0)
if d.mode == FastSync && latest > uint64(minFullBlocks) {
pivot = latest - uint64(minFullBlocks)
} else if d.mode == LightSync {
pivot = latest
}
// Create a timeout timer, and the associated hash fetcher // Create a timeout timer, and the associated hash fetcher
request := time.Now() // time of the last fetch request request := time.Now() // time of the last fetch request
timeout := time.NewTimer(0) // timer to dump a non-responsive active peer timeout := time.NewTimer(0) // timer to dump a non-responsive active peer
@ -1135,6 +1179,19 @@ func (d *Downloader) fetchHeaders(p *peer, td *big.Int, from, latest uint64) err
if !gotHeaders && td.Cmp(d.getTd(d.headBlock().Hash())) > 0 { if !gotHeaders && td.Cmp(d.getTd(d.headBlock().Hash())) > 0 {
return errStallingPeer return errStallingPeer
} }
// If fast or light syncing, ensure promised headers are indeed delivered. This is
// needed to detect scenarios where an attacker feeds a bad pivot and then bails out
// of delivering the post-pivot blocks that would flag the invalid content.
//
// This check cannot be executed "as is" for full imports, since blocks may still be
// queued for processing when the header download completes. However, as long as the
// peer gave us something useful, we're already happy/progressed (above check).
if d.mode == FastSync || d.mode == LightSync {
if td.Cmp(d.getTd(d.headHeader().Hash())) > 0 {
return errStallingPeer
}
}
rollback = nil
return nil return nil
} }
gotHeaders = true gotHeaders = true
@ -1152,8 +1209,8 @@ func (d *Downloader) fetchHeaders(p *peer, td *big.Int, from, latest uint64) err
} }
} }
// If we're importing pure headers, verify based on their recentness // If we're importing pure headers, verify based on their recentness
frequency := headerCheckFrequency frequency := fsHeaderCheckFrequency
if headers[len(headers)-1].Number.Uint64()+uint64(minCheckedHeaders) > pivot { if headers[len(headers)-1].Number.Uint64()+uint64(fsHeaderForceVerify) > pivot {
frequency = 1 frequency = 1
} }
if n, err := d.insertHeaders(headers, frequency); err != nil { if n, err := d.insertHeaders(headers, frequency); err != nil {
@ -1162,11 +1219,8 @@ func (d *Downloader) fetchHeaders(p *peer, td *big.Int, from, latest uint64) err
} }
// All verifications passed, store newly found uncertain headers // All verifications passed, store newly found uncertain headers
rollback = append(rollback, unknown...) rollback = append(rollback, unknown...)
if len(rollback) > minCheckedHeaders { if len(rollback) > fsHeaderSafetyNet {
rollback = append(rollback[:0], rollback[len(rollback)-minCheckedHeaders:]...) rollback = append(rollback[:0], rollback[len(rollback)-fsHeaderSafetyNet:]...)
}
if headers[len(headers)-1].Number.Uint64() >= pivot {
rollback = rollback[:0]
} }
} }
if d.mode == FullSync || d.mode == FastSync { if d.mode == FullSync || d.mode == FastSync {
@ -1230,12 +1284,11 @@ func (d *Downloader) fetchBodies(from uint64) error {
expire = func() []string { return d.queue.ExpireBodies(bodyHardTTL) } expire = func() []string { return d.queue.ExpireBodies(bodyHardTTL) }
fetch = func(p *peer, req *fetchRequest) error { return p.FetchBodies(req) } fetch = func(p *peer, req *fetchRequest) error { return p.FetchBodies(req) }
capacity = func(p *peer) int { return p.BlockCapacity() } capacity = func(p *peer) int { return p.BlockCapacity() }
getIdles = func() ([]*peer, int) { return d.peers.BodyIdlePeers() } setIdle = func(p *peer) { p.SetBodiesIdle() }
setIdle = func(p *peer) { p.SetBlocksIdle() }
) )
err := d.fetchParts(errCancelBodyFetch, d.bodyCh, deliver, d.bodyWakeCh, expire, err := d.fetchParts(errCancelBodyFetch, d.bodyCh, deliver, d.bodyWakeCh, expire,
d.queue.PendingBlocks, d.queue.InFlightBlocks, d.queue.ThrottleBlocks, d.queue.ReserveBodies, d.queue.PendingBlocks, d.queue.InFlightBlocks, d.queue.ShouldThrottleBlocks, d.queue.ReserveBodies,
d.bodyFetchHook, fetch, d.queue.CancelBodies, capacity, getIdles, setIdle, "Body") d.bodyFetchHook, fetch, d.queue.CancelBodies, capacity, d.peers.BodyIdlePeers, setIdle, "Body")
glog.V(logger.Debug).Infof("Block body download terminated: %v", err) glog.V(logger.Debug).Infof("Block body download terminated: %v", err)
return err return err
@ -1252,13 +1305,13 @@ func (d *Downloader) fetchReceipts(from uint64) error {
pack := packet.(*receiptPack) pack := packet.(*receiptPack)
return d.queue.DeliverReceipts(pack.peerId, pack.receipts) return d.queue.DeliverReceipts(pack.peerId, pack.receipts)
} }
expire = func() []string { return d.queue.ExpireReceipts(bodyHardTTL) } expire = func() []string { return d.queue.ExpireReceipts(receiptHardTTL) }
fetch = func(p *peer, req *fetchRequest) error { return p.FetchReceipts(req) } fetch = func(p *peer, req *fetchRequest) error { return p.FetchReceipts(req) }
capacity = func(p *peer) int { return p.ReceiptCapacity() } capacity = func(p *peer) int { return p.ReceiptCapacity() }
setIdle = func(p *peer) { p.SetReceiptsIdle() } setIdle = func(p *peer) { p.SetReceiptsIdle() }
) )
err := d.fetchParts(errCancelReceiptFetch, d.receiptCh, deliver, d.receiptWakeCh, expire, err := d.fetchParts(errCancelReceiptFetch, d.receiptCh, deliver, d.receiptWakeCh, expire,
d.queue.PendingReceipts, d.queue.InFlightReceipts, d.queue.ThrottleReceipts, d.queue.ReserveReceipts, d.queue.PendingReceipts, d.queue.InFlightReceipts, d.queue.ShouldThrottleReceipts, d.queue.ReserveReceipts,
d.receiptFetchHook, fetch, d.queue.CancelReceipts, capacity, d.peers.ReceiptIdlePeers, setIdle, "Receipt") d.receiptFetchHook, fetch, d.queue.CancelReceipts, capacity, d.peers.ReceiptIdlePeers, setIdle, "Receipt")
glog.V(logger.Debug).Infof("Receipt download terminated: %v", err) glog.V(logger.Debug).Infof("Receipt download terminated: %v", err)
@ -1307,9 +1360,9 @@ func (d *Downloader) fetchNodeData() error {
capacity = func(p *peer) int { return p.NodeDataCapacity() } capacity = func(p *peer) int { return p.NodeDataCapacity() }
setIdle = func(p *peer) { p.SetNodeDataIdle() } setIdle = func(p *peer) { p.SetNodeDataIdle() }
) )
err := d.fetchParts(errCancelReceiptFetch, d.stateCh, deliver, d.stateWakeCh, expire, err := d.fetchParts(errCancelStateFetch, d.stateCh, deliver, d.stateWakeCh, expire,
d.queue.PendingNodeData, d.queue.InFlightNodeData, throttle, reserve, nil, fetch, d.queue.PendingNodeData, d.queue.InFlightNodeData, throttle, reserve, nil, fetch,
d.queue.CancelNodeData, capacity, d.peers.ReceiptIdlePeers, setIdle, "State") d.queue.CancelNodeData, capacity, d.peers.NodeDataIdlePeers, setIdle, "State")
glog.V(logger.Debug).Infof("Node state data download terminated: %v", err) glog.V(logger.Debug).Infof("Node state data download terminated: %v", err)
return err return err
@ -1323,7 +1376,7 @@ func (d *Downloader) fetchParts(errCancel error, deliveryCh chan dataPack, deliv
fetchHook func([]*types.Header), fetch func(*peer, *fetchRequest) error, cancel func(*fetchRequest), capacity func(*peer) int, fetchHook func([]*types.Header), fetch func(*peer, *fetchRequest) error, cancel func(*fetchRequest), capacity func(*peer) int,
idle func() ([]*peer, int), setIdle func(*peer), kind string) error { idle func() ([]*peer, int), setIdle func(*peer), kind string) error {
// Create a ticker to detect expired retreival tasks // Create a ticker to detect expired retrieval tasks
ticker := time.NewTicker(100 * time.Millisecond) ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop() defer ticker.Stop()
@ -1366,11 +1419,6 @@ func (d *Downloader) fetchParts(errCancel error, deliveryCh chan dataPack, deliv
// The hash chain is invalid (blocks are not ordered properly), abort // The hash chain is invalid (blocks are not ordered properly), abort
return err return err
case errInvalidBody, errInvalidReceipt:
// The peer delivered something very bad, drop immediately
glog.V(logger.Error).Infof("%s: delivered invalid %s, dropping", peer, strings.ToLower(kind))
d.dropPeer(peer.id)
case errNoFetchesPending: case errNoFetchesPending:
// Peer probably timed out with its delivery but came through // Peer probably timed out with its delivery but came through
// in the end, demote, but allow to to pull from this peer. // in the end, demote, but allow to to pull from this peer.
@ -1475,8 +1523,13 @@ func (d *Downloader) fetchParts(errCancel error, deliveryCh chan dataPack, deliv
fetchHook(request.Headers) fetchHook(request.Headers)
} }
if err := fetch(peer, request); err != nil { if err := fetch(peer, request); err != nil {
glog.V(logger.Error).Infof("%v: %s fetch failed, rescheduling", peer, strings.ToLower(kind)) // Although we could try and make an attempt to fix this, this error really
cancel(request) // means that we've double allocated a fetch task to a peer. If that is the
// case, the internal state of the downloader and the queue is very wrong so
// better hard crash and note the error instead of silently accumulating into
// a much bigger issue.
panic(fmt.Sprintf("%v: %s fetch assignment failed, hard panic", peer, strings.ToLower(kind)))
cancel(request) // noop for now
} }
running = true running = true
} }
@ -1526,6 +1579,7 @@ func (d *Downloader) process() {
// Repeat the processing as long as there are results to process // Repeat the processing as long as there are results to process
for { for {
// Fetch the next batch of results // Fetch the next batch of results
pivot := d.queue.FastSyncPivot() // Fetch pivot before results to prevent reset race
results := d.queue.TakeResults() results := d.queue.TakeResults()
if len(results) == 0 { if len(results) == 0 {
return return
@ -1545,7 +1599,6 @@ func (d *Downloader) process() {
} }
// Retrieve the a batch of results to import // Retrieve the a batch of results to import
var ( var (
headers = make([]*types.Header, 0, maxResultsProcess)
blocks = make([]*types.Block, 0, maxResultsProcess) blocks = make([]*types.Block, 0, maxResultsProcess)
receipts = make([]types.Receipts, 0, maxResultsProcess) receipts = make([]types.Receipts, 0, maxResultsProcess)
) )
@ -1556,11 +1609,9 @@ func (d *Downloader) process() {
blocks = append(blocks, types.NewBlockWithHeader(result.Header).WithBody(result.Transactions, result.Uncles)) blocks = append(blocks, types.NewBlockWithHeader(result.Header).WithBody(result.Transactions, result.Uncles))
case d.mode == FastSync: case d.mode == FastSync:
blocks = append(blocks, types.NewBlockWithHeader(result.Header).WithBody(result.Transactions, result.Uncles)) blocks = append(blocks, types.NewBlockWithHeader(result.Header).WithBody(result.Transactions, result.Uncles))
if result.Header.Number.Uint64() <= d.queue.fastSyncPivot { if result.Header.Number.Uint64() <= pivot {
receipts = append(receipts, result.Receipts) receipts = append(receipts, result.Receipts)
} }
case d.mode == LightSync:
headers = append(headers, result.Header)
} }
} }
// Try to process the results, aborting if there's an error // Try to process the results, aborting if there's an error
@ -1569,12 +1620,10 @@ func (d *Downloader) process() {
index int index int
) )
switch { switch {
case len(headers) > 0:
index, err = d.insertHeaders(headers, headerCheckFrequency)
case len(receipts) > 0: case len(receipts) > 0:
index, err = d.insertReceipts(blocks, receipts) index, err = d.insertReceipts(blocks, receipts)
if err == nil && blocks[len(blocks)-1].NumberU64() == d.queue.fastSyncPivot { if err == nil && blocks[len(blocks)-1].NumberU64() == pivot {
glog.V(logger.Debug).Infof("Committing block #%d [%x…] as the new head", blocks[len(blocks)-1].Number(), blocks[len(blocks)-1].Hash().Bytes()[:4])
index, err = len(blocks)-1, d.commitHeadBlock(blocks[len(blocks)-1].Hash()) index, err = len(blocks)-1, d.commitHeadBlock(blocks[len(blocks)-1].Hash())
} }
default: default:

@ -136,7 +136,7 @@ type downloadTester struct {
} }
// newTester creates a new downloader test mocker. // newTester creates a new downloader test mocker.
func newTester(mode SyncMode) *downloadTester { func newTester() *downloadTester {
tester := &downloadTester{ tester := &downloadTester{
ownHashes: []common.Hash{genesis.Hash()}, ownHashes: []common.Hash{genesis.Hash()},
ownHeaders: map[common.Hash]*types.Header{genesis.Hash(): genesis.Header()}, ownHeaders: map[common.Hash]*types.Header{genesis.Hash(): genesis.Header()},
@ -150,7 +150,7 @@ func newTester(mode SyncMode) *downloadTester {
peerChainTds: make(map[string]map[common.Hash]*big.Int), peerChainTds: make(map[string]map[common.Hash]*big.Int),
} }
tester.stateDb, _ = ethdb.NewMemDatabase() tester.stateDb, _ = ethdb.NewMemDatabase()
tester.downloader = New(mode, tester.stateDb, new(event.TypeMux), tester.hasHeader, tester.hasBlock, tester.getHeader, tester.downloader = New(tester.stateDb, new(event.TypeMux), tester.hasHeader, tester.hasBlock, tester.getHeader,
tester.getBlock, tester.headHeader, tester.headBlock, tester.headFastBlock, tester.commitHeadBlock, tester.getTd, tester.getBlock, tester.headHeader, tester.headBlock, tester.headFastBlock, tester.commitHeadBlock, tester.getTd,
tester.insertHeaders, tester.insertBlocks, tester.insertReceipts, tester.rollback, tester.dropPeer) tester.insertHeaders, tester.insertBlocks, tester.insertReceipts, tester.rollback, tester.dropPeer)
@ -158,7 +158,7 @@ func newTester(mode SyncMode) *downloadTester {
} }
// sync starts synchronizing with a remote peer, blocking until it completes. // sync starts synchronizing with a remote peer, blocking until it completes.
func (dl *downloadTester) sync(id string, td *big.Int) error { func (dl *downloadTester) sync(id string, td *big.Int, mode SyncMode) error {
dl.lock.RLock() dl.lock.RLock()
hash := dl.peerHashes[id][0] hash := dl.peerHashes[id][0]
// If no particular TD was requested, load from the peer's blockchain // If no particular TD was requested, load from the peer's blockchain
@ -170,7 +170,7 @@ func (dl *downloadTester) sync(id string, td *big.Int) error {
} }
dl.lock.RUnlock() dl.lock.RUnlock()
err := dl.downloader.synchronise(id, hash, td) err := dl.downloader.synchronise(id, hash, td, mode)
for { for {
// If the queue is empty and processing stopped, break // If the queue is empty and processing stopped, break
if dl.downloader.queue.Idle() && atomic.LoadInt32(&dl.downloader.processing) == 0 { if dl.downloader.queue.Idle() && atomic.LoadInt32(&dl.downloader.processing) == 0 {
@ -214,7 +214,7 @@ func (dl *downloadTester) headHeader() *types.Header {
defer dl.lock.RUnlock() defer dl.lock.RUnlock()
for i := len(dl.ownHashes) - 1; i >= 0; i-- { for i := len(dl.ownHashes) - 1; i >= 0; i-- {
if header := dl.getHeader(dl.ownHashes[i]); header != nil { if header := dl.ownHeaders[dl.ownHashes[i]]; header != nil {
return header return header
} }
} }
@ -227,7 +227,7 @@ func (dl *downloadTester) headBlock() *types.Block {
defer dl.lock.RUnlock() defer dl.lock.RUnlock()
for i := len(dl.ownHashes) - 1; i >= 0; i-- { for i := len(dl.ownHashes) - 1; i >= 0; i-- {
if block := dl.getBlock(dl.ownHashes[i]); block != nil { if block := dl.ownBlocks[dl.ownHashes[i]]; block != nil {
if _, err := dl.stateDb.Get(block.Root().Bytes()); err == nil { if _, err := dl.stateDb.Get(block.Root().Bytes()); err == nil {
return block return block
} }
@ -242,7 +242,7 @@ func (dl *downloadTester) headFastBlock() *types.Block {
defer dl.lock.RUnlock() defer dl.lock.RUnlock()
for i := len(dl.ownHashes) - 1; i >= 0; i-- { for i := len(dl.ownHashes) - 1; i >= 0; i-- {
if block := dl.getBlock(dl.ownHashes[i]); block != nil { if block := dl.ownBlocks[dl.ownHashes[i]]; block != nil {
return block return block
} }
} }
@ -291,7 +291,7 @@ func (dl *downloadTester) insertHeaders(headers []*types.Header, checkFreq int)
} }
dl.ownHashes = append(dl.ownHashes, header.Hash()) dl.ownHashes = append(dl.ownHashes, header.Hash())
dl.ownHeaders[header.Hash()] = header dl.ownHeaders[header.Hash()] = header
dl.ownChainTd[header.Hash()] = dl.ownChainTd[header.ParentHash] dl.ownChainTd[header.Hash()] = new(big.Int).Add(dl.ownChainTd[header.ParentHash], header.Difficulty)
} }
return len(headers), nil return len(headers), nil
} }
@ -305,11 +305,13 @@ func (dl *downloadTester) insertBlocks(blocks types.Blocks) (int, error) {
if _, ok := dl.ownBlocks[block.ParentHash()]; !ok { if _, ok := dl.ownBlocks[block.ParentHash()]; !ok {
return i, errors.New("unknown parent") return i, errors.New("unknown parent")
} }
dl.ownHashes = append(dl.ownHashes, block.Hash()) if _, ok := dl.ownHeaders[block.Hash()]; !ok {
dl.ownHeaders[block.Hash()] = block.Header() dl.ownHashes = append(dl.ownHashes, block.Hash())
dl.ownHeaders[block.Hash()] = block.Header()
}
dl.ownBlocks[block.Hash()] = block dl.ownBlocks[block.Hash()] = block
dl.stateDb.Put(block.Root().Bytes(), []byte{}) dl.stateDb.Put(block.Root().Bytes(), []byte{0x00})
dl.ownChainTd[block.Hash()] = dl.ownChainTd[block.ParentHash()] dl.ownChainTd[block.Hash()] = new(big.Int).Add(dl.ownChainTd[block.ParentHash()], block.Difficulty())
} }
return len(blocks), nil return len(blocks), nil
} }
@ -381,7 +383,19 @@ func (dl *downloadTester) newSlowPeer(id string, version int, hashes []common.Ha
dl.peerReceipts[id] = make(map[common.Hash]types.Receipts) dl.peerReceipts[id] = make(map[common.Hash]types.Receipts)
dl.peerChainTds[id] = make(map[common.Hash]*big.Int) dl.peerChainTds[id] = make(map[common.Hash]*big.Int)
for _, hash := range hashes { genesis := hashes[len(hashes)-1]
if header := headers[genesis]; header != nil {
dl.peerHeaders[id][genesis] = header
dl.peerChainTds[id][genesis] = header.Difficulty
}
if block := blocks[genesis]; block != nil {
dl.peerBlocks[id][genesis] = block
dl.peerChainTds[id][genesis] = block.Difficulty()
}
for i := len(hashes) - 2; i >= 0; i-- {
hash := hashes[i]
if header, ok := headers[hash]; ok { if header, ok := headers[hash]; ok {
dl.peerHeaders[id][hash] = header dl.peerHeaders[id][hash] = header
if _, ok := dl.peerHeaders[id][header.ParentHash]; ok { if _, ok := dl.peerHeaders[id][header.ParentHash]; ok {
@ -627,21 +641,28 @@ func assertOwnChain(t *testing.T, tester *downloadTester, length int) {
// number of items of the various chain components. // number of items of the various chain components.
func assertOwnForkedChain(t *testing.T, tester *downloadTester, common int, lengths []int) { func assertOwnForkedChain(t *testing.T, tester *downloadTester, common int, lengths []int) {
// Initialize the counters for the first fork // Initialize the counters for the first fork
headers, blocks, receipts := lengths[0], lengths[0], lengths[0]-minFullBlocks headers, blocks := lengths[0], lengths[0]
if receipts < 0 {
receipts = 1 minReceipts, maxReceipts := lengths[0]-fsMinFullBlocks-fsPivotInterval, lengths[0]-fsMinFullBlocks
if minReceipts < 0 {
minReceipts = 1
}
if maxReceipts < 0 {
maxReceipts = 1
} }
// Update the counters for each subsequent fork // Update the counters for each subsequent fork
for _, length := range lengths[1:] { for _, length := range lengths[1:] {
headers += length - common headers += length - common
blocks += length - common blocks += length - common
receipts += length - common - minFullBlocks
minReceipts += length - common - fsMinFullBlocks - fsPivotInterval
maxReceipts += length - common - fsMinFullBlocks
} }
switch tester.downloader.mode { switch tester.downloader.mode {
case FullSync: case FullSync:
receipts = 1 minReceipts, maxReceipts = 1, 1
case LightSync: case LightSync:
blocks, receipts = 1, 1 blocks, minReceipts, maxReceipts = 1, 1, 1
} }
if hs := len(tester.ownHeaders); hs != headers { if hs := len(tester.ownHeaders); hs != headers {
t.Fatalf("synchronised headers mismatch: have %v, want %v", hs, headers) t.Fatalf("synchronised headers mismatch: have %v, want %v", hs, headers)
@ -649,14 +670,20 @@ func assertOwnForkedChain(t *testing.T, tester *downloadTester, common int, leng
if bs := len(tester.ownBlocks); bs != blocks { if bs := len(tester.ownBlocks); bs != blocks {
t.Fatalf("synchronised blocks mismatch: have %v, want %v", bs, blocks) t.Fatalf("synchronised blocks mismatch: have %v, want %v", bs, blocks)
} }
if rs := len(tester.ownReceipts); rs != receipts { if rs := len(tester.ownReceipts); rs < minReceipts || rs > maxReceipts {
t.Fatalf("synchronised receipts mismatch: have %v, want %v", rs, receipts) t.Fatalf("synchronised receipts mismatch: have %v, want between [%v, %v]", rs, minReceipts, maxReceipts)
} }
// Verify the state trie too for fast syncs // Verify the state trie too for fast syncs
if tester.downloader.mode == FastSync { if tester.downloader.mode == FastSync {
if index := lengths[len(lengths)-1] - minFullBlocks - 1; index > 0 { index := 0
if statedb := state.New(tester.ownHeaders[tester.ownHashes[index]].Root, tester.stateDb); statedb == nil { if pivot := int(tester.downloader.queue.fastSyncPivot); pivot < common {
t.Fatalf("state reconstruction failed") index = pivot
} else {
index = len(tester.ownHashes) - lengths[len(lengths)-1] + int(tester.downloader.queue.fastSyncPivot)
}
if index > 0 {
if statedb, err := state.New(tester.ownHeaders[tester.ownHashes[index]].Root, tester.stateDb); statedb == nil || err != nil {
t.Fatalf("state reconstruction failed: %v", err)
} }
} }
} }
@ -678,11 +705,11 @@ func testCanonicalSynchronisation(t *testing.T, protocol int, mode SyncMode) {
targetBlocks := blockCacheLimit - 15 targetBlocks := blockCacheLimit - 15
hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil) hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil)
tester := newTester(mode) tester := newTester()
tester.newPeer("peer", protocol, hashes, headers, blocks, receipts) tester.newPeer("peer", protocol, hashes, headers, blocks, receipts)
// Synchronise with the peer and make sure all relevant data was retrieved // Synchronise with the peer and make sure all relevant data was retrieved
if err := tester.sync("peer", nil); err != nil { if err := tester.sync("peer", nil, mode); err != nil {
t.Fatalf("failed to synchronise blocks: %v", err) t.Fatalf("failed to synchronise blocks: %v", err)
} }
assertOwnChain(t, tester, targetBlocks+1) assertOwnChain(t, tester, targetBlocks+1)
@ -702,7 +729,7 @@ func testThrottling(t *testing.T, protocol int, mode SyncMode) {
targetBlocks := 8 * blockCacheLimit targetBlocks := 8 * blockCacheLimit
hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil) hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil)
tester := newTester(mode) tester := newTester()
tester.newPeer("peer", protocol, hashes, headers, blocks, receipts) tester.newPeer("peer", protocol, hashes, headers, blocks, receipts)
// Wrap the importer to allow stepping // Wrap the importer to allow stepping
@ -714,7 +741,7 @@ func testThrottling(t *testing.T, protocol int, mode SyncMode) {
// Start a synchronisation concurrently // Start a synchronisation concurrently
errc := make(chan error) errc := make(chan error)
go func() { go func() {
errc <- tester.sync("peer", nil) errc <- tester.sync("peer", nil, mode)
}() }()
// Iteratively take some blocks, always checking the retrieval count // Iteratively take some blocks, always checking the retrieval count
for { for {
@ -726,10 +753,11 @@ func testThrottling(t *testing.T, protocol int, mode SyncMode) {
break break
} }
// Wait a bit for sync to throttle itself // Wait a bit for sync to throttle itself
var cached int var cached, frozen int
for start := time.Now(); time.Since(start) < time.Second; { for start := time.Now(); time.Since(start) < time.Second; {
time.Sleep(25 * time.Millisecond) time.Sleep(25 * time.Millisecond)
tester.lock.RLock()
tester.downloader.queue.lock.RLock() tester.downloader.queue.lock.RLock()
cached = len(tester.downloader.queue.blockDonePool) cached = len(tester.downloader.queue.blockDonePool)
if mode == FastSync { if mode == FastSync {
@ -739,16 +767,23 @@ func testThrottling(t *testing.T, protocol int, mode SyncMode) {
} }
} }
} }
frozen = int(atomic.LoadUint32(&blocked))
retrieved = len(tester.ownBlocks)
tester.downloader.queue.lock.RUnlock() tester.downloader.queue.lock.RUnlock()
tester.lock.RUnlock()
if cached == blockCacheLimit || len(tester.ownBlocks)+cached+int(atomic.LoadUint32(&blocked)) == targetBlocks+1 { if cached == blockCacheLimit || retrieved+cached+frozen == targetBlocks+1 {
break break
} }
} }
// Make sure we filled up the cache, then exhaust it // Make sure we filled up the cache, then exhaust it
time.Sleep(25 * time.Millisecond) // give it a chance to screw up time.Sleep(25 * time.Millisecond) // give it a chance to screw up
if cached != blockCacheLimit && len(tester.ownBlocks)+cached+int(atomic.LoadUint32(&blocked)) != targetBlocks+1 {
t.Fatalf("block count mismatch: have %v, want %v (owned %v, target %v)", cached, blockCacheLimit, len(tester.ownBlocks), targetBlocks+1) tester.lock.RLock()
retrieved = len(tester.ownBlocks)
tester.lock.RUnlock()
if cached != blockCacheLimit && retrieved+cached+frozen != targetBlocks+1 {
t.Fatalf("block count mismatch: have %v, want %v (owned %v, blocked %v, target %v)", cached, blockCacheLimit, retrieved, frozen, targetBlocks+1)
} }
// Permit the blocked blocks to import // Permit the blocked blocks to import
if atomic.LoadUint32(&blocked) > 0 { if atomic.LoadUint32(&blocked) > 0 {
@ -779,18 +814,18 @@ func testForkedSynchronisation(t *testing.T, protocol int, mode SyncMode) {
common, fork := MaxHashFetch, 2*MaxHashFetch common, fork := MaxHashFetch, 2*MaxHashFetch
hashesA, hashesB, headersA, headersB, blocksA, blocksB, receiptsA, receiptsB := makeChainFork(common+fork, fork, genesis, nil) hashesA, hashesB, headersA, headersB, blocksA, blocksB, receiptsA, receiptsB := makeChainFork(common+fork, fork, genesis, nil)
tester := newTester(mode) tester := newTester()
tester.newPeer("fork A", protocol, hashesA, headersA, blocksA, receiptsA) tester.newPeer("fork A", protocol, hashesA, headersA, blocksA, receiptsA)
tester.newPeer("fork B", protocol, hashesB, headersB, blocksB, receiptsB) tester.newPeer("fork B", protocol, hashesB, headersB, blocksB, receiptsB)
// Synchronise with the peer and make sure all blocks were retrieved // Synchronise with the peer and make sure all blocks were retrieved
if err := tester.sync("fork A", nil); err != nil { if err := tester.sync("fork A", nil, mode); err != nil {
t.Fatalf("failed to synchronise blocks: %v", err) t.Fatalf("failed to synchronise blocks: %v", err)
} }
assertOwnChain(t, tester, common+fork+1) assertOwnChain(t, tester, common+fork+1)
// Synchronise with the second peer and make sure that fork is pulled too // Synchronise with the second peer and make sure that fork is pulled too
if err := tester.sync("fork B", nil); err != nil { if err := tester.sync("fork B", nil, mode); err != nil {
t.Fatalf("failed to synchronise blocks: %v", err) t.Fatalf("failed to synchronise blocks: %v", err)
} }
assertOwnForkedChain(t, tester, common+1, []int{common + fork + 1, common + fork + 1}) assertOwnForkedChain(t, tester, common+1, []int{common + fork + 1, common + fork + 1})
@ -798,7 +833,7 @@ func testForkedSynchronisation(t *testing.T, protocol int, mode SyncMode) {
// Tests that an inactive downloader will not accept incoming hashes and blocks. // Tests that an inactive downloader will not accept incoming hashes and blocks.
func TestInactiveDownloader61(t *testing.T) { func TestInactiveDownloader61(t *testing.T) {
tester := newTester(FullSync) tester := newTester()
// Check that neither hashes nor blocks are accepted // Check that neither hashes nor blocks are accepted
if err := tester.downloader.DeliverHashes("bad peer", []common.Hash{}); err != errNoSyncActive { if err := tester.downloader.DeliverHashes("bad peer", []common.Hash{}); err != errNoSyncActive {
@ -812,7 +847,7 @@ func TestInactiveDownloader61(t *testing.T) {
// Tests that an inactive downloader will not accept incoming block headers and // Tests that an inactive downloader will not accept incoming block headers and
// bodies. // bodies.
func TestInactiveDownloader62(t *testing.T) { func TestInactiveDownloader62(t *testing.T) {
tester := newTester(FullSync) tester := newTester()
// Check that neither block headers nor bodies are accepted // Check that neither block headers nor bodies are accepted
if err := tester.downloader.DeliverHeaders("bad peer", []*types.Header{}); err != errNoSyncActive { if err := tester.downloader.DeliverHeaders("bad peer", []*types.Header{}); err != errNoSyncActive {
@ -826,7 +861,7 @@ func TestInactiveDownloader62(t *testing.T) {
// Tests that an inactive downloader will not accept incoming block headers, // Tests that an inactive downloader will not accept incoming block headers,
// bodies and receipts. // bodies and receipts.
func TestInactiveDownloader63(t *testing.T) { func TestInactiveDownloader63(t *testing.T) {
tester := newTester(FullSync) tester := newTester()
// Check that neither block headers nor bodies are accepted // Check that neither block headers nor bodies are accepted
if err := tester.downloader.DeliverHeaders("bad peer", []*types.Header{}); err != errNoSyncActive { if err := tester.downloader.DeliverHeaders("bad peer", []*types.Header{}); err != errNoSyncActive {
@ -860,7 +895,7 @@ func testCancel(t *testing.T, protocol int, mode SyncMode) {
} }
hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil) hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil)
tester := newTester(mode) tester := newTester()
tester.newPeer("peer", protocol, hashes, headers, blocks, receipts) tester.newPeer("peer", protocol, hashes, headers, blocks, receipts)
// Make sure canceling works with a pristine downloader // Make sure canceling works with a pristine downloader
@ -869,7 +904,7 @@ func testCancel(t *testing.T, protocol int, mode SyncMode) {
t.Errorf("download queue not idle") t.Errorf("download queue not idle")
} }
// Synchronise with the peer, but cancel afterwards // Synchronise with the peer, but cancel afterwards
if err := tester.sync("peer", nil); err != nil { if err := tester.sync("peer", nil, mode); err != nil {
t.Fatalf("failed to synchronise blocks: %v", err) t.Fatalf("failed to synchronise blocks: %v", err)
} }
tester.downloader.cancel() tester.downloader.cancel()
@ -893,12 +928,12 @@ func testMultiSynchronisation(t *testing.T, protocol int, mode SyncMode) {
targetBlocks := targetPeers*blockCacheLimit - 15 targetBlocks := targetPeers*blockCacheLimit - 15
hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil) hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil)
tester := newTester(mode) tester := newTester()
for i := 0; i < targetPeers; i++ { for i := 0; i < targetPeers; i++ {
id := fmt.Sprintf("peer #%d", i) id := fmt.Sprintf("peer #%d", i)
tester.newPeer(id, protocol, hashes[i*blockCacheLimit:], headers, blocks, receipts) tester.newPeer(id, protocol, hashes[i*blockCacheLimit:], headers, blocks, receipts)
} }
if err := tester.sync("peer #0", nil); err != nil { if err := tester.sync("peer #0", nil, mode); err != nil {
t.Fatalf("failed to synchronise blocks: %v", err) t.Fatalf("failed to synchronise blocks: %v", err)
} }
assertOwnChain(t, tester, targetBlocks+1) assertOwnChain(t, tester, targetBlocks+1)
@ -920,14 +955,14 @@ func testMultiProtoSync(t *testing.T, protocol int, mode SyncMode) {
hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil) hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil)
// Create peers of every type // Create peers of every type
tester := newTester(mode) tester := newTester()
tester.newPeer("peer 61", 61, hashes, headers, blocks, receipts) tester.newPeer("peer 61", 61, hashes, nil, blocks, nil)
tester.newPeer("peer 62", 62, hashes, headers, blocks, receipts) tester.newPeer("peer 62", 62, hashes, headers, blocks, nil)
tester.newPeer("peer 63", 63, hashes, headers, blocks, receipts) tester.newPeer("peer 63", 63, hashes, headers, blocks, receipts)
tester.newPeer("peer 64", 64, hashes, headers, blocks, receipts) tester.newPeer("peer 64", 64, hashes, headers, blocks, receipts)
// Synchronise with the requestd peer and make sure all blocks were retrieved // Synchronise with the requested peer and make sure all blocks were retrieved
if err := tester.sync(fmt.Sprintf("peer %d", protocol), nil); err != nil { if err := tester.sync(fmt.Sprintf("peer %d", protocol), nil, mode); err != nil {
t.Fatalf("failed to synchronise blocks: %v", err) t.Fatalf("failed to synchronise blocks: %v", err)
} }
assertOwnChain(t, tester, targetBlocks+1) assertOwnChain(t, tester, targetBlocks+1)
@ -955,7 +990,7 @@ func testEmptyShortCircuit(t *testing.T, protocol int, mode SyncMode) {
targetBlocks := 2*blockCacheLimit - 15 targetBlocks := 2*blockCacheLimit - 15
hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil) hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil)
tester := newTester(mode) tester := newTester()
tester.newPeer("peer", protocol, hashes, headers, blocks, receipts) tester.newPeer("peer", protocol, hashes, headers, blocks, receipts)
// Instrument the downloader to signal body requests // Instrument the downloader to signal body requests
@ -967,7 +1002,7 @@ func testEmptyShortCircuit(t *testing.T, protocol int, mode SyncMode) {
atomic.AddInt32(&receiptsHave, int32(len(headers))) atomic.AddInt32(&receiptsHave, int32(len(headers)))
} }
// Synchronise with the peer and make sure all blocks were retrieved // Synchronise with the peer and make sure all blocks were retrieved
if err := tester.sync("peer", nil); err != nil { if err := tester.sync("peer", nil, mode); err != nil {
t.Fatalf("failed to synchronise blocks: %v", err) t.Fatalf("failed to synchronise blocks: %v", err)
} }
assertOwnChain(t, tester, targetBlocks+1) assertOwnChain(t, tester, targetBlocks+1)
@ -980,7 +1015,7 @@ func testEmptyShortCircuit(t *testing.T, protocol int, mode SyncMode) {
} }
} }
for hash, receipt := range receipts { for hash, receipt := range receipts {
if mode == FastSync && len(receipt) > 0 && headers[hash].Number.Uint64() <= uint64(targetBlocks-minFullBlocks) { if mode == FastSync && len(receipt) > 0 && headers[hash].Number.Uint64() <= tester.downloader.queue.fastSyncPivot {
receiptsNeeded++ receiptsNeeded++
} }
} }
@ -1006,19 +1041,19 @@ func testMissingHeaderAttack(t *testing.T, protocol int, mode SyncMode) {
targetBlocks := blockCacheLimit - 15 targetBlocks := blockCacheLimit - 15
hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil) hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil)
tester := newTester(mode) tester := newTester()
// Attempt a full sync with an attacker feeding gapped headers // Attempt a full sync with an attacker feeding gapped headers
tester.newPeer("attack", protocol, hashes, headers, blocks, receipts) tester.newPeer("attack", protocol, hashes, headers, blocks, receipts)
missing := targetBlocks / 2 missing := targetBlocks / 2
delete(tester.peerHeaders["attack"], hashes[missing]) delete(tester.peerHeaders["attack"], hashes[missing])
if err := tester.sync("attack", nil); err == nil { if err := tester.sync("attack", nil, mode); err == nil {
t.Fatalf("succeeded attacker synchronisation") t.Fatalf("succeeded attacker synchronisation")
} }
// Synchronise with the valid peer and make sure sync succeeds // Synchronise with the valid peer and make sure sync succeeds
tester.newPeer("valid", protocol, hashes, headers, blocks, receipts) tester.newPeer("valid", protocol, hashes, headers, blocks, receipts)
if err := tester.sync("valid", nil); err != nil { if err := tester.sync("valid", nil, mode); err != nil {
t.Fatalf("failed to synchronise blocks: %v", err) t.Fatalf("failed to synchronise blocks: %v", err)
} }
assertOwnChain(t, tester, targetBlocks+1) assertOwnChain(t, tester, targetBlocks+1)
@ -1038,7 +1073,7 @@ func testShiftedHeaderAttack(t *testing.T, protocol int, mode SyncMode) {
targetBlocks := blockCacheLimit - 15 targetBlocks := blockCacheLimit - 15
hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil) hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil)
tester := newTester(mode) tester := newTester()
// Attempt a full sync with an attacker feeding shifted headers // Attempt a full sync with an attacker feeding shifted headers
tester.newPeer("attack", protocol, hashes, headers, blocks, receipts) tester.newPeer("attack", protocol, hashes, headers, blocks, receipts)
@ -1046,12 +1081,12 @@ func testShiftedHeaderAttack(t *testing.T, protocol int, mode SyncMode) {
delete(tester.peerBlocks["attack"], hashes[len(hashes)-2]) delete(tester.peerBlocks["attack"], hashes[len(hashes)-2])
delete(tester.peerReceipts["attack"], hashes[len(hashes)-2]) delete(tester.peerReceipts["attack"], hashes[len(hashes)-2])
if err := tester.sync("attack", nil); err == nil { if err := tester.sync("attack", nil, mode); err == nil {
t.Fatalf("succeeded attacker synchronisation") t.Fatalf("succeeded attacker synchronisation")
} }
// Synchronise with the valid peer and make sure sync succeeds // Synchronise with the valid peer and make sure sync succeeds
tester.newPeer("valid", protocol, hashes, headers, blocks, receipts) tester.newPeer("valid", protocol, hashes, headers, blocks, receipts)
if err := tester.sync("valid", nil); err != nil { if err := tester.sync("valid", nil, mode); err != nil {
t.Fatalf("failed to synchronise blocks: %v", err) t.Fatalf("failed to synchronise blocks: %v", err)
} }
assertOwnChain(t, tester, targetBlocks+1) assertOwnChain(t, tester, targetBlocks+1)
@ -1064,92 +1099,81 @@ func TestInvalidHeaderRollback64Light(t *testing.T) { testInvalidHeaderRollback(
func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) { func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) {
// Create a small enough block chain to download // Create a small enough block chain to download
targetBlocks := 3*minCheckedHeaders + minFullBlocks targetBlocks := 3*fsHeaderSafetyNet + fsMinFullBlocks
hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil) hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil)
tester := newTester(mode) tester := newTester()
// Attempt to sync with an attacker that feeds junk during the fast sync phase // Attempt to sync with an attacker that feeds junk during the fast sync phase.
// This should result in the last fsHeaderSafetyNet headers being rolled back.
tester.newPeer("fast-attack", protocol, hashes, headers, blocks, receipts) tester.newPeer("fast-attack", protocol, hashes, headers, blocks, receipts)
missing := minCheckedHeaders + MaxHeaderFetch + 1 missing := fsHeaderSafetyNet + MaxHeaderFetch + 1
delete(tester.peerHeaders["fast-attack"], hashes[len(hashes)-missing]) delete(tester.peerHeaders["fast-attack"], hashes[len(hashes)-missing])
if err := tester.sync("fast-attack", nil); err == nil { if err := tester.sync("fast-attack", nil, mode); err == nil {
t.Fatalf("succeeded fast attacker synchronisation") t.Fatalf("succeeded fast attacker synchronisation")
} }
if head := tester.headHeader().Number.Int64(); int(head) > MaxHeaderFetch { if head := tester.headHeader().Number.Int64(); int(head) > MaxHeaderFetch {
t.Fatalf("rollback head mismatch: have %v, want at most %v", head, MaxHeaderFetch) t.Errorf("rollback head mismatch: have %v, want at most %v", head, MaxHeaderFetch)
} }
// Attempt to sync with an attacker that feeds junk during the block import phase // Attempt to sync with an attacker that feeds junk during the block import phase.
// This should result in both the last fsHeaderSafetyNet number of headers being
// rolled back, and also the pivot point being reverted to a non-block status.
tester.newPeer("block-attack", protocol, hashes, headers, blocks, receipts) tester.newPeer("block-attack", protocol, hashes, headers, blocks, receipts)
missing = 3*minCheckedHeaders + MaxHeaderFetch + 1 missing = 3*fsHeaderSafetyNet + MaxHeaderFetch + 1
delete(tester.peerHeaders["block-attack"], hashes[len(hashes)-missing]) delete(tester.peerHeaders["block-attack"], hashes[len(hashes)-missing])
if err := tester.sync("block-attack", nil); err == nil { if err := tester.sync("block-attack", nil, mode); err == nil {
t.Fatalf("succeeded block attacker synchronisation") t.Fatalf("succeeded block attacker synchronisation")
} }
if head := tester.headHeader().Number.Int64(); int(head) > 2*fsHeaderSafetyNet+MaxHeaderFetch {
t.Errorf("rollback head mismatch: have %v, want at most %v", head, 2*fsHeaderSafetyNet+MaxHeaderFetch)
}
if mode == FastSync { if mode == FastSync {
// Fast sync should not discard anything below the verified pivot point if head := tester.headBlock().NumberU64(); head != 0 {
if head := tester.headHeader().Number.Int64(); int(head) < 3*minCheckedHeaders { t.Errorf("fast sync pivot block #%d not rolled back", head)
t.Fatalf("rollback head mismatch: have %v, want at least %v", head, 3*minCheckedHeaders)
} }
} else if mode == LightSync {
// Light sync should still discard data as before
if head := tester.headHeader().Number.Int64(); int(head) > 3*minCheckedHeaders {
t.Fatalf("rollback head mismatch: have %v, want at most %v", head, 3*minCheckedHeaders)
}
}
// Synchronise with the valid peer and make sure sync succeeds
tester.newPeer("valid", protocol, hashes, headers, blocks, receipts)
if err := tester.sync("valid", nil); err != nil {
t.Fatalf("failed to synchronise blocks: %v", err)
} }
assertOwnChain(t, tester, targetBlocks+1) // Attempt to sync with an attacker that withholds promised blocks after the
} // fast sync pivot point. This could be a trial to leave the node with a bad
// but already imported pivot block.
tester.newPeer("withhold-attack", protocol, hashes, headers, blocks, receipts)
missing = 3*fsHeaderSafetyNet + MaxHeaderFetch + 1
// Tests that if a peer sends an invalid block piece (body or receipt) for a tester.downloader.noFast = false
// requested block, it gets dropped immediately by the downloader. tester.downloader.syncInitHook = func(uint64, uint64) {
func TestInvalidContentAttack62(t *testing.T) { testInvalidContentAttack(t, 62, FullSync) } for i := missing; i <= len(hashes); i++ {
func TestInvalidContentAttack63Full(t *testing.T) { testInvalidContentAttack(t, 63, FullSync) } delete(tester.peerHeaders["withhold-attack"], hashes[len(hashes)-i])
func TestInvalidContentAttack63Fast(t *testing.T) { testInvalidContentAttack(t, 63, FastSync) } }
func TestInvalidContentAttack64Full(t *testing.T) { testInvalidContentAttack(t, 64, FullSync) } tester.downloader.syncInitHook = nil
func TestInvalidContentAttack64Fast(t *testing.T) { testInvalidContentAttack(t, 64, FastSync) }
func TestInvalidContentAttack64Light(t *testing.T) { testInvalidContentAttack(t, 64, LightSync) }
func testInvalidContentAttack(t *testing.T, protocol int, mode SyncMode) {
// Create two peers, one feeding invalid block bodies
targetBlocks := 4*blockCacheLimit - 15
hashes, headers, validBlocks, validReceipts := makeChain(targetBlocks, 0, genesis, nil)
invalidBlocks := make(map[common.Hash]*types.Block)
for hash, block := range validBlocks {
invalidBlocks[hash] = types.NewBlockWithHeader(block.Header())
}
invalidReceipts := make(map[common.Hash]types.Receipts)
for hash, _ := range validReceipts {
invalidReceipts[hash] = types.Receipts{&types.Receipt{}}
} }
tester := newTester(mode) if err := tester.sync("withhold-attack", nil, mode); err == nil {
tester.newPeer("valid", protocol, hashes, headers, validBlocks, validReceipts) t.Fatalf("succeeded withholding attacker synchronisation")
if mode != LightSync { }
tester.newPeer("body attack", protocol, hashes, headers, invalidBlocks, validReceipts) if head := tester.headHeader().Number.Int64(); int(head) > 2*fsHeaderSafetyNet+MaxHeaderFetch {
t.Errorf("rollback head mismatch: have %v, want at most %v", head, 2*fsHeaderSafetyNet+MaxHeaderFetch)
} }
if mode == FastSync { if mode == FastSync {
tester.newPeer("receipt attack", protocol, hashes, headers, validBlocks, invalidReceipts) if head := tester.headBlock().NumberU64(); head != 0 {
t.Errorf("fast sync pivot block #%d not rolled back", head)
}
} }
// Synchronise with the valid peer (will pull contents from the attacker too) // Synchronise with the valid peer and make sure sync succeeds. Since the last
if err := tester.sync("valid", nil); err != nil { // rollback should also disable fast syncing for this process, verify that we
// did a fresh full sync. Note, we can't assert anything about the receipts
// since we won't purge the database of them, hence we can't use asserOwnChain.
tester.newPeer("valid", protocol, hashes, headers, blocks, receipts)
if err := tester.sync("valid", nil, mode); err != nil {
t.Fatalf("failed to synchronise blocks: %v", err) t.Fatalf("failed to synchronise blocks: %v", err)
} }
assertOwnChain(t, tester, targetBlocks+1) if hs := len(tester.ownHeaders); hs != len(headers) {
t.Fatalf("synchronised headers mismatch: have %v, want %v", hs, len(headers))
// Make sure the attacker was detected and dropped in the mean time
if _, ok := tester.peerHashes["body attack"]; ok {
t.Fatalf("block body attacker not detected/dropped")
} }
if _, ok := tester.peerHashes["receipt attack"]; ok { if mode != LightSync {
t.Fatalf("receipt attacker not detected/dropped") if bs := len(tester.ownBlocks); bs != len(blocks) {
t.Fatalf("synchronised blocks mismatch: have %v, want %v", bs, len(blocks))
}
} }
} }
@ -1164,11 +1188,11 @@ func TestHighTDStarvationAttack64Fast(t *testing.T) { testHighTDStarvationAttac
func TestHighTDStarvationAttack64Light(t *testing.T) { testHighTDStarvationAttack(t, 64, LightSync) } func TestHighTDStarvationAttack64Light(t *testing.T) { testHighTDStarvationAttack(t, 64, LightSync) }
func testHighTDStarvationAttack(t *testing.T, protocol int, mode SyncMode) { func testHighTDStarvationAttack(t *testing.T, protocol int, mode SyncMode) {
tester := newTester(mode) tester := newTester()
hashes, headers, blocks, receipts := makeChain(0, 0, genesis, nil) hashes, headers, blocks, receipts := makeChain(0, 0, genesis, nil)
tester.newPeer("attack", protocol, []common.Hash{hashes[0]}, headers, blocks, receipts) tester.newPeer("attack", protocol, []common.Hash{hashes[0]}, headers, blocks, receipts)
if err := tester.sync("attack", big.NewInt(1000000)); err != errStallingPeer { if err := tester.sync("attack", big.NewInt(1000000), mode); err != errStallingPeer {
t.Fatalf("synchronisation error mismatch: have %v, want %v", err, errStallingPeer) t.Fatalf("synchronisation error mismatch: have %v, want %v", err, errStallingPeer)
} }
} }
@ -1206,7 +1230,7 @@ func testBlockHeaderAttackerDropping(t *testing.T, protocol int) {
{errCancelBodyFetch, false}, // Synchronisation was canceled, origin may be innocent, don't drop {errCancelBodyFetch, false}, // Synchronisation was canceled, origin may be innocent, don't drop
} }
// Run the tests and check disconnection status // Run the tests and check disconnection status
tester := newTester(FullSync) tester := newTester()
for i, tt := range tests { for i, tt := range tests {
// Register a new peer and ensure it's presence // Register a new peer and ensure it's presence
id := fmt.Sprintf("test %d", i) id := fmt.Sprintf("test %d", i)
@ -1219,120 +1243,125 @@ func testBlockHeaderAttackerDropping(t *testing.T, protocol int) {
// Simulate a synchronisation and check the required result // Simulate a synchronisation and check the required result
tester.downloader.synchroniseMock = func(string, common.Hash) error { return tt.result } tester.downloader.synchroniseMock = func(string, common.Hash) error { return tt.result }
tester.downloader.Synchronise(id, genesis.Hash(), big.NewInt(1000)) tester.downloader.Synchronise(id, genesis.Hash(), big.NewInt(1000), FullSync)
if _, ok := tester.peerHashes[id]; !ok != tt.drop { if _, ok := tester.peerHashes[id]; !ok != tt.drop {
t.Errorf("test %d: peer drop mismatch for %v: have %v, want %v", i, tt.result, !ok, tt.drop) t.Errorf("test %d: peer drop mismatch for %v: have %v, want %v", i, tt.result, !ok, tt.drop)
} }
} }
} }
// Tests that synchronisation boundaries (origin block number and highest block // Tests that synchronisation progress (origin block number, current block number
// number) is tracked and updated correctly. // and highest block number) is tracked and updated correctly.
func TestSyncBoundaries61(t *testing.T) { testSyncBoundaries(t, 61, FullSync) } func TestSyncProgress61(t *testing.T) { testSyncProgress(t, 61, FullSync) }
func TestSyncBoundaries62(t *testing.T) { testSyncBoundaries(t, 62, FullSync) } func TestSyncProgress62(t *testing.T) { testSyncProgress(t, 62, FullSync) }
func TestSyncBoundaries63Full(t *testing.T) { testSyncBoundaries(t, 63, FullSync) } func TestSyncProgress63Full(t *testing.T) { testSyncProgress(t, 63, FullSync) }
func TestSyncBoundaries63Fast(t *testing.T) { testSyncBoundaries(t, 63, FastSync) } func TestSyncProgress63Fast(t *testing.T) { testSyncProgress(t, 63, FastSync) }
func TestSyncBoundaries64Full(t *testing.T) { testSyncBoundaries(t, 64, FullSync) } func TestSyncProgress64Full(t *testing.T) { testSyncProgress(t, 64, FullSync) }
func TestSyncBoundaries64Fast(t *testing.T) { testSyncBoundaries(t, 64, FastSync) } func TestSyncProgress64Fast(t *testing.T) { testSyncProgress(t, 64, FastSync) }
func TestSyncBoundaries64Light(t *testing.T) { testSyncBoundaries(t, 64, LightSync) } func TestSyncProgress64Light(t *testing.T) { testSyncProgress(t, 64, LightSync) }
func testSyncBoundaries(t *testing.T, protocol int, mode SyncMode) { func testSyncProgress(t *testing.T, protocol int, mode SyncMode) {
// Create a small enough block chain to download // Create a small enough block chain to download
targetBlocks := blockCacheLimit - 15 targetBlocks := blockCacheLimit - 15
hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil) hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil)
// Set a sync init hook to catch boundary changes // Set a sync init hook to catch progress changes
starting := make(chan struct{}) starting := make(chan struct{})
progress := make(chan struct{}) progress := make(chan struct{})
tester := newTester(mode) tester := newTester()
tester.downloader.syncInitHook = func(origin, latest uint64) { tester.downloader.syncInitHook = func(origin, latest uint64) {
starting <- struct{}{} starting <- struct{}{}
<-progress <-progress
} }
// Retrieve the sync boundaries and ensure they are zero (pristine sync) // Retrieve the sync progress and ensure they are zero (pristine sync)
if origin, latest := tester.downloader.Boundaries(); origin != 0 || latest != 0 { if origin, current, latest := tester.downloader.Progress(); origin != 0 || current != 0 || latest != 0 {
t.Fatalf("Pristine boundary mismatch: have %v/%v, want %v/%v", origin, latest, 0, 0) t.Fatalf("Pristine progress mismatch: have %v/%v/%v, want %v/%v/%v", origin, current, latest, 0, 0, 0)
} }
// Synchronise half the blocks and check initial boundaries // Synchronise half the blocks and check initial progress
tester.newPeer("peer-half", protocol, hashes[targetBlocks/2:], headers, blocks, receipts) tester.newPeer("peer-half", protocol, hashes[targetBlocks/2:], headers, blocks, receipts)
pending := new(sync.WaitGroup) pending := new(sync.WaitGroup)
pending.Add(1) pending.Add(1)
go func() { go func() {
defer pending.Done() defer pending.Done()
if err := tester.sync("peer-half", nil); err != nil { if err := tester.sync("peer-half", nil, mode); err != nil {
t.Fatalf("failed to synchronise blocks: %v", err) t.Fatalf("failed to synchronise blocks: %v", err)
} }
}() }()
<-starting <-starting
if origin, latest := tester.downloader.Boundaries(); origin != 0 || latest != uint64(targetBlocks/2+1) { if origin, current, latest := tester.downloader.Progress(); origin != 0 || current != 0 || latest != uint64(targetBlocks/2+1) {
t.Fatalf("Initial boundary mismatch: have %v/%v, want %v/%v", origin, latest, 0, targetBlocks/2+1) t.Fatalf("Initial progress mismatch: have %v/%v/%v, want %v/%v/%v", origin, current, latest, 0, 0, targetBlocks/2+1)
} }
progress <- struct{}{} progress <- struct{}{}
pending.Wait() pending.Wait()
// Synchronise all the blocks and check continuation boundaries // Synchronise all the blocks and check continuation progress
tester.newPeer("peer-full", protocol, hashes, headers, blocks, receipts) tester.newPeer("peer-full", protocol, hashes, headers, blocks, receipts)
pending.Add(1) pending.Add(1)
go func() { go func() {
defer pending.Done() defer pending.Done()
if err := tester.sync("peer-full", nil); err != nil { if err := tester.sync("peer-full", nil, mode); err != nil {
t.Fatalf("failed to synchronise blocks: %v", err) t.Fatalf("failed to synchronise blocks: %v", err)
} }
}() }()
<-starting <-starting
if origin, latest := tester.downloader.Boundaries(); origin != uint64(targetBlocks/2+1) || latest != uint64(targetBlocks) { if origin, current, latest := tester.downloader.Progress(); origin != uint64(targetBlocks/2+1) || current != uint64(targetBlocks/2+1) || latest != uint64(targetBlocks) {
t.Fatalf("Completing boundary mismatch: have %v/%v, want %v/%v", origin, latest, targetBlocks/2+1, targetBlocks) t.Fatalf("Completing progress mismatch: have %v/%v/%v, want %v/%v/%v", origin, current, latest, targetBlocks/2+1, targetBlocks/2+1, targetBlocks)
} }
progress <- struct{}{} progress <- struct{}{}
pending.Wait() pending.Wait()
// Check final progress after successful sync
if origin, current, latest := tester.downloader.Progress(); origin != uint64(targetBlocks/2+1) || current != uint64(targetBlocks) || latest != uint64(targetBlocks) {
t.Fatalf("Final progress mismatch: have %v/%v/%v, want %v/%v/%v", origin, current, latest, targetBlocks/2+1, targetBlocks, targetBlocks)
}
} }
// Tests that synchronisation boundaries (origin block number and highest block // Tests that synchronisation progress (origin block number and highest block
// number) is tracked and updated correctly in case of a fork (or manual head // number) is tracked and updated correctly in case of a fork (or manual head
// revertal). // revertal).
func TestForkedSyncBoundaries61(t *testing.T) { testForkedSyncBoundaries(t, 61, FullSync) } func TestForkedSyncProgress61(t *testing.T) { testForkedSyncProgress(t, 61, FullSync) }
func TestForkedSyncBoundaries62(t *testing.T) { testForkedSyncBoundaries(t, 62, FullSync) } func TestForkedSyncProgress62(t *testing.T) { testForkedSyncProgress(t, 62, FullSync) }
func TestForkedSyncBoundaries63Full(t *testing.T) { testForkedSyncBoundaries(t, 63, FullSync) } func TestForkedSyncProgress63Full(t *testing.T) { testForkedSyncProgress(t, 63, FullSync) }
func TestForkedSyncBoundaries63Fast(t *testing.T) { testForkedSyncBoundaries(t, 63, FastSync) } func TestForkedSyncProgress63Fast(t *testing.T) { testForkedSyncProgress(t, 63, FastSync) }
func TestForkedSyncBoundaries64Full(t *testing.T) { testForkedSyncBoundaries(t, 64, FullSync) } func TestForkedSyncProgress64Full(t *testing.T) { testForkedSyncProgress(t, 64, FullSync) }
func TestForkedSyncBoundaries64Fast(t *testing.T) { testForkedSyncBoundaries(t, 64, FastSync) } func TestForkedSyncProgress64Fast(t *testing.T) { testForkedSyncProgress(t, 64, FastSync) }
func TestForkedSyncBoundaries64Light(t *testing.T) { testForkedSyncBoundaries(t, 64, LightSync) } func TestForkedSyncProgress64Light(t *testing.T) { testForkedSyncProgress(t, 64, LightSync) }
func testForkedSyncBoundaries(t *testing.T, protocol int, mode SyncMode) { func testForkedSyncProgress(t *testing.T, protocol int, mode SyncMode) {
// Create a forked chain to simulate origin revertal // Create a forked chain to simulate origin revertal
common, fork := MaxHashFetch, 2*MaxHashFetch common, fork := MaxHashFetch, 2*MaxHashFetch
hashesA, hashesB, headersA, headersB, blocksA, blocksB, receiptsA, receiptsB := makeChainFork(common+fork, fork, genesis, nil) hashesA, hashesB, headersA, headersB, blocksA, blocksB, receiptsA, receiptsB := makeChainFork(common+fork, fork, genesis, nil)
// Set a sync init hook to catch boundary changes // Set a sync init hook to catch progress changes
starting := make(chan struct{}) starting := make(chan struct{})
progress := make(chan struct{}) progress := make(chan struct{})
tester := newTester(mode) tester := newTester()
tester.downloader.syncInitHook = func(origin, latest uint64) { tester.downloader.syncInitHook = func(origin, latest uint64) {
starting <- struct{}{} starting <- struct{}{}
<-progress <-progress
} }
// Retrieve the sync boundaries and ensure they are zero (pristine sync) // Retrieve the sync progress and ensure they are zero (pristine sync)
if origin, latest := tester.downloader.Boundaries(); origin != 0 || latest != 0 { if origin, current, latest := tester.downloader.Progress(); origin != 0 || current != 0 || latest != 0 {
t.Fatalf("Pristine boundary mismatch: have %v/%v, want %v/%v", origin, latest, 0, 0) t.Fatalf("Pristine progress mismatch: have %v/%v/%v, want %v/%v/%v", origin, current, latest, 0, 0, 0)
} }
// Synchronise with one of the forks and check boundaries // Synchronise with one of the forks and check progress
tester.newPeer("fork A", protocol, hashesA, headersA, blocksA, receiptsA) tester.newPeer("fork A", protocol, hashesA, headersA, blocksA, receiptsA)
pending := new(sync.WaitGroup) pending := new(sync.WaitGroup)
pending.Add(1) pending.Add(1)
go func() { go func() {
defer pending.Done() defer pending.Done()
if err := tester.sync("fork A", nil); err != nil { if err := tester.sync("fork A", nil, mode); err != nil {
t.Fatalf("failed to synchronise blocks: %v", err) t.Fatalf("failed to synchronise blocks: %v", err)
} }
}() }()
<-starting <-starting
if origin, latest := tester.downloader.Boundaries(); origin != 0 || latest != uint64(len(hashesA)-1) { if origin, current, latest := tester.downloader.Progress(); origin != 0 || current != 0 || latest != uint64(len(hashesA)-1) {
t.Fatalf("Initial boundary mismatch: have %v/%v, want %v/%v", origin, latest, 0, len(hashesA)-1) t.Fatalf("Initial progress mismatch: have %v/%v/%v, want %v/%v/%v", origin, current, latest, 0, 0, len(hashesA)-1)
} }
progress <- struct{}{} progress <- struct{}{}
pending.Wait() pending.Wait()
@ -1340,52 +1369,57 @@ func testForkedSyncBoundaries(t *testing.T, protocol int, mode SyncMode) {
// Simulate a successful sync above the fork // Simulate a successful sync above the fork
tester.downloader.syncStatsChainOrigin = tester.downloader.syncStatsChainHeight tester.downloader.syncStatsChainOrigin = tester.downloader.syncStatsChainHeight
// Synchronise with the second fork and check boundary resets // Synchronise with the second fork and check progress resets
tester.newPeer("fork B", protocol, hashesB, headersB, blocksB, receiptsB) tester.newPeer("fork B", protocol, hashesB, headersB, blocksB, receiptsB)
pending.Add(1) pending.Add(1)
go func() { go func() {
defer pending.Done() defer pending.Done()
if err := tester.sync("fork B", nil); err != nil { if err := tester.sync("fork B", nil, mode); err != nil {
t.Fatalf("failed to synchronise blocks: %v", err) t.Fatalf("failed to synchronise blocks: %v", err)
} }
}() }()
<-starting <-starting
if origin, latest := tester.downloader.Boundaries(); origin != uint64(common) || latest != uint64(len(hashesB)-1) { if origin, current, latest := tester.downloader.Progress(); origin != uint64(common) || current != uint64(len(hashesA)-1) || latest != uint64(len(hashesB)-1) {
t.Fatalf("Forking boundary mismatch: have %v/%v, want %v/%v", origin, latest, common, len(hashesB)-1) t.Fatalf("Forking progress mismatch: have %v/%v/%v, want %v/%v/%v", origin, current, latest, common, len(hashesA)-1, len(hashesB)-1)
} }
progress <- struct{}{} progress <- struct{}{}
pending.Wait() pending.Wait()
// Check final progress after successful sync
if origin, current, latest := tester.downloader.Progress(); origin != uint64(common) || current != uint64(len(hashesB)-1) || latest != uint64(len(hashesB)-1) {
t.Fatalf("Final progress mismatch: have %v/%v/%v, want %v/%v/%v", origin, current, latest, common, len(hashesB)-1, len(hashesB)-1)
}
} }
// Tests that if synchronisation is aborted due to some failure, then the boundary // Tests that if synchronisation is aborted due to some failure, then the progress
// origin is not updated in the next sync cycle, as it should be considered the // origin is not updated in the next sync cycle, as it should be considered the
// continuation of the previous sync and not a new instance. // continuation of the previous sync and not a new instance.
func TestFailedSyncBoundaries61(t *testing.T) { testFailedSyncBoundaries(t, 61, FullSync) } func TestFailedSyncProgress61(t *testing.T) { testFailedSyncProgress(t, 61, FullSync) }
func TestFailedSyncBoundaries62(t *testing.T) { testFailedSyncBoundaries(t, 62, FullSync) } func TestFailedSyncProgress62(t *testing.T) { testFailedSyncProgress(t, 62, FullSync) }
func TestFailedSyncBoundaries63Full(t *testing.T) { testFailedSyncBoundaries(t, 63, FullSync) } func TestFailedSyncProgress63Full(t *testing.T) { testFailedSyncProgress(t, 63, FullSync) }
func TestFailedSyncBoundaries63Fast(t *testing.T) { testFailedSyncBoundaries(t, 63, FastSync) } func TestFailedSyncProgress63Fast(t *testing.T) { testFailedSyncProgress(t, 63, FastSync) }
func TestFailedSyncBoundaries64Full(t *testing.T) { testFailedSyncBoundaries(t, 64, FullSync) } func TestFailedSyncProgress64Full(t *testing.T) { testFailedSyncProgress(t, 64, FullSync) }
func TestFailedSyncBoundaries64Fast(t *testing.T) { testFailedSyncBoundaries(t, 64, FastSync) } func TestFailedSyncProgress64Fast(t *testing.T) { testFailedSyncProgress(t, 64, FastSync) }
func TestFailedSyncBoundaries64Light(t *testing.T) { testFailedSyncBoundaries(t, 64, LightSync) } func TestFailedSyncProgress64Light(t *testing.T) { testFailedSyncProgress(t, 64, LightSync) }
func testFailedSyncBoundaries(t *testing.T, protocol int, mode SyncMode) { func testFailedSyncProgress(t *testing.T, protocol int, mode SyncMode) {
// Create a small enough block chain to download // Create a small enough block chain to download
targetBlocks := blockCacheLimit - 15 targetBlocks := blockCacheLimit - 15
hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil) hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil)
// Set a sync init hook to catch boundary changes // Set a sync init hook to catch progress changes
starting := make(chan struct{}) starting := make(chan struct{})
progress := make(chan struct{}) progress := make(chan struct{})
tester := newTester(mode) tester := newTester()
tester.downloader.syncInitHook = func(origin, latest uint64) { tester.downloader.syncInitHook = func(origin, latest uint64) {
starting <- struct{}{} starting <- struct{}{}
<-progress <-progress
} }
// Retrieve the sync boundaries and ensure they are zero (pristine sync) // Retrieve the sync progress and ensure they are zero (pristine sync)
if origin, latest := tester.downloader.Boundaries(); origin != 0 || latest != 0 { if origin, current, latest := tester.downloader.Progress(); origin != 0 || current != 0 || latest != 0 {
t.Fatalf("Pristine boundary mismatch: have %v/%v, want %v/%v", origin, latest, 0, 0) t.Fatalf("Pristine progress mismatch: have %v/%v/%v, want %v/%v/%v", origin, current, latest, 0, 0, 0)
} }
// Attempt a full sync with a faulty peer // Attempt a full sync with a faulty peer
tester.newPeer("faulty", protocol, hashes, headers, blocks, receipts) tester.newPeer("faulty", protocol, hashes, headers, blocks, receipts)
@ -1399,62 +1433,67 @@ func testFailedSyncBoundaries(t *testing.T, protocol int, mode SyncMode) {
go func() { go func() {
defer pending.Done() defer pending.Done()
if err := tester.sync("faulty", nil); err == nil { if err := tester.sync("faulty", nil, mode); err == nil {
t.Fatalf("succeeded faulty synchronisation") t.Fatalf("succeeded faulty synchronisation")
} }
}() }()
<-starting <-starting
if origin, latest := tester.downloader.Boundaries(); origin != 0 || latest != uint64(targetBlocks) { if origin, current, latest := tester.downloader.Progress(); origin != 0 || current != 0 || latest != uint64(targetBlocks) {
t.Fatalf("Initial boundary mismatch: have %v/%v, want %v/%v", origin, latest, 0, targetBlocks) t.Fatalf("Initial progress mismatch: have %v/%v/%v, want %v/%v/%v", origin, current, latest, 0, 0, targetBlocks)
} }
progress <- struct{}{} progress <- struct{}{}
pending.Wait() pending.Wait()
// Synchronise with a good peer and check that the boundary origin remind the same after a failure // Synchronise with a good peer and check that the progress origin remind the same after a failure
tester.newPeer("valid", protocol, hashes, headers, blocks, receipts) tester.newPeer("valid", protocol, hashes, headers, blocks, receipts)
pending.Add(1) pending.Add(1)
go func() { go func() {
defer pending.Done() defer pending.Done()
if err := tester.sync("valid", nil); err != nil { if err := tester.sync("valid", nil, mode); err != nil {
t.Fatalf("failed to synchronise blocks: %v", err) t.Fatalf("failed to synchronise blocks: %v", err)
} }
}() }()
<-starting <-starting
if origin, latest := tester.downloader.Boundaries(); origin != 0 || latest != uint64(targetBlocks) { if origin, current, latest := tester.downloader.Progress(); origin != 0 || current > uint64(targetBlocks/2) || latest != uint64(targetBlocks) {
t.Fatalf("Completing boundary mismatch: have %v/%v, want %v/%v", origin, latest, 0, targetBlocks) t.Fatalf("Completing progress mismatch: have %v/%v/%v, want %v/0-%v/%v", origin, current, latest, 0, targetBlocks/2, targetBlocks)
} }
progress <- struct{}{} progress <- struct{}{}
pending.Wait() pending.Wait()
// Check final progress after successful sync
if origin, current, latest := tester.downloader.Progress(); origin > uint64(targetBlocks/2) || current != uint64(targetBlocks) || latest != uint64(targetBlocks) {
t.Fatalf("Final progress mismatch: have %v/%v/%v, want 0-%v/%v/%v", origin, current, latest, targetBlocks/2, targetBlocks, targetBlocks)
}
} }
// Tests that if an attacker fakes a chain height, after the attack is detected, // Tests that if an attacker fakes a chain height, after the attack is detected,
// the boundary height is successfully reduced at the next sync invocation. // the progress height is successfully reduced at the next sync invocation.
func TestFakedSyncBoundaries61(t *testing.T) { testFakedSyncBoundaries(t, 61, FullSync) } func TestFakedSyncProgress61(t *testing.T) { testFakedSyncProgress(t, 61, FullSync) }
func TestFakedSyncBoundaries62(t *testing.T) { testFakedSyncBoundaries(t, 62, FullSync) } func TestFakedSyncProgress62(t *testing.T) { testFakedSyncProgress(t, 62, FullSync) }
func TestFakedSyncBoundaries63Full(t *testing.T) { testFakedSyncBoundaries(t, 63, FullSync) } func TestFakedSyncProgress63Full(t *testing.T) { testFakedSyncProgress(t, 63, FullSync) }
func TestFakedSyncBoundaries63Fast(t *testing.T) { testFakedSyncBoundaries(t, 63, FastSync) } func TestFakedSyncProgress63Fast(t *testing.T) { testFakedSyncProgress(t, 63, FastSync) }
func TestFakedSyncBoundaries64Full(t *testing.T) { testFakedSyncBoundaries(t, 64, FullSync) } func TestFakedSyncProgress64Full(t *testing.T) { testFakedSyncProgress(t, 64, FullSync) }
func TestFakedSyncBoundaries64Fast(t *testing.T) { testFakedSyncBoundaries(t, 64, FastSync) } func TestFakedSyncProgress64Fast(t *testing.T) { testFakedSyncProgress(t, 64, FastSync) }
func TestFakedSyncBoundaries64Light(t *testing.T) { testFakedSyncBoundaries(t, 64, LightSync) } func TestFakedSyncProgress64Light(t *testing.T) { testFakedSyncProgress(t, 64, LightSync) }
func testFakedSyncBoundaries(t *testing.T, protocol int, mode SyncMode) { func testFakedSyncProgress(t *testing.T, protocol int, mode SyncMode) {
// Create a small block chain // Create a small block chain
targetBlocks := blockCacheLimit - 15 targetBlocks := blockCacheLimit - 15
hashes, headers, blocks, receipts := makeChain(targetBlocks+3, 0, genesis, nil) hashes, headers, blocks, receipts := makeChain(targetBlocks+3, 0, genesis, nil)
// Set a sync init hook to catch boundary changes // Set a sync init hook to catch progress changes
starting := make(chan struct{}) starting := make(chan struct{})
progress := make(chan struct{}) progress := make(chan struct{})
tester := newTester(mode) tester := newTester()
tester.downloader.syncInitHook = func(origin, latest uint64) { tester.downloader.syncInitHook = func(origin, latest uint64) {
starting <- struct{}{} starting <- struct{}{}
<-progress <-progress
} }
// Retrieve the sync boundaries and ensure they are zero (pristine sync) // Retrieve the sync progress and ensure they are zero (pristine sync)
if origin, latest := tester.downloader.Boundaries(); origin != 0 || latest != 0 { if origin, current, latest := tester.downloader.Progress(); origin != 0 || current != 0 || latest != 0 {
t.Fatalf("Pristine boundary mismatch: have %v/%v, want %v/%v", origin, latest, 0, 0) t.Fatalf("Pristine progress mismatch: have %v/%v/%v, want %v/%v/%v", origin, current, latest, 0, 0, 0)
} }
// Create and sync with an attacker that promises a higher chain than available // Create and sync with an attacker that promises a higher chain than available
tester.newPeer("attack", protocol, hashes, headers, blocks, receipts) tester.newPeer("attack", protocol, hashes, headers, blocks, receipts)
@ -1469,31 +1508,36 @@ func testFakedSyncBoundaries(t *testing.T, protocol int, mode SyncMode) {
go func() { go func() {
defer pending.Done() defer pending.Done()
if err := tester.sync("attack", nil); err == nil { if err := tester.sync("attack", nil, mode); err == nil {
t.Fatalf("succeeded attacker synchronisation") t.Fatalf("succeeded attacker synchronisation")
} }
}() }()
<-starting <-starting
if origin, latest := tester.downloader.Boundaries(); origin != 0 || latest != uint64(targetBlocks+3) { if origin, current, latest := tester.downloader.Progress(); origin != 0 || current != 0 || latest != uint64(targetBlocks+3) {
t.Fatalf("Initial boundary mismatch: have %v/%v, want %v/%v", origin, latest, 0, targetBlocks+3) t.Fatalf("Initial progress mismatch: have %v/%v/%v, want %v/%v/%v", origin, current, latest, 0, 0, targetBlocks+3)
} }
progress <- struct{}{} progress <- struct{}{}
pending.Wait() pending.Wait()
// Synchronise with a good peer and check that the boundary height has been reduced to the true value // Synchronise with a good peer and check that the progress height has been reduced to the true value
tester.newPeer("valid", protocol, hashes[3:], headers, blocks, receipts) tester.newPeer("valid", protocol, hashes[3:], headers, blocks, receipts)
pending.Add(1) pending.Add(1)
go func() { go func() {
defer pending.Done() defer pending.Done()
if err := tester.sync("valid", nil); err != nil { if err := tester.sync("valid", nil, mode); err != nil {
t.Fatalf("failed to synchronise blocks: %v", err) t.Fatalf("failed to synchronise blocks: %v", err)
} }
}() }()
<-starting <-starting
if origin, latest := tester.downloader.Boundaries(); origin != 0 || latest != uint64(targetBlocks) { if origin, current, latest := tester.downloader.Progress(); origin != 0 || current > uint64(targetBlocks) || latest != uint64(targetBlocks) {
t.Fatalf("Initial boundary mismatch: have %v/%v, want %v/%v", origin, latest, 0, targetBlocks) t.Fatalf("Completing progress mismatch: have %v/%v/%v, want %v/0-%v/%v", origin, current, latest, 0, targetBlocks, targetBlocks)
} }
progress <- struct{}{} progress <- struct{}{}
pending.Wait() pending.Wait()
// Check final progress after successful sync
if origin, current, latest := tester.downloader.Progress(); origin > uint64(targetBlocks) || current != uint64(targetBlocks) || latest != uint64(targetBlocks) {
t.Fatalf("Final progress mismatch: have %v/%v/%v, want 0-%v/%v/%v", origin, current, latest, targetBlocks, targetBlocks, targetBlocks)
}
} }

@ -20,7 +20,7 @@ package downloader
type SyncMode int type SyncMode int
const ( const (
FullSync SyncMode = iota // Synchronise the entire block-chain history from full blocks FullSync SyncMode = iota // Synchronise the entire blockchain history from full blocks
FastSync // Quikcly download the headers, full sync only at the chain head FastSync // Quickly download the headers, full sync only at the chain head
LightSync // Download only the headers and terminate afterwards LightSync // Download only the headers and terminate afterwards
) )

@ -124,6 +124,10 @@ func (p *peer) Reset() {
// Fetch61 sends a block retrieval request to the remote peer. // Fetch61 sends a block retrieval request to the remote peer.
func (p *peer) Fetch61(request *fetchRequest) error { func (p *peer) Fetch61(request *fetchRequest) error {
// Sanity check the protocol version
if p.version != 61 {
panic(fmt.Sprintf("block fetch [eth/61] requested on eth/%d", p.version))
}
// Short circuit if the peer is already fetching // Short circuit if the peer is already fetching
if !atomic.CompareAndSwapInt32(&p.blockIdle, 0, 1) { if !atomic.CompareAndSwapInt32(&p.blockIdle, 0, 1) {
return errAlreadyFetching return errAlreadyFetching
@ -142,6 +146,10 @@ func (p *peer) Fetch61(request *fetchRequest) error {
// FetchBodies sends a block body retrieval request to the remote peer. // FetchBodies sends a block body retrieval request to the remote peer.
func (p *peer) FetchBodies(request *fetchRequest) error { func (p *peer) FetchBodies(request *fetchRequest) error {
// Sanity check the protocol version
if p.version < 62 {
panic(fmt.Sprintf("body fetch [eth/62+] requested on eth/%d", p.version))
}
// Short circuit if the peer is already fetching // Short circuit if the peer is already fetching
if !atomic.CompareAndSwapInt32(&p.blockIdle, 0, 1) { if !atomic.CompareAndSwapInt32(&p.blockIdle, 0, 1) {
return errAlreadyFetching return errAlreadyFetching
@ -160,6 +168,10 @@ func (p *peer) FetchBodies(request *fetchRequest) error {
// FetchReceipts sends a receipt retrieval request to the remote peer. // FetchReceipts sends a receipt retrieval request to the remote peer.
func (p *peer) FetchReceipts(request *fetchRequest) error { func (p *peer) FetchReceipts(request *fetchRequest) error {
// Sanity check the protocol version
if p.version < 63 {
panic(fmt.Sprintf("body fetch [eth/63+] requested on eth/%d", p.version))
}
// Short circuit if the peer is already fetching // Short circuit if the peer is already fetching
if !atomic.CompareAndSwapInt32(&p.receiptIdle, 0, 1) { if !atomic.CompareAndSwapInt32(&p.receiptIdle, 0, 1) {
return errAlreadyFetching return errAlreadyFetching
@ -178,6 +190,10 @@ func (p *peer) FetchReceipts(request *fetchRequest) error {
// FetchNodeData sends a node state data retrieval request to the remote peer. // FetchNodeData sends a node state data retrieval request to the remote peer.
func (p *peer) FetchNodeData(request *fetchRequest) error { func (p *peer) FetchNodeData(request *fetchRequest) error {
// Sanity check the protocol version
if p.version < 63 {
panic(fmt.Sprintf("node data fetch [eth/63+] requested on eth/%d", p.version))
}
// Short circuit if the peer is already fetching // Short circuit if the peer is already fetching
if !atomic.CompareAndSwapInt32(&p.stateIdle, 0, 1) { if !atomic.CompareAndSwapInt32(&p.stateIdle, 0, 1) {
return errAlreadyFetching return errAlreadyFetching
@ -196,35 +212,35 @@ func (p *peer) FetchNodeData(request *fetchRequest) error {
// SetBlocksIdle sets the peer to idle, allowing it to execute new retrieval requests. // SetBlocksIdle sets the peer to idle, allowing it to execute new retrieval requests.
// Its block retrieval allowance will also be updated either up- or downwards, // Its block retrieval allowance will also be updated either up- or downwards,
// depending on whether the previous fetch completed in time or not. // depending on whether the previous fetch completed in time.
func (p *peer) SetBlocksIdle() { func (p *peer) SetBlocksIdle() {
p.setIdle(p.blockStarted, blockSoftTTL, blockHardTTL, MaxBlockFetch, &p.blockCapacity, &p.blockIdle) p.setIdle(p.blockStarted, blockSoftTTL, blockHardTTL, MaxBlockFetch, &p.blockCapacity, &p.blockIdle)
} }
// SetBodiesIdle sets the peer to idle, allowing it to execute new retrieval requests. // SetBodiesIdle sets the peer to idle, allowing it to execute new retrieval requests.
// Its block body retrieval allowance will also be updated either up- or downwards, // Its block body retrieval allowance will also be updated either up- or downwards,
// depending on whether the previous fetch completed in time or not. // depending on whether the previous fetch completed in time.
func (p *peer) SetBodiesIdle() { func (p *peer) SetBodiesIdle() {
p.setIdle(p.blockStarted, bodySoftTTL, bodyHardTTL, MaxBlockFetch, &p.blockCapacity, &p.blockIdle) p.setIdle(p.blockStarted, bodySoftTTL, bodyHardTTL, MaxBodyFetch, &p.blockCapacity, &p.blockIdle)
} }
// SetReceiptsIdle sets the peer to idle, allowing it to execute new retrieval requests. // SetReceiptsIdle sets the peer to idle, allowing it to execute new retrieval requests.
// Its receipt retrieval allowance will also be updated either up- or downwards, // Its receipt retrieval allowance will also be updated either up- or downwards,
// depending on whether the previous fetch completed in time or not. // depending on whether the previous fetch completed in time.
func (p *peer) SetReceiptsIdle() { func (p *peer) SetReceiptsIdle() {
p.setIdle(p.receiptStarted, receiptSoftTTL, receiptHardTTL, MaxReceiptFetch, &p.receiptCapacity, &p.receiptIdle) p.setIdle(p.receiptStarted, receiptSoftTTL, receiptHardTTL, MaxReceiptFetch, &p.receiptCapacity, &p.receiptIdle)
} }
// SetNodeDataIdle sets the peer to idle, allowing it to execute new retrieval // SetNodeDataIdle sets the peer to idle, allowing it to execute new retrieval
// requests. Its node data retrieval allowance will also be updated either up- or // requests. Its node data retrieval allowance will also be updated either up- or
// downwards, depending on whether the previous fetch completed in time or not. // downwards, depending on whether the previous fetch completed in time.
func (p *peer) SetNodeDataIdle() { func (p *peer) SetNodeDataIdle() {
p.setIdle(p.stateStarted, stateSoftTTL, stateSoftTTL, MaxStateFetch, &p.stateCapacity, &p.stateIdle) p.setIdle(p.stateStarted, stateSoftTTL, stateSoftTTL, MaxStateFetch, &p.stateCapacity, &p.stateIdle)
} }
// setIdle sets the peer to idle, allowing it to execute new retrieval requests. // setIdle sets the peer to idle, allowing it to execute new retrieval requests.
// Its data retrieval allowance will also be updated either up- or downwards, // Its data retrieval allowance will also be updated either up- or downwards,
// depending on whether the previous fetch completed in time or not. // depending on whether the previous fetch completed in time.
func (p *peer) setIdle(started time.Time, softTTL, hardTTL time.Duration, maxFetch int, capacity, idle *int32) { func (p *peer) setIdle(started time.Time, softTTL, hardTTL time.Duration, maxFetch int, capacity, idle *int32) {
// Update the peer's download allowance based on previous performance // Update the peer's download allowance based on previous performance
scale := 2.0 scale := 2.0

@ -56,9 +56,8 @@ type fetchRequest struct {
Time time.Time // Time when the request was made Time time.Time // Time when the request was made
} }
// fetchResult is the assembly collecting partial results from potentially more // fetchResult is a struct collecting partial results from data fetchers until
// than one fetcher routines, until all outstanding retrievals complete and the // all outstanding pieces complete and the result as a whole can be processed.
// result as a whole can be processed.
type fetchResult struct { type fetchResult struct {
Pending int // Number of data fetches still pending Pending int // Number of data fetches still pending
@ -89,7 +88,7 @@ type queue struct {
receiptPendPool map[string]*fetchRequest // [eth/63] Currently pending receipt retrieval operations receiptPendPool map[string]*fetchRequest // [eth/63] Currently pending receipt retrieval operations
receiptDonePool map[common.Hash]struct{} // [eth/63] Set of the completed receipt fetches receiptDonePool map[common.Hash]struct{} // [eth/63] Set of the completed receipt fetches
stateTaskIndex int // [eth/63] Counter indexing the added hashes to ensure prioritized retrieval order stateTaskIndex int // [eth/63] Counter indexing the added hashes to ensure prioritised retrieval order
stateTaskPool map[common.Hash]int // [eth/63] Pending node data retrieval tasks, mapping to their priority stateTaskPool map[common.Hash]int // [eth/63] Pending node data retrieval tasks, mapping to their priority
stateTaskQueue *prque.Prque // [eth/63] Priority queue of the hashes to fetch the node data for stateTaskQueue *prque.Prque // [eth/63] Priority queue of the hashes to fetch the node data for
statePendPool map[string]*fetchRequest // [eth/63] Currently pending node data retrieval operations statePendPool map[string]*fetchRequest // [eth/63] Currently pending node data retrieval operations
@ -97,10 +96,10 @@ type queue struct {
stateDatabase ethdb.Database // [eth/63] Trie database to populate during state reassembly stateDatabase ethdb.Database // [eth/63] Trie database to populate during state reassembly
stateScheduler *state.StateSync // [eth/63] State trie synchronisation scheduler and integrator stateScheduler *state.StateSync // [eth/63] State trie synchronisation scheduler and integrator
stateProcessors int32 // [eth/63] Number of currently running state processors stateProcessors int32 // [eth/63] Number of currently running state processors
stateSchedLock sync.RWMutex // [eth/63] Lock serializing access to the state scheduler stateSchedLock sync.RWMutex // [eth/63] Lock serialising access to the state scheduler
resultCache []*fetchResult // Downloaded but not yet delivered fetch results resultCache []*fetchResult // Downloaded but not yet delivered fetch results
resultOffset uint64 // Offset of the first cached fetch result in the block-chain resultOffset uint64 // Offset of the first cached fetch result in the block chain
lock sync.RWMutex lock sync.RWMutex
} }
@ -131,6 +130,9 @@ func (q *queue) Reset() {
q.lock.Lock() q.lock.Lock()
defer q.lock.Unlock() defer q.lock.Unlock()
q.stateSchedLock.Lock()
defer q.stateSchedLock.Unlock()
q.mode = FullSync q.mode = FullSync
q.fastSyncPivot = 0 q.fastSyncPivot = 0
@ -233,9 +235,17 @@ func (q *queue) Idle() bool {
return (queued + pending + cached) == 0 return (queued + pending + cached) == 0
} }
// ThrottleBlocks checks if the download should be throttled (active block (body) // FastSyncPivot retrieves the currently used fast sync pivot point.
func (q *queue) FastSyncPivot() uint64 {
q.lock.RLock()
defer q.lock.RUnlock()
return q.fastSyncPivot
}
// ShouldThrottleBlocks checks if the download should be throttled (active block (body)
// fetches exceed block cache). // fetches exceed block cache).
func (q *queue) ThrottleBlocks() bool { func (q *queue) ShouldThrottleBlocks() bool {
q.lock.RLock() q.lock.RLock()
defer q.lock.RUnlock() defer q.lock.RUnlock()
@ -248,9 +258,9 @@ func (q *queue) ThrottleBlocks() bool {
return pending >= len(q.resultCache)-len(q.blockDonePool) return pending >= len(q.resultCache)-len(q.blockDonePool)
} }
// ThrottleReceipts checks if the download should be throttled (active receipt // ShouldThrottleReceipts checks if the download should be throttled (active receipt
// fetches exceed block cache). // fetches exceed block cache).
func (q *queue) ThrottleReceipts() bool { func (q *queue) ShouldThrottleReceipts() bool {
q.lock.RLock() q.lock.RLock()
defer q.lock.RUnlock() defer q.lock.RUnlock()
@ -269,7 +279,7 @@ func (q *queue) Schedule61(hashes []common.Hash, fifo bool) []common.Hash {
q.lock.Lock() q.lock.Lock()
defer q.lock.Unlock() defer q.lock.Unlock()
// Insert all the hashes prioritized in the arrival order // Insert all the hashes prioritised in the arrival order
inserts := make([]common.Hash, 0, len(hashes)) inserts := make([]common.Hash, 0, len(hashes))
for _, hash := range hashes { for _, hash := range hashes {
// Skip anything we already have // Skip anything we already have
@ -297,10 +307,10 @@ func (q *queue) Schedule(headers []*types.Header, from uint64) []*types.Header {
q.lock.Lock() q.lock.Lock()
defer q.lock.Unlock() defer q.lock.Unlock()
// Insert all the headers prioritized by the contained block number // Insert all the headers prioritised by the contained block number
inserts := make([]*types.Header, 0, len(headers)) inserts := make([]*types.Header, 0, len(headers))
for _, header := range headers { for _, header := range headers {
// Make sure chain order is honored and preserved throughout // Make sure chain order is honoured and preserved throughout
hash := header.Hash() hash := header.Hash()
if header.Number == nil || header.Number.Uint64() != from { if header.Number == nil || header.Number.Uint64() != from {
glog.V(logger.Warn).Infof("Header #%v [%x] broke chain ordering, expected %d", header.Number, hash[:4], from) glog.V(logger.Warn).Infof("Header #%v [%x] broke chain ordering, expected %d", header.Number, hash[:4], from)
@ -347,19 +357,29 @@ func (q *queue) GetHeadResult() *fetchResult {
q.lock.RLock() q.lock.RLock()
defer q.lock.RUnlock() defer q.lock.RUnlock()
// If there are no results pending, return nil
if len(q.resultCache) == 0 || q.resultCache[0] == nil { if len(q.resultCache) == 0 || q.resultCache[0] == nil {
return nil return nil
} }
// If the next result is still incomplete, return nil
if q.resultCache[0].Pending > 0 { if q.resultCache[0].Pending > 0 {
return nil return nil
} }
// If the next result is the fast sync pivot...
if q.mode == FastSync && q.resultCache[0].Header.Number.Uint64() == q.fastSyncPivot { if q.mode == FastSync && q.resultCache[0].Header.Number.Uint64() == q.fastSyncPivot {
// If the pivot state trie is still being pulled, return nil
if len(q.stateTaskPool) > 0 { if len(q.stateTaskPool) > 0 {
return nil return nil
} }
if q.PendingNodeData() > 0 { if q.PendingNodeData() > 0 {
return nil return nil
} }
// If the state is done, but not enough post-pivot headers were verified, stall...
for i := 0; i < fsHeaderForceVerify; i++ {
if i+1 >= len(q.resultCache) || q.resultCache[i+1] == nil {
return nil
}
}
} }
return q.resultCache[0] return q.resultCache[0]
} }
@ -372,7 +392,7 @@ func (q *queue) TakeResults() []*fetchResult {
// Accumulate all available results // Accumulate all available results
results := []*fetchResult{} results := []*fetchResult{}
for _, result := range q.resultCache { for i, result := range q.resultCache {
// Stop if no more results are ready // Stop if no more results are ready
if result == nil || result.Pending > 0 { if result == nil || result.Pending > 0 {
break break
@ -385,6 +405,16 @@ func (q *queue) TakeResults() []*fetchResult {
if q.PendingNodeData() > 0 { if q.PendingNodeData() > 0 {
break break
} }
// Even is state fetch is done, ensure post-pivot headers passed verifications
safe := true
for j := 0; j < fsHeaderForceVerify; j++ {
if i+j+1 >= len(q.resultCache) || q.resultCache[i+j+1] == nil {
safe = false
}
}
if !safe {
break
}
} }
// If we've just inserted the fast sync pivot, stop as the following batch needs different insertion // If we've just inserted the fast sync pivot, stop as the following batch needs different insertion
if q.mode == FastSync && result.Header.Number.Uint64() == q.fastSyncPivot+1 && len(results) > 0 { if q.mode == FastSync && result.Header.Number.Uint64() == q.fastSyncPivot+1 && len(results) > 0 {
@ -411,6 +441,9 @@ func (q *queue) TakeResults() []*fetchResult {
// ReserveBlocks reserves a set of block hashes for the given peer, skipping any // ReserveBlocks reserves a set of block hashes for the given peer, skipping any
// previously failed download. // previously failed download.
func (q *queue) ReserveBlocks(p *peer, count int) *fetchRequest { func (q *queue) ReserveBlocks(p *peer, count int) *fetchRequest {
q.lock.Lock()
defer q.lock.Unlock()
return q.reserveHashes(p, count, q.hashQueue, nil, q.blockPendPool, len(q.resultCache)-len(q.blockDonePool)) return q.reserveHashes(p, count, q.hashQueue, nil, q.blockPendPool, len(q.resultCache)-len(q.blockDonePool))
} }
@ -430,17 +463,21 @@ func (q *queue) ReserveNodeData(p *peer, count int) *fetchRequest {
} }
} }
} }
q.lock.Lock()
defer q.lock.Unlock()
return q.reserveHashes(p, count, q.stateTaskQueue, generator, q.statePendPool, count) return q.reserveHashes(p, count, q.stateTaskQueue, generator, q.statePendPool, count)
} }
// reserveHashes reserves a set of hashes for the given peer, skipping previously // reserveHashes reserves a set of hashes for the given peer, skipping previously
// failed ones. // failed ones.
//
// Note, this method expects the queue lock to be already held for writing. The
// reason the lock is not obtained in here is because the parameters already need
// to access the queue, so they already need a lock anyway.
func (q *queue) reserveHashes(p *peer, count int, taskQueue *prque.Prque, taskGen func(int), pendPool map[string]*fetchRequest, maxPending int) *fetchRequest { func (q *queue) reserveHashes(p *peer, count int, taskQueue *prque.Prque, taskGen func(int), pendPool map[string]*fetchRequest, maxPending int) *fetchRequest {
q.lock.Lock() // Short circuit if the peer's already downloading something (sanity check to
defer q.lock.Unlock() // not corrupt state)
// Short circuit if the peer's already downloading something (sanity check not
// to corrupt state)
if _, ok := pendPool[p.id]; ok { if _, ok := pendPool[p.id]; ok {
return nil return nil
} }
@ -492,30 +529,37 @@ func (q *queue) reserveHashes(p *peer, count int, taskQueue *prque.Prque, taskGe
// previously failed downloads. Beside the next batch of needed fetches, it also // previously failed downloads. Beside the next batch of needed fetches, it also
// returns a flag whether empty blocks were queued requiring processing. // returns a flag whether empty blocks were queued requiring processing.
func (q *queue) ReserveBodies(p *peer, count int) (*fetchRequest, bool, error) { func (q *queue) ReserveBodies(p *peer, count int) (*fetchRequest, bool, error) {
noop := func(header *types.Header) bool { isNoop := func(header *types.Header) bool {
return header.TxHash == types.EmptyRootHash && header.UncleHash == types.EmptyUncleHash return header.TxHash == types.EmptyRootHash && header.UncleHash == types.EmptyUncleHash
} }
return q.reserveHeaders(p, count, q.blockTaskPool, q.blockTaskQueue, q.blockPendPool, q.blockDonePool, noop) q.lock.Lock()
defer q.lock.Unlock()
return q.reserveHeaders(p, count, q.blockTaskPool, q.blockTaskQueue, q.blockPendPool, q.blockDonePool, isNoop)
} }
// ReserveReceipts reserves a set of receipt fetches for the given peer, skipping // ReserveReceipts reserves a set of receipt fetches for the given peer, skipping
// any previously failed downloads. Beside the next batch of needed fetches, it // any previously failed downloads. Beside the next batch of needed fetches, it
// also returns a flag whether empty receipts were queued requiring importing. // also returns a flag whether empty receipts were queued requiring importing.
func (q *queue) ReserveReceipts(p *peer, count int) (*fetchRequest, bool, error) { func (q *queue) ReserveReceipts(p *peer, count int) (*fetchRequest, bool, error) {
noop := func(header *types.Header) bool { isNoop := func(header *types.Header) bool {
return header.ReceiptHash == types.EmptyRootHash return header.ReceiptHash == types.EmptyRootHash
} }
return q.reserveHeaders(p, count, q.receiptTaskPool, q.receiptTaskQueue, q.receiptPendPool, q.receiptDonePool, noop) q.lock.Lock()
defer q.lock.Unlock()
return q.reserveHeaders(p, count, q.receiptTaskPool, q.receiptTaskQueue, q.receiptPendPool, q.receiptDonePool, isNoop)
} }
// reserveHeaders reserves a set of data download operations for a given peer, // reserveHeaders reserves a set of data download operations for a given peer,
// skipping any previously failed ones. This method is a generic version used // skipping any previously failed ones. This method is a generic version used
// by the individual special reservation functions. // by the individual special reservation functions.
//
// Note, this method expects the queue lock to be already held for writing. The
// reason the lock is not obtained in here is because the parameters already need
// to access the queue, so they already need a lock anyway.
func (q *queue) reserveHeaders(p *peer, count int, taskPool map[common.Hash]*types.Header, taskQueue *prque.Prque, func (q *queue) reserveHeaders(p *peer, count int, taskPool map[common.Hash]*types.Header, taskQueue *prque.Prque,
pendPool map[string]*fetchRequest, donePool map[common.Hash]struct{}, noop func(*types.Header) bool) (*fetchRequest, bool, error) { pendPool map[string]*fetchRequest, donePool map[common.Hash]struct{}, isNoop func(*types.Header) bool) (*fetchRequest, bool, error) {
q.lock.Lock()
defer q.lock.Unlock()
// Short circuit if the pool has been depleted, or if the peer's already // Short circuit if the pool has been depleted, or if the peer's already
// downloading something (sanity check not to corrupt state) // downloading something (sanity check not to corrupt state)
if taskQueue.Empty() { if taskQueue.Empty() {
@ -537,7 +581,7 @@ func (q *queue) reserveHeaders(p *peer, count int, taskPool map[common.Hash]*typ
for proc := 0; proc < space && len(send) < count && !taskQueue.Empty(); proc++ { for proc := 0; proc < space && len(send) < count && !taskQueue.Empty(); proc++ {
header := taskQueue.PopItem().(*types.Header) header := taskQueue.PopItem().(*types.Header)
// If we're the first to request this task, initialize the result container // If we're the first to request this task, initialise the result container
index := int(header.Number.Int64() - int64(q.resultOffset)) index := int(header.Number.Int64() - int64(q.resultOffset))
if index >= len(q.resultCache) || index < 0 { if index >= len(q.resultCache) || index < 0 {
return nil, false, errInvalidChain return nil, false, errInvalidChain
@ -553,7 +597,7 @@ func (q *queue) reserveHeaders(p *peer, count int, taskPool map[common.Hash]*typ
} }
} }
// If this fetch task is a noop, skip this fetch operation // If this fetch task is a noop, skip this fetch operation
if noop(header) { if isNoop(header) {
donePool[header.Hash()] = struct{}{} donePool[header.Hash()] = struct{}{}
delete(taskPool, header.Hash()) delete(taskPool, header.Hash())
@ -562,7 +606,7 @@ func (q *queue) reserveHeaders(p *peer, count int, taskPool map[common.Hash]*typ
progress = true progress = true
continue continue
} }
// Otherwise if not a known unknown block, add to the retrieve list // Otherwise unless the peer is known not to have the data, add to the retrieve list
if p.ignored.Has(header.Hash()) { if p.ignored.Has(header.Hash()) {
skip = append(skip, header) skip = append(skip, header)
} else { } else {
@ -655,35 +699,48 @@ func (q *queue) Revoke(peerId string) {
} }
// ExpireBlocks checks for in flight requests that exceeded a timeout allowance, // ExpireBlocks checks for in flight requests that exceeded a timeout allowance,
// canceling them and returning the responsible peers for penalization. // canceling them and returning the responsible peers for penalisation.
func (q *queue) ExpireBlocks(timeout time.Duration) []string { func (q *queue) ExpireBlocks(timeout time.Duration) []string {
q.lock.Lock()
defer q.lock.Unlock()
return q.expire(timeout, q.blockPendPool, q.hashQueue, blockTimeoutMeter) return q.expire(timeout, q.blockPendPool, q.hashQueue, blockTimeoutMeter)
} }
// ExpireBodies checks for in flight block body requests that exceeded a timeout // ExpireBodies checks for in flight block body requests that exceeded a timeout
// allowance, canceling them and returning the responsible peers for penalization. // allowance, canceling them and returning the responsible peers for penalisation.
func (q *queue) ExpireBodies(timeout time.Duration) []string { func (q *queue) ExpireBodies(timeout time.Duration) []string {
q.lock.Lock()
defer q.lock.Unlock()
return q.expire(timeout, q.blockPendPool, q.blockTaskQueue, bodyTimeoutMeter) return q.expire(timeout, q.blockPendPool, q.blockTaskQueue, bodyTimeoutMeter)
} }
// ExpireReceipts checks for in flight receipt requests that exceeded a timeout // ExpireReceipts checks for in flight receipt requests that exceeded a timeout
// allowance, canceling them and returning the responsible peers for penalization. // allowance, canceling them and returning the responsible peers for penalisation.
func (q *queue) ExpireReceipts(timeout time.Duration) []string { func (q *queue) ExpireReceipts(timeout time.Duration) []string {
q.lock.Lock()
defer q.lock.Unlock()
return q.expire(timeout, q.receiptPendPool, q.receiptTaskQueue, receiptTimeoutMeter) return q.expire(timeout, q.receiptPendPool, q.receiptTaskQueue, receiptTimeoutMeter)
} }
// ExpireNodeData checks for in flight node data requests that exceeded a timeout // ExpireNodeData checks for in flight node data requests that exceeded a timeout
// allowance, canceling them and returning the responsible peers for penalization. // allowance, canceling them and returning the responsible peers for penalisation.
func (q *queue) ExpireNodeData(timeout time.Duration) []string { func (q *queue) ExpireNodeData(timeout time.Duration) []string {
q.lock.Lock()
defer q.lock.Unlock()
return q.expire(timeout, q.statePendPool, q.stateTaskQueue, stateTimeoutMeter) return q.expire(timeout, q.statePendPool, q.stateTaskQueue, stateTimeoutMeter)
} }
// expire is the generic check that move expired tasks from a pending pool back // expire is the generic check that move expired tasks from a pending pool back
// into a task pool, returning all entities caught with expired tasks. // into a task pool, returning all entities caught with expired tasks.
//
// Note, this method expects the queue lock to be already held for writing. The
// reason the lock is not obtained in here is because the parameters already need
// to access the queue, so they already need a lock anyway.
func (q *queue) expire(timeout time.Duration, pendPool map[string]*fetchRequest, taskQueue *prque.Prque, timeoutMeter metrics.Meter) []string { func (q *queue) expire(timeout time.Duration, pendPool map[string]*fetchRequest, taskQueue *prque.Prque, timeoutMeter metrics.Meter) []string {
q.lock.Lock()
defer q.lock.Unlock()
// Iterate over the expired requests and return each to the queue // Iterate over the expired requests and return each to the queue
peers := []string{} peers := []string{}
for id, request := range pendPool { for id, request := range pendPool {
@ -764,7 +821,7 @@ func (q *queue) DeliverBlocks(id string, blocks []*types.Block) error {
case len(errs) == 1 && (errs[0] == errInvalidChain || errs[0] == errInvalidBlock): case len(errs) == 1 && (errs[0] == errInvalidChain || errs[0] == errInvalidBlock):
return errs[0] return errs[0]
case len(errs) == len(request.Headers): case len(errs) == len(blocks):
return errStaleDelivery return errStaleDelivery
default: default:
@ -774,6 +831,9 @@ func (q *queue) DeliverBlocks(id string, blocks []*types.Block) error {
// DeliverBodies injects a block body retrieval response into the results queue. // DeliverBodies injects a block body retrieval response into the results queue.
func (q *queue) DeliverBodies(id string, txLists [][]*types.Transaction, uncleLists [][]*types.Header) error { func (q *queue) DeliverBodies(id string, txLists [][]*types.Transaction, uncleLists [][]*types.Header) error {
q.lock.Lock()
defer q.lock.Unlock()
reconstruct := func(header *types.Header, index int, result *fetchResult) error { reconstruct := func(header *types.Header, index int, result *fetchResult) error {
if types.DeriveSha(types.Transactions(txLists[index])) != header.TxHash || types.CalcUncleHash(uncleLists[index]) != header.UncleHash { if types.DeriveSha(types.Transactions(txLists[index])) != header.TxHash || types.CalcUncleHash(uncleLists[index]) != header.UncleHash {
return errInvalidBody return errInvalidBody
@ -787,6 +847,9 @@ func (q *queue) DeliverBodies(id string, txLists [][]*types.Transaction, uncleLi
// DeliverReceipts injects a receipt retrieval response into the results queue. // DeliverReceipts injects a receipt retrieval response into the results queue.
func (q *queue) DeliverReceipts(id string, receiptList [][]*types.Receipt) error { func (q *queue) DeliverReceipts(id string, receiptList [][]*types.Receipt) error {
q.lock.Lock()
defer q.lock.Unlock()
reconstruct := func(header *types.Header, index int, result *fetchResult) error { reconstruct := func(header *types.Header, index int, result *fetchResult) error {
if types.DeriveSha(types.Receipts(receiptList[index])) != header.ReceiptHash { if types.DeriveSha(types.Receipts(receiptList[index])) != header.ReceiptHash {
return errInvalidReceipt return errInvalidReceipt
@ -798,11 +861,12 @@ func (q *queue) DeliverReceipts(id string, receiptList [][]*types.Receipt) error
} }
// deliver injects a data retrieval response into the results queue. // deliver injects a data retrieval response into the results queue.
//
// Note, this method expects the queue lock to be already held for writing. The
// reason the lock is not obtained in here is because the parameters already need
// to access the queue, so they already need a lock anyway.
func (q *queue) deliver(id string, taskPool map[common.Hash]*types.Header, taskQueue *prque.Prque, pendPool map[string]*fetchRequest, func (q *queue) deliver(id string, taskPool map[common.Hash]*types.Header, taskQueue *prque.Prque, pendPool map[string]*fetchRequest,
donePool map[common.Hash]struct{}, reqTimer metrics.Timer, results int, reconstruct func(header *types.Header, index int, result *fetchResult) error) error { donePool map[common.Hash]struct{}, reqTimer metrics.Timer, results int, reconstruct func(header *types.Header, index int, result *fetchResult) error) error {
q.lock.Lock()
defer q.lock.Unlock()
// Short circuit if the data was never requested // Short circuit if the data was never requested
request := pendPool[id] request := pendPool[id]
if request == nil { if request == nil {
@ -818,7 +882,10 @@ func (q *queue) deliver(id string, taskPool map[common.Hash]*types.Header, taskQ
} }
} }
// Assemble each of the results with their headers and retrieved data parts // Assemble each of the results with their headers and retrieved data parts
errs := make([]error, 0) var (
failure error
useful bool
)
for i, header := range request.Headers { for i, header := range request.Headers {
// Short circuit assembly if no more fetch results are found // Short circuit assembly if no more fetch results are found
if i >= results { if i >= results {
@ -827,15 +894,16 @@ func (q *queue) deliver(id string, taskPool map[common.Hash]*types.Header, taskQ
// Reconstruct the next result if contents match up // Reconstruct the next result if contents match up
index := int(header.Number.Int64() - int64(q.resultOffset)) index := int(header.Number.Int64() - int64(q.resultOffset))
if index >= len(q.resultCache) || index < 0 || q.resultCache[index] == nil { if index >= len(q.resultCache) || index < 0 || q.resultCache[index] == nil {
errs = []error{errInvalidChain} failure = errInvalidChain
break break
} }
if err := reconstruct(header, i, q.resultCache[index]); err != nil { if err := reconstruct(header, i, q.resultCache[index]); err != nil {
errs = []error{err} failure = err
break break
} }
donePool[header.Hash()] = struct{}{} donePool[header.Hash()] = struct{}{}
q.resultCache[index].Pending-- q.resultCache[index].Pending--
useful = true
// Clean up a successful fetch // Clean up a successful fetch
request.Headers[i] = nil request.Headers[i] = nil
@ -847,19 +915,16 @@ func (q *queue) deliver(id string, taskPool map[common.Hash]*types.Header, taskQ
taskQueue.Push(header, -float32(header.Number.Uint64())) taskQueue.Push(header, -float32(header.Number.Uint64()))
} }
} }
// If none of the blocks were good, it's a stale delivery // If none of the data was good, it's a stale delivery
switch { switch {
case len(errs) == 0: case failure == nil || failure == errInvalidChain:
return nil return failure
case len(errs) == 1 && (errs[0] == errInvalidChain || errs[0] == errInvalidBody || errs[0] == errInvalidReceipt):
return errs[0]
case len(errs) == len(request.Headers): case useful:
return errStaleDelivery return fmt.Errorf("partial failure: %v", failure)
default: default:
return fmt.Errorf("multiple failures: %v", errs) return errStaleDelivery
} }
} }
@ -876,7 +941,7 @@ func (q *queue) DeliverNodeData(id string, data [][]byte, callback func(error, i
stateReqTimer.UpdateSince(request.Time) stateReqTimer.UpdateSince(request.Time)
delete(q.statePendPool, id) delete(q.statePendPool, id)
// If no data was retrieved, mark them as unavailable for the origin peer // If no data was retrieved, mark their hashes as unavailable for the origin peer
if len(data) == 0 { if len(data) == 0 {
for hash, _ := range request.Hashes { for hash, _ := range request.Hashes {
request.Peer.ignored.Add(hash) request.Peer.ignored.Add(hash)
@ -955,9 +1020,6 @@ func (q *queue) Prepare(offset uint64, mode SyncMode, pivot uint64) {
if q.resultOffset < offset { if q.resultOffset < offset {
q.resultOffset = offset q.resultOffset = offset
} }
q.fastSyncPivot = 0 q.fastSyncPivot = pivot
if mode == FastSync {
q.fastSyncPivot = pivot
}
q.mode = mode q.mode = mode
} }

@ -142,9 +142,11 @@ type Fetcher struct {
dropPeer peerDropFn // Drops a peer for misbehaving dropPeer peerDropFn // Drops a peer for misbehaving
// Testing hooks // Testing hooks
fetchingHook func([]common.Hash) // Method to call upon starting a block (eth/61) or header (eth/62) fetch announceChangeHook func(common.Hash, bool) // Method to call upon adding or deleting a hash from the announce list
completingHook func([]common.Hash) // Method to call upon starting a block body fetch (eth/62) queueChangeHook func(common.Hash, bool) // Method to call upon adding or deleting a block from the import queue
importedHook func(*types.Block) // Method to call upon successful block import (both eth/61 and eth/62) fetchingHook func([]common.Hash) // Method to call upon starting a block (eth/61) or header (eth/62) fetch
completingHook func([]common.Hash) // Method to call upon starting a block body fetch (eth/62)
importedHook func(*types.Block) // Method to call upon successful block import (both eth/61 and eth/62)
} }
// New creates a block fetcher to retrieve blocks based on hash announcements. // New creates a block fetcher to retrieve blocks based on hash announcements.
@ -324,11 +326,16 @@ func (f *Fetcher) loop() {
height := f.chainHeight() height := f.chainHeight()
for !f.queue.Empty() { for !f.queue.Empty() {
op := f.queue.PopItem().(*inject) op := f.queue.PopItem().(*inject)
if f.queueChangeHook != nil {
f.queueChangeHook(op.block.Hash(), false)
}
// If too high up the chain or phase, continue later // If too high up the chain or phase, continue later
number := op.block.NumberU64() number := op.block.NumberU64()
if number > height+1 { if number > height+1 {
f.queue.Push(op, -float32(op.block.NumberU64())) f.queue.Push(op, -float32(op.block.NumberU64()))
if f.queueChangeHook != nil {
f.queueChangeHook(op.block.Hash(), true)
}
break break
} }
// Otherwise if fresh and still unknown, try and import // Otherwise if fresh and still unknown, try and import
@ -372,6 +379,9 @@ func (f *Fetcher) loop() {
} }
f.announces[notification.origin] = count f.announces[notification.origin] = count
f.announced[notification.hash] = append(f.announced[notification.hash], notification) f.announced[notification.hash] = append(f.announced[notification.hash], notification)
if f.announceChangeHook != nil && len(f.announced[notification.hash]) == 1 {
f.announceChangeHook(notification.hash, true)
}
if len(f.announced) == 1 { if len(f.announced) == 1 {
f.rescheduleFetch(fetchTimer) f.rescheduleFetch(fetchTimer)
} }
@ -714,7 +724,9 @@ func (f *Fetcher) enqueue(peer string, block *types.Block) {
f.queues[peer] = count f.queues[peer] = count
f.queued[hash] = op f.queued[hash] = op
f.queue.Push(op, -float32(block.NumberU64())) f.queue.Push(op, -float32(block.NumberU64()))
if f.queueChangeHook != nil {
f.queueChangeHook(op.block.Hash(), true)
}
if glog.V(logger.Debug) { if glog.V(logger.Debug) {
glog.Infof("Peer %s: queued block #%d [%x…], total %v", peer, block.NumberU64(), hash.Bytes()[:4], f.queue.Size()) glog.Infof("Peer %s: queued block #%d [%x…], total %v", peer, block.NumberU64(), hash.Bytes()[:4], f.queue.Size())
} }
@ -781,7 +793,9 @@ func (f *Fetcher) forgetHash(hash common.Hash) {
} }
} }
delete(f.announced, hash) delete(f.announced, hash)
if f.announceChangeHook != nil {
f.announceChangeHook(hash, false)
}
// Remove any pending fetches and decrement the DOS counters // Remove any pending fetches and decrement the DOS counters
if announce := f.fetching[hash]; announce != nil { if announce := f.fetching[hash]; announce != nil {
f.announces[announce.origin]-- f.announces[announce.origin]--

@ -145,6 +145,9 @@ func (f *fetcherTester) insertChain(blocks types.Blocks) (int, error) {
// dropPeer is an emulator for the peer removal, simply accumulating the various // dropPeer is an emulator for the peer removal, simply accumulating the various
// peers dropped by the fetcher. // peers dropped by the fetcher.
func (f *fetcherTester) dropPeer(peer string) { func (f *fetcherTester) dropPeer(peer string) {
f.lock.Lock()
defer f.lock.Unlock()
f.drops[peer] = true f.drops[peer] = true
} }
@ -608,8 +611,11 @@ func TestDistantPropagationDiscarding(t *testing.T) {
// Create a tester and simulate a head block being the middle of the above chain // Create a tester and simulate a head block being the middle of the above chain
tester := newTester() tester := newTester()
tester.lock.Lock()
tester.hashes = []common.Hash{head} tester.hashes = []common.Hash{head}
tester.blocks = map[common.Hash]*types.Block{head: blocks[head]} tester.blocks = map[common.Hash]*types.Block{head: blocks[head]}
tester.lock.Unlock()
// Ensure that a block with a lower number than the threshold is discarded // Ensure that a block with a lower number than the threshold is discarded
tester.fetcher.Enqueue("lower", blocks[hashes[low]]) tester.fetcher.Enqueue("lower", blocks[hashes[low]])
@ -641,8 +647,11 @@ func testDistantAnnouncementDiscarding(t *testing.T, protocol int) {
// Create a tester and simulate a head block being the middle of the above chain // Create a tester and simulate a head block being the middle of the above chain
tester := newTester() tester := newTester()
tester.lock.Lock()
tester.hashes = []common.Hash{head} tester.hashes = []common.Hash{head}
tester.blocks = map[common.Hash]*types.Block{head: blocks[head]} tester.blocks = map[common.Hash]*types.Block{head: blocks[head]}
tester.lock.Unlock()
headerFetcher := tester.makeHeaderFetcher(blocks, -gatherSlack) headerFetcher := tester.makeHeaderFetcher(blocks, -gatherSlack)
bodyFetcher := tester.makeBodyFetcher(blocks, 0) bodyFetcher := tester.makeBodyFetcher(blocks, 0)
@ -687,14 +696,22 @@ func testInvalidNumberAnnouncement(t *testing.T, protocol int) {
tester.fetcher.Notify("bad", hashes[0], 2, time.Now().Add(-arriveTimeout), nil, headerFetcher, bodyFetcher) tester.fetcher.Notify("bad", hashes[0], 2, time.Now().Add(-arriveTimeout), nil, headerFetcher, bodyFetcher)
verifyImportEvent(t, imported, false) verifyImportEvent(t, imported, false)
if !tester.drops["bad"] { tester.lock.RLock()
dropped := tester.drops["bad"]
tester.lock.RUnlock()
if !dropped {
t.Fatalf("peer with invalid numbered announcement not dropped") t.Fatalf("peer with invalid numbered announcement not dropped")
} }
// Make sure a good announcement passes without a drop // Make sure a good announcement passes without a drop
tester.fetcher.Notify("good", hashes[0], 1, time.Now().Add(-arriveTimeout), nil, headerFetcher, bodyFetcher) tester.fetcher.Notify("good", hashes[0], 1, time.Now().Add(-arriveTimeout), nil, headerFetcher, bodyFetcher)
verifyImportEvent(t, imported, true) verifyImportEvent(t, imported, true)
if tester.drops["good"] { tester.lock.RLock()
dropped = tester.drops["good"]
tester.lock.RUnlock()
if dropped {
t.Fatalf("peer with valid numbered announcement dropped") t.Fatalf("peer with valid numbered announcement dropped")
} }
verifyImportDone(t, imported) verifyImportDone(t, imported)
@ -752,9 +769,15 @@ func testHashMemoryExhaustionAttack(t *testing.T, protocol int) {
// Create a tester with instrumented import hooks // Create a tester with instrumented import hooks
tester := newTester() tester := newTester()
imported := make(chan *types.Block) imported, announces := make(chan *types.Block), int32(0)
tester.fetcher.importedHook = func(block *types.Block) { imported <- block } tester.fetcher.importedHook = func(block *types.Block) { imported <- block }
tester.fetcher.announceChangeHook = func(hash common.Hash, added bool) {
if added {
atomic.AddInt32(&announces, 1)
} else {
atomic.AddInt32(&announces, -1)
}
}
// Create a valid chain and an infinite junk chain // Create a valid chain and an infinite junk chain
targetBlocks := hashLimit + 2*maxQueueDist targetBlocks := hashLimit + 2*maxQueueDist
hashes, blocks := makeChain(targetBlocks, 0, genesis) hashes, blocks := makeChain(targetBlocks, 0, genesis)
@ -782,8 +805,8 @@ func testHashMemoryExhaustionAttack(t *testing.T, protocol int) {
tester.fetcher.Notify("attacker", attack[i], 1 /* don't distance drop */, time.Now(), nil, attackerHeaderFetcher, attackerBodyFetcher) tester.fetcher.Notify("attacker", attack[i], 1 /* don't distance drop */, time.Now(), nil, attackerHeaderFetcher, attackerBodyFetcher)
} }
} }
if len(tester.fetcher.announced) != hashLimit+maxQueueDist { if count := atomic.LoadInt32(&announces); count != hashLimit+maxQueueDist {
t.Fatalf("queued announce count mismatch: have %d, want %d", len(tester.fetcher.announced), hashLimit+maxQueueDist) t.Fatalf("queued announce count mismatch: have %d, want %d", count, hashLimit+maxQueueDist)
} }
// Wait for fetches to complete // Wait for fetches to complete
verifyImportCount(t, imported, maxQueueDist) verifyImportCount(t, imported, maxQueueDist)
@ -807,9 +830,15 @@ func TestBlockMemoryExhaustionAttack(t *testing.T) {
// Create a tester with instrumented import hooks // Create a tester with instrumented import hooks
tester := newTester() tester := newTester()
imported := make(chan *types.Block) imported, enqueued := make(chan *types.Block), int32(0)
tester.fetcher.importedHook = func(block *types.Block) { imported <- block } tester.fetcher.importedHook = func(block *types.Block) { imported <- block }
tester.fetcher.queueChangeHook = func(hash common.Hash, added bool) {
if added {
atomic.AddInt32(&enqueued, 1)
} else {
atomic.AddInt32(&enqueued, -1)
}
}
// Create a valid chain and a batch of dangling (but in range) blocks // Create a valid chain and a batch of dangling (but in range) blocks
targetBlocks := hashLimit + 2*maxQueueDist targetBlocks := hashLimit + 2*maxQueueDist
hashes, blocks := makeChain(targetBlocks, 0, genesis) hashes, blocks := makeChain(targetBlocks, 0, genesis)
@ -825,7 +854,7 @@ func TestBlockMemoryExhaustionAttack(t *testing.T) {
tester.fetcher.Enqueue("attacker", block) tester.fetcher.Enqueue("attacker", block)
} }
time.Sleep(200 * time.Millisecond) time.Sleep(200 * time.Millisecond)
if queued := tester.fetcher.queue.Size(); queued != blockLimit { if queued := atomic.LoadInt32(&enqueued); queued != blockLimit {
t.Fatalf("queued block count mismatch: have %d, want %d", queued, blockLimit) t.Fatalf("queued block count mismatch: have %d, want %d", queued, blockLimit)
} }
// Queue up a batch of valid blocks, and check that a new peer is allowed to do so // Queue up a batch of valid blocks, and check that a new peer is allowed to do so
@ -833,7 +862,7 @@ func TestBlockMemoryExhaustionAttack(t *testing.T) {
tester.fetcher.Enqueue("valid", blocks[hashes[len(hashes)-3-i]]) tester.fetcher.Enqueue("valid", blocks[hashes[len(hashes)-3-i]])
} }
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
if queued := tester.fetcher.queue.Size(); queued != blockLimit+maxQueueDist-1 { if queued := atomic.LoadInt32(&enqueued); queued != blockLimit+maxQueueDist-1 {
t.Fatalf("queued block count mismatch: have %d, want %d", queued, blockLimit+maxQueueDist-1) t.Fatalf("queued block count mismatch: have %d, want %d", queued, blockLimit+maxQueueDist-1)
} }
// Insert the missing piece (and sanity check the import) // Insert the missing piece (and sanity check the import)

@ -16,9 +16,9 @@ import (
func makeReceipt(addr common.Address) *types.Receipt { func makeReceipt(addr common.Address) *types.Receipt {
receipt := types.NewReceipt(nil, new(big.Int)) receipt := types.NewReceipt(nil, new(big.Int))
receipt.SetLogs(vm.Logs{ receipt.Logs = vm.Logs{
&vm.Log{Address: addr}, &vm.Log{Address: addr},
}) }
receipt.Bloom = types.CreateBloom(types.Receipts{receipt}) receipt.Bloom = types.CreateBloom(types.Receipts{receipt})
return receipt return receipt
} }
@ -41,7 +41,7 @@ func BenchmarkMipmaps(b *testing.B) {
defer db.Close() defer db.Close()
genesis := core.WriteGenesisBlockForTesting(db, core.GenesisAccount{addr1, big.NewInt(1000000)}) genesis := core.WriteGenesisBlockForTesting(db, core.GenesisAccount{addr1, big.NewInt(1000000)})
chain := core.GenerateChain(genesis, db, 100010, func(i int, gen *core.BlockGen) { chain, receipts := core.GenerateChain(genesis, db, 100010, func(i int, gen *core.BlockGen) {
var receipts types.Receipts var receipts types.Receipts
switch i { switch i {
case 2403: case 2403:
@ -70,7 +70,7 @@ func BenchmarkMipmaps(b *testing.B) {
} }
core.WriteMipmapBloom(db, uint64(i+1), receipts) core.WriteMipmapBloom(db, uint64(i+1), receipts)
}) })
for _, block := range chain { for i, block := range chain {
core.WriteBlock(db, block) core.WriteBlock(db, block)
if err := core.WriteCanonicalHash(db, block.Hash(), block.NumberU64()); err != nil { if err := core.WriteCanonicalHash(db, block.Hash(), block.NumberU64()); err != nil {
b.Fatalf("failed to insert block number: %v", err) b.Fatalf("failed to insert block number: %v", err)
@ -78,11 +78,10 @@ func BenchmarkMipmaps(b *testing.B) {
if err := core.WriteHeadBlockHash(db, block.Hash()); err != nil { if err := core.WriteHeadBlockHash(db, block.Hash()); err != nil {
b.Fatalf("failed to insert block number: %v", err) b.Fatalf("failed to insert block number: %v", err)
} }
if err := core.PutBlockReceipts(db, block, block.Receipts()); err != nil { if err := core.PutBlockReceipts(db, block.Hash(), receipts[i]); err != nil {
b.Fatal("error writing block receipts:", err) b.Fatal("error writing block receipts:", err)
} }
} }
b.ResetTimer() b.ResetTimer()
filter := New(db) filter := New(db)
@ -118,47 +117,47 @@ func TestFilters(t *testing.T) {
defer db.Close() defer db.Close()
genesis := core.WriteGenesisBlockForTesting(db, core.GenesisAccount{addr, big.NewInt(1000000)}) genesis := core.WriteGenesisBlockForTesting(db, core.GenesisAccount{addr, big.NewInt(1000000)})
chain := core.GenerateChain(genesis, db, 1000, func(i int, gen *core.BlockGen) { chain, receipts := core.GenerateChain(genesis, db, 1000, func(i int, gen *core.BlockGen) {
var receipts types.Receipts var receipts types.Receipts
switch i { switch i {
case 1: case 1:
receipt := types.NewReceipt(nil, new(big.Int)) receipt := types.NewReceipt(nil, new(big.Int))
receipt.SetLogs(vm.Logs{ receipt.Logs = vm.Logs{
&vm.Log{ &vm.Log{
Address: addr, Address: addr,
Topics: []common.Hash{hash1}, Topics: []common.Hash{hash1},
}, },
}) }
gen.AddUncheckedReceipt(receipt) gen.AddUncheckedReceipt(receipt)
receipts = types.Receipts{receipt} receipts = types.Receipts{receipt}
case 2: case 2:
receipt := types.NewReceipt(nil, new(big.Int)) receipt := types.NewReceipt(nil, new(big.Int))
receipt.SetLogs(vm.Logs{ receipt.Logs = vm.Logs{
&vm.Log{ &vm.Log{
Address: addr, Address: addr,
Topics: []common.Hash{hash2}, Topics: []common.Hash{hash2},
}, },
}) }
gen.AddUncheckedReceipt(receipt) gen.AddUncheckedReceipt(receipt)
receipts = types.Receipts{receipt} receipts = types.Receipts{receipt}
case 998: case 998:
receipt := types.NewReceipt(nil, new(big.Int)) receipt := types.NewReceipt(nil, new(big.Int))
receipt.SetLogs(vm.Logs{ receipt.Logs = vm.Logs{
&vm.Log{ &vm.Log{
Address: addr, Address: addr,
Topics: []common.Hash{hash3}, Topics: []common.Hash{hash3},
}, },
}) }
gen.AddUncheckedReceipt(receipt) gen.AddUncheckedReceipt(receipt)
receipts = types.Receipts{receipt} receipts = types.Receipts{receipt}
case 999: case 999:
receipt := types.NewReceipt(nil, new(big.Int)) receipt := types.NewReceipt(nil, new(big.Int))
receipt.SetLogs(vm.Logs{ receipt.Logs = vm.Logs{
&vm.Log{ &vm.Log{
Address: addr, Address: addr,
Topics: []common.Hash{hash4}, Topics: []common.Hash{hash4},
}, },
}) }
gen.AddUncheckedReceipt(receipt) gen.AddUncheckedReceipt(receipt)
receipts = types.Receipts{receipt} receipts = types.Receipts{receipt}
} }
@ -173,7 +172,7 @@ func TestFilters(t *testing.T) {
// by one // by one
core.WriteMipmapBloom(db, uint64(i+1), receipts) core.WriteMipmapBloom(db, uint64(i+1), receipts)
}) })
for _, block := range chain { for i, block := range chain {
core.WriteBlock(db, block) core.WriteBlock(db, block)
if err := core.WriteCanonicalHash(db, block.Hash(), block.NumberU64()); err != nil { if err := core.WriteCanonicalHash(db, block.Hash(), block.NumberU64()); err != nil {
t.Fatalf("failed to insert block number: %v", err) t.Fatalf("failed to insert block number: %v", err)
@ -181,7 +180,7 @@ func TestFilters(t *testing.T) {
if err := core.WriteHeadBlockHash(db, block.Hash()); err != nil { if err := core.WriteHeadBlockHash(db, block.Hash()); err != nil {
t.Fatalf("failed to insert block number: %v", err) t.Fatalf("failed to insert block number: %v", err)
} }
if err := core.PutBlockReceipts(db, block, block.Receipts()); err != nil { if err := core.PutBlockReceipts(db, block.Hash(), receipts[i]); err != nil {
t.Fatal("error writing block receipts:", err) t.Fatal("error writing block receipts:", err)
} }
} }

@ -84,6 +84,11 @@ type ProtocolManager struct {
// NewProtocolManager returns a new ethereum sub protocol manager. The Ethereum sub protocol manages peers capable // NewProtocolManager returns a new ethereum sub protocol manager. The Ethereum sub protocol manages peers capable
// with the ethereum network. // with the ethereum network.
func NewProtocolManager(fastSync bool, networkId int, mux *event.TypeMux, txpool txPool, pow pow.PoW, blockchain *core.BlockChain, chaindb ethdb.Database) (*ProtocolManager, error) { func NewProtocolManager(fastSync bool, networkId int, mux *event.TypeMux, txpool txPool, pow pow.PoW, blockchain *core.BlockChain, chaindb ethdb.Database) (*ProtocolManager, error) {
// Figure out whether to allow fast sync or not
if fastSync && blockchain.CurrentBlock().NumberU64() > 0 {
glog.V(logger.Info).Infof("blockchain not empty, fast sync disabled")
fastSync = false
}
// Create the protocol manager with the base fields // Create the protocol manager with the base fields
manager := &ProtocolManager{ manager := &ProtocolManager{
fastSync: fastSync, fastSync: fastSync,
@ -103,7 +108,7 @@ func NewProtocolManager(fastSync bool, networkId int, mux *event.TypeMux, txpool
if fastSync && version < eth63 { if fastSync && version < eth63 {
continue continue
} }
// Compatible, initialize the sub-protocol // Compatible; initialise the sub-protocol
version := version // Closure for the run version := version // Closure for the run
manager.SubProtocols = append(manager.SubProtocols, p2p.Protocol{ manager.SubProtocols = append(manager.SubProtocols, p2p.Protocol{
Name: "eth", Name: "eth",
@ -120,13 +125,9 @@ func NewProtocolManager(fastSync bool, networkId int, mux *event.TypeMux, txpool
return nil, errIncompatibleConfig return nil, errIncompatibleConfig
} }
// Construct the different synchronisation mechanisms // Construct the different synchronisation mechanisms
syncMode := downloader.FullSync manager.downloader = downloader.New(chaindb, manager.eventMux, blockchain.HasHeader, blockchain.HasBlock, blockchain.GetHeader, blockchain.GetBlock,
if fastSync { blockchain.CurrentHeader, blockchain.CurrentBlock, blockchain.CurrentFastBlock, blockchain.FastSyncCommitHead, blockchain.GetTd,
syncMode = downloader.FastSync blockchain.InsertHeaderChain, blockchain.InsertChain, blockchain.InsertReceiptChain, blockchain.Rollback, manager.removePeer)
}
manager.downloader = downloader.New(syncMode, chaindb, manager.eventMux, blockchain.HasHeader, blockchain.HasBlock, blockchain.GetHeader,
blockchain.GetBlock, blockchain.CurrentHeader, blockchain.CurrentBlock, blockchain.CurrentFastBlock, blockchain.FastSyncCommitHead,
blockchain.GetTd, blockchain.InsertHeaderChain, blockchain.InsertChain, blockchain.InsertReceiptChain, blockchain.Rollback, manager.removePeer)
validator := func(block *types.Block, parent *types.Block) error { validator := func(block *types.Block, parent *types.Block) error {
return core.ValidateHeader(pow, block.Header(), parent.Header(), true, false) return core.ValidateHeader(pow, block.Header(), parent.Header(), true, false)

@ -443,7 +443,9 @@ func testGetNodeData(t *testing.T, protocol int) {
// Fetch for now the entire chain db // Fetch for now the entire chain db
hashes := []common.Hash{} hashes := []common.Hash{}
for _, key := range pm.chaindb.(*ethdb.MemDatabase).Keys() { for _, key := range pm.chaindb.(*ethdb.MemDatabase).Keys() {
hashes = append(hashes, common.BytesToHash(key)) if len(key) == len(common.Hash{}) {
hashes = append(hashes, common.BytesToHash(key))
}
} }
p2p.Send(peer.app, 0x0d, hashes) p2p.Send(peer.app, 0x0d, hashes)
msg, err := peer.app.ReadMsg() msg, err := peer.app.ReadMsg()

@ -101,7 +101,7 @@ func (rw *meteredMsgReadWriter) ReadMsg() (p2p.Msg, error) {
packets, traffic = reqBlockInPacketsMeter, reqBlockInTrafficMeter packets, traffic = reqBlockInPacketsMeter, reqBlockInTrafficMeter
case rw.version >= eth62 && msg.Code == BlockHeadersMsg: case rw.version >= eth62 && msg.Code == BlockHeadersMsg:
packets, traffic = reqBlockInPacketsMeter, reqBlockInTrafficMeter packets, traffic = reqHeaderInPacketsMeter, reqHeaderInTrafficMeter
case rw.version >= eth62 && msg.Code == BlockBodiesMsg: case rw.version >= eth62 && msg.Code == BlockBodiesMsg:
packets, traffic = reqBodyInPacketsMeter, reqBodyInTrafficMeter packets, traffic = reqBodyInPacketsMeter, reqBodyInTrafficMeter

@ -22,6 +22,7 @@ import (
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/eth/downloader"
"github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/logger/glog" "github.com/ethereum/go-ethereum/logger/glog"
"github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/discover"
@ -165,5 +166,20 @@ func (pm *ProtocolManager) synchronise(peer *peer) {
return return
} }
// Otherwise try to sync with the downloader // Otherwise try to sync with the downloader
pm.downloader.Synchronise(peer.id, peer.Head(), peer.Td()) mode := downloader.FullSync
if pm.fastSync {
mode = downloader.FastSync
}
pm.downloader.Synchronise(peer.id, peer.Head(), peer.Td(), mode)
// If fast sync was enabled, and we synced up, disable it
if pm.fastSync {
for pm.downloader.Synchronising() {
time.Sleep(100 * time.Millisecond)
}
if pm.blockchain.CurrentBlock().NumberU64() > 0 {
glog.V(logger.Info).Infof("fast sync complete, auto disabling")
pm.fastSync = false
}
}
} }

@ -0,0 +1,53 @@
// Copyright 2015 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package eth
import (
"testing"
"time"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/discover"
)
// Tests that fast sync gets disabled as soon as a real block is successfully
// imported into the blockchain.
func TestFastSyncDisabling(t *testing.T) {
// Create a pristine protocol manager, check that fast sync is left enabled
pmEmpty := newTestProtocolManagerMust(t, true, 0, nil, nil)
if !pmEmpty.fastSync {
t.Fatalf("fast sync disabled on pristine blockchain")
}
// Create a full protocol manager, check that fast sync gets disabled
pmFull := newTestProtocolManagerMust(t, true, 1024, nil, nil)
if pmFull.fastSync {
t.Fatalf("fast sync not disabled on non-empty blockchain")
}
// Sync up the two peers
io1, io2 := p2p.MsgPipe()
go pmFull.handle(pmFull.newPeer(63, NetworkId, p2p.NewPeer(discover.NodeID{}, "empty", nil), io2))
go pmEmpty.handle(pmEmpty.newPeer(63, NetworkId, p2p.NewPeer(discover.NodeID{}, "full", nil), io1))
time.Sleep(250 * time.Millisecond)
pmEmpty.synchronise(pmEmpty.peers.BestPeer())
// Check that fast sync was disabled
if pmEmpty.fastSync {
t.Fatalf("fast sync not disabled after successful synchronisation")
}
}

@ -17,6 +17,7 @@
package ethdb package ethdb
import ( import (
"errors"
"fmt" "fmt"
"sync" "sync"
@ -56,7 +57,10 @@ func (db *MemDatabase) Get(key []byte) ([]byte, error) {
db.lock.RLock() db.lock.RLock()
defer db.lock.RUnlock() defer db.lock.RUnlock()
return db.db[string(key)], nil if entry, ok := db.db[string(key)]; ok {
return entry, nil
}
return nil, errors.New("not found")
} }
func (db *MemDatabase) Keys() [][]byte { func (db *MemDatabase) Keys() [][]byte {
@ -132,8 +136,8 @@ func (b *memBatch) Write() error {
b.lock.RLock() b.lock.RLock()
defer b.lock.RUnlock() defer b.lock.RUnlock()
b.db.lock.RLock() b.db.lock.Lock()
defer b.db.lock.RUnlock() defer b.db.lock.Unlock()
for _, kv := range b.writes { for _, kv := range b.writes {
b.db.db[string(kv.k)] = kv.v b.db.db[string(kv.k)] = kv.v

@ -168,9 +168,7 @@ func (self *ethApi) IsMining(req *shared.Request) (interface{}, error) {
} }
func (self *ethApi) IsSyncing(req *shared.Request) (interface{}, error) { func (self *ethApi) IsSyncing(req *shared.Request) (interface{}, error) {
current := self.ethereum.BlockChain().CurrentBlock().NumberU64() origin, current, height := self.ethereum.Downloader().Progress()
origin, height := self.ethereum.Downloader().Boundaries()
if current < height { if current < height {
return map[string]interface{}{ return map[string]interface{}{
"startingBlock": newHexNum(big.NewInt(int64(origin)).Bytes()), "startingBlock": newHexNum(big.NewInt(int64(origin)).Bytes()),

@ -31,7 +31,7 @@ type request struct {
object *node // Target node to populate with retrieved data (hashnode originally) object *node // Target node to populate with retrieved data (hashnode originally)
parents []*request // Parent state nodes referencing this entry (notify all upon completion) parents []*request // Parent state nodes referencing this entry (notify all upon completion)
depth int // Depth level within the trie the node is located to prioritize DFS depth int // Depth level within the trie the node is located to prioritise DFS
deps int // Number of dependencies before allowed to commit this node deps int // Number of dependencies before allowed to commit this node
callback TrieSyncLeafCallback // Callback to invoke if a leaf node it reached on this branch callback TrieSyncLeafCallback // Callback to invoke if a leaf node it reached on this branch

Loading…
Cancel
Save