les: code refactoring (#14416)

This commit does various code refactorings:

- generalizes and moves the request retrieval/timeout/resend logic out of LesOdr
  (will be used by a subsequent PR)
- reworks the peer management logic so that all services can register with
  peerSet to get notified about added/dropped peers (also gets rid of the ugly
  getAllPeers callback in requestDistributor)
- moves peerSet, LesOdr, requestDistributor and retrieveManager initialization
  out of ProtocolManager because I believe they do not really belong there and the
  whole init process was ugly and ad-hoc
pull/14677/head
Felföldi Zsolt 7 years ago committed by Felix Lange
parent 60e27b51bc
commit a5d08c893d
  1. 40
      les/backend.go
  2. 58
      les/distributor.go
  3. 10
      les/distributor_test.go
  4. 11
      les/fetcher.go
  5. 207
      les/handler.go
  6. 16
      les/handler_test.go
  7. 52
      les/helper_test.go
  8. 195
      les/odr.go
  9. 28
      les/odr_test.go
  10. 38
      les/peer.go
  11. 27
      les/request_test.go
  12. 395
      les/retrieve.go
  13. 21
      les/server.go
  14. 26
      les/serverpool.go
  15. 16
      les/txrelay.go

@ -19,6 +19,7 @@ package les
import ( import (
"fmt" "fmt"
"sync"
"time" "time"
"github.com/ethereum/go-ethereum/accounts" "github.com/ethereum/go-ethereum/accounts"
@ -38,6 +39,7 @@ import (
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/node" "github.com/ethereum/go-ethereum/node"
"github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/discv5"
"github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/params"
rpc "github.com/ethereum/go-ethereum/rpc" rpc "github.com/ethereum/go-ethereum/rpc"
) )
@ -49,9 +51,13 @@ type LightEthereum struct {
// Channel for shutting down the service // Channel for shutting down the service
shutdownChan chan bool shutdownChan chan bool
// Handlers // Handlers
peers *peerSet
txPool *light.TxPool txPool *light.TxPool
blockchain *light.LightChain blockchain *light.LightChain
protocolManager *ProtocolManager protocolManager *ProtocolManager
serverPool *serverPool
reqDist *requestDistributor
retriever *retrieveManager
// DB interfaces // DB interfaces
chainDb ethdb.Database // Block chain database chainDb ethdb.Database // Block chain database
@ -63,6 +69,9 @@ type LightEthereum struct {
networkId uint64 networkId uint64
netRPCService *ethapi.PublicNetAPI netRPCService *ethapi.PublicNetAPI
quitSync chan struct{}
wg sync.WaitGroup
} }
func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) { func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) {
@ -76,20 +85,26 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) {
} }
log.Info("Initialised chain configuration", "config", chainConfig) log.Info("Initialised chain configuration", "config", chainConfig)
odr := NewLesOdr(chainDb) peers := newPeerSet()
relay := NewLesTxRelay() quitSync := make(chan struct{})
eth := &LightEthereum{ eth := &LightEthereum{
odr: odr,
relay: relay,
chainDb: chainDb,
chainConfig: chainConfig, chainConfig: chainConfig,
chainDb: chainDb,
eventMux: ctx.EventMux, eventMux: ctx.EventMux,
peers: peers,
reqDist: newRequestDistributor(peers, quitSync),
accountManager: ctx.AccountManager, accountManager: ctx.AccountManager,
engine: eth.CreateConsensusEngine(ctx, config, chainConfig, chainDb), engine: eth.CreateConsensusEngine(ctx, config, chainConfig, chainDb),
shutdownChan: make(chan bool), shutdownChan: make(chan bool),
networkId: config.NetworkId, networkId: config.NetworkId,
} }
if eth.blockchain, err = light.NewLightChain(odr, eth.chainConfig, eth.engine, eth.eventMux); err != nil {
eth.relay = NewLesTxRelay(peers, eth.reqDist)
eth.serverPool = newServerPool(chainDb, quitSync, &eth.wg)
eth.retriever = newRetrieveManager(peers, eth.reqDist, eth.serverPool)
eth.odr = NewLesOdr(chainDb, eth.retriever)
if eth.blockchain, err = light.NewLightChain(eth.odr, eth.chainConfig, eth.engine, eth.eventMux); err != nil {
return nil, err return nil, err
} }
// Rewind the chain in case of an incompatible config upgrade. // Rewind the chain in case of an incompatible config upgrade.
@ -100,13 +115,9 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) {
} }
eth.txPool = light.NewTxPool(eth.chainConfig, eth.eventMux, eth.blockchain, eth.relay) eth.txPool = light.NewTxPool(eth.chainConfig, eth.eventMux, eth.blockchain, eth.relay)
lightSync := config.SyncMode == downloader.LightSync if eth.protocolManager, err = NewProtocolManager(eth.chainConfig, true, config.NetworkId, eth.eventMux, eth.engine, eth.peers, eth.blockchain, nil, chainDb, eth.odr, eth.relay, quitSync, &eth.wg); err != nil {
if eth.protocolManager, err = NewProtocolManager(eth.chainConfig, lightSync, config.NetworkId, eth.eventMux, eth.engine, eth.blockchain, nil, chainDb, odr, relay); err != nil {
return nil, err return nil, err
} }
relay.ps = eth.protocolManager.peers
relay.reqDist = eth.protocolManager.reqDist
eth.ApiBackend = &LesApiBackend{eth, nil} eth.ApiBackend = &LesApiBackend{eth, nil}
gpoParams := config.GPO gpoParams := config.GPO
if gpoParams.Default == nil { if gpoParams.Default == nil {
@ -116,6 +127,10 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) {
return eth, nil return eth, nil
} }
func lesTopic(genesisHash common.Hash) discv5.Topic {
return discv5.Topic("LES@" + common.Bytes2Hex(genesisHash.Bytes()[0:8]))
}
type LightDummyAPI struct{} type LightDummyAPI struct{}
// Etherbase is the address that mining rewards will be send to // Etherbase is the address that mining rewards will be send to
@ -188,7 +203,8 @@ func (s *LightEthereum) Protocols() []p2p.Protocol {
func (s *LightEthereum) Start(srvr *p2p.Server) error { func (s *LightEthereum) Start(srvr *p2p.Server) error {
log.Warn("Light client mode is an experimental feature") log.Warn("Light client mode is an experimental feature")
s.netRPCService = ethapi.NewPublicNetAPI(srvr, s.networkId) s.netRPCService = ethapi.NewPublicNetAPI(srvr, s.networkId)
s.protocolManager.Start(srvr) s.serverPool.start(srvr, lesTopic(s.blockchain.Genesis().Hash()))
s.protocolManager.Start()
return nil return nil
} }

@ -34,11 +34,11 @@ var ErrNoPeers = errors.New("no suitable peers available")
type requestDistributor struct { type requestDistributor struct {
reqQueue *list.List reqQueue *list.List
lastReqOrder uint64 lastReqOrder uint64
peers map[distPeer]struct{}
peerLock sync.RWMutex
stopChn, loopChn chan struct{} stopChn, loopChn chan struct{}
loopNextSent bool loopNextSent bool
lock sync.Mutex lock sync.Mutex
getAllPeers func() map[distPeer]struct{}
} }
// distPeer is an LES server peer interface for the request distributor. // distPeer is an LES server peer interface for the request distributor.
@ -71,15 +71,39 @@ type distReq struct {
} }
// newRequestDistributor creates a new request distributor // newRequestDistributor creates a new request distributor
func newRequestDistributor(getAllPeers func() map[distPeer]struct{}, stopChn chan struct{}) *requestDistributor { func newRequestDistributor(peers *peerSet, stopChn chan struct{}) *requestDistributor {
r := &requestDistributor{ d := &requestDistributor{
reqQueue: list.New(), reqQueue: list.New(),
loopChn: make(chan struct{}, 2), loopChn: make(chan struct{}, 2),
stopChn: stopChn, stopChn: stopChn,
getAllPeers: getAllPeers, peers: make(map[distPeer]struct{}),
}
if peers != nil {
peers.notify(d)
} }
go r.loop() go d.loop()
return r return d
}
// registerPeer implements peerSetNotify
func (d *requestDistributor) registerPeer(p *peer) {
d.peerLock.Lock()
d.peers[p] = struct{}{}
d.peerLock.Unlock()
}
// unregisterPeer implements peerSetNotify
func (d *requestDistributor) unregisterPeer(p *peer) {
d.peerLock.Lock()
delete(d.peers, p)
d.peerLock.Unlock()
}
// registerTestPeer adds a new test peer
func (d *requestDistributor) registerTestPeer(p distPeer) {
d.peerLock.Lock()
d.peers[p] = struct{}{}
d.peerLock.Unlock()
} }
// distMaxWait is the maximum waiting time after which further necessary waiting // distMaxWait is the maximum waiting time after which further necessary waiting
@ -152,8 +176,7 @@ func (sp selectPeerItem) Weight() int64 {
// nextRequest returns the next possible request from any peer, along with the // nextRequest returns the next possible request from any peer, along with the
// associated peer and necessary waiting time // associated peer and necessary waiting time
func (d *requestDistributor) nextRequest() (distPeer, *distReq, time.Duration) { func (d *requestDistributor) nextRequest() (distPeer, *distReq, time.Duration) {
peers := d.getAllPeers() checkedPeers := make(map[distPeer]struct{})
elem := d.reqQueue.Front() elem := d.reqQueue.Front()
var ( var (
bestPeer distPeer bestPeer distPeer
@ -162,11 +185,14 @@ func (d *requestDistributor) nextRequest() (distPeer, *distReq, time.Duration) {
sel *weightedRandomSelect sel *weightedRandomSelect
) )
for (len(peers) > 0 || elem == d.reqQueue.Front()) && elem != nil { d.peerLock.RLock()
defer d.peerLock.RUnlock()
for (len(d.peers) > 0 || elem == d.reqQueue.Front()) && elem != nil {
req := elem.Value.(*distReq) req := elem.Value.(*distReq)
canSend := false canSend := false
for peer, _ := range peers { for peer, _ := range d.peers {
if peer.canQueue() && req.canSend(peer) { if _, ok := checkedPeers[peer]; !ok && peer.canQueue() && req.canSend(peer) {
canSend = true canSend = true
cost := req.getCost(peer) cost := req.getCost(peer)
wait, bufRemain := peer.waitBefore(cost) wait, bufRemain := peer.waitBefore(cost)
@ -182,7 +208,7 @@ func (d *requestDistributor) nextRequest() (distPeer, *distReq, time.Duration) {
bestWait = wait bestWait = wait
} }
} }
delete(peers, peer) checkedPeers[peer] = struct{}{}
} }
} }
next := elem.Next() next := elem.Next()

@ -122,20 +122,14 @@ func testRequestDistributor(t *testing.T, resend bool) {
stop := make(chan struct{}) stop := make(chan struct{})
defer close(stop) defer close(stop)
dist := newRequestDistributor(nil, stop)
var peers [testDistPeerCount]*testDistPeer var peers [testDistPeerCount]*testDistPeer
for i, _ := range peers { for i, _ := range peers {
peers[i] = &testDistPeer{} peers[i] = &testDistPeer{}
go peers[i].worker(t, !resend, stop) go peers[i].worker(t, !resend, stop)
dist.registerTestPeer(peers[i])
} }
dist := newRequestDistributor(func() map[distPeer]struct{} {
m := make(map[distPeer]struct{})
for _, peer := range peers {
m[peer] = struct{}{}
}
return m
}, stop)
var wg sync.WaitGroup var wg sync.WaitGroup
for i := 1; i <= testDistReqCount; i++ { for i := 1; i <= testDistReqCount; i++ {

@ -116,6 +116,7 @@ func newLightFetcher(pm *ProtocolManager) *lightFetcher {
syncDone: make(chan *peer), syncDone: make(chan *peer),
maxConfirmedTd: big.NewInt(0), maxConfirmedTd: big.NewInt(0),
} }
pm.peers.notify(f)
go f.syncLoop() go f.syncLoop()
return f return f
} }
@ -209,8 +210,8 @@ func (f *lightFetcher) syncLoop() {
} }
} }
// addPeer adds a new peer to the fetcher's peer set // registerPeer adds a new peer to the fetcher's peer set
func (f *lightFetcher) addPeer(p *peer) { func (f *lightFetcher) registerPeer(p *peer) {
p.lock.Lock() p.lock.Lock()
p.hasBlock = func(hash common.Hash, number uint64) bool { p.hasBlock = func(hash common.Hash, number uint64) bool {
return f.peerHasBlock(p, hash, number) return f.peerHasBlock(p, hash, number)
@ -223,8 +224,8 @@ func (f *lightFetcher) addPeer(p *peer) {
f.peers[p] = &fetcherPeerInfo{nodeByHash: make(map[common.Hash]*fetcherTreeNode)} f.peers[p] = &fetcherPeerInfo{nodeByHash: make(map[common.Hash]*fetcherTreeNode)}
} }
// removePeer removes a new peer from the fetcher's peer set // unregisterPeer removes a new peer from the fetcher's peer set
func (f *lightFetcher) removePeer(p *peer) { func (f *lightFetcher) unregisterPeer(p *peer) {
p.lock.Lock() p.lock.Lock()
p.hasBlock = nil p.hasBlock = nil
p.lock.Unlock() p.lock.Unlock()
@ -416,7 +417,7 @@ func (f *lightFetcher) nextRequest() (*distReq, uint64) {
f.syncing = bestSyncing f.syncing = bestSyncing
var rq *distReq var rq *distReq
reqID := getNextReqID() reqID := genReqID()
if f.syncing { if f.syncing {
rq = &distReq{ rq = &distReq{
getCost: func(dp distPeer) uint64 { getCost: func(dp distPeer) uint64 {

@ -102,7 +102,9 @@ type ProtocolManager struct {
odr *LesOdr odr *LesOdr
server *LesServer server *LesServer
serverPool *serverPool serverPool *serverPool
lesTopic discv5.Topic
reqDist *requestDistributor reqDist *requestDistributor
retriever *retrieveManager
downloader *downloader.Downloader downloader *downloader.Downloader
fetcher *lightFetcher fetcher *lightFetcher
@ -123,12 +125,12 @@ type ProtocolManager struct {
// wait group is used for graceful shutdowns during downloading // wait group is used for graceful shutdowns during downloading
// and processing // and processing
wg sync.WaitGroup wg *sync.WaitGroup
} }
// NewProtocolManager returns a new ethereum sub protocol manager. The Ethereum sub protocol manages peers capable // NewProtocolManager returns a new ethereum sub protocol manager. The Ethereum sub protocol manages peers capable
// with the ethereum network. // with the ethereum network.
func NewProtocolManager(chainConfig *params.ChainConfig, lightSync bool, networkId uint64, mux *event.TypeMux, engine consensus.Engine, blockchain BlockChain, txpool txPool, chainDb ethdb.Database, odr *LesOdr, txrelay *LesTxRelay) (*ProtocolManager, error) { func NewProtocolManager(chainConfig *params.ChainConfig, lightSync bool, networkId uint64, mux *event.TypeMux, engine consensus.Engine, peers *peerSet, blockchain BlockChain, txpool txPool, chainDb ethdb.Database, odr *LesOdr, txrelay *LesTxRelay, quitSync chan struct{}, wg *sync.WaitGroup) (*ProtocolManager, error) {
// Create the protocol manager with the base fields // Create the protocol manager with the base fields
manager := &ProtocolManager{ manager := &ProtocolManager{
lightSync: lightSync, lightSync: lightSync,
@ -136,15 +138,20 @@ func NewProtocolManager(chainConfig *params.ChainConfig, lightSync bool, network
blockchain: blockchain, blockchain: blockchain,
chainConfig: chainConfig, chainConfig: chainConfig,
chainDb: chainDb, chainDb: chainDb,
odr: odr,
networkId: networkId, networkId: networkId,
txpool: txpool, txpool: txpool,
txrelay: txrelay, txrelay: txrelay,
odr: odr, peers: peers,
peers: newPeerSet(),
newPeerCh: make(chan *peer), newPeerCh: make(chan *peer),
quitSync: make(chan struct{}), quitSync: quitSync,
wg: wg,
noMorePeers: make(chan struct{}), noMorePeers: make(chan struct{}),
} }
if odr != nil {
manager.retriever = odr.retriever
manager.reqDist = odr.retriever.dist
}
// Initiate a sub-protocol for every implemented version we can handle // Initiate a sub-protocol for every implemented version we can handle
manager.SubProtocols = make([]p2p.Protocol, 0, len(ProtocolVersions)) manager.SubProtocols = make([]p2p.Protocol, 0, len(ProtocolVersions))
for i, version := range ProtocolVersions { for i, version := range ProtocolVersions {
@ -202,84 +209,22 @@ func NewProtocolManager(chainConfig *params.ChainConfig, lightSync bool, network
manager.downloader = downloader.New(downloader.LightSync, chainDb, manager.eventMux, blockchain.HasHeader, nil, blockchain.GetHeaderByHash, manager.downloader = downloader.New(downloader.LightSync, chainDb, manager.eventMux, blockchain.HasHeader, nil, blockchain.GetHeaderByHash,
nil, blockchain.CurrentHeader, nil, nil, nil, blockchain.GetTdByHash, nil, blockchain.CurrentHeader, nil, nil, nil, blockchain.GetTdByHash,
blockchain.InsertHeaderChain, nil, nil, blockchain.Rollback, removePeer) blockchain.InsertHeaderChain, nil, nil, blockchain.Rollback, removePeer)
manager.peers.notify((*downloaderPeerNotify)(manager))
manager.fetcher = newLightFetcher(manager)
} }
manager.reqDist = newRequestDistributor(func() map[distPeer]struct{} {
m := make(map[distPeer]struct{})
peers := manager.peers.AllPeers()
for _, peer := range peers {
m[peer] = struct{}{}
}
return m
}, manager.quitSync)
if odr != nil {
odr.removePeer = removePeer
odr.reqDist = manager.reqDist
}
/*validator := func(block *types.Block, parent *types.Block) error {
return core.ValidateHeader(pow, block.Header(), parent.Header(), true, false)
}
heighter := func() uint64 {
return chainman.LastBlockNumberU64()
}
manager.fetcher = fetcher.New(chainman.GetBlockNoOdr, validator, nil, heighter, chainman.InsertChain, manager.removePeer)
*/
return manager, nil return manager, nil
} }
// removePeer initiates disconnection from a peer by removing it from the peer set
func (pm *ProtocolManager) removePeer(id string) { func (pm *ProtocolManager) removePeer(id string) {
// Short circuit if the peer was already removed pm.peers.Unregister(id)
peer := pm.peers.Peer(id)
if peer == nil {
return
}
log.Debug("Removing light Ethereum peer", "peer", id)
if err := pm.peers.Unregister(id); err != nil {
if err == errNotRegistered {
return
}
}
// Unregister the peer from the downloader and Ethereum peer set
if pm.lightSync {
pm.downloader.UnregisterPeer(id)
if pm.txrelay != nil {
pm.txrelay.removePeer(id)
}
if pm.fetcher != nil {
pm.fetcher.removePeer(peer)
}
}
// Hard disconnect at the networking layer
if peer != nil {
peer.Peer.Disconnect(p2p.DiscUselessPeer)
}
} }
func (pm *ProtocolManager) Start(srvr *p2p.Server) { func (pm *ProtocolManager) Start() {
var topicDisc *discv5.Network
if srvr != nil {
topicDisc = srvr.DiscV5
}
lesTopic := discv5.Topic("LES@" + common.Bytes2Hex(pm.blockchain.Genesis().Hash().Bytes()[0:8]))
if pm.lightSync { if pm.lightSync {
// start sync handler
if srvr != nil { // srvr is nil during testing
pm.serverPool = newServerPool(pm.chainDb, []byte("serverPool/"), srvr, lesTopic, pm.quitSync, &pm.wg)
pm.odr.serverPool = pm.serverPool
pm.fetcher = newLightFetcher(pm)
}
go pm.syncer() go pm.syncer()
} else { } else {
if topicDisc != nil {
go func() {
logger := log.New("topic", lesTopic)
logger.Info("Starting topic registration")
defer logger.Info("Terminated topic registration")
topicDisc.RegisterTopic(lesTopic, pm.quitSync)
}()
}
go func() { go func() {
for range pm.newPeerCh { for range pm.newPeerCh {
} }
@ -342,65 +287,10 @@ func (pm *ProtocolManager) handle(p *peer) error {
}() }()
// Register the peer in the downloader. If the downloader considers it banned, we disconnect // Register the peer in the downloader. If the downloader considers it banned, we disconnect
if pm.lightSync { if pm.lightSync {
requestHeadersByHash := func(origin common.Hash, amount int, skip int, reverse bool) error {
reqID := getNextReqID()
rq := &distReq{
getCost: func(dp distPeer) uint64 {
peer := dp.(*peer)
return peer.GetRequestCost(GetBlockHeadersMsg, amount)
},
canSend: func(dp distPeer) bool {
return dp.(*peer) == p
},
request: func(dp distPeer) func() {
peer := dp.(*peer)
cost := peer.GetRequestCost(GetBlockHeadersMsg, amount)
peer.fcServer.QueueRequest(reqID, cost)
return func() { peer.RequestHeadersByHash(reqID, cost, origin, amount, skip, reverse) }
},
}
_, ok := <-pm.reqDist.queue(rq)
if !ok {
return ErrNoPeers
}
return nil
}
requestHeadersByNumber := func(origin uint64, amount int, skip int, reverse bool) error {
reqID := getNextReqID()
rq := &distReq{
getCost: func(dp distPeer) uint64 {
peer := dp.(*peer)
return peer.GetRequestCost(GetBlockHeadersMsg, amount)
},
canSend: func(dp distPeer) bool {
return dp.(*peer) == p
},
request: func(dp distPeer) func() {
peer := dp.(*peer)
cost := peer.GetRequestCost(GetBlockHeadersMsg, amount)
peer.fcServer.QueueRequest(reqID, cost)
return func() { peer.RequestHeadersByNumber(reqID, cost, origin, amount, skip, reverse) }
},
}
_, ok := <-pm.reqDist.queue(rq)
if !ok {
return ErrNoPeers
}
return nil
}
if err := pm.downloader.RegisterPeer(p.id, ethVersion, p.HeadAndTd,
requestHeadersByHash, requestHeadersByNumber, nil, nil, nil); err != nil {
return err
}
if pm.txrelay != nil {
pm.txrelay.addPeer(p)
}
p.lock.Lock() p.lock.Lock()
head := p.headInfo head := p.headInfo
p.lock.Unlock() p.lock.Unlock()
if pm.fetcher != nil { if pm.fetcher != nil {
pm.fetcher.addPeer(p)
pm.fetcher.announce(p, head) pm.fetcher.announce(p, head)
} }
@ -926,7 +816,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
} }
if deliverMsg != nil { if deliverMsg != nil {
err := pm.odr.Deliver(p, deliverMsg) err := pm.retriever.deliver(p, deliverMsg)
if err != nil { if err != nil {
p.responseErrors++ p.responseErrors++
if p.responseErrors > maxResponseErrors { if p.responseErrors > maxResponseErrors {
@ -946,3 +836,64 @@ func (self *ProtocolManager) NodeInfo() *eth.EthNodeInfo {
Head: self.blockchain.LastBlockHash(), Head: self.blockchain.LastBlockHash(),
} }
} }
// downloaderPeerNotify implements peerSetNotify
type downloaderPeerNotify ProtocolManager
func (d *downloaderPeerNotify) registerPeer(p *peer) {
pm := (*ProtocolManager)(d)
requestHeadersByHash := func(origin common.Hash, amount int, skip int, reverse bool) error {
reqID := genReqID()
rq := &distReq{
getCost: func(dp distPeer) uint64 {
peer := dp.(*peer)
return peer.GetRequestCost(GetBlockHeadersMsg, amount)
},
canSend: func(dp distPeer) bool {
return dp.(*peer) == p
},
request: func(dp distPeer) func() {
peer := dp.(*peer)
cost := peer.GetRequestCost(GetBlockHeadersMsg, amount)
peer.fcServer.QueueRequest(reqID, cost)
return func() { peer.RequestHeadersByHash(reqID, cost, origin, amount, skip, reverse) }
},
}
_, ok := <-pm.reqDist.queue(rq)
if !ok {
return ErrNoPeers
}
return nil
}
requestHeadersByNumber := func(origin uint64, amount int, skip int, reverse bool) error {
reqID := genReqID()
rq := &distReq{
getCost: func(dp distPeer) uint64 {
peer := dp.(*peer)
return peer.GetRequestCost(GetBlockHeadersMsg, amount)
},
canSend: func(dp distPeer) bool {
return dp.(*peer) == p
},
request: func(dp distPeer) func() {
peer := dp.(*peer)
cost := peer.GetRequestCost(GetBlockHeadersMsg, amount)
peer.fcServer.QueueRequest(reqID, cost)
return func() { peer.RequestHeadersByNumber(reqID, cost, origin, amount, skip, reverse) }
},
}
_, ok := <-pm.reqDist.queue(rq)
if !ok {
return ErrNoPeers
}
return nil
}
pm.downloader.RegisterPeer(p.id, ethVersion, p.HeadAndTd, requestHeadersByHash, requestHeadersByNumber, nil, nil, nil)
}
func (d *downloaderPeerNotify) unregisterPeer(p *peer) {
pm := (*ProtocolManager)(d)
pm.downloader.UnregisterPeer(p.id)
}

@ -25,6 +25,7 @@ import (
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/eth/downloader" "github.com/ethereum/go-ethereum/eth/downloader"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/trie"
@ -42,7 +43,8 @@ func expectResponse(r p2p.MsgReader, msgcode, reqID, bv uint64, data interface{}
func TestGetBlockHeadersLes1(t *testing.T) { testGetBlockHeaders(t, 1) } func TestGetBlockHeadersLes1(t *testing.T) { testGetBlockHeaders(t, 1) }
func testGetBlockHeaders(t *testing.T, protocol int) { func testGetBlockHeaders(t *testing.T, protocol int) {
pm, _, _ := newTestProtocolManagerMust(t, false, downloader.MaxHashFetch+15, nil) db, _ := ethdb.NewMemDatabase()
pm := newTestProtocolManagerMust(t, false, downloader.MaxHashFetch+15, nil, nil, nil, db)
bc := pm.blockchain.(*core.BlockChain) bc := pm.blockchain.(*core.BlockChain)
peer, _ := newTestPeer(t, "peer", protocol, pm, true) peer, _ := newTestPeer(t, "peer", protocol, pm, true)
defer peer.close() defer peer.close()
@ -170,7 +172,8 @@ func testGetBlockHeaders(t *testing.T, protocol int) {
func TestGetBlockBodiesLes1(t *testing.T) { testGetBlockBodies(t, 1) } func TestGetBlockBodiesLes1(t *testing.T) { testGetBlockBodies(t, 1) }
func testGetBlockBodies(t *testing.T, protocol int) { func testGetBlockBodies(t *testing.T, protocol int) {
pm, _, _ := newTestProtocolManagerMust(t, false, downloader.MaxBlockFetch+15, nil) db, _ := ethdb.NewMemDatabase()
pm := newTestProtocolManagerMust(t, false, downloader.MaxBlockFetch+15, nil, nil, nil, db)
bc := pm.blockchain.(*core.BlockChain) bc := pm.blockchain.(*core.BlockChain)
peer, _ := newTestPeer(t, "peer", protocol, pm, true) peer, _ := newTestPeer(t, "peer", protocol, pm, true)
defer peer.close() defer peer.close()
@ -246,7 +249,8 @@ func TestGetCodeLes1(t *testing.T) { testGetCode(t, 1) }
func testGetCode(t *testing.T, protocol int) { func testGetCode(t *testing.T, protocol int) {
// Assemble the test environment // Assemble the test environment
pm, _, _ := newTestProtocolManagerMust(t, false, 4, testChainGen) db, _ := ethdb.NewMemDatabase()
pm := newTestProtocolManagerMust(t, false, 4, testChainGen, nil, nil, db)
bc := pm.blockchain.(*core.BlockChain) bc := pm.blockchain.(*core.BlockChain)
peer, _ := newTestPeer(t, "peer", protocol, pm, true) peer, _ := newTestPeer(t, "peer", protocol, pm, true)
defer peer.close() defer peer.close()
@ -278,7 +282,8 @@ func TestGetReceiptLes1(t *testing.T) { testGetReceipt(t, 1) }
func testGetReceipt(t *testing.T, protocol int) { func testGetReceipt(t *testing.T, protocol int) {
// Assemble the test environment // Assemble the test environment
pm, db, _ := newTestProtocolManagerMust(t, false, 4, testChainGen) db, _ := ethdb.NewMemDatabase()
pm := newTestProtocolManagerMust(t, false, 4, testChainGen, nil, nil, db)
bc := pm.blockchain.(*core.BlockChain) bc := pm.blockchain.(*core.BlockChain)
peer, _ := newTestPeer(t, "peer", protocol, pm, true) peer, _ := newTestPeer(t, "peer", protocol, pm, true)
defer peer.close() defer peer.close()
@ -304,7 +309,8 @@ func TestGetProofsLes1(t *testing.T) { testGetReceipt(t, 1) }
func testGetProofs(t *testing.T, protocol int) { func testGetProofs(t *testing.T, protocol int) {
// Assemble the test environment // Assemble the test environment
pm, db, _ := newTestProtocolManagerMust(t, false, 4, testChainGen) db, _ := ethdb.NewMemDatabase()
pm := newTestProtocolManagerMust(t, false, 4, testChainGen, nil, nil, db)
bc := pm.blockchain.(*core.BlockChain) bc := pm.blockchain.(*core.BlockChain)
peer, _ := newTestPeer(t, "peer", protocol, pm, true) peer, _ := newTestPeer(t, "peer", protocol, pm, true)
defer peer.close() defer peer.close()

@ -25,7 +25,6 @@ import (
"math/big" "math/big"
"sync" "sync"
"testing" "testing"
"time"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/consensus/ethash" "github.com/ethereum/go-ethereum/consensus/ethash"
@ -132,22 +131,22 @@ func testRCL() RequestCostList {
// newTestProtocolManager creates a new protocol manager for testing purposes, // newTestProtocolManager creates a new protocol manager for testing purposes,
// with the given number of blocks already known, and potential notification // with the given number of blocks already known, and potential notification
// channels for different events. // channels for different events.
func newTestProtocolManager(lightSync bool, blocks int, generator func(int, *core.BlockGen)) (*ProtocolManager, ethdb.Database, *LesOdr, error) { func newTestProtocolManager(lightSync bool, blocks int, generator func(int, *core.BlockGen), peers *peerSet, odr *LesOdr, db ethdb.Database) (*ProtocolManager, error) {
var ( var (
evmux = new(event.TypeMux) evmux = new(event.TypeMux)
engine = ethash.NewFaker() engine = ethash.NewFaker()
db, _ = ethdb.NewMemDatabase()
gspec = core.Genesis{ gspec = core.Genesis{
Config: params.TestChainConfig, Config: params.TestChainConfig,
Alloc: core.GenesisAlloc{testBankAddress: {Balance: testBankFunds}}, Alloc: core.GenesisAlloc{testBankAddress: {Balance: testBankFunds}},
} }
genesis = gspec.MustCommit(db) genesis = gspec.MustCommit(db)
odr *LesOdr chain BlockChain
chain BlockChain
) )
if peers == nil {
peers = newPeerSet()
}
if lightSync { if lightSync {
odr = NewLesOdr(db)
chain, _ = light.NewLightChain(odr, gspec.Config, engine, evmux) chain, _ = light.NewLightChain(odr, gspec.Config, engine, evmux)
} else { } else {
blockchain, _ := core.NewBlockChain(db, gspec.Config, engine, evmux, vm.Config{}) blockchain, _ := core.NewBlockChain(db, gspec.Config, engine, evmux, vm.Config{})
@ -158,9 +157,9 @@ func newTestProtocolManager(lightSync bool, blocks int, generator func(int, *cor
chain = blockchain chain = blockchain
} }
pm, err := NewProtocolManager(gspec.Config, lightSync, NetworkId, evmux, engine, chain, nil, db, odr, nil) pm, err := NewProtocolManager(gspec.Config, lightSync, NetworkId, evmux, engine, peers, chain, nil, db, odr, nil, make(chan struct{}), new(sync.WaitGroup))
if err != nil { if err != nil {
return nil, nil, nil, err return nil, err
} }
if !lightSync { if !lightSync {
srv := &LesServer{protocolManager: pm} srv := &LesServer{protocolManager: pm}
@ -174,20 +173,20 @@ func newTestProtocolManager(lightSync bool, blocks int, generator func(int, *cor
srv.fcManager = flowcontrol.NewClientManager(50, 10, 1000000000) srv.fcManager = flowcontrol.NewClientManager(50, 10, 1000000000)
srv.fcCostStats = newCostStats(nil) srv.fcCostStats = newCostStats(nil)
} }
pm.Start(nil) pm.Start()
return pm, db, odr, nil return pm, nil
} }
// newTestProtocolManagerMust creates a new protocol manager for testing purposes, // newTestProtocolManagerMust creates a new protocol manager for testing purposes,
// with the given number of blocks already known, and potential notification // with the given number of blocks already known, and potential notification
// channels for different events. In case of an error, the constructor force- // channels for different events. In case of an error, the constructor force-
// fails the test. // fails the test.
func newTestProtocolManagerMust(t *testing.T, lightSync bool, blocks int, generator func(int, *core.BlockGen)) (*ProtocolManager, ethdb.Database, *LesOdr) { func newTestProtocolManagerMust(t *testing.T, lightSync bool, blocks int, generator func(int, *core.BlockGen), peers *peerSet, odr *LesOdr, db ethdb.Database) *ProtocolManager {
pm, db, odr, err := newTestProtocolManager(lightSync, blocks, generator) pm, err := newTestProtocolManager(lightSync, blocks, generator, peers, odr, db)
if err != nil { if err != nil {
t.Fatalf("Failed to create protocol manager: %v", err) t.Fatalf("Failed to create protocol manager: %v", err)
} }
return pm, db, odr return pm
} }
// testTxPool is a fake, helper transaction pool for testing purposes // testTxPool is a fake, helper transaction pool for testing purposes
@ -342,30 +341,3 @@ func (p *testPeer) handshake(t *testing.T, td *big.Int, head common.Hash, headNu
func (p *testPeer) close() { func (p *testPeer) close() {
p.app.Close() p.app.Close()
} }
type testServerPool struct {
peer *peer
lock sync.RWMutex
}
func (p *testServerPool) setPeer(peer *peer) {
p.lock.Lock()
defer p.lock.Unlock()
p.peer = peer
}
func (p *testServerPool) getAllPeers() map[distPeer]struct{} {
p.lock.RLock()
defer p.lock.RUnlock()
m := make(map[distPeer]struct{})
if p.peer != nil {
m[p.peer] = struct{}{}
}
return m
}
func (p *testServerPool) adjustResponseTime(*poolEntry, time.Duration, bool) {
}

@ -18,45 +18,24 @@ package les
import ( import (
"context" "context"
"crypto/rand"
"encoding/binary"
"sync"
"time"
"github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/light"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
) )
var ( // LesOdr implements light.OdrBackend
softRequestTimeout = time.Millisecond * 500
hardRequestTimeout = time.Second * 10
)
// peerDropFn is a callback type for dropping a peer detected as malicious.
type peerDropFn func(id string)
type odrPeerSelector interface {
adjustResponseTime(*poolEntry, time.Duration, bool)
}
type LesOdr struct { type LesOdr struct {
light.OdrBackend db ethdb.Database
db ethdb.Database stop chan struct{}
stop chan struct{} retriever *retrieveManager
removePeer peerDropFn
mlock, clock sync.Mutex
sentReqs map[uint64]*sentReq
serverPool odrPeerSelector
reqDist *requestDistributor
} }
func NewLesOdr(db ethdb.Database) *LesOdr { func NewLesOdr(db ethdb.Database, retriever *retrieveManager) *LesOdr {
return &LesOdr{ return &LesOdr{
db: db, db: db,
stop: make(chan struct{}), retriever: retriever,
sentReqs: make(map[uint64]*sentReq), stop: make(chan struct{}),
} }
} }
@ -68,17 +47,6 @@ func (odr *LesOdr) Database() ethdb.Database {
return odr.db return odr.db
} }
// validatorFunc is a function that processes a message.
type validatorFunc func(ethdb.Database, *Msg) error
// sentReq is a request waiting for an answer that satisfies its valFunc
type sentReq struct {
valFunc validatorFunc
sentTo map[*peer]chan struct{}
lock sync.RWMutex // protects acces to sentTo
answered chan struct{} // closed and set to nil when any peer answers it
}
const ( const (
MsgBlockBodies = iota MsgBlockBodies = iota
MsgCode MsgCode
@ -94,156 +62,29 @@ type Msg struct {
Obj interface{} Obj interface{}
} }
// Deliver is called by the LES protocol manager to deliver ODR reply messages to waiting requests // Retrieve tries to fetch an object from the LES network.
func (self *LesOdr) Deliver(peer *peer, msg *Msg) error { // If the network retrieval was successful, it stores the object in local db.
var delivered chan struct{} func (self *LesOdr) Retrieve(ctx context.Context, req light.OdrRequest) (err error) {
self.mlock.Lock() lreq := LesRequest(req)
req, ok := self.sentReqs[msg.ReqID]
self.mlock.Unlock()
if ok {
req.lock.Lock()
delivered, ok = req.sentTo[peer]
req.lock.Unlock()
}
if !ok {
return errResp(ErrUnexpectedResponse, "reqID = %v", msg.ReqID)
}
if err := req.valFunc(self.db, msg); err != nil {
peer.Log().Warn("Invalid odr response", "err", err)
return errResp(ErrInvalidResponse, "reqID = %v", msg.ReqID)
}
close(delivered)
req.lock.Lock()
delete(req.sentTo, peer)
if req.answered != nil {
close(req.answered)
req.answered = nil
}
req.lock.Unlock()
return nil
}
func (self *LesOdr) requestPeer(req *sentReq, peer *peer, delivered, timeout chan struct{}, reqWg *sync.WaitGroup) {
stime := mclock.Now()
defer func() {
req.lock.Lock()
delete(req.sentTo, peer)
req.lock.Unlock()
reqWg.Done()
}()
select {
case <-delivered:
if self.serverPool != nil {
self.serverPool.adjustResponseTime(peer.poolEntry, time.Duration(mclock.Now()-stime), false)
}
return
case <-time.After(softRequestTimeout):
close(timeout)
case <-self.stop:
return
}
select {
case <-delivered:
case <-time.After(hardRequestTimeout):
peer.Log().Debug("Request timed out hard")
go self.removePeer(peer.id)
case <-self.stop:
return
}
if self.serverPool != nil {
self.serverPool.adjustResponseTime(peer.poolEntry, time.Duration(mclock.Now()-stime), true)
}
}
// networkRequest sends a request to known peers until an answer is received
// or the context is cancelled
func (self *LesOdr) networkRequest(ctx context.Context, lreq LesOdrRequest) error {
answered := make(chan struct{})
req := &sentReq{
valFunc: lreq.Validate,
sentTo: make(map[*peer]chan struct{}),
answered: answered, // reply delivered by any peer
}
exclude := make(map[*peer]struct{})
reqWg := new(sync.WaitGroup)
reqWg.Add(1)
defer reqWg.Done()
var timeout chan struct{} reqID := genReqID()
reqID := getNextReqID()
rq := &distReq{ rq := &distReq{
getCost: func(dp distPeer) uint64 { getCost: func(dp distPeer) uint64 {
return lreq.GetCost(dp.(*peer)) return lreq.GetCost(dp.(*peer))
}, },
canSend: func(dp distPeer) bool { canSend: func(dp distPeer) bool {
p := dp.(*peer) p := dp.(*peer)
_, ok := exclude[p] return lreq.CanSend(p)
return !ok && lreq.CanSend(p)
}, },
request: func(dp distPeer) func() { request: func(dp distPeer) func() {
p := dp.(*peer) p := dp.(*peer)
exclude[p] = struct{}{}
delivered := make(chan struct{})
timeout = make(chan struct{})
req.lock.Lock()
req.sentTo[p] = delivered
req.lock.Unlock()
reqWg.Add(1)
cost := lreq.GetCost(p) cost := lreq.GetCost(p)
p.fcServer.QueueRequest(reqID, cost) p.fcServer.QueueRequest(reqID, cost)
go self.requestPeer(req, p, delivered, timeout, reqWg)
return func() { lreq.Request(reqID, p) } return func() { lreq.Request(reqID, p) }
}, },
} }
self.mlock.Lock() if err = self.retriever.retrieve(ctx, reqID, rq, func(p distPeer, msg *Msg) error { return lreq.Validate(self.db, msg) }); err == nil {
self.sentReqs[reqID] = req
self.mlock.Unlock()
go func() {
reqWg.Wait()
self.mlock.Lock()
delete(self.sentReqs, reqID)
self.mlock.Unlock()
}()
for {
peerChn := self.reqDist.queue(rq)
select {
case <-ctx.Done():
self.reqDist.cancel(rq)
return ctx.Err()
case <-answered:
self.reqDist.cancel(rq)
return nil
case _, ok := <-peerChn:
if !ok {
return ErrNoPeers
}
}
select {
case <-ctx.Done():
return ctx.Err()
case <-answered:
return nil
case <-timeout:
}
}
}
// Retrieve tries to fetch an object from the LES network.
// If the network retrieval was successful, it stores the object in local db.
func (self *LesOdr) Retrieve(ctx context.Context, req light.OdrRequest) (err error) {
lreq := LesRequest(req)
err = self.networkRequest(ctx, lreq)
if err == nil {
// retrieved from network, store in db // retrieved from network, store in db
req.StoreResult(self.db) req.StoreResult(self.db)
} else { } else {
@ -251,9 +92,3 @@ func (self *LesOdr) Retrieve(ctx context.Context, req light.OdrRequest) (err err
} }
return return
} }
func getNextReqID() uint64 {
var rnd [8]byte
rand.Read(rnd[:])
return binary.BigEndian.Uint64(rnd[:])
}

@ -158,15 +158,15 @@ func odrContractCall(ctx context.Context, db ethdb.Database, config *params.Chai
func testOdr(t *testing.T, protocol int, expFail uint64, fn odrTestFn) { func testOdr(t *testing.T, protocol int, expFail uint64, fn odrTestFn) {
// Assemble the test environment // Assemble the test environment
pm, db, odr := newTestProtocolManagerMust(t, false, 4, testChainGen) peers := newPeerSet()
lpm, ldb, odr := newTestProtocolManagerMust(t, true, 0, nil) dist := newRequestDistributor(peers, make(chan struct{}))
rm := newRetrieveManager(peers, dist, nil)
db, _ := ethdb.NewMemDatabase()
ldb, _ := ethdb.NewMemDatabase()
odr := NewLesOdr(ldb, rm)
pm := newTestProtocolManagerMust(t, false, 4, testChainGen, nil, nil, db)
lpm := newTestProtocolManagerMust(t, true, 0, nil, peers, odr, ldb)
_, err1, lpeer, err2 := newTestPeerPair("peer", protocol, pm, lpm) _, err1, lpeer, err2 := newTestPeerPair("peer", protocol, pm, lpm)
pool := &testServerPool{}
lpm.reqDist = newRequestDistributor(pool.getAllPeers, lpm.quitSync)
odr.reqDist = lpm.reqDist
pool.setPeer(lpeer)
odr.serverPool = pool
lpeer.hasBlock = func(common.Hash, uint64) bool { return true }
select { select {
case <-time.After(time.Millisecond * 100): case <-time.After(time.Millisecond * 100):
case err := <-err1: case err := <-err1:
@ -198,13 +198,19 @@ func testOdr(t *testing.T, protocol int, expFail uint64, fn odrTestFn) {
} }
// temporarily remove peer to test odr fails // temporarily remove peer to test odr fails
pool.setPeer(nil)
// expect retrievals to fail (except genesis block) without a les peer // expect retrievals to fail (except genesis block) without a les peer
peers.Unregister(lpeer.id)
time.Sleep(time.Millisecond * 10) // ensure that all peerSetNotify callbacks are executed
test(expFail) test(expFail)
pool.setPeer(lpeer)
// expect all retrievals to pass // expect all retrievals to pass
peers.Register(lpeer)
time.Sleep(time.Millisecond * 10) // ensure that all peerSetNotify callbacks are executed
lpeer.lock.Lock()
lpeer.hasBlock = func(common.Hash, uint64) bool { return true }
lpeer.lock.Unlock()
test(5) test(5)
pool.setPeer(nil)
// still expect all retrievals to pass, now data should be cached locally // still expect all retrievals to pass, now data should be cached locally
peers.Unregister(lpeer.id)
time.Sleep(time.Millisecond * 10) // ensure that all peerSetNotify callbacks are executed
test(5) test(5)
} }

@ -166,9 +166,9 @@ func (p *peer) GetRequestCost(msgcode uint64, amount int) uint64 {
// HasBlock checks if the peer has a given block // HasBlock checks if the peer has a given block
func (p *peer) HasBlock(hash common.Hash, number uint64) bool { func (p *peer) HasBlock(hash common.Hash, number uint64) bool {
p.lock.RLock() p.lock.RLock()
hashBlock := p.hasBlock hasBlock := p.hasBlock
p.lock.RUnlock() p.lock.RUnlock()
return hashBlock != nil && hashBlock(hash, number) return hasBlock != nil && hasBlock(hash, number)
} }
// SendAnnounce announces the availability of a number of blocks through // SendAnnounce announces the availability of a number of blocks through
@ -433,12 +433,20 @@ func (p *peer) String() string {
) )
} }
// peerSetNotify is a callback interface to notify services about added or
// removed peers
type peerSetNotify interface {
registerPeer(*peer)
unregisterPeer(*peer)
}
// peerSet represents the collection of active peers currently participating in // peerSet represents the collection of active peers currently participating in
// the Light Ethereum sub-protocol. // the Light Ethereum sub-protocol.
type peerSet struct { type peerSet struct {
peers map[string]*peer peers map[string]*peer
lock sync.RWMutex lock sync.RWMutex
closed bool notifyList []peerSetNotify
closed bool
} }
// newPeerSet creates a new peer set to track the active participants. // newPeerSet creates a new peer set to track the active participants.
@ -448,6 +456,17 @@ func newPeerSet() *peerSet {
} }
} }
// notify adds a service to be notified about added or removed peers
func (ps *peerSet) notify(n peerSetNotify) {
ps.lock.Lock()
defer ps.lock.Unlock()
ps.notifyList = append(ps.notifyList, n)
for _, p := range ps.peers {
go n.registerPeer(p)
}
}
// Register injects a new peer into the working set, or returns an error if the // Register injects a new peer into the working set, or returns an error if the
// peer is already known. // peer is already known.
func (ps *peerSet) Register(p *peer) error { func (ps *peerSet) Register(p *peer) error {
@ -462,11 +481,14 @@ func (ps *peerSet) Register(p *peer) error {
} }
ps.peers[p.id] = p ps.peers[p.id] = p
p.sendQueue = newExecQueue(100) p.sendQueue = newExecQueue(100)
for _, n := range ps.notifyList {
go n.registerPeer(p)
}
return nil return nil
} }
// Unregister removes a remote peer from the active set, disabling any further // Unregister removes a remote peer from the active set, disabling any further
// actions to/from that particular entity. // actions to/from that particular entity. It also initiates disconnection at the networking layer.
func (ps *peerSet) Unregister(id string) error { func (ps *peerSet) Unregister(id string) error {
ps.lock.Lock() ps.lock.Lock()
defer ps.lock.Unlock() defer ps.lock.Unlock()
@ -474,7 +496,11 @@ func (ps *peerSet) Unregister(id string) error {
if p, ok := ps.peers[id]; !ok { if p, ok := ps.peers[id]; !ok {
return errNotRegistered return errNotRegistered
} else { } else {
for _, n := range ps.notifyList {
go n.unregisterPeer(p)
}
p.sendQueue.quit() p.sendQueue.quit()
p.Peer.Disconnect(p2p.DiscUselessPeer)
} }
delete(ps.peers, id) delete(ps.peers, id)
return nil return nil

@ -68,15 +68,16 @@ func tfCodeAccess(db ethdb.Database, bhash common.Hash, number uint64) light.Odr
func testAccess(t *testing.T, protocol int, fn accessTestFn) { func testAccess(t *testing.T, protocol int, fn accessTestFn) {
// Assemble the test environment // Assemble the test environment
pm, db, _ := newTestProtocolManagerMust(t, false, 4, testChainGen) peers := newPeerSet()
lpm, ldb, odr := newTestProtocolManagerMust(t, true, 0, nil) dist := newRequestDistributor(peers, make(chan struct{}))
rm := newRetrieveManager(peers, dist, nil)
db, _ := ethdb.NewMemDatabase()
ldb, _ := ethdb.NewMemDatabase()
odr := NewLesOdr(ldb, rm)
pm := newTestProtocolManagerMust(t, false, 4, testChainGen, nil, nil, db)
lpm := newTestProtocolManagerMust(t, true, 0, nil, peers, odr, ldb)
_, err1, lpeer, err2 := newTestPeerPair("peer", protocol, pm, lpm) _, err1, lpeer, err2 := newTestPeerPair("peer", protocol, pm, lpm)
pool := &testServerPool{}
lpm.reqDist = newRequestDistributor(pool.getAllPeers, lpm.quitSync)
odr.reqDist = lpm.reqDist
pool.setPeer(lpeer)
odr.serverPool = pool
lpeer.hasBlock = func(common.Hash, uint64) bool { return true }
select { select {
case <-time.After(time.Millisecond * 100): case <-time.After(time.Millisecond * 100):
case err := <-err1: case err := <-err1:
@ -108,10 +109,16 @@ func testAccess(t *testing.T, protocol int, fn accessTestFn) {
} }
// temporarily remove peer to test odr fails // temporarily remove peer to test odr fails
pool.setPeer(nil) peers.Unregister(lpeer.id)
time.Sleep(time.Millisecond * 10) // ensure that all peerSetNotify callbacks are executed
// expect retrievals to fail (except genesis block) without a les peer // expect retrievals to fail (except genesis block) without a les peer
test(0) test(0)
pool.setPeer(lpeer)
peers.Register(lpeer)
time.Sleep(time.Millisecond * 10) // ensure that all peerSetNotify callbacks are executed
lpeer.lock.Lock()
lpeer.hasBlock = func(common.Hash, uint64) bool { return true }
lpeer.lock.Unlock()
// expect all retrievals to pass // expect all retrievals to pass
test(5) test(5)
} }

@ -0,0 +1,395 @@
// Copyright 2016 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
// Package light implements on-demand retrieval capable state and chain objects
// for the Ethereum Light Client.
package les
import (
"context"
"crypto/rand"
"encoding/binary"
"sync"
"time"
"github.com/ethereum/go-ethereum/common/mclock"
)
var (
retryQueue = time.Millisecond * 100
softRequestTimeout = time.Millisecond * 500
hardRequestTimeout = time.Second * 10
)
// retrieveManager is a layer on top of requestDistributor which takes care of
// matching replies by request ID and handles timeouts and resends if necessary.
type retrieveManager struct {
dist *requestDistributor
peers *peerSet
serverPool peerSelector
lock sync.RWMutex
sentReqs map[uint64]*sentReq
}
// validatorFunc is a function that processes a reply message
type validatorFunc func(distPeer, *Msg) error
// peerSelector receives feedback info about response times and timeouts
type peerSelector interface {
adjustResponseTime(*poolEntry, time.Duration, bool)
}
// sentReq represents a request sent and tracked by retrieveManager
type sentReq struct {
rm *retrieveManager
req *distReq
id uint64
validate validatorFunc
eventsCh chan reqPeerEvent
stopCh chan struct{}
stopped bool
err error
lock sync.RWMutex // protect access to sentTo map
sentTo map[distPeer]sentReqToPeer
reqQueued bool // a request has been queued but not sent
reqSent bool // a request has been sent but not timed out
reqSrtoCount int // number of requests that reached soft (but not hard) timeout
}
// sentReqToPeer notifies the request-from-peer goroutine (tryRequest) about a response
// delivered by the given peer. Only one delivery is allowed per request per peer,
// after which delivered is set to true, the validity of the response is sent on the
// valid channel and no more responses are accepted.
type sentReqToPeer struct {
delivered bool
valid chan bool
}
// reqPeerEvent is sent by the request-from-peer goroutine (tryRequest) to the
// request state machine (retrieveLoop) through the eventsCh channel.
type reqPeerEvent struct {
event int
peer distPeer
}
const (
rpSent = iota // if peer == nil, not sent (no suitable peers)
rpSoftTimeout
rpHardTimeout
rpDeliveredValid
rpDeliveredInvalid
)
// newRetrieveManager creates the retrieve manager
func newRetrieveManager(peers *peerSet, dist *requestDistributor, serverPool peerSelector) *retrieveManager {
return &retrieveManager{
peers: peers,
dist: dist,
serverPool: serverPool,
sentReqs: make(map[uint64]*sentReq),
}
}
// retrieve sends a request (to multiple peers if necessary) and waits for an answer
// that is delivered through the deliver function and successfully validated by the
// validator callback. It returns when a valid answer is delivered or the context is
// cancelled.
func (rm *retrieveManager) retrieve(ctx context.Context, reqID uint64, req *distReq, val validatorFunc) error {
sentReq := rm.sendReq(reqID, req, val)
select {
case <-sentReq.stopCh:
case <-ctx.Done():
sentReq.stop(ctx.Err())
}
return sentReq.getError()
}
// sendReq starts a process that keeps trying to retrieve a valid answer for a
// request from any suitable peers until stopped or succeeded.
func (rm *retrieveManager) sendReq(reqID uint64, req *distReq, val validatorFunc) *sentReq {
r := &sentReq{
rm: rm,
req: req,
id: reqID,
sentTo: make(map[distPeer]sentReqToPeer),
stopCh: make(chan struct{}),
eventsCh: make(chan reqPeerEvent, 10),
validate: val,
}
canSend := req.canSend
req.canSend = func(p distPeer) bool {
// add an extra check to canSend: the request has not been sent to the same peer before
r.lock.RLock()
_, sent := r.sentTo[p]
r.lock.RUnlock()
return !sent && canSend(p)
}
request := req.request
req.request = func(p distPeer) func() {
// before actually sending the request, put an entry into the sentTo map
r.lock.Lock()
r.sentTo[p] = sentReqToPeer{false, make(chan bool, 1)}
r.lock.Unlock()
return request(p)
}
rm.lock.Lock()
rm.sentReqs[reqID] = r
rm.lock.Unlock()
go r.retrieveLoop()
return r
}
// deliver is called by the LES protocol manager to deliver reply messages to waiting requests
func (rm *retrieveManager) deliver(peer distPeer, msg *Msg) error {
rm.lock.RLock()
req, ok := rm.sentReqs[msg.ReqID]
rm.lock.RUnlock()
if ok {
return req.deliver(peer, msg)
}
return errResp(ErrUnexpectedResponse, "reqID = %v", msg.ReqID)
}
// reqStateFn represents a state of the retrieve loop state machine
type reqStateFn func() reqStateFn
// retrieveLoop is the retrieval state machine event loop
func (r *sentReq) retrieveLoop() {
go r.tryRequest()
r.reqQueued = true
state := r.stateRequesting
for state != nil {
state = state()
}
r.rm.lock.Lock()
delete(r.rm.sentReqs, r.id)
r.rm.lock.Unlock()
}
// stateRequesting: a request has been queued or sent recently; when it reaches soft timeout,
// a new request is sent to a new peer
func (r *sentReq) stateRequesting() reqStateFn {
select {
case ev := <-r.eventsCh:
r.update(ev)
switch ev.event {
case rpSent:
if ev.peer == nil {
// request send failed, no more suitable peers
if r.waiting() {
// we are already waiting for sent requests which may succeed so keep waiting
return r.stateNoMorePeers
}
// nothing to wait for, no more peers to ask, return with error
r.stop(ErrNoPeers)
// no need to go to stopped state because waiting() already returned false
return nil
}
case rpSoftTimeout:
// last request timed out, try asking a new peer
go r.tryRequest()
r.reqQueued = true
return r.stateRequesting
case rpDeliveredValid:
r.stop(nil)
return r.stateStopped
}
return r.stateRequesting
case <-r.stopCh:
return r.stateStopped
}
}
// stateNoMorePeers: could not send more requests because no suitable peers are available.
// Peers may become suitable for a certain request later or new peers may appear so we
// keep trying.
func (r *sentReq) stateNoMorePeers() reqStateFn {
select {
case <-time.After(retryQueue):
go r.tryRequest()
r.reqQueued = true
return r.stateRequesting
case ev := <-r.eventsCh:
r.update(ev)
if ev.event == rpDeliveredValid {
r.stop(nil)
return r.stateStopped
}
return r.stateNoMorePeers
case <-r.stopCh:
return r.stateStopped
}
}
// stateStopped: request succeeded or cancelled, just waiting for some peers
// to either answer or time out hard
func (r *sentReq) stateStopped() reqStateFn {
for r.waiting() {
r.update(<-r.eventsCh)
}
return nil
}
// update updates the queued/sent flags and timed out peers counter according to the event
func (r *sentReq) update(ev reqPeerEvent) {
switch ev.event {
case rpSent:
r.reqQueued = false
if ev.peer != nil {
r.reqSent = true
}
case rpSoftTimeout:
r.reqSent = false
r.reqSrtoCount++
case rpHardTimeout, rpDeliveredValid, rpDeliveredInvalid:
r.reqSrtoCount--
}
}
// waiting returns true if the retrieval mechanism is waiting for an answer from
// any peer
func (r *sentReq) waiting() bool {
return r.reqQueued || r.reqSent || r.reqSrtoCount > 0
}
// tryRequest tries to send the request to a new peer and waits for it to either
// succeed or time out if it has been sent. It also sends the appropriate reqPeerEvent
// messages to the request's event channel.
func (r *sentReq) tryRequest() {
sent := r.rm.dist.queue(r.req)
var p distPeer
select {
case p = <-sent:
case <-r.stopCh:
if r.rm.dist.cancel(r.req) {
p = nil
} else {
p = <-sent
}
}
r.eventsCh <- reqPeerEvent{rpSent, p}
if p == nil {
return
}
reqSent := mclock.Now()
srto, hrto := false, false
r.lock.RLock()
s, ok := r.sentTo[p]
r.lock.RUnlock()
if !ok {
panic(nil)
}
defer func() {
// send feedback to server pool and remove peer if hard timeout happened
pp, ok := p.(*peer)
if ok && r.rm.serverPool != nil {
respTime := time.Duration(mclock.Now() - reqSent)
r.rm.serverPool.adjustResponseTime(pp.poolEntry, respTime, srto)
}
if hrto {
pp.Log().Debug("Request timed out hard")
if r.rm.peers != nil {
r.rm.peers.Unregister(pp.id)
}
}
r.lock.Lock()
delete(r.sentTo, p)
r.lock.Unlock()
}()
select {
case ok := <-s.valid:
if ok {
r.eventsCh <- reqPeerEvent{rpDeliveredValid, p}
} else {
r.eventsCh <- reqPeerEvent{rpDeliveredInvalid, p}
}
return
case <-time.After(softRequestTimeout):
srto = true
r.eventsCh <- reqPeerEvent{rpSoftTimeout, p}
}
select {
case ok := <-s.valid:
if ok {
r.eventsCh <- reqPeerEvent{rpDeliveredValid, p}
} else {
r.eventsCh <- reqPeerEvent{rpDeliveredInvalid, p}
}
case <-time.After(hardRequestTimeout):
hrto = true
r.eventsCh <- reqPeerEvent{rpHardTimeout, p}
}
}
// deliver a reply belonging to this request
func (r *sentReq) deliver(peer distPeer, msg *Msg) error {
r.lock.Lock()
defer r.lock.Unlock()
s, ok := r.sentTo[peer]
if !ok || s.delivered {
return errResp(ErrUnexpectedResponse, "reqID = %v", msg.ReqID)
}
valid := r.validate(peer, msg) == nil
r.sentTo[peer] = sentReqToPeer{true, s.valid}
s.valid <- valid
if !valid {
return errResp(ErrInvalidResponse, "reqID = %v", msg.ReqID)
}
return nil
}
// stop stops the retrieval process and sets an error code that will be returned
// by getError
func (r *sentReq) stop(err error) {
r.lock.Lock()
if !r.stopped {
r.stopped = true
r.err = err
close(r.stopCh)
}
r.lock.Unlock()
}
// getError returns any retrieval error (either internally generated or set by the
// stop function) after stopCh has been closed
func (r *sentReq) getError() error {
return r.err
}
// genReqID generates a new random request ID
func genReqID() uint64 {
var rnd [8]byte
rand.Read(rnd[:])
return binary.BigEndian.Uint64(rnd[:])
}

@ -32,6 +32,7 @@ import (
"github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/light"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/discv5"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/trie"
) )
@ -41,17 +42,24 @@ type LesServer struct {
fcManager *flowcontrol.ClientManager // nil if our node is client only fcManager *flowcontrol.ClientManager // nil if our node is client only
fcCostStats *requestCostStats fcCostStats *requestCostStats
defParams *flowcontrol.ServerParams defParams *flowcontrol.ServerParams
lesTopic discv5.Topic
quitSync chan struct{}
stopped bool stopped bool
} }
func NewLesServer(eth *eth.Ethereum, config *eth.Config) (*LesServer, error) { func NewLesServer(eth *eth.Ethereum, config *eth.Config) (*LesServer, error) {
pm, err := NewProtocolManager(eth.BlockChain().Config(), false, config.NetworkId, eth.EventMux(), eth.Engine(), eth.BlockChain(), eth.TxPool(), eth.ChainDb(), nil, nil) quitSync := make(chan struct{})
pm, err := NewProtocolManager(eth.BlockChain().Config(), false, config.NetworkId, eth.EventMux(), eth.Engine(), newPeerSet(), eth.BlockChain(), eth.TxPool(), eth.ChainDb(), nil, nil, quitSync, new(sync.WaitGroup))
if err != nil { if err != nil {
return nil, err return nil, err
} }
pm.blockLoop() pm.blockLoop()
srv := &LesServer{protocolManager: pm} srv := &LesServer{
protocolManager: pm,
quitSync: quitSync,
lesTopic: lesTopic(eth.BlockChain().Genesis().Hash()),
}
pm.server = srv pm.server = srv
srv.defParams = &flowcontrol.ServerParams{ srv.defParams = &flowcontrol.ServerParams{
@ -69,7 +77,14 @@ func (s *LesServer) Protocols() []p2p.Protocol {
// Start starts the LES server // Start starts the LES server
func (s *LesServer) Start(srvr *p2p.Server) { func (s *LesServer) Start(srvr *p2p.Server) {
s.protocolManager.Start(srvr) s.protocolManager.Start()
go func() {
logger := log.New("topic", s.lesTopic)
logger.Info("Starting topic registration")
defer logger.Info("Terminated topic registration")
srvr.DiscV5.RegisterTopic(s.lesTopic, s.quitSync)
}()
} }
// Stop stops the LES service // Stop stops the LES service

@ -102,6 +102,8 @@ type serverPool struct {
wg *sync.WaitGroup wg *sync.WaitGroup
connWg sync.WaitGroup connWg sync.WaitGroup
topic discv5.Topic
discSetPeriod chan time.Duration discSetPeriod chan time.Duration
discNodes chan *discv5.Node discNodes chan *discv5.Node
discLookups chan bool discLookups chan bool
@ -118,11 +120,9 @@ type serverPool struct {
} }
// newServerPool creates a new serverPool instance // newServerPool creates a new serverPool instance
func newServerPool(db ethdb.Database, dbPrefix []byte, server *p2p.Server, topic discv5.Topic, quit chan struct{}, wg *sync.WaitGroup) *serverPool { func newServerPool(db ethdb.Database, quit chan struct{}, wg *sync.WaitGroup) *serverPool {
pool := &serverPool{ pool := &serverPool{
db: db, db: db,
dbKey: append(dbPrefix, []byte(topic)...),
server: server,
quit: quit, quit: quit,
wg: wg, wg: wg,
entries: make(map[discover.NodeID]*poolEntry), entries: make(map[discover.NodeID]*poolEntry),
@ -135,19 +135,25 @@ func newServerPool(db ethdb.Database, dbPrefix []byte, server *p2p.Server, topic
} }
pool.knownQueue = newPoolEntryQueue(maxKnownEntries, pool.removeEntry) pool.knownQueue = newPoolEntryQueue(maxKnownEntries, pool.removeEntry)
pool.newQueue = newPoolEntryQueue(maxNewEntries, pool.removeEntry) pool.newQueue = newPoolEntryQueue(maxNewEntries, pool.removeEntry)
wg.Add(1) return pool
}
func (pool *serverPool) start(server *p2p.Server, topic discv5.Topic) {
pool.server = server
pool.topic = topic
pool.dbKey = append([]byte("serverPool/"), []byte(topic)...)
pool.wg.Add(1)
pool.loadNodes() pool.loadNodes()
pool.checkDial()
go pool.eventLoop()
pool.checkDial()
if pool.server.DiscV5 != nil { if pool.server.DiscV5 != nil {
pool.discSetPeriod = make(chan time.Duration, 1) pool.discSetPeriod = make(chan time.Duration, 1)
pool.discNodes = make(chan *discv5.Node, 100) pool.discNodes = make(chan *discv5.Node, 100)
pool.discLookups = make(chan bool, 100) pool.discLookups = make(chan bool, 100)
go pool.server.DiscV5.SearchTopic(topic, pool.discSetPeriod, pool.discNodes, pool.discLookups) go pool.server.DiscV5.SearchTopic(pool.topic, pool.discSetPeriod, pool.discNodes, pool.discLookups)
} }
go pool.eventLoop()
return pool
} }
// connect should be called upon any incoming connection. If the connection has been // connect should be called upon any incoming connection. If the connection has been
@ -485,7 +491,7 @@ func (pool *serverPool) checkDial() {
// dial initiates a new connection // dial initiates a new connection
func (pool *serverPool) dial(entry *poolEntry, knownSelected bool) { func (pool *serverPool) dial(entry *poolEntry, knownSelected bool) {
if entry.state != psNotConnected { if pool.server == nil || entry.state != psNotConnected {
return return
} }
entry.state = psDialed entry.state = psDialed

@ -39,26 +39,28 @@ type LesTxRelay struct {
reqDist *requestDistributor reqDist *requestDistributor
} }
func NewLesTxRelay() *LesTxRelay { func NewLesTxRelay(ps *peerSet, reqDist *requestDistributor) *LesTxRelay {
return &LesTxRelay{ r := &LesTxRelay{
txSent: make(map[common.Hash]*ltrInfo), txSent: make(map[common.Hash]*ltrInfo),
txPending: make(map[common.Hash]struct{}), txPending: make(map[common.Hash]struct{}),
ps: ps,
reqDist: reqDist,
} }
ps.notify(r)
return r
} }
func (self *LesTxRelay) addPeer(p *peer) { func (self *LesTxRelay) registerPeer(p *peer) {
self.lock.Lock() self.lock.Lock()
defer self.lock.Unlock() defer self.lock.Unlock()
self.ps.Register(p)
self.peerList = self.ps.AllPeers() self.peerList = self.ps.AllPeers()
} }
func (self *LesTxRelay) removePeer(id string) { func (self *LesTxRelay) unregisterPeer(p *peer) {
self.lock.Lock() self.lock.Lock()
defer self.lock.Unlock() defer self.lock.Unlock()
self.ps.Unregister(id)
self.peerList = self.ps.AllPeers() self.peerList = self.ps.AllPeers()
} }
@ -112,7 +114,7 @@ func (self *LesTxRelay) send(txs types.Transactions, count int) {
pp := p pp := p
ll := list ll := list
reqID := getNextReqID() reqID := genReqID()
rq := &distReq{ rq := &distReq{
getCost: func(dp distPeer) uint64 { getCost: func(dp distPeer) uint64 {
peer := dp.(*peer) peer := dp.(*peer)

Loading…
Cancel
Save