les: handler separation (#19639)

les: handler separation
pull/19491/head
gary rong 5 years ago committed by Felföldi Zsolt
parent 4aee0d1994
commit 2ed729d38e
  1. 22
      core/blockchain.go
  2. 14
      les/api.go
  3. 2
      les/api_backend.go
  4. 18
      les/api_test.go
  5. 47
      les/benchmark.go
  6. 3
      les/bloombits.go
  7. 131
      les/client.go
  8. 401
      les/client_handler.go
  9. 69
      les/commons.go
  10. 11
      les/costtracker.go
  11. 37
      les/distributor.go
  12. 2
      les/distributor_test.go
  13. 75
      les/fetcher.go
  14. 168
      les/fetcher_test.go
  15. 1293
      les/handler.go
  16. 198
      les/handler_test.go
  17. 90
      les/metrics.go
  18. 5
      les/odr.go
  19. 38
      les/odr_test.go
  20. 48
      les/peer.go
  21. 80
      les/peer_test.go
  22. 28
      les/request_test.go
  23. 340
      les/server.go
  24. 921
      les/server_handler.go
  25. 61
      les/serverpool.go
  26. 71
      les/sync.go
  27. 17
      les/sync_test.go
  28. 448
      les/test_helper.go
  29. 226
      les/ulc_test.go
  30. 6
      light/odr_util.go
  31. 6
      light/postprocess.go

@ -75,6 +75,7 @@ const (
bodyCacheLimit = 256 bodyCacheLimit = 256
blockCacheLimit = 256 blockCacheLimit = 256
receiptsCacheLimit = 32 receiptsCacheLimit = 32
txLookupCacheLimit = 1024
maxFutureBlocks = 256 maxFutureBlocks = 256
maxTimeFutureBlocks = 30 maxTimeFutureBlocks = 30
badBlockLimit = 10 badBlockLimit = 10
@ -155,6 +156,7 @@ type BlockChain struct {
bodyRLPCache *lru.Cache // Cache for the most recent block bodies in RLP encoded format bodyRLPCache *lru.Cache // Cache for the most recent block bodies in RLP encoded format
receiptsCache *lru.Cache // Cache for the most recent receipts per block receiptsCache *lru.Cache // Cache for the most recent receipts per block
blockCache *lru.Cache // Cache for the most recent entire blocks blockCache *lru.Cache // Cache for the most recent entire blocks
txLookupCache *lru.Cache // Cache for the most recent transaction lookup data.
futureBlocks *lru.Cache // future blocks are blocks added for later processing futureBlocks *lru.Cache // future blocks are blocks added for later processing
quit chan struct{} // blockchain quit channel quit chan struct{} // blockchain quit channel
@ -189,6 +191,7 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *par
bodyRLPCache, _ := lru.New(bodyCacheLimit) bodyRLPCache, _ := lru.New(bodyCacheLimit)
receiptsCache, _ := lru.New(receiptsCacheLimit) receiptsCache, _ := lru.New(receiptsCacheLimit)
blockCache, _ := lru.New(blockCacheLimit) blockCache, _ := lru.New(blockCacheLimit)
txLookupCache, _ := lru.New(txLookupCacheLimit)
futureBlocks, _ := lru.New(maxFutureBlocks) futureBlocks, _ := lru.New(maxFutureBlocks)
badBlocks, _ := lru.New(badBlockLimit) badBlocks, _ := lru.New(badBlockLimit)
@ -204,6 +207,7 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *par
bodyRLPCache: bodyRLPCache, bodyRLPCache: bodyRLPCache,
receiptsCache: receiptsCache, receiptsCache: receiptsCache,
blockCache: blockCache, blockCache: blockCache,
txLookupCache: txLookupCache,
futureBlocks: futureBlocks, futureBlocks: futureBlocks,
engine: engine, engine: engine,
vmConfig: vmConfig, vmConfig: vmConfig,
@ -440,6 +444,7 @@ func (bc *BlockChain) SetHead(head uint64) error {
bc.bodyRLPCache.Purge() bc.bodyRLPCache.Purge()
bc.receiptsCache.Purge() bc.receiptsCache.Purge()
bc.blockCache.Purge() bc.blockCache.Purge()
bc.txLookupCache.Purge()
bc.futureBlocks.Purge() bc.futureBlocks.Purge()
return bc.loadLastState() return bc.loadLastState()
@ -921,6 +926,7 @@ func (bc *BlockChain) truncateAncient(head uint64) error {
bc.bodyRLPCache.Purge() bc.bodyRLPCache.Purge()
bc.receiptsCache.Purge() bc.receiptsCache.Purge()
bc.blockCache.Purge() bc.blockCache.Purge()
bc.txLookupCache.Purge()
bc.futureBlocks.Purge() bc.futureBlocks.Purge()
log.Info("Rewind ancient data", "number", head) log.Info("Rewind ancient data", "number", head)
@ -2151,6 +2157,22 @@ func (bc *BlockChain) GetHeaderByNumber(number uint64) *types.Header {
return bc.hc.GetHeaderByNumber(number) return bc.hc.GetHeaderByNumber(number)
} }
// GetTransactionLookup retrieves the lookup associate with the given transaction
// hash from the cache or database.
func (bc *BlockChain) GetTransactionLookup(hash common.Hash) *rawdb.LegacyTxLookupEntry {
// Short circuit if the txlookup already in the cache, retrieve otherwise
if lookup, exist := bc.txLookupCache.Get(hash); exist {
return lookup.(*rawdb.LegacyTxLookupEntry)
}
tx, blockHash, blockNumber, txIndex := rawdb.ReadTransaction(bc.db, hash)
if tx == nil {
return nil
}
lookup := &rawdb.LegacyTxLookupEntry{BlockHash: blockHash, BlockIndex: blockNumber, Index: txIndex}
bc.txLookupCache.Add(hash, lookup)
return lookup
}
// Config retrieves the chain's fork configuration. // Config retrieves the chain's fork configuration.
func (bc *BlockChain) Config() *params.ChainConfig { return bc.chainConfig } func (bc *BlockChain) Config() *params.ChainConfig { return bc.chainConfig }

@ -30,15 +30,11 @@ var (
// PrivateLightAPI provides an API to access the LES light server or light client. // PrivateLightAPI provides an API to access the LES light server or light client.
type PrivateLightAPI struct { type PrivateLightAPI struct {
backend *lesCommons backend *lesCommons
reg *checkpointOracle
} }
// NewPrivateLightAPI creates a new LES service API. // NewPrivateLightAPI creates a new LES service API.
func NewPrivateLightAPI(backend *lesCommons, reg *checkpointOracle) *PrivateLightAPI { func NewPrivateLightAPI(backend *lesCommons) *PrivateLightAPI {
return &PrivateLightAPI{ return &PrivateLightAPI{backend: backend}
backend: backend,
reg: reg,
}
} }
// LatestCheckpoint returns the latest local checkpoint package. // LatestCheckpoint returns the latest local checkpoint package.
@ -67,7 +63,7 @@ func (api *PrivateLightAPI) LatestCheckpoint() ([4]string, error) {
// result[2], 32 bytes hex encoded latest section bloom trie root hash // result[2], 32 bytes hex encoded latest section bloom trie root hash
func (api *PrivateLightAPI) GetCheckpoint(index uint64) ([3]string, error) { func (api *PrivateLightAPI) GetCheckpoint(index uint64) ([3]string, error) {
var res [3]string var res [3]string
cp := api.backend.getLocalCheckpoint(index) cp := api.backend.localCheckpoint(index)
if cp.Empty() { if cp.Empty() {
return res, errNoCheckpoint return res, errNoCheckpoint
} }
@ -77,8 +73,8 @@ func (api *PrivateLightAPI) GetCheckpoint(index uint64) ([3]string, error) {
// GetCheckpointContractAddress returns the contract contract address in hex format. // GetCheckpointContractAddress returns the contract contract address in hex format.
func (api *PrivateLightAPI) GetCheckpointContractAddress() (string, error) { func (api *PrivateLightAPI) GetCheckpointContractAddress() (string, error) {
if api.reg == nil { if api.backend.oracle == nil {
return "", errNotActivated return "", errNotActivated
} }
return api.reg.config.Address.Hex(), nil return api.backend.oracle.config.Address.Hex(), nil
} }

@ -54,7 +54,7 @@ func (b *LesApiBackend) CurrentBlock() *types.Block {
} }
func (b *LesApiBackend) SetHead(number uint64) { func (b *LesApiBackend) SetHead(number uint64) {
b.eth.protocolManager.downloader.Cancel() b.eth.handler.downloader.Cancel()
b.eth.blockchain.SetHead(number) b.eth.blockchain.SetHead(number)
} }

@ -78,19 +78,16 @@ func TestCapacityAPI10(t *testing.T) {
// while connected and going back and forth between free and priority mode with // while connected and going back and forth between free and priority mode with
// the supplied API calls is also thoroughly tested. // the supplied API calls is also thoroughly tested.
func testCapacityAPI(t *testing.T, clientCount int) { func testCapacityAPI(t *testing.T, clientCount int) {
// Skip test if no data dir specified
if testServerDataDir == "" { if testServerDataDir == "" {
// Skip test if no data dir specified
return return
} }
for !testSim(t, 1, clientCount, []string{testServerDataDir}, nil, func(ctx context.Context, net *simulations.Network, servers []*simulations.Node, clients []*simulations.Node) bool { for !testSim(t, 1, clientCount, []string{testServerDataDir}, nil, func(ctx context.Context, net *simulations.Network, servers []*simulations.Node, clients []*simulations.Node) bool {
if len(servers) != 1 { if len(servers) != 1 {
t.Fatalf("Invalid number of servers: %d", len(servers)) t.Fatalf("Invalid number of servers: %d", len(servers))
} }
server := servers[0] server := servers[0]
clientRpcClients := make([]*rpc.Client, len(clients))
serverRpcClient, err := server.Client() serverRpcClient, err := server.Client()
if err != nil { if err != nil {
t.Fatalf("Failed to obtain rpc client: %v", err) t.Fatalf("Failed to obtain rpc client: %v", err)
@ -105,13 +102,13 @@ func testCapacityAPI(t *testing.T, clientCount int) {
} }
freeIdx := rand.Intn(len(clients)) freeIdx := rand.Intn(len(clients))
clientRpcClients := make([]*rpc.Client, len(clients))
for i, client := range clients { for i, client := range clients {
var err error var err error
clientRpcClients[i], err = client.Client() clientRpcClients[i], err = client.Client()
if err != nil { if err != nil {
t.Fatalf("Failed to obtain rpc client: %v", err) t.Fatalf("Failed to obtain rpc client: %v", err)
} }
t.Log("connecting client", i) t.Log("connecting client", i)
if i != freeIdx { if i != freeIdx {
setCapacity(ctx, t, serverRpcClient, client.ID(), testCap/uint64(len(clients))) setCapacity(ctx, t, serverRpcClient, client.ID(), testCap/uint64(len(clients)))
@ -138,10 +135,13 @@ func testCapacityAPI(t *testing.T, clientCount int) {
reqCount := make([]uint64, len(clientRpcClients)) reqCount := make([]uint64, len(clientRpcClients))
// Send light request like crazy.
for i, c := range clientRpcClients { for i, c := range clientRpcClients {
wg.Add(1) wg.Add(1)
i, c := i, c i, c := i, c
go func() { go func() {
defer wg.Done()
queue := make(chan struct{}, 100) queue := make(chan struct{}, 100)
reqCount[i] = 0 reqCount[i] = 0
for { for {
@ -149,10 +149,8 @@ func testCapacityAPI(t *testing.T, clientCount int) {
case queue <- struct{}{}: case queue <- struct{}{}:
select { select {
case <-stop: case <-stop:
wg.Done()
return return
case <-ctx.Done(): case <-ctx.Done():
wg.Done()
return return
default: default:
wg.Add(1) wg.Add(1)
@ -169,10 +167,8 @@ func testCapacityAPI(t *testing.T, clientCount int) {
}() }()
} }
case <-stop: case <-stop:
wg.Done()
return return
case <-ctx.Done(): case <-ctx.Done():
wg.Done()
return return
} }
} }
@ -313,12 +309,10 @@ func getHead(ctx context.Context, t *testing.T, client *rpc.Client) (uint64, com
} }
func testRequest(ctx context.Context, t *testing.T, client *rpc.Client) bool { func testRequest(ctx context.Context, t *testing.T, client *rpc.Client) bool {
//res := make(map[string]interface{})
var res string var res string
var addr common.Address var addr common.Address
rand.Read(addr[:]) rand.Read(addr[:])
c, _ := context.WithTimeout(ctx, time.Second*12) c, _ := context.WithTimeout(ctx, time.Second*12)
// if err := client.CallContext(ctx, &res, "eth_getProof", addr, nil, "latest"); err != nil {
err := client.CallContext(c, &res, "eth_getBalance", addr, "latest") err := client.CallContext(c, &res, "eth_getBalance", addr, "latest")
if err != nil { if err != nil {
t.Log("request error:", err) t.Log("request error:", err)
@ -418,7 +412,6 @@ func NewNetwork() (*simulations.Network, func(), error) {
adapterTeardown() adapterTeardown()
net.Shutdown() net.Shutdown()
} }
return net, teardown, nil return net, teardown, nil
} }
@ -516,7 +509,6 @@ func newLesServerService(ctx *adapters.ServiceContext) (node.Service, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
server, err := NewLesServer(ethereum, &config) server, err := NewLesServer(ethereum, &config)
if err != nil { if err != nil {
return nil, err return nil, err

@ -39,7 +39,7 @@ import (
// requestBenchmark is an interface for different randomized request generators // requestBenchmark is an interface for different randomized request generators
type requestBenchmark interface { type requestBenchmark interface {
// init initializes the generator for generating the given number of randomized requests // init initializes the generator for generating the given number of randomized requests
init(pm *ProtocolManager, count int) error init(h *serverHandler, count int) error
// request initiates sending a single request to the given peer // request initiates sending a single request to the given peer
request(peer *peer, index int) error request(peer *peer, index int) error
} }
@ -52,10 +52,10 @@ type benchmarkBlockHeaders struct {
hashes []common.Hash hashes []common.Hash
} }
func (b *benchmarkBlockHeaders) init(pm *ProtocolManager, count int) error { func (b *benchmarkBlockHeaders) init(h *serverHandler, count int) error {
d := int64(b.amount-1) * int64(b.skip+1) d := int64(b.amount-1) * int64(b.skip+1)
b.offset = 0 b.offset = 0
b.randMax = pm.blockchain.CurrentHeader().Number.Int64() + 1 - d b.randMax = h.blockchain.CurrentHeader().Number.Int64() + 1 - d
if b.randMax < 0 { if b.randMax < 0 {
return fmt.Errorf("chain is too short") return fmt.Errorf("chain is too short")
} }
@ -65,7 +65,7 @@ func (b *benchmarkBlockHeaders) init(pm *ProtocolManager, count int) error {
if b.byHash { if b.byHash {
b.hashes = make([]common.Hash, count) b.hashes = make([]common.Hash, count)
for i := range b.hashes { for i := range b.hashes {
b.hashes[i] = rawdb.ReadCanonicalHash(pm.chainDb, uint64(b.offset+rand.Int63n(b.randMax))) b.hashes[i] = rawdb.ReadCanonicalHash(h.chainDb, uint64(b.offset+rand.Int63n(b.randMax)))
} }
} }
return nil return nil
@ -85,11 +85,11 @@ type benchmarkBodiesOrReceipts struct {
hashes []common.Hash hashes []common.Hash
} }
func (b *benchmarkBodiesOrReceipts) init(pm *ProtocolManager, count int) error { func (b *benchmarkBodiesOrReceipts) init(h *serverHandler, count int) error {
randMax := pm.blockchain.CurrentHeader().Number.Int64() + 1 randMax := h.blockchain.CurrentHeader().Number.Int64() + 1
b.hashes = make([]common.Hash, count) b.hashes = make([]common.Hash, count)
for i := range b.hashes { for i := range b.hashes {
b.hashes[i] = rawdb.ReadCanonicalHash(pm.chainDb, uint64(rand.Int63n(randMax))) b.hashes[i] = rawdb.ReadCanonicalHash(h.chainDb, uint64(rand.Int63n(randMax)))
} }
return nil return nil
} }
@ -108,8 +108,8 @@ type benchmarkProofsOrCode struct {
headHash common.Hash headHash common.Hash
} }
func (b *benchmarkProofsOrCode) init(pm *ProtocolManager, count int) error { func (b *benchmarkProofsOrCode) init(h *serverHandler, count int) error {
b.headHash = pm.blockchain.CurrentHeader().Hash() b.headHash = h.blockchain.CurrentHeader().Hash()
return nil return nil
} }
@ -130,11 +130,11 @@ type benchmarkHelperTrie struct {
sectionCount, headNum uint64 sectionCount, headNum uint64
} }
func (b *benchmarkHelperTrie) init(pm *ProtocolManager, count int) error { func (b *benchmarkHelperTrie) init(h *serverHandler, count int) error {
if b.bloom { if b.bloom {
b.sectionCount, b.headNum, _ = pm.server.bloomTrieIndexer.Sections() b.sectionCount, b.headNum, _ = h.server.bloomTrieIndexer.Sections()
} else { } else {
b.sectionCount, _, _ = pm.server.chtIndexer.Sections() b.sectionCount, _, _ = h.server.chtIndexer.Sections()
b.headNum = b.sectionCount*params.CHTFrequency - 1 b.headNum = b.sectionCount*params.CHTFrequency - 1
} }
if b.sectionCount == 0 { if b.sectionCount == 0 {
@ -170,7 +170,7 @@ type benchmarkTxSend struct {
txs types.Transactions txs types.Transactions
} }
func (b *benchmarkTxSend) init(pm *ProtocolManager, count int) error { func (b *benchmarkTxSend) init(h *serverHandler, count int) error {
key, _ := crypto.GenerateKey() key, _ := crypto.GenerateKey()
addr := crypto.PubkeyToAddress(key.PublicKey) addr := crypto.PubkeyToAddress(key.PublicKey)
signer := types.NewEIP155Signer(big.NewInt(18)) signer := types.NewEIP155Signer(big.NewInt(18))
@ -196,7 +196,7 @@ func (b *benchmarkTxSend) request(peer *peer, index int) error {
// benchmarkTxStatus implements requestBenchmark // benchmarkTxStatus implements requestBenchmark
type benchmarkTxStatus struct{} type benchmarkTxStatus struct{}
func (b *benchmarkTxStatus) init(pm *ProtocolManager, count int) error { func (b *benchmarkTxStatus) init(h *serverHandler, count int) error {
return nil return nil
} }
@ -217,7 +217,7 @@ type benchmarkSetup struct {
// runBenchmark runs a benchmark cycle for all benchmark types in the specified // runBenchmark runs a benchmark cycle for all benchmark types in the specified
// number of passes // number of passes
func (pm *ProtocolManager) runBenchmark(benchmarks []requestBenchmark, passCount int, targetTime time.Duration) []*benchmarkSetup { func (h *serverHandler) runBenchmark(benchmarks []requestBenchmark, passCount int, targetTime time.Duration) []*benchmarkSetup {
setup := make([]*benchmarkSetup, len(benchmarks)) setup := make([]*benchmarkSetup, len(benchmarks))
for i, b := range benchmarks { for i, b := range benchmarks {
setup[i] = &benchmarkSetup{req: b} setup[i] = &benchmarkSetup{req: b}
@ -239,7 +239,7 @@ func (pm *ProtocolManager) runBenchmark(benchmarks []requestBenchmark, passCount
if next.totalTime > 0 { if next.totalTime > 0 {
count = int(uint64(next.totalCount) * uint64(targetTime) / uint64(next.totalTime)) count = int(uint64(next.totalCount) * uint64(targetTime) / uint64(next.totalTime))
} }
if err := pm.measure(next, count); err != nil { if err := h.measure(next, count); err != nil {
next.err = err next.err = err
} }
} }
@ -275,14 +275,15 @@ func (m *meteredPipe) WriteMsg(msg p2p.Msg) error {
// measure runs a benchmark for a single type in a single pass, with the given // measure runs a benchmark for a single type in a single pass, with the given
// number of requests // number of requests
func (pm *ProtocolManager) measure(setup *benchmarkSetup, count int) error { func (h *serverHandler) measure(setup *benchmarkSetup, count int) error {
clientPipe, serverPipe := p2p.MsgPipe() clientPipe, serverPipe := p2p.MsgPipe()
clientMeteredPipe := &meteredPipe{rw: clientPipe} clientMeteredPipe := &meteredPipe{rw: clientPipe}
serverMeteredPipe := &meteredPipe{rw: serverPipe} serverMeteredPipe := &meteredPipe{rw: serverPipe}
var id enode.ID var id enode.ID
rand.Read(id[:]) rand.Read(id[:])
clientPeer := pm.newPeer(lpv2, NetworkId, p2p.NewPeer(id, "client", nil), clientMeteredPipe)
serverPeer := pm.newPeer(lpv2, NetworkId, p2p.NewPeer(id, "server", nil), serverMeteredPipe) clientPeer := newPeer(lpv2, NetworkId, false, p2p.NewPeer(id, "client", nil), clientMeteredPipe)
serverPeer := newPeer(lpv2, NetworkId, false, p2p.NewPeer(id, "server", nil), serverMeteredPipe)
serverPeer.sendQueue = newExecQueue(count) serverPeer.sendQueue = newExecQueue(count)
serverPeer.announceType = announceTypeNone serverPeer.announceType = announceTypeNone
serverPeer.fcCosts = make(requestCostTable) serverPeer.fcCosts = make(requestCostTable)
@ -291,10 +292,10 @@ func (pm *ProtocolManager) measure(setup *benchmarkSetup, count int) error {
serverPeer.fcCosts[code] = c serverPeer.fcCosts[code] = c
} }
serverPeer.fcParams = flowcontrol.ServerParams{BufLimit: 1, MinRecharge: 1} serverPeer.fcParams = flowcontrol.ServerParams{BufLimit: 1, MinRecharge: 1}
serverPeer.fcClient = flowcontrol.NewClientNode(pm.server.fcManager, serverPeer.fcParams) serverPeer.fcClient = flowcontrol.NewClientNode(h.server.fcManager, serverPeer.fcParams)
defer serverPeer.fcClient.Disconnect() defer serverPeer.fcClient.Disconnect()
if err := setup.req.init(pm, count); err != nil { if err := setup.req.init(h, count); err != nil {
return err return err
} }
@ -311,7 +312,7 @@ func (pm *ProtocolManager) measure(setup *benchmarkSetup, count int) error {
}() }()
go func() { go func() {
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
if err := pm.handleMsg(serverPeer); err != nil { if err := h.handleMsg(serverPeer); err != nil {
errCh <- err errCh <- err
return return
} }
@ -336,7 +337,7 @@ func (pm *ProtocolManager) measure(setup *benchmarkSetup, count int) error {
if err != nil { if err != nil {
return err return err
} }
case <-pm.quitSync: case <-h.closeCh:
clientPipe.Close() clientPipe.Close()
serverPipe.Close() serverPipe.Close()
return fmt.Errorf("Benchmark cancelled") return fmt.Errorf("Benchmark cancelled")

@ -46,9 +46,10 @@ const (
func (eth *LightEthereum) startBloomHandlers(sectionSize uint64) { func (eth *LightEthereum) startBloomHandlers(sectionSize uint64) {
for i := 0; i < bloomServiceThreads; i++ { for i := 0; i < bloomServiceThreads; i++ {
go func() { go func() {
defer eth.wg.Done()
for { for {
select { select {
case <-eth.shutdownChan: case <-eth.closeCh:
return return
case request := <-eth.bloomRequests: case request := <-eth.bloomRequests:

@ -19,8 +19,6 @@ package les
import ( import (
"fmt" "fmt"
"sync"
"time"
"github.com/ethereum/go-ethereum/accounts" "github.com/ethereum/go-ethereum/accounts"
"github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/accounts/abi/bind"
@ -42,7 +40,7 @@ import (
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/node" "github.com/ethereum/go-ethereum/node"
"github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/discv5" "github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/params"
"github.com/ethereum/go-ethereum/rpc" "github.com/ethereum/go-ethereum/rpc"
) )
@ -50,33 +48,23 @@ import (
type LightEthereum struct { type LightEthereum struct {
lesCommons lesCommons
odr *LesOdr
chainConfig *params.ChainConfig
// Channel for shutting down the service
shutdownChan chan bool
// Handlers
peers *peerSet
txPool *light.TxPool
blockchain *light.LightChain
serverPool *serverPool
reqDist *requestDistributor reqDist *requestDistributor
retriever *retrieveManager retriever *retrieveManager
odr *LesOdr
relay *lesTxRelay relay *lesTxRelay
handler *clientHandler
txPool *light.TxPool
blockchain *light.LightChain
serverPool *serverPool
bloomRequests chan chan *bloombits.Retrieval // Channel receiving bloom data retrieval requests bloomRequests chan chan *bloombits.Retrieval // Channel receiving bloom data retrieval requests
bloomIndexer *core.ChainIndexer bloomIndexer *core.ChainIndexer // Bloom indexer operating during block imports
ApiBackend *LesApiBackend
ApiBackend *LesApiBackend
eventMux *event.TypeMux eventMux *event.TypeMux
engine consensus.Engine engine consensus.Engine
accountManager *accounts.Manager accountManager *accounts.Manager
netRPCService *ethapi.PublicNetAPI
networkId uint64
netRPCService *ethapi.PublicNetAPI
wg sync.WaitGroup
} }
func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) { func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) {
@ -91,26 +79,24 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) {
log.Info("Initialised chain configuration", "config", chainConfig) log.Info("Initialised chain configuration", "config", chainConfig)
peers := newPeerSet() peers := newPeerSet()
quitSync := make(chan struct{})
leth := &LightEthereum{ leth := &LightEthereum{
lesCommons: lesCommons{ lesCommons: lesCommons{
chainDb: chainDb, genesis: genesisHash,
config: config, config: config,
iConfig: light.DefaultClientIndexerConfig, chainConfig: chainConfig,
iConfig: light.DefaultClientIndexerConfig,
chainDb: chainDb,
peers: peers,
closeCh: make(chan struct{}),
}, },
chainConfig: chainConfig,
eventMux: ctx.EventMux, eventMux: ctx.EventMux,
peers: peers, reqDist: newRequestDistributor(peers, &mclock.System{}),
reqDist: newRequestDistributor(peers, quitSync, &mclock.System{}),
accountManager: ctx.AccountManager, accountManager: ctx.AccountManager,
engine: eth.CreateConsensusEngine(ctx, chainConfig, &config.Ethash, nil, false, chainDb), engine: eth.CreateConsensusEngine(ctx, chainConfig, &config.Ethash, nil, false, chainDb),
shutdownChan: make(chan bool),
networkId: config.NetworkId,
bloomRequests: make(chan chan *bloombits.Retrieval), bloomRequests: make(chan chan *bloombits.Retrieval),
bloomIndexer: eth.NewBloomIndexer(chainDb, params.BloomBitsBlocksClient, params.HelperTrieConfirmations), bloomIndexer: eth.NewBloomIndexer(chainDb, params.BloomBitsBlocksClient, params.HelperTrieConfirmations),
serverPool: newServerPool(chainDb, config.UltraLightServers),
} }
leth.serverPool = newServerPool(chainDb, quitSync, &leth.wg, leth.config.UltraLightServers)
leth.retriever = newRetrieveManager(peers, leth.reqDist, leth.serverPool) leth.retriever = newRetrieveManager(peers, leth.reqDist, leth.serverPool)
leth.relay = newLesTxRelay(peers, leth.retriever) leth.relay = newLesTxRelay(peers, leth.retriever)
@ -128,11 +114,26 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) {
if leth.blockchain, err = light.NewLightChain(leth.odr, leth.chainConfig, leth.engine, checkpoint); err != nil { if leth.blockchain, err = light.NewLightChain(leth.odr, leth.chainConfig, leth.engine, checkpoint); err != nil {
return nil, err return nil, err
} }
leth.chainReader = leth.blockchain
leth.txPool = light.NewTxPool(leth.chainConfig, leth.blockchain, leth.relay)
// Set up checkpoint oracle.
oracle := config.CheckpointOracle
if oracle == nil {
oracle = params.CheckpointOracles[genesisHash]
}
leth.oracle = newCheckpointOracle(oracle, leth.localCheckpoint)
// Note: AddChildIndexer starts the update process for the child // Note: AddChildIndexer starts the update process for the child
leth.bloomIndexer.AddChildIndexer(leth.bloomTrieIndexer) leth.bloomIndexer.AddChildIndexer(leth.bloomTrieIndexer)
leth.chtIndexer.Start(leth.blockchain) leth.chtIndexer.Start(leth.blockchain)
leth.bloomIndexer.Start(leth.blockchain) leth.bloomIndexer.Start(leth.blockchain)
leth.handler = newClientHandler(config.UltraLightServers, config.UltraLightFraction, checkpoint, leth)
if leth.handler.ulc != nil {
log.Warn("Ultra light client is enabled", "trustedNodes", len(leth.handler.ulc.keys), "minTrustedFraction", leth.handler.ulc.fraction)
leth.blockchain.DisableCheckFreq()
}
// Rewind the chain in case of an incompatible config upgrade. // Rewind the chain in case of an incompatible config upgrade.
if compat, ok := genesisErr.(*params.ConfigCompatError); ok { if compat, ok := genesisErr.(*params.ConfigCompatError); ok {
log.Warn("Rewinding chain to upgrade configuration", "err", compat) log.Warn("Rewinding chain to upgrade configuration", "err", compat)
@ -140,41 +141,16 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) {
rawdb.WriteChainConfig(chainDb, genesisHash, chainConfig) rawdb.WriteChainConfig(chainDb, genesisHash, chainConfig)
} }
leth.txPool = light.NewTxPool(leth.chainConfig, leth.blockchain, leth.relay)
leth.ApiBackend = &LesApiBackend{ctx.ExtRPCEnabled(), leth, nil} leth.ApiBackend = &LesApiBackend{ctx.ExtRPCEnabled(), leth, nil}
gpoParams := config.GPO gpoParams := config.GPO
if gpoParams.Default == nil { if gpoParams.Default == nil {
gpoParams.Default = config.Miner.GasPrice gpoParams.Default = config.Miner.GasPrice
} }
leth.ApiBackend.gpo = gasprice.NewOracle(leth.ApiBackend, gpoParams) leth.ApiBackend.gpo = gasprice.NewOracle(leth.ApiBackend, gpoParams)
oracle := config.CheckpointOracle
if oracle == nil {
oracle = params.CheckpointOracles[genesisHash]
}
registrar := newCheckpointOracle(oracle, leth.getLocalCheckpoint)
if leth.protocolManager, err = NewProtocolManager(leth.chainConfig, checkpoint, light.DefaultClientIndexerConfig, config.UltraLightServers, config.UltraLightFraction, true, config.NetworkId, leth.eventMux, leth.peers, leth.blockchain, nil, chainDb, leth.odr, leth.serverPool, registrar, quitSync, &leth.wg, nil); err != nil {
return nil, err
}
if leth.protocolManager.ulc != nil {
log.Warn("Ultra light client is enabled", "servers", len(config.UltraLightServers), "fraction", config.UltraLightFraction)
leth.blockchain.DisableCheckFreq()
}
return leth, nil return leth, nil
} }
func lesTopic(genesisHash common.Hash, protocolVersion uint) discv5.Topic {
var name string
switch protocolVersion {
case lpv2:
name = "LES2"
default:
panic(nil)
}
return discv5.Topic(name + "@" + common.Bytes2Hex(genesisHash.Bytes()[0:8]))
}
type LightDummyAPI struct{} type LightDummyAPI struct{}
// Etherbase is the address that mining rewards will be send to // Etherbase is the address that mining rewards will be send to
@ -209,7 +185,7 @@ func (s *LightEthereum) APIs() []rpc.API {
}, { }, {
Namespace: "eth", Namespace: "eth",
Version: "1.0", Version: "1.0",
Service: downloader.NewPublicDownloaderAPI(s.protocolManager.downloader, s.eventMux), Service: downloader.NewPublicDownloaderAPI(s.handler.downloader, s.eventMux),
Public: true, Public: true,
}, { }, {
Namespace: "eth", Namespace: "eth",
@ -224,7 +200,7 @@ func (s *LightEthereum) APIs() []rpc.API {
}, { }, {
Namespace: "les", Namespace: "les",
Version: "1.0", Version: "1.0",
Service: NewPrivateLightAPI(&s.lesCommons, s.protocolManager.reg), Service: NewPrivateLightAPI(&s.lesCommons),
Public: false, Public: false,
}, },
}...) }...)
@ -238,54 +214,63 @@ func (s *LightEthereum) BlockChain() *light.LightChain { return s.blockchai
func (s *LightEthereum) TxPool() *light.TxPool { return s.txPool } func (s *LightEthereum) TxPool() *light.TxPool { return s.txPool }
func (s *LightEthereum) Engine() consensus.Engine { return s.engine } func (s *LightEthereum) Engine() consensus.Engine { return s.engine }
func (s *LightEthereum) LesVersion() int { return int(ClientProtocolVersions[0]) } func (s *LightEthereum) LesVersion() int { return int(ClientProtocolVersions[0]) }
func (s *LightEthereum) Downloader() *downloader.Downloader { return s.protocolManager.downloader } func (s *LightEthereum) Downloader() *downloader.Downloader { return s.handler.downloader }
func (s *LightEthereum) EventMux() *event.TypeMux { return s.eventMux } func (s *LightEthereum) EventMux() *event.TypeMux { return s.eventMux }
// Protocols implements node.Service, returning all the currently configured // Protocols implements node.Service, returning all the currently configured
// network protocols to start. // network protocols to start.
func (s *LightEthereum) Protocols() []p2p.Protocol { func (s *LightEthereum) Protocols() []p2p.Protocol {
return s.makeProtocols(ClientProtocolVersions) return s.makeProtocols(ClientProtocolVersions, s.handler.runPeer, func(id enode.ID) interface{} {
if p := s.peers.Peer(peerIdToString(id)); p != nil {
return p.Info()
}
return nil
})
} }
// Start implements node.Service, starting all internal goroutines needed by the // Start implements node.Service, starting all internal goroutines needed by the
// Ethereum protocol implementation. // light ethereum protocol implementation.
func (s *LightEthereum) Start(srvr *p2p.Server) error { func (s *LightEthereum) Start(srvr *p2p.Server) error {
log.Warn("Light client mode is an experimental feature") log.Warn("Light client mode is an experimental feature")
// Start bloom request workers.
s.wg.Add(bloomServiceThreads)
s.startBloomHandlers(params.BloomBitsBlocksClient) s.startBloomHandlers(params.BloomBitsBlocksClient)
s.netRPCService = ethapi.NewPublicNetAPI(srvr, s.networkId)
s.netRPCService = ethapi.NewPublicNetAPI(srvr, s.config.NetworkId)
// clients are searching for the first advertised protocol in the list // clients are searching for the first advertised protocol in the list
protocolVersion := AdvertiseProtocolVersions[0] protocolVersion := AdvertiseProtocolVersions[0]
s.serverPool.start(srvr, lesTopic(s.blockchain.Genesis().Hash(), protocolVersion)) s.serverPool.start(srvr, lesTopic(s.blockchain.Genesis().Hash(), protocolVersion))
s.protocolManager.Start(s.config.LightPeers)
return nil return nil
} }
// Stop implements node.Service, terminating all internal goroutines used by the // Stop implements node.Service, terminating all internal goroutines used by the
// Ethereum protocol. // Ethereum protocol.
func (s *LightEthereum) Stop() error { func (s *LightEthereum) Stop() error {
close(s.closeCh)
s.peers.Close()
s.reqDist.close()
s.odr.Stop() s.odr.Stop()
s.relay.Stop() s.relay.Stop()
s.bloomIndexer.Close() s.bloomIndexer.Close()
s.chtIndexer.Close() s.chtIndexer.Close()
s.blockchain.Stop() s.blockchain.Stop()
s.protocolManager.Stop() s.handler.stop()
s.txPool.Stop() s.txPool.Stop()
s.engine.Close() s.engine.Close()
s.eventMux.Stop() s.eventMux.Stop()
s.serverPool.stop()
time.Sleep(time.Millisecond * 200)
s.chainDb.Close() s.chainDb.Close()
close(s.shutdownChan) s.wg.Wait()
log.Info("Light ethereum stopped")
return nil return nil
} }
// SetClient sets the rpc client and binds the registrar contract. // SetClient sets the rpc client and binds the registrar contract.
func (s *LightEthereum) SetContractBackend(backend bind.ContractBackend) { func (s *LightEthereum) SetContractBackend(backend bind.ContractBackend) {
// Short circuit if registrar is nil if s.oracle == nil {
if s.protocolManager.reg == nil {
return return
} }
s.protocolManager.reg.start(backend) s.oracle.start(backend)
} }

@ -0,0 +1,401 @@
// Copyright 2019 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package les
import (
"math/big"
"sync"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/eth/downloader"
"github.com/ethereum/go-ethereum/light"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/params"
)
// clientHandler is responsible for receiving and processing all incoming server
// responses.
type clientHandler struct {
ulc *ulc
checkpoint *params.TrustedCheckpoint
fetcher *lightFetcher
downloader *downloader.Downloader
backend *LightEthereum
closeCh chan struct{}
wg sync.WaitGroup // WaitGroup used to track all connected peers.
}
func newClientHandler(ulcServers []string, ulcFraction int, checkpoint *params.TrustedCheckpoint, backend *LightEthereum) *clientHandler {
handler := &clientHandler{
backend: backend,
closeCh: make(chan struct{}),
}
if ulcServers != nil {
ulc, err := newULC(ulcServers, ulcFraction)
if err != nil {
log.Error("Failed to initialize ultra light client")
}
handler.ulc = ulc
log.Info("Enable ultra light client mode")
}
var height uint64
if checkpoint != nil {
height = (checkpoint.SectionIndex+1)*params.CHTFrequency - 1
}
handler.fetcher = newLightFetcher(handler)
handler.downloader = downloader.New(height, backend.chainDb, nil, backend.eventMux, nil, backend.blockchain, handler.removePeer)
handler.backend.peers.notify((*downloaderPeerNotify)(handler))
return handler
}
func (h *clientHandler) stop() {
close(h.closeCh)
h.downloader.Terminate()
h.fetcher.close()
h.wg.Wait()
}
// runPeer is the p2p protocol run function for the given version.
func (h *clientHandler) runPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) error {
trusted := false
if h.ulc != nil {
trusted = h.ulc.trusted(p.ID())
}
peer := newPeer(int(version), h.backend.config.NetworkId, trusted, p, newMeteredMsgWriter(rw, int(version)))
peer.poolEntry = h.backend.serverPool.connect(peer, peer.Node())
if peer.poolEntry == nil {
return p2p.DiscRequested
}
h.wg.Add(1)
defer h.wg.Done()
err := h.handle(peer)
h.backend.serverPool.disconnect(peer.poolEntry)
return err
}
func (h *clientHandler) handle(p *peer) error {
if h.backend.peers.Len() >= h.backend.config.LightPeers && !p.Peer.Info().Network.Trusted {
return p2p.DiscTooManyPeers
}
p.Log().Debug("Light Ethereum peer connected", "name", p.Name())
// Execute the LES handshake
var (
head = h.backend.blockchain.CurrentHeader()
hash = head.Hash()
number = head.Number.Uint64()
td = h.backend.blockchain.GetTd(hash, number)
)
if err := p.Handshake(td, hash, number, h.backend.blockchain.Genesis().Hash(), nil); err != nil {
p.Log().Debug("Light Ethereum handshake failed", "err", err)
return err
}
// Register the peer locally
if err := h.backend.peers.Register(p); err != nil {
p.Log().Error("Light Ethereum peer registration failed", "err", err)
return err
}
serverConnectionGauge.Update(int64(h.backend.peers.Len()))
connectedAt := mclock.Now()
defer func() {
h.backend.peers.Unregister(p.id)
connectionTimer.Update(time.Duration(mclock.Now() - connectedAt))
serverConnectionGauge.Update(int64(h.backend.peers.Len()))
}()
h.fetcher.announce(p, p.headInfo)
// pool entry can be nil during the unit test.
if p.poolEntry != nil {
h.backend.serverPool.registered(p.poolEntry)
}
// Spawn a main loop to handle all incoming messages.
for {
if err := h.handleMsg(p); err != nil {
p.Log().Debug("Light Ethereum message handling failed", "err", err)
p.fcServer.DumpLogs()
return err
}
}
}
// handleMsg is invoked whenever an inbound message is received from a remote
// peer. The remote connection is torn down upon returning any error.
func (h *clientHandler) handleMsg(p *peer) error {
// Read the next message from the remote peer, and ensure it's fully consumed
msg, err := p.rw.ReadMsg()
if err != nil {
return err
}
p.Log().Trace("Light Ethereum message arrived", "code", msg.Code, "bytes", msg.Size)
if msg.Size > ProtocolMaxMsgSize {
return errResp(ErrMsgTooLarge, "%v > %v", msg.Size, ProtocolMaxMsgSize)
}
defer msg.Discard()
var deliverMsg *Msg
// Handle the message depending on its contents
switch msg.Code {
case AnnounceMsg:
p.Log().Trace("Received announce message")
var req announceData
if err := msg.Decode(&req); err != nil {
return errResp(ErrDecode, "%v: %v", msg, err)
}
if err := req.sanityCheck(); err != nil {
return err
}
update, size := req.Update.decode()
if p.rejectUpdate(size) {
return errResp(ErrRequestRejected, "")
}
p.updateFlowControl(update)
if req.Hash != (common.Hash{}) {
if p.announceType == announceTypeNone {
return errResp(ErrUnexpectedResponse, "")
}
if p.announceType == announceTypeSigned {
if err := req.checkSignature(p.ID(), update); err != nil {
p.Log().Trace("Invalid announcement signature", "err", err)
return err
}
p.Log().Trace("Valid announcement signature")
}
p.Log().Trace("Announce message content", "number", req.Number, "hash", req.Hash, "td", req.Td, "reorg", req.ReorgDepth)
h.fetcher.announce(p, &req)
}
case BlockHeadersMsg:
p.Log().Trace("Received block header response message")
var resp struct {
ReqID, BV uint64
Headers []*types.Header
}
if err := msg.Decode(&resp); err != nil {
return errResp(ErrDecode, "msg %v: %v", msg, err)
}
p.fcServer.ReceivedReply(resp.ReqID, resp.BV)
if h.fetcher.requestedID(resp.ReqID) {
h.fetcher.deliverHeaders(p, resp.ReqID, resp.Headers)
} else {
if err := h.downloader.DeliverHeaders(p.id, resp.Headers); err != nil {
log.Debug("Failed to deliver headers", "err", err)
}
}
case BlockBodiesMsg:
p.Log().Trace("Received block bodies response")
var resp struct {
ReqID, BV uint64
Data []*types.Body
}
if err := msg.Decode(&resp); err != nil {
return errResp(ErrDecode, "msg %v: %v", msg, err)
}
p.fcServer.ReceivedReply(resp.ReqID, resp.BV)
deliverMsg = &Msg{
MsgType: MsgBlockBodies,
ReqID: resp.ReqID,
Obj: resp.Data,
}
case CodeMsg:
p.Log().Trace("Received code response")
var resp struct {
ReqID, BV uint64
Data [][]byte
}
if err := msg.Decode(&resp); err != nil {
return errResp(ErrDecode, "msg %v: %v", msg, err)
}
p.fcServer.ReceivedReply(resp.ReqID, resp.BV)
deliverMsg = &Msg{
MsgType: MsgCode,
ReqID: resp.ReqID,
Obj: resp.Data,
}
case ReceiptsMsg:
p.Log().Trace("Received receipts response")
var resp struct {
ReqID, BV uint64
Receipts []types.Receipts
}
if err := msg.Decode(&resp); err != nil {
return errResp(ErrDecode, "msg %v: %v", msg, err)
}
p.fcServer.ReceivedReply(resp.ReqID, resp.BV)
deliverMsg = &Msg{
MsgType: MsgReceipts,
ReqID: resp.ReqID,
Obj: resp.Receipts,
}
case ProofsV2Msg:
p.Log().Trace("Received les/2 proofs response")
var resp struct {
ReqID, BV uint64
Data light.NodeList
}
if err := msg.Decode(&resp); err != nil {
return errResp(ErrDecode, "msg %v: %v", msg, err)
}
p.fcServer.ReceivedReply(resp.ReqID, resp.BV)
deliverMsg = &Msg{
MsgType: MsgProofsV2,
ReqID: resp.ReqID,
Obj: resp.Data,
}
case HelperTrieProofsMsg:
p.Log().Trace("Received helper trie proof response")
var resp struct {
ReqID, BV uint64
Data HelperTrieResps
}
if err := msg.Decode(&resp); err != nil {
return errResp(ErrDecode, "msg %v: %v", msg, err)
}
p.fcServer.ReceivedReply(resp.ReqID, resp.BV)
deliverMsg = &Msg{
MsgType: MsgHelperTrieProofs,
ReqID: resp.ReqID,
Obj: resp.Data,
}
case TxStatusMsg:
p.Log().Trace("Received tx status response")
var resp struct {
ReqID, BV uint64
Status []light.TxStatus
}
if err := msg.Decode(&resp); err != nil {
return errResp(ErrDecode, "msg %v: %v", msg, err)
}
p.fcServer.ReceivedReply(resp.ReqID, resp.BV)
deliverMsg = &Msg{
MsgType: MsgTxStatus,
ReqID: resp.ReqID,
Obj: resp.Status,
}
case StopMsg:
p.freezeServer(true)
h.backend.retriever.frozen(p)
p.Log().Debug("Service stopped")
case ResumeMsg:
var bv uint64
if err := msg.Decode(&bv); err != nil {
return errResp(ErrDecode, "msg %v: %v", msg, err)
}
p.fcServer.ResumeFreeze(bv)
p.freezeServer(false)
p.Log().Debug("Service resumed")
default:
p.Log().Trace("Received invalid message", "code", msg.Code)
return errResp(ErrInvalidMsgCode, "%v", msg.Code)
}
// Deliver the received response to retriever.
if deliverMsg != nil {
if err := h.backend.retriever.deliver(p, deliverMsg); err != nil {
p.responseErrors++
if p.responseErrors > maxResponseErrors {
return err
}
}
}
return nil
}
func (h *clientHandler) removePeer(id string) {
h.backend.peers.Unregister(id)
}
type peerConnection struct {
handler *clientHandler
peer *peer
}
func (pc *peerConnection) Head() (common.Hash, *big.Int) {
return pc.peer.HeadAndTd()
}
func (pc *peerConnection) RequestHeadersByHash(origin common.Hash, amount int, skip int, reverse bool) error {
rq := &distReq{
getCost: func(dp distPeer) uint64 {
peer := dp.(*peer)
return peer.GetRequestCost(GetBlockHeadersMsg, amount)
},
canSend: func(dp distPeer) bool {
return dp.(*peer) == pc.peer
},
request: func(dp distPeer) func() {
reqID := genReqID()
peer := dp.(*peer)
cost := peer.GetRequestCost(GetBlockHeadersMsg, amount)
peer.fcServer.QueuedRequest(reqID, cost)
return func() { peer.RequestHeadersByHash(reqID, cost, origin, amount, skip, reverse) }
},
}
_, ok := <-pc.handler.backend.reqDist.queue(rq)
if !ok {
return light.ErrNoPeers
}
return nil
}
func (pc *peerConnection) RequestHeadersByNumber(origin uint64, amount int, skip int, reverse bool) error {
rq := &distReq{
getCost: func(dp distPeer) uint64 {
peer := dp.(*peer)
return peer.GetRequestCost(GetBlockHeadersMsg, amount)
},
canSend: func(dp distPeer) bool {
return dp.(*peer) == pc.peer
},
request: func(dp distPeer) func() {
reqID := genReqID()
peer := dp.(*peer)
cost := peer.GetRequestCost(GetBlockHeadersMsg, amount)
peer.fcServer.QueuedRequest(reqID, cost)
return func() { peer.RequestHeadersByNumber(reqID, cost, origin, amount, skip, reverse) }
},
}
_, ok := <-pc.handler.backend.reqDist.queue(rq)
if !ok {
return light.ErrNoPeers
}
return nil
}
// downloaderPeerNotify implements peerSetNotify
type downloaderPeerNotify clientHandler
func (d *downloaderPeerNotify) registerPeer(p *peer) {
h := (*clientHandler)(d)
pc := &peerConnection{
handler: h,
peer: p,
}
h.downloader.RegisterLightPeer(p.id, ethVersion, pc)
}
func (d *downloaderPeerNotify) unregisterPeer(p *peer) {
h := (*clientHandler)(d)
h.downloader.UnregisterPeer(p.id)
}

@ -17,25 +17,56 @@
package les package les
import ( import (
"fmt"
"math/big" "math/big"
"sync"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/eth" "github.com/ethereum/go-ethereum/eth"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/light"
"github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/discv5"
"github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/params"
) )
func errResp(code errCode, format string, v ...interface{}) error {
return fmt.Errorf("%v - %v", code, fmt.Sprintf(format, v...))
}
func lesTopic(genesisHash common.Hash, protocolVersion uint) discv5.Topic {
var name string
switch protocolVersion {
case lpv2:
name = "LES2"
default:
panic(nil)
}
return discv5.Topic(name + "@" + common.Bytes2Hex(genesisHash.Bytes()[0:8]))
}
type chainReader interface {
CurrentHeader() *types.Header
}
// lesCommons contains fields needed by both server and client. // lesCommons contains fields needed by both server and client.
type lesCommons struct { type lesCommons struct {
genesis common.Hash
config *eth.Config config *eth.Config
chainConfig *params.ChainConfig
iConfig *light.IndexerConfig iConfig *light.IndexerConfig
chainDb ethdb.Database chainDb ethdb.Database
protocolManager *ProtocolManager peers *peerSet
chainReader chainReader
chtIndexer, bloomTrieIndexer *core.ChainIndexer chtIndexer, bloomTrieIndexer *core.ChainIndexer
oracle *checkpointOracle
closeCh chan struct{}
wg sync.WaitGroup
} }
// NodeInfo represents a short summary of the Ethereum sub-protocol metadata // NodeInfo represents a short summary of the Ethereum sub-protocol metadata
@ -50,7 +81,7 @@ type NodeInfo struct {
} }
// makeProtocols creates protocol descriptors for the given LES versions. // makeProtocols creates protocol descriptors for the given LES versions.
func (c *lesCommons) makeProtocols(versions []uint) []p2p.Protocol { func (c *lesCommons) makeProtocols(versions []uint, runPeer func(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) error, peerInfo func(id enode.ID) interface{}) []p2p.Protocol {
protos := make([]p2p.Protocol, len(versions)) protos := make([]p2p.Protocol, len(versions))
for i, version := range versions { for i, version := range versions {
version := version version := version
@ -59,15 +90,10 @@ func (c *lesCommons) makeProtocols(versions []uint) []p2p.Protocol {
Version: version, Version: version,
Length: ProtocolLengths[version], Length: ProtocolLengths[version],
NodeInfo: c.nodeInfo, NodeInfo: c.nodeInfo,
Run: func(p *p2p.Peer, rw p2p.MsgReadWriter) error { Run: func(peer *p2p.Peer, rw p2p.MsgReadWriter) error {
return c.protocolManager.runPeer(version, p, rw) return runPeer(version, peer, rw)
},
PeerInfo: func(id enode.ID) interface{} {
if p := c.protocolManager.peers.Peer(peerIdToString(id)); p != nil {
return p.Info()
}
return nil
}, },
PeerInfo: peerInfo,
} }
} }
return protos return protos
@ -75,22 +101,21 @@ func (c *lesCommons) makeProtocols(versions []uint) []p2p.Protocol {
// nodeInfo retrieves some protocol metadata about the running host node. // nodeInfo retrieves some protocol metadata about the running host node.
func (c *lesCommons) nodeInfo() interface{} { func (c *lesCommons) nodeInfo() interface{} {
chain := c.protocolManager.blockchain head := c.chainReader.CurrentHeader()
head := chain.CurrentHeader()
hash := head.Hash() hash := head.Hash()
return &NodeInfo{ return &NodeInfo{
Network: c.config.NetworkId, Network: c.config.NetworkId,
Difficulty: chain.GetTd(hash, head.Number.Uint64()), Difficulty: rawdb.ReadTd(c.chainDb, hash, head.Number.Uint64()),
Genesis: chain.Genesis().Hash(), Genesis: c.genesis,
Config: chain.Config(), Config: c.chainConfig,
Head: chain.CurrentHeader().Hash(), Head: hash,
CHT: c.latestLocalCheckpoint(), CHT: c.latestLocalCheckpoint(),
} }
} }
// latestLocalCheckpoint finds the common stored section index and returns a set of // latestLocalCheckpoint finds the common stored section index and returns a set
// post-processed trie roots (CHT and BloomTrie) associated with // of post-processed trie roots (CHT and BloomTrie) associated with the appropriate
// the appropriate section index and head hash as a local checkpoint package. // section index and head hash as a local checkpoint package.
func (c *lesCommons) latestLocalCheckpoint() params.TrustedCheckpoint { func (c *lesCommons) latestLocalCheckpoint() params.TrustedCheckpoint {
sections, _, _ := c.chtIndexer.Sections() sections, _, _ := c.chtIndexer.Sections()
sections2, _, _ := c.bloomTrieIndexer.Sections() sections2, _, _ := c.bloomTrieIndexer.Sections()
@ -102,15 +127,15 @@ func (c *lesCommons) latestLocalCheckpoint() params.TrustedCheckpoint {
// No checkpoint information can be provided. // No checkpoint information can be provided.
return params.TrustedCheckpoint{} return params.TrustedCheckpoint{}
} }
return c.getLocalCheckpoint(sections - 1) return c.localCheckpoint(sections - 1)
} }
// getLocalCheckpoint returns a set of post-processed trie roots (CHT and BloomTrie) // localCheckpoint returns a set of post-processed trie roots (CHT and BloomTrie)
// associated with the appropriate head hash by specific section index. // associated with the appropriate head hash by specific section index.
// //
// The returned checkpoint is only the checkpoint generated by the local indexers, // The returned checkpoint is only the checkpoint generated by the local indexers,
// not the stable checkpoint registered in the registrar contract. // not the stable checkpoint registered in the registrar contract.
func (c *lesCommons) getLocalCheckpoint(index uint64) params.TrustedCheckpoint { func (c *lesCommons) localCheckpoint(index uint64) params.TrustedCheckpoint {
sectionHead := c.chtIndexer.SectionHead(index) sectionHead := c.chtIndexer.SectionHead(index)
return params.TrustedCheckpoint{ return params.TrustedCheckpoint{
SectionIndex: index, SectionIndex: index,

@ -81,7 +81,8 @@ var (
) )
const ( const (
maxCostFactor = 2 // ratio of maximum and average cost estimates maxCostFactor = 2 // ratio of maximum and average cost estimates
bufLimitRatio = 6000 // fixed bufLimit/MRR ratio
gfUsageThreshold = 0.5 gfUsageThreshold = 0.5
gfUsageTC = time.Second gfUsageTC = time.Second
gfRaiseTC = time.Second * 200 gfRaiseTC = time.Second * 200
@ -127,6 +128,10 @@ type costTracker struct {
totalRechargeCh chan uint64 totalRechargeCh chan uint64
stats map[uint64][]uint64 // Used for testing purpose. stats map[uint64][]uint64 // Used for testing purpose.
// TestHooks
testing bool // Disable real cost evaluation for testing purpose.
testCostList RequestCostList // Customized cost table for testing purpose.
} }
// newCostTracker creates a cost tracker and loads the cost factor statistics from the database. // newCostTracker creates a cost tracker and loads the cost factor statistics from the database.
@ -265,8 +270,9 @@ func (ct *costTracker) gfLoop() {
select { select {
case r := <-ct.reqInfoCh: case r := <-ct.reqInfoCh:
requestServedMeter.Mark(int64(r.servingTime)) requestServedMeter.Mark(int64(r.servingTime))
requestEstimatedMeter.Mark(int64(r.avgTimeCost / factor))
requestServedTimer.Update(time.Duration(r.servingTime)) requestServedTimer.Update(time.Duration(r.servingTime))
requestEstimatedMeter.Mark(int64(r.avgTimeCost / factor))
requestEstimatedTimer.Update(time.Duration(r.avgTimeCost / factor))
relativeCostHistogram.Update(int64(r.avgTimeCost / factor / r.servingTime)) relativeCostHistogram.Update(int64(r.avgTimeCost / factor / r.servingTime))
now := mclock.Now() now := mclock.Now()
@ -323,7 +329,6 @@ func (ct *costTracker) gfLoop() {
} }
recentServedGauge.Update(int64(recentTime)) recentServedGauge.Update(int64(recentTime))
recentEstimatedGauge.Update(int64(recentAvg)) recentEstimatedGauge.Update(int64(recentAvg))
totalRechargeGauge.Update(int64(totalRecharge))
case <-saveTicker.C: case <-saveTicker.C:
saveCostFactor() saveCostFactor()

@ -28,14 +28,17 @@ import (
// suitable peers, obeying flow control rules and prioritizing them in creation // suitable peers, obeying flow control rules and prioritizing them in creation
// order (even when a resend is necessary). // order (even when a resend is necessary).
type requestDistributor struct { type requestDistributor struct {
clock mclock.Clock clock mclock.Clock
reqQueue *list.List reqQueue *list.List
lastReqOrder uint64 lastReqOrder uint64
peers map[distPeer]struct{} peers map[distPeer]struct{}
peerLock sync.RWMutex peerLock sync.RWMutex
stopChn, loopChn chan struct{} loopChn chan struct{}
loopNextSent bool loopNextSent bool
lock sync.Mutex lock sync.Mutex
closeCh chan struct{}
wg sync.WaitGroup
} }
// distPeer is an LES server peer interface for the request distributor. // distPeer is an LES server peer interface for the request distributor.
@ -66,20 +69,22 @@ type distReq struct {
sentChn chan distPeer sentChn chan distPeer
element *list.Element element *list.Element
waitForPeers mclock.AbsTime waitForPeers mclock.AbsTime
enterQueue mclock.AbsTime
} }
// newRequestDistributor creates a new request distributor // newRequestDistributor creates a new request distributor
func newRequestDistributor(peers *peerSet, stopChn chan struct{}, clock mclock.Clock) *requestDistributor { func newRequestDistributor(peers *peerSet, clock mclock.Clock) *requestDistributor {
d := &requestDistributor{ d := &requestDistributor{
clock: clock, clock: clock,
reqQueue: list.New(), reqQueue: list.New(),
loopChn: make(chan struct{}, 2), loopChn: make(chan struct{}, 2),
stopChn: stopChn, closeCh: make(chan struct{}),
peers: make(map[distPeer]struct{}), peers: make(map[distPeer]struct{}),
} }
if peers != nil { if peers != nil {
peers.notify(d) peers.notify(d)
} }
d.wg.Add(1)
go d.loop() go d.loop()
return d return d
} }
@ -115,9 +120,10 @@ const waitForPeers = time.Second * 3
// main event loop // main event loop
func (d *requestDistributor) loop() { func (d *requestDistributor) loop() {
defer d.wg.Done()
for { for {
select { select {
case <-d.stopChn: case <-d.closeCh:
d.lock.Lock() d.lock.Lock()
elem := d.reqQueue.Front() elem := d.reqQueue.Front()
for elem != nil { for elem != nil {
@ -140,6 +146,7 @@ func (d *requestDistributor) loop() {
send := req.request(peer) send := req.request(peer)
if send != nil { if send != nil {
peer.queueSend(send) peer.queueSend(send)
requestSendDelay.Update(time.Duration(d.clock.Now() - req.enterQueue))
} }
chn <- peer chn <- peer
close(chn) close(chn)
@ -249,6 +256,9 @@ func (d *requestDistributor) queue(r *distReq) chan distPeer {
r.reqOrder = d.lastReqOrder r.reqOrder = d.lastReqOrder
r.waitForPeers = d.clock.Now() + mclock.AbsTime(waitForPeers) r.waitForPeers = d.clock.Now() + mclock.AbsTime(waitForPeers)
} }
// Assign the timestamp when the request is queued no matter it's
// a new one or re-queued one.
r.enterQueue = d.clock.Now()
back := d.reqQueue.Back() back := d.reqQueue.Back()
if back == nil || r.reqOrder > back.Value.(*distReq).reqOrder { if back == nil || r.reqOrder > back.Value.(*distReq).reqOrder {
@ -294,3 +304,8 @@ func (d *requestDistributor) remove(r *distReq) {
r.element = nil r.element = nil
} }
} }
func (d *requestDistributor) close() {
close(d.closeCh)
d.wg.Wait()
}

@ -121,7 +121,7 @@ func testRequestDistributor(t *testing.T, resend bool) {
stop := make(chan struct{}) stop := make(chan struct{})
defer close(stop) defer close(stop)
dist := newRequestDistributor(nil, stop, &mclock.System{}) dist := newRequestDistributor(nil, &mclock.System{})
var peers [testDistPeerCount]*testDistPeer var peers [testDistPeerCount]*testDistPeer
for i := range peers { for i := range peers {
peers[i] = &testDistPeer{} peers[i] = &testDistPeer{}

@ -40,9 +40,8 @@ const (
// ODR system to ensure that we only request data related to a certain block from peers who have already processed // ODR system to ensure that we only request data related to a certain block from peers who have already processed
// and announced that block. // and announced that block.
type lightFetcher struct { type lightFetcher struct {
pm *ProtocolManager handler *clientHandler
odr *LesOdr chain *light.LightChain
chain lightChain
lock sync.Mutex // lock protects access to the fetcher's internal state variables except sent requests lock sync.Mutex // lock protects access to the fetcher's internal state variables except sent requests
maxConfirmedTd *big.Int maxConfirmedTd *big.Int
@ -58,13 +57,9 @@ type lightFetcher struct {
requestTriggered bool requestTriggered bool
requestTrigger chan struct{} requestTrigger chan struct{}
lastTrustedHeader *types.Header lastTrustedHeader *types.Header
}
// lightChain extends the BlockChain interface by locking. closeCh chan struct{}
type lightChain interface { wg sync.WaitGroup
BlockChain
LockChain()
UnlockChain()
} }
// fetcherPeerInfo holds fetcher-specific information about each active peer // fetcherPeerInfo holds fetcher-specific information about each active peer
@ -114,32 +109,37 @@ type fetchResponse struct {
} }
// newLightFetcher creates a new light fetcher // newLightFetcher creates a new light fetcher
func newLightFetcher(pm *ProtocolManager) *lightFetcher { func newLightFetcher(h *clientHandler) *lightFetcher {
f := &lightFetcher{ f := &lightFetcher{
pm: pm, handler: h,
chain: pm.blockchain.(*light.LightChain), chain: h.backend.blockchain,
odr: pm.odr,
peers: make(map[*peer]*fetcherPeerInfo), peers: make(map[*peer]*fetcherPeerInfo),
deliverChn: make(chan fetchResponse, 100), deliverChn: make(chan fetchResponse, 100),
requested: make(map[uint64]fetchRequest), requested: make(map[uint64]fetchRequest),
timeoutChn: make(chan uint64), timeoutChn: make(chan uint64),
requestTrigger: make(chan struct{}, 1), requestTrigger: make(chan struct{}, 1),
syncDone: make(chan *peer), syncDone: make(chan *peer),
closeCh: make(chan struct{}),
maxConfirmedTd: big.NewInt(0), maxConfirmedTd: big.NewInt(0),
} }
pm.peers.notify(f) h.backend.peers.notify(f)
f.pm.wg.Add(1) f.wg.Add(1)
go f.syncLoop() go f.syncLoop()
return f return f
} }
func (f *lightFetcher) close() {
close(f.closeCh)
f.wg.Wait()
}
// syncLoop is the main event loop of the light fetcher // syncLoop is the main event loop of the light fetcher
func (f *lightFetcher) syncLoop() { func (f *lightFetcher) syncLoop() {
defer f.pm.wg.Done() defer f.wg.Done()
for { for {
select { select {
case <-f.pm.quitSync: case <-f.closeCh:
return return
// request loop keeps running until no further requests are necessary or possible // request loop keeps running until no further requests are necessary or possible
case <-f.requestTrigger: case <-f.requestTrigger:
@ -156,7 +156,7 @@ func (f *lightFetcher) syncLoop() {
f.lock.Unlock() f.lock.Unlock()
if rq != nil { if rq != nil {
if _, ok := <-f.pm.reqDist.queue(rq); ok { if _, ok := <-f.handler.backend.reqDist.queue(rq); ok {
if syncing { if syncing {
f.lock.Lock() f.lock.Lock()
f.syncing = true f.syncing = true
@ -187,9 +187,9 @@ func (f *lightFetcher) syncLoop() {
} }
f.reqMu.Unlock() f.reqMu.Unlock()
if ok { if ok {
f.pm.serverPool.adjustResponseTime(req.peer.poolEntry, time.Duration(mclock.Now()-req.sent), true) f.handler.backend.serverPool.adjustResponseTime(req.peer.poolEntry, time.Duration(mclock.Now()-req.sent), true)
req.peer.Log().Debug("Fetching data timed out hard") req.peer.Log().Debug("Fetching data timed out hard")
go f.pm.removePeer(req.peer.id) go f.handler.removePeer(req.peer.id)
} }
case resp := <-f.deliverChn: case resp := <-f.deliverChn:
f.reqMu.Lock() f.reqMu.Lock()
@ -202,12 +202,12 @@ func (f *lightFetcher) syncLoop() {
} }
f.reqMu.Unlock() f.reqMu.Unlock()
if ok { if ok {
f.pm.serverPool.adjustResponseTime(req.peer.poolEntry, time.Duration(mclock.Now()-req.sent), req.timeout) f.handler.backend.serverPool.adjustResponseTime(req.peer.poolEntry, time.Duration(mclock.Now()-req.sent), req.timeout)
} }
f.lock.Lock() f.lock.Lock()
if !ok || !(f.syncing || f.processResponse(req, resp)) { if !ok || !(f.syncing || f.processResponse(req, resp)) {
resp.peer.Log().Debug("Failed processing response") resp.peer.Log().Debug("Failed processing response")
go f.pm.removePeer(resp.peer.id) go f.handler.removePeer(resp.peer.id)
} }
f.lock.Unlock() f.lock.Unlock()
case p := <-f.syncDone: case p := <-f.syncDone:
@ -264,7 +264,7 @@ func (f *lightFetcher) announce(p *peer, head *announceData) {
if fp.lastAnnounced != nil && head.Td.Cmp(fp.lastAnnounced.td) <= 0 { if fp.lastAnnounced != nil && head.Td.Cmp(fp.lastAnnounced.td) <= 0 {
// announced tds should be strictly monotonic // announced tds should be strictly monotonic
p.Log().Debug("Received non-monotonic td", "current", head.Td, "previous", fp.lastAnnounced.td) p.Log().Debug("Received non-monotonic td", "current", head.Td, "previous", fp.lastAnnounced.td)
go f.pm.removePeer(p.id) go f.handler.removePeer(p.id)
return return
} }
@ -297,7 +297,7 @@ func (f *lightFetcher) announce(p *peer, head *announceData) {
// if one of root's children is canonical, keep it, delete other branches and root itself // if one of root's children is canonical, keep it, delete other branches and root itself
var newRoot *fetcherTreeNode var newRoot *fetcherTreeNode
for i, nn := range fp.root.children { for i, nn := range fp.root.children {
if rawdb.ReadCanonicalHash(f.pm.chainDb, nn.number) == nn.hash { if rawdb.ReadCanonicalHash(f.handler.backend.chainDb, nn.number) == nn.hash {
fp.root.children = append(fp.root.children[:i], fp.root.children[i+1:]...) fp.root.children = append(fp.root.children[:i], fp.root.children[i+1:]...)
nn.parent = nil nn.parent = nil
newRoot = nn newRoot = nn
@ -390,7 +390,7 @@ func (f *lightFetcher) peerHasBlock(p *peer, hash common.Hash, number uint64, ha
// //
// when syncing, just check if it is part of the known chain, there is nothing better we // when syncing, just check if it is part of the known chain, there is nothing better we
// can do since we do not know the most recent block hash yet // can do since we do not know the most recent block hash yet
return rawdb.ReadCanonicalHash(f.pm.chainDb, fp.root.number) == fp.root.hash && rawdb.ReadCanonicalHash(f.pm.chainDb, number) == hash return rawdb.ReadCanonicalHash(f.handler.backend.chainDb, fp.root.number) == fp.root.hash && rawdb.ReadCanonicalHash(f.handler.backend.chainDb, number) == hash
} }
// requestAmount calculates the amount of headers to be downloaded starting // requestAmount calculates the amount of headers to be downloaded starting
@ -453,8 +453,7 @@ func (f *lightFetcher) findBestRequest() (bestHash common.Hash, bestAmount uint6
if f.checkKnownNode(p, n) || n.requested { if f.checkKnownNode(p, n) || n.requested {
continue continue
} }
// if ulc mode is disabled, isTrustedHash returns true
//if ulc mode is disabled, isTrustedHash returns true
amount := f.requestAmount(p, n) amount := f.requestAmount(p, n)
if (bestTd == nil || n.td.Cmp(bestTd) > 0 || amount < bestAmount) && (f.isTrustedHash(hash) || f.maxConfirmedTd.Int64() == 0) { if (bestTd == nil || n.td.Cmp(bestTd) > 0 || amount < bestAmount) && (f.isTrustedHash(hash) || f.maxConfirmedTd.Int64() == 0) {
bestHash = hash bestHash = hash
@ -470,7 +469,7 @@ func (f *lightFetcher) findBestRequest() (bestHash common.Hash, bestAmount uint6
// isTrustedHash checks if the block can be trusted by the minimum trusted fraction. // isTrustedHash checks if the block can be trusted by the minimum trusted fraction.
func (f *lightFetcher) isTrustedHash(hash common.Hash) bool { func (f *lightFetcher) isTrustedHash(hash common.Hash) bool {
// If ultra light cliet mode is disabled, trust all hashes // If ultra light cliet mode is disabled, trust all hashes
if f.pm.ulc == nil { if f.handler.ulc == nil {
return true return true
} }
// Ultra light enabled, only trust after enough confirmations // Ultra light enabled, only trust after enough confirmations
@ -480,7 +479,7 @@ func (f *lightFetcher) isTrustedHash(hash common.Hash) bool {
agreed++ agreed++
} }
} }
return 100*agreed/len(f.pm.ulc.keys) >= f.pm.ulc.fraction return 100*agreed/len(f.handler.ulc.keys) >= f.handler.ulc.fraction
} }
func (f *lightFetcher) newFetcherDistReqForSync(bestHash common.Hash) *distReq { func (f *lightFetcher) newFetcherDistReqForSync(bestHash common.Hash) *distReq {
@ -500,14 +499,14 @@ func (f *lightFetcher) newFetcherDistReqForSync(bestHash common.Hash) *distReq {
return fp != nil && fp.nodeByHash[bestHash] != nil return fp != nil && fp.nodeByHash[bestHash] != nil
}, },
request: func(dp distPeer) func() { request: func(dp distPeer) func() {
if f.pm.ulc != nil { if f.handler.ulc != nil {
// Keep last trusted header before sync // Keep last trusted header before sync
f.setLastTrustedHeader(f.chain.CurrentHeader()) f.setLastTrustedHeader(f.chain.CurrentHeader())
} }
go func() { go func() {
p := dp.(*peer) p := dp.(*peer)
p.Log().Debug("Synchronisation started") p.Log().Debug("Synchronisation started")
f.pm.synchronise(p) f.handler.synchronise(p)
f.syncDone <- p f.syncDone <- p
}() }()
return nil return nil
@ -607,7 +606,7 @@ func (f *lightFetcher) newHeaders(headers []*types.Header, tds []*big.Int) {
for p, fp := range f.peers { for p, fp := range f.peers {
if !f.checkAnnouncedHeaders(fp, headers, tds) { if !f.checkAnnouncedHeaders(fp, headers, tds) {
p.Log().Debug("Inconsistent announcement") p.Log().Debug("Inconsistent announcement")
go f.pm.removePeer(p.id) go f.handler.removePeer(p.id)
} }
if fp.confirmedTd != nil && (maxTd == nil || maxTd.Cmp(fp.confirmedTd) > 0) { if fp.confirmedTd != nil && (maxTd == nil || maxTd.Cmp(fp.confirmedTd) > 0) {
maxTd = fp.confirmedTd maxTd = fp.confirmedTd
@ -705,7 +704,7 @@ func (f *lightFetcher) checkSyncedHeaders(p *peer) {
node = fp.lastAnnounced node = fp.lastAnnounced
td *big.Int td *big.Int
) )
if f.pm.ulc != nil { if f.handler.ulc != nil {
// Roll back untrusted blocks // Roll back untrusted blocks
h, unapproved := f.lastTrustedTreeNode(p) h, unapproved := f.lastTrustedTreeNode(p)
f.chain.Rollback(unapproved) f.chain.Rollback(unapproved)
@ -721,7 +720,7 @@ func (f *lightFetcher) checkSyncedHeaders(p *peer) {
// Now node is the latest downloaded/approved header after syncing // Now node is the latest downloaded/approved header after syncing
if node == nil { if node == nil {
p.Log().Debug("Synchronisation failed") p.Log().Debug("Synchronisation failed")
go f.pm.removePeer(p.id) go f.handler.removePeer(p.id)
return return
} }
header := f.chain.GetHeader(node.hash, node.number) header := f.chain.GetHeader(node.hash, node.number)
@ -741,7 +740,7 @@ func (f *lightFetcher) lastTrustedTreeNode(p *peer) (*types.Header, []common.Has
if canonical.Number.Uint64() > f.lastTrustedHeader.Number.Uint64() { if canonical.Number.Uint64() > f.lastTrustedHeader.Number.Uint64() {
canonical = f.chain.GetHeaderByNumber(f.lastTrustedHeader.Number.Uint64()) canonical = f.chain.GetHeaderByNumber(f.lastTrustedHeader.Number.Uint64())
} }
commonAncestor := rawdb.FindCommonAncestor(f.pm.chainDb, canonical, f.lastTrustedHeader) commonAncestor := rawdb.FindCommonAncestor(f.handler.backend.chainDb, canonical, f.lastTrustedHeader)
if commonAncestor == nil { if commonAncestor == nil {
log.Error("Common ancestor of last trusted header and canonical header is nil", "canonical hash", canonical.Hash(), "trusted hash", f.lastTrustedHeader.Hash()) log.Error("Common ancestor of last trusted header and canonical header is nil", "canonical hash", canonical.Hash(), "trusted hash", f.lastTrustedHeader.Hash())
return current, unapprovedHashes return current, unapprovedHashes
@ -787,7 +786,7 @@ func (f *lightFetcher) checkKnownNode(p *peer, n *fetcherTreeNode) bool {
} }
if !f.checkAnnouncedHeaders(fp, []*types.Header{header}, []*big.Int{td}) { if !f.checkAnnouncedHeaders(fp, []*types.Header{header}, []*big.Int{td}) {
p.Log().Debug("Inconsistent announcement") p.Log().Debug("Inconsistent announcement")
go f.pm.removePeer(p.id) go f.handler.removePeer(p.id)
} }
if fp.confirmedTd != nil { if fp.confirmedTd != nil {
f.updateMaxConfirmedTd(fp.confirmedTd) f.updateMaxConfirmedTd(fp.confirmedTd)
@ -880,12 +879,12 @@ func (f *lightFetcher) checkUpdateStats(p *peer, newEntry *updateStatsEntry) {
fp.firstUpdateStats = newEntry fp.firstUpdateStats = newEntry
} }
for fp.firstUpdateStats != nil && fp.firstUpdateStats.time <= now-mclock.AbsTime(blockDelayTimeout) { for fp.firstUpdateStats != nil && fp.firstUpdateStats.time <= now-mclock.AbsTime(blockDelayTimeout) {
f.pm.serverPool.adjustBlockDelay(p.poolEntry, blockDelayTimeout) f.handler.backend.serverPool.adjustBlockDelay(p.poolEntry, blockDelayTimeout)
fp.firstUpdateStats = fp.firstUpdateStats.next fp.firstUpdateStats = fp.firstUpdateStats.next
} }
if fp.confirmedTd != nil { if fp.confirmedTd != nil {
for fp.firstUpdateStats != nil && fp.firstUpdateStats.td.Cmp(fp.confirmedTd) <= 0 { for fp.firstUpdateStats != nil && fp.firstUpdateStats.td.Cmp(fp.confirmedTd) <= 0 {
f.pm.serverPool.adjustBlockDelay(p.poolEntry, time.Duration(now-fp.firstUpdateStats.time)) f.handler.backend.serverPool.adjustBlockDelay(p.poolEntry, time.Duration(now-fp.firstUpdateStats.time))
fp.firstUpdateStats = fp.firstUpdateStats.next fp.firstUpdateStats = fp.firstUpdateStats.next
} }
} }

@ -1,168 +0,0 @@
// Copyright 2019 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package les
import (
"math/big"
"testing"
"net"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/enode"
)
func TestFetcherULCPeerSelector(t *testing.T) {
id1 := newNodeID(t).ID()
id2 := newNodeID(t).ID()
id3 := newNodeID(t).ID()
id4 := newNodeID(t).ID()
ftn1 := &fetcherTreeNode{
hash: common.HexToHash("1"),
td: big.NewInt(1),
}
ftn2 := &fetcherTreeNode{
hash: common.HexToHash("2"),
td: big.NewInt(2),
parent: ftn1,
}
ftn3 := &fetcherTreeNode{
hash: common.HexToHash("3"),
td: big.NewInt(3),
parent: ftn2,
}
lf := lightFetcher{
pm: &ProtocolManager{
ulc: &ulc{
keys: map[string]bool{
id1.String(): true,
id2.String(): true,
id3.String(): true,
id4.String(): true,
},
fraction: 70,
},
},
maxConfirmedTd: ftn1.td,
peers: map[*peer]*fetcherPeerInfo{
{
id: "peer1",
Peer: p2p.NewPeer(id1, "peer1", []p2p.Cap{}),
trusted: true,
}: {
nodeByHash: map[common.Hash]*fetcherTreeNode{
ftn1.hash: ftn1,
ftn2.hash: ftn2,
},
},
{
Peer: p2p.NewPeer(id2, "peer2", []p2p.Cap{}),
id: "peer2",
trusted: true,
}: {
nodeByHash: map[common.Hash]*fetcherTreeNode{
ftn1.hash: ftn1,
ftn2.hash: ftn2,
},
},
{
id: "peer3",
Peer: p2p.NewPeer(id3, "peer3", []p2p.Cap{}),
trusted: true,
}: {
nodeByHash: map[common.Hash]*fetcherTreeNode{
ftn1.hash: ftn1,
ftn2.hash: ftn2,
ftn3.hash: ftn3,
},
},
{
id: "peer4",
Peer: p2p.NewPeer(id4, "peer4", []p2p.Cap{}),
trusted: true,
}: {
nodeByHash: map[common.Hash]*fetcherTreeNode{
ftn1.hash: ftn1,
},
},
},
chain: &lightChainStub{
tds: map[common.Hash]*big.Int{},
headers: map[common.Hash]*types.Header{
ftn1.hash: {},
ftn2.hash: {},
ftn3.hash: {},
},
},
}
bestHash, bestAmount, bestTD, sync := lf.findBestRequest()
if bestTD == nil {
t.Fatal("Empty result")
}
if bestTD.Cmp(ftn2.td) != 0 {
t.Fatal("bad td", bestTD)
}
if bestHash != ftn2.hash {
t.Fatal("bad hash", bestTD)
}
_, _ = bestAmount, sync
}
type lightChainStub struct {
BlockChain
tds map[common.Hash]*big.Int
headers map[common.Hash]*types.Header
insertHeaderChainAssertFunc func(chain []*types.Header, checkFreq int) (int, error)
}
func (l *lightChainStub) GetHeader(hash common.Hash, number uint64) *types.Header {
if h, ok := l.headers[hash]; ok {
return h
}
return nil
}
func (l *lightChainStub) LockChain() {}
func (l *lightChainStub) UnlockChain() {}
func (l *lightChainStub) GetTd(hash common.Hash, number uint64) *big.Int {
if td, ok := l.tds[hash]; ok {
return td
}
return nil
}
func (l *lightChainStub) InsertHeaderChain(chain []*types.Header, checkFreq int) (int, error) {
return l.insertHeaderChainAssertFunc(chain, checkFreq)
}
func newNodeID(t *testing.T) *enode.Node {
key, err := crypto.GenerateKey()
if err != nil {
t.Fatal("generate key err:", err)
}
return enode.NewV4(&key.PublicKey, net.IP{}, 35000, 35000)
}

File diff suppressed because it is too large Load Diff

@ -48,11 +48,13 @@ func expectResponse(r p2p.MsgReader, msgcode, reqID, bv uint64, data interface{}
// Tests that block headers can be retrieved from a remote chain based on user queries. // Tests that block headers can be retrieved from a remote chain based on user queries.
func TestGetBlockHeadersLes2(t *testing.T) { testGetBlockHeaders(t, 2) } func TestGetBlockHeadersLes2(t *testing.T) { testGetBlockHeaders(t, 2) }
func TestGetBlockHeadersLes3(t *testing.T) { testGetBlockHeaders(t, 3) }
func testGetBlockHeaders(t *testing.T, protocol int) { func testGetBlockHeaders(t *testing.T, protocol int) {
server, tearDown := newServerEnv(t, downloader.MaxHashFetch+15, protocol, nil) server, tearDown := newServerEnv(t, downloader.MaxHashFetch+15, protocol, nil, false, true, 0)
defer tearDown() defer tearDown()
bc := server.pm.blockchain.(*core.BlockChain)
bc := server.handler.blockchain
// Create a "random" unknown hash for testing // Create a "random" unknown hash for testing
var unknown common.Hash var unknown common.Hash
@ -114,10 +116,10 @@ func testGetBlockHeaders(t *testing.T, protocol int) {
[]common.Hash{bc.CurrentBlock().Hash()}, []common.Hash{bc.CurrentBlock().Hash()},
}, },
// Ensure protocol limits are honored // Ensure protocol limits are honored
/*{ //{
&getBlockHeadersData{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64() - 1}, Amount: limit + 10, Reverse: true}, // &getBlockHeadersData{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64() - 1}, Amount: limit + 10, Reverse: true},
bc.GetBlockHashesFromHash(bc.CurrentBlock().Hash(), limit), // []common.Hash{},
},*/ //},
// Check that requesting more than available is handled gracefully // Check that requesting more than available is handled gracefully
{ {
&getBlockHeadersData{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64() - 4}, Skip: 3, Amount: 3}, &getBlockHeadersData{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64() - 4}, Skip: 3, Amount: 3},
@ -165,9 +167,10 @@ func testGetBlockHeaders(t *testing.T, protocol int) {
} }
// Send the hash request and verify the response // Send the hash request and verify the response
reqID++ reqID++
cost := server.tPeer.GetRequestCost(GetBlockHeadersMsg, int(tt.query.Amount))
sendRequest(server.tPeer.app, GetBlockHeadersMsg, reqID, cost, tt.query) cost := server.peer.peer.GetRequestCost(GetBlockHeadersMsg, int(tt.query.Amount))
if err := expectResponse(server.tPeer.app, BlockHeadersMsg, reqID, testBufLimit, headers); err != nil { sendRequest(server.peer.app, GetBlockHeadersMsg, reqID, cost, tt.query)
if err := expectResponse(server.peer.app, BlockHeadersMsg, reqID, testBufLimit, headers); err != nil {
t.Errorf("test %d: headers mismatch: %v", i, err) t.Errorf("test %d: headers mismatch: %v", i, err)
} }
} }
@ -175,11 +178,13 @@ func testGetBlockHeaders(t *testing.T, protocol int) {
// Tests that block contents can be retrieved from a remote chain based on their hashes. // Tests that block contents can be retrieved from a remote chain based on their hashes.
func TestGetBlockBodiesLes2(t *testing.T) { testGetBlockBodies(t, 2) } func TestGetBlockBodiesLes2(t *testing.T) { testGetBlockBodies(t, 2) }
func TestGetBlockBodiesLes3(t *testing.T) { testGetBlockBodies(t, 3) }
func testGetBlockBodies(t *testing.T, protocol int) { func testGetBlockBodies(t *testing.T, protocol int) {
server, tearDown := newServerEnv(t, downloader.MaxBlockFetch+15, protocol, nil) server, tearDown := newServerEnv(t, downloader.MaxBlockFetch+15, protocol, nil, false, true, 0)
defer tearDown() defer tearDown()
bc := server.pm.blockchain.(*core.BlockChain)
bc := server.handler.blockchain
// Create a batch of tests for various scenarios // Create a batch of tests for various scenarios
limit := MaxBodyFetch limit := MaxBodyFetch
@ -239,10 +244,11 @@ func testGetBlockBodies(t *testing.T, protocol int) {
} }
} }
reqID++ reqID++
// Send the hash request and verify the response // Send the hash request and verify the response
cost := server.tPeer.GetRequestCost(GetBlockBodiesMsg, len(hashes)) cost := server.peer.peer.GetRequestCost(GetBlockBodiesMsg, len(hashes))
sendRequest(server.tPeer.app, GetBlockBodiesMsg, reqID, cost, hashes) sendRequest(server.peer.app, GetBlockBodiesMsg, reqID, cost, hashes)
if err := expectResponse(server.tPeer.app, BlockBodiesMsg, reqID, testBufLimit, bodies); err != nil { if err := expectResponse(server.peer.app, BlockBodiesMsg, reqID, testBufLimit, bodies); err != nil {
t.Errorf("test %d: bodies mismatch: %v", i, err) t.Errorf("test %d: bodies mismatch: %v", i, err)
} }
} }
@ -250,12 +256,13 @@ func testGetBlockBodies(t *testing.T, protocol int) {
// Tests that the contract codes can be retrieved based on account addresses. // Tests that the contract codes can be retrieved based on account addresses.
func TestGetCodeLes2(t *testing.T) { testGetCode(t, 2) } func TestGetCodeLes2(t *testing.T) { testGetCode(t, 2) }
func TestGetCodeLes3(t *testing.T) { testGetCode(t, 3) }
func testGetCode(t *testing.T, protocol int) { func testGetCode(t *testing.T, protocol int) {
// Assemble the test environment // Assemble the test environment
server, tearDown := newServerEnv(t, 4, protocol, nil) server, tearDown := newServerEnv(t, 4, protocol, nil, false, true, 0)
defer tearDown() defer tearDown()
bc := server.pm.blockchain.(*core.BlockChain) bc := server.handler.blockchain
var codereqs []*CodeReq var codereqs []*CodeReq
var codes [][]byte var codes [][]byte
@ -271,9 +278,9 @@ func testGetCode(t *testing.T, protocol int) {
} }
} }
cost := server.tPeer.GetRequestCost(GetCodeMsg, len(codereqs)) cost := server.peer.peer.GetRequestCost(GetCodeMsg, len(codereqs))
sendRequest(server.tPeer.app, GetCodeMsg, 42, cost, codereqs) sendRequest(server.peer.app, GetCodeMsg, 42, cost, codereqs)
if err := expectResponse(server.tPeer.app, CodeMsg, 42, testBufLimit, codes); err != nil { if err := expectResponse(server.peer.app, CodeMsg, 42, testBufLimit, codes); err != nil {
t.Errorf("codes mismatch: %v", err) t.Errorf("codes mismatch: %v", err)
} }
} }
@ -283,18 +290,18 @@ func TestGetStaleCodeLes2(t *testing.T) { testGetStaleCode(t, 2) }
func TestGetStaleCodeLes3(t *testing.T) { testGetStaleCode(t, 3) } func TestGetStaleCodeLes3(t *testing.T) { testGetStaleCode(t, 3) }
func testGetStaleCode(t *testing.T, protocol int) { func testGetStaleCode(t *testing.T, protocol int) {
server, tearDown := newServerEnv(t, core.TriesInMemory+4, protocol, nil) server, tearDown := newServerEnv(t, core.TriesInMemory+4, protocol, nil, false, true, 0)
defer tearDown() defer tearDown()
bc := server.pm.blockchain.(*core.BlockChain) bc := server.handler.blockchain
check := func(number uint64, expected [][]byte) { check := func(number uint64, expected [][]byte) {
req := &CodeReq{ req := &CodeReq{
BHash: bc.GetHeaderByNumber(number).Hash(), BHash: bc.GetHeaderByNumber(number).Hash(),
AccKey: crypto.Keccak256(testContractAddr[:]), AccKey: crypto.Keccak256(testContractAddr[:]),
} }
cost := server.tPeer.GetRequestCost(GetCodeMsg, 1) cost := server.peer.peer.GetRequestCost(GetCodeMsg, 1)
sendRequest(server.tPeer.app, GetCodeMsg, 42, cost, []*CodeReq{req}) sendRequest(server.peer.app, GetCodeMsg, 42, cost, []*CodeReq{req})
if err := expectResponse(server.tPeer.app, CodeMsg, 42, testBufLimit, expected); err != nil { if err := expectResponse(server.peer.app, CodeMsg, 42, testBufLimit, expected); err != nil {
t.Errorf("codes mismatch: %v", err) t.Errorf("codes mismatch: %v", err)
} }
} }
@ -305,12 +312,14 @@ func testGetStaleCode(t *testing.T, protocol int) {
// Tests that the transaction receipts can be retrieved based on hashes. // Tests that the transaction receipts can be retrieved based on hashes.
func TestGetReceiptLes2(t *testing.T) { testGetReceipt(t, 2) } func TestGetReceiptLes2(t *testing.T) { testGetReceipt(t, 2) }
func TestGetReceiptLes3(t *testing.T) { testGetReceipt(t, 3) }
func testGetReceipt(t *testing.T, protocol int) { func testGetReceipt(t *testing.T, protocol int) {
// Assemble the test environment // Assemble the test environment
server, tearDown := newServerEnv(t, 4, protocol, nil) server, tearDown := newServerEnv(t, 4, protocol, nil, false, true, 0)
defer tearDown() defer tearDown()
bc := server.pm.blockchain.(*core.BlockChain)
bc := server.handler.blockchain
// Collect the hashes to request, and the response to expect // Collect the hashes to request, and the response to expect
var receipts []types.Receipts var receipts []types.Receipts
@ -322,26 +331,28 @@ func testGetReceipt(t *testing.T, protocol int) {
receipts = append(receipts, rawdb.ReadRawReceipts(server.db, block.Hash(), block.NumberU64())) receipts = append(receipts, rawdb.ReadRawReceipts(server.db, block.Hash(), block.NumberU64()))
} }
// Send the hash request and verify the response // Send the hash request and verify the response
cost := server.tPeer.GetRequestCost(GetReceiptsMsg, len(hashes)) cost := server.peer.peer.GetRequestCost(GetReceiptsMsg, len(hashes))
sendRequest(server.tPeer.app, GetReceiptsMsg, 42, cost, hashes) sendRequest(server.peer.app, GetReceiptsMsg, 42, cost, hashes)
if err := expectResponse(server.tPeer.app, ReceiptsMsg, 42, testBufLimit, receipts); err != nil { if err := expectResponse(server.peer.app, ReceiptsMsg, 42, testBufLimit, receipts); err != nil {
t.Errorf("receipts mismatch: %v", err) t.Errorf("receipts mismatch: %v", err)
} }
} }
// Tests that trie merkle proofs can be retrieved // Tests that trie merkle proofs can be retrieved
func TestGetProofsLes2(t *testing.T) { testGetProofs(t, 2) } func TestGetProofsLes2(t *testing.T) { testGetProofs(t, 2) }
func TestGetProofsLes3(t *testing.T) { testGetProofs(t, 3) }
func testGetProofs(t *testing.T, protocol int) { func testGetProofs(t *testing.T, protocol int) {
// Assemble the test environment // Assemble the test environment
server, tearDown := newServerEnv(t, 4, protocol, nil) server, tearDown := newServerEnv(t, 4, protocol, nil, false, true, 0)
defer tearDown() defer tearDown()
bc := server.pm.blockchain.(*core.BlockChain)
bc := server.handler.blockchain
var proofreqs []ProofReq var proofreqs []ProofReq
proofsV2 := light.NewNodeSet() proofsV2 := light.NewNodeSet()
accounts := []common.Address{bankAddr, userAddr1, userAddr2, {}} accounts := []common.Address{bankAddr, userAddr1, userAddr2, signerAddr, {}}
for i := uint64(0); i <= bc.CurrentBlock().NumberU64(); i++ { for i := uint64(0); i <= bc.CurrentBlock().NumberU64(); i++ {
header := bc.GetHeaderByNumber(i) header := bc.GetHeaderByNumber(i)
trie, _ := trie.New(header.Root, trie.NewDatabase(server.db)) trie, _ := trie.New(header.Root, trie.NewDatabase(server.db))
@ -356,9 +367,9 @@ func testGetProofs(t *testing.T, protocol int) {
} }
} }
// Send the proof request and verify the response // Send the proof request and verify the response
cost := server.tPeer.GetRequestCost(GetProofsV2Msg, len(proofreqs)) cost := server.peer.peer.GetRequestCost(GetProofsV2Msg, len(proofreqs))
sendRequest(server.tPeer.app, GetProofsV2Msg, 42, cost, proofreqs) sendRequest(server.peer.app, GetProofsV2Msg, 42, cost, proofreqs)
if err := expectResponse(server.tPeer.app, ProofsV2Msg, 42, testBufLimit, proofsV2.NodeList()); err != nil { if err := expectResponse(server.peer.app, ProofsV2Msg, 42, testBufLimit, proofsV2.NodeList()); err != nil {
t.Errorf("proofs mismatch: %v", err) t.Errorf("proofs mismatch: %v", err)
} }
} }
@ -368,9 +379,9 @@ func TestGetStaleProofLes2(t *testing.T) { testGetStaleProof(t, 2) }
func TestGetStaleProofLes3(t *testing.T) { testGetStaleProof(t, 3) } func TestGetStaleProofLes3(t *testing.T) { testGetStaleProof(t, 3) }
func testGetStaleProof(t *testing.T, protocol int) { func testGetStaleProof(t *testing.T, protocol int) {
server, tearDown := newServerEnv(t, core.TriesInMemory+4, protocol, nil) server, tearDown := newServerEnv(t, core.TriesInMemory+4, protocol, nil, false, true, 0)
defer tearDown() defer tearDown()
bc := server.pm.blockchain.(*core.BlockChain) bc := server.handler.blockchain
check := func(number uint64, wantOK bool) { check := func(number uint64, wantOK bool) {
var ( var (
@ -381,8 +392,8 @@ func testGetStaleProof(t *testing.T, protocol int) {
BHash: header.Hash(), BHash: header.Hash(),
Key: account, Key: account,
} }
cost := server.tPeer.GetRequestCost(GetProofsV2Msg, 1) cost := server.peer.peer.GetRequestCost(GetProofsV2Msg, 1)
sendRequest(server.tPeer.app, GetProofsV2Msg, 42, cost, []*ProofReq{req}) sendRequest(server.peer.app, GetProofsV2Msg, 42, cost, []*ProofReq{req})
var expected []rlp.RawValue var expected []rlp.RawValue
if wantOK { if wantOK {
@ -391,7 +402,7 @@ func testGetStaleProof(t *testing.T, protocol int) {
t.Prove(account, 0, proofsV2) t.Prove(account, 0, proofsV2)
expected = proofsV2.NodeList() expected = proofsV2.NodeList()
} }
if err := expectResponse(server.tPeer.app, ProofsV2Msg, 42, testBufLimit, expected); err != nil { if err := expectResponse(server.peer.app, ProofsV2Msg, 42, testBufLimit, expected); err != nil {
t.Errorf("codes mismatch: %v", err) t.Errorf("codes mismatch: %v", err)
} }
} }
@ -402,6 +413,7 @@ func testGetStaleProof(t *testing.T, protocol int) {
// Tests that CHT proofs can be correctly retrieved. // Tests that CHT proofs can be correctly retrieved.
func TestGetCHTProofsLes2(t *testing.T) { testGetCHTProofs(t, 2) } func TestGetCHTProofsLes2(t *testing.T) { testGetCHTProofs(t, 2) }
func TestGetCHTProofsLes3(t *testing.T) { testGetCHTProofs(t, 3) }
func testGetCHTProofs(t *testing.T, protocol int) { func testGetCHTProofs(t *testing.T, protocol int) {
config := light.TestServerIndexerConfig config := light.TestServerIndexerConfig
@ -415,9 +427,10 @@ func testGetCHTProofs(t *testing.T, protocol int) {
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
} }
} }
server, tearDown := newServerEnv(t, int(config.ChtSize+config.ChtConfirms), protocol, waitIndexers) server, tearDown := newServerEnv(t, int(config.ChtSize+config.ChtConfirms), protocol, waitIndexers, false, true, 0)
defer tearDown() defer tearDown()
bc := server.pm.blockchain.(*core.BlockChain)
bc := server.handler.blockchain
// Assemble the proofs from the different protocols // Assemble the proofs from the different protocols
header := bc.GetHeaderByNumber(config.ChtSize - 1) header := bc.GetHeaderByNumber(config.ChtSize - 1)
@ -440,15 +453,18 @@ func testGetCHTProofs(t *testing.T, protocol int) {
AuxReq: auxHeader, AuxReq: auxHeader,
}} }}
// Send the proof request and verify the response // Send the proof request and verify the response
cost := server.tPeer.GetRequestCost(GetHelperTrieProofsMsg, len(requestsV2)) cost := server.peer.peer.GetRequestCost(GetHelperTrieProofsMsg, len(requestsV2))
sendRequest(server.tPeer.app, GetHelperTrieProofsMsg, 42, cost, requestsV2) sendRequest(server.peer.app, GetHelperTrieProofsMsg, 42, cost, requestsV2)
if err := expectResponse(server.tPeer.app, HelperTrieProofsMsg, 42, testBufLimit, proofsV2); err != nil { if err := expectResponse(server.peer.app, HelperTrieProofsMsg, 42, testBufLimit, proofsV2); err != nil {
t.Errorf("proofs mismatch: %v", err) t.Errorf("proofs mismatch: %v", err)
} }
} }
func TestGetBloombitsProofsLes2(t *testing.T) { testGetBloombitsProofs(t, 2) }
func TestGetBloombitsProofsLes3(t *testing.T) { testGetBloombitsProofs(t, 3) }
// Tests that bloombits proofs can be correctly retrieved. // Tests that bloombits proofs can be correctly retrieved.
func TestGetBloombitsProofs(t *testing.T) { func testGetBloombitsProofs(t *testing.T, protocol int) {
config := light.TestServerIndexerConfig config := light.TestServerIndexerConfig
waitIndexers := func(cIndexer, bIndexer, btIndexer *core.ChainIndexer) { waitIndexers := func(cIndexer, bIndexer, btIndexer *core.ChainIndexer) {
@ -460,9 +476,10 @@ func TestGetBloombitsProofs(t *testing.T) {
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
} }
} }
server, tearDown := newServerEnv(t, int(config.BloomTrieSize+config.BloomTrieConfirms), 2, waitIndexers) server, tearDown := newServerEnv(t, int(config.BloomTrieSize+config.BloomTrieConfirms), protocol, waitIndexers, false, true, 0)
defer tearDown() defer tearDown()
bc := server.pm.blockchain.(*core.BlockChain)
bc := server.handler.blockchain
// Request and verify each bit of the bloom bits proofs // Request and verify each bit of the bloom bits proofs
for bit := 0; bit < 2048; bit++ { for bit := 0; bit < 2048; bit++ {
@ -485,43 +502,39 @@ func TestGetBloombitsProofs(t *testing.T) {
trie.Prove(key, 0, &proofs.Proofs) trie.Prove(key, 0, &proofs.Proofs)
// Send the proof request and verify the response // Send the proof request and verify the response
cost := server.tPeer.GetRequestCost(GetHelperTrieProofsMsg, len(requests)) cost := server.peer.peer.GetRequestCost(GetHelperTrieProofsMsg, len(requests))
sendRequest(server.tPeer.app, GetHelperTrieProofsMsg, 42, cost, requests) sendRequest(server.peer.app, GetHelperTrieProofsMsg, 42, cost, requests)
if err := expectResponse(server.tPeer.app, HelperTrieProofsMsg, 42, testBufLimit, proofs); err != nil { if err := expectResponse(server.peer.app, HelperTrieProofsMsg, 42, testBufLimit, proofs); err != nil {
t.Errorf("bit %d: proofs mismatch: %v", bit, err) t.Errorf("bit %d: proofs mismatch: %v", bit, err)
} }
} }
} }
func TestTransactionStatusLes2(t *testing.T) { func TestTransactionStatusLes2(t *testing.T) { testTransactionStatus(t, 2) }
server, tearDown := newServerEnv(t, 0, 2, nil) func TestTransactionStatusLes3(t *testing.T) { testTransactionStatus(t, 3) }
func testTransactionStatus(t *testing.T, protocol int) {
server, tearDown := newServerEnv(t, 0, protocol, nil, false, true, 0)
defer tearDown() defer tearDown()
server.pm.addTxsSync = true server.handler.addTxsSync = true
chain := server.pm.blockchain.(*core.BlockChain) chain := server.handler.blockchain
config := core.DefaultTxPoolConfig
config.Journal = ""
txpool := core.NewTxPool(config, params.TestChainConfig, chain)
server.pm.txpool = txpool
peer, _ := newTestPeer(t, "peer", 2, server.pm, true, 0)
defer peer.close()
var reqID uint64 var reqID uint64
test := func(tx *types.Transaction, send bool, expStatus light.TxStatus) { test := func(tx *types.Transaction, send bool, expStatus light.TxStatus) {
reqID++ reqID++
if send { if send {
cost := server.tPeer.GetRequestCost(SendTxV2Msg, 1) cost := server.peer.peer.GetRequestCost(SendTxV2Msg, 1)
sendRequest(server.tPeer.app, SendTxV2Msg, reqID, cost, types.Transactions{tx}) sendRequest(server.peer.app, SendTxV2Msg, reqID, cost, types.Transactions{tx})
} else { } else {
cost := server.tPeer.GetRequestCost(GetTxStatusMsg, 1) cost := server.peer.peer.GetRequestCost(GetTxStatusMsg, 1)
sendRequest(server.tPeer.app, GetTxStatusMsg, reqID, cost, []common.Hash{tx.Hash()}) sendRequest(server.peer.app, GetTxStatusMsg, reqID, cost, []common.Hash{tx.Hash()})
} }
if err := expectResponse(server.tPeer.app, TxStatusMsg, reqID, testBufLimit, []light.TxStatus{expStatus}); err != nil { if err := expectResponse(server.peer.app, TxStatusMsg, reqID, testBufLimit, []light.TxStatus{expStatus}); err != nil {
t.Errorf("transaction status mismatch") t.Errorf("transaction status mismatch")
} }
} }
signer := types.HomesteadSigner{} signer := types.HomesteadSigner{}
// test error status by sending an underpriced transaction // test error status by sending an underpriced transaction
@ -551,18 +564,22 @@ func TestTransactionStatusLes2(t *testing.T) {
} }
// wait until TxPool processes the inserted block // wait until TxPool processes the inserted block
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
if pending, _ := txpool.Stats(); pending == 1 { if pending, _ := server.handler.txpool.Stats(); pending == 1 {
break break
} }
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
} }
if pending, _ := txpool.Stats(); pending != 1 { if pending, _ := server.handler.txpool.Stats(); pending != 1 {
t.Fatalf("pending count mismatch: have %d, want 1", pending) t.Fatalf("pending count mismatch: have %d, want 1", pending)
} }
// Discard new block announcement
msg, _ := server.peer.app.ReadMsg()
msg.Discard()
// check if their status is included now // check if their status is included now
block1hash := rawdb.ReadCanonicalHash(server.db, 1) block1hash := rawdb.ReadCanonicalHash(server.db, 1)
test(tx1, false, light.TxStatus{Status: core.TxStatusIncluded, Lookup: &rawdb.LegacyTxLookupEntry{BlockHash: block1hash, BlockIndex: 1, Index: 0}}) test(tx1, false, light.TxStatus{Status: core.TxStatusIncluded, Lookup: &rawdb.LegacyTxLookupEntry{BlockHash: block1hash, BlockIndex: 1, Index: 0}})
test(tx2, false, light.TxStatus{Status: core.TxStatusIncluded, Lookup: &rawdb.LegacyTxLookupEntry{BlockHash: block1hash, BlockIndex: 1, Index: 1}}) test(tx2, false, light.TxStatus{Status: core.TxStatusIncluded, Lookup: &rawdb.LegacyTxLookupEntry{BlockHash: block1hash, BlockIndex: 1, Index: 1}})
// create a reorg that rolls them back // create a reorg that rolls them back
@ -572,46 +589,46 @@ func TestTransactionStatusLes2(t *testing.T) {
} }
// wait until TxPool processes the reorg // wait until TxPool processes the reorg
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
if pending, _ := txpool.Stats(); pending == 3 { if pending, _ := server.handler.txpool.Stats(); pending == 3 {
break break
} }
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
} }
if pending, _ := txpool.Stats(); pending != 3 { if pending, _ := server.handler.txpool.Stats(); pending != 3 {
t.Fatalf("pending count mismatch: have %d, want 3", pending) t.Fatalf("pending count mismatch: have %d, want 3", pending)
} }
// Discard new block announcement
msg, _ = server.peer.app.ReadMsg()
msg.Discard()
// check if their status is pending again // check if their status is pending again
test(tx1, false, light.TxStatus{Status: core.TxStatusPending}) test(tx1, false, light.TxStatus{Status: core.TxStatusPending})
test(tx2, false, light.TxStatus{Status: core.TxStatusPending}) test(tx2, false, light.TxStatus{Status: core.TxStatusPending})
} }
func TestStopResumeLes3(t *testing.T) { func TestStopResumeLes3(t *testing.T) {
db := rawdb.NewMemoryDatabase() server, tearDown := newServerEnv(t, 0, 3, nil, true, true, testBufLimit/10)
clock := &mclock.Simulated{} defer tearDown()
testCost := testBufLimit / 10
pm, _, err := newTestProtocolManager(false, 0, nil, nil, nil, db, nil, 0, testCost, clock)
if err != nil {
t.Fatalf("Failed to create protocol manager: %v", err)
}
peer, _ := newTestPeer(t, "peer", 3, pm, true, testCost)
defer peer.close()
expBuf := testBufLimit server.handler.server.costTracker.testing = true
var reqID uint64
header := pm.blockchain.CurrentHeader() var (
reqID uint64
expBuf = testBufLimit
testCost = testBufLimit / 10
)
header := server.handler.blockchain.CurrentHeader()
req := func() { req := func() {
reqID++ reqID++
sendRequest(peer.app, GetBlockHeadersMsg, reqID, testCost, &getBlockHeadersData{Origin: hashOrNumber{Hash: header.Hash()}, Amount: 1}) sendRequest(server.peer.app, GetBlockHeadersMsg, reqID, testCost, &getBlockHeadersData{Origin: hashOrNumber{Hash: header.Hash()}, Amount: 1})
} }
for i := 1; i <= 5; i++ { for i := 1; i <= 5; i++ {
// send requests while we still have enough buffer and expect a response // send requests while we still have enough buffer and expect a response
for expBuf >= testCost { for expBuf >= testCost {
req() req()
expBuf -= testCost expBuf -= testCost
if err := expectResponse(peer.app, BlockHeadersMsg, reqID, expBuf, []*types.Header{header}); err != nil { if err := expectResponse(server.peer.app, BlockHeadersMsg, reqID, expBuf, []*types.Header{header}); err != nil {
t.Fatalf("expected response and failed: %v", err) t.Errorf("expected response and failed: %v", err)
} }
} }
// send some more requests in excess and expect a single StopMsg // send some more requests in excess and expect a single StopMsg
@ -620,15 +637,16 @@ func TestStopResumeLes3(t *testing.T) {
req() req()
c-- c--
} }
if err := p2p.ExpectMsg(peer.app, StopMsg, nil); err != nil { if err := p2p.ExpectMsg(server.peer.app, StopMsg, nil); err != nil {
t.Errorf("expected StopMsg and failed: %v", err) t.Errorf("expected StopMsg and failed: %v", err)
} }
// wait until the buffer is recharged by half of the limit // wait until the buffer is recharged by half of the limit
wait := testBufLimit / testBufRecharge / 2 wait := testBufLimit / testBufRecharge / 2
clock.Run(time.Millisecond * time.Duration(wait)) server.clock.(*mclock.Simulated).Run(time.Millisecond * time.Duration(wait))
// expect a ResumeMsg with the partially recharged buffer value // expect a ResumeMsg with the partially recharged buffer value
expBuf += testBufRecharge * wait expBuf += testBufRecharge * wait
if err := p2p.ExpectMsg(peer.app, ResumeMsg, expBuf); err != nil { if err := p2p.ExpectMsg(server.peer.app, ResumeMsg, expBuf); err != nil {
t.Errorf("expected ResumeMsg and failed: %v", err) t.Errorf("expected ResumeMsg and failed: %v", err)
} }
} }

@ -22,31 +22,73 @@ import (
) )
var ( var (
miscInPacketsMeter = metrics.NewRegisteredMeter("les/misc/in/packets", nil) miscInPacketsMeter = metrics.NewRegisteredMeter("les/misc/in/packets/total", nil)
miscInTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic", nil) miscInTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic/total", nil)
miscOutPacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets", nil) miscInHeaderPacketsMeter = metrics.NewRegisteredMeter("les/misc/in/packets/header", nil)
miscOutTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic", nil) miscInHeaderTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic/header", nil)
miscInBodyPacketsMeter = metrics.NewRegisteredMeter("les/misc/in/packets/body", nil)
connectionTimer = metrics.NewRegisteredTimer("les/connectionTime", nil) miscInBodyTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic/body", nil)
miscInCodePacketsMeter = metrics.NewRegisteredMeter("les/misc/in/packets/code", nil)
totalConnectedGauge = metrics.NewRegisteredGauge("les/server/totalConnected", nil) miscInCodeTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic/code", nil)
totalCapacityGauge = metrics.NewRegisteredGauge("les/server/totalCapacity", nil) miscInReceiptPacketsMeter = metrics.NewRegisteredMeter("les/misc/in/packets/receipt", nil)
totalRechargeGauge = metrics.NewRegisteredGauge("les/server/totalRecharge", nil) miscInReceiptTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic/receipt", nil)
blockProcessingTimer = metrics.NewRegisteredTimer("les/server/blockProcessingTime", nil) miscInTrieProofPacketsMeter = metrics.NewRegisteredMeter("les/misc/in/packets/proof", nil)
requestServedTimer = metrics.NewRegisteredTimer("les/server/requestServed", nil) miscInTrieProofTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic/proof", nil)
requestServedMeter = metrics.NewRegisteredMeter("les/server/totalRequestServed", nil) miscInHelperTriePacketsMeter = metrics.NewRegisteredMeter("les/misc/in/packets/helperTrie", nil)
requestEstimatedMeter = metrics.NewRegisteredMeter("les/server/totalRequestEstimated", nil) miscInHelperTrieTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic/helperTrie", nil)
relativeCostHistogram = metrics.NewRegisteredHistogram("les/server/relativeCost", nil, metrics.NewExpDecaySample(1028, 0.015)) miscInTxsPacketsMeter = metrics.NewRegisteredMeter("les/misc/in/packets/txs", nil)
recentServedGauge = metrics.NewRegisteredGauge("les/server/recentRequestServed", nil) miscInTxsTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic/txs", nil)
recentEstimatedGauge = metrics.NewRegisteredGauge("les/server/recentRequestEstimated", nil) miscInTxStatusPacketsMeter = metrics.NewRegisteredMeter("les/misc/in/packets/txStatus", nil)
sqServedGauge = metrics.NewRegisteredGauge("les/server/servingQueue/served", nil) miscInTxStatusTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic/txStatus", nil)
sqQueuedGauge = metrics.NewRegisteredGauge("les/server/servingQueue/queued", nil)
miscOutPacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets/total", nil)
miscOutTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/total", nil)
miscOutHeaderPacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets/header", nil)
miscOutHeaderTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/header", nil)
miscOutBodyPacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets/body", nil)
miscOutBodyTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/body", nil)
miscOutCodePacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets/code", nil)
miscOutCodeTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/code", nil)
miscOutReceiptPacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets/receipt", nil)
miscOutReceiptTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/receipt", nil)
miscOutTrieProofPacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets/proof", nil)
miscOutTrieProofTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/proof", nil)
miscOutHelperTriePacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets/helperTrie", nil)
miscOutHelperTrieTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/helperTrie", nil)
miscOutTxsPacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets/txs", nil)
miscOutTxsTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/txs", nil)
miscOutTxStatusPacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets/txStatus", nil)
miscOutTxStatusTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/txStatus", nil)
connectionTimer = metrics.NewRegisteredTimer("les/connection/duration", nil)
serverConnectionGauge = metrics.NewRegisteredGauge("les/connection/server", nil)
clientConnectionGauge = metrics.NewRegisteredGauge("les/connection/client", nil)
totalCapacityGauge = metrics.NewRegisteredGauge("les/server/totalCapacity", nil)
totalRechargeGauge = metrics.NewRegisteredGauge("les/server/totalRecharge", nil)
totalConnectedGauge = metrics.NewRegisteredGauge("les/server/totalConnected", nil)
blockProcessingTimer = metrics.NewRegisteredTimer("les/server/blockProcessingTime", nil)
requestServedMeter = metrics.NewRegisteredMeter("les/server/req/avgServedTime", nil)
requestServedTimer = metrics.NewRegisteredTimer("les/server/req/servedTime", nil)
requestEstimatedMeter = metrics.NewRegisteredMeter("les/server/req/avgEstimatedTime", nil)
requestEstimatedTimer = metrics.NewRegisteredTimer("les/server/req/estimatedTime", nil)
relativeCostHistogram = metrics.NewRegisteredHistogram("les/server/req/relative", nil, metrics.NewExpDecaySample(1028, 0.015))
recentServedGauge = metrics.NewRegisteredGauge("les/server/recentRequestServed", nil)
recentEstimatedGauge = metrics.NewRegisteredGauge("les/server/recentRequestEstimated", nil)
sqServedGauge = metrics.NewRegisteredGauge("les/server/servingQueue/served", nil)
sqQueuedGauge = metrics.NewRegisteredGauge("les/server/servingQueue/queued", nil)
clientConnectedMeter = metrics.NewRegisteredMeter("les/server/clientEvent/connected", nil) clientConnectedMeter = metrics.NewRegisteredMeter("les/server/clientEvent/connected", nil)
clientRejectedMeter = metrics.NewRegisteredMeter("les/server/clientEvent/rejected", nil) clientRejectedMeter = metrics.NewRegisteredMeter("les/server/clientEvent/rejected", nil)
clientKickedMeter = metrics.NewRegisteredMeter("les/server/clientEvent/kicked", nil) clientKickedMeter = metrics.NewRegisteredMeter("les/server/clientEvent/kicked", nil)
clientDisconnectedMeter = metrics.NewRegisteredMeter("les/server/clientEvent/disconnected", nil) clientDisconnectedMeter = metrics.NewRegisteredMeter("les/server/clientEvent/disconnected", nil)
clientFreezeMeter = metrics.NewRegisteredMeter("les/server/clientEvent/freeze", nil) clientFreezeMeter = metrics.NewRegisteredMeter("les/server/clientEvent/freeze", nil)
clientErrorMeter = metrics.NewRegisteredMeter("les/server/clientEvent/error", nil) clientErrorMeter = metrics.NewRegisteredMeter("les/server/clientEvent/error", nil)
requestRTT = metrics.NewRegisteredTimer("les/client/req/rtt", nil)
requestSendDelay = metrics.NewRegisteredTimer("les/client/req/sendDelay", nil)
) )
// meteredMsgReadWriter is a wrapper around a p2p.MsgReadWriter, capable of // meteredMsgReadWriter is a wrapper around a p2p.MsgReadWriter, capable of
@ -58,17 +100,11 @@ type meteredMsgReadWriter struct {
// newMeteredMsgWriter wraps a p2p MsgReadWriter with metering support. If the // newMeteredMsgWriter wraps a p2p MsgReadWriter with metering support. If the
// metrics system is disabled, this function returns the original object. // metrics system is disabled, this function returns the original object.
func newMeteredMsgWriter(rw p2p.MsgReadWriter) p2p.MsgReadWriter { func newMeteredMsgWriter(rw p2p.MsgReadWriter, version int) p2p.MsgReadWriter {
if !metrics.Enabled { if !metrics.Enabled {
return rw return rw
} }
return &meteredMsgReadWriter{MsgReadWriter: rw} return &meteredMsgReadWriter{MsgReadWriter: rw, version: version}
}
// Init sets the protocol version used by the stream to know which meters to
// increment in case of overlapping message ids between protocol versions.
func (rw *meteredMsgReadWriter) Init(version int) {
rw.version = version
} }
func (rw *meteredMsgReadWriter) ReadMsg() (p2p.Msg, error) { func (rw *meteredMsgReadWriter) ReadMsg() (p2p.Msg, error) {

@ -18,7 +18,9 @@ package les
import ( import (
"context" "context"
"time"
"github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/light"
@ -120,10 +122,11 @@ func (odr *LesOdr) Retrieve(ctx context.Context, req light.OdrRequest) (err erro
return func() { lreq.Request(reqID, p) } return func() { lreq.Request(reqID, p) }
}, },
} }
sent := mclock.Now()
if err = odr.retriever.retrieve(ctx, reqID, rq, func(p distPeer, msg *Msg) error { return lreq.Validate(odr.db, msg) }, odr.stop); err == nil { if err = odr.retriever.retrieve(ctx, reqID, rq, func(p distPeer, msg *Msg) error { return lreq.Validate(odr.db, msg) }, odr.stop); err == nil {
// retrieved from network, store in db // retrieved from network, store in db
req.StoreResult(odr.db) req.StoreResult(odr.db)
requestRTT.Update(time.Duration(mclock.Now() - sent))
} else { } else {
log.Debug("Failed to retrieve data from network", "err", err) log.Debug("Failed to retrieve data from network", "err", err)
} }

@ -39,6 +39,7 @@ import (
type odrTestFn func(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte type odrTestFn func(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte
func TestOdrGetBlockLes2(t *testing.T) { testOdr(t, 2, 1, true, odrGetBlock) } func TestOdrGetBlockLes2(t *testing.T) { testOdr(t, 2, 1, true, odrGetBlock) }
func TestOdrGetBlockLes3(t *testing.T) { testOdr(t, 3, 1, true, odrGetBlock) }
func odrGetBlock(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte { func odrGetBlock(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte {
var block *types.Block var block *types.Block
@ -55,6 +56,7 @@ func odrGetBlock(ctx context.Context, db ethdb.Database, config *params.ChainCon
} }
func TestOdrGetReceiptsLes2(t *testing.T) { testOdr(t, 2, 1, true, odrGetReceipts) } func TestOdrGetReceiptsLes2(t *testing.T) { testOdr(t, 2, 1, true, odrGetReceipts) }
func TestOdrGetReceiptsLes3(t *testing.T) { testOdr(t, 3, 1, true, odrGetReceipts) }
func odrGetReceipts(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte { func odrGetReceipts(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte {
var receipts types.Receipts var receipts types.Receipts
@ -75,6 +77,7 @@ func odrGetReceipts(ctx context.Context, db ethdb.Database, config *params.Chain
} }
func TestOdrAccountsLes2(t *testing.T) { testOdr(t, 2, 1, true, odrAccounts) } func TestOdrAccountsLes2(t *testing.T) { testOdr(t, 2, 1, true, odrAccounts) }
func TestOdrAccountsLes3(t *testing.T) { testOdr(t, 3, 1, true, odrAccounts) }
func odrAccounts(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte { func odrAccounts(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte {
dummyAddr := common.HexToAddress("1234567812345678123456781234567812345678") dummyAddr := common.HexToAddress("1234567812345678123456781234567812345678")
@ -103,6 +106,7 @@ func odrAccounts(ctx context.Context, db ethdb.Database, config *params.ChainCon
} }
func TestOdrContractCallLes2(t *testing.T) { testOdr(t, 2, 2, true, odrContractCall) } func TestOdrContractCallLes2(t *testing.T) { testOdr(t, 2, 2, true, odrContractCall) }
func TestOdrContractCallLes3(t *testing.T) { testOdr(t, 3, 2, true, odrContractCall) }
type callmsg struct { type callmsg struct {
types.Message types.Message
@ -152,6 +156,7 @@ func odrContractCall(ctx context.Context, db ethdb.Database, config *params.Chai
} }
func TestOdrTxStatusLes2(t *testing.T) { testOdr(t, 2, 1, false, odrTxStatus) } func TestOdrTxStatusLes2(t *testing.T) { testOdr(t, 2, 1, false, odrTxStatus) }
func TestOdrTxStatusLes3(t *testing.T) { testOdr(t, 3, 1, false, odrTxStatus) }
func odrTxStatus(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte { func odrTxStatus(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte {
var txs types.Transactions var txs types.Transactions
@ -178,21 +183,22 @@ func odrTxStatus(ctx context.Context, db ethdb.Database, config *params.ChainCon
// testOdr tests odr requests whose validation guaranteed by block headers. // testOdr tests odr requests whose validation guaranteed by block headers.
func testOdr(t *testing.T, protocol int, expFail uint64, checkCached bool, fn odrTestFn) { func testOdr(t *testing.T, protocol int, expFail uint64, checkCached bool, fn odrTestFn) {
// Assemble the test environment // Assemble the test environment
server, client, tearDown := newClientServerEnv(t, 4, protocol, nil, true) server, client, tearDown := newClientServerEnv(t, 4, protocol, nil, nil, 0, false, true)
defer tearDown() defer tearDown()
client.pm.synchronise(client.rPeer)
client.handler.synchronise(client.peer.peer)
test := func(expFail uint64) { test := func(expFail uint64) {
// Mark this as a helper to put the failures at the correct lines // Mark this as a helper to put the failures at the correct lines
t.Helper() t.Helper()
for i := uint64(0); i <= server.pm.blockchain.CurrentHeader().Number.Uint64(); i++ { for i := uint64(0); i <= server.handler.blockchain.CurrentHeader().Number.Uint64(); i++ {
bhash := rawdb.ReadCanonicalHash(server.db, i) bhash := rawdb.ReadCanonicalHash(server.db, i)
b1 := fn(light.NoOdr, server.db, server.pm.chainConfig, server.pm.blockchain.(*core.BlockChain), nil, bhash) b1 := fn(light.NoOdr, server.db, server.handler.server.chainConfig, server.handler.blockchain, nil, bhash)
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel() b2 := fn(ctx, client.db, client.handler.backend.chainConfig, nil, client.handler.backend.blockchain, bhash)
b2 := fn(ctx, client.db, client.pm.chainConfig, nil, client.pm.blockchain.(*light.LightChain), bhash) cancel()
eq := bytes.Equal(b1, b2) eq := bytes.Equal(b1, b2)
exp := i < expFail exp := i < expFail
@ -204,22 +210,22 @@ func testOdr(t *testing.T, protocol int, expFail uint64, checkCached bool, fn od
} }
} }
} }
// temporarily remove peer to test odr fails
// expect retrievals to fail (except genesis block) without a les peer // expect retrievals to fail (except genesis block) without a les peer
client.peers.Unregister(client.rPeer.id) client.handler.backend.peers.lock.Lock()
time.Sleep(time.Millisecond * 10) // ensure that all peerSetNotify callbacks are executed client.peer.peer.hasBlock = func(common.Hash, uint64, bool) bool { return false }
client.handler.backend.peers.lock.Unlock()
test(expFail) test(expFail)
// expect all retrievals to pass // expect all retrievals to pass
client.peers.Register(client.rPeer) client.handler.backend.peers.lock.Lock()
time.Sleep(time.Millisecond * 10) // ensure that all peerSetNotify callbacks are executed client.peer.peer.hasBlock = func(common.Hash, uint64, bool) bool { return true }
client.peers.lock.Lock() client.handler.backend.peers.lock.Unlock()
client.rPeer.hasBlock = func(common.Hash, uint64, bool) bool { return true }
client.peers.lock.Unlock()
test(5) test(5)
// still expect all retrievals to pass, now data should be cached locally
if checkCached { if checkCached {
// still expect all retrievals to pass, now data should be cached locally client.handler.backend.peers.Unregister(client.peer.peer.id)
client.peers.Unregister(client.rPeer.id)
time.Sleep(time.Millisecond * 10) // ensure that all peerSetNotify callbacks are executed time.Sleep(time.Millisecond * 10) // ensure that all peerSetNotify callbacks are executed
test(5) test(5)
} }

@ -111,7 +111,7 @@ type peer struct {
fcServer *flowcontrol.ServerNode // nil if the peer is client only fcServer *flowcontrol.ServerNode // nil if the peer is client only
fcParams flowcontrol.ServerParams fcParams flowcontrol.ServerParams
fcCosts requestCostTable fcCosts requestCostTable
balanceTracker *balanceTracker // set by clientPool.connect, used and removed by ProtocolManager.handle balanceTracker *balanceTracker // set by clientPool.connect, used and removed by serverHandler.
trusted bool trusted bool
onlyAnnounce bool onlyAnnounce bool
@ -291,6 +291,11 @@ func (p *peer) updateCapacity(cap uint64) {
p.queueSend(func() { p.SendAnnounce(announceData{Update: kvList}) }) p.queueSend(func() { p.SendAnnounce(announceData{Update: kvList}) })
} }
func (p *peer) responseID() uint64 {
p.responseCount += 1
return p.responseCount
}
func sendRequest(w p2p.MsgWriter, msgcode, reqID, cost uint64, data interface{}) error { func sendRequest(w p2p.MsgWriter, msgcode, reqID, cost uint64, data interface{}) error {
type req struct { type req struct {
ReqID uint64 ReqID uint64
@ -373,6 +378,7 @@ func (p *peer) HasBlock(hash common.Hash, number uint64, hasState bool) bool {
} }
hasBlock := p.hasBlock hasBlock := p.hasBlock
p.lock.RUnlock() p.lock.RUnlock()
return head >= number && number >= since && (recent == 0 || number+recent+4 > head) && hasBlock != nil && hasBlock(hash, number, hasState) return head >= number && number >= since && (recent == 0 || number+recent+4 > head) && hasBlock != nil && hasBlock(hash, number, hasState)
} }
@ -571,6 +577,8 @@ func (p *peer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis
defer p.lock.Unlock() defer p.lock.Unlock()
var send keyValueList var send keyValueList
// Add some basic handshake fields
send = send.add("protocolVersion", uint64(p.version)) send = send.add("protocolVersion", uint64(p.version))
send = send.add("networkId", p.network) send = send.add("networkId", p.network)
send = send.add("headTd", td) send = send.add("headTd", td)
@ -578,7 +586,8 @@ func (p *peer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis
send = send.add("headNum", headNum) send = send.add("headNum", headNum)
send = send.add("genesisHash", genesis) send = send.add("genesisHash", genesis)
if server != nil { if server != nil {
if !server.onlyAnnounce { // Add some information which services server can offer.
if !server.config.UltraLightOnlyAnnounce {
send = send.add("serveHeaders", nil) send = send.add("serveHeaders", nil)
send = send.add("serveChainSince", uint64(0)) send = send.add("serveChainSince", uint64(0))
send = send.add("serveStateSince", uint64(0)) send = send.add("serveStateSince", uint64(0))
@ -594,25 +603,28 @@ func (p *peer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis
} }
send = send.add("flowControl/BL", server.defParams.BufLimit) send = send.add("flowControl/BL", server.defParams.BufLimit)
send = send.add("flowControl/MRR", server.defParams.MinRecharge) send = send.add("flowControl/MRR", server.defParams.MinRecharge)
var costList RequestCostList var costList RequestCostList
if server.costTracker != nil { if server.costTracker.testCostList != nil {
costList = server.costTracker.makeCostList(server.costTracker.globalFactor()) costList = server.costTracker.testCostList
} else { } else {
costList = testCostList(server.testCost) costList = server.costTracker.makeCostList(server.costTracker.globalFactor())
} }
send = send.add("flowControl/MRC", costList) send = send.add("flowControl/MRC", costList)
p.fcCosts = costList.decode(ProtocolLengths[uint(p.version)]) p.fcCosts = costList.decode(ProtocolLengths[uint(p.version)])
p.fcParams = server.defParams p.fcParams = server.defParams
if server.protocolManager != nil && server.protocolManager.reg != nil && server.protocolManager.reg.isRunning() { // Add advertised checkpoint and register block height which
cp, height := server.protocolManager.reg.stableCheckpoint() // client can verify the checkpoint validity.
if server.oracle != nil && server.oracle.isRunning() {
cp, height := server.oracle.stableCheckpoint()
if cp != nil { if cp != nil {
send = send.add("checkpoint/value", cp) send = send.add("checkpoint/value", cp)
send = send.add("checkpoint/registerHeight", height) send = send.add("checkpoint/registerHeight", height)
} }
} }
} else { } else {
//on client node // Add some client-specific handshake fields
p.announceType = announceTypeSimple p.announceType = announceTypeSimple
if p.trusted { if p.trusted {
p.announceType = announceTypeSigned p.announceType = announceTypeSigned
@ -663,17 +675,12 @@ func (p *peer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis
} }
if server != nil { if server != nil {
// until we have a proper peer connectivity API, allow LES connection to other servers
/*if recv.get("serveStateSince", nil) == nil {
return errResp(ErrUselessPeer, "wanted client, got server")
}*/
if recv.get("announceType", &p.announceType) != nil { if recv.get("announceType", &p.announceType) != nil {
//set default announceType on server side // set default announceType on server side
p.announceType = announceTypeSimple p.announceType = announceTypeSimple
} }
p.fcClient = flowcontrol.NewClientNode(server.fcManager, server.defParams) p.fcClient = flowcontrol.NewClientNode(server.fcManager, server.defParams)
} else { } else {
//mark OnlyAnnounce server if "serveHeaders", "serveChainSince", "serveStateSince" or "txRelay" fields don't exist
if recv.get("serveChainSince", &p.chainSince) != nil { if recv.get("serveChainSince", &p.chainSince) != nil {
p.onlyAnnounce = true p.onlyAnnounce = true
} }
@ -730,15 +737,10 @@ func (p *peer) updateFlowControl(update keyValueMap) {
if p.fcServer == nil { if p.fcServer == nil {
return return
} }
params := p.fcParams // If any of the flow control params is nil, refuse to update.
updateParams := false var params flowcontrol.ServerParams
if update.get("flowControl/BL", &params.BufLimit) == nil { if update.get("flowControl/BL", &params.BufLimit) == nil && update.get("flowControl/MRR", &params.MinRecharge) == nil {
updateParams = true // todo can light client set a minimal acceptable flow control params?
}
if update.get("flowControl/MRR", &params.MinRecharge) == nil {
updateParams = true
}
if updateParams {
p.fcParams = params p.fcParams = params
p.fcServer.UpdateParams(params) p.fcServer.UpdateParams(params)
} }

@ -18,47 +18,54 @@ package les
import ( import (
"math/big" "math/big"
"net"
"testing" "testing"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/eth"
"github.com/ethereum/go-ethereum/les/flowcontrol" "github.com/ethereum/go-ethereum/les/flowcontrol"
"github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
) )
const ( const protocolVersion = lpv2
test_networkid = 10
protocol_version = lpv2
)
var ( var (
hash = common.HexToHash("some string") hash = common.HexToHash("deadbeef")
genesis = common.HexToHash("genesis hash") genesis = common.HexToHash("cafebabe")
headNum = uint64(1234) headNum = uint64(1234)
td = big.NewInt(123) td = big.NewInt(123)
) )
//ulc connects to trusted peer and send announceType=announceTypeSigned func newNodeID(t *testing.T) *enode.Node {
key, err := crypto.GenerateKey()
if err != nil {
t.Fatal("generate key err:", err)
}
return enode.NewV4(&key.PublicKey, net.IP{}, 35000, 35000)
}
// ulc connects to trusted peer and send announceType=announceTypeSigned
func TestPeerHandshakeSetAnnounceTypeToAnnounceTypeSignedForTrustedPeer(t *testing.T) { func TestPeerHandshakeSetAnnounceTypeToAnnounceTypeSignedForTrustedPeer(t *testing.T) {
id := newNodeID(t).ID() id := newNodeID(t).ID()
//peer to connect(on ulc side) // peer to connect(on ulc side)
p := peer{ p := peer{
Peer: p2p.NewPeer(id, "test peer", []p2p.Cap{}), Peer: p2p.NewPeer(id, "test peer", []p2p.Cap{}),
version: protocol_version, version: protocolVersion,
trusted: true, trusted: true,
rw: &rwStub{ rw: &rwStub{
WriteHook: func(recvList keyValueList) { WriteHook: func(recvList keyValueList) {
//checking that ulc sends to peer allowedRequests=onlyAnnounceRequests and announceType = announceTypeSigned
recv, _ := recvList.decode() recv, _ := recvList.decode()
var reqType uint64 var reqType uint64
err := recv.get("announceType", &reqType) err := recv.get("announceType", &reqType)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if reqType != announceTypeSigned { if reqType != announceTypeSigned {
t.Fatal("Expected announceTypeSigned") t.Fatal("Expected announceTypeSigned")
} }
@ -71,18 +78,15 @@ func TestPeerHandshakeSetAnnounceTypeToAnnounceTypeSignedForTrustedPeer(t *testi
l = l.add("flowControl/BL", uint64(0)) l = l.add("flowControl/BL", uint64(0))
l = l.add("flowControl/MRR", uint64(0)) l = l.add("flowControl/MRR", uint64(0))
l = l.add("flowControl/MRC", testCostList(0)) l = l.add("flowControl/MRC", testCostList(0))
return l return l
}, },
}, },
network: test_networkid, network: NetworkId,
} }
err := p.Handshake(td, hash, headNum, genesis, nil) err := p.Handshake(td, hash, headNum, genesis, nil)
if err != nil { if err != nil {
t.Fatalf("Handshake error: %s", err) t.Fatalf("Handshake error: %s", err)
} }
if p.announceType != announceTypeSigned { if p.announceType != announceTypeSigned {
t.Fatal("Incorrect announceType") t.Fatal("Incorrect announceType")
} }
@ -92,18 +96,16 @@ func TestPeerHandshakeAnnounceTypeSignedForTrustedPeersPeerNotInTrusted(t *testi
id := newNodeID(t).ID() id := newNodeID(t).ID()
p := peer{ p := peer{
Peer: p2p.NewPeer(id, "test peer", []p2p.Cap{}), Peer: p2p.NewPeer(id, "test peer", []p2p.Cap{}),
version: protocol_version, version: protocolVersion,
rw: &rwStub{ rw: &rwStub{
WriteHook: func(recvList keyValueList) { WriteHook: func(recvList keyValueList) {
//checking that ulc sends to peer allowedRequests=noRequests and announceType != announceTypeSigned // checking that ulc sends to peer allowedRequests=noRequests and announceType != announceTypeSigned
recv, _ := recvList.decode() recv, _ := recvList.decode()
var reqType uint64 var reqType uint64
err := recv.get("announceType", &reqType) err := recv.get("announceType", &reqType)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if reqType == announceTypeSigned { if reqType == announceTypeSigned {
t.Fatal("Expected not announceTypeSigned") t.Fatal("Expected not announceTypeSigned")
} }
@ -116,13 +118,11 @@ func TestPeerHandshakeAnnounceTypeSignedForTrustedPeersPeerNotInTrusted(t *testi
l = l.add("flowControl/BL", uint64(0)) l = l.add("flowControl/BL", uint64(0))
l = l.add("flowControl/MRR", uint64(0)) l = l.add("flowControl/MRR", uint64(0))
l = l.add("flowControl/MRC", testCostList(0)) l = l.add("flowControl/MRC", testCostList(0))
return l return l
}, },
}, },
network: test_networkid, network: NetworkId,
} }
err := p.Handshake(td, hash, headNum, genesis, nil) err := p.Handshake(td, hash, headNum, genesis, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -139,16 +139,15 @@ func TestPeerHandshakeDefaultAllRequests(t *testing.T) {
p := peer{ p := peer{
Peer: p2p.NewPeer(id, "test peer", []p2p.Cap{}), Peer: p2p.NewPeer(id, "test peer", []p2p.Cap{}),
version: protocol_version, version: protocolVersion,
rw: &rwStub{ rw: &rwStub{
ReadHook: func(l keyValueList) keyValueList { ReadHook: func(l keyValueList) keyValueList {
l = l.add("announceType", uint64(announceTypeSigned)) l = l.add("announceType", uint64(announceTypeSigned))
l = l.add("allowedRequests", uint64(0)) l = l.add("allowedRequests", uint64(0))
return l return l
}, },
}, },
network: test_networkid, network: NetworkId,
} }
err := p.Handshake(td, hash, headNum, genesis, s) err := p.Handshake(td, hash, headNum, genesis, s)
@ -165,15 +164,14 @@ func TestPeerHandshakeServerSendOnlyAnnounceRequestsHeaders(t *testing.T) {
id := newNodeID(t).ID() id := newNodeID(t).ID()
s := generateLesServer() s := generateLesServer()
s.onlyAnnounce = true s.config.UltraLightOnlyAnnounce = true
p := peer{ p := peer{
Peer: p2p.NewPeer(id, "test peer", []p2p.Cap{}), Peer: p2p.NewPeer(id, "test peer", []p2p.Cap{}),
version: protocol_version, version: protocolVersion,
rw: &rwStub{ rw: &rwStub{
ReadHook: func(l keyValueList) keyValueList { ReadHook: func(l keyValueList) keyValueList {
l = l.add("announceType", uint64(announceTypeSigned)) l = l.add("announceType", uint64(announceTypeSigned))
return l return l
}, },
WriteHook: func(l keyValueList) { WriteHook: func(l keyValueList) {
@ -187,7 +185,7 @@ func TestPeerHandshakeServerSendOnlyAnnounceRequestsHeaders(t *testing.T) {
} }
}, },
}, },
network: test_networkid, network: NetworkId,
} }
err := p.Handshake(td, hash, headNum, genesis, s) err := p.Handshake(td, hash, headNum, genesis, s)
@ -200,7 +198,7 @@ func TestPeerHandshakeClientReceiveOnlyAnnounceRequestsHeaders(t *testing.T) {
p := peer{ p := peer{
Peer: p2p.NewPeer(id, "test peer", []p2p.Cap{}), Peer: p2p.NewPeer(id, "test peer", []p2p.Cap{}),
version: protocol_version, version: protocolVersion,
rw: &rwStub{ rw: &rwStub{
ReadHook: func(l keyValueList) keyValueList { ReadHook: func(l keyValueList) keyValueList {
l = l.add("flowControl/BL", uint64(0)) l = l.add("flowControl/BL", uint64(0))
@ -212,7 +210,7 @@ func TestPeerHandshakeClientReceiveOnlyAnnounceRequestsHeaders(t *testing.T) {
return l return l
}, },
}, },
network: test_networkid, network: NetworkId,
trusted: true, trusted: true,
} }
@ -231,19 +229,17 @@ func TestPeerHandshakeClientReturnErrorOnUselessPeer(t *testing.T) {
p := peer{ p := peer{
Peer: p2p.NewPeer(id, "test peer", []p2p.Cap{}), Peer: p2p.NewPeer(id, "test peer", []p2p.Cap{}),
version: protocol_version, version: protocolVersion,
rw: &rwStub{ rw: &rwStub{
ReadHook: func(l keyValueList) keyValueList { ReadHook: func(l keyValueList) keyValueList {
l = l.add("flowControl/BL", uint64(0)) l = l.add("flowControl/BL", uint64(0))
l = l.add("flowControl/MRR", uint64(0)) l = l.add("flowControl/MRR", uint64(0))
l = l.add("flowControl/MRC", RequestCostList{}) l = l.add("flowControl/MRC", RequestCostList{})
l = l.add("announceType", uint64(announceTypeSigned)) l = l.add("announceType", uint64(announceTypeSigned))
return l return l
}, },
}, },
network: test_networkid, network: NetworkId,
} }
err := p.Handshake(td, hash, headNum, genesis, nil) err := p.Handshake(td, hash, headNum, genesis, nil)
@ -254,12 +250,16 @@ func TestPeerHandshakeClientReturnErrorOnUselessPeer(t *testing.T) {
func generateLesServer() *LesServer { func generateLesServer() *LesServer {
s := &LesServer{ s := &LesServer{
lesCommons: lesCommons{
config: &eth.Config{UltraLightOnlyAnnounce: true},
},
defParams: flowcontrol.ServerParams{ defParams: flowcontrol.ServerParams{
BufLimit: uint64(300000000), BufLimit: uint64(300000000),
MinRecharge: uint64(50000), MinRecharge: uint64(50000),
}, },
fcManager: flowcontrol.NewClientManager(nil, &mclock.System{}), fcManager: flowcontrol.NewClientManager(nil, &mclock.System{}),
} }
s.costTracker, _ = newCostTracker(rawdb.NewMemoryDatabase(), s.config)
return s return s
} }
@ -270,8 +270,8 @@ type rwStub struct {
func (s *rwStub) ReadMsg() (p2p.Msg, error) { func (s *rwStub) ReadMsg() (p2p.Msg, error) {
payload := keyValueList{} payload := keyValueList{}
payload = payload.add("protocolVersion", uint64(protocol_version)) payload = payload.add("protocolVersion", uint64(protocolVersion))
payload = payload.add("networkId", uint64(test_networkid)) payload = payload.add("networkId", uint64(NetworkId))
payload = payload.add("headTd", td) payload = payload.add("headTd", td)
payload = payload.add("headHash", hash) payload = payload.add("headHash", hash)
payload = payload.add("headNum", headNum) payload = payload.add("headNum", headNum)
@ -280,12 +280,10 @@ func (s *rwStub) ReadMsg() (p2p.Msg, error) {
if s.ReadHook != nil { if s.ReadHook != nil {
payload = s.ReadHook(payload) payload = s.ReadHook(payload)
} }
size, p, err := rlp.EncodeToReader(payload) size, p, err := rlp.EncodeToReader(payload)
if err != nil { if err != nil {
return p2p.Msg{}, err return p2p.Msg{}, err
} }
return p2p.Msg{ return p2p.Msg{
Size: uint32(size), Size: uint32(size),
Payload: p, Payload: p,
@ -297,10 +295,8 @@ func (s *rwStub) WriteMsg(m p2p.Msg) error {
if err := m.Decode(&recvList); err != nil { if err := m.Decode(&recvList); err != nil {
return err return err
} }
if s.WriteHook != nil { if s.WriteHook != nil {
s.WriteHook(recvList) s.WriteHook(recvList)
} }
return nil return nil
} }

@ -37,18 +37,21 @@ func secAddr(addr common.Address) []byte {
type accessTestFn func(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest type accessTestFn func(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest
func TestBlockAccessLes2(t *testing.T) { testAccess(t, 2, tfBlockAccess) } func TestBlockAccessLes2(t *testing.T) { testAccess(t, 2, tfBlockAccess) }
func TestBlockAccessLes3(t *testing.T) { testAccess(t, 3, tfBlockAccess) }
func tfBlockAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest { func tfBlockAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest {
return &light.BlockRequest{Hash: bhash, Number: number} return &light.BlockRequest{Hash: bhash, Number: number}
} }
func TestReceiptsAccessLes2(t *testing.T) { testAccess(t, 2, tfReceiptsAccess) } func TestReceiptsAccessLes2(t *testing.T) { testAccess(t, 2, tfReceiptsAccess) }
func TestReceiptsAccessLes3(t *testing.T) { testAccess(t, 3, tfReceiptsAccess) }
func tfReceiptsAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest { func tfReceiptsAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest {
return &light.ReceiptsRequest{Hash: bhash, Number: number} return &light.ReceiptsRequest{Hash: bhash, Number: number}
} }
func TestTrieEntryAccessLes2(t *testing.T) { testAccess(t, 2, tfTrieEntryAccess) } func TestTrieEntryAccessLes2(t *testing.T) { testAccess(t, 2, tfTrieEntryAccess) }
func TestTrieEntryAccessLes3(t *testing.T) { testAccess(t, 3, tfTrieEntryAccess) }
func tfTrieEntryAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest { func tfTrieEntryAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest {
if number := rawdb.ReadHeaderNumber(db, bhash); number != nil { if number := rawdb.ReadHeaderNumber(db, bhash); number != nil {
@ -58,6 +61,7 @@ func tfTrieEntryAccess(db ethdb.Database, bhash common.Hash, number uint64) ligh
} }
func TestCodeAccessLes2(t *testing.T) { testAccess(t, 2, tfCodeAccess) } func TestCodeAccessLes2(t *testing.T) { testAccess(t, 2, tfCodeAccess) }
func TestCodeAccessLes3(t *testing.T) { testAccess(t, 3, tfCodeAccess) }
func tfCodeAccess(db ethdb.Database, bhash common.Hash, num uint64) light.OdrRequest { func tfCodeAccess(db ethdb.Database, bhash common.Hash, num uint64) light.OdrRequest {
number := rawdb.ReadHeaderNumber(db, bhash) number := rawdb.ReadHeaderNumber(db, bhash)
@ -75,17 +79,18 @@ func tfCodeAccess(db ethdb.Database, bhash common.Hash, num uint64) light.OdrReq
func testAccess(t *testing.T, protocol int, fn accessTestFn) { func testAccess(t *testing.T, protocol int, fn accessTestFn) {
// Assemble the test environment // Assemble the test environment
server, client, tearDown := newClientServerEnv(t, 4, protocol, nil, true) server, client, tearDown := newClientServerEnv(t, 4, protocol, nil, nil, 0, false, true)
defer tearDown() defer tearDown()
client.pm.synchronise(client.rPeer) client.handler.synchronise(client.peer.peer)
test := func(expFail uint64) { test := func(expFail uint64) {
for i := uint64(0); i <= server.pm.blockchain.CurrentHeader().Number.Uint64(); i++ { for i := uint64(0); i <= server.handler.blockchain.CurrentHeader().Number.Uint64(); i++ {
bhash := rawdb.ReadCanonicalHash(server.db, i) bhash := rawdb.ReadCanonicalHash(server.db, i)
if req := fn(client.db, bhash, i); req != nil { if req := fn(client.db, bhash, i); req != nil {
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel() err := client.handler.backend.odr.Retrieve(ctx, req)
err := client.pm.odr.Retrieve(ctx, req) cancel()
got := err == nil got := err == nil
exp := i < expFail exp := i < expFail
if exp && !got { if exp && !got {
@ -97,18 +102,5 @@ func testAccess(t *testing.T, protocol int, fn accessTestFn) {
} }
} }
} }
// temporarily remove peer to test odr fails
client.peers.Unregister(client.rPeer.id)
time.Sleep(time.Millisecond * 10) // ensure that all peerSetNotify callbacks are executed
// expect retrievals to fail (except genesis block) without a les peer
test(0)
client.peers.Register(client.rPeer)
time.Sleep(time.Millisecond * 10) // ensure that all peerSetNotify callbacks are executed
client.rPeer.lock.Lock()
client.rPeer.hasBlock = func(common.Hash, uint64, bool) bool { return true }
client.rPeer.lock.Unlock()
// expect all retrievals to pass
test(5) test(5)
} }

@ -18,15 +18,11 @@ package les
import ( import (
"crypto/ecdsa" "crypto/ecdsa"
"sync"
"time" "time"
"github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/accounts/abi/bind"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/eth" "github.com/ethereum/go-ethereum/eth"
"github.com/ethereum/go-ethereum/les/flowcontrol" "github.com/ethereum/go-ethereum/les/flowcontrol"
"github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/light"
@ -38,80 +34,94 @@ import (
"github.com/ethereum/go-ethereum/rpc" "github.com/ethereum/go-ethereum/rpc"
) )
const bufLimitRatio = 6000 // fixed bufLimit/MRR ratio
type LesServer struct { type LesServer struct {
lesCommons lesCommons
archiveMode bool // Flag whether the ethereum node runs in archive mode. archiveMode bool // Flag whether the ethereum node runs in archive mode.
handler *serverHandler
lesTopics []discv5.Topic
privateKey *ecdsa.PrivateKey
fcManager *flowcontrol.ClientManager // nil if our node is client only // Flow control and capacity management
fcManager *flowcontrol.ClientManager
costTracker *costTracker costTracker *costTracker
testCost uint64
defParams flowcontrol.ServerParams defParams flowcontrol.ServerParams
lesTopics []discv5.Topic servingQueue *servingQueue
privateKey *ecdsa.PrivateKey clientPool *clientPool
quitSync chan struct{}
onlyAnnounce bool
thcNormal, thcBlockProcessing int // serving thread count for normal operation and block processing mode
maxPeers int freeCapacity uint64 // The minimal client capacity used for free client.
minCapacity, maxCapacity, freeClientCap uint64 threadsIdle int // Request serving threads count when system is idle.
clientPool *clientPool threadsBusy int // Request serving threads count when system is busy(block insertion).
} }
func NewLesServer(e *eth.Ethereum, config *eth.Config) (*LesServer, error) { func NewLesServer(e *eth.Ethereum, config *eth.Config) (*LesServer, error) {
// Collect les protocol version information supported by local node.
lesTopics := make([]discv5.Topic, len(AdvertiseProtocolVersions)) lesTopics := make([]discv5.Topic, len(AdvertiseProtocolVersions))
for i, pv := range AdvertiseProtocolVersions { for i, pv := range AdvertiseProtocolVersions {
lesTopics[i] = lesTopic(e.BlockChain().Genesis().Hash(), pv) lesTopics[i] = lesTopic(e.BlockChain().Genesis().Hash(), pv)
} }
quitSync := make(chan struct{}) // Calculate the number of threads used to service the light client
// requests based on the user-specified value.
threads := config.LightServ * 4 / 100
if threads < 4 {
threads = 4
}
srv := &LesServer{ srv := &LesServer{
lesCommons: lesCommons{ lesCommons: lesCommons{
genesis: e.BlockChain().Genesis().Hash(),
config: config, config: config,
chainConfig: e.BlockChain().Config(),
iConfig: light.DefaultServerIndexerConfig, iConfig: light.DefaultServerIndexerConfig,
chainDb: e.ChainDb(), chainDb: e.ChainDb(),
peers: newPeerSet(),
chainReader: e.BlockChain(),
chtIndexer: light.NewChtIndexer(e.ChainDb(), nil, params.CHTFrequency, params.HelperTrieProcessConfirmations), chtIndexer: light.NewChtIndexer(e.ChainDb(), nil, params.CHTFrequency, params.HelperTrieProcessConfirmations),
bloomTrieIndexer: light.NewBloomTrieIndexer(e.ChainDb(), nil, params.BloomBitsBlocks, params.BloomTrieFrequency), bloomTrieIndexer: light.NewBloomTrieIndexer(e.ChainDb(), nil, params.BloomBitsBlocks, params.BloomTrieFrequency),
closeCh: make(chan struct{}),
}, },
archiveMode: e.ArchiveMode(), archiveMode: e.ArchiveMode(),
quitSync: quitSync,
lesTopics: lesTopics, lesTopics: lesTopics,
onlyAnnounce: config.UltraLightOnlyAnnounce, fcManager: flowcontrol.NewClientManager(nil, &mclock.System{}),
servingQueue: newServingQueue(int64(time.Millisecond*10), float64(config.LightServ)/100),
threadsBusy: config.LightServ/100 + 1,
threadsIdle: threads,
} }
srv.costTracker, srv.minCapacity = newCostTracker(e.ChainDb(), config) srv.handler = newServerHandler(srv, e.BlockChain(), e.ChainDb(), e.TxPool(), e.Synced)
srv.costTracker, srv.freeCapacity = newCostTracker(e.ChainDb(), config)
logger := log.New() // Set up checkpoint oracle.
srv.thcNormal = config.LightServ * 4 / 100 oracle := config.CheckpointOracle
if srv.thcNormal < 4 { if oracle == nil {
srv.thcNormal = 4 oracle = params.CheckpointOracles[e.BlockChain().Genesis().Hash()]
} }
srv.thcBlockProcessing = config.LightServ/100 + 1 srv.oracle = newCheckpointOracle(oracle, srv.localCheckpoint)
srv.fcManager = flowcontrol.NewClientManager(nil, &mclock.System{})
// Initialize server capacity management fields.
srv.defParams = flowcontrol.ServerParams{
BufLimit: srv.freeCapacity * bufLimitRatio,
MinRecharge: srv.freeCapacity,
}
// LES flow control tries to more or less guarantee the possibility for the
// clients to send a certain amount of requests at any time and get a quick
// response. Most of the clients want this guarantee but don't actually need
// to send requests most of the time. Our goal is to serve as many clients as
// possible while the actually used server capacity does not exceed the limits
totalRecharge := srv.costTracker.totalRecharge()
maxCapacity := srv.freeCapacity * uint64(srv.config.LightPeers)
if totalRecharge > maxCapacity {
maxCapacity = totalRecharge
}
srv.fcManager.SetCapacityLimits(srv.freeCapacity, maxCapacity, srv.freeCapacity*2)
srv.clientPool = newClientPool(srv.chainDb, srv.freeCapacity, 10000, mclock.System{}, func(id enode.ID) { go srv.peers.Unregister(peerIdToString(id)) })
srv.peers.notify(srv.clientPool)
checkpoint := srv.latestLocalCheckpoint() checkpoint := srv.latestLocalCheckpoint()
if !checkpoint.Empty() { if !checkpoint.Empty() {
logger.Info("Loaded latest checkpoint", "section", checkpoint.SectionIndex, "head", checkpoint.SectionHead, log.Info("Loaded latest checkpoint", "section", checkpoint.SectionIndex, "head", checkpoint.SectionHead,
"chtroot", checkpoint.CHTRoot, "bloomroot", checkpoint.BloomRoot) "chtroot", checkpoint.CHTRoot, "bloomroot", checkpoint.BloomRoot)
} }
srv.chtIndexer.Start(e.BlockChain()) srv.chtIndexer.Start(e.BlockChain())
oracle := config.CheckpointOracle
if oracle == nil {
oracle = params.CheckpointOracles[e.BlockChain().Genesis().Hash()]
}
registrar := newCheckpointOracle(oracle, srv.getLocalCheckpoint)
// TODO(rjl493456442) Checkpoint is useless for les server, separate handler for client and server.
pm, err := NewProtocolManager(e.BlockChain().Config(), nil, light.DefaultServerIndexerConfig, config.UltraLightServers, config.UltraLightFraction, false, config.NetworkId, e.EventMux(), newPeerSet(), e.BlockChain(), e.TxPool(), e.ChainDb(), nil, nil, registrar, quitSync, new(sync.WaitGroup), e.Synced)
if err != nil {
return nil, err
}
srv.protocolManager = pm
pm.servingQueue = newServingQueue(int64(time.Millisecond*10), float64(config.LightServ)/100)
pm.server = srv
return srv, nil return srv, nil
} }
@ -120,102 +130,29 @@ func (s *LesServer) APIs() []rpc.API {
{ {
Namespace: "les", Namespace: "les",
Version: "1.0", Version: "1.0",
Service: NewPrivateLightAPI(&s.lesCommons, s.protocolManager.reg), Service: NewPrivateLightAPI(&s.lesCommons),
Public: false, Public: false,
}, },
} }
} }
// startEventLoop starts an event handler loop that updates the recharge curve of
// the client manager and adjusts the client pool's size according to the total
// capacity updates coming from the client manager
func (s *LesServer) startEventLoop() {
s.protocolManager.wg.Add(1)
var (
processing, procLast bool
procStarted time.Time
)
blockProcFeed := make(chan bool, 100)
s.protocolManager.blockchain.(*core.BlockChain).SubscribeBlockProcessingEvent(blockProcFeed)
totalRechargeCh := make(chan uint64, 100)
totalRecharge := s.costTracker.subscribeTotalRecharge(totalRechargeCh)
totalCapacityCh := make(chan uint64, 100)
updateRecharge := func() {
if processing {
if !procLast {
procStarted = time.Now()
}
s.protocolManager.servingQueue.setThreads(s.thcBlockProcessing)
s.fcManager.SetRechargeCurve(flowcontrol.PieceWiseLinear{{0, 0}, {totalRecharge, totalRecharge}})
} else {
if procLast {
blockProcessingTimer.UpdateSince(procStarted)
}
s.protocolManager.servingQueue.setThreads(s.thcNormal)
s.fcManager.SetRechargeCurve(flowcontrol.PieceWiseLinear{{0, 0}, {totalRecharge / 16, totalRecharge / 2}, {totalRecharge / 2, totalRecharge / 2}, {totalRecharge, totalRecharge}})
}
procLast = processing
}
updateRecharge()
totalCapacity := s.fcManager.SubscribeTotalCapacity(totalCapacityCh)
s.clientPool.setLimits(s.maxPeers, totalCapacity)
var maxFreePeers uint64
go func() {
for {
select {
case processing = <-blockProcFeed:
updateRecharge()
case totalRecharge = <-totalRechargeCh:
updateRecharge()
case totalCapacity = <-totalCapacityCh:
totalCapacityGauge.Update(int64(totalCapacity))
newFreePeers := totalCapacity / s.freeClientCap
if newFreePeers < maxFreePeers && newFreePeers < uint64(s.maxPeers) {
log.Warn("Reduced total capacity", "maxFreePeers", newFreePeers)
}
maxFreePeers = newFreePeers
s.clientPool.setLimits(s.maxPeers, totalCapacity)
case <-s.protocolManager.quitSync:
s.protocolManager.wg.Done()
return
}
}
}()
}
func (s *LesServer) Protocols() []p2p.Protocol { func (s *LesServer) Protocols() []p2p.Protocol {
return s.makeProtocols(ServerProtocolVersions) return s.makeProtocols(ServerProtocolVersions, s.handler.runPeer, func(id enode.ID) interface{} {
if p := s.peers.Peer(peerIdToString(id)); p != nil {
return p.Info()
}
return nil
})
} }
// Start starts the LES server // Start starts the LES server
func (s *LesServer) Start(srvr *p2p.Server) { func (s *LesServer) Start(srvr *p2p.Server) {
s.maxPeers = s.config.LightPeers s.privateKey = srvr.PrivateKey
totalRecharge := s.costTracker.totalRecharge() s.handler.start()
if s.maxPeers > 0 {
s.freeClientCap = s.minCapacity //totalRecharge / uint64(s.maxPeers) s.wg.Add(1)
if s.freeClientCap < s.minCapacity { go s.capacityManagement()
s.freeClientCap = s.minCapacity
}
if s.freeClientCap > 0 {
s.defParams = flowcontrol.ServerParams{
BufLimit: s.freeClientCap * bufLimitRatio,
MinRecharge: s.freeClientCap,
}
}
}
s.maxCapacity = s.freeClientCap * uint64(s.maxPeers)
if totalRecharge > s.maxCapacity {
s.maxCapacity = totalRecharge
}
s.fcManager.SetCapacityLimits(s.freeClientCap, s.maxCapacity, s.freeClientCap*2)
s.clientPool = newClientPool(s.chainDb, s.freeClientCap, 10000, mclock.System{}, func(id enode.ID) { go s.protocolManager.removePeer(peerIdToString(id)) })
s.clientPool.setPriceFactors(priceFactors{0, 1, 1}, priceFactors{0, 1, 1})
s.protocolManager.peers.notify(s.clientPool)
s.startEventLoop()
s.protocolManager.Start(s.config.LightPeers)
if srvr.DiscV5 != nil { if srvr.DiscV5 != nil {
for _, topic := range s.lesTopics { for _, topic := range s.lesTopics {
topic := topic topic := topic
@ -224,12 +161,32 @@ func (s *LesServer) Start(srvr *p2p.Server) {
logger.Info("Starting topic registration") logger.Info("Starting topic registration")
defer logger.Info("Terminated topic registration") defer logger.Info("Terminated topic registration")
srvr.DiscV5.RegisterTopic(topic, s.quitSync) srvr.DiscV5.RegisterTopic(topic, s.closeCh)
}() }()
} }
} }
s.privateKey = srvr.PrivateKey }
s.protocolManager.blockLoop()
// Stop stops the LES service
func (s *LesServer) Stop() {
close(s.closeCh)
// Disconnect existing sessions.
// This also closes the gate for any new registrations on the peer set.
// sessions which are already established but not added to pm.peers yet
// will exit when they try to register.
s.peers.Close()
s.fcManager.Stop()
s.clientPool.stop()
s.costTracker.stop()
s.handler.stop()
s.servingQueue.stop()
// Note, bloom trie indexer is closed by parent bloombits indexer.
s.chtIndexer.Close()
s.wg.Wait()
log.Info("Les server stopped")
} }
func (s *LesServer) SetBloomBitsIndexer(bloomIndexer *core.ChainIndexer) { func (s *LesServer) SetBloomBitsIndexer(bloomIndexer *core.ChainIndexer) {
@ -238,78 +195,67 @@ func (s *LesServer) SetBloomBitsIndexer(bloomIndexer *core.ChainIndexer) {
// SetClient sets the rpc client and starts running checkpoint contract if it is not yet watched. // SetClient sets the rpc client and starts running checkpoint contract if it is not yet watched.
func (s *LesServer) SetContractBackend(backend bind.ContractBackend) { func (s *LesServer) SetContractBackend(backend bind.ContractBackend) {
if s.protocolManager.reg != nil { if s.oracle == nil {
s.protocolManager.reg.start(backend) return
} }
s.oracle.start(backend)
} }
// Stop stops the LES service // capacityManagement starts an event handler loop that updates the recharge curve of
func (s *LesServer) Stop() { // the client manager and adjusts the client pool's size according to the total
s.fcManager.Stop() // capacity updates coming from the client manager
s.chtIndexer.Close() func (s *LesServer) capacityManagement() {
// bloom trie indexer is closed by parent bloombits indexer defer s.wg.Done()
go func() {
<-s.protocolManager.noMorePeers
}()
s.clientPool.stop()
s.costTracker.stop()
s.protocolManager.Stop()
}
// todo(rjl493456442) separate client and server implementation. processCh := make(chan bool, 100)
func (pm *ProtocolManager) blockLoop() { sub := s.handler.blockchain.SubscribeBlockProcessingEvent(processCh)
pm.wg.Add(1) defer sub.Unsubscribe()
headCh := make(chan core.ChainHeadEvent, 10)
headSub := pm.blockchain.SubscribeChainHeadEvent(headCh)
go func() {
var lastHead *types.Header
lastBroadcastTd := common.Big0
for {
select {
case ev := <-headCh:
peers := pm.peers.AllPeers()
if len(peers) > 0 {
header := ev.Block.Header()
hash := header.Hash()
number := header.Number.Uint64()
td := rawdb.ReadTd(pm.chainDb, hash, number)
if td != nil && td.Cmp(lastBroadcastTd) > 0 {
var reorg uint64
if lastHead != nil {
reorg = lastHead.Number.Uint64() - rawdb.FindCommonAncestor(pm.chainDb, header, lastHead).Number.Uint64()
}
lastHead = header
lastBroadcastTd = td
log.Debug("Announcing block to peers", "number", number, "hash", hash, "td", td, "reorg", reorg) totalRechargeCh := make(chan uint64, 100)
totalRecharge := s.costTracker.subscribeTotalRecharge(totalRechargeCh)
announce := announceData{Hash: hash, Number: number, Td: td, ReorgDepth: reorg} totalCapacityCh := make(chan uint64, 100)
var ( totalCapacity := s.fcManager.SubscribeTotalCapacity(totalCapacityCh)
signed bool s.clientPool.setLimits(s.config.LightPeers, totalCapacity)
signedAnnounce announceData
)
for _, p := range peers { var (
p := p busy bool
switch p.announceType { freePeers uint64
case announceTypeSimple: blockProcess mclock.AbsTime
p.queueSend(func() { p.SendAnnounce(announce) }) )
case announceTypeSigned: updateRecharge := func() {
if !signed { if busy {
signedAnnounce = announce s.servingQueue.setThreads(s.threadsBusy)
signedAnnounce.sign(pm.server.privateKey) s.fcManager.SetRechargeCurve(flowcontrol.PieceWiseLinear{{0, 0}, {totalRecharge, totalRecharge}})
signed = true } else {
} s.servingQueue.setThreads(s.threadsIdle)
p.queueSend(func() { p.SendAnnounce(signedAnnounce) }) s.fcManager.SetRechargeCurve(flowcontrol.PieceWiseLinear{{0, 0}, {totalRecharge / 10, totalRecharge}, {totalRecharge, totalRecharge}})
} }
} }
} updateRecharge()
}
case <-pm.quitSync: for {
headSub.Unsubscribe() select {
pm.wg.Done() case busy = <-processCh:
return if busy {
blockProcess = mclock.Now()
} else {
blockProcessingTimer.Update(time.Duration(mclock.Now() - blockProcess))
} }
updateRecharge()
case totalRecharge = <-totalRechargeCh:
totalRechargeGauge.Update(int64(totalRecharge))
updateRecharge()
case totalCapacity = <-totalCapacityCh:
totalCapacityGauge.Update(int64(totalCapacity))
newFreePeers := totalCapacity / s.freeCapacity
if newFreePeers < freePeers && newFreePeers < uint64(s.config.LightPeers) {
log.Warn("Reduced free peer connections", "from", freePeers, "to", newFreePeers)
}
freePeers = newFreePeers
s.clientPool.setLimits(s.config.LightPeers, totalCapacity)
case <-s.closeCh:
return
} }
}() }
} }

@ -0,0 +1,921 @@
// Copyright 2019 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package les
import (
"encoding/binary"
"encoding/json"
"errors"
"sync"
"sync/atomic"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/light"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/metrics"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie"
)
const (
softResponseLimit = 2 * 1024 * 1024 // Target maximum size of returned blocks, headers or node data.
estHeaderRlpSize = 500 // Approximate size of an RLP encoded block header
ethVersion = 63 // equivalent eth version for the downloader
MaxHeaderFetch = 192 // Amount of block headers to be fetched per retrieval request
MaxBodyFetch = 32 // Amount of block bodies to be fetched per retrieval request
MaxReceiptFetch = 128 // Amount of transaction receipts to allow fetching per request
MaxCodeFetch = 64 // Amount of contract codes to allow fetching per request
MaxProofsFetch = 64 // Amount of merkle proofs to be fetched per retrieval request
MaxHelperTrieProofsFetch = 64 // Amount of helper tries to be fetched per retrieval request
MaxTxSend = 64 // Amount of transactions to be send per request
MaxTxStatus = 256 // Amount of transactions to queried per request
)
var errTooManyInvalidRequest = errors.New("too many invalid requests made")
// serverHandler is responsible for serving light client and process
// all incoming light requests.
type serverHandler struct {
blockchain *core.BlockChain
chainDb ethdb.Database
txpool *core.TxPool
server *LesServer
closeCh chan struct{} // Channel used to exit all background routines of handler.
wg sync.WaitGroup // WaitGroup used to track all background routines of handler.
synced func() bool // Callback function used to determine whether local node is synced.
// Testing fields
addTxsSync bool
}
func newServerHandler(server *LesServer, blockchain *core.BlockChain, chainDb ethdb.Database, txpool *core.TxPool, synced func() bool) *serverHandler {
handler := &serverHandler{
server: server,
blockchain: blockchain,
chainDb: chainDb,
txpool: txpool,
closeCh: make(chan struct{}),
synced: synced,
}
return handler
}
// start starts the server handler.
func (h *serverHandler) start() {
h.wg.Add(1)
go h.broadcastHeaders()
}
// stop stops the server handler.
func (h *serverHandler) stop() {
close(h.closeCh)
h.wg.Wait()
}
// runPeer is the p2p protocol run function for the given version.
func (h *serverHandler) runPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) error {
peer := newPeer(int(version), h.server.config.NetworkId, false, p, newMeteredMsgWriter(rw, int(version)))
h.wg.Add(1)
defer h.wg.Done()
return h.handle(peer)
}
func (h *serverHandler) handle(p *peer) error {
// Reject light clients if server is not synced.
if !h.synced() {
return p2p.DiscRequested
}
p.Log().Debug("Light Ethereum peer connected", "name", p.Name())
// Execute the LES handshake
var (
head = h.blockchain.CurrentHeader()
hash = head.Hash()
number = head.Number.Uint64()
td = h.blockchain.GetTd(hash, number)
)
if err := p.Handshake(td, hash, number, h.blockchain.Genesis().Hash(), h.server); err != nil {
p.Log().Debug("Light Ethereum handshake failed", "err", err)
return err
}
defer p.fcClient.Disconnect()
// Register the peer locally
if err := h.server.peers.Register(p); err != nil {
p.Log().Error("Light Ethereum peer registration failed", "err", err)
return err
}
clientConnectionGauge.Update(int64(h.server.peers.Len()))
// add dummy balance tracker for tests
if p.balanceTracker == nil {
p.balanceTracker = &balanceTracker{}
p.balanceTracker.init(&mclock.System{}, 1)
}
connectedAt := mclock.Now()
defer func() {
p.balanceTracker = nil
h.server.peers.Unregister(p.id)
clientConnectionGauge.Update(int64(h.server.peers.Len()))
connectionTimer.Update(time.Duration(mclock.Now() - connectedAt))
}()
// Spawn a main loop to handle all incoming messages.
for {
select {
case err := <-p.errCh:
p.Log().Debug("Failed to send light ethereum response", "err", err)
return err
default:
}
if err := h.handleMsg(p); err != nil {
p.Log().Debug("Light Ethereum message handling failed", "err", err)
return err
}
}
}
// handleMsg is invoked whenever an inbound message is received from a remote
// peer. The remote connection is torn down upon returning any error.
func (h *serverHandler) handleMsg(p *peer) error {
// Read the next message from the remote peer, and ensure it's fully consumed
msg, err := p.rw.ReadMsg()
if err != nil {
return err
}
p.Log().Trace("Light Ethereum message arrived", "code", msg.Code, "bytes", msg.Size)
// Discard large message which exceeds the limitation.
if msg.Size > ProtocolMaxMsgSize {
clientErrorMeter.Mark(1)
return errResp(ErrMsgTooLarge, "%v > %v", msg.Size, ProtocolMaxMsgSize)
}
defer msg.Discard()
var (
maxCost uint64
task *servingTask
)
p.responseCount++
responseCount := p.responseCount
// accept returns an indicator whether the request can be served.
// If so, deduct the max cost from the flow control buffer.
accept := func(reqID, reqCnt, maxCnt uint64) bool {
// Short circuit if the peer is already frozen or the request is invalid.
inSizeCost := h.server.costTracker.realCost(0, msg.Size, 0)
if p.isFrozen() || reqCnt == 0 || reqCnt > maxCnt {
p.fcClient.OneTimeCost(inSizeCost)
return false
}
// Prepaid max cost units before request been serving.
maxCost = p.fcCosts.getMaxCost(msg.Code, reqCnt)
accepted, bufShort, priority := p.fcClient.AcceptRequest(reqID, responseCount, maxCost)
if !accepted {
p.freezeClient()
p.Log().Error("Request came too early", "remaining", common.PrettyDuration(time.Duration(bufShort*1000000/p.fcParams.MinRecharge)))
p.fcClient.OneTimeCost(inSizeCost)
return false
}
// Create a multi-stage task, estimate the time it takes for the task to
// execute, and cache it in the request service queue.
factor := h.server.costTracker.globalFactor()
if factor < 0.001 {
factor = 1
p.Log().Error("Invalid global cost factor", "factor", factor)
}
maxTime := uint64(float64(maxCost) / factor)
task = h.server.servingQueue.newTask(p, maxTime, priority)
if task.start() {
return true
}
p.fcClient.RequestProcessed(reqID, responseCount, maxCost, inSizeCost)
return false
}
// sendResponse sends back the response and updates the flow control statistic.
sendResponse := func(reqID, amount uint64, reply *reply, servingTime uint64) {
p.responseLock.Lock()
defer p.responseLock.Unlock()
// Short circuit if the client is already frozen.
if p.isFrozen() {
realCost := h.server.costTracker.realCost(servingTime, msg.Size, 0)
p.fcClient.RequestProcessed(reqID, responseCount, maxCost, realCost)
return
}
// Positive correction buffer value with real cost.
var replySize uint32
if reply != nil {
replySize = reply.size()
}
var realCost uint64
if h.server.costTracker.testing {
realCost = maxCost // Assign a fake cost for testing purpose
} else {
realCost = h.server.costTracker.realCost(servingTime, msg.Size, replySize)
}
bv := p.fcClient.RequestProcessed(reqID, responseCount, maxCost, realCost)
if amount != 0 {
// Feed cost tracker request serving statistic.
h.server.costTracker.updateStats(msg.Code, amount, servingTime, realCost)
// Reduce priority "balance" for the specific peer.
p.balanceTracker.requestCost(realCost)
}
if reply != nil {
p.queueSend(func() {
if err := reply.send(bv); err != nil {
select {
case p.errCh <- err:
default:
}
}
})
}
}
switch msg.Code {
case GetBlockHeadersMsg:
p.Log().Trace("Received block header request")
if metrics.EnabledExpensive {
miscInHeaderPacketsMeter.Mark(1)
miscInHeaderTrafficMeter.Mark(int64(msg.Size))
}
var req struct {
ReqID uint64
Query getBlockHeadersData
}
if err := msg.Decode(&req); err != nil {
clientErrorMeter.Mark(1)
return errResp(ErrDecode, "%v: %v", msg, err)
}
query := req.Query
if accept(req.ReqID, query.Amount, MaxHeaderFetch) {
go func() {
hashMode := query.Origin.Hash != (common.Hash{})
first := true
maxNonCanonical := uint64(100)
// Gather headers until the fetch or network limits is reached
var (
bytes common.StorageSize
headers []*types.Header
unknown bool
)
for !unknown && len(headers) < int(query.Amount) && bytes < softResponseLimit {
if !first && !task.waitOrStop() {
sendResponse(req.ReqID, 0, nil, task.servingTime)
return
}
// Retrieve the next header satisfying the query
var origin *types.Header
if hashMode {
if first {
origin = h.blockchain.GetHeaderByHash(query.Origin.Hash)
if origin != nil {
query.Origin.Number = origin.Number.Uint64()
}
} else {
origin = h.blockchain.GetHeader(query.Origin.Hash, query.Origin.Number)
}
} else {
origin = h.blockchain.GetHeaderByNumber(query.Origin.Number)
}
if origin == nil {
atomic.AddUint32(&p.invalidCount, 1)
break
}
headers = append(headers, origin)
bytes += estHeaderRlpSize
// Advance to the next header of the query
switch {
case hashMode && query.Reverse:
// Hash based traversal towards the genesis block
ancestor := query.Skip + 1
if ancestor == 0 {
unknown = true
} else {
query.Origin.Hash, query.Origin.Number = h.blockchain.GetAncestor(query.Origin.Hash, query.Origin.Number, ancestor, &maxNonCanonical)
unknown = query.Origin.Hash == common.Hash{}
}
case hashMode && !query.Reverse:
// Hash based traversal towards the leaf block
var (
current = origin.Number.Uint64()
next = current + query.Skip + 1
)
if next <= current {
infos, _ := json.MarshalIndent(p.Peer.Info(), "", " ")
p.Log().Warn("GetBlockHeaders skip overflow attack", "current", current, "skip", query.Skip, "next", next, "attacker", infos)
unknown = true
} else {
if header := h.blockchain.GetHeaderByNumber(next); header != nil {
nextHash := header.Hash()
expOldHash, _ := h.blockchain.GetAncestor(nextHash, next, query.Skip+1, &maxNonCanonical)
if expOldHash == query.Origin.Hash {
query.Origin.Hash, query.Origin.Number = nextHash, next
} else {
unknown = true
}
} else {
unknown = true
}
}
case query.Reverse:
// Number based traversal towards the genesis block
if query.Origin.Number >= query.Skip+1 {
query.Origin.Number -= query.Skip + 1
} else {
unknown = true
}
case !query.Reverse:
// Number based traversal towards the leaf block
query.Origin.Number += query.Skip + 1
}
first = false
}
reply := p.ReplyBlockHeaders(req.ReqID, headers)
sendResponse(req.ReqID, query.Amount, p.ReplyBlockHeaders(req.ReqID, headers), task.done())
if metrics.EnabledExpensive {
miscOutHeaderPacketsMeter.Mark(1)
miscOutHeaderTrafficMeter.Mark(int64(reply.size()))
}
}()
}
case GetBlockBodiesMsg:
p.Log().Trace("Received block bodies request")
if metrics.EnabledExpensive {
miscInBodyPacketsMeter.Mark(1)
miscInBodyTrafficMeter.Mark(int64(msg.Size))
}
var req struct {
ReqID uint64
Hashes []common.Hash
}
if err := msg.Decode(&req); err != nil {
clientErrorMeter.Mark(1)
return errResp(ErrDecode, "msg %v: %v", msg, err)
}
var (
bytes int
bodies []rlp.RawValue
)
reqCnt := len(req.Hashes)
if accept(req.ReqID, uint64(reqCnt), MaxBodyFetch) {
go func() {
for i, hash := range req.Hashes {
if i != 0 && !task.waitOrStop() {
sendResponse(req.ReqID, 0, nil, task.servingTime)
return
}
if bytes >= softResponseLimit {
break
}
body := h.blockchain.GetBodyRLP(hash)
if body == nil {
atomic.AddUint32(&p.invalidCount, 1)
continue
}
bodies = append(bodies, body)
bytes += len(body)
}
reply := p.ReplyBlockBodiesRLP(req.ReqID, bodies)
sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
if metrics.EnabledExpensive {
miscOutBodyPacketsMeter.Mark(1)
miscOutBodyTrafficMeter.Mark(int64(reply.size()))
}
}()
}
case GetCodeMsg:
p.Log().Trace("Received code request")
if metrics.EnabledExpensive {
miscInCodePacketsMeter.Mark(1)
miscInCodeTrafficMeter.Mark(int64(msg.Size))
}
var req struct {
ReqID uint64
Reqs []CodeReq
}
if err := msg.Decode(&req); err != nil {
clientErrorMeter.Mark(1)
return errResp(ErrDecode, "msg %v: %v", msg, err)
}
var (
bytes int
data [][]byte
)
reqCnt := len(req.Reqs)
if accept(req.ReqID, uint64(reqCnt), MaxCodeFetch) {
go func() {
for i, request := range req.Reqs {
if i != 0 && !task.waitOrStop() {
sendResponse(req.ReqID, 0, nil, task.servingTime)
return
}
// Look up the root hash belonging to the request
header := h.blockchain.GetHeaderByHash(request.BHash)
if header == nil {
p.Log().Warn("Failed to retrieve associate header for code", "hash", request.BHash)
atomic.AddUint32(&p.invalidCount, 1)
continue
}
// Refuse to search stale state data in the database since looking for
// a non-exist key is kind of expensive.
local := h.blockchain.CurrentHeader().Number.Uint64()
if !h.server.archiveMode && header.Number.Uint64()+core.TriesInMemory <= local {
p.Log().Debug("Reject stale code request", "number", header.Number.Uint64(), "head", local)
atomic.AddUint32(&p.invalidCount, 1)
continue
}
triedb := h.blockchain.StateCache().TrieDB()
account, err := h.getAccount(triedb, header.Root, common.BytesToHash(request.AccKey))
if err != nil {
p.Log().Warn("Failed to retrieve account for code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err)
atomic.AddUint32(&p.invalidCount, 1)
continue
}
code, err := triedb.Node(common.BytesToHash(account.CodeHash))
if err != nil {
p.Log().Warn("Failed to retrieve account code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "codehash", common.BytesToHash(account.CodeHash), "err", err)
continue
}
// Accumulate the code and abort if enough data was retrieved
data = append(data, code)
if bytes += len(code); bytes >= softResponseLimit {
break
}
}
reply := p.ReplyCode(req.ReqID, data)
sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
if metrics.EnabledExpensive {
miscOutCodePacketsMeter.Mark(1)
miscOutCodeTrafficMeter.Mark(int64(reply.size()))
}
}()
}
case GetReceiptsMsg:
p.Log().Trace("Received receipts request")
if metrics.EnabledExpensive {
miscInReceiptPacketsMeter.Mark(1)
miscInReceiptTrafficMeter.Mark(int64(msg.Size))
}
var req struct {
ReqID uint64
Hashes []common.Hash
}
if err := msg.Decode(&req); err != nil {
clientErrorMeter.Mark(1)
return errResp(ErrDecode, "msg %v: %v", msg, err)
}
var (
bytes int
receipts []rlp.RawValue
)
reqCnt := len(req.Hashes)
if accept(req.ReqID, uint64(reqCnt), MaxReceiptFetch) {
go func() {
for i, hash := range req.Hashes {
if i != 0 && !task.waitOrStop() {
sendResponse(req.ReqID, 0, nil, task.servingTime)
return
}
if bytes >= softResponseLimit {
break
}
// Retrieve the requested block's receipts, skipping if unknown to us
results := h.blockchain.GetReceiptsByHash(hash)
if results == nil {
if header := h.blockchain.GetHeaderByHash(hash); header == nil || header.ReceiptHash != types.EmptyRootHash {
atomic.AddUint32(&p.invalidCount, 1)
continue
}
}
// If known, encode and queue for response packet
if encoded, err := rlp.EncodeToBytes(results); err != nil {
log.Error("Failed to encode receipt", "err", err)
} else {
receipts = append(receipts, encoded)
bytes += len(encoded)
}
}
reply := p.ReplyReceiptsRLP(req.ReqID, receipts)
sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
if metrics.EnabledExpensive {
miscOutReceiptPacketsMeter.Mark(1)
miscOutReceiptTrafficMeter.Mark(int64(reply.size()))
}
}()
}
case GetProofsV2Msg:
p.Log().Trace("Received les/2 proofs request")
if metrics.EnabledExpensive {
miscInTrieProofPacketsMeter.Mark(1)
miscInTrieProofTrafficMeter.Mark(int64(msg.Size))
}
var req struct {
ReqID uint64
Reqs []ProofReq
}
if err := msg.Decode(&req); err != nil {
clientErrorMeter.Mark(1)
return errResp(ErrDecode, "msg %v: %v", msg, err)
}
// Gather state data until the fetch or network limits is reached
var (
lastBHash common.Hash
root common.Hash
)
reqCnt := len(req.Reqs)
if accept(req.ReqID, uint64(reqCnt), MaxProofsFetch) {
go func() {
nodes := light.NewNodeSet()
for i, request := range req.Reqs {
if i != 0 && !task.waitOrStop() {
sendResponse(req.ReqID, 0, nil, task.servingTime)
return
}
// Look up the root hash belonging to the request
var (
number *uint64
header *types.Header
trie state.Trie
)
if request.BHash != lastBHash {
root, lastBHash = common.Hash{}, request.BHash
if header = h.blockchain.GetHeaderByHash(request.BHash); header == nil {
p.Log().Warn("Failed to retrieve header for proof", "block", *number, "hash", request.BHash)
atomic.AddUint32(&p.invalidCount, 1)
continue
}
// Refuse to search stale state data in the database since looking for
// a non-exist key is kind of expensive.
local := h.blockchain.CurrentHeader().Number.Uint64()
if !h.server.archiveMode && header.Number.Uint64()+core.TriesInMemory <= local {
p.Log().Debug("Reject stale trie request", "number", header.Number.Uint64(), "head", local)
atomic.AddUint32(&p.invalidCount, 1)
continue
}
root = header.Root
}
// If a header lookup failed (non existent), ignore subsequent requests for the same header
if root == (common.Hash{}) {
atomic.AddUint32(&p.invalidCount, 1)
continue
}
// Open the account or storage trie for the request
statedb := h.blockchain.StateCache()
switch len(request.AccKey) {
case 0:
// No account key specified, open an account trie
trie, err = statedb.OpenTrie(root)
if trie == nil || err != nil {
p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "root", root, "err", err)
continue
}
default:
// Account key specified, open a storage trie
account, err := h.getAccount(statedb.TrieDB(), root, common.BytesToHash(request.AccKey))
if err != nil {
p.Log().Warn("Failed to retrieve account for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err)
atomic.AddUint32(&p.invalidCount, 1)
continue
}
trie, err = statedb.OpenStorageTrie(common.BytesToHash(request.AccKey), account.Root)
if trie == nil || err != nil {
p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "root", account.Root, "err", err)
continue
}
}
// Prove the user's request from the account or stroage trie
if err := trie.Prove(request.Key, request.FromLevel, nodes); err != nil {
p.Log().Warn("Failed to prove state request", "block", header.Number, "hash", header.Hash(), "err", err)
continue
}
if nodes.DataSize() >= softResponseLimit {
break
}
}
reply := p.ReplyProofsV2(req.ReqID, nodes.NodeList())
sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
if metrics.EnabledExpensive {
miscOutTrieProofPacketsMeter.Mark(1)
miscOutTrieProofTrafficMeter.Mark(int64(reply.size()))
}
}()
}
case GetHelperTrieProofsMsg:
p.Log().Trace("Received helper trie proof request")
if metrics.EnabledExpensive {
miscInHelperTriePacketsMeter.Mark(1)
miscInHelperTrieTrafficMeter.Mark(int64(msg.Size))
}
var req struct {
ReqID uint64
Reqs []HelperTrieReq
}
if err := msg.Decode(&req); err != nil {
clientErrorMeter.Mark(1)
return errResp(ErrDecode, "msg %v: %v", msg, err)
}
// Gather state data until the fetch or network limits is reached
var (
auxBytes int
auxData [][]byte
)
reqCnt := len(req.Reqs)
if accept(req.ReqID, uint64(reqCnt), MaxHelperTrieProofsFetch) {
go func() {
var (
lastIdx uint64
lastType uint
root common.Hash
auxTrie *trie.Trie
)
nodes := light.NewNodeSet()
for i, request := range req.Reqs {
if i != 0 && !task.waitOrStop() {
sendResponse(req.ReqID, 0, nil, task.servingTime)
return
}
if auxTrie == nil || request.Type != lastType || request.TrieIdx != lastIdx {
auxTrie, lastType, lastIdx = nil, request.Type, request.TrieIdx
var prefix string
if root, prefix = h.getHelperTrie(request.Type, request.TrieIdx); root != (common.Hash{}) {
auxTrie, _ = trie.New(root, trie.NewDatabase(rawdb.NewTable(h.chainDb, prefix)))
}
}
if request.AuxReq == auxRoot {
var data []byte
if root != (common.Hash{}) {
data = root[:]
}
auxData = append(auxData, data)
auxBytes += len(data)
} else {
if auxTrie != nil {
auxTrie.Prove(request.Key, request.FromLevel, nodes)
}
if request.AuxReq != 0 {
data := h.getAuxiliaryHeaders(request)
auxData = append(auxData, data)
auxBytes += len(data)
}
}
if nodes.DataSize()+auxBytes >= softResponseLimit {
break
}
}
reply := p.ReplyHelperTrieProofs(req.ReqID, HelperTrieResps{Proofs: nodes.NodeList(), AuxData: auxData})
sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
if metrics.EnabledExpensive {
miscOutHelperTriePacketsMeter.Mark(1)
miscOutHelperTrieTrafficMeter.Mark(int64(reply.size()))
}
}()
}
case SendTxV2Msg:
p.Log().Trace("Received new transactions")
if metrics.EnabledExpensive {
miscInTxsPacketsMeter.Mark(1)
miscInTxsTrafficMeter.Mark(int64(msg.Size))
}
var req struct {
ReqID uint64
Txs []*types.Transaction
}
if err := msg.Decode(&req); err != nil {
clientErrorMeter.Mark(1)
return errResp(ErrDecode, "msg %v: %v", msg, err)
}
reqCnt := len(req.Txs)
if accept(req.ReqID, uint64(reqCnt), MaxTxSend) {
go func() {
stats := make([]light.TxStatus, len(req.Txs))
for i, tx := range req.Txs {
if i != 0 && !task.waitOrStop() {
return
}
hash := tx.Hash()
stats[i] = h.txStatus(hash)
if stats[i].Status == core.TxStatusUnknown {
addFn := h.txpool.AddRemotes
// Add txs synchronously for testing purpose
if h.addTxsSync {
addFn = h.txpool.AddRemotesSync
}
if errs := addFn([]*types.Transaction{tx}); errs[0] != nil {
stats[i].Error = errs[0].Error()
continue
}
stats[i] = h.txStatus(hash)
}
}
reply := p.ReplyTxStatus(req.ReqID, stats)
sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
if metrics.EnabledExpensive {
miscOutTxsPacketsMeter.Mark(1)
miscOutTxsTrafficMeter.Mark(int64(reply.size()))
}
}()
}
case GetTxStatusMsg:
p.Log().Trace("Received transaction status query request")
if metrics.EnabledExpensive {
miscInTxStatusPacketsMeter.Mark(1)
miscInTxStatusTrafficMeter.Mark(int64(msg.Size))
}
var req struct {
ReqID uint64
Hashes []common.Hash
}
if err := msg.Decode(&req); err != nil {
clientErrorMeter.Mark(1)
return errResp(ErrDecode, "msg %v: %v", msg, err)
}
reqCnt := len(req.Hashes)
if accept(req.ReqID, uint64(reqCnt), MaxTxStatus) {
go func() {
stats := make([]light.TxStatus, len(req.Hashes))
for i, hash := range req.Hashes {
if i != 0 && !task.waitOrStop() {
sendResponse(req.ReqID, 0, nil, task.servingTime)
return
}
stats[i] = h.txStatus(hash)
}
reply := p.ReplyTxStatus(req.ReqID, stats)
sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
if metrics.EnabledExpensive {
miscOutTxStatusPacketsMeter.Mark(1)
miscOutTxStatusTrafficMeter.Mark(int64(reply.size()))
}
}()
}
default:
p.Log().Trace("Received invalid message", "code", msg.Code)
clientErrorMeter.Mark(1)
return errResp(ErrInvalidMsgCode, "%v", msg.Code)
}
// If the client has made too much invalid request(e.g. request a non-exist data),
// reject them to prevent SPAM attack.
if atomic.LoadUint32(&p.invalidCount) > maxRequestErrors {
clientErrorMeter.Mark(1)
return errTooManyInvalidRequest
}
return nil
}
// getAccount retrieves an account from the state based on root.
func (h *serverHandler) getAccount(triedb *trie.Database, root, hash common.Hash) (state.Account, error) {
trie, err := trie.New(root, triedb)
if err != nil {
return state.Account{}, err
}
blob, err := trie.TryGet(hash[:])
if err != nil {
return state.Account{}, err
}
var account state.Account
if err = rlp.DecodeBytes(blob, &account); err != nil {
return state.Account{}, err
}
return account, nil
}
// getHelperTrie returns the post-processed trie root for the given trie ID and section index
func (h *serverHandler) getHelperTrie(typ uint, index uint64) (common.Hash, string) {
switch typ {
case htCanonical:
sectionHead := rawdb.ReadCanonicalHash(h.chainDb, (index+1)*h.server.iConfig.ChtSize-1)
return light.GetChtRoot(h.chainDb, index, sectionHead), light.ChtTablePrefix
case htBloomBits:
sectionHead := rawdb.ReadCanonicalHash(h.chainDb, (index+1)*h.server.iConfig.BloomTrieSize-1)
return light.GetBloomTrieRoot(h.chainDb, index, sectionHead), light.BloomTrieTablePrefix
}
return common.Hash{}, ""
}
// getAuxiliaryHeaders returns requested auxiliary headers for the CHT request.
func (h *serverHandler) getAuxiliaryHeaders(req HelperTrieReq) []byte {
if req.Type == htCanonical && req.AuxReq == auxHeader && len(req.Key) == 8 {
blockNum := binary.BigEndian.Uint64(req.Key)
hash := rawdb.ReadCanonicalHash(h.chainDb, blockNum)
return rawdb.ReadHeaderRLP(h.chainDb, hash, blockNum)
}
return nil
}
// txStatus returns the status of a specified transaction.
func (h *serverHandler) txStatus(hash common.Hash) light.TxStatus {
var stat light.TxStatus
// Looking the transaction in txpool first.
stat.Status = h.txpool.Status([]common.Hash{hash})[0]
// If the transaction is unknown to the pool, try looking it up locally.
if stat.Status == core.TxStatusUnknown {
lookup := h.blockchain.GetTransactionLookup(hash)
if lookup != nil {
stat.Status = core.TxStatusIncluded
stat.Lookup = lookup
}
}
return stat
}
// broadcastHeaders broadcasts new block information to all connected light
// clients. According to the agreement between client and server, server should
// only broadcast new announcement if the total difficulty is higher than the
// last one. Besides server will add the signature if client requires.
func (h *serverHandler) broadcastHeaders() {
defer h.wg.Done()
headCh := make(chan core.ChainHeadEvent, 10)
headSub := h.blockchain.SubscribeChainHeadEvent(headCh)
defer headSub.Unsubscribe()
var (
lastHead *types.Header
lastTd = common.Big0
)
for {
select {
case ev := <-headCh:
peers := h.server.peers.AllPeers()
if len(peers) == 0 {
continue
}
header := ev.Block.Header()
hash, number := header.Hash(), header.Number.Uint64()
td := h.blockchain.GetTd(hash, number)
if td == nil || td.Cmp(lastTd) <= 0 {
continue
}
var reorg uint64
if lastHead != nil {
reorg = lastHead.Number.Uint64() - rawdb.FindCommonAncestor(h.chainDb, header, lastHead).Number.Uint64()
}
lastHead, lastTd = header, td
log.Debug("Announcing block to peers", "number", number, "hash", hash, "td", td, "reorg", reorg)
var (
signed bool
signedAnnounce announceData
)
announce := announceData{Hash: hash, Number: number, Td: td, ReorgDepth: reorg}
for _, p := range peers {
p := p
switch p.announceType {
case announceTypeSimple:
p.queueSend(func() { p.SendAnnounce(announce) })
case announceTypeSigned:
if !signed {
signedAnnounce = announce
signedAnnounce.sign(h.server.privateKey)
signed = true
}
p.queueSend(func() { p.SendAnnounce(signedAnnounce) })
}
}
case <-h.closeCh:
return
}
}
}

@ -115,8 +115,6 @@ type serverPool struct {
db ethdb.Database db ethdb.Database
dbKey []byte dbKey []byte
server *p2p.Server server *p2p.Server
quit chan struct{}
wg *sync.WaitGroup
connWg sync.WaitGroup connWg sync.WaitGroup
topic discv5.Topic topic discv5.Topic
@ -137,14 +135,15 @@ type serverPool struct {
connCh chan *connReq connCh chan *connReq
disconnCh chan *disconnReq disconnCh chan *disconnReq
registerCh chan *registerReq registerCh chan *registerReq
closeCh chan struct{}
wg sync.WaitGroup
} }
// newServerPool creates a new serverPool instance // newServerPool creates a new serverPool instance
func newServerPool(db ethdb.Database, quit chan struct{}, wg *sync.WaitGroup, trustedNodes []string) *serverPool { func newServerPool(db ethdb.Database, ulcServers []string) *serverPool {
pool := &serverPool{ pool := &serverPool{
db: db, db: db,
quit: quit,
wg: wg,
entries: make(map[enode.ID]*poolEntry), entries: make(map[enode.ID]*poolEntry),
timeout: make(chan *poolEntry, 1), timeout: make(chan *poolEntry, 1),
adjustStats: make(chan poolStatAdjust, 100), adjustStats: make(chan poolStatAdjust, 100),
@ -152,10 +151,11 @@ func newServerPool(db ethdb.Database, quit chan struct{}, wg *sync.WaitGroup, tr
connCh: make(chan *connReq), connCh: make(chan *connReq),
disconnCh: make(chan *disconnReq), disconnCh: make(chan *disconnReq),
registerCh: make(chan *registerReq), registerCh: make(chan *registerReq),
closeCh: make(chan struct{}),
knownSelect: newWeightedRandomSelect(), knownSelect: newWeightedRandomSelect(),
newSelect: newWeightedRandomSelect(), newSelect: newWeightedRandomSelect(),
fastDiscover: true, fastDiscover: true,
trustedNodes: parseTrustedNodes(trustedNodes), trustedNodes: parseTrustedNodes(ulcServers),
} }
pool.knownQueue = newPoolEntryQueue(maxKnownEntries, pool.removeEntry) pool.knownQueue = newPoolEntryQueue(maxKnownEntries, pool.removeEntry)
@ -167,7 +167,6 @@ func (pool *serverPool) start(server *p2p.Server, topic discv5.Topic) {
pool.server = server pool.server = server
pool.topic = topic pool.topic = topic
pool.dbKey = append([]byte("serverPool/"), []byte(topic)...) pool.dbKey = append([]byte("serverPool/"), []byte(topic)...)
pool.wg.Add(1)
pool.loadNodes() pool.loadNodes()
pool.connectToTrustedNodes() pool.connectToTrustedNodes()
@ -178,9 +177,15 @@ func (pool *serverPool) start(server *p2p.Server, topic discv5.Topic) {
go pool.discoverNodes() go pool.discoverNodes()
} }
pool.checkDial() pool.checkDial()
pool.wg.Add(1)
go pool.eventLoop() go pool.eventLoop()
} }
func (pool *serverPool) stop() {
close(pool.closeCh)
pool.wg.Wait()
}
// discoverNodes wraps SearchTopic, converting result nodes to enode.Node. // discoverNodes wraps SearchTopic, converting result nodes to enode.Node.
func (pool *serverPool) discoverNodes() { func (pool *serverPool) discoverNodes() {
ch := make(chan *discv5.Node) ch := make(chan *discv5.Node)
@ -207,7 +212,7 @@ func (pool *serverPool) connect(p *peer, node *enode.Node) *poolEntry {
req := &connReq{p: p, node: node, result: make(chan *poolEntry, 1)} req := &connReq{p: p, node: node, result: make(chan *poolEntry, 1)}
select { select {
case pool.connCh <- req: case pool.connCh <- req:
case <-pool.quit: case <-pool.closeCh:
return nil return nil
} }
return <-req.result return <-req.result
@ -219,7 +224,7 @@ func (pool *serverPool) registered(entry *poolEntry) {
req := &registerReq{entry: entry, done: make(chan struct{})} req := &registerReq{entry: entry, done: make(chan struct{})}
select { select {
case pool.registerCh <- req: case pool.registerCh <- req:
case <-pool.quit: case <-pool.closeCh:
return return
} }
<-req.done <-req.done
@ -231,7 +236,7 @@ func (pool *serverPool) registered(entry *poolEntry) {
func (pool *serverPool) disconnect(entry *poolEntry) { func (pool *serverPool) disconnect(entry *poolEntry) {
stopped := false stopped := false
select { select {
case <-pool.quit: case <-pool.closeCh:
stopped = true stopped = true
default: default:
} }
@ -278,6 +283,7 @@ func (pool *serverPool) adjustResponseTime(entry *poolEntry, time time.Duration,
// eventLoop handles pool events and mutex locking for all internal functions // eventLoop handles pool events and mutex locking for all internal functions
func (pool *serverPool) eventLoop() { func (pool *serverPool) eventLoop() {
defer pool.wg.Done()
lookupCnt := 0 lookupCnt := 0
var convTime mclock.AbsTime var convTime mclock.AbsTime
if pool.discSetPeriod != nil { if pool.discSetPeriod != nil {
@ -361,7 +367,7 @@ func (pool *serverPool) eventLoop() {
case req := <-pool.connCh: case req := <-pool.connCh:
if pool.trustedNodes[req.p.ID()] != nil { if pool.trustedNodes[req.p.ID()] != nil {
// ignore trusted nodes // ignore trusted nodes
req.result <- nil req.result <- &poolEntry{trusted: true}
} else { } else {
// Handle peer connection requests. // Handle peer connection requests.
entry := pool.entries[req.p.ID()] entry := pool.entries[req.p.ID()]
@ -389,6 +395,9 @@ func (pool *serverPool) eventLoop() {
} }
case req := <-pool.registerCh: case req := <-pool.registerCh:
if req.entry.trusted {
continue
}
// Handle peer registration requests. // Handle peer registration requests.
entry := req.entry entry := req.entry
entry.state = psRegistered entry.state = psRegistered
@ -402,10 +411,13 @@ func (pool *serverPool) eventLoop() {
close(req.done) close(req.done)
case req := <-pool.disconnCh: case req := <-pool.disconnCh:
if req.entry.trusted {
continue
}
// Handle peer disconnection requests. // Handle peer disconnection requests.
disconnect(req, req.stopped) disconnect(req, req.stopped)
case <-pool.quit: case <-pool.closeCh:
if pool.discSetPeriod != nil { if pool.discSetPeriod != nil {
close(pool.discSetPeriod) close(pool.discSetPeriod)
} }
@ -421,7 +433,6 @@ func (pool *serverPool) eventLoop() {
disconnect(req, true) disconnect(req, true)
} }
pool.saveNodes() pool.saveNodes()
pool.wg.Done()
return return
} }
} }
@ -549,10 +560,10 @@ func (pool *serverPool) setRetryDial(entry *poolEntry) {
entry.delayedRetry = true entry.delayedRetry = true
go func() { go func() {
select { select {
case <-pool.quit: case <-pool.closeCh:
case <-time.After(delay): case <-time.After(delay):
select { select {
case <-pool.quit: case <-pool.closeCh:
case pool.enableRetry <- entry: case pool.enableRetry <- entry:
} }
} }
@ -618,10 +629,10 @@ func (pool *serverPool) dial(entry *poolEntry, knownSelected bool) {
go func() { go func() {
pool.server.AddPeer(entry.node) pool.server.AddPeer(entry.node)
select { select {
case <-pool.quit: case <-pool.closeCh:
case <-time.After(dialTimeout): case <-time.After(dialTimeout):
select { select {
case <-pool.quit: case <-pool.closeCh:
case pool.timeout <- entry: case pool.timeout <- entry:
} }
} }
@ -662,14 +673,14 @@ type poolEntry struct {
lastConnected, dialed *poolEntryAddress lastConnected, dialed *poolEntryAddress
addrSelect weightedRandomSelect addrSelect weightedRandomSelect
lastDiscovered mclock.AbsTime lastDiscovered mclock.AbsTime
known, knownSelected bool known, knownSelected, trusted bool
connectStats, delayStats poolStats connectStats, delayStats poolStats
responseStats, timeoutStats poolStats responseStats, timeoutStats poolStats
state int state int
regTime mclock.AbsTime regTime mclock.AbsTime
queueIdx int queueIdx int
removed bool removed bool
delayedRetry bool delayedRetry bool
shortRetry int shortRetry int

@ -43,35 +43,6 @@ const (
checkpointSync checkpointSync
) )
// syncer is responsible for periodically synchronising with the network, both
// downloading hashes and blocks as well as handling the announcement handler.
func (pm *ProtocolManager) syncer() {
// Start and ensure cleanup of sync mechanisms
//pm.fetcher.Start()
//defer pm.fetcher.Stop()
defer pm.downloader.Terminate()
// Wait for different events to fire synchronisation operations
//forceSync := time.Tick(forceSyncCycle)
for {
select {
case <-pm.newPeerCh:
/* // Make sure we have peers to select from, then sync
if pm.peers.Len() < minDesiredPeerCount {
break
}
go pm.synchronise(pm.peers.BestPeer())
*/
/*case <-forceSync:
// Force a sync even if not enough peers are present
go pm.synchronise(pm.peers.BestPeer())
*/
case <-pm.noMorePeers:
return
}
}
}
// validateCheckpoint verifies the advertised checkpoint by peer is valid or not. // validateCheckpoint verifies the advertised checkpoint by peer is valid or not.
// //
// Each network has several hard-coded checkpoint signer addresses. Only the // Each network has several hard-coded checkpoint signer addresses. Only the
@ -80,22 +51,22 @@ func (pm *ProtocolManager) syncer() {
// In addition to the checkpoint registered in the registrar contract, there are // In addition to the checkpoint registered in the registrar contract, there are
// several legacy hardcoded checkpoints in our codebase. These checkpoints are // several legacy hardcoded checkpoints in our codebase. These checkpoints are
// also considered as valid. // also considered as valid.
func (pm *ProtocolManager) validateCheckpoint(peer *peer) error { func (h *clientHandler) validateCheckpoint(peer *peer) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel() defer cancel()
// Fetch the block header corresponding to the checkpoint registration. // Fetch the block header corresponding to the checkpoint registration.
cp := peer.checkpoint cp := peer.checkpoint
header, err := light.GetUntrustedHeaderByNumber(ctx, pm.odr, peer.checkpointNumber, peer.id) header, err := light.GetUntrustedHeaderByNumber(ctx, h.backend.odr, peer.checkpointNumber, peer.id)
if err != nil { if err != nil {
return err return err
} }
// Fetch block logs associated with the block header. // Fetch block logs associated with the block header.
logs, err := light.GetUntrustedBlockLogs(ctx, pm.odr, header) logs, err := light.GetUntrustedBlockLogs(ctx, h.backend.odr, header)
if err != nil { if err != nil {
return err return err
} }
events := pm.reg.contract.LookupCheckpointEvents(logs, cp.SectionIndex, cp.Hash()) events := h.backend.oracle.contract.LookupCheckpointEvents(logs, cp.SectionIndex, cp.Hash())
if len(events) == 0 { if len(events) == 0 {
return errInvalidCheckpoint return errInvalidCheckpoint
} }
@ -107,7 +78,7 @@ func (pm *ProtocolManager) validateCheckpoint(peer *peer) error {
for _, event := range events { for _, event := range events {
signatures = append(signatures, append(event.R[:], append(event.S[:], event.V)...)) signatures = append(signatures, append(event.R[:], append(event.S[:], event.V)...))
} }
valid, signers := pm.reg.verifySigners(index, hash, signatures) valid, signers := h.backend.oracle.verifySigners(index, hash, signatures)
if !valid { if !valid {
return errInvalidCheckpoint return errInvalidCheckpoint
} }
@ -116,14 +87,14 @@ func (pm *ProtocolManager) validateCheckpoint(peer *peer) error {
} }
// synchronise tries to sync up our local chain with a remote peer. // synchronise tries to sync up our local chain with a remote peer.
func (pm *ProtocolManager) synchronise(peer *peer) { func (h *clientHandler) synchronise(peer *peer) {
// Short circuit if the peer is nil. // Short circuit if the peer is nil.
if peer == nil { if peer == nil {
return return
} }
// Make sure the peer's TD is higher than our own. // Make sure the peer's TD is higher than our own.
latest := pm.blockchain.CurrentHeader() latest := h.backend.blockchain.CurrentHeader()
currentTd := rawdb.ReadTd(pm.chainDb, latest.Hash(), latest.Number.Uint64()) currentTd := rawdb.ReadTd(h.backend.chainDb, latest.Hash(), latest.Number.Uint64())
if currentTd != nil && peer.headBlockInfo().Td.Cmp(currentTd) < 0 { if currentTd != nil && peer.headBlockInfo().Td.Cmp(currentTd) < 0 {
return return
} }
@ -140,8 +111,8 @@ func (pm *ProtocolManager) synchronise(peer *peer) {
// => Use provided checkpoint // => Use provided checkpoint
var checkpoint = &peer.checkpoint var checkpoint = &peer.checkpoint
var hardcoded bool var hardcoded bool
if pm.checkpoint != nil && pm.checkpoint.SectionIndex >= peer.checkpoint.SectionIndex { if h.checkpoint != nil && h.checkpoint.SectionIndex >= peer.checkpoint.SectionIndex {
checkpoint = pm.checkpoint // Use the hardcoded one. checkpoint = h.checkpoint // Use the hardcoded one.
hardcoded = true hardcoded = true
} }
// Determine whether we should run checkpoint syncing or normal light syncing. // Determine whether we should run checkpoint syncing or normal light syncing.
@ -157,34 +128,34 @@ func (pm *ProtocolManager) synchronise(peer *peer) {
case checkpoint.Empty(): case checkpoint.Empty():
mode = lightSync mode = lightSync
log.Debug("Disable checkpoint syncing", "reason", "empty checkpoint") log.Debug("Disable checkpoint syncing", "reason", "empty checkpoint")
case latest.Number.Uint64() >= (checkpoint.SectionIndex+1)*pm.iConfig.ChtSize-1: case latest.Number.Uint64() >= (checkpoint.SectionIndex+1)*h.backend.iConfig.ChtSize-1:
mode = lightSync mode = lightSync
log.Debug("Disable checkpoint syncing", "reason", "local chain beyond the checkpoint") log.Debug("Disable checkpoint syncing", "reason", "local chain beyond the checkpoint")
case hardcoded: case hardcoded:
mode = legacyCheckpointSync mode = legacyCheckpointSync
log.Debug("Disable checkpoint syncing", "reason", "checkpoint is hardcoded") log.Debug("Disable checkpoint syncing", "reason", "checkpoint is hardcoded")
case pm.reg == nil || !pm.reg.isRunning(): case h.backend.oracle == nil || !h.backend.oracle.isRunning():
mode = legacyCheckpointSync mode = legacyCheckpointSync
log.Debug("Disable checkpoint syncing", "reason", "checkpoint syncing is not activated") log.Debug("Disable checkpoint syncing", "reason", "checkpoint syncing is not activated")
} }
// Notify testing framework if syncing has completed(for testing purpose). // Notify testing framework if syncing has completed(for testing purpose).
defer func() { defer func() {
if pm.reg != nil && pm.reg.syncDoneHook != nil { if h.backend.oracle != nil && h.backend.oracle.syncDoneHook != nil {
pm.reg.syncDoneHook() h.backend.oracle.syncDoneHook()
} }
}() }()
start := time.Now() start := time.Now()
if mode == checkpointSync || mode == legacyCheckpointSync { if mode == checkpointSync || mode == legacyCheckpointSync {
// Validate the advertised checkpoint // Validate the advertised checkpoint
if mode == legacyCheckpointSync { if mode == legacyCheckpointSync {
checkpoint = pm.checkpoint checkpoint = h.checkpoint
} else if mode == checkpointSync { } else if mode == checkpointSync {
if err := pm.validateCheckpoint(peer); err != nil { if err := h.validateCheckpoint(peer); err != nil {
log.Debug("Failed to validate checkpoint", "reason", err) log.Debug("Failed to validate checkpoint", "reason", err)
pm.removePeer(peer.id) h.removePeer(peer.id)
return return
} }
pm.blockchain.(*light.LightChain).AddTrustedCheckpoint(checkpoint) h.backend.blockchain.AddTrustedCheckpoint(checkpoint)
} }
log.Debug("Checkpoint syncing start", "peer", peer.id, "checkpoint", checkpoint.SectionIndex) log.Debug("Checkpoint syncing start", "peer", peer.id, "checkpoint", checkpoint.SectionIndex)
@ -197,14 +168,14 @@ func (pm *ProtocolManager) synchronise(peer *peer) {
// of the latest epoch covered by checkpoint. // of the latest epoch covered by checkpoint.
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel() defer cancel()
if !checkpoint.Empty() && !pm.blockchain.(*light.LightChain).SyncCheckpoint(ctx, checkpoint) { if !checkpoint.Empty() && !h.backend.blockchain.SyncCheckpoint(ctx, checkpoint) {
log.Debug("Sync checkpoint failed") log.Debug("Sync checkpoint failed")
pm.removePeer(peer.id) h.removePeer(peer.id)
return return
} }
} }
// Fetch the remaining block headers based on the current chain header. // Fetch the remaining block headers based on the current chain header.
if err := pm.downloader.Synchronise(peer.id, peer.Head(), peer.Td(), downloader.LightSync); err != nil { if err := h.downloader.Synchronise(peer.id, peer.Head(), peer.Td(), downloader.LightSync); err != nil {
log.Debug("Synchronise failed", "reason", err) log.Debug("Synchronise failed", "reason", err)
return return
} }

@ -57,7 +57,7 @@ func testCheckpointSyncing(t *testing.T, protocol int, syncMode int) {
} }
} }
// Generate 512+4 blocks (totally 1 CHT sections) // Generate 512+4 blocks (totally 1 CHT sections)
server, client, tearDown := newClientServerEnv(t, int(config.ChtSize+config.ChtConfirms), protocol, waitIndexers, false) server, client, tearDown := newClientServerEnv(t, int(config.ChtSize+config.ChtConfirms), protocol, waitIndexers, nil, 0, false, false)
defer tearDown() defer tearDown()
expected := config.ChtSize + config.ChtConfirms expected := config.ChtSize + config.ChtConfirms
@ -74,8 +74,8 @@ func testCheckpointSyncing(t *testing.T, protocol int, syncMode int) {
} }
if syncMode == 1 { if syncMode == 1 {
// Register the assembled checkpoint as hardcoded one. // Register the assembled checkpoint as hardcoded one.
client.pm.checkpoint = cp client.handler.checkpoint = cp
client.pm.blockchain.(*light.LightChain).AddTrustedCheckpoint(cp) client.handler.backend.blockchain.AddTrustedCheckpoint(cp)
} else { } else {
// Register the assembled checkpoint into oracle. // Register the assembled checkpoint into oracle.
header := server.backend.Blockchain().CurrentHeader() header := server.backend.Blockchain().CurrentHeader()
@ -83,14 +83,14 @@ func testCheckpointSyncing(t *testing.T, protocol int, syncMode int) {
data := append([]byte{0x19, 0x00}, append(registrarAddr.Bytes(), append([]byte{0, 0, 0, 0, 0, 0, 0, 0}, cp.Hash().Bytes()...)...)...) data := append([]byte{0x19, 0x00}, append(registrarAddr.Bytes(), append([]byte{0, 0, 0, 0, 0, 0, 0, 0}, cp.Hash().Bytes()...)...)...)
sig, _ := crypto.Sign(crypto.Keccak256(data), signerKey) sig, _ := crypto.Sign(crypto.Keccak256(data), signerKey)
sig[64] += 27 // Transform V from 0/1 to 27/28 according to the yellow paper sig[64] += 27 // Transform V from 0/1 to 27/28 according to the yellow paper
if _, err := server.pm.reg.contract.RegisterCheckpoint(bind.NewKeyedTransactor(signerKey), cp.SectionIndex, cp.Hash().Bytes(), new(big.Int).Sub(header.Number, big.NewInt(1)), header.ParentHash, [][]byte{sig}); err != nil { if _, err := server.handler.server.oracle.contract.RegisterCheckpoint(bind.NewKeyedTransactor(signerKey), cp.SectionIndex, cp.Hash().Bytes(), new(big.Int).Sub(header.Number, big.NewInt(1)), header.ParentHash, [][]byte{sig}); err != nil {
t.Error("register checkpoint failed", err) t.Error("register checkpoint failed", err)
} }
server.backend.Commit() server.backend.Commit()
// Wait for the checkpoint registration // Wait for the checkpoint registration
for { for {
_, hash, _, err := server.pm.reg.contract.Contract().GetLatestCheckpoint(nil) _, hash, _, err := server.handler.server.oracle.contract.Contract().GetLatestCheckpoint(nil)
if err != nil || hash == [32]byte{} { if err != nil || hash == [32]byte{} {
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
continue continue
@ -102,8 +102,8 @@ func testCheckpointSyncing(t *testing.T, protocol int, syncMode int) {
} }
done := make(chan error) done := make(chan error)
client.pm.reg.syncDoneHook = func() { client.handler.backend.oracle.syncDoneHook = func() {
header := client.pm.blockchain.CurrentHeader() header := client.handler.backend.blockchain.CurrentHeader()
if header.Number.Uint64() == expected { if header.Number.Uint64() == expected {
done <- nil done <- nil
} else { } else {
@ -112,7 +112,7 @@ func testCheckpointSyncing(t *testing.T, protocol int, syncMode int) {
} }
// Create connected peer pair. // Create connected peer pair.
peer, err1, lPeer, err2 := newTestPeerPair("peer", protocol, server.pm, client.pm) _, err1, _, err2 := newTestPeerPair("peer", protocol, server.handler, client.handler)
select { select {
case <-time.After(time.Millisecond * 100): case <-time.After(time.Millisecond * 100):
case err := <-err1: case err := <-err1:
@ -120,7 +120,6 @@ func testCheckpointSyncing(t *testing.T, protocol int, syncMode int) {
case err := <-err2: case err := <-err2:
t.Fatalf("peer 2 handshake error: %v", err) t.Fatalf("peer 2 handshake error: %v", err)
} }
server.rPeer, client.rPeer = peer, lPeer
select { select {
case err := <-done: case err := <-done:

@ -23,7 +23,6 @@ import (
"context" "context"
"crypto/rand" "crypto/rand"
"math/big" "math/big"
"sync"
"testing" "testing"
"time" "time"
@ -57,8 +56,8 @@ var (
userAddr1 = crypto.PubkeyToAddress(userKey1.PublicKey) userAddr1 = crypto.PubkeyToAddress(userKey1.PublicKey)
userAddr2 = crypto.PubkeyToAddress(userKey2.PublicKey) userAddr2 = crypto.PubkeyToAddress(userKey2.PublicKey)
testContractCode = common.Hex2Bytes("606060405260cc8060106000396000f360606040526000357c01000000000000000000000000000000000000000000000000000000009004806360cd2685146041578063c16431b914606b57603f565b005b6055600480803590602001909190505060a9565b6040518082815260200191505060405180910390f35b60886004808035906020019091908035906020019091905050608a565b005b80600060005083606481101560025790900160005b50819055505b5050565b6000600060005082606481101560025790900160005b5054905060c7565b91905056")
testContractAddr common.Address testContractAddr common.Address
testContractCode = common.Hex2Bytes("606060405260cc8060106000396000f360606040526000357c01000000000000000000000000000000000000000000000000000000009004806360cd2685146041578063c16431b914606b57603f565b005b6055600480803590602001909190505060a9565b6040518082815260200191505060405180910390f35b60886004808035906020019091908035906020019091905050608a565b005b80600060005083606481101560025790900160005b50819055505b5050565b6000600060005082606481101560025790900160005b5054905060c7565b91905056")
testContractCodeDeployed = testContractCode[16:] testContractCodeDeployed = testContractCode[16:]
testContractDeployed = uint64(2) testContractDeployed = uint64(2)
@ -77,8 +76,10 @@ var (
// The number of confirmations needed to generate a checkpoint(only used in test). // The number of confirmations needed to generate a checkpoint(only used in test).
processConfirms = big.NewInt(4) processConfirms = big.NewInt(4)
// // The token bucket buffer limit for testing purpose.
testBufLimit = uint64(1000000) testBufLimit = uint64(1000000)
// The buffer recharging speed for testing purpose.
testBufRecharge = uint64(1000) testBufRecharge = uint64(1000)
) )
@ -97,8 +98,8 @@ contract test {
} }
*/ */
// prepareTestchain pre-commits specified number customized blocks into chain. // prepare pre-commits specified number customized blocks into chain.
func prepareTestchain(n int, backend *backends.SimulatedBackend) { func prepare(n int, backend *backends.SimulatedBackend) {
var ( var (
ctx = context.Background() ctx = context.Background()
signer = types.HomesteadSigner{} signer = types.HomesteadSigner{}
@ -164,51 +165,88 @@ func testIndexers(db ethdb.Database, odr light.OdrBackend, config *light.Indexer
return indexers[:] return indexers[:]
} }
// newTestProtocolManager creates a new protocol manager for testing purposes, func newTestClientHandler(backend *backends.SimulatedBackend, odr *LesOdr, indexers []*core.ChainIndexer, db ethdb.Database, peers *peerSet, ulcServers []string, ulcFraction int) *clientHandler {
// with the given number of blocks already known, potential notification
// channels for different events and relative chain indexers array.
func newTestProtocolManager(lightSync bool, blocks int, odr *LesOdr, indexers []*core.ChainIndexer, peers *peerSet, db ethdb.Database, ulcServers []string, ulcFraction int, testCost uint64, clock mclock.Clock) (*ProtocolManager, *backends.SimulatedBackend, error) {
var ( var (
evmux = new(event.TypeMux) evmux = new(event.TypeMux)
engine = ethash.NewFaker() engine = ethash.NewFaker()
gspec = core.Genesis{ gspec = core.Genesis{
Config: params.AllEthashProtocolChanges, Config: params.AllEthashProtocolChanges,
Alloc: core.GenesisAlloc{bankAddr: {Balance: bankFunds}}, Alloc: core.GenesisAlloc{bankAddr: {Balance: bankFunds}},
GasLimit: 100000000,
} }
pool txPool oracle *checkpointOracle
chain BlockChain
exitCh = make(chan struct{})
) )
gspec.MustCommit(db) genesis := gspec.MustCommit(db)
if peers == nil { chain, _ := light.NewLightChain(odr, gspec.Config, engine, nil)
peers = newPeerSet() if indexers != nil {
checkpointConfig := &params.CheckpointOracleConfig{
Address: crypto.CreateAddress(bankAddr, 0),
Signers: []common.Address{signerAddr},
Threshold: 1,
}
getLocal := func(index uint64) params.TrustedCheckpoint {
chtIndexer := indexers[0]
sectionHead := chtIndexer.SectionHead(index)
return params.TrustedCheckpoint{
SectionIndex: index,
SectionHead: sectionHead,
CHTRoot: light.GetChtRoot(db, index, sectionHead),
BloomRoot: light.GetBloomTrieRoot(db, index, sectionHead),
}
}
oracle = newCheckpointOracle(checkpointConfig, getLocal)
} }
// create a simulation backend and pre-commit several customized block to the database. client := &LightEthereum{
simulation := backends.NewSimulatedBackendWithDatabase(db, gspec.Alloc, 100000000) lesCommons: lesCommons{
prepareTestchain(blocks, simulation) genesis: genesis.Hash(),
config: &eth.Config{LightPeers: 100, NetworkId: NetworkId},
// initialize empty chain for light client or pre-committed chain for server. chainConfig: params.AllEthashProtocolChanges,
if lightSync { iConfig: light.TestClientIndexerConfig,
chain, _ = light.NewLightChain(odr, gspec.Config, engine, nil) chainDb: db,
} else { oracle: oracle,
chain = simulation.Blockchain() chainReader: chain,
config := core.DefaultTxPoolConfig peers: peers,
config.Journal = "" closeCh: make(chan struct{}),
pool = core.NewTxPool(config, gspec.Config, simulation.Blockchain()) },
reqDist: odr.retriever.dist,
retriever: odr.retriever,
odr: odr,
engine: engine,
blockchain: chain,
eventMux: evmux,
} }
client.handler = newClientHandler(ulcServers, ulcFraction, nil, client)
// Create contract registrar if client.oracle != nil {
indexConfig := light.TestServerIndexerConfig client.oracle.start(backend)
if lightSync {
indexConfig = light.TestClientIndexerConfig
} }
config := &params.CheckpointOracleConfig{ return client.handler
Address: crypto.CreateAddress(bankAddr, 0), }
Signers: []common.Address{signerAddr},
Threshold: 1, func newTestServerHandler(blocks int, indexers []*core.ChainIndexer, db ethdb.Database, peers *peerSet, clock mclock.Clock) (*serverHandler, *backends.SimulatedBackend) {
} var (
var reg *checkpointOracle gspec = core.Genesis{
Config: params.AllEthashProtocolChanges,
Alloc: core.GenesisAlloc{bankAddr: {Balance: bankFunds}},
GasLimit: 100000000,
}
oracle *checkpointOracle
)
genesis := gspec.MustCommit(db)
// create a simulation backend and pre-commit several customized block to the database.
simulation := backends.NewSimulatedBackendWithDatabase(db, gspec.Alloc, 100000000)
prepare(blocks, simulation)
txpoolConfig := core.DefaultTxPoolConfig
txpoolConfig.Journal = ""
txpool := core.NewTxPool(txpoolConfig, gspec.Config, simulation.Blockchain())
if indexers != nil { if indexers != nil {
checkpointConfig := &params.CheckpointOracleConfig{
Address: crypto.CreateAddress(bankAddr, 0),
Signers: []common.Address{signerAddr},
Threshold: 1,
}
getLocal := func(index uint64) params.TrustedCheckpoint { getLocal := func(index uint64) params.TrustedCheckpoint {
chtIndexer := indexers[0] chtIndexer := indexers[0]
sectionHead := chtIndexer.SectionHead(index) sectionHead := chtIndexer.SectionHead(index)
@ -219,72 +257,63 @@ func newTestProtocolManager(lightSync bool, blocks int, odr *LesOdr, indexers []
BloomRoot: light.GetBloomTrieRoot(db, index, sectionHead), BloomRoot: light.GetBloomTrieRoot(db, index, sectionHead),
} }
} }
reg = newCheckpointOracle(config, getLocal) oracle = newCheckpointOracle(checkpointConfig, getLocal)
}
pm, err := NewProtocolManager(gspec.Config, nil, indexConfig, ulcServers, ulcFraction, lightSync, NetworkId, evmux, peers, chain, pool, db, odr, nil, reg, exitCh, new(sync.WaitGroup), func() bool { return true })
if err != nil {
return nil, nil, err
} }
// Registrar initialization could failed if checkpoint contract is not specified. server := &LesServer{
if pm.reg != nil { lesCommons: lesCommons{
pm.reg.start(simulation) genesis: genesis.Hash(),
} config: &eth.Config{LightPeers: 100, NetworkId: NetworkId},
// Set up les server stuff. chainConfig: params.AllEthashProtocolChanges,
if !lightSync { iConfig: light.TestServerIndexerConfig,
srv := &LesServer{lesCommons: lesCommons{protocolManager: pm, chainDb: db}} chainDb: db,
pm.server = srv chainReader: simulation.Blockchain(),
pm.servingQueue = newServingQueue(int64(time.Millisecond*10), 1) oracle: oracle,
pm.servingQueue.setThreads(4) peers: peers,
closeCh: make(chan struct{}),
srv.defParams = flowcontrol.ServerParams{ },
servingQueue: newServingQueue(int64(time.Millisecond*10), 1),
defParams: flowcontrol.ServerParams{
BufLimit: testBufLimit, BufLimit: testBufLimit,
MinRecharge: testBufRecharge, MinRecharge: testBufRecharge,
} },
srv.testCost = testCost fcManager: flowcontrol.NewClientManager(nil, clock),
srv.fcManager = flowcontrol.NewClientManager(nil, clock)
} }
pm.Start(1000) server.costTracker, server.freeCapacity = newCostTracker(db, server.config)
return pm, simulation, nil server.costTracker.testCostList = testCostList(0) // Disable flow control mechanism.
} server.handler = newServerHandler(server, simulation.Blockchain(), db, txpool, func() bool { return true })
if server.oracle != nil {
// newTestProtocolManagerMust creates a new protocol manager for testing purposes, server.oracle.start(simulation)
// with the given number of blocks already known, potential notification channels
// for different events and relative chain indexers array. In case of an error, the
// constructor force-fails the test.
func newTestProtocolManagerMust(t *testing.T, lightSync bool, blocks int, odr *LesOdr, indexers []*core.ChainIndexer, peers *peerSet, db ethdb.Database, ulcServers []string, ulcFraction int) (*ProtocolManager, *backends.SimulatedBackend) {
pm, backend, err := newTestProtocolManager(lightSync, blocks, odr, indexers, peers, db, ulcServers, ulcFraction, 0, &mclock.System{})
if err != nil {
t.Fatalf("Failed to create protocol manager: %v", err)
} }
return pm, backend server.servingQueue.setThreads(4)
server.handler.start()
return server.handler, simulation
} }
// testPeer is a simulated peer to allow testing direct network calls. // testPeer is a simulated peer to allow testing direct network calls.
type testPeer struct { type testPeer struct {
peer *peer
net p2p.MsgReadWriter // Network layer reader/writer to simulate remote messaging net p2p.MsgReadWriter // Network layer reader/writer to simulate remote messaging
app *p2p.MsgPipeRW // Application layer reader/writer to simulate the local side app *p2p.MsgPipeRW // Application layer reader/writer to simulate the local side
*peer
} }
// newTestPeer creates a new peer registered at the given protocol manager. // newTestPeer creates a new peer registered at the given protocol manager.
func newTestPeer(t *testing.T, name string, version int, pm *ProtocolManager, shake bool, testCost uint64) (*testPeer, <-chan error) { func newTestPeer(t *testing.T, name string, version int, handler *serverHandler, shake bool, testCost uint64) (*testPeer, <-chan error) {
// Create a message pipe to communicate through // Create a message pipe to communicate through
app, net := p2p.MsgPipe() app, net := p2p.MsgPipe()
// Generate a random id and create the peer // Generate a random id and create the peer
var id enode.ID var id enode.ID
rand.Read(id[:]) rand.Read(id[:])
peer := newPeer(version, NetworkId, false, p2p.NewPeer(id, name, nil), net)
peer := pm.newPeer(version, NetworkId, p2p.NewPeer(id, name, nil), net)
// Start the peer on a new thread // Start the peer on a new thread
errc := make(chan error, 1) errCh := make(chan error, 1)
go func() { go func() {
select { select {
case pm.newPeerCh <- peer: case <-handler.closeCh:
errc <- pm.handle(peer) errCh <- p2p.DiscQuitting
case <-pm.quitSync: case errCh <- handler.handle(peer):
errc <- p2p.DiscQuitting
} }
}() }()
tp := &testPeer{ tp := &testPeer{
@ -294,17 +323,27 @@ func newTestPeer(t *testing.T, name string, version int, pm *ProtocolManager, sh
} }
// Execute any implicitly requested handshakes and return // Execute any implicitly requested handshakes and return
if shake { if shake {
// Customize the cost table if required.
if testCost != 0 {
handler.server.costTracker.testCostList = testCostList(testCost)
}
var ( var (
genesis = pm.blockchain.Genesis() genesis = handler.blockchain.Genesis()
head = pm.blockchain.CurrentHeader() head = handler.blockchain.CurrentHeader()
td = pm.blockchain.GetTd(head.Hash(), head.Number.Uint64()) td = handler.blockchain.GetTd(head.Hash(), head.Number.Uint64())
) )
tp.handshake(t, td, head.Hash(), head.Number.Uint64(), genesis.Hash(), testCost) tp.handshake(t, td, head.Hash(), head.Number.Uint64(), genesis.Hash(), testCostList(testCost))
} }
return tp, errc return tp, errCh
}
// close terminates the local side of the peer, notifying the remote protocol
// manager of termination.
func (p *testPeer) close() {
p.app.Close()
} }
func newTestPeerPair(name string, version int, pm, pm2 *ProtocolManager) (*peer, <-chan error, *peer, <-chan error) { func newTestPeerPair(name string, version int, server *serverHandler, client *clientHandler) (*testPeer, <-chan error, *testPeer, <-chan error) {
// Create a message pipe to communicate through // Create a message pipe to communicate through
app, net := p2p.MsgPipe() app, net := p2p.MsgPipe()
@ -312,36 +351,34 @@ func newTestPeerPair(name string, version int, pm, pm2 *ProtocolManager) (*peer,
var id enode.ID var id enode.ID
rand.Read(id[:]) rand.Read(id[:])
peer := pm.newPeer(version, NetworkId, p2p.NewPeer(id, name, nil), net) peer1 := newPeer(version, NetworkId, false, p2p.NewPeer(id, name, nil), net)
peer2 := pm2.newPeer(version, NetworkId, p2p.NewPeer(id, name, nil), app) peer2 := newPeer(version, NetworkId, false, p2p.NewPeer(id, name, nil), app)
// Start the peer on a new thread // Start the peer on a new thread
errc := make(chan error, 1) errc1 := make(chan error, 1)
errc2 := make(chan error, 1) errc2 := make(chan error, 1)
go func() { go func() {
select { select {
case pm.newPeerCh <- peer: case <-server.closeCh:
errc <- pm.handle(peer) errc1 <- p2p.DiscQuitting
case <-pm.quitSync: case errc1 <- server.handle(peer1):
errc <- p2p.DiscQuitting
} }
}() }()
go func() { go func() {
select { select {
case pm2.newPeerCh <- peer2: case <-client.closeCh:
errc2 <- pm2.handle(peer2) errc1 <- p2p.DiscQuitting
case <-pm2.quitSync: case errc1 <- client.handle(peer2):
errc2 <- p2p.DiscQuitting
} }
}() }()
return peer, errc, peer2, errc2 return &testPeer{peer: peer1, net: net, app: app}, errc1, &testPeer{peer: peer2, net: app, app: net}, errc2
} }
// handshake simulates a trivial handshake that expects the same state from the // handshake simulates a trivial handshake that expects the same state from the
// remote side as we are simulating locally. // remote side as we are simulating locally.
func (p *testPeer) handshake(t *testing.T, td *big.Int, head common.Hash, headNum uint64, genesis common.Hash, testCost uint64) { func (p *testPeer) handshake(t *testing.T, td *big.Int, head common.Hash, headNum uint64, genesis common.Hash, costList RequestCostList) {
var expList keyValueList var expList keyValueList
expList = expList.add("protocolVersion", uint64(p.version)) expList = expList.add("protocolVersion", uint64(p.peer.version))
expList = expList.add("networkId", uint64(NetworkId)) expList = expList.add("networkId", uint64(NetworkId))
expList = expList.add("headTd", td) expList = expList.add("headTd", td)
expList = expList.add("headHash", head) expList = expList.add("headHash", head)
@ -356,7 +393,7 @@ func (p *testPeer) handshake(t *testing.T, td *big.Int, head common.Hash, headNu
expList = expList.add("txRelay", nil) expList = expList.add("txRelay", nil)
expList = expList.add("flowControl/BL", testBufLimit) expList = expList.add("flowControl/BL", testBufLimit)
expList = expList.add("flowControl/MRR", testBufRecharge) expList = expList.add("flowControl/MRR", testBufRecharge)
expList = expList.add("flowControl/MRC", testCostList(testCost)) expList = expList.add("flowControl/MRC", costList)
if err := p2p.ExpectMsg(p.app, StatusMsg, expList); err != nil { if err := p2p.ExpectMsg(p.app, StatusMsg, expList); err != nil {
t.Fatalf("status recv: %v", err) t.Fatalf("status recv: %v", err)
@ -364,113 +401,119 @@ func (p *testPeer) handshake(t *testing.T, td *big.Int, head common.Hash, headNu
if err := p2p.Send(p.app, StatusMsg, sendList); err != nil { if err := p2p.Send(p.app, StatusMsg, sendList); err != nil {
t.Fatalf("status send: %v", err) t.Fatalf("status send: %v", err)
} }
p.peer.fcParams = flowcontrol.ServerParams{
p.fcParams = flowcontrol.ServerParams{
BufLimit: testBufLimit, BufLimit: testBufLimit,
MinRecharge: testBufRecharge, MinRecharge: testBufRecharge,
} }
} }
// close terminates the local side of the peer, notifying the remote protocol type indexerCallback func(*core.ChainIndexer, *core.ChainIndexer, *core.ChainIndexer)
// manager of termination.
func (p *testPeer) close() {
p.app.Close()
}
// TestEntity represents a network entity for testing with necessary auxiliary fields. // testClient represents a client for testing with necessary auxiliary fields.
type TestEntity struct { type testClient struct {
clock mclock.Clock
db ethdb.Database db ethdb.Database
rPeer *peer peer *testPeer
tPeer *testPeer handler *clientHandler
peers *peerSet
pm *ProtocolManager chtIndexer *core.ChainIndexer
bloomIndexer *core.ChainIndexer
bloomTrieIndexer *core.ChainIndexer
}
// testServer represents a server for testing with necessary auxiliary fields.
type testServer struct {
clock mclock.Clock
backend *backends.SimulatedBackend backend *backends.SimulatedBackend
db ethdb.Database
peer *testPeer
handler *serverHandler
// Indexers
chtIndexer *core.ChainIndexer chtIndexer *core.ChainIndexer
bloomIndexer *core.ChainIndexer bloomIndexer *core.ChainIndexer
bloomTrieIndexer *core.ChainIndexer bloomTrieIndexer *core.ChainIndexer
} }
// newServerEnv creates a server testing environment with a connected test peer for testing purpose. func newServerEnv(t *testing.T, blocks int, protocol int, callback indexerCallback, simClock bool, newPeer bool, testCost uint64) (*testServer, func()) {
func newServerEnv(t *testing.T, blocks int, protocol int, waitIndexers func(*core.ChainIndexer, *core.ChainIndexer, *core.ChainIndexer)) (*TestEntity, func()) {
db := rawdb.NewMemoryDatabase() db := rawdb.NewMemoryDatabase()
indexers := testIndexers(db, nil, light.TestServerIndexerConfig) indexers := testIndexers(db, nil, light.TestServerIndexerConfig)
pm, b := newTestProtocolManagerMust(t, false, blocks, nil, indexers, nil, db, nil, 0) var clock mclock.Clock = &mclock.System{}
peer, _ := newTestPeer(t, "peer", protocol, pm, true, 0) if simClock {
clock = &mclock.Simulated{}
}
handler, b := newTestServerHandler(blocks, indexers, db, newPeerSet(), clock)
var peer *testPeer
if newPeer {
peer, _ = newTestPeer(t, "peer", protocol, handler, true, testCost)
}
cIndexer, bIndexer, btIndexer := indexers[0], indexers[1], indexers[2] cIndexer, bIndexer, btIndexer := indexers[0], indexers[1], indexers[2]
cIndexer.Start(pm.blockchain.(*core.BlockChain)) cIndexer.Start(handler.blockchain)
bIndexer.Start(pm.blockchain.(*core.BlockChain)) bIndexer.Start(handler.blockchain)
// Wait until indexers generate enough index data. // Wait until indexers generate enough index data.
if waitIndexers != nil { if callback != nil {
waitIndexers(cIndexer, bIndexer, btIndexer) callback(cIndexer, bIndexer, btIndexer)
} }
server := &testServer{
return &TestEntity{ clock: clock,
db: db, backend: b,
tPeer: peer, db: db,
pm: pm, peer: peer,
backend: b, handler: handler,
chtIndexer: cIndexer, chtIndexer: cIndexer,
bloomIndexer: bIndexer, bloomIndexer: bIndexer,
bloomTrieIndexer: btIndexer, bloomTrieIndexer: btIndexer,
}, func() { }
teardown := func() {
if newPeer {
peer.close() peer.close()
// Note bloom trie indexer will be closed by it parent recursively.
cIndexer.Close()
bIndexer.Close()
b.Close() b.Close()
} }
cIndexer.Close()
bIndexer.Close()
}
return server, teardown
} }
// newClientServerEnv creates a client/server arch environment with a connected les server and light client pair func newClientServerEnv(t *testing.T, blocks int, protocol int, callback indexerCallback, ulcServers []string, ulcFraction int, simClock bool, connect bool) (*testServer, *testClient, func()) {
// for testing purpose. sdb, cdb := rawdb.NewMemoryDatabase(), rawdb.NewMemoryDatabase()
func newClientServerEnv(t *testing.T, blocks int, protocol int, waitIndexers func(*core.ChainIndexer, *core.ChainIndexer, *core.ChainIndexer), newPeer bool) (*TestEntity, *TestEntity, func()) { speers, cPeers := newPeerSet(), newPeerSet()
db, ldb := rawdb.NewMemoryDatabase(), rawdb.NewMemoryDatabase()
peers, lPeers := newPeerSet(), newPeerSet()
dist := newRequestDistributor(lPeers, make(chan struct{}), &mclock.System{}) var clock mclock.Clock = &mclock.System{}
rm := newRetrieveManager(lPeers, dist, nil) if simClock {
odr := NewLesOdr(ldb, light.TestClientIndexerConfig, rm) clock = &mclock.Simulated{}
}
dist := newRequestDistributor(cPeers, clock)
rm := newRetrieveManager(cPeers, dist, nil)
odr := NewLesOdr(cdb, light.TestClientIndexerConfig, rm)
indexers := testIndexers(db, nil, light.TestServerIndexerConfig) sindexers := testIndexers(sdb, nil, light.TestServerIndexerConfig)
lIndexers := testIndexers(ldb, odr, light.TestClientIndexerConfig) cIndexers := testIndexers(cdb, odr, light.TestClientIndexerConfig)
cIndexer, bIndexer, btIndexer := indexers[0], indexers[1], indexers[2] scIndexer, sbIndexer, sbtIndexer := sindexers[0], sindexers[1], sindexers[2]
lcIndexer, lbIndexer, lbtIndexer := lIndexers[0], lIndexers[1], lIndexers[2] ccIndexer, cbIndexer, cbtIndexer := cIndexers[0], cIndexers[1], cIndexers[2]
odr.SetIndexers(ccIndexer, cbIndexer, cbtIndexer)
odr.SetIndexers(lcIndexer, lbtIndexer, lbIndexer) server, b := newTestServerHandler(blocks, sindexers, sdb, speers, clock)
client := newTestClientHandler(b, odr, cIndexers, cdb, cPeers, ulcServers, ulcFraction)
pm, b := newTestProtocolManagerMust(t, false, blocks, nil, indexers, peers, db, nil, 0) scIndexer.Start(server.blockchain)
lpm, lb := newTestProtocolManagerMust(t, true, 0, odr, lIndexers, lPeers, ldb, nil, 0) sbIndexer.Start(server.blockchain)
ccIndexer.Start(client.backend.blockchain)
cbIndexer.Start(client.backend.blockchain)
startIndexers := func(clientMode bool, pm *ProtocolManager) { if callback != nil {
if clientMode { callback(scIndexer, sbIndexer, sbtIndexer)
lcIndexer.Start(pm.blockchain.(*light.LightChain))
lbIndexer.Start(pm.blockchain.(*light.LightChain))
} else {
cIndexer.Start(pm.blockchain.(*core.BlockChain))
bIndexer.Start(pm.blockchain.(*core.BlockChain))
}
} }
startIndexers(false, pm)
startIndexers(true, lpm)
// Execute wait until function if it is specified.
if waitIndexers != nil {
waitIndexers(cIndexer, bIndexer, btIndexer)
}
var ( var (
peer, lPeer *peer speer, cpeer *testPeer
err1, err2 <-chan error err1, err2 <-chan error
) )
if newPeer { if connect {
peer, err1, lPeer, err2 = newTestPeerPair("peer", protocol, pm, lpm) cpeer, err1, speer, err2 = newTestPeerPair("peer", protocol, server, client)
select { select {
case <-time.After(time.Millisecond * 100): case <-time.After(time.Millisecond * 100):
case err := <-err1: case err := <-err1:
@ -479,32 +522,35 @@ func newClientServerEnv(t *testing.T, blocks int, protocol int, waitIndexers fun
t.Fatalf("peer 2 handshake error: %v", err) t.Fatalf("peer 2 handshake error: %v", err)
} }
} }
s := &testServer{
return &TestEntity{ clock: clock,
db: db, backend: b,
pm: pm, db: sdb,
rPeer: peer, peer: cpeer,
peers: peers, handler: server,
backend: b, chtIndexer: scIndexer,
chtIndexer: cIndexer, bloomIndexer: sbIndexer,
bloomIndexer: bIndexer, bloomTrieIndexer: sbtIndexer,
bloomTrieIndexer: btIndexer, }
}, &TestEntity{ c := &testClient{
db: ldb, clock: clock,
pm: lpm, db: cdb,
rPeer: lPeer, peer: speer,
peers: lPeers, handler: client,
backend: lb, chtIndexer: ccIndexer,
chtIndexer: lcIndexer, bloomIndexer: cbIndexer,
bloomIndexer: lbIndexer, bloomTrieIndexer: cbtIndexer,
bloomTrieIndexer: lbtIndexer, }
}, func() { teardown := func() {
// Note bloom trie indexers will be closed by their parents recursively. if connect {
cIndexer.Close() speer.close()
bIndexer.Close() cpeer.close()
lcIndexer.Close()
lbIndexer.Close()
b.Close()
lb.Close()
} }
ccIndexer.Close()
cbIndexer.Close()
scIndexer.Close()
sbIndexer.Close()
b.Close()
}
return s, c, teardown
} }

@ -17,151 +17,100 @@
package les package les
import ( import (
"crypto/ecdsa" "crypto/rand"
"fmt" "fmt"
"math/big"
"net" "net"
"reflect"
"testing" "testing"
"time" "time"
"github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/light"
"github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enode"
) )
func TestULCSyncWithOnePeer(t *testing.T) { func TestULCAnnounceThresholdLes2(t *testing.T) { testULCAnnounceThreshold(t, 2) }
f := newFullPeerPair(t, 1, 4) func TestULCAnnounceThresholdLes3(t *testing.T) { testULCAnnounceThreshold(t, 3) }
l := newLightPeer(t, []string{f.Node.String()}, 100)
func testULCAnnounceThreshold(t *testing.T, protocol int) {
if reflect.DeepEqual(f.PM.blockchain.CurrentHeader().Hash(), l.PM.blockchain.CurrentHeader().Hash()) { // todo figure out why it takes fetcher so longer to fetcher the announced header.
t.Fatal("blocks are equal") t.Skip("Sometimes it can failed")
} var cases = []struct {
_, _, err := connectPeers(f, l, 2) height []int
if err != nil { threshold int
t.Fatal(err) expect uint64
} }{
l.PM.fetcher.lock.Lock() {[]int{1}, 100, 1},
l.PM.fetcher.nextRequest() {[]int{0, 0, 0}, 100, 0},
l.PM.fetcher.lock.Unlock() {[]int{1, 2, 3}, 30, 3},
{[]int{1, 2, 3}, 60, 2},
if !reflect.DeepEqual(f.PM.blockchain.CurrentHeader().Hash(), l.PM.blockchain.CurrentHeader().Hash()) { {[]int{3, 2, 1}, 67, 1},
t.Fatal("sync doesn't work") {[]int{3, 2, 1}, 100, 1},
} }
} for _, testcase := range cases {
var (
func TestULCReceiveAnnounce(t *testing.T) { servers []*testServer
f := newFullPeerPair(t, 1, 4) teardowns []func()
l := newLightPeer(t, []string{f.Node.String()}, 100) nodes []*enode.Node
fPeer, lPeer, err := connectPeers(f, l, 2) ids []string
if err != nil { )
t.Fatal(err) for i := 0; i < len(testcase.height); i++ {
} s, n, teardown := newServerPeer(t, 0, protocol)
l.PM.synchronise(fPeer)
servers = append(servers, s)
//check that the sync is finished correctly nodes = append(nodes, n)
if !reflect.DeepEqual(f.PM.blockchain.CurrentHeader().Hash(), l.PM.blockchain.CurrentHeader().Hash()) { teardowns = append(teardowns, teardown)
t.Fatal("sync doesn't work") ids = append(ids, n.String())
} }
l.PM.peers.lock.Lock() c, teardown := newLightPeer(t, protocol, ids, testcase.threshold)
if len(l.PM.peers.peers) == 0 {
t.Fatal("peer list should not be empty")
}
l.PM.peers.lock.Unlock()
time.Sleep(time.Second)
//send a signed announce message(payload doesn't matter)
td := f.PM.blockchain.GetTd(l.PM.blockchain.CurrentHeader().Hash(), l.PM.blockchain.CurrentHeader().Number.Uint64())
announce := announceData{
Number: l.PM.blockchain.CurrentHeader().Number.Uint64() + 1,
Td: td.Add(td, big.NewInt(1)),
}
announce.sign(f.Key)
lPeer.SendAnnounce(announce)
}
func TestULCShouldNotSyncWithTwoPeersOneHaveEmptyChain(t *testing.T) {
f1 := newFullPeerPair(t, 1, 4)
f2 := newFullPeerPair(t, 2, 0)
l := newLightPeer(t, []string{f1.Node.String(), f2.Node.String()}, 100)
_, _, err := connectPeers(f1, l, 2)
if err != nil {
t.Fatal(err)
}
_, _, err = connectPeers(f2, l, 2)
if err != nil {
t.Fatal(err)
}
l.PM.fetcher.lock.Lock()
l.PM.fetcher.nextRequest()
l.PM.fetcher.lock.Unlock()
if reflect.DeepEqual(f2.PM.blockchain.CurrentHeader().Hash(), l.PM.blockchain.CurrentHeader().Hash()) {
t.Fatal("Incorrect hash: second peer has empty chain")
}
}
func TestULCShouldNotSyncWithThreePeersOneHaveEmptyChain(t *testing.T) {
f1 := newFullPeerPair(t, 1, 3)
f2 := newFullPeerPair(t, 2, 4)
f3 := newFullPeerPair(t, 3, 0)
l := newLightPeer(t, []string{f1.Node.String(), f2.Node.String(), f3.Node.String()}, 60) // Connect all servers.
_, _, err := connectPeers(f1, l, 2) for i := 0; i < len(servers); i++ {
if err != nil { connect(servers[i].handler, nodes[i].ID(), c.handler, protocol)
t.Fatal(err) }
} for i := 0; i < len(servers); i++ {
_, _, err = connectPeers(f2, l, 2) for j := 0; j < testcase.height[i]; j++ {
if err != nil { servers[i].backend.Commit()
t.Fatal(err) }
} }
_, _, err = connectPeers(f3, l, 2) time.Sleep(1500 * time.Millisecond) // Ensure the fetcher has done its work.
if err != nil { head := c.handler.backend.blockchain.CurrentHeader().Number.Uint64()
t.Fatal(err) if head != testcase.expect {
} t.Fatalf("chain height mismatch, want %d, got %d", testcase.expect, head)
l.PM.fetcher.lock.Lock() }
l.PM.fetcher.nextRequest()
l.PM.fetcher.lock.Unlock()
if !reflect.DeepEqual(f1.PM.blockchain.CurrentHeader().Hash(), l.PM.blockchain.CurrentHeader().Hash()) { // Release all servers and client resources.
t.Fatal("Incorrect hash") teardown()
for i := 0; i < len(teardowns); i++ {
teardowns[i]()
}
} }
} }
type pairPeer struct { func connect(server *serverHandler, serverId enode.ID, client *clientHandler, protocol int) (*peer, *peer, error) {
Name string
Node *enode.Node
PM *ProtocolManager
Key *ecdsa.PrivateKey
}
func connectPeers(full, light pairPeer, version int) (*peer, *peer, error) {
// Create a message pipe to communicate through // Create a message pipe to communicate through
app, net := p2p.MsgPipe() app, net := p2p.MsgPipe()
peerLight := full.PM.newPeer(version, NetworkId, p2p.NewPeer(light.Node.ID(), light.Name, nil), net) var id enode.ID
peerFull := light.PM.newPeer(version, NetworkId, p2p.NewPeer(full.Node.ID(), full.Name, nil), app) rand.Read(id[:])
peer1 := newPeer(protocol, NetworkId, true, p2p.NewPeer(serverId, "", nil), net) // Mark server as trusted
peer2 := newPeer(protocol, NetworkId, false, p2p.NewPeer(id, "", nil), app)
// Start the peerLight on a new thread // Start the peerLight on a new thread
errc1 := make(chan error, 1) errc1 := make(chan error, 1)
errc2 := make(chan error, 1) errc2 := make(chan error, 1)
go func() { go func() {
select { select {
case light.PM.newPeerCh <- peerFull: case <-server.closeCh:
errc1 <- light.PM.handle(peerFull)
case <-light.PM.quitSync:
errc1 <- p2p.DiscQuitting errc1 <- p2p.DiscQuitting
case errc1 <- server.handle(peer2):
} }
}() }()
go func() { go func() {
select { select {
case full.PM.newPeerCh <- peerLight: case <-client.closeCh:
errc2 <- full.PM.handle(peerLight) errc1 <- p2p.DiscQuitting
case <-full.PM.quitSync: case errc1 <- client.handle(peer1):
errc2 <- p2p.DiscQuitting
} }
}() }()
@ -172,48 +121,23 @@ func connectPeers(full, light pairPeer, version int) (*peer, *peer, error) {
case err := <-errc2: case err := <-errc2:
return nil, nil, fmt.Errorf("peerFull handshake error: %v", err) return nil, nil, fmt.Errorf("peerFull handshake error: %v", err)
} }
return peer1, peer2, nil
return peerFull, peerLight, nil
} }
// newFullPeerPair creates node with full sync mode // newServerPeer creates server peer.
func newFullPeerPair(t *testing.T, index int, numberOfblocks int) pairPeer { func newServerPeer(t *testing.T, blocks int, protocol int) (*testServer, *enode.Node, func()) {
db := rawdb.NewMemoryDatabase() s, teardown := newServerEnv(t, blocks, protocol, nil, false, false, 0)
pmFull, _ := newTestProtocolManagerMust(t, false, numberOfblocks, nil, nil, nil, db, nil, 0)
peerPairFull := pairPeer{
Name: "full node",
PM: pmFull,
}
key, err := crypto.GenerateKey() key, err := crypto.GenerateKey()
if err != nil { if err != nil {
t.Fatal("generate key err:", err) t.Fatal("generate key err:", err)
} }
peerPairFull.Key = key s.handler.server.privateKey = key
peerPairFull.Node = enode.NewV4(&key.PublicKey, net.ParseIP("127.0.0.1"), 35000, 35000) n := enode.NewV4(&key.PublicKey, net.ParseIP("127.0.0.1"), 35000, 35000)
return peerPairFull return s, n, teardown
} }
// newLightPeer creates node with light sync mode // newLightPeer creates node with light sync mode
func newLightPeer(t *testing.T, ulcServers []string, ulcFraction int) pairPeer { func newLightPeer(t *testing.T, protocol int, ulcServers []string, ulcFraction int) (*testClient, func()) {
peers := newPeerSet() _, c, teardown := newClientServerEnv(t, 0, protocol, nil, ulcServers, ulcFraction, false, false)
dist := newRequestDistributor(peers, make(chan struct{}), &mclock.System{}) return c, teardown
rm := newRetrieveManager(peers, dist, nil)
ldb := rawdb.NewMemoryDatabase()
odr := NewLesOdr(ldb, light.DefaultClientIndexerConfig, rm)
pmLight, _ := newTestProtocolManagerMust(t, true, 0, odr, nil, peers, ldb, ulcServers, ulcFraction)
peerPairLight := pairPeer{
Name: "ulc node",
PM: pmLight,
}
key, err := crypto.GenerateKey()
if err != nil {
t.Fatal("generate key err:", err)
}
peerPairLight.Key = key
peerPairLight.Node = enode.NewV4(&key.PublicKey, net.IP{}, 35000, 35000)
return peerPairLight
} }

@ -60,7 +60,7 @@ func GetHeaderByNumber(ctx context.Context, odr OdrBackend, number uint64) (*typ
} }
} }
if number >= chtCount*odr.IndexerConfig().ChtSize { if number >= chtCount*odr.IndexerConfig().ChtSize {
return nil, ErrNoTrustedCht return nil, errNoTrustedCht
} }
r := &ChtRequest{ChtRoot: GetChtRoot(db, chtCount-1, sectionHead), ChtNum: chtCount - 1, BlockNum: number, Config: odr.IndexerConfig()} r := &ChtRequest{ChtRoot: GetChtRoot(db, chtCount-1, sectionHead), ChtNum: chtCount - 1, BlockNum: number, Config: odr.IndexerConfig()}
if err := odr.Retrieve(ctx, r); err != nil { if err := odr.Retrieve(ctx, r); err != nil {
@ -124,7 +124,7 @@ func GetBlock(ctx context.Context, odr OdrBackend, hash common.Hash, number uint
// Retrieve the block header and body contents // Retrieve the block header and body contents
header := rawdb.ReadHeader(odr.Database(), hash, number) header := rawdb.ReadHeader(odr.Database(), hash, number)
if header == nil { if header == nil {
return nil, ErrNoHeader return nil, errNoHeader
} }
body, err := GetBody(ctx, odr, hash, number) body, err := GetBody(ctx, odr, hash, number)
if err != nil { if err != nil {
@ -241,7 +241,7 @@ func GetBloomBits(ctx context.Context, odr OdrBackend, bitIdx uint, sectionIdxLi
} else { } else {
// TODO(rjl493456442) Convert sectionIndex to BloomTrie relative index // TODO(rjl493456442) Convert sectionIndex to BloomTrie relative index
if sectionIdx >= bloomTrieCount { if sectionIdx >= bloomTrieCount {
return nil, ErrNoTrustedBloomTrie return nil, errNoTrustedBloomTrie
} }
reqList = append(reqList, sectionIdx) reqList = append(reqList, sectionIdx)
reqIdx = append(reqIdx, i) reqIdx = append(reqIdx, i)

@ -98,9 +98,9 @@ var (
) )
var ( var (
ErrNoTrustedCht = errors.New("no trusted canonical hash trie") errNoTrustedCht = errors.New("no trusted canonical hash trie")
ErrNoTrustedBloomTrie = errors.New("no trusted bloom trie") errNoTrustedBloomTrie = errors.New("no trusted bloom trie")
ErrNoHeader = errors.New("header not found") errNoHeader = errors.New("header not found")
chtPrefix = []byte("chtRootV2-") // chtPrefix + chtNum (uint64 big endian) -> trie root hash chtPrefix = []byte("chtRootV2-") // chtPrefix + chtNum (uint64 big endian) -> trie root hash
ChtTablePrefix = "cht-" ChtTablePrefix = "cht-"
) )

Loading…
Cancel
Save