Merge pull request #2097 from karalabe/block-state-checks

core, eth/downloader: ensure state presence in ancestor lookup
pull/2110/head
Jeffrey Wilcke 9 years ago
commit 32226f1b0c
  1. 13
      core/blockchain.go
  2. 1
      core/state/statedb.go
  3. 22
      eth/downloader/downloader.go
  4. 17
      eth/downloader/downloader_test.go
  5. 4
      eth/downloader/types.go
  6. 7
      eth/handler.go

@ -587,6 +587,19 @@ func (bc *BlockChain) HasBlock(hash common.Hash) bool {
return bc.GetBlock(hash) != nil return bc.GetBlock(hash) != nil
} }
// HasBlockAndState checks if a block and associated state trie is fully present
// in the database or not, caching it if present.
func (bc *BlockChain) HasBlockAndState(hash common.Hash) bool {
// Check first that the block itself is known
block := bc.GetBlock(hash)
if block == nil {
return false
}
// Ensure the associated state is also present
_, err := state.New(block.Root(), bc.chainDb)
return err == nil
}
// GetBlock retrieves a block from the database by hash, caching it if found. // GetBlock retrieves a block from the database by hash, caching it if found.
func (self *BlockChain) GetBlock(hash common.Hash) *types.Block { func (self *BlockChain) GetBlock(hash common.Hash) *types.Block {
// Short circuit if the block's already in the cache, retrieve otherwise // Short circuit if the block's already in the cache, retrieve otherwise

@ -57,7 +57,6 @@ type StateDB struct {
func New(root common.Hash, db ethdb.Database) (*StateDB, error) { func New(root common.Hash, db ethdb.Database) (*StateDB, error) {
tr, err := trie.NewSecure(root, db) tr, err := trie.NewSecure(root, db)
if err != nil { if err != nil {
glog.Errorf("can't create state trie with root %x: %v", root[:], err)
return nil, err return nil, err
} }
return &StateDB{ return &StateDB{

@ -113,7 +113,7 @@ type Downloader struct {
// Callbacks // Callbacks
hasHeader headerCheckFn // Checks if a header is present in the chain hasHeader headerCheckFn // Checks if a header is present in the chain
hasBlock blockCheckFn // Checks if a block is present in the chain hasBlockAndState blockAndStateCheckFn // Checks if a block and associated state is present in the chain
getHeader headerRetrievalFn // Retrieves a header from the chain getHeader headerRetrievalFn // Retrieves a header from the chain
getBlock blockRetrievalFn // Retrieves a block from the chain getBlock blockRetrievalFn // Retrieves a block from the chain
headHeader headHeaderRetrievalFn // Retrieves the head header from the chain headHeader headHeaderRetrievalFn // Retrieves the head header from the chain
@ -156,10 +156,10 @@ type Downloader struct {
} }
// New creates a new downloader to fetch hashes and blocks from remote peers. // New creates a new downloader to fetch hashes and blocks from remote peers.
func New(stateDb ethdb.Database, mux *event.TypeMux, hasHeader headerCheckFn, hasBlock blockCheckFn, getHeader headerRetrievalFn, func New(stateDb ethdb.Database, mux *event.TypeMux, hasHeader headerCheckFn, hasBlockAndState blockAndStateCheckFn,
getBlock blockRetrievalFn, headHeader headHeaderRetrievalFn, headBlock headBlockRetrievalFn, headFastBlock headFastBlockRetrievalFn, getHeader headerRetrievalFn, getBlock blockRetrievalFn, headHeader headHeaderRetrievalFn, headBlock headBlockRetrievalFn,
commitHeadBlock headBlockCommitterFn, getTd tdRetrievalFn, insertHeaders headerChainInsertFn, insertBlocks blockChainInsertFn, headFastBlock headFastBlockRetrievalFn, commitHeadBlock headBlockCommitterFn, getTd tdRetrievalFn, insertHeaders headerChainInsertFn,
insertReceipts receiptChainInsertFn, rollback chainRollbackFn, dropPeer peerDropFn) *Downloader { insertBlocks blockChainInsertFn, insertReceipts receiptChainInsertFn, rollback chainRollbackFn, dropPeer peerDropFn) *Downloader {
return &Downloader{ return &Downloader{
mode: FullSync, mode: FullSync,
@ -167,7 +167,7 @@ func New(stateDb ethdb.Database, mux *event.TypeMux, hasHeader headerCheckFn, ha
queue: newQueue(stateDb), queue: newQueue(stateDb),
peers: newPeerSet(), peers: newPeerSet(),
hasHeader: hasHeader, hasHeader: hasHeader,
hasBlock: hasBlock, hasBlockAndState: hasBlockAndState,
getHeader: getHeader, getHeader: getHeader,
getBlock: getBlock, getBlock: getBlock,
headHeader: headHeader, headHeader: headHeader,
@ -564,7 +564,7 @@ func (d *Downloader) findAncestor61(p *peer) (uint64, error) {
// Check if a common ancestor was found // Check if a common ancestor was found
finished = true finished = true
for i := len(hashes) - 1; i >= 0; i-- { for i := len(hashes) - 1; i >= 0; i-- {
if d.hasBlock(hashes[i]) { if d.hasBlockAndState(hashes[i]) {
number, hash = uint64(from)+uint64(i), hashes[i] number, hash = uint64(from)+uint64(i), hashes[i]
break break
} }
@ -620,11 +620,11 @@ func (d *Downloader) findAncestor61(p *peer) (uint64, error) {
arrived = true arrived = true
// Modify the search interval based on the response // Modify the search interval based on the response
block := d.getBlock(hashes[0]) if !d.hasBlockAndState(hashes[0]) {
if block == nil {
end = check end = check
break break
} }
block := d.getBlock(hashes[0]) // this doesn't check state, hence the above explicit check
if block.NumberU64() != check { if block.NumberU64() != check {
glog.V(logger.Debug).Infof("%v: non requested hash #%d [%x…], instead of #%d", p, block.NumberU64(), block.Hash().Bytes()[:4], check) glog.V(logger.Debug).Infof("%v: non requested hash #%d [%x…], instead of #%d", p, block.NumberU64(), block.Hash().Bytes()[:4], check)
return 0, errBadPeer return 0, errBadPeer
@ -989,7 +989,7 @@ func (d *Downloader) findAncestor(p *peer) (uint64, error) {
// Check if a common ancestor was found // Check if a common ancestor was found
finished = true finished = true
for i := len(headers) - 1; i >= 0; i-- { for i := len(headers) - 1; i >= 0; i-- {
if (d.mode != LightSync && d.hasBlock(headers[i].Hash())) || (d.mode == LightSync && d.hasHeader(headers[i].Hash())) { if (d.mode != LightSync && d.hasBlockAndState(headers[i].Hash())) || (d.mode == LightSync && d.hasHeader(headers[i].Hash())) {
number, hash = headers[i].Number.Uint64(), headers[i].Hash() number, hash = headers[i].Number.Uint64(), headers[i].Hash()
break break
} }
@ -1045,7 +1045,7 @@ func (d *Downloader) findAncestor(p *peer) (uint64, error) {
arrived = true arrived = true
// Modify the search interval based on the response // Modify the search interval based on the response
if (d.mode == FullSync && !d.hasBlock(headers[0].Hash())) || (d.mode != FullSync && !d.hasHeader(headers[0].Hash())) { if (d.mode == FullSync && !d.hasBlockAndState(headers[0].Hash())) || (d.mode != FullSync && !d.hasHeader(headers[0].Hash())) {
end = check end = check
break break
} }

@ -153,6 +153,8 @@ func newTester() *downloadTester {
peerChainTds: make(map[string]map[common.Hash]*big.Int), peerChainTds: make(map[string]map[common.Hash]*big.Int),
} }
tester.stateDb, _ = ethdb.NewMemDatabase() tester.stateDb, _ = ethdb.NewMemDatabase()
tester.stateDb.Put(genesis.Root().Bytes(), []byte{0x00})
tester.downloader = New(tester.stateDb, new(event.TypeMux), tester.hasHeader, tester.hasBlock, tester.getHeader, tester.downloader = New(tester.stateDb, new(event.TypeMux), tester.hasHeader, tester.hasBlock, tester.getHeader,
tester.getBlock, tester.headHeader, tester.headBlock, tester.headFastBlock, tester.commitHeadBlock, tester.getTd, tester.getBlock, tester.headHeader, tester.headBlock, tester.headFastBlock, tester.commitHeadBlock, tester.getTd,
tester.insertHeaders, tester.insertBlocks, tester.insertReceipts, tester.rollback, tester.dropPeer) tester.insertHeaders, tester.insertBlocks, tester.insertReceipts, tester.rollback, tester.dropPeer)
@ -180,9 +182,14 @@ func (dl *downloadTester) hasHeader(hash common.Hash) bool {
return dl.getHeader(hash) != nil return dl.getHeader(hash) != nil
} }
// hasBlock checks if a block is present in the testers canonical chain. // hasBlock checks if a block and associated state is present in the testers canonical chain.
func (dl *downloadTester) hasBlock(hash common.Hash) bool { func (dl *downloadTester) hasBlock(hash common.Hash) bool {
return dl.getBlock(hash) != nil block := dl.getBlock(hash)
if block == nil {
return false
}
_, err := dl.stateDb.Get(block.Root().Bytes())
return err == nil
} }
// getHeader retrieves a header from the testers canonical chain. // getHeader retrieves a header from the testers canonical chain.
@ -295,8 +302,10 @@ func (dl *downloadTester) insertBlocks(blocks types.Blocks) (int, error) {
defer dl.lock.Unlock() defer dl.lock.Unlock()
for i, block := range blocks { for i, block := range blocks {
if _, ok := dl.ownBlocks[block.ParentHash()]; !ok { if parent, ok := dl.ownBlocks[block.ParentHash()]; !ok {
return i, errors.New("unknown parent") return i, errors.New("unknown parent")
} else if _, err := dl.stateDb.Get(parent.Root().Bytes()); err != nil {
return i, fmt.Errorf("unknown parent state %x: %v", parent.Root(), err)
} }
if _, ok := dl.ownHeaders[block.Hash()]; !ok { if _, ok := dl.ownHeaders[block.Hash()]; !ok {
dl.ownHashes = append(dl.ownHashes, block.Hash()) dl.ownHashes = append(dl.ownHashes, block.Hash())
@ -1103,6 +1112,8 @@ func testShiftedHeaderAttack(t *testing.T, protocol int, mode SyncMode) {
} }
// Tests that upon detecting an invalid header, the recent ones are rolled back // Tests that upon detecting an invalid header, the recent ones are rolled back
// for various failure scenarios. Afterwards a full sync is attempted to make
// sure no state was corrupted.
func TestInvalidHeaderRollback63Fast(t *testing.T) { testInvalidHeaderRollback(t, 63, FastSync) } func TestInvalidHeaderRollback63Fast(t *testing.T) { testInvalidHeaderRollback(t, 63, FastSync) }
func TestInvalidHeaderRollback64Fast(t *testing.T) { testInvalidHeaderRollback(t, 64, FastSync) } func TestInvalidHeaderRollback64Fast(t *testing.T) { testInvalidHeaderRollback(t, 64, FastSync) }
func TestInvalidHeaderRollback64Light(t *testing.T) { testInvalidHeaderRollback(t, 64, LightSync) } func TestInvalidHeaderRollback64Light(t *testing.T) { testInvalidHeaderRollback(t, 64, LightSync) }

@ -27,8 +27,8 @@ import (
// headerCheckFn is a callback type for verifying a header's presence in the local chain. // headerCheckFn is a callback type for verifying a header's presence in the local chain.
type headerCheckFn func(common.Hash) bool type headerCheckFn func(common.Hash) bool
// blockCheckFn is a callback type for verifying a block's presence in the local chain. // blockAndStateCheckFn is a callback type for verifying block and associated states' presence in the local chain.
type blockCheckFn func(common.Hash) bool type blockAndStateCheckFn func(common.Hash) bool
// headerRetrievalFn is a callback type for retrieving a header from the local chain. // headerRetrievalFn is a callback type for retrieving a header from the local chain.
type headerRetrievalFn func(common.Hash) *types.Header type headerRetrievalFn func(common.Hash) *types.Header

@ -138,9 +138,10 @@ func NewProtocolManager(fastSync bool, networkId int, mux *event.TypeMux, txpool
return nil, errIncompatibleConfig return nil, errIncompatibleConfig
} }
// Construct the different synchronisation mechanisms // Construct the different synchronisation mechanisms
manager.downloader = downloader.New(chaindb, manager.eventMux, blockchain.HasHeader, blockchain.HasBlock, blockchain.GetHeader, blockchain.GetBlock, manager.downloader = downloader.New(chaindb, manager.eventMux, blockchain.HasHeader, blockchain.HasBlockAndState, blockchain.GetHeader,
blockchain.CurrentHeader, blockchain.CurrentBlock, blockchain.CurrentFastBlock, blockchain.FastSyncCommitHead, blockchain.GetTd, blockchain.GetBlock, blockchain.CurrentHeader, blockchain.CurrentBlock, blockchain.CurrentFastBlock, blockchain.FastSyncCommitHead,
blockchain.InsertHeaderChain, blockchain.InsertChain, blockchain.InsertReceiptChain, blockchain.Rollback, manager.removePeer) blockchain.GetTd, blockchain.InsertHeaderChain, blockchain.InsertChain, blockchain.InsertReceiptChain, blockchain.Rollback,
manager.removePeer)
validator := func(block *types.Block, parent *types.Block) error { validator := func(block *types.Block, parent *types.Block) error {
return core.ValidateHeader(pow, block.Header(), parent.Header(), true, false) return core.ValidateHeader(pow, block.Header(), parent.Header(), true, false)

Loading…
Cancel
Save