diff --git a/core/blockchain.go b/core/blockchain.go index ea26fa0345..34832252a7 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -1524,6 +1524,18 @@ func (bc *BlockChain) GetBlockHashesFromHash(hash common.Hash, max uint64) []com return bc.hc.GetBlockHashesFromHash(hash, max) } +// GetAncestor retrieves the Nth ancestor of a given block. It assumes that either the given block or +// a close ancestor of it is canonical. maxNonCanonical points to a downwards counter limiting the +// number of blocks to be individually checked before we reach the canonical chain. +// +// Note: ancestor == 0 returns the same block, 1 returns its parent and so on. +func (bc *BlockChain) GetAncestor(hash common.Hash, number, ancestor uint64, maxNonCanonical *uint64) (common.Hash, uint64) { + bc.chainmu.Lock() + defer bc.chainmu.Unlock() + + return bc.hc.GetAncestor(hash, number, ancestor, maxNonCanonical) +} + // GetHeaderByNumber retrieves a block header from the database by number, // caching it (associated with its hash) if found. func (bc *BlockChain) GetHeaderByNumber(number uint64) *types.Header { diff --git a/core/headerchain.go b/core/headerchain.go index 2ac0cccc72..6e759ed1c1 100644 --- a/core/headerchain.go +++ b/core/headerchain.go @@ -307,6 +307,43 @@ func (hc *HeaderChain) GetBlockHashesFromHash(hash common.Hash, max uint64) []co return chain } +// GetAncestor retrieves the Nth ancestor of a given block. It assumes that either the given block or +// a close ancestor of it is canonical. maxNonCanonical points to a downwards counter limiting the +// number of blocks to be individually checked before we reach the canonical chain. +// +// Note: ancestor == 0 returns the same block, 1 returns its parent and so on. +func (hc *HeaderChain) GetAncestor(hash common.Hash, number, ancestor uint64, maxNonCanonical *uint64) (common.Hash, uint64) { + if ancestor > number { + return common.Hash{}, 0 + } + if ancestor == 1 { + // in this case it is cheaper to just read the header + if header := hc.GetHeader(hash, number); header != nil { + return header.ParentHash, number - 1 + } else { + return common.Hash{}, 0 + } + } + for ancestor != 0 { + if rawdb.ReadCanonicalHash(hc.chainDb, number) == hash { + number -= ancestor + return rawdb.ReadCanonicalHash(hc.chainDb, number), number + } + if *maxNonCanonical == 0 { + return common.Hash{}, 0 + } + *maxNonCanonical-- + ancestor-- + header := hc.GetHeader(hash, number) + if header == nil { + return common.Hash{}, 0 + } + hash = header.ParentHash + number-- + } + return hash, number +} + // GetTd retrieves a block's total difficulty in the canonical chain from the // database by hash and number, caching it if found. func (hc *HeaderChain) GetTd(hash common.Hash, number uint64) *big.Int { diff --git a/eth/handler.go b/eth/handler.go index 918d71088d..a46e7f13cb 100644 --- a/eth/handler.go +++ b/eth/handler.go @@ -340,6 +340,8 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { return errResp(ErrDecode, "%v: %v", msg, err) } hashMode := query.Origin.Hash != (common.Hash{}) + first := true + maxNonCanonical := uint64(100) // Gather headers until the fetch or network limits is reached var ( @@ -351,31 +353,36 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { // Retrieve the next header satisfying the query var origin *types.Header if hashMode { - origin = pm.blockchain.GetHeaderByHash(query.Origin.Hash) + if first { + first = false + origin = pm.blockchain.GetHeaderByHash(query.Origin.Hash) + if origin != nil { + query.Origin.Number = origin.Number.Uint64() + } + } else { + origin = pm.blockchain.GetHeader(query.Origin.Hash, query.Origin.Number) + } } else { origin = pm.blockchain.GetHeaderByNumber(query.Origin.Number) } if origin == nil { break } - number := origin.Number.Uint64() headers = append(headers, origin) bytes += estHeaderRlpSize // Advance to the next header of the query switch { - case query.Origin.Hash != (common.Hash{}) && query.Reverse: + case hashMode && query.Reverse: // Hash based traversal towards the genesis block - for i := 0; i < int(query.Skip)+1; i++ { - if header := pm.blockchain.GetHeader(query.Origin.Hash, number); header != nil { - query.Origin.Hash = header.ParentHash - number-- - } else { - unknown = true - break - } + ancestor := query.Skip + 1 + if ancestor == 0 { + unknown = true + } else { + query.Origin.Hash, query.Origin.Number = pm.blockchain.GetAncestor(query.Origin.Hash, query.Origin.Number, ancestor, &maxNonCanonical) + unknown = (query.Origin.Hash == common.Hash{}) } - case query.Origin.Hash != (common.Hash{}) && !query.Reverse: + case hashMode && !query.Reverse: // Hash based traversal towards the leaf block var ( current = origin.Number.Uint64() @@ -387,8 +394,10 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { unknown = true } else { if header := pm.blockchain.GetHeaderByNumber(next); header != nil { - if pm.blockchain.GetBlockHashesFromHash(header.Hash(), query.Skip+1)[query.Skip] == query.Origin.Hash { - query.Origin.Hash = header.Hash() + nextHash := header.Hash() + expOldHash, _ := pm.blockchain.GetAncestor(nextHash, next, query.Skip+1, &maxNonCanonical) + if expOldHash == query.Origin.Hash { + query.Origin.Hash, query.Origin.Number = nextHash, next } else { unknown = true } diff --git a/les/handler.go b/les/handler.go index 38f810d721..a1c16cb875 100644 --- a/les/handler.go +++ b/les/handler.go @@ -83,7 +83,7 @@ type BlockChain interface { InsertHeaderChain(chain []*types.Header, checkFreq int) (int, error) Rollback(chain []common.Hash) GetHeaderByNumber(number uint64) *types.Header - GetBlockHashesFromHash(hash common.Hash, max uint64) []common.Hash + GetAncestor(hash common.Hash, number, ancestor uint64, maxNonCanonical *uint64) (common.Hash, uint64) Genesis() *types.Block SubscribeChainHeadEvent(ch chan<- core.ChainHeadEvent) event.Subscription } @@ -419,6 +419,8 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { } hashMode := query.Origin.Hash != (common.Hash{}) + first := true + maxNonCanonical := uint64(100) // Gather headers until the fetch or network limits is reached var ( @@ -430,14 +432,21 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { // Retrieve the next header satisfying the query var origin *types.Header if hashMode { - origin = pm.blockchain.GetHeaderByHash(query.Origin.Hash) + if first { + first = false + origin = pm.blockchain.GetHeaderByHash(query.Origin.Hash) + if origin != nil { + query.Origin.Number = origin.Number.Uint64() + } + } else { + origin = pm.blockchain.GetHeader(query.Origin.Hash, query.Origin.Number) + } } else { origin = pm.blockchain.GetHeaderByNumber(query.Origin.Number) } if origin == nil { break } - number := origin.Number.Uint64() headers = append(headers, origin) bytes += estHeaderRlpSize @@ -445,14 +454,12 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { switch { case hashMode && query.Reverse: // Hash based traversal towards the genesis block - for i := 0; i < int(query.Skip)+1; i++ { - if header := pm.blockchain.GetHeader(query.Origin.Hash, number); header != nil { - query.Origin.Hash = header.ParentHash - number-- - } else { - unknown = true - break - } + ancestor := query.Skip + 1 + if ancestor == 0 { + unknown = true + } else { + query.Origin.Hash, query.Origin.Number = pm.blockchain.GetAncestor(query.Origin.Hash, query.Origin.Number, ancestor, &maxNonCanonical) + unknown = (query.Origin.Hash == common.Hash{}) } case hashMode && !query.Reverse: // Hash based traversal towards the leaf block @@ -466,8 +473,10 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { unknown = true } else { if header := pm.blockchain.GetHeaderByNumber(next); header != nil { - if pm.blockchain.GetBlockHashesFromHash(header.Hash(), query.Skip+1)[query.Skip] == query.Origin.Hash { - query.Origin.Hash = header.Hash() + nextHash := header.Hash() + expOldHash, _ := pm.blockchain.GetAncestor(nextHash, next, query.Skip+1, &maxNonCanonical) + if expOldHash == query.Origin.Hash { + query.Origin.Hash, query.Origin.Number = nextHash, next } else { unknown = true } diff --git a/light/lightchain.go b/light/lightchain.go index 9d0a4e4f73..30b9bd89a6 100644 --- a/light/lightchain.go +++ b/light/lightchain.go @@ -433,6 +433,18 @@ func (self *LightChain) GetBlockHashesFromHash(hash common.Hash, max uint64) []c return self.hc.GetBlockHashesFromHash(hash, max) } +// GetAncestor retrieves the Nth ancestor of a given block. It assumes that either the given block or +// a close ancestor of it is canonical. maxNonCanonical points to a downwards counter limiting the +// number of blocks to be individually checked before we reach the canonical chain. +// +// Note: ancestor == 0 returns the same block, 1 returns its parent and so on. +func (bc *LightChain) GetAncestor(hash common.Hash, number, ancestor uint64, maxNonCanonical *uint64) (common.Hash, uint64) { + bc.chainmu.Lock() + defer bc.chainmu.Unlock() + + return bc.hc.GetAncestor(hash, number, ancestor, maxNonCanonical) +} + // GetHeaderByNumber retrieves a block header from the database by number, // caching it (associated with its hash) if found. func (self *LightChain) GetHeaderByNumber(number uint64) *types.Header {