diff --git a/core/block_processor.go b/core/block_processor.go index ba63508055..60f0258c44 100644 --- a/core/block_processor.go +++ b/core/block_processor.go @@ -383,6 +383,15 @@ func (sm *BlockProcessor) ValidateHeader(header *types.Header, checkPow, uncle b } } +// ValidateHeaderWithParent verifies the validity of a header, relying on the database and +// POW behind the block processor. +func (sm *BlockProcessor) ValidateHeaderWithParent(header, parent *types.Header, checkPow, uncle bool) error { + if sm.bc.HasHeader(header.Hash()) { + return nil + } + return ValidateHeader(sm.Pow, header, parent, checkPow, uncle) +} + // See YP section 4.3.4. "Block Header Validity" // Validates a header. Returns an error if the header is invalid. func ValidateHeader(pow pow.PoW, header *types.Header, parent *types.Header, checkPow, uncle bool) error { @@ -425,7 +434,7 @@ func ValidateHeader(pow pow.PoW, header *types.Header, parent *types.Header, che if checkPow { // Verify the nonce of the header. Return an error if it's not valid if !pow.Verify(types.NewBlockWithHeader(header)) { - return ValidationError("Header's nonce is invalid (= %x)", header.Nonce) + return &BlockNonceErr{Hash: header.Hash(), Number: header.Number, Nonce: header.Nonce.Uint64()} } } return nil diff --git a/core/blockchain.go b/core/blockchain.go index 6c8a24751f..3e7dfa9eec 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -22,6 +22,8 @@ import ( "fmt" "io" "math/big" + "math/rand" + "runtime" "sync" "sync/atomic" "time" @@ -671,7 +673,7 @@ func (self *BlockChain) writeHeader(header *types.Header) error { // should be done or not. The reason behind the optional check is because some // of the header retrieval mechanisms already need to verfy nonces, as well as // because nonces can be verified sparsely, not needing to check each. -func (self *BlockChain) InsertHeaderChain(chain []*types.Header, verify bool) (int, error) { +func (self *BlockChain) InsertHeaderChain(chain []*types.Header, checkFreq int) (int, error) { self.wg.Add(1) defer self.wg.Done() @@ -683,16 +685,85 @@ func (self *BlockChain) InsertHeaderChain(chain []*types.Header, verify bool) (i stats := struct{ processed, ignored int }{} start := time.Now() - // Start the parallel nonce verifier, with a fake nonce if not requested - verifier := self.pow - if !verify { - verifier = FakePow{} + // Generate the list of headers that should be POW verified + verify := make([]bool, len(chain)) + for i := 0; i < len(verify)/checkFreq; i++ { + index := i*checkFreq + rand.Intn(checkFreq) + if index >= len(verify) { + index = len(verify) - 1 + } + verify[index] = true } - nonceAbort, nonceResults := verifyNoncesFromHeaders(verifier, chain) - defer close(nonceAbort) + verify[len(verify)-1] = true // Last should always be verified to avoid junk + + // Create the header verification task queue and worker functions + tasks := make(chan int, len(chain)) + for i := 0; i < len(chain); i++ { + tasks <- i + } + close(tasks) - // Iterate over the headers, inserting any new ones - complete := make([]bool, len(chain)) + errs, failed := make([]error, len(tasks)), int32(0) + process := func(worker int) { + for index := range tasks { + header, hash := chain[index], chain[index].Hash() + + // Short circuit insertion if shutting down or processing failed + if atomic.LoadInt32(&self.procInterrupt) == 1 { + return + } + if atomic.LoadInt32(&failed) > 0 { + return + } + // Short circuit if the header is bad or already known + if BadHashes[hash] { + errs[index] = BadHashError(hash) + atomic.AddInt32(&failed, 1) + return + } + if self.HasHeader(hash) { + continue + } + // Verify that the header honors the chain parameters + checkPow := verify[index] + + var err error + if index == 0 { + err = self.processor.ValidateHeader(header, checkPow, false) + } else { + err = self.processor.ValidateHeaderWithParent(header, chain[index-1], checkPow, false) + } + if err != nil { + errs[index] = err + atomic.AddInt32(&failed, 1) + return + } + } + } + // Start as many worker threads as goroutines allowed + pending := new(sync.WaitGroup) + for i := 0; i < runtime.GOMAXPROCS(0); i++ { + pending.Add(1) + go func(id int) { + defer pending.Done() + process(id) + }(i) + } + pending.Wait() + + // If anything failed, report + if atomic.LoadInt32(&self.procInterrupt) == 1 { + glog.V(logger.Debug).Infoln("premature abort during receipt chain processing") + return 0, nil + } + if failed > 0 { + for i, err := range errs { + if err != nil { + return i, err + } + } + } + // All headers passed verification, import them into the database for i, header := range chain { // Short circuit insertion if shutting down if atomic.LoadInt32(&self.procInterrupt) == 1 { @@ -701,24 +772,7 @@ func (self *BlockChain) InsertHeaderChain(chain []*types.Header, verify bool) (i } hash := header.Hash() - // Accumulate verification results until the next header is verified - for !complete[i] { - if res := <-nonceResults; res.valid { - complete[res.index] = true - } else { - header := chain[res.index] - return res.index, &BlockNonceErr{ - Hash: header.Hash(), - Number: new(big.Int).Set(header.Number), - Nonce: header.Nonce.Uint64(), - } - } - } - if BadHashes[hash] { - glog.V(logger.Error).Infof("bad header %d [%x…], known bad hash", header.Number, hash) - return i, BadHashError(hash) - } - // Write the header to the chain and get the status + // If the header's already known, skip it, otherwise store if self.HasHeader(hash) { stats.ignored++ continue @@ -743,76 +797,116 @@ func (self *BlockChain) InsertReceiptChain(blockChain types.Blocks, receiptChain defer self.wg.Done() // Collect some import statistics to report on - stats := struct{ processed, ignored int }{} + stats := struct{ processed, ignored int32 }{} start := time.Now() - // Iterate over the blocks and receipts, inserting any new ones + // Create the block importing task queue and worker functions + tasks := make(chan int, len(blockChain)) for i := 0; i < len(blockChain) && i < len(receiptChain); i++ { - block, receipts := blockChain[i], receiptChain[i] + tasks <- i + } + close(tasks) - // Short circuit insertion if shutting down - if atomic.LoadInt32(&self.procInterrupt) == 1 { - glog.V(logger.Debug).Infoln("premature abort during receipt chain processing") - break - } - // Short circuit if the owner header is unknown - if !self.HasHeader(block.Hash()) { - glog.V(logger.Debug).Infof("containing header #%d [%x…] unknown", block.Number(), block.Hash().Bytes()[:4]) - return i, fmt.Errorf("containing header #%d [%x…] unknown", block.Number(), block.Hash().Bytes()[:4]) - } - // Skip if the entire data is already known - if self.HasBlock(block.Hash()) { - stats.ignored++ - continue - } - // Compute all the non-consensus fields of the receipts - transactions, logIndex := block.Transactions(), uint(0) - for j := 0; j < len(receipts); j++ { - // The transaction hash can be retrieved from the transaction itself - receipts[j].TxHash = transactions[j].Hash() - - // The contract address can be derived from the transaction itself - if MessageCreatesContract(transactions[j]) { - from, _ := transactions[j].From() - receipts[j].ContractAddress = crypto.CreateAddress(from, transactions[j].Nonce()) + errs, failed := make([]error, len(tasks)), int32(0) + process := func(worker int) { + for index := range tasks { + block, receipts := blockChain[index], receiptChain[index] + + // Short circuit insertion if shutting down or processing failed + if atomic.LoadInt32(&self.procInterrupt) == 1 { + return } - // The used gas can be calculated based on previous receipts - if j == 0 { - receipts[j].GasUsed = new(big.Int).Set(receipts[j].CumulativeGasUsed) - } else { - receipts[j].GasUsed = new(big.Int).Sub(receipts[j].CumulativeGasUsed, receipts[j-1].CumulativeGasUsed) + if atomic.LoadInt32(&failed) > 0 { + return } - // The derived log fields can simply be set from the block and transaction - for k := 0; k < len(receipts[j].Logs); k++ { - receipts[j].Logs[k].BlockNumber = block.NumberU64() - receipts[j].Logs[k].BlockHash = block.Hash() - receipts[j].Logs[k].TxHash = receipts[j].TxHash - receipts[j].Logs[k].TxIndex = uint(j) - receipts[j].Logs[k].Index = logIndex - logIndex++ + // Short circuit if the owner header is unknown + if !self.HasHeader(block.Hash()) { + errs[index] = fmt.Errorf("containing header #%d [%x…] unknown", block.Number(), block.Hash().Bytes()[:4]) + atomic.AddInt32(&failed, 1) + return } + // Skip if the entire data is already known + if self.HasBlock(block.Hash()) { + atomic.AddInt32(&stats.ignored, 1) + continue + } + // Compute all the non-consensus fields of the receipts + transactions, logIndex := block.Transactions(), uint(0) + for j := 0; j < len(receipts); j++ { + // The transaction hash can be retrieved from the transaction itself + receipts[j].TxHash = transactions[j].Hash() + + // The contract address can be derived from the transaction itself + if MessageCreatesContract(transactions[j]) { + from, _ := transactions[j].From() + receipts[j].ContractAddress = crypto.CreateAddress(from, transactions[j].Nonce()) + } + // The used gas can be calculated based on previous receipts + if j == 0 { + receipts[j].GasUsed = new(big.Int).Set(receipts[j].CumulativeGasUsed) + } else { + receipts[j].GasUsed = new(big.Int).Sub(receipts[j].CumulativeGasUsed, receipts[j-1].CumulativeGasUsed) + } + // The derived log fields can simply be set from the block and transaction + for k := 0; k < len(receipts[j].Logs); k++ { + receipts[j].Logs[k].BlockNumber = block.NumberU64() + receipts[j].Logs[k].BlockHash = block.Hash() + receipts[j].Logs[k].TxHash = receipts[j].TxHash + receipts[j].Logs[k].TxIndex = uint(j) + receipts[j].Logs[k].Index = logIndex + logIndex++ + } + } + // Write all the data out into the database + if err := WriteBody(self.chainDb, block.Hash(), &types.Body{block.Transactions(), block.Uncles()}); err != nil { + errs[index] = fmt.Errorf("failed to write block body: %v", err) + atomic.AddInt32(&failed, 1) + glog.Fatal(errs[index]) + return + } + if err := PutBlockReceipts(self.chainDb, block.Hash(), receipts); err != nil { + errs[index] = fmt.Errorf("failed to write block receipts: %v", err) + atomic.AddInt32(&failed, 1) + glog.Fatal(errs[index]) + return + } + atomic.AddInt32(&stats.processed, 1) } - // Write all the data out into the database - if err := WriteBody(self.chainDb, block.Hash(), &types.Body{block.Transactions(), block.Uncles()}); err != nil { - glog.Fatalf("failed to write block body: %v", err) - return i, err - } - if err := PutBlockReceipts(self.chainDb, block.Hash(), receipts); err != nil { - glog.Fatalf("failed to write block receipts: %v", err) - return i, err - } - // Update the head fast sync block if better - self.mu.Lock() - if self.GetTd(self.currentFastBlock.Hash()).Cmp(self.GetTd(block.Hash())) < 0 { - if err := WriteHeadFastBlockHash(self.chainDb, block.Hash()); err != nil { - glog.Fatalf("failed to update head fast block hash: %v", err) + } + // Start as many worker threads as goroutines allowed + pending := new(sync.WaitGroup) + for i := 0; i < runtime.GOMAXPROCS(0); i++ { + pending.Add(1) + go func(id int) { + defer pending.Done() + process(id) + }(i) + } + pending.Wait() + + // If anything failed, report + if atomic.LoadInt32(&self.procInterrupt) == 1 { + glog.V(logger.Debug).Infoln("premature abort during receipt chain processing") + return 0, nil + } + if failed > 0 { + for i, err := range errs { + if err != nil { + return i, err } - self.currentFastBlock = block } - self.mu.Unlock() - - stats.processed++ } + // Update the head fast sync block if better + self.mu.Lock() + head := blockChain[len(errs)-1] + if self.GetTd(self.currentFastBlock.Hash()).Cmp(self.GetTd(head.Hash())) < 0 { + if err := WriteHeadFastBlockHash(self.chainDb, head.Hash()); err != nil { + glog.Fatalf("failed to update head fast block hash: %v", err) + } + self.currentFastBlock = head + } + self.mu.Unlock() + // Report some public statistics so the user has a clue what's going on first, last := blockChain[0], blockChain[len(blockChain)-1] glog.V(logger.Info).Infof("imported %d receipt(s) (%d ignored) in %v. #%d [%x… / %x…]", stats.processed, stats.ignored, diff --git a/core/blockchain_test.go b/core/blockchain_test.go index 93c2128bc5..a614aaa2fc 100644 --- a/core/blockchain_test.go +++ b/core/blockchain_test.go @@ -94,7 +94,7 @@ func testFork(t *testing.T, processor *BlockProcessor, i, n int, full bool, comp } } else { headerChainB = makeHeaderChain(processor2.bc.CurrentHeader(), n, db, forkSeed) - if _, err := processor2.bc.InsertHeaderChain(headerChainB, true); err != nil { + if _, err := processor2.bc.InsertHeaderChain(headerChainB, 1); err != nil { t.Fatalf("failed to insert forking chain: %v", err) } } @@ -415,7 +415,9 @@ func TestChainMultipleInsertions(t *testing.T) { type bproc struct{} -func (bproc) Process(*types.Block) (vm.Logs, types.Receipts, error) { return nil, nil, nil } +func (bproc) Process(*types.Block) (vm.Logs, types.Receipts, error) { return nil, nil, nil } +func (bproc) ValidateHeader(*types.Header, bool, bool) error { return nil } +func (bproc) ValidateHeaderWithParent(*types.Header, *types.Header, bool, bool) error { return nil } func makeHeaderChainWithDiff(genesis *types.Block, d []int, seed byte) []*types.Header { blocks := makeBlockChainWithDiff(genesis, d, seed) @@ -492,8 +494,8 @@ func testReorg(t *testing.T, first, second []int, td int64, full bool) { bc.InsertChain(makeBlockChainWithDiff(genesis, first, 11)) bc.InsertChain(makeBlockChainWithDiff(genesis, second, 22)) } else { - bc.InsertHeaderChain(makeHeaderChainWithDiff(genesis, first, 11), false) - bc.InsertHeaderChain(makeHeaderChainWithDiff(genesis, second, 22), false) + bc.InsertHeaderChain(makeHeaderChainWithDiff(genesis, first, 11), 1) + bc.InsertHeaderChain(makeHeaderChainWithDiff(genesis, second, 22), 1) } // Check that the chain is valid number and link wise if full { @@ -543,7 +545,7 @@ func testBadHashes(t *testing.T, full bool) { } else { headers := makeHeaderChainWithDiff(genesis, []int{1, 2, 4}, 10) BadHashes[headers[2].Hash()] = true - _, err = bc.InsertHeaderChain(headers, true) + _, err = bc.InsertHeaderChain(headers, 1) } if !IsBadHashError(err) { t.Errorf("error mismatch: want: BadHashError, have: %v", err) @@ -575,7 +577,7 @@ func testReorgBadHashes(t *testing.T, full bool) { BadHashes[blocks[3].Header().Hash()] = true defer func() { delete(BadHashes, blocks[3].Header().Hash()) }() } else { - if _, err := bc.InsertHeaderChain(headers, true); err != nil { + if _, err := bc.InsertHeaderChain(headers, 1); err != nil { t.Fatalf("failed to import headers: %v", err) } if bc.CurrentHeader().Hash() != headers[3].Hash() { @@ -631,6 +633,8 @@ func testInsertNonceError(t *testing.T, full bool) { failHash = blocks[failAt].Hash() processor.bc.pow = failPow{failNum} + processor.Pow = failPow{failNum} + failRes, err = processor.bc.InsertChain(blocks) } else { headers := makeHeaderChain(processor.bc.CurrentHeader(), i, db, 0) @@ -640,7 +644,9 @@ func testInsertNonceError(t *testing.T, full bool) { failHash = headers[failAt].Hash() processor.bc.pow = failPow{failNum} - failRes, err = processor.bc.InsertHeaderChain(headers, true) + processor.Pow = failPow{failNum} + + failRes, err = processor.bc.InsertHeaderChain(headers, 1) } // Check that the returned error indicates the nonce failure. if failRes != failAt { @@ -714,12 +720,13 @@ func TestFastVsFullChains(t *testing.T) { fastDb, _ := ethdb.NewMemDatabase() WriteGenesisBlockForTesting(fastDb, GenesisAccount{address, funds}) fast, _ := NewBlockChain(fastDb, FakePow{}, new(event.TypeMux)) + fast.SetProcessor(NewBlockProcessor(fastDb, FakePow{}, fast, new(event.TypeMux))) headers := make([]*types.Header, len(blocks)) for i, block := range blocks { headers[i] = block.Header() } - if n, err := fast.InsertHeaderChain(headers, true); err != nil { + if n, err := fast.InsertHeaderChain(headers, 1); err != nil { t.Fatalf("failed to insert header %d: %v", n, err) } if n, err := fast.InsertReceiptChain(blocks, receipts); err != nil { @@ -796,12 +803,13 @@ func TestLightVsFastVsFullChainHeads(t *testing.T) { fastDb, _ := ethdb.NewMemDatabase() WriteGenesisBlockForTesting(fastDb, GenesisAccount{address, funds}) fast, _ := NewBlockChain(fastDb, FakePow{}, new(event.TypeMux)) + fast.SetProcessor(NewBlockProcessor(fastDb, FakePow{}, fast, new(event.TypeMux))) headers := make([]*types.Header, len(blocks)) for i, block := range blocks { headers[i] = block.Header() } - if n, err := fast.InsertHeaderChain(headers, true); err != nil { + if n, err := fast.InsertHeaderChain(headers, 1); err != nil { t.Fatalf("failed to insert header %d: %v", n, err) } if n, err := fast.InsertReceiptChain(blocks, receipts); err != nil { @@ -813,8 +821,9 @@ func TestLightVsFastVsFullChainHeads(t *testing.T) { lightDb, _ := ethdb.NewMemDatabase() WriteGenesisBlockForTesting(lightDb, GenesisAccount{address, funds}) light, _ := NewBlockChain(lightDb, FakePow{}, new(event.TypeMux)) + light.SetProcessor(NewBlockProcessor(lightDb, FakePow{}, light, new(event.TypeMux))) - if n, err := light.InsertHeaderChain(headers, true); err != nil { + if n, err := light.InsertHeaderChain(headers, 1); err != nil { t.Fatalf("failed to insert header %d: %v", n, err) } assert(t, "light", light, height, 0, 0) diff --git a/core/chain_makers.go b/core/chain_makers.go index 31b2627afe..1d41b40918 100644 --- a/core/chain_makers.go +++ b/core/chain_makers.go @@ -239,7 +239,7 @@ func newCanonical(n int, full bool) (ethdb.Database, *BlockProcessor, error) { } // Header-only chain requested headers := makeHeaderChain(genesis.Header(), n, db, canonicalSeed) - _, err := blockchain.InsertHeaderChain(headers, true) + _, err := blockchain.InsertHeaderChain(headers, 1) return db, processor, err } diff --git a/core/error.go b/core/error.go index 5e32124a7e..af42cd19c5 100644 --- a/core/error.go +++ b/core/error.go @@ -111,7 +111,7 @@ type BlockNonceErr struct { } func (err *BlockNonceErr) Error() string { - return fmt.Sprintf("block %d (%v) nonce is invalid (got %d)", err.Number, err.Hash, err.Nonce) + return fmt.Sprintf("nonce for #%d [%x…] is invalid (got %d)", err.Number, err.Hash, err.Nonce) } // IsBlockNonceErr returns true for invalid block nonce errors. diff --git a/core/state/sync.go b/core/state/sync.go index e9bebe8ee7..5a388886cf 100644 --- a/core/state/sync.go +++ b/core/state/sync.go @@ -21,78 +21,51 @@ import ( "math/big" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/crypto/sha3" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" ) -type StateSync struct { - db ethdb.Database - sync *trie.TrieSync - codeReqs map[common.Hash]struct{} // requested but not yet written to database - codeReqList []common.Hash // requested since last Missing -} +// StateSync is the main state synchronisation scheduler, which provides yet the +// unknown state hashes to retrieve, accepts node data associated with said hashes +// and reconstructs the state database step by step until all is done. +type StateSync trie.TrieSync -var sha3_nil = common.BytesToHash(sha3.NewKeccak256().Sum(nil)) +// NewStateSync create a new state trie download scheduler. +func NewStateSync(root common.Hash, database ethdb.Database) *StateSync { + // Pre-declare the result syncer t + var syncer *trie.TrieSync -func NewStateSync(root common.Hash, db ethdb.Database) *StateSync { - ss := &StateSync{ - db: db, - codeReqs: make(map[common.Hash]struct{}), - } - ss.codeReqs[sha3_nil] = struct{}{} // never request the nil hash - ss.sync = trie.NewTrieSync(root, db, ss.leafFound) - return ss -} + callback := func(leaf []byte, parent common.Hash) error { + var obj struct { + Nonce uint64 + Balance *big.Int + Root common.Hash + CodeHash []byte + } + if err := rlp.Decode(bytes.NewReader(leaf), &obj); err != nil { + return err + } + syncer.AddSubTrie(obj.Root, 64, parent, nil) + syncer.AddRawEntry(common.BytesToHash(obj.CodeHash), 64, parent) -func (self *StateSync) leafFound(leaf []byte, parent common.Hash) error { - var obj struct { - Nonce uint64 - Balance *big.Int - Root common.Hash - CodeHash []byte - } - if err := rlp.Decode(bytes.NewReader(leaf), &obj); err != nil { - return err + return nil } - self.sync.AddSubTrie(obj.Root, 64, parent, nil) + syncer = trie.NewTrieSync(root, database, callback) + return (*StateSync)(syncer) +} - codehash := common.BytesToHash(obj.CodeHash) - if _, ok := self.codeReqs[codehash]; !ok { - code, _ := self.db.Get(obj.CodeHash) - if code == nil { - self.codeReqs[codehash] = struct{}{} - self.codeReqList = append(self.codeReqList, codehash) - } - } - return nil +// Missing retrieves the known missing nodes from the state trie for retrieval. +func (s *StateSync) Missing(max int) []common.Hash { + return (*trie.TrieSync)(s).Missing(max) } -func (self *StateSync) Missing(max int) []common.Hash { - cr := len(self.codeReqList) - gh := 0 - if max != 0 { - if cr > max { - cr = max - } - gh = max - cr - } - list := append(self.sync.Missing(gh), self.codeReqList[:cr]...) - self.codeReqList = self.codeReqList[cr:] - return list +// Process injects a batch of retrieved trie nodes data. +func (s *StateSync) Process(list []trie.SyncResult) (int, error) { + return (*trie.TrieSync)(s).Process(list) } -func (self *StateSync) Process(list []trie.SyncResult) error { - for i := 0; i < len(list); i++ { - if _, ok := self.codeReqs[list[i].Hash]; ok { // code data, not a node - self.db.Put(list[i].Hash[:], list[i].Data) - delete(self.codeReqs, list[i].Hash) - list[i] = list[len(list)-1] - list = list[:len(list)-1] - i-- - } - } - _, err := self.sync.Process(list) - return err +// Pending returns the number of state entries currently pending for download. +func (s *StateSync) Pending() int { + return (*trie.TrieSync)(s).Pending() } diff --git a/core/state/sync_test.go b/core/state/sync_test.go index f6afe8bd83..f0376d484b 100644 --- a/core/state/sync_test.go +++ b/core/state/sync_test.go @@ -115,8 +115,8 @@ func testIterativeStateSync(t *testing.T, batch int) { } results[i] = trie.SyncResult{hash, data} } - if err := sched.Process(results); err != nil { - t.Fatalf("failed to process results: %v", err) + if index, err := sched.Process(results); err != nil { + t.Fatalf("failed to process result #%d: %v", index, err) } queue = append(queue[:0], sched.Missing(batch)...) } @@ -145,8 +145,8 @@ func TestIterativeDelayedStateSync(t *testing.T) { } results[i] = trie.SyncResult{hash, data} } - if err := sched.Process(results); err != nil { - t.Fatalf("failed to process results: %v", err) + if index, err := sched.Process(results); err != nil { + t.Fatalf("failed to process result #%d: %v", index, err) } queue = append(queue[len(results):], sched.Missing(0)...) } @@ -183,8 +183,8 @@ func testIterativeRandomStateSync(t *testing.T, batch int) { results = append(results, trie.SyncResult{hash, data}) } // Feed the retrieved results back and queue new tasks - if err := sched.Process(results); err != nil { - t.Fatalf("failed to process results: %v", err) + if index, err := sched.Process(results); err != nil { + t.Fatalf("failed to process result #%d: %v", index, err) } queue = make(map[common.Hash]struct{}) for _, hash := range sched.Missing(batch) { @@ -226,8 +226,8 @@ func TestIterativeRandomDelayedStateSync(t *testing.T) { } } // Feed the retrieved results back and queue new tasks - if err := sched.Process(results); err != nil { - t.Fatalf("failed to process results: %v", err) + if index, err := sched.Process(results); err != nil { + t.Fatalf("failed to process result #%d: %v", index, err) } for _, hash := range sched.Missing(0) { queue[hash] = struct{}{} diff --git a/core/types/common.go b/core/types/common.go index 29019a1b47..fe682f98a3 100644 --- a/core/types/common.go +++ b/core/types/common.go @@ -20,4 +20,6 @@ import "github.com/ethereum/go-ethereum/core/vm" type BlockProcessor interface { Process(*Block) (vm.Logs, Receipts, error) + ValidateHeader(*Header, bool, bool) error + ValidateHeaderWithParent(*Header, *Header, bool, bool) error } diff --git a/eth/downloader/downloader.go b/eth/downloader/downloader.go index 96177ae8ae..e19b70dfdd 100644 --- a/eth/downloader/downloader.go +++ b/eth/downloader/downloader.go @@ -830,7 +830,7 @@ func (d *Downloader) fetchBlocks61(from uint64) error { } // If there's nothing more to fetch, wait or terminate if d.queue.PendingBlocks() == 0 { - if d.queue.InFlight() == 0 && finished { + if !d.queue.InFlightBlocks() && finished { glog.V(logger.Debug).Infof("Block fetching completed") return nil } @@ -864,7 +864,7 @@ func (d *Downloader) fetchBlocks61(from uint64) error { } // Make sure that we have peers available for fetching. If all peers have been tried // and all failed throw an error - if !throttled && d.queue.InFlight() == 0 && len(idles) == total { + if !throttled && !d.queue.InFlightBlocks() && len(idles) == total { return errPeersUnavailable } } @@ -1124,7 +1124,7 @@ func (d *Downloader) fetchHeaders(p *peer, td *big.Int, from uint64) error { glog.V(logger.Detail).Infof("%v: schedule %d headers from #%d", p, len(headers), from) if d.mode == FastSync || d.mode == LightSync { - if n, err := d.insertHeaders(headers, false); err != nil { + if n, err := d.insertHeaders(headers, headerCheckFrequency); err != nil { glog.V(logger.Debug).Infof("%v: invalid header #%d [%x…]: %v", p, headers[n].Number, headers[n].Hash().Bytes()[:4], err) return errInvalidChain } @@ -1194,8 +1194,8 @@ func (d *Downloader) fetchBodies(from uint64) error { setIdle = func(p *peer) { p.SetBlocksIdle() } ) err := d.fetchParts(errCancelBodyFetch, d.bodyCh, deliver, d.bodyWakeCh, expire, - d.queue.PendingBlocks, d.queue.ThrottleBlocks, d.queue.ReserveBodies, d.bodyFetchHook, - fetch, d.queue.CancelBodies, capacity, getIdles, setIdle, "Body") + d.queue.PendingBlocks, d.queue.InFlightBlocks, d.queue.ThrottleBlocks, d.queue.ReserveBodies, + d.bodyFetchHook, fetch, d.queue.CancelBodies, capacity, getIdles, setIdle, "Body") glog.V(logger.Debug).Infof("Block body download terminated: %v", err) return err @@ -1218,8 +1218,8 @@ func (d *Downloader) fetchReceipts(from uint64) error { setIdle = func(p *peer) { p.SetReceiptsIdle() } ) err := d.fetchParts(errCancelReceiptFetch, d.receiptCh, deliver, d.receiptWakeCh, expire, - d.queue.PendingReceipts, d.queue.ThrottleReceipts, d.queue.ReserveReceipts, d.receiptFetchHook, - fetch, d.queue.CancelReceipts, capacity, d.peers.ReceiptIdlePeers, setIdle, "Receipt") + d.queue.PendingReceipts, d.queue.InFlightReceipts, d.queue.ThrottleReceipts, d.queue.ReserveReceipts, + d.receiptFetchHook, fetch, d.queue.CancelReceipts, capacity, d.peers.ReceiptIdlePeers, setIdle, "Receipt") glog.V(logger.Debug).Infof("Receipt download terminated: %v", err) return err @@ -1234,15 +1234,29 @@ func (d *Downloader) fetchNodeData() error { var ( deliver = func(packet dataPack) error { start := time.Now() - done, found, err := d.queue.DeliverNodeData(packet.PeerId(), packet.(*statePack).states) - - d.syncStatsLock.Lock() - totalDone, totalKnown := d.syncStatsStateDone+uint64(done), d.syncStatsStateTotal+uint64(found) - d.syncStatsStateDone, d.syncStatsStateTotal = totalDone, totalKnown - d.syncStatsLock.Unlock() + return d.queue.DeliverNodeData(packet.PeerId(), packet.(*statePack).states, func(err error, delivered int) { + if err != nil { + // If the node data processing failed, the root hash is very wrong, abort + glog.V(logger.Error).Infof("peer %d: state processing failed: %v", packet.PeerId(), err) + d.cancel() + return + } + // Processing succeeded, notify state fetcher and processor of continuation + if d.queue.PendingNodeData() == 0 { + go d.process() + } else { + select { + case d.stateWakeCh <- true: + default: + } + } + // Log a message to the user and return + d.syncStatsLock.Lock() + defer d.syncStatsLock.Unlock() - glog.V(logger.Info).Infof("imported %d [%d / %d] state entries in %v.", done, totalDone, totalKnown, time.Since(start)) - return err + d.syncStatsStateDone += uint64(delivered) + glog.V(logger.Info).Infof("imported %d state entries in %v: processed %d in total", delivered, time.Since(start), d.syncStatsStateDone) + }) } expire = func() []string { return d.queue.ExpireNodeData(stateHardTTL) } throttle = func() bool { return false } @@ -1254,8 +1268,8 @@ func (d *Downloader) fetchNodeData() error { setIdle = func(p *peer) { p.SetNodeDataIdle() } ) err := d.fetchParts(errCancelReceiptFetch, d.stateCh, deliver, d.stateWakeCh, expire, - d.queue.PendingNodeData, throttle, reserve, nil, fetch, d.queue.CancelNodeData, - capacity, d.peers.ReceiptIdlePeers, setIdle, "State") + d.queue.PendingNodeData, d.queue.InFlightNodeData, throttle, reserve, nil, fetch, + d.queue.CancelNodeData, capacity, d.peers.ReceiptIdlePeers, setIdle, "State") glog.V(logger.Debug).Infof("Node state data download terminated: %v", err) return err @@ -1265,8 +1279,9 @@ func (d *Downloader) fetchNodeData() error { // peers, reserving a chunk of fetch requests for each, waiting for delivery and // also periodically checking for timeouts. func (d *Downloader) fetchParts(errCancel error, deliveryCh chan dataPack, deliver func(packet dataPack) error, wakeCh chan bool, - expire func() []string, pending func() int, throttle func() bool, reserve func(*peer, int) (*fetchRequest, bool, error), fetchHook func([]*types.Header), - fetch func(*peer, *fetchRequest) error, cancel func(*fetchRequest), capacity func(*peer) int, idle func() ([]*peer, int), setIdle func(*peer), kind string) error { + expire func() []string, pending func() int, inFlight func() bool, throttle func() bool, reserve func(*peer, int) (*fetchRequest, bool, error), + fetchHook func([]*types.Header), fetch func(*peer, *fetchRequest) error, cancel func(*fetchRequest), capacity func(*peer) int, + idle func() ([]*peer, int), setIdle func(*peer), kind string) error { // Create a ticker to detect expired retreival tasks ticker := time.NewTicker(100 * time.Millisecond) @@ -1378,14 +1393,14 @@ func (d *Downloader) fetchParts(errCancel error, deliveryCh chan dataPack, deliv } // If there's nothing more to fetch, wait or terminate if pending() == 0 { - if d.queue.InFlight() == 0 && finished { + if !inFlight() && finished { glog.V(logger.Debug).Infof("%s fetching completed", kind) return nil } break } // Send a download request to all idle peers, until throttled - progressed, throttled := false, false + progressed, throttled, running := false, false, inFlight() idles, total := idle() for _, peer := range idles { @@ -1423,10 +1438,11 @@ func (d *Downloader) fetchParts(errCancel error, deliveryCh chan dataPack, deliv glog.V(logger.Error).Infof("%v: %s fetch failed, rescheduling", peer, strings.ToLower(kind)) cancel(request) } + running = true } // Make sure that we have peers available for fetching. If all peers have been tried // and all failed throw an error - if !progressed && !throttled && d.queue.InFlight() == 0 && len(idles) == total { + if !progressed && !throttled && !running && len(idles) == total && pending() > 0 { return errPeersUnavailable } } @@ -1514,12 +1530,12 @@ func (d *Downloader) process() { ) switch { case len(headers) > 0: - index, err = d.insertHeaders(headers, true) + index, err = d.insertHeaders(headers, headerCheckFrequency) case len(receipts) > 0: index, err = d.insertReceipts(blocks, receipts) if err == nil && blocks[len(blocks)-1].NumberU64() == d.queue.fastSyncPivot { - err = d.commitHeadBlock(blocks[len(blocks)-1].Hash()) + index, err = len(blocks)-1, d.commitHeadBlock(blocks[len(blocks)-1].Hash()) } default: index, err = d.insertBlocks(blocks) diff --git a/eth/downloader/downloader_test.go b/eth/downloader/downloader_test.go index 8944ae4b0b..0e60371b3b 100644 --- a/eth/downloader/downloader_test.go +++ b/eth/downloader/downloader_test.go @@ -268,7 +268,7 @@ func (dl *downloadTester) getTd(hash common.Hash) *big.Int { } // insertHeaders injects a new batch of headers into the simulated chain. -func (dl *downloadTester) insertHeaders(headers []*types.Header, verify bool) (int, error) { +func (dl *downloadTester) insertHeaders(headers []*types.Header, checkFreq int) (int, error) { dl.lock.Lock() defer dl.lock.Unlock() @@ -1262,7 +1262,7 @@ func testForkedSyncBoundaries(t *testing.T, protocol int, mode SyncMode) { pending.Wait() // Simulate a successful sync above the fork - tester.downloader.syncStatsOrigin = tester.downloader.syncStatsHeight + tester.downloader.syncStatsChainOrigin = tester.downloader.syncStatsChainHeight // Synchronise with the second fork and check boundary resets tester.newPeer("fork B", protocol, hashesB, headersB, blocksB, receiptsB) diff --git a/eth/downloader/queue.go b/eth/downloader/queue.go index 942ed0d63f..bb8d892cdd 100644 --- a/eth/downloader/queue.go +++ b/eth/downloader/queue.go @@ -23,6 +23,7 @@ import ( "errors" "fmt" "sync" + "sync/atomic" "time" "github.com/ethereum/go-ethereum/common" @@ -93,8 +94,10 @@ type queue struct { stateTaskQueue *prque.Prque // [eth/63] Priority queue of the hashes to fetch the node data for statePendPool map[string]*fetchRequest // [eth/63] Currently pending node data retrieval operations - stateDatabase ethdb.Database // [eth/63] Trie database to populate during state reassembly - stateScheduler *state.StateSync // [eth/63] State trie synchronisation scheduler and integrator + stateDatabase ethdb.Database // [eth/63] Trie database to populate during state reassembly + stateScheduler *state.StateSync // [eth/63] State trie synchronisation scheduler and integrator + stateProcessors int32 // [eth/63] Number of currently running state processors + stateSchedLock sync.RWMutex // [eth/63] Lock serializing access to the state scheduler resultCache []*fetchResult // Downloaded but not yet delivered fetch results resultOffset uint64 // Offset of the first cached fetch result in the block-chain @@ -175,18 +178,40 @@ func (q *queue) PendingReceipts() int { // PendingNodeData retrieves the number of node data entries pending for retrieval. func (q *queue) PendingNodeData() int { + q.stateSchedLock.RLock() + defer q.stateSchedLock.RUnlock() + + if q.stateScheduler != nil { + return q.stateScheduler.Pending() + } + return 0 +} + +// InFlightBlocks retrieves whether there are block fetch requests currently in +// flight. +func (q *queue) InFlightBlocks() bool { q.lock.RLock() defer q.lock.RUnlock() - return q.stateTaskQueue.Size() + return len(q.blockPendPool) > 0 } -// InFlight retrieves the number of fetch requests currently in flight. -func (q *queue) InFlight() int { +// InFlightReceipts retrieves whether there are receipt fetch requests currently +// in flight. +func (q *queue) InFlightReceipts() bool { q.lock.RLock() defer q.lock.RUnlock() - return len(q.blockPendPool) + len(q.receiptPendPool) + len(q.statePendPool) + return len(q.receiptPendPool) > 0 +} + +// InFlightNodeData retrieves whether there are node data entry fetch requests +// currently in flight. +func (q *queue) InFlightNodeData() bool { + q.lock.RLock() + defer q.lock.RUnlock() + + return len(q.statePendPool)+int(atomic.LoadInt32(&q.stateProcessors)) > 0 } // Idle returns if the queue is fully idle or has some data still inside. This @@ -199,6 +224,12 @@ func (q *queue) Idle() bool { pending := len(q.blockPendPool) + len(q.receiptPendPool) + len(q.statePendPool) cached := len(q.blockDonePool) + len(q.receiptDonePool) + q.stateSchedLock.RLock() + if q.stateScheduler != nil { + queued += q.stateScheduler.Pending() + } + q.stateSchedLock.RUnlock() + return (queued + pending + cached) == 0 } @@ -299,12 +330,9 @@ func (q *queue) Schedule(headers []*types.Header, from uint64) []*types.Header { } if q.mode == FastSync && header.Number.Uint64() == q.fastSyncPivot { // Pivoting point of the fast sync, retrieve the state tries + q.stateSchedLock.Lock() q.stateScheduler = state.NewStateSync(header.Root, q.stateDatabase) - for _, hash := range q.stateScheduler.Missing(0) { - q.stateTaskPool[hash] = q.stateTaskIndex - q.stateTaskQueue.Push(hash, -float32(q.stateTaskIndex)) - q.stateTaskIndex++ - } + q.stateSchedLock.Unlock() } inserts = append(inserts, header) q.headerHead = hash @@ -325,8 +353,13 @@ func (q *queue) GetHeadResult() *fetchResult { if q.resultCache[0].Pending > 0 { return nil } - if q.mode == FastSync && q.resultCache[0].Header.Number.Uint64() == q.fastSyncPivot && len(q.stateTaskPool) > 0 { - return nil + if q.mode == FastSync && q.resultCache[0].Header.Number.Uint64() == q.fastSyncPivot { + if len(q.stateTaskPool) > 0 { + return nil + } + if q.PendingNodeData() > 0 { + return nil + } } return q.resultCache[0] } @@ -345,8 +378,13 @@ func (q *queue) TakeResults() []*fetchResult { break } // The fast sync pivot block may only be processed after state fetch completes - if q.mode == FastSync && result.Header.Number.Uint64() == q.fastSyncPivot && len(q.stateTaskPool) > 0 { - break + if q.mode == FastSync && result.Header.Number.Uint64() == q.fastSyncPivot { + if len(q.stateTaskPool) > 0 { + break + } + if q.PendingNodeData() > 0 { + break + } } // If we've just inserted the fast sync pivot, stop as the following batch needs different insertion if q.mode == FastSync && result.Header.Number.Uint64() == q.fastSyncPivot+1 && len(results) > 0 { @@ -373,26 +411,34 @@ func (q *queue) TakeResults() []*fetchResult { // ReserveBlocks reserves a set of block hashes for the given peer, skipping any // previously failed download. func (q *queue) ReserveBlocks(p *peer, count int) *fetchRequest { - return q.reserveHashes(p, count, q.hashQueue, q.blockPendPool, len(q.resultCache)-len(q.blockDonePool)) + return q.reserveHashes(p, count, q.hashQueue, nil, q.blockPendPool, len(q.resultCache)-len(q.blockDonePool)) } // ReserveNodeData reserves a set of node data hashes for the given peer, skipping // any previously failed download. func (q *queue) ReserveNodeData(p *peer, count int) *fetchRequest { - return q.reserveHashes(p, count, q.stateTaskQueue, q.statePendPool, 0) + // Create a task generator to fetch status-fetch tasks if all schedules ones are done + generator := func(max int) { + q.stateSchedLock.Lock() + defer q.stateSchedLock.Unlock() + + for _, hash := range q.stateScheduler.Missing(max) { + q.stateTaskPool[hash] = q.stateTaskIndex + q.stateTaskQueue.Push(hash, -float32(q.stateTaskIndex)) + q.stateTaskIndex++ + } + } + return q.reserveHashes(p, count, q.stateTaskQueue, generator, q.statePendPool, count) } // reserveHashes reserves a set of hashes for the given peer, skipping previously // failed ones. -func (q *queue) reserveHashes(p *peer, count int, taskQueue *prque.Prque, pendPool map[string]*fetchRequest, maxPending int) *fetchRequest { +func (q *queue) reserveHashes(p *peer, count int, taskQueue *prque.Prque, taskGen func(int), pendPool map[string]*fetchRequest, maxPending int) *fetchRequest { q.lock.Lock() defer q.lock.Unlock() - // Short circuit if the pool has been depleted, or if the peer's already - // downloading something (sanity check not to corrupt state) - if taskQueue.Empty() { - return nil - } + // Short circuit if the peer's already downloading something (sanity check not + // to corrupt state) if _, ok := pendPool[p.id]; ok { return nil } @@ -403,6 +449,13 @@ func (q *queue) reserveHashes(p *peer, count int, taskQueue *prque.Prque, pendPo allowance -= len(request.Hashes) } } + // If there's a task generator, ask it to fill our task queue + if taskGen != nil && taskQueue.Size() < allowance { + taskGen(allowance - taskQueue.Size()) + } + if taskQueue.Empty() { + return nil + } // Retrieve a batch of hashes, skipping previously failed ones send := make(map[common.Hash]int) skip := make(map[common.Hash]int) @@ -809,14 +862,14 @@ func (q *queue) deliver(id string, taskPool map[common.Hash]*types.Header, taskQ } // DeliverNodeData injects a node state data retrieval response into the queue. -func (q *queue) DeliverNodeData(id string, data [][]byte) (int, int, error) { +func (q *queue) DeliverNodeData(id string, data [][]byte, callback func(error, int)) error { q.lock.Lock() defer q.lock.Unlock() // Short circuit if the data was never requested request := q.statePendPool[id] if request == nil { - return 0, 0, errNoFetchesPending + return errNoFetchesPending } stateReqTimer.UpdateSince(request.Time) delete(q.statePendPool, id) @@ -829,7 +882,7 @@ func (q *queue) DeliverNodeData(id string, data [][]byte) (int, int, error) { } // Iterate over the downloaded data and verify each of them errs := make([]error, 0) - processed := 0 + process := []trie.SyncResult{} for _, blob := range data { // Skip any blocks that were not requested hash := common.BytesToHash(crypto.Sha3(blob)) @@ -837,41 +890,58 @@ func (q *queue) DeliverNodeData(id string, data [][]byte) (int, int, error) { errs = append(errs, fmt.Errorf("non-requested state data %x", hash)) continue } - // Inject the next state trie item into the database - if err := q.stateScheduler.Process([]trie.SyncResult{{hash, blob}}); err != nil { - errs = []error{err} - break - } - processed++ + // Inject the next state trie item into the processing queue + process = append(process, trie.SyncResult{hash, blob}) delete(request.Hashes, hash) delete(q.stateTaskPool, hash) } + // Start the asynchronous node state data injection + atomic.AddInt32(&q.stateProcessors, 1) + go func() { + defer atomic.AddInt32(&q.stateProcessors, -1) + q.deliverNodeData(process, callback) + }() // Return all failed or missing fetches to the queue for hash, index := range request.Hashes { q.stateTaskQueue.Push(hash, float32(index)) } - // Also enqueue any newly required state trie nodes - discovered := 0 - if len(q.stateTaskPool) < maxQueuedStates { - for _, hash := range q.stateScheduler.Missing(4 * MaxStateFetch) { - q.stateTaskPool[hash] = q.stateTaskIndex - q.stateTaskQueue.Push(hash, -float32(q.stateTaskIndex)) - q.stateTaskIndex++ - discovered++ - } - } // If none of the data items were good, it's a stale delivery switch { case len(errs) == 0: - return processed, discovered, nil + return nil case len(errs) == len(request.Hashes): - return processed, discovered, errStaleDelivery + return errStaleDelivery default: - return processed, discovered, fmt.Errorf("multiple failures: %v", errs) + return fmt.Errorf("multiple failures: %v", errs) + } +} + +// deliverNodeData is the asynchronous node data processor that injects a batch +// of sync results into the state scheduler. +func (q *queue) deliverNodeData(results []trie.SyncResult, callback func(error, int)) { + // Process results one by one to permit task fetches in between + for i, result := range results { + q.stateSchedLock.Lock() + + if q.stateScheduler == nil { + // Syncing aborted since this async delivery started, bail out + q.stateSchedLock.Unlock() + callback(errNoFetchesPending, i) + return + } + if _, err := q.stateScheduler.Process([]trie.SyncResult{result}); err != nil { + // Processing a state result failed, bail out + q.stateSchedLock.Unlock() + callback(err, i) + return + } + // Item processing succeeded, release the lock (temporarily) + q.stateSchedLock.Unlock() } + callback(nil, len(results)) } // Prepare configures the result cache to allow accepting and caching inbound diff --git a/eth/downloader/types.go b/eth/downloader/types.go index 221ef38f66..60d9a2b120 100644 --- a/eth/downloader/types.go +++ b/eth/downloader/types.go @@ -52,7 +52,7 @@ type headBlockCommitterFn func(common.Hash) error type tdRetrievalFn func(common.Hash) *big.Int // headerChainInsertFn is a callback type to insert a batch of headers into the local chain. -type headerChainInsertFn func([]*types.Header, bool) (int, error) +type headerChainInsertFn func([]*types.Header, int) (int, error) // blockChainInsertFn is a callback type to insert a batch of blocks into the local chain. type blockChainInsertFn func(types.Blocks) (int, error) diff --git a/ethdb/memory_database.go b/ethdb/memory_database.go index 81911f23fa..330834fa4b 100644 --- a/ethdb/memory_database.go +++ b/ethdb/memory_database.go @@ -18,6 +18,7 @@ package ethdb import ( "fmt" + "sync" "github.com/ethereum/go-ethereum/common" ) @@ -26,29 +27,42 @@ import ( * This is a test memory database. Do not use for any production it does not get persisted */ type MemDatabase struct { - db map[string][]byte + db map[string][]byte + lock sync.RWMutex } func NewMemDatabase() (*MemDatabase, error) { - db := &MemDatabase{db: make(map[string][]byte)} - - return db, nil + return &MemDatabase{ + db: make(map[string][]byte), + }, nil } func (db *MemDatabase) Put(key []byte, value []byte) error { + db.lock.Lock() + defer db.lock.Unlock() + db.db[string(key)] = common.CopyBytes(value) return nil } func (db *MemDatabase) Set(key []byte, value []byte) { + db.lock.Lock() + defer db.lock.Unlock() + db.Put(key, value) } func (db *MemDatabase) Get(key []byte) ([]byte, error) { + db.lock.RLock() + defer db.lock.RUnlock() + return db.db[string(key)], nil } func (db *MemDatabase) Keys() [][]byte { + db.lock.RLock() + defer db.lock.RUnlock() + keys := [][]byte{} for key, _ := range db.db { keys = append(keys, []byte(key)) @@ -65,12 +79,17 @@ func (db *MemDatabase) GetKeys() []*common.Key { */ func (db *MemDatabase) Delete(key []byte) error { - delete(db.db, string(key)) + db.lock.Lock() + defer db.lock.Unlock() + delete(db.db, string(key)) return nil } func (db *MemDatabase) Print() { + db.lock.RLock() + defer db.lock.RUnlock() + for key, val := range db.db { fmt.Printf("%x(%d): ", key, len(key)) node := common.NewValueFromBytes(val) @@ -83,11 +102,9 @@ func (db *MemDatabase) Close() { func (db *MemDatabase) LastKnownTD() []byte { data, _ := db.Get([]byte("LastKnownTotalDifficulty")) - if len(data) == 0 || data == nil { data = []byte{0x0} } - return data } @@ -100,16 +117,26 @@ type kv struct{ k, v []byte } type memBatch struct { db *MemDatabase writes []kv + lock sync.RWMutex } -func (w *memBatch) Put(key, value []byte) error { - w.writes = append(w.writes, kv{key, common.CopyBytes(value)}) +func (b *memBatch) Put(key, value []byte) error { + b.lock.Lock() + defer b.lock.Unlock() + + b.writes = append(b.writes, kv{key, common.CopyBytes(value)}) return nil } -func (w *memBatch) Write() error { - for _, kv := range w.writes { - w.db.db[string(kv.k)] = kv.v +func (b *memBatch) Write() error { + b.lock.RLock() + defer b.lock.RUnlock() + + b.db.lock.RLock() + defer b.db.lock.RUnlock() + + for _, kv := range b.writes { + b.db.db[string(kv.k)] = kv.v } return nil } diff --git a/trie/sync.go b/trie/sync.go index 65cfd6ed8a..bb112fb621 100644 --- a/trie/sync.go +++ b/trie/sync.go @@ -20,6 +20,7 @@ import ( "fmt" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/ethdb" "gopkg.in/karalabe/cookiejar.v2/collections/prque" ) @@ -50,15 +51,15 @@ type TrieSyncLeafCallback func(leaf []byte, parent common.Hash) error // TrieSync is the main state trie synchronisation scheduler, which provides yet // unknown trie hashes to retrieve, accepts node data associated with said hashes -// and reconstructs the trie steb by step until all is done. +// and reconstructs the trie step by step until all is done. type TrieSync struct { - database Database // State database for storing all the assembled node data + database ethdb.Database // State database for storing all the assembled node data requests map[common.Hash]*request // Pending requests pertaining to a key hash queue *prque.Prque // Priority queue with the pending requests } // NewTrieSync creates a new trie data download scheduler. -func NewTrieSync(root common.Hash, database Database, callback TrieSyncLeafCallback) *TrieSync { +func NewTrieSync(root common.Hash, database ethdb.Database, callback TrieSyncLeafCallback) *TrieSync { ts := &TrieSync{ database: database, requests: make(map[common.Hash]*request), @@ -70,10 +71,14 @@ func NewTrieSync(root common.Hash, database Database, callback TrieSyncLeafCallb // AddSubTrie registers a new trie to the sync code, rooted at the designated parent. func (s *TrieSync) AddSubTrie(root common.Hash, depth int, parent common.Hash, callback TrieSyncLeafCallback) { - // Short circuit if the trie is empty + // Short circuit if the trie is empty or already known if root == emptyRoot { return } + blob, _ := s.database.Get(root.Bytes()) + if local, err := decodeNode(blob); local != nil && err == nil { + return + } // Assemble the new sub-trie sync request node := node(hashNode(root.Bytes())) req := &request{ @@ -94,6 +99,35 @@ func (s *TrieSync) AddSubTrie(root common.Hash, depth int, parent common.Hash, c s.schedule(req) } +// AddRawEntry schedules the direct retrieval of a state entry that should not be +// interpreted as a trie node, but rather accepted and stored into the database +// as is. This method's goal is to support misc state metadata retrievals (e.g. +// contract code). +func (s *TrieSync) AddRawEntry(hash common.Hash, depth int, parent common.Hash) { + // Short circuit if the entry is empty or already known + if hash == emptyState { + return + } + if blob, _ := s.database.Get(hash.Bytes()); blob != nil { + return + } + // Assemble the new sub-trie sync request + req := &request{ + hash: hash, + depth: depth, + } + // If this sub-trie has a designated parent, link them together + if parent != (common.Hash{}) { + ancestor := s.requests[parent] + if ancestor == nil { + panic(fmt.Sprintf("raw-entry ancestor not found: %x", parent)) + } + ancestor.deps++ + req.parents = append(req.parents, ancestor) + } + s.schedule(req) +} + // Missing retrieves the known missing nodes from the trie for retrieval. func (s *TrieSync) Missing(max int) []common.Hash { requests := []common.Hash{} @@ -111,6 +145,12 @@ func (s *TrieSync) Process(results []SyncResult) (int, error) { if request == nil { return i, fmt.Errorf("not requested: %x", item.Hash) } + // If the item is a raw entry request, commit directly + if request.object == nil { + request.data = item.Data + s.commit(request, nil) + continue + } // Decode the node data content and update the request node, err := decodeNode(item.Data) if err != nil { @@ -125,7 +165,7 @@ func (s *TrieSync) Process(results []SyncResult) (int, error) { return i, err } if len(requests) == 0 && request.deps == 0 { - s.commit(request) + s.commit(request, nil) continue } request.deps += len(requests) @@ -136,6 +176,11 @@ func (s *TrieSync) Process(results []SyncResult) (int, error) { return 0, nil } +// Pending returns the number of state entries currently pending for download. +func (s *TrieSync) Pending() int { + return len(s.requests) +} + // schedule inserts a new state retrieval request into the fetch queue. If there // is already a pending request for this node, the new request will be discarded // and only a parent reference added to the old one. @@ -213,9 +258,16 @@ func (s *TrieSync) children(req *request) ([]*request, error) { // commit finalizes a retrieval request and stores it into the database. If any // of the referencing parent requests complete due to this commit, they are also // committed themselves. -func (s *TrieSync) commit(req *request) error { +func (s *TrieSync) commit(req *request, batch ethdb.Batch) (err error) { + // Create a new batch if none was specified + if batch == nil { + batch = s.database.NewBatch() + defer func() { + err = batch.Write() + }() + } // Write the node content to disk - if err := s.database.Put(req.hash[:], req.data); err != nil { + if err := batch.Put(req.hash[:], req.data); err != nil { return err } delete(s.requests, req.hash) @@ -224,7 +276,7 @@ func (s *TrieSync) commit(req *request) error { for _, parent := range req.parents { parent.deps-- if parent.deps == 0 { - if err := s.commit(parent); err != nil { + if err := s.commit(parent, batch); err != nil { return err } } diff --git a/trie/trie.go b/trie/trie.go index aa8d39fe2f..a3a383fb58 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -24,6 +24,7 @@ import ( "hash" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto/sha3" "github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger/glog" @@ -35,8 +36,12 @@ const defaultCacheCapacity = 800 var ( // The global cache stores decoded trie nodes by hash as they get loaded. globalCache = newARC(defaultCacheCapacity) + // This is the known root hash of an empty trie. emptyRoot = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") + + // This is the known hash of an empty state trie entry. + emptyState = crypto.Sha3Hash(nil) ) var ErrMissingRoot = errors.New("missing root node")