diff --git a/eth/downloader/downloader.go b/eth/downloader/downloader.go index e4d1392d0a..6ac58140a3 100644 --- a/eth/downloader/downloader.go +++ b/eth/downloader/downloader.go @@ -114,21 +114,11 @@ type Downloader struct { syncStatsState stateSyncStats syncStatsLock sync.RWMutex // Lock protecting the sync stats fields + lightchain LightChain + blockchain BlockChain + // Callbacks - hasHeader headerCheckFn // Checks if a header is present in the chain - hasBlockAndState blockAndStateCheckFn // Checks if a block and associated state is present in the chain - getHeader headerRetrievalFn // Retrieves a header from the chain - getBlock blockRetrievalFn // Retrieves a block from the chain - headHeader headHeaderRetrievalFn // Retrieves the head header from the chain - headBlock headBlockRetrievalFn // Retrieves the head block from the chain - headFastBlock headFastBlockRetrievalFn // Retrieves the head fast-sync block from the chain - commitHeadBlock headBlockCommitterFn // Commits a manually assembled block as the chain head - getTd tdRetrievalFn // Retrieves the TD of a block from the chain - insertHeaders headerChainInsertFn // Injects a batch of headers into the chain - insertBlocks blockChainInsertFn // Injects a batch of blocks into the chain - insertReceipts receiptChainInsertFn // Injects a batch of blocks and their receipts into the chain - rollback chainRollbackFn // Removes a batch of recently added chain links - dropPeer peerDropFn // Drops a peer for misbehaving + dropPeer peerDropFn // Drops a peer for misbehaving // Status synchroniseMock func(id string, hash common.Hash) error // Replacement for synchronise during testing @@ -163,45 +153,80 @@ type Downloader struct { chainInsertHook func([]*fetchResult) // Method to call upon inserting a chain of blocks (possibly in multiple invocations) } +// LightChain encapsulates functions required to synchronise a light chain. +type LightChain interface { + // HasHeader verifies a header's presence in the local chain. + HasHeader(common.Hash) bool + + // GetHeaderByHash retrieves a header from the local chain. + GetHeaderByHash(common.Hash) *types.Header + + // CurrentHeader retrieves the head header from the local chain. + CurrentHeader() *types.Header + + // GetTdByHash returns the total difficulty of a local block. + GetTdByHash(common.Hash) *big.Int + + // InsertHeaderChain inserts a batch of headers into the local chain. + InsertHeaderChain([]*types.Header, int) (int, error) + + // Rollback removes a few recently added elements from the local chain. + Rollback([]common.Hash) +} + +// BlockChain encapsulates functions required to sync a (full or fast) blockchain. +type BlockChain interface { + LightChain + + // HasBlockAndState verifies block and associated states' presence in the local chain. + HasBlockAndState(common.Hash) bool + + // GetBlockByHash retrieves a block from the local chain. + GetBlockByHash(common.Hash) *types.Block + + // CurrentBlock retrieves the head block from the local chain. + CurrentBlock() *types.Block + + // CurrentFastBlock retrieves the head fast block from the local chain. + CurrentFastBlock() *types.Block + + // FastSyncCommitHead directly commits the head block to a certain entity. + FastSyncCommitHead(common.Hash) error + + // InsertChain inserts a batch of blocks into the local chain. + InsertChain(types.Blocks) (int, error) + + // InsertReceiptChain inserts a batch of receipts into the local chain. + InsertReceiptChain(types.Blocks, []types.Receipts) (int, error) +} + // New creates a new downloader to fetch hashes and blocks from remote peers. -func New(mode SyncMode, stateDb ethdb.Database, mux *event.TypeMux, hasHeader headerCheckFn, hasBlockAndState blockAndStateCheckFn, - getHeader headerRetrievalFn, getBlock blockRetrievalFn, headHeader headHeaderRetrievalFn, headBlock headBlockRetrievalFn, - headFastBlock headFastBlockRetrievalFn, commitHeadBlock headBlockCommitterFn, getTd tdRetrievalFn, insertHeaders headerChainInsertFn, - insertBlocks blockChainInsertFn, insertReceipts receiptChainInsertFn, rollback chainRollbackFn, dropPeer peerDropFn) *Downloader { +func New(mode SyncMode, stateDb ethdb.Database, mux *event.TypeMux, chain BlockChain, lightchain LightChain, dropPeer peerDropFn) *Downloader { + if lightchain == nil { + lightchain = chain + } dl := &Downloader{ - mode: mode, - mux: mux, - queue: newQueue(), - peers: newPeerSet(), - stateDB: stateDb, - rttEstimate: uint64(rttMaxEstimate), - rttConfidence: uint64(1000000), - hasHeader: hasHeader, - hasBlockAndState: hasBlockAndState, - getHeader: getHeader, - getBlock: getBlock, - headHeader: headHeader, - headBlock: headBlock, - headFastBlock: headFastBlock, - commitHeadBlock: commitHeadBlock, - getTd: getTd, - insertHeaders: insertHeaders, - insertBlocks: insertBlocks, - insertReceipts: insertReceipts, - rollback: rollback, - dropPeer: dropPeer, - headerCh: make(chan dataPack, 1), - bodyCh: make(chan dataPack, 1), - receiptCh: make(chan dataPack, 1), - bodyWakeCh: make(chan bool, 1), - receiptWakeCh: make(chan bool, 1), - headerProcCh: make(chan []*types.Header, 1), - quitCh: make(chan struct{}), - // for stateFetcher + mode: mode, + stateDB: stateDb, + mux: mux, + queue: newQueue(), + peers: newPeerSet(), + rttEstimate: uint64(rttMaxEstimate), + rttConfidence: uint64(1000000), + blockchain: chain, + lightchain: lightchain, + dropPeer: dropPeer, + headerCh: make(chan dataPack, 1), + bodyCh: make(chan dataPack, 1), + receiptCh: make(chan dataPack, 1), + bodyWakeCh: make(chan bool, 1), + receiptWakeCh: make(chan bool, 1), + headerProcCh: make(chan []*types.Header, 1), + quitCh: make(chan struct{}), + stateCh: make(chan dataPack), stateSyncStart: make(chan *stateSync), trackStateReq: make(chan *stateReq), - stateCh: make(chan dataPack), } go dl.qosTuner() go dl.stateFetcher() @@ -223,11 +248,11 @@ func (d *Downloader) Progress() ethereum.SyncProgress { current := uint64(0) switch d.mode { case FullSync: - current = d.headBlock().NumberU64() + current = d.blockchain.CurrentBlock().NumberU64() case FastSync: - current = d.headFastBlock().NumberU64() + current = d.blockchain.CurrentFastBlock().NumberU64() case LightSync: - current = d.headHeader().Number.Uint64() + current = d.lightchain.CurrentHeader().Number.Uint64() } return ethereum.SyncProgress{ StartingBlock: d.syncStatsChainOrigin, @@ -245,13 +270,11 @@ func (d *Downloader) Synchronising() bool { // RegisterPeer injects a new download peer into the set of block source to be // used for fetching hashes and blocks from. -func (d *Downloader) RegisterPeer(id string, version int, currentHead currentHeadRetrievalFn, - getRelHeaders relativeHeaderFetcherFn, getAbsHeaders absoluteHeaderFetcherFn, getBlockBodies blockBodyFetcherFn, - getReceipts receiptFetcherFn, getNodeData stateFetcherFn) error { +func (d *Downloader) RegisterPeer(id string, version int, peer Peer) error { logger := log.New("peer", id) logger.Trace("Registering sync peer") - if err := d.peers.Register(newPeer(id, version, currentHead, getRelHeaders, getAbsHeaders, getBlockBodies, getReceipts, getNodeData, logger)); err != nil { + if err := d.peers.Register(newPeerConnection(id, version, peer, logger)); err != nil { logger.Error("Failed to register sync peer", "err", err) return err } @@ -260,6 +283,11 @@ func (d *Downloader) RegisterPeer(id string, version int, currentHead currentHea return nil } +// RegisterLightPeer injects a light client peer, wrapping it so it appears as a regular peer. +func (d *Downloader) RegisterLightPeer(id string, version int, peer LightPeer) error { + return d.RegisterPeer(id, version, &lightPeerWrapper{peer}) +} + // UnregisterPeer remove a peer from the known list, preventing any action from // the specified peer. An effort is also made to return any pending fetches into // the queue. @@ -371,7 +399,7 @@ func (d *Downloader) synchronise(id string, hash common.Hash, td *big.Int, mode // syncWithPeer starts a block synchronization based on the hash chain from the // specified peer and head hash. -func (d *Downloader) syncWithPeer(p *peer, hash common.Hash, td *big.Int) (err error) { +func (d *Downloader) syncWithPeer(p *peerConnection, hash common.Hash, td *big.Int) (err error) { d.mux.Post(StartEvent{}) defer func() { // reset on error @@ -524,12 +552,12 @@ func (d *Downloader) Terminate() { // fetchHeight retrieves the head header of the remote peer to aid in estimating // the total time a pending synchronisation would take. -func (d *Downloader) fetchHeight(p *peer) (*types.Header, error) { +func (d *Downloader) fetchHeight(p *peerConnection) (*types.Header, error) { p.log.Debug("Retrieving remote chain height") // Request the advertised remote head block and wait for the response - head, _ := p.currentHead() - go p.getRelHeaders(head, 1, 0, false) + head, _ := p.peer.Head() + go p.peer.RequestHeadersByHash(head, 1, 0, false) ttl := d.requestTTL() timeout := time.After(ttl) @@ -570,15 +598,15 @@ func (d *Downloader) fetchHeight(p *peer) (*types.Header, error) { // on the correct chain, checking the top N links should already get us a match. // In the rare scenario when we ended up on a long reorganisation (i.e. none of // the head links match), we do a binary search to find the common ancestor. -func (d *Downloader) findAncestor(p *peer, height uint64) (uint64, error) { +func (d *Downloader) findAncestor(p *peerConnection, height uint64) (uint64, error) { // Figure out the valid ancestor range to prevent rewrite attacks - floor, ceil := int64(-1), d.headHeader().Number.Uint64() + floor, ceil := int64(-1), d.lightchain.CurrentHeader().Number.Uint64() p.log.Debug("Looking for common ancestor", "local", ceil, "remote", height) if d.mode == FullSync { - ceil = d.headBlock().NumberU64() + ceil = d.blockchain.CurrentBlock().NumberU64() } else if d.mode == FastSync { - ceil = d.headFastBlock().NumberU64() + ceil = d.blockchain.CurrentFastBlock().NumberU64() } if ceil >= MaxForkAncestry { floor = int64(ceil - MaxForkAncestry) @@ -598,7 +626,7 @@ func (d *Downloader) findAncestor(p *peer, height uint64) (uint64, error) { if count > limit { count = limit } - go p.getAbsHeaders(uint64(from), count, 15, false) + go p.peer.RequestHeadersByNumber(uint64(from), count, 15, false) // Wait for the remote response to the head fetch number, hash := uint64(0), common.Hash{} @@ -638,7 +666,7 @@ func (d *Downloader) findAncestor(p *peer, height uint64) (uint64, error) { continue } // Otherwise check if we already know the header or not - if (d.mode == FullSync && d.hasBlockAndState(headers[i].Hash())) || (d.mode != FullSync && d.hasHeader(headers[i].Hash())) { + if (d.mode == FullSync && d.blockchain.HasBlockAndState(headers[i].Hash())) || (d.mode != FullSync && d.lightchain.HasHeader(headers[i].Hash())) { number, hash = headers[i].Number.Uint64(), headers[i].Hash() // If every header is known, even future ones, the peer straight out lied about its head @@ -680,7 +708,7 @@ func (d *Downloader) findAncestor(p *peer, height uint64) (uint64, error) { ttl := d.requestTTL() timeout := time.After(ttl) - go p.getAbsHeaders(uint64(check), 1, 0, false) + go p.peer.RequestHeadersByNumber(uint64(check), 1, 0, false) // Wait until a reply arrives to this request for arrived := false; !arrived; { @@ -703,11 +731,11 @@ func (d *Downloader) findAncestor(p *peer, height uint64) (uint64, error) { arrived = true // Modify the search interval based on the response - if (d.mode == FullSync && !d.hasBlockAndState(headers[0].Hash())) || (d.mode != FullSync && !d.hasHeader(headers[0].Hash())) { + if (d.mode == FullSync && !d.blockchain.HasBlockAndState(headers[0].Hash())) || (d.mode != FullSync && !d.lightchain.HasHeader(headers[0].Hash())) { end = check break } - header := d.getHeader(headers[0].Hash()) // Independent of sync mode, header surely exists + header := d.lightchain.GetHeaderByHash(headers[0].Hash()) // Independent of sync mode, header surely exists if header.Number.Uint64() != check { p.log.Debug("Received non requested header", "number", header.Number, "hash", header.Hash(), "request", check) return 0, errBadPeer @@ -741,7 +769,7 @@ func (d *Downloader) findAncestor(p *peer, height uint64) (uint64, error) { // other peers are only accepted if they map cleanly to the skeleton. If no one // can fill in the skeleton - not even the origin peer - it's assumed invalid and // the origin is dropped. -func (d *Downloader) fetchHeaders(p *peer, from uint64) error { +func (d *Downloader) fetchHeaders(p *peerConnection, from uint64) error { p.log.Debug("Directing header downloads", "origin", from) defer p.log.Debug("Header download terminated") @@ -761,10 +789,10 @@ func (d *Downloader) fetchHeaders(p *peer, from uint64) error { if skeleton { p.log.Trace("Fetching skeleton headers", "count", MaxHeaderFetch, "from", from) - go p.getAbsHeaders(from+uint64(MaxHeaderFetch)-1, MaxSkeletonSize, MaxHeaderFetch-1, false) + go p.peer.RequestHeadersByNumber(from+uint64(MaxHeaderFetch)-1, MaxSkeletonSize, MaxHeaderFetch-1, false) } else { p.log.Trace("Fetching full headers", "count", MaxHeaderFetch, "from", from) - go p.getAbsHeaders(from, MaxHeaderFetch, 0, false) + go p.peer.RequestHeadersByNumber(from, MaxHeaderFetch, 0, false) } } // Start pulling the header chain skeleton until all is done @@ -866,12 +894,12 @@ func (d *Downloader) fillHeaderSkeleton(from uint64, skeleton []*types.Header) ( } expire = func() map[string]int { return d.queue.ExpireHeaders(d.requestTTL()) } throttle = func() bool { return false } - reserve = func(p *peer, count int) (*fetchRequest, bool, error) { + reserve = func(p *peerConnection, count int) (*fetchRequest, bool, error) { return d.queue.ReserveHeaders(p, count), false, nil } - fetch = func(p *peer, req *fetchRequest) error { return p.FetchHeaders(req.From, MaxHeaderFetch) } - capacity = func(p *peer) int { return p.HeaderCapacity(d.requestRTT()) } - setIdle = func(p *peer, accepted int) { p.SetHeadersIdle(accepted) } + fetch = func(p *peerConnection, req *fetchRequest) error { return p.FetchHeaders(req.From, MaxHeaderFetch) } + capacity = func(p *peerConnection) int { return p.HeaderCapacity(d.requestRTT()) } + setIdle = func(p *peerConnection, accepted int) { p.SetHeadersIdle(accepted) } ) err := d.fetchParts(errCancelHeaderFetch, d.headerCh, deliver, d.queue.headerContCh, expire, d.queue.PendingHeaders, d.queue.InFlightHeaders, throttle, reserve, @@ -895,9 +923,9 @@ func (d *Downloader) fetchBodies(from uint64) error { return d.queue.DeliverBodies(pack.peerId, pack.transactions, pack.uncles) } expire = func() map[string]int { return d.queue.ExpireBodies(d.requestTTL()) } - fetch = func(p *peer, req *fetchRequest) error { return p.FetchBodies(req) } - capacity = func(p *peer) int { return p.BlockCapacity(d.requestRTT()) } - setIdle = func(p *peer, accepted int) { p.SetBodiesIdle(accepted) } + fetch = func(p *peerConnection, req *fetchRequest) error { return p.FetchBodies(req) } + capacity = func(p *peerConnection) int { return p.BlockCapacity(d.requestRTT()) } + setIdle = func(p *peerConnection, accepted int) { p.SetBodiesIdle(accepted) } ) err := d.fetchParts(errCancelBodyFetch, d.bodyCh, deliver, d.bodyWakeCh, expire, d.queue.PendingBlocks, d.queue.InFlightBlocks, d.queue.ShouldThrottleBlocks, d.queue.ReserveBodies, @@ -919,9 +947,9 @@ func (d *Downloader) fetchReceipts(from uint64) error { return d.queue.DeliverReceipts(pack.peerId, pack.receipts) } expire = func() map[string]int { return d.queue.ExpireReceipts(d.requestTTL()) } - fetch = func(p *peer, req *fetchRequest) error { return p.FetchReceipts(req) } - capacity = func(p *peer) int { return p.ReceiptCapacity(d.requestRTT()) } - setIdle = func(p *peer, accepted int) { p.SetReceiptsIdle(accepted) } + fetch = func(p *peerConnection, req *fetchRequest) error { return p.FetchReceipts(req) } + capacity = func(p *peerConnection) int { return p.ReceiptCapacity(d.requestRTT()) } + setIdle = func(p *peerConnection, accepted int) { p.SetReceiptsIdle(accepted) } ) err := d.fetchParts(errCancelReceiptFetch, d.receiptCh, deliver, d.receiptWakeCh, expire, d.queue.PendingReceipts, d.queue.InFlightReceipts, d.queue.ShouldThrottleReceipts, d.queue.ReserveReceipts, @@ -957,9 +985,9 @@ func (d *Downloader) fetchReceipts(from uint64) error { // - setIdle: network callback to set a peer back to idle and update its estimated capacity (traffic shaping) // - kind: textual label of the type being downloaded to display in log mesages func (d *Downloader) fetchParts(errCancel error, deliveryCh chan dataPack, deliver func(dataPack) (int, error), wakeCh chan bool, - expire func() map[string]int, pending func() int, inFlight func() bool, throttle func() bool, reserve func(*peer, int) (*fetchRequest, bool, error), - fetchHook func([]*types.Header), fetch func(*peer, *fetchRequest) error, cancel func(*fetchRequest), capacity func(*peer) int, - idle func() ([]*peer, int), setIdle func(*peer, int), kind string) error { + expire func() map[string]int, pending func() int, inFlight func() bool, throttle func() bool, reserve func(*peerConnection, int) (*fetchRequest, bool, error), + fetchHook func([]*types.Header), fetch func(*peerConnection, *fetchRequest) error, cancel func(*fetchRequest), capacity func(*peerConnection) int, + idle func() ([]*peerConnection, int), setIdle func(*peerConnection, int), kind string) error { // Create a ticker to detect expired retrieval tasks ticker := time.NewTicker(100 * time.Millisecond) @@ -1124,23 +1152,19 @@ func (d *Downloader) processHeaders(origin uint64, td *big.Int) error { for i, header := range rollback { hashes[i] = header.Hash() } - lastHeader, lastFastBlock, lastBlock := d.headHeader().Number, common.Big0, common.Big0 - if d.headFastBlock != nil { - lastFastBlock = d.headFastBlock().Number() + lastHeader, lastFastBlock, lastBlock := d.lightchain.CurrentHeader().Number, common.Big0, common.Big0 + if d.mode != LightSync { + lastFastBlock = d.blockchain.CurrentFastBlock().Number() + lastBlock = d.blockchain.CurrentBlock().Number() } - if d.headBlock != nil { - lastBlock = d.headBlock().Number() - } - d.rollback(hashes) + d.lightchain.Rollback(hashes) curFastBlock, curBlock := common.Big0, common.Big0 - if d.headFastBlock != nil { - curFastBlock = d.headFastBlock().Number() - } - if d.headBlock != nil { - curBlock = d.headBlock().Number() + if d.mode != LightSync { + curFastBlock = d.blockchain.CurrentFastBlock().Number() + curBlock = d.blockchain.CurrentBlock().Number() } log.Warn("Rolled back headers", "count", len(hashes), - "header", fmt.Sprintf("%d->%d", lastHeader, d.headHeader().Number), + "header", fmt.Sprintf("%d->%d", lastHeader, d.lightchain.CurrentHeader().Number), "fast", fmt.Sprintf("%d->%d", lastFastBlock, curFastBlock), "block", fmt.Sprintf("%d->%d", lastBlock, curBlock)) @@ -1190,7 +1214,7 @@ func (d *Downloader) processHeaders(origin uint64, td *big.Int) error { // L: Request new headers up from 11 (R's TD was higher, it must have something) // R: Nothing to give if d.mode != LightSync { - if !gotHeaders && td.Cmp(d.getTd(d.headBlock().Hash())) > 0 { + if !gotHeaders && td.Cmp(d.blockchain.GetTdByHash(d.blockchain.CurrentBlock().Hash())) > 0 { return errStallingPeer } } @@ -1202,7 +1226,7 @@ func (d *Downloader) processHeaders(origin uint64, td *big.Int) error { // queued for processing when the header download completes. However, as long as the // peer gave us something useful, we're already happy/progressed (above check). if d.mode == FastSync || d.mode == LightSync { - if td.Cmp(d.getTd(d.headHeader().Hash())) > 0 { + if td.Cmp(d.lightchain.GetTdByHash(d.lightchain.CurrentHeader().Hash())) > 0 { return errStallingPeer } } @@ -1232,7 +1256,7 @@ func (d *Downloader) processHeaders(origin uint64, td *big.Int) error { // Collect the yet unknown headers to mark them as uncertain unknown := make([]*types.Header, 0, len(headers)) for _, header := range chunk { - if !d.hasHeader(header.Hash()) { + if !d.lightchain.HasHeader(header.Hash()) { unknown = append(unknown, header) } } @@ -1241,7 +1265,7 @@ func (d *Downloader) processHeaders(origin uint64, td *big.Int) error { if chunk[len(chunk)-1].Number.Uint64()+uint64(fsHeaderForceVerify) > pivot { frequency = 1 } - if n, err := d.insertHeaders(chunk, frequency); err != nil { + if n, err := d.lightchain.InsertHeaderChain(chunk, frequency); err != nil { // If some headers were inserted, add them too to the rollback list if n > 0 { rollback = append(rollback, chunk[:n]...) @@ -1328,7 +1352,7 @@ func (d *Downloader) importBlockResults(results []*fetchResult) error { for i, result := range results[:items] { blocks[i] = types.NewBlockWithHeader(result.Header).WithBody(result.Transactions, result.Uncles) } - if index, err := d.insertBlocks(blocks); err != nil { + if index, err := d.blockchain.InsertChain(blocks); err != nil { log.Debug("Downloaded item processing failed", "number", results[index].Header.Number, "hash", results[index].Header.Hash(), "err", err) return errInvalidChain } @@ -1368,6 +1392,7 @@ func (d *Downloader) processFastSyncContent(latest *types.Header) error { stateSync.Cancel() if err := d.commitPivotBlock(P); err != nil { return err + } } if err := d.importBlockResults(afterP); err != nil { @@ -1416,7 +1441,7 @@ func (d *Downloader) commitFastSyncData(results []*fetchResult, stateSync *state blocks[i] = types.NewBlockWithHeader(result.Header).WithBody(result.Transactions, result.Uncles) receipts[i] = result.Receipts } - if index, err := d.insertReceipts(blocks, receipts); err != nil { + if index, err := d.blockchain.InsertReceiptChain(blocks, receipts); err != nil { log.Debug("Downloaded item processing failed", "number", results[index].Header.Number, "hash", results[index].Header.Hash(), "err", err) return errInvalidChain } @@ -1434,10 +1459,10 @@ func (d *Downloader) commitPivotBlock(result *fetchResult) error { return err } log.Debug("Committing fast sync pivot as new head", "number", b.Number(), "hash", b.Hash()) - if _, err := d.insertReceipts([]*types.Block{b}, []types.Receipts{result.Receipts}); err != nil { + if _, err := d.blockchain.InsertReceiptChain([]*types.Block{b}, []types.Receipts{result.Receipts}); err != nil { return err } - return d.commitHeadBlock(b.Hash()) + return d.blockchain.FastSyncCommitHead(b.Hash()) } // DeliverHeaders injects a new batch of block headers received from a remote diff --git a/eth/downloader/downloader_test.go b/eth/downloader/downloader_test.go index 1fb5a0910e..b354682a14 100644 --- a/eth/downloader/downloader_test.go +++ b/eth/downloader/downloader_test.go @@ -96,9 +96,7 @@ func newTester() *downloadTester { tester.stateDb, _ = ethdb.NewMemDatabase() tester.stateDb.Put(genesis.Root().Bytes(), []byte{0x00}) - tester.downloader = New(FullSync, tester.stateDb, new(event.TypeMux), tester.hasHeader, tester.hasBlock, tester.getHeader, - tester.getBlock, tester.headHeader, tester.headBlock, tester.headFastBlock, tester.commitHeadBlock, tester.getTd, - tester.insertHeaders, tester.insertBlocks, tester.insertReceipts, tester.rollback, tester.dropPeer) + tester.downloader = New(FullSync, tester.stateDb, new(event.TypeMux), tester, nil, tester.dropPeer) return tester } @@ -218,14 +216,14 @@ func (dl *downloadTester) sync(id string, td *big.Int, mode SyncMode) error { return err } -// hasHeader checks if a header is present in the testers canonical chain. -func (dl *downloadTester) hasHeader(hash common.Hash) bool { - return dl.getHeader(hash) != nil +// HasHeader checks if a header is present in the testers canonical chain. +func (dl *downloadTester) HasHeader(hash common.Hash) bool { + return dl.GetHeaderByHash(hash) != nil } -// hasBlock checks if a block and associated state is present in the testers canonical chain. -func (dl *downloadTester) hasBlock(hash common.Hash) bool { - block := dl.getBlock(hash) +// HasBlockAndState checks if a block and associated state is present in the testers canonical chain. +func (dl *downloadTester) HasBlockAndState(hash common.Hash) bool { + block := dl.GetBlockByHash(hash) if block == nil { return false } @@ -233,24 +231,24 @@ func (dl *downloadTester) hasBlock(hash common.Hash) bool { return err == nil } -// getHeader retrieves a header from the testers canonical chain. -func (dl *downloadTester) getHeader(hash common.Hash) *types.Header { +// GetHeader retrieves a header from the testers canonical chain. +func (dl *downloadTester) GetHeaderByHash(hash common.Hash) *types.Header { dl.lock.RLock() defer dl.lock.RUnlock() return dl.ownHeaders[hash] } -// getBlock retrieves a block from the testers canonical chain. -func (dl *downloadTester) getBlock(hash common.Hash) *types.Block { +// GetBlock retrieves a block from the testers canonical chain. +func (dl *downloadTester) GetBlockByHash(hash common.Hash) *types.Block { dl.lock.RLock() defer dl.lock.RUnlock() return dl.ownBlocks[hash] } -// headHeader retrieves the current head header from the canonical chain. -func (dl *downloadTester) headHeader() *types.Header { +// CurrentHeader retrieves the current head header from the canonical chain. +func (dl *downloadTester) CurrentHeader() *types.Header { dl.lock.RLock() defer dl.lock.RUnlock() @@ -262,8 +260,8 @@ func (dl *downloadTester) headHeader() *types.Header { return dl.genesis.Header() } -// headBlock retrieves the current head block from the canonical chain. -func (dl *downloadTester) headBlock() *types.Block { +// CurrentBlock retrieves the current head block from the canonical chain. +func (dl *downloadTester) CurrentBlock() *types.Block { dl.lock.RLock() defer dl.lock.RUnlock() @@ -277,8 +275,8 @@ func (dl *downloadTester) headBlock() *types.Block { return dl.genesis } -// headFastBlock retrieves the current head fast-sync block from the canonical chain. -func (dl *downloadTester) headFastBlock() *types.Block { +// CurrentFastBlock retrieves the current head fast-sync block from the canonical chain. +func (dl *downloadTester) CurrentFastBlock() *types.Block { dl.lock.RLock() defer dl.lock.RUnlock() @@ -290,26 +288,26 @@ func (dl *downloadTester) headFastBlock() *types.Block { return dl.genesis } -// commitHeadBlock manually sets the head block to a given hash. -func (dl *downloadTester) commitHeadBlock(hash common.Hash) error { +// FastSyncCommitHead manually sets the head block to a given hash. +func (dl *downloadTester) FastSyncCommitHead(hash common.Hash) error { // For now only check that the state trie is correct - if block := dl.getBlock(hash); block != nil { + if block := dl.GetBlockByHash(hash); block != nil { _, err := trie.NewSecure(block.Root(), dl.stateDb, 0) return err } return fmt.Errorf("non existent block: %x", hash[:4]) } -// getTd retrieves the block's total difficulty from the canonical chain. -func (dl *downloadTester) getTd(hash common.Hash) *big.Int { +// GetTdByHash retrieves the block's total difficulty from the canonical chain. +func (dl *downloadTester) GetTdByHash(hash common.Hash) *big.Int { dl.lock.RLock() defer dl.lock.RUnlock() return dl.ownChainTd[hash] } -// insertHeaders injects a new batch of headers into the simulated chain. -func (dl *downloadTester) insertHeaders(headers []*types.Header, checkFreq int) (int, error) { +// InsertHeaderChain injects a new batch of headers into the simulated chain. +func (dl *downloadTester) InsertHeaderChain(headers []*types.Header, checkFreq int) (int, error) { dl.lock.Lock() defer dl.lock.Unlock() @@ -337,8 +335,8 @@ func (dl *downloadTester) insertHeaders(headers []*types.Header, checkFreq int) return len(headers), nil } -// insertBlocks injects a new batch of blocks into the simulated chain. -func (dl *downloadTester) insertBlocks(blocks types.Blocks) (int, error) { +// InsertChain injects a new batch of blocks into the simulated chain. +func (dl *downloadTester) InsertChain(blocks types.Blocks) (int, error) { dl.lock.Lock() defer dl.lock.Unlock() @@ -359,8 +357,8 @@ func (dl *downloadTester) insertBlocks(blocks types.Blocks) (int, error) { return len(blocks), nil } -// insertReceipts injects a new batch of receipts into the simulated chain. -func (dl *downloadTester) insertReceipts(blocks types.Blocks, receipts []types.Receipts) (int, error) { +// InsertReceiptChain injects a new batch of receipts into the simulated chain. +func (dl *downloadTester) InsertReceiptChain(blocks types.Blocks, receipts []types.Receipts) (int, error) { dl.lock.Lock() defer dl.lock.Unlock() @@ -377,8 +375,8 @@ func (dl *downloadTester) insertReceipts(blocks types.Blocks, receipts []types.R return len(blocks), nil } -// rollback removes some recently added elements from the chain. -func (dl *downloadTester) rollback(hashes []common.Hash) { +// Rollback removes some recently added elements from the chain. +func (dl *downloadTester) Rollback(hashes []common.Hash) { dl.lock.Lock() defer dl.lock.Unlock() @@ -406,14 +404,7 @@ func (dl *downloadTester) newSlowPeer(id string, version int, hashes []common.Ha defer dl.lock.Unlock() var err error - switch version { - case 62: - err = dl.downloader.RegisterPeer(id, version, dl.peerCurrentHeadFn(id), dl.peerGetRelHeadersFn(id, delay), dl.peerGetAbsHeadersFn(id, delay), dl.peerGetBodiesFn(id, delay), nil, nil) - case 63: - err = dl.downloader.RegisterPeer(id, version, dl.peerCurrentHeadFn(id), dl.peerGetRelHeadersFn(id, delay), dl.peerGetAbsHeadersFn(id, delay), dl.peerGetBodiesFn(id, delay), dl.peerGetReceiptsFn(id, delay), dl.peerGetNodeDataFn(id, delay)) - case 64: - err = dl.downloader.RegisterPeer(id, version, dl.peerCurrentHeadFn(id), dl.peerGetRelHeadersFn(id, delay), dl.peerGetAbsHeadersFn(id, delay), dl.peerGetBodiesFn(id, delay), dl.peerGetReceiptsFn(id, delay), dl.peerGetNodeDataFn(id, delay)) - } + err = dl.downloader.RegisterPeer(id, version, &downloadTesterPeer{dl, id, delay}) if err == nil { // Assign the owned hashes, headers and blocks to the peer (deep copy) dl.peerHashes[id] = make([]common.Hash, len(hashes)) @@ -471,139 +462,133 @@ func (dl *downloadTester) dropPeer(id string) { dl.downloader.UnregisterPeer(id) } -// peerCurrentHeadFn constructs a function to retrieve a peer's current head hash +type downloadTesterPeer struct { + dl *downloadTester + id string + delay time.Duration +} + +// Head constructs a function to retrieve a peer's current head hash // and total difficulty. -func (dl *downloadTester) peerCurrentHeadFn(id string) func() (common.Hash, *big.Int) { - return func() (common.Hash, *big.Int) { - dl.lock.RLock() - defer dl.lock.RUnlock() +func (dlp *downloadTesterPeer) Head() (common.Hash, *big.Int) { + dlp.dl.lock.RLock() + defer dlp.dl.lock.RUnlock() - return dl.peerHashes[id][0], nil - } + return dlp.dl.peerHashes[dlp.id][0], nil } -// peerGetRelHeadersFn constructs a GetBlockHeaders function based on a hashed +// RequestHeadersByHash constructs a GetBlockHeaders function based on a hashed // origin; associated with a particular peer in the download tester. The returned // function can be used to retrieve batches of headers from the particular peer. -func (dl *downloadTester) peerGetRelHeadersFn(id string, delay time.Duration) func(common.Hash, int, int, bool) error { - return func(origin common.Hash, amount int, skip int, reverse bool) error { - // Find the canonical number of the hash - dl.lock.RLock() - number := uint64(0) - for num, hash := range dl.peerHashes[id] { - if hash == origin { - number = uint64(len(dl.peerHashes[id]) - num - 1) - break - } +func (dlp *downloadTesterPeer) RequestHeadersByHash(origin common.Hash, amount int, skip int, reverse bool) error { + // Find the canonical number of the hash + dlp.dl.lock.RLock() + number := uint64(0) + for num, hash := range dlp.dl.peerHashes[dlp.id] { + if hash == origin { + number = uint64(len(dlp.dl.peerHashes[dlp.id]) - num - 1) + break } - dl.lock.RUnlock() - - // Use the absolute header fetcher to satisfy the query - return dl.peerGetAbsHeadersFn(id, delay)(number, amount, skip, reverse) } + dlp.dl.lock.RUnlock() + + // Use the absolute header fetcher to satisfy the query + return dlp.RequestHeadersByNumber(number, amount, skip, reverse) } -// peerGetAbsHeadersFn constructs a GetBlockHeaders function based on a numbered +// RequestHeadersByNumber constructs a GetBlockHeaders function based on a numbered // origin; associated with a particular peer in the download tester. The returned // function can be used to retrieve batches of headers from the particular peer. -func (dl *downloadTester) peerGetAbsHeadersFn(id string, delay time.Duration) func(uint64, int, int, bool) error { - return func(origin uint64, amount int, skip int, reverse bool) error { - time.Sleep(delay) - - dl.lock.RLock() - defer dl.lock.RUnlock() - - // Gather the next batch of headers - hashes := dl.peerHashes[id] - headers := dl.peerHeaders[id] - result := make([]*types.Header, 0, amount) - for i := 0; i < amount && len(hashes)-int(origin)-1-i*(skip+1) >= 0; i++ { - if header, ok := headers[hashes[len(hashes)-int(origin)-1-i*(skip+1)]]; ok { - result = append(result, header) - } +func (dlp *downloadTesterPeer) RequestHeadersByNumber(origin uint64, amount int, skip int, reverse bool) error { + time.Sleep(dlp.delay) + + dlp.dl.lock.RLock() + defer dlp.dl.lock.RUnlock() + + // Gather the next batch of headers + hashes := dlp.dl.peerHashes[dlp.id] + headers := dlp.dl.peerHeaders[dlp.id] + result := make([]*types.Header, 0, amount) + for i := 0; i < amount && len(hashes)-int(origin)-1-i*(skip+1) >= 0; i++ { + if header, ok := headers[hashes[len(hashes)-int(origin)-1-i*(skip+1)]]; ok { + result = append(result, header) } - // Delay delivery a bit to allow attacks to unfold - go func() { - time.Sleep(time.Millisecond) - dl.downloader.DeliverHeaders(id, result) - }() - return nil } + // Delay delivery a bit to allow attacks to unfold + go func() { + time.Sleep(time.Millisecond) + dlp.dl.downloader.DeliverHeaders(dlp.id, result) + }() + return nil } -// peerGetBodiesFn constructs a getBlockBodies method associated with a particular +// RequestBodies constructs a getBlockBodies method associated with a particular // peer in the download tester. The returned function can be used to retrieve // batches of block bodies from the particularly requested peer. -func (dl *downloadTester) peerGetBodiesFn(id string, delay time.Duration) func([]common.Hash) error { - return func(hashes []common.Hash) error { - time.Sleep(delay) +func (dlp *downloadTesterPeer) RequestBodies(hashes []common.Hash) error { + time.Sleep(dlp.delay) - dl.lock.RLock() - defer dl.lock.RUnlock() + dlp.dl.lock.RLock() + defer dlp.dl.lock.RUnlock() - blocks := dl.peerBlocks[id] + blocks := dlp.dl.peerBlocks[dlp.id] - transactions := make([][]*types.Transaction, 0, len(hashes)) - uncles := make([][]*types.Header, 0, len(hashes)) + transactions := make([][]*types.Transaction, 0, len(hashes)) + uncles := make([][]*types.Header, 0, len(hashes)) - for _, hash := range hashes { - if block, ok := blocks[hash]; ok { - transactions = append(transactions, block.Transactions()) - uncles = append(uncles, block.Uncles()) - } + for _, hash := range hashes { + if block, ok := blocks[hash]; ok { + transactions = append(transactions, block.Transactions()) + uncles = append(uncles, block.Uncles()) } - go dl.downloader.DeliverBodies(id, transactions, uncles) - - return nil } + go dlp.dl.downloader.DeliverBodies(dlp.id, transactions, uncles) + + return nil } -// peerGetReceiptsFn constructs a getReceipts method associated with a particular +// RequestReceipts constructs a getReceipts method associated with a particular // peer in the download tester. The returned function can be used to retrieve // batches of block receipts from the particularly requested peer. -func (dl *downloadTester) peerGetReceiptsFn(id string, delay time.Duration) func([]common.Hash) error { - return func(hashes []common.Hash) error { - time.Sleep(delay) +func (dlp *downloadTesterPeer) RequestReceipts(hashes []common.Hash) error { + time.Sleep(dlp.delay) - dl.lock.RLock() - defer dl.lock.RUnlock() + dlp.dl.lock.RLock() + defer dlp.dl.lock.RUnlock() - receipts := dl.peerReceipts[id] + receipts := dlp.dl.peerReceipts[dlp.id] - results := make([][]*types.Receipt, 0, len(hashes)) - for _, hash := range hashes { - if receipt, ok := receipts[hash]; ok { - results = append(results, receipt) - } + results := make([][]*types.Receipt, 0, len(hashes)) + for _, hash := range hashes { + if receipt, ok := receipts[hash]; ok { + results = append(results, receipt) } - go dl.downloader.DeliverReceipts(id, results) - - return nil } + go dlp.dl.downloader.DeliverReceipts(dlp.id, results) + + return nil } -// peerGetNodeDataFn constructs a getNodeData method associated with a particular +// RequestNodeData constructs a getNodeData method associated with a particular // peer in the download tester. The returned function can be used to retrieve // batches of node state data from the particularly requested peer. -func (dl *downloadTester) peerGetNodeDataFn(id string, delay time.Duration) func([]common.Hash) error { - return func(hashes []common.Hash) error { - time.Sleep(delay) - - dl.lock.RLock() - defer dl.lock.RUnlock() - - results := make([][]byte, 0, len(hashes)) - for _, hash := range hashes { - if data, err := dl.peerDb.Get(hash.Bytes()); err == nil { - if !dl.peerMissingStates[id][hash] { - results = append(results, data) - } +func (dlp *downloadTesterPeer) RequestNodeData(hashes []common.Hash) error { + time.Sleep(dlp.delay) + + dlp.dl.lock.RLock() + defer dlp.dl.lock.RUnlock() + + results := make([][]byte, 0, len(hashes)) + for _, hash := range hashes { + if data, err := dlp.dl.peerDb.Get(hash.Bytes()); err == nil { + if !dlp.dl.peerMissingStates[dlp.id][hash] { + results = append(results, data) } } - go dl.downloader.DeliverNodeData(id, results) - - return nil } + go dlp.dl.downloader.DeliverNodeData(dlp.id, results) + + return nil } // assertOwnChain checks if the local chain contains the correct number of items @@ -1212,7 +1197,7 @@ func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) { if err := tester.sync("fast-attack", nil, mode); err == nil { t.Fatalf("succeeded fast attacker synchronisation") } - if head := tester.headHeader().Number.Int64(); int(head) > MaxHeaderFetch { + if head := tester.CurrentHeader().Number.Int64(); int(head) > MaxHeaderFetch { t.Errorf("rollback head mismatch: have %v, want at most %v", head, MaxHeaderFetch) } // Attempt to sync with an attacker that feeds junk during the block import phase. @@ -1226,11 +1211,11 @@ func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) { if err := tester.sync("block-attack", nil, mode); err == nil { t.Fatalf("succeeded block attacker synchronisation") } - if head := tester.headHeader().Number.Int64(); int(head) > 2*fsHeaderSafetyNet+MaxHeaderFetch { + if head := tester.CurrentHeader().Number.Int64(); int(head) > 2*fsHeaderSafetyNet+MaxHeaderFetch { t.Errorf("rollback head mismatch: have %v, want at most %v", head, 2*fsHeaderSafetyNet+MaxHeaderFetch) } if mode == FastSync { - if head := tester.headBlock().NumberU64(); head != 0 { + if head := tester.CurrentBlock().NumberU64(); head != 0 { t.Errorf("fast sync pivot block #%d not rolled back", head) } } @@ -1251,11 +1236,11 @@ func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) { if err := tester.sync("withhold-attack", nil, mode); err == nil { t.Fatalf("succeeded withholding attacker synchronisation") } - if head := tester.headHeader().Number.Int64(); int(head) > 2*fsHeaderSafetyNet+MaxHeaderFetch { + if head := tester.CurrentHeader().Number.Int64(); int(head) > 2*fsHeaderSafetyNet+MaxHeaderFetch { t.Errorf("rollback head mismatch: have %v, want at most %v", head, 2*fsHeaderSafetyNet+MaxHeaderFetch) } if mode == FastSync { - if head := tester.headBlock().NumberU64(); head != 0 { + if head := tester.CurrentBlock().NumberU64(); head != 0 { t.Errorf("fast sync pivot block #%d not rolled back", head) } } @@ -1670,6 +1655,48 @@ func TestDeliverHeadersHang64Full(t *testing.T) { testDeliverHeadersHang(t, 64, func TestDeliverHeadersHang64Fast(t *testing.T) { testDeliverHeadersHang(t, 64, FastSync) } func TestDeliverHeadersHang64Light(t *testing.T) { testDeliverHeadersHang(t, 64, LightSync) } +type floodingTestPeer struct { + peer Peer + tester *downloadTester +} + +func (ftp *floodingTestPeer) Head() (common.Hash, *big.Int) { return ftp.peer.Head() } +func (ftp *floodingTestPeer) RequestHeadersByHash(hash common.Hash, count int, skip int, reverse bool) error { + return ftp.peer.RequestHeadersByHash(hash, count, skip, reverse) +} +func (ftp *floodingTestPeer) RequestBodies(hashes []common.Hash) error { + return ftp.peer.RequestBodies(hashes) +} +func (ftp *floodingTestPeer) RequestReceipts(hashes []common.Hash) error { + return ftp.peer.RequestReceipts(hashes) +} +func (ftp *floodingTestPeer) RequestNodeData(hashes []common.Hash) error { + return ftp.peer.RequestNodeData(hashes) +} + +func (ftp *floodingTestPeer) RequestHeadersByNumber(from uint64, count, skip int, reverse bool) error { + deliveriesDone := make(chan struct{}, 500) + for i := 0; i < cap(deliveriesDone); i++ { + peer := fmt.Sprintf("fake-peer%d", i) + go func() { + ftp.tester.downloader.DeliverHeaders(peer, []*types.Header{{}, {}, {}, {}}) + deliveriesDone <- struct{}{} + }() + } + // Deliver the actual requested headers. + go ftp.peer.RequestHeadersByNumber(from, count, skip, reverse) + // None of the extra deliveries should block. + timeout := time.After(15 * time.Second) + for i := 0; i < cap(deliveriesDone); i++ { + select { + case <-deliveriesDone: + case <-timeout: + panic("blocked") + } + } + return nil +} + func testDeliverHeadersHang(t *testing.T, protocol int, mode SyncMode) { t.Parallel() @@ -1677,7 +1704,6 @@ func testDeliverHeadersHang(t *testing.T, protocol int, mode SyncMode) { defer master.terminate() hashes, headers, blocks, receipts := master.makeChain(5, 0, master.genesis, nil, false) - fakeHeads := []*types.Header{{}, {}, {}, {}} for i := 0; i < 200; i++ { tester := newTester() tester.peerDb = master.peerDb @@ -1685,29 +1711,11 @@ func testDeliverHeadersHang(t *testing.T, protocol int, mode SyncMode) { tester.newPeer("peer", protocol, hashes, headers, blocks, receipts) // Whenever the downloader requests headers, flood it with // a lot of unrequested header deliveries. - tester.downloader.peers.peers["peer"].getAbsHeaders = func(from uint64, count, skip int, reverse bool) error { - deliveriesDone := make(chan struct{}, 500) - for i := 0; i < cap(deliveriesDone); i++ { - peer := fmt.Sprintf("fake-peer%d", i) - go func() { - tester.downloader.DeliverHeaders(peer, fakeHeads) - deliveriesDone <- struct{}{} - }() - } - // Deliver the actual requested headers. - impl := tester.peerGetAbsHeadersFn("peer", 0) - go impl(from, count, skip, reverse) - // None of the extra deliveries should block. - timeout := time.After(15 * time.Second) - for i := 0; i < cap(deliveriesDone); i++ { - select { - case <-deliveriesDone: - case <-timeout: - panic("blocked") - } - } - return nil + tester.downloader.peers.peers["peer"].peer = &floodingTestPeer{ + tester.downloader.peers.peers["peer"].peer, + tester, } + if err := tester.sync("peer", nil, mode); err != nil { t.Errorf("sync failed: %v", err) } @@ -1739,7 +1747,7 @@ func testFastCriticalRestarts(t *testing.T, protocol int, progress bool) { for i := 0; i < fsPivotInterval; i++ { tester.peerMissingStates["peer"][headers[hashes[fsMinFullBlocks+i]].Root] = true } - tester.downloader.peers.peers["peer"].getNodeData = tester.peerGetNodeDataFn("peer", 500*time.Millisecond) // Enough to reach the critical section + (tester.downloader.peers.peers["peer"].peer).(*downloadTesterPeer).delay = 500 * time.Millisecond // Enough to reach the critical section // Synchronise with the peer a few times and make sure they fail until the retry limit for i := 0; i < int(fsCriticalTrials)-1; i++ { @@ -1758,7 +1766,7 @@ func testFastCriticalRestarts(t *testing.T, protocol int, progress bool) { tester.lock.Lock() tester.peerHeaders["peer"][hashes[fsMinFullBlocks-1]] = headers[hashes[fsMinFullBlocks-1]] tester.peerMissingStates["peer"] = map[common.Hash]bool{tester.downloader.fsPivotLock.Root: true} - tester.downloader.peers.peers["peer"].getNodeData = tester.peerGetNodeDataFn("peer", 0) + (tester.downloader.peers.peers["peer"].peer).(*downloadTesterPeer).delay = 0 tester.lock.Unlock() } } diff --git a/eth/downloader/peer.go b/eth/downloader/peer.go index dc8b097728..d0dc9a8aa6 100644 --- a/eth/downloader/peer.go +++ b/eth/downloader/peer.go @@ -39,24 +39,14 @@ const ( measurementImpact = 0.1 // The impact a single measurement has on a peer's final throughput value. ) -// Head hash and total difficulty retriever for -type currentHeadRetrievalFn func() (common.Hash, *big.Int) - -// Block header and body fetchers belonging to eth/62 and above -type relativeHeaderFetcherFn func(common.Hash, int, int, bool) error -type absoluteHeaderFetcherFn func(uint64, int, int, bool) error -type blockBodyFetcherFn func([]common.Hash) error -type receiptFetcherFn func([]common.Hash) error -type stateFetcherFn func([]common.Hash) error - var ( errAlreadyFetching = errors.New("already fetching blocks from peer") errAlreadyRegistered = errors.New("peer is already registered") errNotRegistered = errors.New("peer is not registered") ) -// peer represents an active peer from which hashes and blocks are retrieved. -type peer struct { +// peerConnection represents an active peer from which hashes and blocks are retrieved. +type peerConnection struct { id string // Unique identifier of the peer headerIdle int32 // Current header activity state of the peer (idle = 0, active = 1) @@ -78,37 +68,57 @@ type peer struct { lacking map[common.Hash]struct{} // Set of hashes not to request (didn't have previously) - currentHead currentHeadRetrievalFn // Method to fetch the currently known head of the peer - - getRelHeaders relativeHeaderFetcherFn // [eth/62] Method to retrieve a batch of headers from an origin hash - getAbsHeaders absoluteHeaderFetcherFn // [eth/62] Method to retrieve a batch of headers from an absolute position - getBlockBodies blockBodyFetcherFn // [eth/62] Method to retrieve a batch of block bodies - - getReceipts receiptFetcherFn // [eth/63] Method to retrieve a batch of block transaction receipts - getNodeData stateFetcherFn // [eth/63] Method to retrieve a batch of state trie data + peer Peer version int // Eth protocol version number to switch strategies log log.Logger // Contextual logger to add extra infos to peer logs lock sync.RWMutex } -// newPeer create a new downloader peer, with specific hash and block retrieval -// mechanisms. -func newPeer(id string, version int, currentHead currentHeadRetrievalFn, - getRelHeaders relativeHeaderFetcherFn, getAbsHeaders absoluteHeaderFetcherFn, getBlockBodies blockBodyFetcherFn, - getReceipts receiptFetcherFn, getNodeData stateFetcherFn, logger log.Logger) *peer { +// LightPeer encapsulates the methods required to synchronise with a remote light peer. +type LightPeer interface { + Head() (common.Hash, *big.Int) + RequestHeadersByHash(common.Hash, int, int, bool) error + RequestHeadersByNumber(uint64, int, int, bool) error +} + +// Peer encapsulates the methods required to synchronise with a remote full peer. +type Peer interface { + LightPeer + RequestBodies([]common.Hash) error + RequestReceipts([]common.Hash) error + RequestNodeData([]common.Hash) error +} + +// lightPeerWrapper wraps a LightPeer struct, stubbing out the Peer-only methods. +type lightPeerWrapper struct { + peer LightPeer +} + +func (w *lightPeerWrapper) Head() (common.Hash, *big.Int) { return w.peer.Head() } +func (w *lightPeerWrapper) RequestHeadersByHash(h common.Hash, amount int, skip int, reverse bool) error { + return w.peer.RequestHeadersByHash(h, amount, skip, reverse) +} +func (w *lightPeerWrapper) RequestHeadersByNumber(i uint64, amount int, skip int, reverse bool) error { + return w.peer.RequestHeadersByNumber(i, amount, skip, reverse) +} +func (w *lightPeerWrapper) RequestBodies([]common.Hash) error { + panic("RequestBodies not supported in light client mode sync") +} +func (w *lightPeerWrapper) RequestReceipts([]common.Hash) error { + panic("RequestReceipts not supported in light client mode sync") +} +func (w *lightPeerWrapper) RequestNodeData([]common.Hash) error { + panic("RequestNodeData not supported in light client mode sync") +} - return &peer{ +// newPeerConnection creates a new downloader peer. +func newPeerConnection(id string, version int, peer Peer, logger log.Logger) *peerConnection { + return &peerConnection{ id: id, lacking: make(map[common.Hash]struct{}), - currentHead: currentHead, - getRelHeaders: getRelHeaders, - getAbsHeaders: getAbsHeaders, - getBlockBodies: getBlockBodies, - - getReceipts: getReceipts, - getNodeData: getNodeData, + peer: peer, version: version, log: logger, @@ -116,7 +126,7 @@ func newPeer(id string, version int, currentHead currentHeadRetrievalFn, } // Reset clears the internal state of a peer entity. -func (p *peer) Reset() { +func (p *peerConnection) Reset() { p.lock.Lock() defer p.lock.Unlock() @@ -134,7 +144,7 @@ func (p *peer) Reset() { } // FetchHeaders sends a header retrieval request to the remote peer. -func (p *peer) FetchHeaders(from uint64, count int) error { +func (p *peerConnection) FetchHeaders(from uint64, count int) error { // Sanity check the protocol version if p.version < 62 { panic(fmt.Sprintf("header fetch [eth/62+] requested on eth/%d", p.version)) @@ -146,13 +156,13 @@ func (p *peer) FetchHeaders(from uint64, count int) error { p.headerStarted = time.Now() // Issue the header retrieval request (absolut upwards without gaps) - go p.getAbsHeaders(from, count, 0, false) + go p.peer.RequestHeadersByNumber(from, count, 0, false) return nil } // FetchBodies sends a block body retrieval request to the remote peer. -func (p *peer) FetchBodies(request *fetchRequest) error { +func (p *peerConnection) FetchBodies(request *fetchRequest) error { // Sanity check the protocol version if p.version < 62 { panic(fmt.Sprintf("body fetch [eth/62+] requested on eth/%d", p.version)) @@ -168,13 +178,13 @@ func (p *peer) FetchBodies(request *fetchRequest) error { for _, header := range request.Headers { hashes = append(hashes, header.Hash()) } - go p.getBlockBodies(hashes) + go p.peer.RequestBodies(hashes) return nil } // FetchReceipts sends a receipt retrieval request to the remote peer. -func (p *peer) FetchReceipts(request *fetchRequest) error { +func (p *peerConnection) FetchReceipts(request *fetchRequest) error { // Sanity check the protocol version if p.version < 63 { panic(fmt.Sprintf("body fetch [eth/63+] requested on eth/%d", p.version)) @@ -190,13 +200,13 @@ func (p *peer) FetchReceipts(request *fetchRequest) error { for _, header := range request.Headers { hashes = append(hashes, header.Hash()) } - go p.getReceipts(hashes) + go p.peer.RequestReceipts(hashes) return nil } // FetchNodeData sends a node state data retrieval request to the remote peer. -func (p *peer) FetchNodeData(hashes []common.Hash) error { +func (p *peerConnection) FetchNodeData(hashes []common.Hash) error { // Sanity check the protocol version if p.version < 63 { panic(fmt.Sprintf("node data fetch [eth/63+] requested on eth/%d", p.version)) @@ -206,48 +216,50 @@ func (p *peer) FetchNodeData(hashes []common.Hash) error { return errAlreadyFetching } p.stateStarted = time.Now() - go p.getNodeData(hashes) + + go p.peer.RequestNodeData(hashes) + return nil } // SetHeadersIdle sets the peer to idle, allowing it to execute new header retrieval // requests. Its estimated header retrieval throughput is updated with that measured // just now. -func (p *peer) SetHeadersIdle(delivered int) { +func (p *peerConnection) SetHeadersIdle(delivered int) { p.setIdle(p.headerStarted, delivered, &p.headerThroughput, &p.headerIdle) } // SetBlocksIdle sets the peer to idle, allowing it to execute new block retrieval // requests. Its estimated block retrieval throughput is updated with that measured // just now. -func (p *peer) SetBlocksIdle(delivered int) { +func (p *peerConnection) SetBlocksIdle(delivered int) { p.setIdle(p.blockStarted, delivered, &p.blockThroughput, &p.blockIdle) } // SetBodiesIdle sets the peer to idle, allowing it to execute block body retrieval // requests. Its estimated body retrieval throughput is updated with that measured // just now. -func (p *peer) SetBodiesIdle(delivered int) { +func (p *peerConnection) SetBodiesIdle(delivered int) { p.setIdle(p.blockStarted, delivered, &p.blockThroughput, &p.blockIdle) } // SetReceiptsIdle sets the peer to idle, allowing it to execute new receipt // retrieval requests. Its estimated receipt retrieval throughput is updated // with that measured just now. -func (p *peer) SetReceiptsIdle(delivered int) { +func (p *peerConnection) SetReceiptsIdle(delivered int) { p.setIdle(p.receiptStarted, delivered, &p.receiptThroughput, &p.receiptIdle) } // SetNodeDataIdle sets the peer to idle, allowing it to execute new state trie // data retrieval requests. Its estimated state retrieval throughput is updated // with that measured just now. -func (p *peer) SetNodeDataIdle(delivered int) { +func (p *peerConnection) SetNodeDataIdle(delivered int) { p.setIdle(p.stateStarted, delivered, &p.stateThroughput, &p.stateIdle) } // setIdle sets the peer to idle, allowing it to execute new retrieval requests. // Its estimated retrieval throughput is updated with that measured just now. -func (p *peer) setIdle(started time.Time, delivered int, throughput *float64, idle *int32) { +func (p *peerConnection) setIdle(started time.Time, delivered int, throughput *float64, idle *int32) { // Irrelevant of the scaling, make sure the peer ends up idle defer atomic.StoreInt32(idle, 0) @@ -274,7 +286,7 @@ func (p *peer) setIdle(started time.Time, delivered int, throughput *float64, id // HeaderCapacity retrieves the peers header download allowance based on its // previously discovered throughput. -func (p *peer) HeaderCapacity(targetRTT time.Duration) int { +func (p *peerConnection) HeaderCapacity(targetRTT time.Duration) int { p.lock.RLock() defer p.lock.RUnlock() @@ -283,7 +295,7 @@ func (p *peer) HeaderCapacity(targetRTT time.Duration) int { // BlockCapacity retrieves the peers block download allowance based on its // previously discovered throughput. -func (p *peer) BlockCapacity(targetRTT time.Duration) int { +func (p *peerConnection) BlockCapacity(targetRTT time.Duration) int { p.lock.RLock() defer p.lock.RUnlock() @@ -292,7 +304,7 @@ func (p *peer) BlockCapacity(targetRTT time.Duration) int { // ReceiptCapacity retrieves the peers receipt download allowance based on its // previously discovered throughput. -func (p *peer) ReceiptCapacity(targetRTT time.Duration) int { +func (p *peerConnection) ReceiptCapacity(targetRTT time.Duration) int { p.lock.RLock() defer p.lock.RUnlock() @@ -301,7 +313,7 @@ func (p *peer) ReceiptCapacity(targetRTT time.Duration) int { // NodeDataCapacity retrieves the peers state download allowance based on its // previously discovered throughput. -func (p *peer) NodeDataCapacity(targetRTT time.Duration) int { +func (p *peerConnection) NodeDataCapacity(targetRTT time.Duration) int { p.lock.RLock() defer p.lock.RUnlock() @@ -311,7 +323,7 @@ func (p *peer) NodeDataCapacity(targetRTT time.Duration) int { // MarkLacking appends a new entity to the set of items (blocks, receipts, states) // that a peer is known not to have (i.e. have been requested before). If the // set reaches its maximum allowed capacity, items are randomly dropped off. -func (p *peer) MarkLacking(hash common.Hash) { +func (p *peerConnection) MarkLacking(hash common.Hash) { p.lock.Lock() defer p.lock.Unlock() @@ -326,7 +338,7 @@ func (p *peer) MarkLacking(hash common.Hash) { // Lacks retrieves whether the hash of a blockchain item is on the peers lacking // list (i.e. whether we know that the peer does not have it). -func (p *peer) Lacks(hash common.Hash) bool { +func (p *peerConnection) Lacks(hash common.Hash) bool { p.lock.RLock() defer p.lock.RUnlock() @@ -337,7 +349,7 @@ func (p *peer) Lacks(hash common.Hash) bool { // peerSet represents the collection of active peer participating in the chain // download procedure. type peerSet struct { - peers map[string]*peer + peers map[string]*peerConnection newPeerFeed event.Feed lock sync.RWMutex } @@ -345,11 +357,11 @@ type peerSet struct { // newPeerSet creates a new peer set top track the active download sources. func newPeerSet() *peerSet { return &peerSet{ - peers: make(map[string]*peer), + peers: make(map[string]*peerConnection), } } -func (ps *peerSet) SubscribeNewPeers(ch chan<- *peer) event.Subscription { +func (ps *peerSet) SubscribeNewPeers(ch chan<- *peerConnection) event.Subscription { return ps.newPeerFeed.Subscribe(ch) } @@ -370,7 +382,7 @@ func (ps *peerSet) Reset() { // The method also sets the starting throughput values of the new peer to the // average of all existing peers, to give it a realistic chance of being used // for data retrievals. -func (ps *peerSet) Register(p *peer) error { +func (ps *peerSet) Register(p *peerConnection) error { // Retrieve the current median RTT as a sane default p.rtt = ps.medianRTT() @@ -417,7 +429,7 @@ func (ps *peerSet) Unregister(id string) error { } // Peer retrieves the registered peer with the given id. -func (ps *peerSet) Peer(id string) *peer { +func (ps *peerSet) Peer(id string) *peerConnection { ps.lock.RLock() defer ps.lock.RUnlock() @@ -433,11 +445,11 @@ func (ps *peerSet) Len() int { } // AllPeers retrieves a flat list of all the peers within the set. -func (ps *peerSet) AllPeers() []*peer { +func (ps *peerSet) AllPeers() []*peerConnection { ps.lock.RLock() defer ps.lock.RUnlock() - list := make([]*peer, 0, len(ps.peers)) + list := make([]*peerConnection, 0, len(ps.peers)) for _, p := range ps.peers { list = append(list, p) } @@ -446,11 +458,11 @@ func (ps *peerSet) AllPeers() []*peer { // HeaderIdlePeers retrieves a flat list of all the currently header-idle peers // within the active peer set, ordered by their reputation. -func (ps *peerSet) HeaderIdlePeers() ([]*peer, int) { - idle := func(p *peer) bool { +func (ps *peerSet) HeaderIdlePeers() ([]*peerConnection, int) { + idle := func(p *peerConnection) bool { return atomic.LoadInt32(&p.headerIdle) == 0 } - throughput := func(p *peer) float64 { + throughput := func(p *peerConnection) float64 { p.lock.RLock() defer p.lock.RUnlock() return p.headerThroughput @@ -460,11 +472,11 @@ func (ps *peerSet) HeaderIdlePeers() ([]*peer, int) { // BodyIdlePeers retrieves a flat list of all the currently body-idle peers within // the active peer set, ordered by their reputation. -func (ps *peerSet) BodyIdlePeers() ([]*peer, int) { - idle := func(p *peer) bool { +func (ps *peerSet) BodyIdlePeers() ([]*peerConnection, int) { + idle := func(p *peerConnection) bool { return atomic.LoadInt32(&p.blockIdle) == 0 } - throughput := func(p *peer) float64 { + throughput := func(p *peerConnection) float64 { p.lock.RLock() defer p.lock.RUnlock() return p.blockThroughput @@ -474,11 +486,11 @@ func (ps *peerSet) BodyIdlePeers() ([]*peer, int) { // ReceiptIdlePeers retrieves a flat list of all the currently receipt-idle peers // within the active peer set, ordered by their reputation. -func (ps *peerSet) ReceiptIdlePeers() ([]*peer, int) { - idle := func(p *peer) bool { +func (ps *peerSet) ReceiptIdlePeers() ([]*peerConnection, int) { + idle := func(p *peerConnection) bool { return atomic.LoadInt32(&p.receiptIdle) == 0 } - throughput := func(p *peer) float64 { + throughput := func(p *peerConnection) float64 { p.lock.RLock() defer p.lock.RUnlock() return p.receiptThroughput @@ -488,11 +500,11 @@ func (ps *peerSet) ReceiptIdlePeers() ([]*peer, int) { // NodeDataIdlePeers retrieves a flat list of all the currently node-data-idle // peers within the active peer set, ordered by their reputation. -func (ps *peerSet) NodeDataIdlePeers() ([]*peer, int) { - idle := func(p *peer) bool { +func (ps *peerSet) NodeDataIdlePeers() ([]*peerConnection, int) { + idle := func(p *peerConnection) bool { return atomic.LoadInt32(&p.stateIdle) == 0 } - throughput := func(p *peer) float64 { + throughput := func(p *peerConnection) float64 { p.lock.RLock() defer p.lock.RUnlock() return p.stateThroughput @@ -503,11 +515,11 @@ func (ps *peerSet) NodeDataIdlePeers() ([]*peer, int) { // idlePeers retrieves a flat list of all currently idle peers satisfying the // protocol version constraints, using the provided function to check idleness. // The resulting set of peers are sorted by their measure throughput. -func (ps *peerSet) idlePeers(minProtocol, maxProtocol int, idleCheck func(*peer) bool, throughput func(*peer) float64) ([]*peer, int) { +func (ps *peerSet) idlePeers(minProtocol, maxProtocol int, idleCheck func(*peerConnection) bool, throughput func(*peerConnection) float64) ([]*peerConnection, int) { ps.lock.RLock() defer ps.lock.RUnlock() - idle, total := make([]*peer, 0, len(ps.peers)), 0 + idle, total := make([]*peerConnection, 0, len(ps.peers)), 0 for _, p := range ps.peers { if p.version >= minProtocol && p.version <= maxProtocol { if idleCheck(p) { diff --git a/eth/downloader/queue.go b/eth/downloader/queue.go index 8a7735d673..6926f1d8c8 100644 --- a/eth/downloader/queue.go +++ b/eth/downloader/queue.go @@ -41,7 +41,7 @@ var ( // fetchRequest is a currently running data retrieval operation. type fetchRequest struct { - Peer *peer // Peer to which the request was sent + Peer *peerConnection // Peer to which the request was sent From uint64 // [eth/62] Requested chain element index (used for skeleton fills only) Hashes map[common.Hash]int // [eth/61] Requested hashes with their insertion index (priority) Headers []*types.Header // [eth/62] Requested headers, sorted by request order @@ -391,7 +391,7 @@ func (q *queue) countProcessableItems() int { // ReserveHeaders reserves a set of headers for the given peer, skipping any // previously failed batches. -func (q *queue) ReserveHeaders(p *peer, count int) *fetchRequest { +func (q *queue) ReserveHeaders(p *peerConnection, count int) *fetchRequest { q.lock.Lock() defer q.lock.Unlock() @@ -432,7 +432,7 @@ func (q *queue) ReserveHeaders(p *peer, count int) *fetchRequest { // ReserveBodies reserves a set of body fetches for the given peer, skipping any // previously failed downloads. Beside the next batch of needed fetches, it also // returns a flag whether empty blocks were queued requiring processing. -func (q *queue) ReserveBodies(p *peer, count int) (*fetchRequest, bool, error) { +func (q *queue) ReserveBodies(p *peerConnection, count int) (*fetchRequest, bool, error) { isNoop := func(header *types.Header) bool { return header.TxHash == types.EmptyRootHash && header.UncleHash == types.EmptyUncleHash } @@ -445,7 +445,7 @@ func (q *queue) ReserveBodies(p *peer, count int) (*fetchRequest, bool, error) { // ReserveReceipts reserves a set of receipt fetches for the given peer, skipping // any previously failed downloads. Beside the next batch of needed fetches, it // also returns a flag whether empty receipts were queued requiring importing. -func (q *queue) ReserveReceipts(p *peer, count int) (*fetchRequest, bool, error) { +func (q *queue) ReserveReceipts(p *peerConnection, count int) (*fetchRequest, bool, error) { isNoop := func(header *types.Header) bool { return header.ReceiptHash == types.EmptyRootHash } @@ -462,7 +462,7 @@ func (q *queue) ReserveReceipts(p *peer, count int) (*fetchRequest, bool, error) // Note, this method expects the queue lock to be already held for writing. The // reason the lock is not obtained in here is because the parameters already need // to access the queue, so they already need a lock anyway. -func (q *queue) reserveHeaders(p *peer, count int, taskPool map[common.Hash]*types.Header, taskQueue *prque.Prque, +func (q *queue) reserveHeaders(p *peerConnection, count int, taskPool map[common.Hash]*types.Header, taskQueue *prque.Prque, pendPool map[string]*fetchRequest, donePool map[common.Hash]struct{}, isNoop func(*types.Header) bool) (*fetchRequest, bool, error) { // Short circuit if the pool has been depleted, or if the peer's already // downloading something (sanity check not to corrupt state) diff --git a/eth/downloader/statesync.go b/eth/downloader/statesync.go index 4e66120393..a5ce8c42df 100644 --- a/eth/downloader/statesync.go +++ b/eth/downloader/statesync.go @@ -37,7 +37,7 @@ type stateReq struct { tasks map[common.Hash]*stateTask // Download tasks to track previous attempts timeout time.Duration // Maximum round trip time for this to complete timer *time.Timer // Timer to fire when the RTT timeout expires - peer *peer // Peer that we're requesting from + peer *peerConnection // Peer that we're requesting from response [][]byte // Response data of the peer (nil for timeouts) } @@ -246,7 +246,7 @@ func (s *stateSync) Cancel() error { // and timeouts. func (s *stateSync) loop() error { // Listen for new peer events to assign tasks to them - newPeer := make(chan *peer, 1024) + newPeer := make(chan *peerConnection, 1024) peerSub := s.d.peers.SubscribeNewPeers(newPeer) defer peerSub.Unsubscribe() diff --git a/eth/downloader/types.go b/eth/downloader/types.go index e105104864..3f30ea9dd1 100644 --- a/eth/downloader/types.go +++ b/eth/downloader/types.go @@ -18,51 +18,10 @@ package downloader import ( "fmt" - "math/big" - "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" ) -// headerCheckFn is a callback type for verifying a header's presence in the local chain. -type headerCheckFn func(common.Hash) bool - -// blockAndStateCheckFn is a callback type for verifying block and associated states' presence in the local chain. -type blockAndStateCheckFn func(common.Hash) bool - -// headerRetrievalFn is a callback type for retrieving a header from the local chain. -type headerRetrievalFn func(common.Hash) *types.Header - -// blockRetrievalFn is a callback type for retrieving a block from the local chain. -type blockRetrievalFn func(common.Hash) *types.Block - -// headHeaderRetrievalFn is a callback type for retrieving the head header from the local chain. -type headHeaderRetrievalFn func() *types.Header - -// headBlockRetrievalFn is a callback type for retrieving the head block from the local chain. -type headBlockRetrievalFn func() *types.Block - -// headFastBlockRetrievalFn is a callback type for retrieving the head fast block from the local chain. -type headFastBlockRetrievalFn func() *types.Block - -// headBlockCommitterFn is a callback for directly committing the head block to a certain entity. -type headBlockCommitterFn func(common.Hash) error - -// tdRetrievalFn is a callback type for retrieving the total difficulty of a local block. -type tdRetrievalFn func(common.Hash) *big.Int - -// headerChainInsertFn is a callback type to insert a batch of headers into the local chain. -type headerChainInsertFn func([]*types.Header, int) (int, error) - -// blockChainInsertFn is a callback type to insert a batch of blocks into the local chain. -type blockChainInsertFn func(types.Blocks) (int, error) - -// receiptChainInsertFn is a callback type to insert a batch of receipts into the local chain. -type receiptChainInsertFn func(types.Blocks, []types.Receipts) (int, error) - -// chainRollbackFn is a callback type to remove a few recently added elements from the local chain. -type chainRollbackFn func([]common.Hash) - // peerDropFn is a callback type for dropping a peer detected as malicious. type peerDropFn func(id string) diff --git a/eth/handler.go b/eth/handler.go index 1af9e755ba..b2422d71c4 100644 --- a/eth/handler.go +++ b/eth/handler.go @@ -157,10 +157,7 @@ func NewProtocolManager(config *params.ChainConfig, mode downloader.SyncMode, ne return nil, errIncompatibleConfig } // Construct the different synchronisation mechanisms - manager.downloader = downloader.New(mode, chaindb, manager.eventMux, blockchain.HasHeader, blockchain.HasBlockAndState, blockchain.GetHeaderByHash, - blockchain.GetBlockByHash, blockchain.CurrentHeader, blockchain.CurrentBlock, blockchain.CurrentFastBlock, blockchain.FastSyncCommitHead, - blockchain.GetTdByHash, blockchain.InsertHeaderChain, manager.blockchain.InsertChain, blockchain.InsertReceiptChain, blockchain.Rollback, - manager.removePeer) + manager.downloader = downloader.New(mode, chaindb, manager.eventMux, blockchain, nil, manager.removePeer) validator := func(header *types.Header) error { return engine.VerifyHeader(blockchain, header, true) @@ -268,7 +265,7 @@ func (pm *ProtocolManager) handle(p *peer) error { defer pm.removePeer(p.id) // Register the peer in the downloader. If the downloader considers it banned, we disconnect - if err := pm.downloader.RegisterPeer(p.id, p.version, p.Head, p.RequestHeadersByHash, p.RequestHeadersByNumber, p.RequestBodies, p.RequestReceipts, p.RequestNodeData); err != nil { + if err := pm.downloader.RegisterPeer(p.id, p.version, p); err != nil { return err } // Propagate existing transactions. new transactions appearing diff --git a/les/handler.go b/les/handler.go index 77bc077a2e..39045ecbea 100644 --- a/les/handler.go +++ b/les/handler.go @@ -206,9 +206,7 @@ func NewProtocolManager(chainConfig *params.ChainConfig, lightSync bool, network } if lightSync { - manager.downloader = downloader.New(downloader.LightSync, chainDb, manager.eventMux, blockchain.HasHeader, nil, blockchain.GetHeaderByHash, - nil, blockchain.CurrentHeader, nil, nil, nil, blockchain.GetTdByHash, - blockchain.InsertHeaderChain, nil, nil, blockchain.Rollback, removePeer) + manager.downloader = downloader.New(downloader.LightSync, chainDb, manager.eventMux, nil, blockchain, removePeer) manager.peers.notify((*downloaderPeerNotify)(manager)) manager.fetcher = newLightFetcher(manager) } @@ -840,57 +838,70 @@ func (self *ProtocolManager) NodeInfo() *eth.EthNodeInfo { // downloaderPeerNotify implements peerSetNotify type downloaderPeerNotify ProtocolManager -func (d *downloaderPeerNotify) registerPeer(p *peer) { - pm := (*ProtocolManager)(d) +type peerConnection struct { + manager *ProtocolManager + peer *peer +} - requestHeadersByHash := func(origin common.Hash, amount int, skip int, reverse bool) error { - reqID := genReqID() - rq := &distReq{ - getCost: func(dp distPeer) uint64 { - peer := dp.(*peer) - return peer.GetRequestCost(GetBlockHeadersMsg, amount) - }, - canSend: func(dp distPeer) bool { - return dp.(*peer) == p - }, - request: func(dp distPeer) func() { - peer := dp.(*peer) - cost := peer.GetRequestCost(GetBlockHeadersMsg, amount) - peer.fcServer.QueueRequest(reqID, cost) - return func() { peer.RequestHeadersByHash(reqID, cost, origin, amount, skip, reverse) } - }, - } - _, ok := <-pm.reqDist.queue(rq) - if !ok { - return ErrNoPeers - } - return nil +func (pc *peerConnection) Head() (common.Hash, *big.Int) { + return pc.peer.HeadAndTd() +} + +func (pc *peerConnection) RequestHeadersByHash(origin common.Hash, amount int, skip int, reverse bool) error { + reqID := genReqID() + rq := &distReq{ + getCost: func(dp distPeer) uint64 { + peer := dp.(*peer) + return peer.GetRequestCost(GetBlockHeadersMsg, amount) + }, + canSend: func(dp distPeer) bool { + return dp.(*peer) == pc.peer + }, + request: func(dp distPeer) func() { + peer := dp.(*peer) + cost := peer.GetRequestCost(GetBlockHeadersMsg, amount) + peer.fcServer.QueueRequest(reqID, cost) + return func() { peer.RequestHeadersByHash(reqID, cost, origin, amount, skip, reverse) } + }, } - requestHeadersByNumber := func(origin uint64, amount int, skip int, reverse bool) error { - reqID := genReqID() - rq := &distReq{ - getCost: func(dp distPeer) uint64 { - peer := dp.(*peer) - return peer.GetRequestCost(GetBlockHeadersMsg, amount) - }, - canSend: func(dp distPeer) bool { - return dp.(*peer) == p - }, - request: func(dp distPeer) func() { - peer := dp.(*peer) - cost := peer.GetRequestCost(GetBlockHeadersMsg, amount) - peer.fcServer.QueueRequest(reqID, cost) - return func() { peer.RequestHeadersByNumber(reqID, cost, origin, amount, skip, reverse) } - }, - } - _, ok := <-pm.reqDist.queue(rq) - if !ok { - return ErrNoPeers - } - return nil + _, ok := <-pc.manager.reqDist.queue(rq) + if !ok { + return ErrNoPeers + } + return nil +} + +func (pc *peerConnection) RequestHeadersByNumber(origin uint64, amount int, skip int, reverse bool) error { + reqID := genReqID() + rq := &distReq{ + getCost: func(dp distPeer) uint64 { + peer := dp.(*peer) + return peer.GetRequestCost(GetBlockHeadersMsg, amount) + }, + canSend: func(dp distPeer) bool { + return dp.(*peer) == pc.peer + }, + request: func(dp distPeer) func() { + peer := dp.(*peer) + cost := peer.GetRequestCost(GetBlockHeadersMsg, amount) + peer.fcServer.QueueRequest(reqID, cost) + return func() { peer.RequestHeadersByNumber(reqID, cost, origin, amount, skip, reverse) } + }, } + _, ok := <-pc.manager.reqDist.queue(rq) + if !ok { + return ErrNoPeers + } + return nil +} - pm.downloader.RegisterPeer(p.id, ethVersion, p.HeadAndTd, requestHeadersByHash, requestHeadersByNumber, nil, nil, nil) +func (d *downloaderPeerNotify) registerPeer(p *peer) { + pm := (*ProtocolManager)(d) + pc := &peerConnection{ + manager: pm, + peer: p, + } + pm.downloader.RegisterLightPeer(p.id, ethVersion, pc) } func (d *downloaderPeerNotify) unregisterPeer(p *peer) {