les, les/flowcontrol: implement LES/3 (#19329)

les, les/flowcontrol: implement LES/3
pull/19646/head
Felföldi Zsolt 6 years ago committed by GitHub
parent 3d58268bba
commit 58497f46bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 8
      cmd/utils/flags.go
  2. 12
      core/blockchain.go
  3. 22
      core/blockchain_test.go
  4. 32
      les/api.go
  5. 188
      les/costtracker.go
  6. 227
      les/csvlogger/csvlogger.go
  7. 33
      les/distributor.go
  8. 9
      les/execqueue.go
  9. 111
      les/flowcontrol/control.go
  10. 193
      les/flowcontrol/manager.go
  11. 12
      les/flowcontrol/manager_test.go
  12. 57
      les/freeclient.go
  13. 4
      les/freeclient_test.go
  14. 649
      les/handler.go
  15. 51
      les/handler_test.go
  16. 29
      les/helper_test.go
  17. 110
      les/peer.go
  18. 4
      les/peer_test.go
  19. 10
      les/protocol.go
  20. 70
      les/retrieve.go
  21. 91
      les/server.go
  22. 219
      les/servingqueue.go

@ -205,13 +205,13 @@ var (
} }
LightBandwidthInFlag = cli.IntFlag{ LightBandwidthInFlag = cli.IntFlag{
Name: "lightbwin", Name: "lightbwin",
Usage: "Incoming bandwidth limit for light server (1000 bytes/sec, 0 = unlimited)", Usage: "Incoming bandwidth limit for light server (kilobytes/sec, 0 = unlimited)",
Value: 1000, Value: 0,
} }
LightBandwidthOutFlag = cli.IntFlag{ LightBandwidthOutFlag = cli.IntFlag{
Name: "lightbwout", Name: "lightbwout",
Usage: "Outgoing bandwidth limit for light server (1000 bytes/sec, 0 = unlimited)", Usage: "Outgoing bandwidth limit for light server (kilobytes/sec, 0 = unlimited)",
Value: 5000, Value: 0,
} }
LightPeersFlag = cli.IntFlag{ LightPeersFlag = cli.IntFlag{
Name: "lightpeers", Name: "lightpeers",

@ -74,7 +74,7 @@ const (
maxFutureBlocks = 256 maxFutureBlocks = 256
maxTimeFutureBlocks = 30 maxTimeFutureBlocks = 30
badBlockLimit = 10 badBlockLimit = 10
triesInMemory = 128 TriesInMemory = 128
// BlockChainVersion ensures that an incompatible database forces a resync from scratch. // BlockChainVersion ensures that an incompatible database forces a resync from scratch.
// //
@ -799,7 +799,7 @@ func (bc *BlockChain) Stop() {
if !bc.cacheConfig.TrieDirtyDisabled { if !bc.cacheConfig.TrieDirtyDisabled {
triedb := bc.stateCache.TrieDB() triedb := bc.stateCache.TrieDB()
for _, offset := range []uint64{0, 1, triesInMemory - 1} { for _, offset := range []uint64{0, 1, TriesInMemory - 1} {
if number := bc.CurrentBlock().NumberU64(); number > offset { if number := bc.CurrentBlock().NumberU64(); number > offset {
recent := bc.GetBlockByNumber(number - offset) recent := bc.GetBlockByNumber(number - offset)
@ -1224,7 +1224,7 @@ func (bc *BlockChain) writeBlockWithState(block *types.Block, receipts []*types.
triedb.Reference(root, common.Hash{}) // metadata reference to keep trie alive triedb.Reference(root, common.Hash{}) // metadata reference to keep trie alive
bc.triegc.Push(root, -int64(block.NumberU64())) bc.triegc.Push(root, -int64(block.NumberU64()))
if current := block.NumberU64(); current > triesInMemory { if current := block.NumberU64(); current > TriesInMemory {
// If we exceeded our memory allowance, flush matured singleton nodes to disk // If we exceeded our memory allowance, flush matured singleton nodes to disk
var ( var (
nodes, imgs = triedb.Size() nodes, imgs = triedb.Size()
@ -1234,7 +1234,7 @@ func (bc *BlockChain) writeBlockWithState(block *types.Block, receipts []*types.
triedb.Cap(limit - ethdb.IdealBatchSize) triedb.Cap(limit - ethdb.IdealBatchSize)
} }
// Find the next state trie we need to commit // Find the next state trie we need to commit
chosen := current - triesInMemory chosen := current - TriesInMemory
// If we exceeded out time allowance, flush an entire trie to disk // If we exceeded out time allowance, flush an entire trie to disk
if bc.gcproc > bc.cacheConfig.TrieTimeLimit { if bc.gcproc > bc.cacheConfig.TrieTimeLimit {
@ -1246,8 +1246,8 @@ func (bc *BlockChain) writeBlockWithState(block *types.Block, receipts []*types.
} else { } else {
// If we're exceeding limits but haven't reached a large enough memory gap, // If we're exceeding limits but haven't reached a large enough memory gap,
// warn the user that the system is becoming unstable. // warn the user that the system is becoming unstable.
if chosen < lastWrite+triesInMemory && bc.gcproc >= 2*bc.cacheConfig.TrieTimeLimit { if chosen < lastWrite+TriesInMemory && bc.gcproc >= 2*bc.cacheConfig.TrieTimeLimit {
log.Info("State in memory for too long, committing", "time", bc.gcproc, "allowance", bc.cacheConfig.TrieTimeLimit, "optimum", float64(chosen-lastWrite)/triesInMemory) log.Info("State in memory for too long, committing", "time", bc.gcproc, "allowance", bc.cacheConfig.TrieTimeLimit, "optimum", float64(chosen-lastWrite)/TriesInMemory)
} }
// Flush an entire trie and restart the counters // Flush an entire trie and restart the counters
triedb.Commit(header.Root, true) triedb.Commit(header.Root, true)

@ -1534,7 +1534,7 @@ func TestTrieForkGC(t *testing.T) {
db := rawdb.NewMemoryDatabase() db := rawdb.NewMemoryDatabase()
genesis := new(Genesis).MustCommit(db) genesis := new(Genesis).MustCommit(db)
blocks, _ := GenerateChain(params.TestChainConfig, genesis, engine, db, 2*triesInMemory, func(i int, b *BlockGen) { b.SetCoinbase(common.Address{1}) }) blocks, _ := GenerateChain(params.TestChainConfig, genesis, engine, db, 2*TriesInMemory, func(i int, b *BlockGen) { b.SetCoinbase(common.Address{1}) })
// Generate a bunch of fork blocks, each side forking from the canonical chain // Generate a bunch of fork blocks, each side forking from the canonical chain
forks := make([]*types.Block, len(blocks)) forks := make([]*types.Block, len(blocks))
@ -1563,7 +1563,7 @@ func TestTrieForkGC(t *testing.T) {
} }
} }
// Dereference all the recent tries and ensure no past trie is left in // Dereference all the recent tries and ensure no past trie is left in
for i := 0; i < triesInMemory; i++ { for i := 0; i < TriesInMemory; i++ {
chain.stateCache.TrieDB().Dereference(blocks[len(blocks)-1-i].Root()) chain.stateCache.TrieDB().Dereference(blocks[len(blocks)-1-i].Root())
chain.stateCache.TrieDB().Dereference(forks[len(blocks)-1-i].Root()) chain.stateCache.TrieDB().Dereference(forks[len(blocks)-1-i].Root())
} }
@ -1582,8 +1582,8 @@ func TestLargeReorgTrieGC(t *testing.T) {
genesis := new(Genesis).MustCommit(db) genesis := new(Genesis).MustCommit(db)
shared, _ := GenerateChain(params.TestChainConfig, genesis, engine, db, 64, func(i int, b *BlockGen) { b.SetCoinbase(common.Address{1}) }) shared, _ := GenerateChain(params.TestChainConfig, genesis, engine, db, 64, func(i int, b *BlockGen) { b.SetCoinbase(common.Address{1}) })
original, _ := GenerateChain(params.TestChainConfig, shared[len(shared)-1], engine, db, 2*triesInMemory, func(i int, b *BlockGen) { b.SetCoinbase(common.Address{2}) }) original, _ := GenerateChain(params.TestChainConfig, shared[len(shared)-1], engine, db, 2*TriesInMemory, func(i int, b *BlockGen) { b.SetCoinbase(common.Address{2}) })
competitor, _ := GenerateChain(params.TestChainConfig, shared[len(shared)-1], engine, db, 2*triesInMemory+1, func(i int, b *BlockGen) { b.SetCoinbase(common.Address{3}) }) competitor, _ := GenerateChain(params.TestChainConfig, shared[len(shared)-1], engine, db, 2*TriesInMemory+1, func(i int, b *BlockGen) { b.SetCoinbase(common.Address{3}) })
// Import the shared chain and the original canonical one // Import the shared chain and the original canonical one
diskdb := rawdb.NewMemoryDatabase() diskdb := rawdb.NewMemoryDatabase()
@ -1618,7 +1618,7 @@ func TestLargeReorgTrieGC(t *testing.T) {
if _, err := chain.InsertChain(competitor[len(competitor)-2:]); err != nil { if _, err := chain.InsertChain(competitor[len(competitor)-2:]); err != nil {
t.Fatalf("failed to finalize competitor chain: %v", err) t.Fatalf("failed to finalize competitor chain: %v", err)
} }
for i, block := range competitor[:len(competitor)-triesInMemory] { for i, block := range competitor[:len(competitor)-TriesInMemory] {
if node, _ := chain.stateCache.TrieDB().Node(block.Root()); node != nil { if node, _ := chain.stateCache.TrieDB().Node(block.Root()); node != nil {
t.Fatalf("competitor %d: competing chain state missing", i) t.Fatalf("competitor %d: competing chain state missing", i)
} }
@ -1753,7 +1753,7 @@ func TestLowDiffLongChain(t *testing.T) {
// We must use a pretty long chain to ensure that the fork doesn't overtake us // We must use a pretty long chain to ensure that the fork doesn't overtake us
// until after at least 128 blocks post tip // until after at least 128 blocks post tip
blocks, _ := GenerateChain(params.TestChainConfig, genesis, engine, db, 6*triesInMemory, func(i int, b *BlockGen) { blocks, _ := GenerateChain(params.TestChainConfig, genesis, engine, db, 6*TriesInMemory, func(i int, b *BlockGen) {
b.SetCoinbase(common.Address{1}) b.SetCoinbase(common.Address{1})
b.OffsetTime(-9) b.OffsetTime(-9)
}) })
@ -1771,7 +1771,7 @@ func TestLowDiffLongChain(t *testing.T) {
} }
// Generate fork chain, starting from an early block // Generate fork chain, starting from an early block
parent := blocks[10] parent := blocks[10]
fork, _ := GenerateChain(params.TestChainConfig, parent, engine, db, 8*triesInMemory, func(i int, b *BlockGen) { fork, _ := GenerateChain(params.TestChainConfig, parent, engine, db, 8*TriesInMemory, func(i int, b *BlockGen) {
b.SetCoinbase(common.Address{2}) b.SetCoinbase(common.Address{2})
}) })
@ -1806,7 +1806,7 @@ func testSideImport(t *testing.T, numCanonBlocksInSidechain, blocksBetweenCommon
genesis := new(Genesis).MustCommit(db) genesis := new(Genesis).MustCommit(db)
// Generate and import the canonical chain // Generate and import the canonical chain
blocks, _ := GenerateChain(params.TestChainConfig, genesis, engine, db, 2*triesInMemory, nil) blocks, _ := GenerateChain(params.TestChainConfig, genesis, engine, db, 2*TriesInMemory, nil)
diskdb := rawdb.NewMemoryDatabase() diskdb := rawdb.NewMemoryDatabase()
new(Genesis).MustCommit(diskdb) new(Genesis).MustCommit(diskdb)
chain, err := NewBlockChain(diskdb, nil, params.TestChainConfig, engine, vm.Config{}, nil) chain, err := NewBlockChain(diskdb, nil, params.TestChainConfig, engine, vm.Config{}, nil)
@ -1817,9 +1817,9 @@ func testSideImport(t *testing.T, numCanonBlocksInSidechain, blocksBetweenCommon
t.Fatalf("block %d: failed to insert into chain: %v", n, err) t.Fatalf("block %d: failed to insert into chain: %v", n, err)
} }
lastPrunedIndex := len(blocks) - triesInMemory - 1 lastPrunedIndex := len(blocks) - TriesInMemory - 1
lastPrunedBlock := blocks[lastPrunedIndex] lastPrunedBlock := blocks[lastPrunedIndex]
firstNonPrunedBlock := blocks[len(blocks)-triesInMemory] firstNonPrunedBlock := blocks[len(blocks)-TriesInMemory]
// Verify pruning of lastPrunedBlock // Verify pruning of lastPrunedBlock
if chain.HasBlockAndState(lastPrunedBlock.Hash(), lastPrunedBlock.NumberU64()) { if chain.HasBlockAndState(lastPrunedBlock.Hash(), lastPrunedBlock.NumberU64()) {
@ -1836,7 +1836,7 @@ func testSideImport(t *testing.T, numCanonBlocksInSidechain, blocksBetweenCommon
// Generate fork chain, make it longer than canon // Generate fork chain, make it longer than canon
parentIndex := lastPrunedIndex + blocksBetweenCommonAncestorAndPruneblock parentIndex := lastPrunedIndex + blocksBetweenCommonAncestorAndPruneblock
parent := blocks[parentIndex] parent := blocks[parentIndex]
fork, _ := GenerateChain(params.TestChainConfig, parent, engine, db, 2*triesInMemory, func(i int, b *BlockGen) { fork, _ := GenerateChain(params.TestChainConfig, parent, engine, db, 2*TriesInMemory, func(i int, b *BlockGen) {
b.SetCoinbase(common.Address{2}) b.SetCoinbase(common.Address{2})
}) })
// Prepend the parent(s) // Prepend the parent(s)

@ -19,11 +19,13 @@ package les
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"sync" "sync"
"time" "time"
"github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/les/csvlogger"
"github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/rpc" "github.com/ethereum/go-ethereum/rpc"
) )
@ -99,7 +101,7 @@ func (s tcSubs) send(tc uint64, underrun bool) {
// MinimumCapacity queries minimum assignable capacity for a single client // MinimumCapacity queries minimum assignable capacity for a single client
func (api *PrivateLightServerAPI) MinimumCapacity() hexutil.Uint64 { func (api *PrivateLightServerAPI) MinimumCapacity() hexutil.Uint64 {
return hexutil.Uint64(minCapacity) return hexutil.Uint64(api.server.minCapacity)
} }
// FreeClientCapacity queries the capacity provided for free clients // FreeClientCapacity queries the capacity provided for free clients
@ -115,7 +117,7 @@ func (api *PrivateLightServerAPI) FreeClientCapacity() hexutil.Uint64 {
// Note: assigned capacity can be changed while the client is connected with // Note: assigned capacity can be changed while the client is connected with
// immediate effect. // immediate effect.
func (api *PrivateLightServerAPI) SetClientCapacity(id enode.ID, cap uint64) error { func (api *PrivateLightServerAPI) SetClientCapacity(id enode.ID, cap uint64) error {
if cap != 0 && cap < minCapacity { if cap != 0 && cap < api.server.minCapacity {
return ErrMinCap return ErrMinCap
} }
return api.server.priorityClientPool.setClientCapacity(id, cap) return api.server.priorityClientPool.setClientCapacity(id, cap)
@ -144,6 +146,8 @@ type priorityClientPool struct {
totalCap, totalCapAnnounced uint64 totalCap, totalCapAnnounced uint64
totalConnectedCap, freeClientCap uint64 totalConnectedCap, freeClientCap uint64
maxPeers, priorityCount int maxPeers, priorityCount int
logger *csvlogger.Logger
logTotalPriConn *csvlogger.Channel
subs tcSubs subs tcSubs
updateSchedule []scheduledUpdate updateSchedule []scheduledUpdate
@ -164,12 +168,14 @@ type priorityClientInfo struct {
} }
// newPriorityClientPool creates a new priority client pool // newPriorityClientPool creates a new priority client pool
func newPriorityClientPool(freeClientCap uint64, ps *peerSet, child clientPool) *priorityClientPool { func newPriorityClientPool(freeClientCap uint64, ps *peerSet, child clientPool, metricsLogger, eventLogger *csvlogger.Logger) *priorityClientPool {
return &priorityClientPool{ return &priorityClientPool{
clients: make(map[enode.ID]priorityClientInfo), clients: make(map[enode.ID]priorityClientInfo),
freeClientCap: freeClientCap, freeClientCap: freeClientCap,
ps: ps, ps: ps,
child: child, child: child,
logger: eventLogger,
logTotalPriConn: metricsLogger.NewChannel("totalPriConn", 0),
} }
} }
@ -185,6 +191,7 @@ func (v *priorityClientPool) registerPeer(p *peer) {
id := p.ID() id := p.ID()
c := v.clients[id] c := v.clients[id]
v.logger.Event(fmt.Sprintf("priorityClientPool: registerPeer cap=%d connected=%v, %x", c.cap, c.connected, id.Bytes()))
if c.connected { if c.connected {
return return
} }
@ -192,6 +199,7 @@ func (v *priorityClientPool) registerPeer(p *peer) {
v.child.registerPeer(p) v.child.registerPeer(p)
} }
if c.cap != 0 && v.totalConnectedCap+c.cap > v.totalCap { if c.cap != 0 && v.totalConnectedCap+c.cap > v.totalCap {
v.logger.Event(fmt.Sprintf("priorityClientPool: rejected, %x", id.Bytes()))
go v.ps.Unregister(p.id) go v.ps.Unregister(p.id)
return return
} }
@ -202,6 +210,8 @@ func (v *priorityClientPool) registerPeer(p *peer) {
if c.cap != 0 { if c.cap != 0 {
v.priorityCount++ v.priorityCount++
v.totalConnectedCap += c.cap v.totalConnectedCap += c.cap
v.logger.Event(fmt.Sprintf("priorityClientPool: accepted with %d capacity, %x", c.cap, id.Bytes()))
v.logTotalPriConn.Update(float64(v.totalConnectedCap))
if v.child != nil { if v.child != nil {
v.child.setLimits(v.maxPeers-v.priorityCount, v.totalCap-v.totalConnectedCap) v.child.setLimits(v.maxPeers-v.priorityCount, v.totalCap-v.totalConnectedCap)
} }
@ -217,6 +227,7 @@ func (v *priorityClientPool) unregisterPeer(p *peer) {
id := p.ID() id := p.ID()
c := v.clients[id] c := v.clients[id]
v.logger.Event(fmt.Sprintf("priorityClientPool: unregisterPeer cap=%d connected=%v, %x", c.cap, c.connected, id.Bytes()))
if !c.connected { if !c.connected {
return return
} }
@ -225,6 +236,7 @@ func (v *priorityClientPool) unregisterPeer(p *peer) {
v.clients[id] = c v.clients[id] = c
v.priorityCount-- v.priorityCount--
v.totalConnectedCap -= c.cap v.totalConnectedCap -= c.cap
v.logTotalPriConn.Update(float64(v.totalConnectedCap))
if v.child != nil { if v.child != nil {
v.child.setLimits(v.maxPeers-v.priorityCount, v.totalCap-v.totalConnectedCap) v.child.setLimits(v.maxPeers-v.priorityCount, v.totalCap-v.totalConnectedCap)
} }
@ -299,8 +311,10 @@ func (v *priorityClientPool) setLimitsNow(count int, totalCap uint64) {
if v.priorityCount > count || v.totalConnectedCap > totalCap { if v.priorityCount > count || v.totalConnectedCap > totalCap {
for id, c := range v.clients { for id, c := range v.clients {
if c.connected { if c.connected {
v.logger.Event(fmt.Sprintf("priorityClientPool: setLimitsNow kicked out, %x", id.Bytes()))
c.connected = false c.connected = false
v.totalConnectedCap -= c.cap v.totalConnectedCap -= c.cap
v.logTotalPriConn.Update(float64(v.totalConnectedCap))
v.priorityCount-- v.priorityCount--
v.clients[id] = c v.clients[id] = c
go v.ps.Unregister(c.peer.id) go v.ps.Unregister(c.peer.id)
@ -356,6 +370,7 @@ func (v *priorityClientPool) setClientCapacity(id enode.ID, cap uint64) error {
v.priorityCount-- v.priorityCount--
} }
v.totalConnectedCap += cap - c.cap v.totalConnectedCap += cap - c.cap
v.logTotalPriConn.Update(float64(v.totalConnectedCap))
if v.child != nil { if v.child != nil {
v.child.setLimits(v.maxPeers-v.priorityCount, v.totalCap-v.totalConnectedCap) v.child.setLimits(v.maxPeers-v.priorityCount, v.totalCap-v.totalConnectedCap)
} }
@ -374,6 +389,9 @@ func (v *priorityClientPool) setClientCapacity(id enode.ID, cap uint64) error {
} else { } else {
delete(v.clients, id) delete(v.clients, id)
} }
if c.connected {
v.logger.Event(fmt.Sprintf("priorityClientPool: changed capacity to %d, %x", cap, id.Bytes()))
}
return nil return nil
} }

@ -18,6 +18,7 @@ package les
import ( import (
"encoding/binary" "encoding/binary"
"fmt"
"math" "math"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -26,6 +27,7 @@ import (
"github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/eth" "github.com/ethereum/go-ethereum/eth"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/les/csvlogger"
"github.com/ethereum/go-ethereum/les/flowcontrol" "github.com/ethereum/go-ethereum/les/flowcontrol"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
) )
@ -52,7 +54,7 @@ var (
GetCodeMsg: {0, 80}, GetCodeMsg: {0, 80},
GetProofsV2Msg: {0, 80}, GetProofsV2Msg: {0, 80},
GetHelperTrieProofsMsg: {0, 20}, GetHelperTrieProofsMsg: {0, 20},
SendTxV2Msg: {0, 66000}, SendTxV2Msg: {0, 16500},
GetTxStatusMsg: {0, 50}, GetTxStatusMsg: {0, 50},
} }
// maximum outgoing message size estimates // maximum outgoing message size estimates
@ -66,17 +68,27 @@ var (
SendTxV2Msg: {0, 100}, SendTxV2Msg: {0, 100},
GetTxStatusMsg: {0, 100}, GetTxStatusMsg: {0, 100},
} }
minBufLimit = uint64(50000000 * maxCostFactor) // minimum buffer limit allowed for a client // request amounts that have to fit into the minimum buffer size minBufferMultiplier times
minCapacity = (minBufLimit-1)/bufLimitRatio + 1 // minimum capacity allowed for a client minBufferReqAmount = map[uint64]uint64{
GetBlockHeadersMsg: 192,
GetBlockBodiesMsg: 1,
GetReceiptsMsg: 1,
GetCodeMsg: 1,
GetProofsV2Msg: 1,
GetHelperTrieProofsMsg: 16,
SendTxV2Msg: 8,
GetTxStatusMsg: 64,
}
minBufferMultiplier = 3
) )
const ( const (
maxCostFactor = 2 // ratio of maximum and average cost estimates maxCostFactor = 2 // ratio of maximum and average cost estimates
gfInitWeight = time.Second * 10
gfMaxWeight = time.Hour
gfUsageThreshold = 0.5 gfUsageThreshold = 0.5
gfUsageTC = time.Second gfUsageTC = time.Second
gfDbKey = "_globalCostFactor" gfRaiseTC = time.Second * 200
gfDropTC = time.Second * 50
gfDbKey = "_globalCostFactorV3"
) )
// costTracker is responsible for calculating costs and cost estimates on the // costTracker is responsible for calculating costs and cost estimates on the
@ -94,21 +106,30 @@ type costTracker struct {
inSizeFactor, outSizeFactor float64 inSizeFactor, outSizeFactor float64
gf, utilTarget float64 gf, utilTarget float64
minBufLimit uint64
gfUpdateCh chan gfUpdate gfUpdateCh chan gfUpdate
gfLock sync.RWMutex gfLock sync.RWMutex
totalRechargeCh chan uint64 totalRechargeCh chan uint64
stats map[uint64][]uint64 stats map[uint64][]uint64
logger *csvlogger.Logger
logRecentTime, logRecentAvg, logTotalRecharge, logRelCost *csvlogger.Channel
} }
// newCostTracker creates a cost tracker and loads the cost factor statistics from the database // newCostTracker creates a cost tracker and loads the cost factor statistics from the database.
func newCostTracker(db ethdb.Database, config *eth.Config) *costTracker { // It also returns the minimum capacity that can be assigned to any peer.
func newCostTracker(db ethdb.Database, config *eth.Config, logger *csvlogger.Logger) (*costTracker, uint64) {
utilTarget := float64(config.LightServ) * flowcontrol.FixedPointMultiplier / 100 utilTarget := float64(config.LightServ) * flowcontrol.FixedPointMultiplier / 100
ct := &costTracker{ ct := &costTracker{
db: db, db: db,
stopCh: make(chan chan struct{}), stopCh: make(chan chan struct{}),
utilTarget: utilTarget, utilTarget: utilTarget,
logger: logger,
logRelCost: logger.NewMinMaxChannel("relativeCost", true),
logRecentTime: logger.NewMinMaxChannel("recentTime", true),
logRecentAvg: logger.NewMinMaxChannel("recentAvg", true),
logTotalRecharge: logger.NewChannel("totalRecharge", 0.01),
} }
if config.LightBandwidthIn > 0 { if config.LightBandwidthIn > 0 {
ct.inSizeFactor = utilTarget / float64(config.LightBandwidthIn) ct.inSizeFactor = utilTarget / float64(config.LightBandwidthIn)
@ -123,7 +144,16 @@ func newCostTracker(db ethdb.Database, config *eth.Config) *costTracker {
} }
} }
ct.gfLoop() ct.gfLoop()
return ct costList := ct.makeCostList(ct.globalFactor() * 1.25)
for _, c := range costList {
amount := minBufferReqAmount[c.MsgCode]
cost := c.BaseCost + amount*c.ReqCost
if cost > ct.minBufLimit {
ct.minBufLimit = cost
}
}
ct.minBufLimit *= uint64(minBufferMultiplier)
return ct, (ct.minBufLimit-1)/bufLimitRatio + 1
} }
// stop stops the cost tracker and saves the cost factor statistics to the database // stop stops the cost tracker and saves the cost factor statistics to the database
@ -138,16 +168,14 @@ func (ct *costTracker) stop() {
// makeCostList returns upper cost estimates based on the hardcoded cost estimate // makeCostList returns upper cost estimates based on the hardcoded cost estimate
// tables and the optionally specified incoming/outgoing bandwidth limits // tables and the optionally specified incoming/outgoing bandwidth limits
func (ct *costTracker) makeCostList() RequestCostList { func (ct *costTracker) makeCostList(globalFactor float64) RequestCostList {
maxCost := func(avgTime, inSize, outSize uint64) uint64 { maxCost := func(avgTimeCost, inSize, outSize uint64) uint64 {
globalFactor := ct.globalFactor() cost := avgTimeCost * maxCostFactor
inSizeCost := uint64(float64(inSize) * ct.inSizeFactor * globalFactor)
cost := avgTime * maxCostFactor
inSizeCost := uint64(float64(inSize) * ct.inSizeFactor * globalFactor * maxCostFactor)
if inSizeCost > cost { if inSizeCost > cost {
cost = inSizeCost cost = inSizeCost
} }
outSizeCost := uint64(float64(outSize) * ct.outSizeFactor * globalFactor * maxCostFactor) outSizeCost := uint64(float64(outSize) * ct.outSizeFactor * globalFactor)
if outSizeCost > cost { if outSizeCost > cost {
cost = outSizeCost cost = outSizeCost
} }
@ -155,17 +183,29 @@ func (ct *costTracker) makeCostList() RequestCostList {
} }
var list RequestCostList var list RequestCostList
for code, data := range reqAvgTimeCost { for code, data := range reqAvgTimeCost {
baseCost := maxCost(data.baseCost, reqMaxInSize[code].baseCost, reqMaxOutSize[code].baseCost)
reqCost := maxCost(data.reqCost, reqMaxInSize[code].reqCost, reqMaxOutSize[code].reqCost)
if ct.minBufLimit != 0 {
// if minBufLimit is set then always enforce maximum request cost <= minBufLimit
maxCost := baseCost + reqCost*minBufferReqAmount[code]
if maxCost > ct.minBufLimit {
mul := 0.999 * float64(ct.minBufLimit) / float64(maxCost)
baseCost = uint64(float64(baseCost) * mul)
reqCost = uint64(float64(reqCost) * mul)
}
}
list = append(list, requestCostListItem{ list = append(list, requestCostListItem{
MsgCode: code, MsgCode: code,
BaseCost: maxCost(data.baseCost, reqMaxInSize[code].baseCost, reqMaxOutSize[code].baseCost), BaseCost: baseCost,
ReqCost: maxCost(data.reqCost, reqMaxInSize[code].reqCost, reqMaxOutSize[code].reqCost), ReqCost: reqCost,
}) })
} }
return list return list
} }
type gfUpdate struct { type gfUpdate struct {
avgTime, servingTime float64 avgTimeCost, servingTime float64
} }
// gfLoop starts an event loop which updates the global cost factor which is // gfLoop starts an event loop which updates the global cost factor which is
@ -178,45 +218,74 @@ type gfUpdate struct {
// total allowed serving time per second but nominated in cost units, should // total allowed serving time per second but nominated in cost units, should
// also be scaled with the cost factor and is also updated by this loop. // also be scaled with the cost factor and is also updated by this loop.
func (ct *costTracker) gfLoop() { func (ct *costTracker) gfLoop() {
var gfUsage, gfSum, gfWeight float64 var gfLog, recentTime, recentAvg float64
lastUpdate := mclock.Now() lastUpdate := mclock.Now()
expUpdate := lastUpdate expUpdate := lastUpdate
data, _ := ct.db.Get([]byte(gfDbKey)) data, _ := ct.db.Get([]byte(gfDbKey))
if len(data) == 16 { if len(data) == 8 {
gfSum = math.Float64frombits(binary.BigEndian.Uint64(data[0:8])) gfLog = math.Float64frombits(binary.BigEndian.Uint64(data[:]))
gfWeight = math.Float64frombits(binary.BigEndian.Uint64(data[8:16]))
} }
if gfWeight < float64(gfInitWeight) { gf := math.Exp(gfLog)
gfSum = float64(gfInitWeight)
gfWeight = float64(gfInitWeight)
}
gf := gfSum / gfWeight
ct.gf = gf ct.gf = gf
totalRecharge := ct.utilTarget * gf
ct.gfUpdateCh = make(chan gfUpdate, 100) ct.gfUpdateCh = make(chan gfUpdate, 100)
threshold := gfUsageThreshold * float64(gfUsageTC) * ct.utilTarget / 1000000
go func() { go func() {
saveCostFactor := func() {
var data [8]byte
binary.BigEndian.PutUint64(data[:], math.Float64bits(gfLog))
ct.db.Put([]byte(gfDbKey), data[:])
log.Debug("global cost factor saved", "value", gf)
}
saveTicker := time.NewTicker(time.Minute * 10)
for { for {
select { select {
case r := <-ct.gfUpdateCh: case r := <-ct.gfUpdateCh:
now := mclock.Now() now := mclock.Now()
max := r.servingTime * gf if ct.logRelCost != nil && r.avgTimeCost > 1e-20 {
if r.avgTime > max { ct.logRelCost.Update(r.servingTime * gf / r.avgTimeCost)
max = r.avgTime }
if r.servingTime > 1000000000 {
ct.logger.Event(fmt.Sprintf("Very long servingTime = %f avgTimeCost = %f costFactor = %f", r.servingTime, r.avgTimeCost, gf))
} }
dt := float64(now - expUpdate) dt := float64(now - expUpdate)
expUpdate = now expUpdate = now
gfUsage = gfUsage*math.Exp(-dt/float64(gfUsageTC)) + max*1000000/float64(gfUsageTC) exp := math.Exp(-dt / float64(gfUsageTC))
// calculate gf correction until now, based on previous values
var gfCorr float64
max := recentTime
if recentAvg > max {
max = recentAvg
}
// we apply continuous correction when MAX(recentTime, recentAvg) > threshold
if max > threshold {
// calculate correction time between last expUpdate and now
if max*exp >= threshold {
gfCorr = dt
} else {
gfCorr = math.Log(max/threshold) * float64(gfUsageTC)
}
// calculate log(gf) correction with the right direction and time constant
if recentTime > recentAvg {
// drop gf if actual serving times are larger than average estimates
gfCorr /= -float64(gfDropTC)
} else {
// raise gf if actual serving times are smaller than average estimates
gfCorr /= float64(gfRaiseTC)
}
}
// update recent cost values with current request
recentTime = recentTime*exp + r.servingTime
recentAvg = recentAvg*exp + r.avgTimeCost/gf
if gfUsage >= gfUsageThreshold*ct.utilTarget*gf { if gfCorr != 0 {
gfSum += r.avgTime gfLog += gfCorr
gfWeight += r.servingTime gf = math.Exp(gfLog)
if time.Duration(now-lastUpdate) > time.Second { if time.Duration(now-lastUpdate) > time.Second {
gf = gfSum / gfWeight totalRecharge = ct.utilTarget * gf
if gfWeight >= float64(gfMaxWeight) {
gfSum = gf * float64(gfMaxWeight)
gfWeight = float64(gfMaxWeight)
}
lastUpdate = now lastUpdate = now
ct.gfLock.Lock() ct.gfLock.Lock()
ct.gf = gf ct.gf = gf
@ -224,19 +293,22 @@ func (ct *costTracker) gfLoop() {
ct.gfLock.Unlock() ct.gfLock.Unlock()
if ch != nil { if ch != nil {
select { select {
case ct.totalRechargeCh <- uint64(ct.utilTarget * gf): case ct.totalRechargeCh <- uint64(totalRecharge):
default: default:
} }
} }
log.Debug("global cost factor updated", "gf", gf, "weight", time.Duration(gfWeight)) log.Debug("global cost factor updated", "gf", gf)
} }
} }
ct.logRecentTime.Update(recentTime)
ct.logRecentAvg.Update(recentAvg)
ct.logTotalRecharge.Update(totalRecharge)
case <-saveTicker.C:
saveCostFactor()
case stopCh := <-ct.stopCh: case stopCh := <-ct.stopCh:
var data [16]byte saveCostFactor()
binary.BigEndian.PutUint64(data[0:8], math.Float64bits(gfSum))
binary.BigEndian.PutUint64(data[8:16], math.Float64bits(gfWeight))
ct.db.Put([]byte(gfDbKey), data[:])
log.Debug("global cost factor saved", "sum", time.Duration(gfSum), "weight", time.Duration(gfWeight))
close(stopCh) close(stopCh)
return return
} }
@ -275,15 +347,15 @@ func (ct *costTracker) subscribeTotalRecharge(ch chan uint64) uint64 {
// average estimate statistics // average estimate statistics
func (ct *costTracker) updateStats(code, amount, servingTime, realCost uint64) { func (ct *costTracker) updateStats(code, amount, servingTime, realCost uint64) {
avg := reqAvgTimeCost[code] avg := reqAvgTimeCost[code]
avgTime := avg.baseCost + amount*avg.reqCost avgTimeCost := avg.baseCost + amount*avg.reqCost
select { select {
case ct.gfUpdateCh <- gfUpdate{float64(avgTime), float64(servingTime)}: case ct.gfUpdateCh <- gfUpdate{float64(avgTimeCost), float64(servingTime)}:
default: default:
} }
if makeCostStats { if makeCostStats {
realCost <<= 4 realCost <<= 4
l := 0 l := 0
for l < 9 && realCost > avgTime { for l < 9 && realCost > avgTimeCost {
l++ l++
realCost >>= 1 realCost >>= 1
} }
@ -339,8 +411,8 @@ type (
} }
) )
// getCost calculates the estimated cost for a given request type and amount // getMaxCost calculates the estimated cost for a given request type and amount
func (table requestCostTable) getCost(code, amount uint64) uint64 { func (table requestCostTable) getMaxCost(code, amount uint64) uint64 {
costs := table[code] costs := table[code]
return costs.baseCost + amount*costs.reqCost return costs.baseCost + amount*costs.reqCost
} }
@ -360,7 +432,7 @@ func (list RequestCostList) decode(protocolLength uint64) requestCostTable {
} }
// testCostList returns a dummy request cost list used by tests // testCostList returns a dummy request cost list used by tests
func testCostList() RequestCostList { func testCostList(testCost uint64) RequestCostList {
cl := make(RequestCostList, len(reqAvgTimeCost)) cl := make(RequestCostList, len(reqAvgTimeCost))
var max uint64 var max uint64
for code := range reqAvgTimeCost { for code := range reqAvgTimeCost {
@ -372,7 +444,7 @@ func testCostList() RequestCostList {
for code := uint64(0); code <= max; code++ { for code := uint64(0); code <= max; code++ {
if _, ok := reqAvgTimeCost[code]; ok { if _, ok := reqAvgTimeCost[code]; ok {
cl[i].MsgCode = code cl[i].MsgCode = code
cl[i].BaseCost = 0 cl[i].BaseCost = testCost
cl[i].ReqCost = 0 cl[i].ReqCost = 0
i++ i++
} }

@ -0,0 +1,227 @@
// Copyright 2019 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package csvlogger
import (
"fmt"
"os"
"sync"
"time"
"github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/log"
)
// Logger is a metrics/events logger that writes logged values and events into a comma separated file
type Logger struct {
file *os.File
started mclock.AbsTime
channels []*Channel
period time.Duration
stopCh, stopped chan struct{}
storeCh chan string
eventHeader string
}
// NewLogger creates a new Logger
func NewLogger(fileName string, updatePeriod time.Duration, eventHeader string) *Logger {
if fileName == "" {
return nil
}
f, err := os.Create(fileName)
if err != nil {
log.Error("Error creating log file", "name", fileName, "error", err)
return nil
}
return &Logger{
file: f,
period: updatePeriod,
stopCh: make(chan struct{}),
storeCh: make(chan string, 1),
eventHeader: eventHeader,
}
}
// NewChannel creates a new value logger channel that writes values in a single
// column. If the relative change of the value is bigger than the given threshold
// then a new line is added immediately (threshold can also be 0).
func (l *Logger) NewChannel(name string, threshold float64) *Channel {
if l == nil {
return nil
}
c := &Channel{
logger: l,
name: name,
threshold: threshold,
}
l.channels = append(l.channels, c)
return c
}
// NewMinMaxChannel creates a new value logger channel that writes the minimum and
// maximum of the tracked value in two columns. It never triggers adding a new line.
// If zeroDefault is true then 0 is written to both min and max columns if no update
// was given during the last period. If it is false then the last update will appear
// in both columns.
func (l *Logger) NewMinMaxChannel(name string, zeroDefault bool) *Channel {
if l == nil {
return nil
}
c := &Channel{
logger: l,
name: name,
minmax: true,
mmZeroDefault: zeroDefault,
}
l.channels = append(l.channels, c)
return c
}
func (l *Logger) store(event string) {
s := fmt.Sprintf("%g", float64(mclock.Now()-l.started)/1000000000)
for _, ch := range l.channels {
s += ", " + ch.store()
}
if event != "" {
s += ", " + event
}
l.file.WriteString(s + "\n")
}
// Start writes the header line and starts the logger
func (l *Logger) Start() {
if l == nil {
return
}
l.started = mclock.Now()
s := "Time"
for _, ch := range l.channels {
s += ", " + ch.header()
}
if l.eventHeader != "" {
s += ", " + l.eventHeader
}
l.file.WriteString(s + "\n")
go func() {
timer := time.NewTimer(l.period)
for {
select {
case <-timer.C:
l.store("")
timer.Reset(l.period)
case event := <-l.storeCh:
l.store(event)
if !timer.Stop() {
<-timer.C
}
timer.Reset(l.period)
case <-l.stopCh:
close(l.stopped)
return
}
}
}()
}
// Stop stops the logger and closes the file
func (l *Logger) Stop() {
if l == nil {
return
}
l.stopped = make(chan struct{})
close(l.stopCh)
<-l.stopped
l.file.Close()
}
// Event immediately adds a new line and adds the given event string in the last column
func (l *Logger) Event(event string) {
if l == nil {
return
}
select {
case l.storeCh <- event:
case <-l.stopCh:
}
}
// Channel represents a logger channel tracking a single value
type Channel struct {
logger *Logger
lock sync.Mutex
name string
threshold, storeMin, storeMax, lastValue, min, max float64
minmax, mmSet, mmZeroDefault bool
}
// Update updates the tracked value
func (lc *Channel) Update(value float64) {
if lc == nil {
return
}
lc.lock.Lock()
defer lc.lock.Unlock()
lc.lastValue = value
if lc.minmax {
if value > lc.max || !lc.mmSet {
lc.max = value
}
if value < lc.min || !lc.mmSet {
lc.min = value
}
lc.mmSet = true
} else {
if value < lc.storeMin || value > lc.storeMax {
select {
case lc.logger.storeCh <- "":
default:
}
}
}
}
func (lc *Channel) store() (s string) {
lc.lock.Lock()
defer lc.lock.Unlock()
if lc.minmax {
s = fmt.Sprintf("%g, %g", lc.min, lc.max)
lc.mmSet = false
if lc.mmZeroDefault {
lc.min = 0
} else {
lc.min = lc.lastValue
}
lc.max = lc.min
} else {
s = fmt.Sprintf("%g", lc.lastValue)
lc.storeMin = lc.lastValue * (1 - lc.threshold)
lc.storeMax = lc.lastValue * (1 + lc.threshold)
if lc.lastValue < 0 {
lc.storeMin, lc.storeMax = lc.storeMax, lc.storeMin
}
}
return
}
func (lc *Channel) header() string {
if lc.minmax {
return lc.name + " (min), " + lc.name + " (max)"
}
return lc.name
}

@ -62,9 +62,10 @@ type distReq struct {
canSend func(distPeer) bool canSend func(distPeer) bool
request func(distPeer) func() request func(distPeer) func()
reqOrder uint64 reqOrder uint64
sentChn chan distPeer sentChn chan distPeer
element *list.Element element *list.Element
waitForPeers mclock.AbsTime
} }
// newRequestDistributor creates a new request distributor // newRequestDistributor creates a new request distributor
@ -106,7 +107,11 @@ func (d *requestDistributor) registerTestPeer(p distPeer) {
// distMaxWait is the maximum waiting time after which further necessary waiting // distMaxWait is the maximum waiting time after which further necessary waiting
// times are recalculated based on new feedback from the servers // times are recalculated based on new feedback from the servers
const distMaxWait = time.Millisecond * 10 const distMaxWait = time.Millisecond * 50
// waitForPeers is the time window in which a request does not fail even if it
// has no suitable peers to send to at the moment
const waitForPeers = time.Second * 3
// main event loop // main event loop
func (d *requestDistributor) loop() { func (d *requestDistributor) loop() {
@ -179,8 +184,6 @@ func (d *requestDistributor) nextRequest() (distPeer, *distReq, time.Duration) {
checkedPeers := make(map[distPeer]struct{}) checkedPeers := make(map[distPeer]struct{})
elem := d.reqQueue.Front() elem := d.reqQueue.Front()
var ( var (
bestPeer distPeer
bestReq *distReq
bestWait time.Duration bestWait time.Duration
sel *weightedRandomSelect sel *weightedRandomSelect
) )
@ -188,9 +191,18 @@ func (d *requestDistributor) nextRequest() (distPeer, *distReq, time.Duration) {
d.peerLock.RLock() d.peerLock.RLock()
defer d.peerLock.RUnlock() defer d.peerLock.RUnlock()
for (len(d.peers) > 0 || elem == d.reqQueue.Front()) && elem != nil { peerCount := len(d.peers)
for (len(checkedPeers) < peerCount || elem == d.reqQueue.Front()) && elem != nil {
req := elem.Value.(*distReq) req := elem.Value.(*distReq)
canSend := false canSend := false
now := d.clock.Now()
if req.waitForPeers > now {
canSend = true
wait := time.Duration(req.waitForPeers - now)
if bestWait == 0 || wait < bestWait {
bestWait = wait
}
}
for peer := range d.peers { for peer := range d.peers {
if _, ok := checkedPeers[peer]; !ok && peer.canQueue() && req.canSend(peer) { if _, ok := checkedPeers[peer]; !ok && peer.canQueue() && req.canSend(peer) {
canSend = true canSend = true
@ -202,9 +214,7 @@ func (d *requestDistributor) nextRequest() (distPeer, *distReq, time.Duration) {
} }
sel.update(selectPeerItem{peer: peer, req: req, weight: int64(bufRemain*1000000) + 1}) sel.update(selectPeerItem{peer: peer, req: req, weight: int64(bufRemain*1000000) + 1})
} else { } else {
if bestReq == nil || wait < bestWait { if bestWait == 0 || wait < bestWait {
bestPeer = peer
bestReq = req
bestWait = wait bestWait = wait
} }
} }
@ -223,7 +233,7 @@ func (d *requestDistributor) nextRequest() (distPeer, *distReq, time.Duration) {
c := sel.choose().(selectPeerItem) c := sel.choose().(selectPeerItem)
return c.peer, c.req, 0 return c.peer, c.req, 0
} }
return bestPeer, bestReq, bestWait return nil, nil, bestWait
} }
// queue adds a request to the distribution queue, returns a channel where the // queue adds a request to the distribution queue, returns a channel where the
@ -237,6 +247,7 @@ func (d *requestDistributor) queue(r *distReq) chan distPeer {
if r.reqOrder == 0 { if r.reqOrder == 0 {
d.lastReqOrder++ d.lastReqOrder++
r.reqOrder = d.lastReqOrder r.reqOrder = d.lastReqOrder
r.waitForPeers = d.clock.Now() + mclock.AbsTime(waitForPeers)
} }
back := d.reqQueue.Back() back := d.reqQueue.Back()

@ -44,7 +44,7 @@ func (q *execQueue) loop() {
func (q *execQueue) waitNext(drop bool) (f func()) { func (q *execQueue) waitNext(drop bool) (f func()) {
q.mu.Lock() q.mu.Lock()
if drop { if drop && len(q.funcs) > 0 {
// Remove the function that just executed. We do this here instead of when // Remove the function that just executed. We do this here instead of when
// dequeuing so len(q.funcs) includes the function that is running. // dequeuing so len(q.funcs) includes the function that is running.
q.funcs = append(q.funcs[:0], q.funcs[1:]...) q.funcs = append(q.funcs[:0], q.funcs[1:]...)
@ -84,6 +84,13 @@ func (q *execQueue) queue(f func()) bool {
return ok return ok
} }
// clear drops all queued functions
func (q *execQueue) clear() {
q.mu.Lock()
q.funcs = q.funcs[:0]
q.mu.Unlock()
}
// quit stops the exec queue. // quit stops the exec queue.
// quit waits for the current execution to finish before returning. // quit waits for the current execution to finish before returning.
func (q *execQueue) quit() { func (q *execQueue) quit() {

@ -56,11 +56,12 @@ type scheduledUpdate struct {
// (used in server mode only) // (used in server mode only)
type ClientNode struct { type ClientNode struct {
params ServerParams params ServerParams
bufValue uint64 bufValue int64
lastTime mclock.AbsTime lastTime mclock.AbsTime
updateSchedule []scheduledUpdate updateSchedule []scheduledUpdate
sumCost uint64 // sum of req costs received from this client sumCost uint64 // sum of req costs received from this client
accepted map[uint64]uint64 // value = sumCost after accepting the given req accepted map[uint64]uint64 // value = sumCost after accepting the given req
connected bool
lock sync.Mutex lock sync.Mutex
cm *ClientManager cm *ClientManager
log *logger log *logger
@ -70,11 +71,12 @@ type ClientNode struct {
// NewClientNode returns a new ClientNode // NewClientNode returns a new ClientNode
func NewClientNode(cm *ClientManager, params ServerParams) *ClientNode { func NewClientNode(cm *ClientManager, params ServerParams) *ClientNode {
node := &ClientNode{ node := &ClientNode{
cm: cm, cm: cm,
params: params, params: params,
bufValue: params.BufLimit, bufValue: int64(params.BufLimit),
lastTime: cm.clock.Now(), lastTime: cm.clock.Now(),
accepted: make(map[uint64]uint64), accepted: make(map[uint64]uint64),
connected: true,
} }
if keepLogs > 0 { if keepLogs > 0 {
node.log = newLogger(keepLogs) node.log = newLogger(keepLogs)
@ -85,9 +87,55 @@ func NewClientNode(cm *ClientManager, params ServerParams) *ClientNode {
// Disconnect should be called when a client is disconnected // Disconnect should be called when a client is disconnected
func (node *ClientNode) Disconnect() { func (node *ClientNode) Disconnect() {
node.lock.Lock()
defer node.lock.Unlock()
node.connected = false
node.cm.disconnect(node) node.cm.disconnect(node)
} }
// BufferStatus returns the current buffer value and limit
func (node *ClientNode) BufferStatus() (uint64, uint64) {
node.lock.Lock()
defer node.lock.Unlock()
if !node.connected {
return 0, 0
}
now := node.cm.clock.Now()
node.update(now)
node.cm.updateBuffer(node, 0, now)
bv := node.bufValue
if bv < 0 {
bv = 0
}
return uint64(bv), node.params.BufLimit
}
// OneTimeCost subtracts the given amount from the node's buffer.
//
// Note: this call can take the buffer into the negative region internally.
// In this case zero buffer value is returned by exported calls and no requests
// are accepted.
func (node *ClientNode) OneTimeCost(cost uint64) {
node.lock.Lock()
defer node.lock.Unlock()
now := node.cm.clock.Now()
node.update(now)
node.bufValue -= int64(cost)
node.cm.updateBuffer(node, -int64(cost), now)
}
// Freeze notifies the client manager about a client freeze event in which case
// the total capacity allowance is slightly reduced.
func (node *ClientNode) Freeze() {
node.lock.Lock()
frozenCap := node.params.MinRecharge
node.lock.Unlock()
node.cm.reduceTotalCapacity(frozenCap)
}
// update recalculates the buffer value at a specified time while also performing // update recalculates the buffer value at a specified time while also performing
// scheduled flow control parameter updates if necessary // scheduled flow control parameter updates if necessary
func (node *ClientNode) update(now mclock.AbsTime) { func (node *ClientNode) update(now mclock.AbsTime) {
@ -105,9 +153,9 @@ func (node *ClientNode) recalcBV(now mclock.AbsTime) {
if now < node.lastTime { if now < node.lastTime {
dt = 0 dt = 0
} }
node.bufValue += node.params.MinRecharge * dt / uint64(fcTimeConst) node.bufValue += int64(node.params.MinRecharge * dt / uint64(fcTimeConst))
if node.bufValue > node.params.BufLimit { if node.bufValue > int64(node.params.BufLimit) {
node.bufValue = node.params.BufLimit node.bufValue = int64(node.params.BufLimit)
} }
if node.log != nil { if node.log != nil {
node.log.add(now, fmt.Sprintf("updated bv=%d MRR=%d BufLimit=%d", node.bufValue, node.params.MinRecharge, node.params.BufLimit)) node.log.add(now, fmt.Sprintf("updated bv=%d MRR=%d BufLimit=%d", node.bufValue, node.params.MinRecharge, node.params.BufLimit))
@ -139,11 +187,11 @@ func (node *ClientNode) UpdateParams(params ServerParams) {
// updateParams updates the flow control parameters of the node // updateParams updates the flow control parameters of the node
func (node *ClientNode) updateParams(params ServerParams, now mclock.AbsTime) { func (node *ClientNode) updateParams(params ServerParams, now mclock.AbsTime) {
diff := params.BufLimit - node.params.BufLimit diff := int64(params.BufLimit - node.params.BufLimit)
if int64(diff) > 0 { if diff > 0 {
node.bufValue += diff node.bufValue += diff
} else if node.bufValue > params.BufLimit { } else if node.bufValue > int64(params.BufLimit) {
node.bufValue = params.BufLimit node.bufValue = int64(params.BufLimit)
} }
node.cm.updateParams(node, params, now) node.cm.updateParams(node, params, now)
} }
@ -157,14 +205,14 @@ func (node *ClientNode) AcceptRequest(reqID, index, maxCost uint64) (accepted bo
now := node.cm.clock.Now() now := node.cm.clock.Now()
node.update(now) node.update(now)
if maxCost > node.bufValue { if int64(maxCost) > node.bufValue {
if node.log != nil { if node.log != nil {
node.log.add(now, fmt.Sprintf("rejected reqID=%d bv=%d maxCost=%d", reqID, node.bufValue, maxCost)) node.log.add(now, fmt.Sprintf("rejected reqID=%d bv=%d maxCost=%d", reqID, node.bufValue, maxCost))
node.log.dump(now) node.log.dump(now)
} }
return false, maxCost - node.bufValue, 0 return false, maxCost - uint64(node.bufValue), 0
} }
node.bufValue -= maxCost node.bufValue -= int64(maxCost)
node.sumCost += maxCost node.sumCost += maxCost
if node.log != nil { if node.log != nil {
node.log.add(now, fmt.Sprintf("accepted reqID=%d bv=%d maxCost=%d sumCost=%d", reqID, node.bufValue, maxCost, node.sumCost)) node.log.add(now, fmt.Sprintf("accepted reqID=%d bv=%d maxCost=%d sumCost=%d", reqID, node.bufValue, maxCost, node.sumCost))
@ -174,19 +222,22 @@ func (node *ClientNode) AcceptRequest(reqID, index, maxCost uint64) (accepted bo
} }
// RequestProcessed should be called when the request has been processed // RequestProcessed should be called when the request has been processed
func (node *ClientNode) RequestProcessed(reqID, index, maxCost, realCost uint64) (bv uint64) { func (node *ClientNode) RequestProcessed(reqID, index, maxCost, realCost uint64) uint64 {
node.lock.Lock() node.lock.Lock()
defer node.lock.Unlock() defer node.lock.Unlock()
now := node.cm.clock.Now() now := node.cm.clock.Now()
node.update(now) node.update(now)
node.cm.processed(node, maxCost, realCost, now) node.cm.processed(node, maxCost, realCost, now)
bv = node.bufValue + node.sumCost - node.accepted[index] bv := node.bufValue + int64(node.sumCost-node.accepted[index])
if node.log != nil { if node.log != nil {
node.log.add(now, fmt.Sprintf("processed reqID=%d bv=%d maxCost=%d realCost=%d sumCost=%d oldSumCost=%d reportedBV=%d", reqID, node.bufValue, maxCost, realCost, node.sumCost, node.accepted[index], bv)) node.log.add(now, fmt.Sprintf("processed reqID=%d bv=%d maxCost=%d realCost=%d sumCost=%d oldSumCost=%d reportedBV=%d", reqID, node.bufValue, maxCost, realCost, node.sumCost, node.accepted[index], bv))
} }
delete(node.accepted, index) delete(node.accepted, index)
return if bv < 0 {
return 0
}
return uint64(bv)
} }
// ServerNode is the flow control system's representation of a server // ServerNode is the flow control system's representation of a server
@ -345,6 +396,28 @@ func (node *ServerNode) ReceivedReply(reqID, bv uint64) {
} }
} }
// ResumeFreeze cleans all pending requests and sets the buffer estimate to the
// reported value after resuming from a frozen state
func (node *ServerNode) ResumeFreeze(bv uint64) {
node.lock.Lock()
defer node.lock.Unlock()
for reqID := range node.pending {
delete(node.pending, reqID)
}
now := node.clock.Now()
node.recalcBLE(now)
if bv > node.params.BufLimit {
bv = node.params.BufLimit
}
node.bufEstimate = bv
node.bufRecharge = node.bufEstimate < node.params.BufLimit
node.lastTime = now
if node.log != nil {
node.log.add(now, fmt.Sprintf("unfreeze bv=%d sumCost=%d", bv, node.sumCost))
}
}
// DumpLogs dumps the event log if logging is used // DumpLogs dumps the event log if logging is used
func (node *ServerNode) DumpLogs() { func (node *ServerNode) DumpLogs() {
node.lock.Lock() node.lock.Lock()

@ -47,9 +47,9 @@ type cmNodeFields struct {
const FixedPointMultiplier = 1000000 const FixedPointMultiplier = 1000000
var ( var (
capFactorDropTC = 1 / float64(time.Second*10) // time constant for dropping the capacity factor capacityDropFactor = 0.1
capFactorRaiseTC = 1 / float64(time.Hour) // time constant for raising the capacity factor capacityRaiseTC = 1 / (3 * float64(time.Hour)) // time constant for raising the capacity factor
capFactorRaiseThreshold = 0.75 // connected / total capacity ratio threshold for raising the capacity factor capacityRaiseThresholdRatio = 1.125 // total/connected capacity ratio threshold for raising the capacity factor
) )
// ClientManager controls the capacity assigned to the clients of a server. // ClientManager controls the capacity assigned to the clients of a server.
@ -61,10 +61,14 @@ type ClientManager struct {
clock mclock.Clock clock mclock.Clock
lock sync.Mutex lock sync.Mutex
enabledCh chan struct{} enabledCh chan struct{}
stop chan chan struct{}
curve PieceWiseLinear curve PieceWiseLinear
sumRecharge, totalRecharge, totalConnected uint64 sumRecharge, totalRecharge, totalConnected uint64
capLogFactor, totalCapacity float64 logTotalCap, totalCapacity float64
logTotalCapRaiseLimit float64
minLogTotalCap, maxLogTotalCap float64
capacityRaiseThreshold uint64
capLastUpdate mclock.AbsTime capLastUpdate mclock.AbsTime
totalCapacityCh chan uint64 totalCapacityCh chan uint64
@ -106,13 +110,35 @@ func NewClientManager(curve PieceWiseLinear, clock mclock.Clock) *ClientManager
clock: clock, clock: clock,
rcQueue: prque.New(func(a interface{}, i int) { a.(*ClientNode).queueIndex = i }), rcQueue: prque.New(func(a interface{}, i int) { a.(*ClientNode).queueIndex = i }),
capLastUpdate: clock.Now(), capLastUpdate: clock.Now(),
stop: make(chan chan struct{}),
} }
if curve != nil { if curve != nil {
cm.SetRechargeCurve(curve) cm.SetRechargeCurve(curve)
} }
go func() {
// regularly recalculate and update total capacity
for {
select {
case <-time.After(time.Minute):
cm.lock.Lock()
cm.updateTotalCapacity(cm.clock.Now(), true)
cm.lock.Unlock()
case stop := <-cm.stop:
close(stop)
return
}
}
}()
return cm return cm
} }
// Stop stops the client manager
func (cm *ClientManager) Stop() {
stop := make(chan struct{})
cm.stop <- stop
<-stop
}
// SetRechargeCurve updates the recharge curve // SetRechargeCurve updates the recharge curve
func (cm *ClientManager) SetRechargeCurve(curve PieceWiseLinear) { func (cm *ClientManager) SetRechargeCurve(curve PieceWiseLinear) {
cm.lock.Lock() cm.lock.Lock()
@ -120,13 +146,29 @@ func (cm *ClientManager) SetRechargeCurve(curve PieceWiseLinear) {
now := cm.clock.Now() now := cm.clock.Now()
cm.updateRecharge(now) cm.updateRecharge(now)
cm.updateCapFactor(now, false)
cm.curve = curve cm.curve = curve
if len(curve) > 0 { if len(curve) > 0 {
cm.totalRecharge = curve[len(curve)-1].Y cm.totalRecharge = curve[len(curve)-1].Y
} else { } else {
cm.totalRecharge = 0 cm.totalRecharge = 0
} }
}
// SetCapacityRaiseThreshold sets a threshold value used for raising capFactor.
// Either if the difference between total allowed and connected capacity is less
// than this threshold or if their ratio is less than capacityRaiseThresholdRatio
// then capFactor is allowed to slowly raise.
func (cm *ClientManager) SetCapacityLimits(min, max, raiseThreshold uint64) {
if min < 1 {
min = 1
}
cm.minLogTotalCap = math.Log(float64(min))
if max < 1 {
max = 1
}
cm.maxLogTotalCap = math.Log(float64(max))
cm.logTotalCap = cm.maxLogTotalCap
cm.capacityRaiseThreshold = raiseThreshold
cm.refreshCapacity() cm.refreshCapacity()
} }
@ -141,8 +183,9 @@ func (cm *ClientManager) connect(node *ClientNode) {
node.corrBufValue = int64(node.params.BufLimit) node.corrBufValue = int64(node.params.BufLimit)
node.rcLastIntValue = cm.rcLastIntValue node.rcLastIntValue = cm.rcLastIntValue
node.queueIndex = -1 node.queueIndex = -1
cm.updateCapFactor(now, true) cm.updateTotalCapacity(now, true)
cm.totalConnected += node.params.MinRecharge cm.totalConnected += node.params.MinRecharge
cm.updateRaiseLimit()
} }
// disconnect should be called when a client is disconnected // disconnect should be called when a client is disconnected
@ -152,8 +195,9 @@ func (cm *ClientManager) disconnect(node *ClientNode) {
now := cm.clock.Now() now := cm.clock.Now()
cm.updateRecharge(cm.clock.Now()) cm.updateRecharge(cm.clock.Now())
cm.updateCapFactor(now, true) cm.updateTotalCapacity(now, true)
cm.totalConnected -= node.params.MinRecharge cm.totalConnected -= node.params.MinRecharge
cm.updateRaiseLimit()
} }
// accepted is called when a request with given maximum cost is accepted. // accepted is called when a request with given maximum cost is accepted.
@ -174,18 +218,24 @@ func (cm *ClientManager) accepted(node *ClientNode, maxCost uint64, now mclock.A
// //
// Note: processed should always be called for all accepted requests // Note: processed should always be called for all accepted requests
func (cm *ClientManager) processed(node *ClientNode, maxCost, realCost uint64, now mclock.AbsTime) { func (cm *ClientManager) processed(node *ClientNode, maxCost, realCost uint64, now mclock.AbsTime) {
cm.lock.Lock()
defer cm.lock.Unlock()
if realCost > maxCost { if realCost > maxCost {
realCost = maxCost realCost = maxCost
} }
cm.updateNodeRc(node, int64(maxCost-realCost), &node.params, now) cm.updateBuffer(node, int64(maxCost-realCost), now)
if uint64(node.corrBufValue) > node.bufValue { }
// updateBuffer recalulates the corrected buffer value, adds the given value to it
// and updates the node's actual buffer value if possible
func (cm *ClientManager) updateBuffer(node *ClientNode, add int64, now mclock.AbsTime) {
cm.lock.Lock()
defer cm.lock.Unlock()
cm.updateNodeRc(node, add, &node.params, now)
if node.corrBufValue > node.bufValue {
if node.log != nil { if node.log != nil {
node.log.add(now, fmt.Sprintf("corrected bv=%d oldBv=%d", node.corrBufValue, node.bufValue)) node.log.add(now, fmt.Sprintf("corrected bv=%d oldBv=%d", node.corrBufValue, node.bufValue))
} }
node.bufValue = uint64(node.corrBufValue) node.bufValue = node.corrBufValue
} }
} }
@ -195,11 +245,30 @@ func (cm *ClientManager) updateParams(node *ClientNode, params ServerParams, now
defer cm.lock.Unlock() defer cm.lock.Unlock()
cm.updateRecharge(now) cm.updateRecharge(now)
cm.updateCapFactor(now, true) cm.updateTotalCapacity(now, true)
cm.totalConnected += params.MinRecharge - node.params.MinRecharge cm.totalConnected += params.MinRecharge - node.params.MinRecharge
cm.updateRaiseLimit()
cm.updateNodeRc(node, 0, &params, now) cm.updateNodeRc(node, 0, &params, now)
} }
// updateRaiseLimit recalculates the limiting value until which logTotalCap
// can be raised when no client freeze events occur
func (cm *ClientManager) updateRaiseLimit() {
if cm.capacityRaiseThreshold == 0 {
cm.logTotalCapRaiseLimit = 0
return
}
limit := float64(cm.totalConnected + cm.capacityRaiseThreshold)
limit2 := float64(cm.totalConnected) * capacityRaiseThresholdRatio
if limit2 > limit {
limit = limit2
}
if limit < 1 {
limit = 1
}
cm.logTotalCapRaiseLimit = math.Log(limit)
}
// updateRecharge updates the recharge integrator and checks the recharge queue // updateRecharge updates the recharge integrator and checks the recharge queue
// for nodes with recently filled buffers // for nodes with recently filled buffers
func (cm *ClientManager) updateRecharge(now mclock.AbsTime) { func (cm *ClientManager) updateRecharge(now mclock.AbsTime) {
@ -208,9 +277,15 @@ func (cm *ClientManager) updateRecharge(now mclock.AbsTime) {
// updating is done in multiple steps if node buffers are filled and sumRecharge // updating is done in multiple steps if node buffers are filled and sumRecharge
// is decreased before the given target time // is decreased before the given target time
for cm.sumRecharge > 0 { for cm.sumRecharge > 0 {
bonusRatio := cm.curve.ValueAt(cm.sumRecharge) / float64(cm.sumRecharge) sumRecharge := cm.sumRecharge
if bonusRatio < 1 { if sumRecharge > cm.totalRecharge {
bonusRatio = 1 sumRecharge = cm.totalRecharge
}
bonusRatio := float64(1)
v := cm.curve.ValueAt(sumRecharge)
s := float64(sumRecharge)
if v > s && s > 0 {
bonusRatio = v / s
} }
dt := now - lastUpdate dt := now - lastUpdate
// fetch the client that finishes first // fetch the client that finishes first
@ -228,7 +303,6 @@ func (cm *ClientManager) updateRecharge(now mclock.AbsTime) {
// finished recharging, update corrBufValue and sumRecharge if necessary and do next step // finished recharging, update corrBufValue and sumRecharge if necessary and do next step
if rcqNode.corrBufValue < int64(rcqNode.params.BufLimit) { if rcqNode.corrBufValue < int64(rcqNode.params.BufLimit) {
rcqNode.corrBufValue = int64(rcqNode.params.BufLimit) rcqNode.corrBufValue = int64(rcqNode.params.BufLimit)
cm.updateCapFactor(lastUpdate, true)
cm.sumRecharge -= rcqNode.params.MinRecharge cm.sumRecharge -= rcqNode.params.MinRecharge
} }
cm.rcLastIntValue = rcqNode.rcFullIntValue cm.rcLastIntValue = rcqNode.rcFullIntValue
@ -249,9 +323,6 @@ func (cm *ClientManager) updateNodeRc(node *ClientNode, bvc int64, params *Serve
node.rcLastIntValue = cm.rcLastIntValue node.rcLastIntValue = cm.rcLastIntValue
} }
node.corrBufValue += bvc node.corrBufValue += bvc
if node.corrBufValue < 0 {
node.corrBufValue = 0
}
diff := int64(params.BufLimit - node.params.BufLimit) diff := int64(params.BufLimit - node.params.BufLimit)
if diff > 0 { if diff > 0 {
node.corrBufValue += diff node.corrBufValue += diff
@ -261,15 +332,14 @@ func (cm *ClientManager) updateNodeRc(node *ClientNode, bvc int64, params *Serve
node.corrBufValue = int64(params.BufLimit) node.corrBufValue = int64(params.BufLimit)
isFull = true isFull = true
} }
sumRecharge := cm.sumRecharge
if !wasFull { if !wasFull {
sumRecharge -= node.params.MinRecharge cm.sumRecharge -= node.params.MinRecharge
} }
if params != &node.params { if params != &node.params {
node.params = *params node.params = *params
} }
if !isFull { if !isFull {
sumRecharge += node.params.MinRecharge cm.sumRecharge += node.params.MinRecharge
if node.queueIndex != -1 { if node.queueIndex != -1 {
cm.rcQueue.Remove(node.queueIndex) cm.rcQueue.Remove(node.queueIndex)
} }
@ -277,63 +347,54 @@ func (cm *ClientManager) updateNodeRc(node *ClientNode, bvc int64, params *Serve
node.rcFullIntValue = cm.rcLastIntValue + (int64(node.params.BufLimit)-node.corrBufValue)*FixedPointMultiplier/int64(node.params.MinRecharge) node.rcFullIntValue = cm.rcLastIntValue + (int64(node.params.BufLimit)-node.corrBufValue)*FixedPointMultiplier/int64(node.params.MinRecharge)
cm.rcQueue.Push(node, -node.rcFullIntValue) cm.rcQueue.Push(node, -node.rcFullIntValue)
} }
if sumRecharge != cm.sumRecharge { }
cm.updateCapFactor(now, true)
cm.sumRecharge = sumRecharge
}
// reduceTotalCapacity reduces the total capacity allowance in case of a client freeze event
func (cm *ClientManager) reduceTotalCapacity(frozenCap uint64) {
cm.lock.Lock()
defer cm.lock.Unlock()
ratio := float64(1)
if frozenCap < cm.totalConnected {
ratio = float64(frozenCap) / float64(cm.totalConnected)
}
now := cm.clock.Now()
cm.updateTotalCapacity(now, false)
cm.logTotalCap -= capacityDropFactor * ratio
if cm.logTotalCap < cm.minLogTotalCap {
cm.logTotalCap = cm.minLogTotalCap
}
cm.updateTotalCapacity(now, true)
} }
// updateCapFactor updates the total capacity factor. The capacity factor allows // updateTotalCapacity updates the total capacity factor. The capacity factor allows
// the total capacity of the system to go over the allowed total recharge value // the total capacity of the system to go over the allowed total recharge value
// if the sum of momentarily recharging clients only exceeds the total recharge // if clients go to frozen state sufficiently rarely.
// allowance in a very small fraction of time. // The capacity factor is dropped instantly by a small amount if a clients is frozen.
// The capacity factor is dropped quickly (with a small time constant) if sumRecharge // It is raised slowly (with a large time constant) if the total connected capacity
// exceeds totalRecharge. It is raised slowly (with a large time constant) if most // is close to the total allowed amount and no clients are frozen.
// of the total capacity is used by connected clients (totalConnected is larger than func (cm *ClientManager) updateTotalCapacity(now mclock.AbsTime, refresh bool) {
// totalCapacity*capFactorRaiseThreshold) and sumRecharge stays under
// totalRecharge*totalConnected/totalCapacity.
func (cm *ClientManager) updateCapFactor(now mclock.AbsTime, refresh bool) {
if cm.totalRecharge == 0 {
return
}
dt := now - cm.capLastUpdate dt := now - cm.capLastUpdate
cm.capLastUpdate = now cm.capLastUpdate = now
var d float64 if cm.logTotalCap < cm.logTotalCapRaiseLimit {
if cm.sumRecharge > cm.totalRecharge { cm.logTotalCap += capacityRaiseTC * float64(dt)
d = (1 - float64(cm.sumRecharge)/float64(cm.totalRecharge)) * capFactorDropTC if cm.logTotalCap > cm.logTotalCapRaiseLimit {
} else { cm.logTotalCap = cm.logTotalCapRaiseLimit
totalConnected := float64(cm.totalConnected)
var connRatio float64
if totalConnected < cm.totalCapacity {
connRatio = totalConnected / cm.totalCapacity
} else {
connRatio = 1
}
if connRatio > capFactorRaiseThreshold {
sumRecharge := float64(cm.sumRecharge)
limit := float64(cm.totalRecharge) * connRatio
if sumRecharge < limit {
d = (1 - sumRecharge/limit) * (connRatio - capFactorRaiseThreshold) * (1 / (1 - capFactorRaiseThreshold)) * capFactorRaiseTC
}
} }
} }
if d != 0 { if cm.logTotalCap > cm.maxLogTotalCap {
cm.capLogFactor += d * float64(dt) cm.logTotalCap = cm.maxLogTotalCap
if cm.capLogFactor < 0 { }
cm.capLogFactor = 0 if refresh {
} cm.refreshCapacity()
if refresh {
cm.refreshCapacity()
}
} }
} }
// refreshCapacity recalculates the total capacity value and sends an update to the subscription // refreshCapacity recalculates the total capacity value and sends an update to the subscription
// channel if the relative change of the value since the last update is more than 0.1 percent // channel if the relative change of the value since the last update is more than 0.1 percent
func (cm *ClientManager) refreshCapacity() { func (cm *ClientManager) refreshCapacity() {
totalCapacity := float64(cm.totalRecharge) * math.Exp(cm.capLogFactor) totalCapacity := math.Exp(cm.logTotalCap)
if totalCapacity >= cm.totalCapacity*0.999 && totalCapacity <= cm.totalCapacity*1.001 { if totalCapacity >= cm.totalCapacity*0.999 && totalCapacity <= cm.totalCapacity*1.001 {
return return
} }

@ -63,7 +63,7 @@ func testConstantTotalCapacity(t *testing.T, nodeCount, maxCapacityNodes, random
} }
m := NewClientManager(PieceWiseLinear{{0, totalCapacity}}, clock) m := NewClientManager(PieceWiseLinear{{0, totalCapacity}}, clock)
for _, n := range nodes { for _, n := range nodes {
n.bufLimit = n.capacity * 6000 //uint64(2000+rand.Intn(10000)) n.bufLimit = n.capacity * 6000
n.node = NewClientNode(m, ServerParams{BufLimit: n.bufLimit, MinRecharge: n.capacity}) n.node = NewClientNode(m, ServerParams{BufLimit: n.bufLimit, MinRecharge: n.capacity})
} }
maxNodes := make([]int, maxCapacityNodes) maxNodes := make([]int, maxCapacityNodes)
@ -73,6 +73,7 @@ func testConstantTotalCapacity(t *testing.T, nodeCount, maxCapacityNodes, random
maxNodes[i] = rand.Intn(nodeCount) maxNodes[i] = rand.Intn(nodeCount)
} }
var sendCount int
for i := 0; i < testLength; i++ { for i := 0; i < testLength; i++ {
now := clock.Now() now := clock.Now()
for _, idx := range maxNodes { for _, idx := range maxNodes {
@ -83,13 +84,15 @@ func testConstantTotalCapacity(t *testing.T, nodeCount, maxCapacityNodes, random
maxNodes[rand.Intn(maxCapacityNodes)] = rand.Intn(nodeCount) maxNodes[rand.Intn(maxCapacityNodes)] = rand.Intn(nodeCount)
} }
sendCount := randomSend sendCount += randomSend
for sendCount > 0 { failCount := randomSend * 10
for sendCount > 0 && failCount > 0 {
if nodes[rand.Intn(nodeCount)].send(t, now) { if nodes[rand.Intn(nodeCount)].send(t, now) {
sendCount-- sendCount--
} else {
failCount--
} }
} }
clock.Run(time.Millisecond) clock.Run(time.Millisecond)
} }
@ -117,7 +120,6 @@ func (n *testNode) send(t *testing.T, now mclock.AbsTime) bool {
if bv < testMaxCost { if bv < testMaxCost {
n.waitUntil = now + mclock.AbsTime((testMaxCost-bv)*1001000/n.capacity) n.waitUntil = now + mclock.AbsTime((testMaxCost-bv)*1001000/n.capacity)
} }
//n.waitUntil = now + mclock.AbsTime(float64(testMaxCost)*1001000/float64(n.capacity)*(1-float64(bv)/float64(n.bufLimit)))
n.totalCost += rcost n.totalCost += rcost
return true return true
} }

@ -26,6 +26,7 @@ import (
"github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/common/prque" "github.com/ethereum/go-ethereum/common/prque"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/les/csvlogger"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
) )
@ -52,6 +53,8 @@ type freeClientPool struct {
connectedLimit, totalLimit int connectedLimit, totalLimit int
freeClientCap uint64 freeClientCap uint64
logger *csvlogger.Logger
logTotalFreeConn *csvlogger.Channel
addressMap map[string]*freeClientPoolEntry addressMap map[string]*freeClientPoolEntry
connPool, disconnPool *prque.Prque connPool, disconnPool *prque.Prque
@ -66,16 +69,18 @@ const (
) )
// newFreeClientPool creates a new free client pool // newFreeClientPool creates a new free client pool
func newFreeClientPool(db ethdb.Database, freeClientCap uint64, totalLimit int, clock mclock.Clock, removePeer func(string)) *freeClientPool { func newFreeClientPool(db ethdb.Database, freeClientCap uint64, totalLimit int, clock mclock.Clock, removePeer func(string), metricsLogger, eventLogger *csvlogger.Logger) *freeClientPool {
pool := &freeClientPool{ pool := &freeClientPool{
db: db, db: db,
clock: clock, clock: clock,
addressMap: make(map[string]*freeClientPoolEntry), addressMap: make(map[string]*freeClientPoolEntry),
connPool: prque.New(poolSetIndex), connPool: prque.New(poolSetIndex),
disconnPool: prque.New(poolSetIndex), disconnPool: prque.New(poolSetIndex),
freeClientCap: freeClientCap, freeClientCap: freeClientCap,
totalLimit: totalLimit, totalLimit: totalLimit,
removePeer: removePeer, logger: eventLogger,
logTotalFreeConn: metricsLogger.NewChannel("totalFreeConn", 0),
removePeer: removePeer,
} }
pool.loadFromDb() pool.loadFromDb()
return pool return pool
@ -88,10 +93,25 @@ func (f *freeClientPool) stop() {
f.lock.Unlock() f.lock.Unlock()
} }
// freeClientId returns a string identifier for the peer. Multiple peers with the
// same identifier can not be in the free client pool simultaneously.
func freeClientId(p *peer) string {
if addr, ok := p.RemoteAddr().(*net.TCPAddr); ok {
if addr.IP.IsLoopback() {
// using peer id instead of loopback ip address allows multiple free
// connections from local machine to own server
return p.id
} else {
return addr.IP.String()
}
}
return ""
}
// registerPeer implements clientPool // registerPeer implements clientPool
func (f *freeClientPool) registerPeer(p *peer) { func (f *freeClientPool) registerPeer(p *peer) {
if addr, ok := p.RemoteAddr().(*net.TCPAddr); ok { if freeId := freeClientId(p); freeId != "" {
if !f.connect(addr.IP.String(), p.id) { if !f.connect(freeId, p.id) {
f.removePeer(p.id) f.removePeer(p.id)
} }
} }
@ -107,7 +127,9 @@ func (f *freeClientPool) connect(address, id string) bool {
return false return false
} }
f.logger.Event("freeClientPool: connecting from " + address + ", " + id)
if f.connectedLimit == 0 { if f.connectedLimit == 0 {
f.logger.Event("freeClientPool: rejected, " + id)
log.Debug("Client rejected", "address", address) log.Debug("Client rejected", "address", address)
return false return false
} }
@ -119,6 +141,7 @@ func (f *freeClientPool) connect(address, id string) bool {
f.addressMap[address] = e f.addressMap[address] = e
} else { } else {
if e.connected { if e.connected {
f.logger.Event("freeClientPool: already connected, " + id)
log.Debug("Client already connected", "address", address) log.Debug("Client already connected", "address", address)
return false return false
} }
@ -131,9 +154,11 @@ func (f *freeClientPool) connect(address, id string) bool {
if e.linUsage+int64(connectedBias)-i.linUsage < 0 { if e.linUsage+int64(connectedBias)-i.linUsage < 0 {
// kick it out and accept the new client // kick it out and accept the new client
f.dropClient(i, now) f.dropClient(i, now)
f.logger.Event("freeClientPool: kicked out, " + i.id)
} else { } else {
// keep the old client and reject the new one // keep the old client and reject the new one
f.connPool.Push(i, i.linUsage) f.connPool.Push(i, i.linUsage)
f.logger.Event("freeClientPool: rejected, " + id)
log.Debug("Client rejected", "address", address) log.Debug("Client rejected", "address", address)
return false return false
} }
@ -142,17 +167,19 @@ func (f *freeClientPool) connect(address, id string) bool {
e.connected = true e.connected = true
e.id = id e.id = id
f.connPool.Push(e, e.linUsage) f.connPool.Push(e, e.linUsage)
f.logTotalFreeConn.Update(float64(uint64(f.connPool.Size()) * f.freeClientCap))
if f.connPool.Size()+f.disconnPool.Size() > f.totalLimit { if f.connPool.Size()+f.disconnPool.Size() > f.totalLimit {
f.disconnPool.Pop() f.disconnPool.Pop()
} }
f.logger.Event("freeClientPool: accepted, " + id)
log.Debug("Client accepted", "address", address) log.Debug("Client accepted", "address", address)
return true return true
} }
// unregisterPeer implements clientPool // unregisterPeer implements clientPool
func (f *freeClientPool) unregisterPeer(p *peer) { func (f *freeClientPool) unregisterPeer(p *peer) {
if addr, ok := p.RemoteAddr().(*net.TCPAddr); ok { if freeId := freeClientId(p); freeId != "" {
f.disconnect(addr.IP.String()) f.disconnect(freeId)
} }
} }
@ -174,9 +201,11 @@ func (f *freeClientPool) disconnect(address string) {
} }
f.connPool.Remove(e.index) f.connPool.Remove(e.index)
f.logTotalFreeConn.Update(float64(uint64(f.connPool.Size()) * f.freeClientCap))
f.calcLogUsage(e, now) f.calcLogUsage(e, now)
e.connected = false e.connected = false
f.disconnPool.Push(e, -e.logUsage) f.disconnPool.Push(e, -e.logUsage)
f.logger.Event("freeClientPool: disconnected, " + e.id)
log.Debug("Client disconnected", "address", address) log.Debug("Client disconnected", "address", address)
} }
@ -194,6 +223,7 @@ func (f *freeClientPool) setLimits(count int, totalCap uint64) {
for f.connPool.Size() > f.connectedLimit { for f.connPool.Size() > f.connectedLimit {
i := f.connPool.PopItem().(*freeClientPoolEntry) i := f.connPool.PopItem().(*freeClientPoolEntry)
f.dropClient(i, now) f.dropClient(i, now)
f.logger.Event("freeClientPool: setLimits kicked out, " + i.id)
} }
} }
@ -201,6 +231,7 @@ func (f *freeClientPool) setLimits(count int, totalCap uint64) {
// disconnected pool // disconnected pool
func (f *freeClientPool) dropClient(i *freeClientPoolEntry, now mclock.AbsTime) { func (f *freeClientPool) dropClient(i *freeClientPoolEntry, now mclock.AbsTime) {
f.connPool.Remove(i.index) f.connPool.Remove(i.index)
f.logTotalFreeConn.Update(float64(uint64(f.connPool.Size()) * f.freeClientCap))
f.calcLogUsage(i, now) f.calcLogUsage(i, now)
i.connected = false i.connected = false
f.disconnPool.Push(i, -i.logUsage) f.disconnPool.Push(i, -i.logUsage)

@ -61,7 +61,7 @@ func testFreeClientPool(t *testing.T, connLimit, clientCount int) {
} }
disconnCh <- i disconnCh <- i
} }
pool = newFreeClientPool(db, 1, 10000, &clock, disconnFn) pool = newFreeClientPool(db, 1, 10000, &clock, disconnFn, nil, nil)
) )
pool.setLimits(connLimit, uint64(connLimit)) pool.setLimits(connLimit, uint64(connLimit))
@ -130,7 +130,7 @@ func testFreeClientPool(t *testing.T, connLimit, clientCount int) {
// close and restart pool // close and restart pool
pool.stop() pool.stop()
pool = newFreeClientPool(db, 1, 10000, &clock, disconnFn) pool = newFreeClientPool(db, 1, 10000, &clock, disconnFn, nil, nil)
pool.setLimits(connLimit, uint64(connLimit)) pool.setLimits(connLimit, uint64(connLimit))
// try connecting all known peers (connLimit should be filled up) // try connecting all known peers (connLimit should be filled up)

@ -34,6 +34,7 @@ import (
"github.com/ethereum/go-ethereum/eth/downloader" "github.com/ethereum/go-ethereum/eth/downloader"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/les/csvlogger"
"github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/light"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p"
@ -118,6 +119,7 @@ type ProtocolManager struct {
wg *sync.WaitGroup wg *sync.WaitGroup
eventMux *event.TypeMux eventMux *event.TypeMux
logger *csvlogger.Logger
// Callbacks // Callbacks
synced func() bool synced func() bool
@ -165,8 +167,6 @@ func NewProtocolManager(
if odr != nil { if odr != nil {
manager.retriever = odr.retriever manager.retriever = odr.retriever
manager.reqDist = odr.retriever.dist manager.reqDist = odr.retriever.dist
} else {
manager.servingQueue = newServingQueue(int64(time.Millisecond * 10))
} }
if ulcConfig != nil { if ulcConfig != nil {
@ -272,6 +272,7 @@ func (pm *ProtocolManager) handle(p *peer) error {
// Ignore maxPeers if this is a trusted peer // Ignore maxPeers if this is a trusted peer
// In server mode we try to check into the client pool after handshake // In server mode we try to check into the client pool after handshake
if pm.client && pm.peers.Len() >= pm.maxPeers && !p.Peer.Info().Network.Trusted { if pm.client && pm.peers.Len() >= pm.maxPeers && !p.Peer.Info().Network.Trusted {
pm.logger.Event("Rejected (too many peers), " + p.id)
return p2p.DiscTooManyPeers return p2p.DiscTooManyPeers
} }
// Reject light clients if server is not synced. // Reject light clients if server is not synced.
@ -290,6 +291,7 @@ func (pm *ProtocolManager) handle(p *peer) error {
) )
if err := p.Handshake(td, hash, number, genesis.Hash(), pm.server); err != nil { if err := p.Handshake(td, hash, number, genesis.Hash(), pm.server); err != nil {
p.Log().Debug("Light Ethereum handshake failed", "err", err) p.Log().Debug("Light Ethereum handshake failed", "err", err)
pm.logger.Event("Handshake error: " + err.Error() + ", " + p.id)
return err return err
} }
if p.fcClient != nil { if p.fcClient != nil {
@ -303,9 +305,12 @@ func (pm *ProtocolManager) handle(p *peer) error {
// Register the peer locally // Register the peer locally
if err := pm.peers.Register(p); err != nil { if err := pm.peers.Register(p); err != nil {
p.Log().Error("Light Ethereum peer registration failed", "err", err) p.Log().Error("Light Ethereum peer registration failed", "err", err)
pm.logger.Event("Peer registration error: " + err.Error() + ", " + p.id)
return err return err
} }
pm.logger.Event("Connection established, " + p.id)
defer func() { defer func() {
pm.logger.Event("Closed connection, " + p.id)
pm.removePeer(p.id) pm.removePeer(p.id)
}() }()
@ -326,6 +331,7 @@ func (pm *ProtocolManager) handle(p *peer) error {
// main loop. handle incoming messages. // main loop. handle incoming messages.
for { for {
if err := pm.handleMsg(p); err != nil { if err := pm.handleMsg(p); err != nil {
pm.logger.Event("Message handling error: " + err.Error() + ", " + p.id)
p.Log().Debug("Light Ethereum message handling failed", "err", err) p.Log().Debug("Light Ethereum message handling failed", "err", err)
if p.fcServer != nil { if p.fcServer != nil {
p.fcServer.DumpLogs() p.fcServer.DumpLogs()
@ -358,23 +364,40 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
) )
accept := func(reqID, reqCnt, maxCnt uint64) bool { accept := func(reqID, reqCnt, maxCnt uint64) bool {
if reqCnt == 0 { inSizeCost := func() uint64 {
return false if pm.server.costTracker != nil {
return pm.server.costTracker.realCost(0, msg.Size, 0)
}
return 0
} }
if p.fcClient == nil || reqCnt > maxCnt { if p.isFrozen() || reqCnt == 0 || p.fcClient == nil || reqCnt > maxCnt {
p.fcClient.OneTimeCost(inSizeCost())
return false return false
} }
maxCost = p.fcCosts.getCost(msg.Code, reqCnt) maxCost = p.fcCosts.getMaxCost(msg.Code, reqCnt)
gf := float64(1)
if pm.server.costTracker != nil {
gf = pm.server.costTracker.globalFactor()
if gf < 0.001 {
p.Log().Error("Invalid global cost factor", "globalFactor", gf)
gf = 1
}
}
maxTime := uint64(float64(maxCost) / gf)
if accepted, bufShort, servingPriority := p.fcClient.AcceptRequest(reqID, responseCount, maxCost); !accepted { if accepted, bufShort, servingPriority := p.fcClient.AcceptRequest(reqID, responseCount, maxCost); !accepted {
if bufShort > 0 { p.freezeClient()
p.Log().Error("Request came too early", "remaining", common.PrettyDuration(time.Duration(bufShort*1000000/p.fcParams.MinRecharge))) p.Log().Warn("Request came too early", "remaining", common.PrettyDuration(time.Duration(bufShort*1000000/p.fcParams.MinRecharge)))
} p.fcClient.OneTimeCost(inSizeCost())
return false return false
} else { } else {
task = pm.servingQueue.newTask(servingPriority) task = pm.servingQueue.newTask(p, maxTime, servingPriority)
}
if task.start() {
return true
} }
return task.start() p.fcClient.RequestProcessed(reqID, responseCount, maxCost, inSizeCost())
return false
} }
if msg.Size > ProtocolMaxMsgSize { if msg.Size > ProtocolMaxMsgSize {
@ -388,6 +411,10 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
p.responseLock.Lock() p.responseLock.Lock()
defer p.responseLock.Unlock() defer p.responseLock.Unlock()
if p.isFrozen() {
amount = 0
reply = nil
}
var replySize uint32 var replySize uint32
if reply != nil { if reply != nil {
replySize = reply.size() replySize = reply.size()
@ -395,7 +422,9 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
var realCost uint64 var realCost uint64
if pm.server.costTracker != nil { if pm.server.costTracker != nil {
realCost = pm.server.costTracker.realCost(servingTime, msg.Size, replySize) realCost = pm.server.costTracker.realCost(servingTime, msg.Size, replySize)
pm.server.costTracker.updateStats(msg.Code, amount, servingTime, realCost) if amount != 0 {
pm.server.costTracker.updateStats(msg.Code, amount, servingTime, realCost)
}
} else { } else {
realCost = maxCost realCost = maxCost
} }
@ -463,94 +492,94 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
} }
query := req.Query query := req.Query
if !accept(req.ReqID, query.Amount, MaxHeaderFetch) { if accept(req.ReqID, query.Amount, MaxHeaderFetch) {
return errResp(ErrRequestRejected, "") go func() {
} hashMode := query.Origin.Hash != (common.Hash{})
go func() { first := true
hashMode := query.Origin.Hash != (common.Hash{}) maxNonCanonical := uint64(100)
first := true
maxNonCanonical := uint64(100) // Gather headers until the fetch or network limits is reached
var (
// Gather headers until the fetch or network limits is reached bytes common.StorageSize
var ( headers []*types.Header
bytes common.StorageSize unknown bool
headers []*types.Header )
unknown bool for !unknown && len(headers) < int(query.Amount) && bytes < softResponseLimit {
) if !first && !task.waitOrStop() {
for !unknown && len(headers) < int(query.Amount) && bytes < softResponseLimit { sendResponse(req.ReqID, 0, nil, task.servingTime)
if !first && !task.waitOrStop() { return
return }
} // Retrieve the next header satisfying the query
// Retrieve the next header satisfying the query var origin *types.Header
var origin *types.Header if hashMode {
if hashMode { if first {
if first { origin = pm.blockchain.GetHeaderByHash(query.Origin.Hash)
origin = pm.blockchain.GetHeaderByHash(query.Origin.Hash) if origin != nil {
if origin != nil { query.Origin.Number = origin.Number.Uint64()
query.Origin.Number = origin.Number.Uint64() }
} else {
origin = pm.blockchain.GetHeader(query.Origin.Hash, query.Origin.Number)
} }
} else { } else {
origin = pm.blockchain.GetHeader(query.Origin.Hash, query.Origin.Number) origin = pm.blockchain.GetHeaderByNumber(query.Origin.Number)
} }
} else { if origin == nil {
origin = pm.blockchain.GetHeaderByNumber(query.Origin.Number) break
}
if origin == nil {
break
}
headers = append(headers, origin)
bytes += estHeaderRlpSize
// Advance to the next header of the query
switch {
case hashMode && query.Reverse:
// Hash based traversal towards the genesis block
ancestor := query.Skip + 1
if ancestor == 0 {
unknown = true
} else {
query.Origin.Hash, query.Origin.Number = pm.blockchain.GetAncestor(query.Origin.Hash, query.Origin.Number, ancestor, &maxNonCanonical)
unknown = (query.Origin.Hash == common.Hash{})
} }
case hashMode && !query.Reverse: headers = append(headers, origin)
// Hash based traversal towards the leaf block bytes += estHeaderRlpSize
var (
current = origin.Number.Uint64() // Advance to the next header of the query
next = current + query.Skip + 1 switch {
) case hashMode && query.Reverse:
if next <= current { // Hash based traversal towards the genesis block
infos, _ := json.MarshalIndent(p.Peer.Info(), "", " ") ancestor := query.Skip + 1
p.Log().Warn("GetBlockHeaders skip overflow attack", "current", current, "skip", query.Skip, "next", next, "attacker", infos) if ancestor == 0 {
unknown = true unknown = true
} else { } else {
if header := pm.blockchain.GetHeaderByNumber(next); header != nil { query.Origin.Hash, query.Origin.Number = pm.blockchain.GetAncestor(query.Origin.Hash, query.Origin.Number, ancestor, &maxNonCanonical)
nextHash := header.Hash() unknown = (query.Origin.Hash == common.Hash{})
expOldHash, _ := pm.blockchain.GetAncestor(nextHash, next, query.Skip+1, &maxNonCanonical) }
if expOldHash == query.Origin.Hash { case hashMode && !query.Reverse:
query.Origin.Hash, query.Origin.Number = nextHash, next // Hash based traversal towards the leaf block
var (
current = origin.Number.Uint64()
next = current + query.Skip + 1
)
if next <= current {
infos, _ := json.MarshalIndent(p.Peer.Info(), "", " ")
p.Log().Warn("GetBlockHeaders skip overflow attack", "current", current, "skip", query.Skip, "next", next, "attacker", infos)
unknown = true
} else {
if header := pm.blockchain.GetHeaderByNumber(next); header != nil {
nextHash := header.Hash()
expOldHash, _ := pm.blockchain.GetAncestor(nextHash, next, query.Skip+1, &maxNonCanonical)
if expOldHash == query.Origin.Hash {
query.Origin.Hash, query.Origin.Number = nextHash, next
} else {
unknown = true
}
} else { } else {
unknown = true unknown = true
} }
}
case query.Reverse:
// Number based traversal towards the genesis block
if query.Origin.Number >= query.Skip+1 {
query.Origin.Number -= query.Skip + 1
} else { } else {
unknown = true unknown = true
} }
}
case query.Reverse:
// Number based traversal towards the genesis block
if query.Origin.Number >= query.Skip+1 {
query.Origin.Number -= query.Skip + 1
} else {
unknown = true
}
case !query.Reverse: case !query.Reverse:
// Number based traversal towards the leaf block // Number based traversal towards the leaf block
query.Origin.Number += query.Skip + 1 query.Origin.Number += query.Skip + 1
}
first = false
} }
first = false sendResponse(req.ReqID, query.Amount, p.ReplyBlockHeaders(req.ReqID, headers), task.done())
} }()
sendResponse(req.ReqID, query.Amount, p.ReplyBlockHeaders(req.ReqID, headers), task.done()) }
}()
case BlockHeadersMsg: case BlockHeadersMsg:
if pm.downloader == nil { if pm.downloader == nil {
@ -592,27 +621,27 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
bodies []rlp.RawValue bodies []rlp.RawValue
) )
reqCnt := len(req.Hashes) reqCnt := len(req.Hashes)
if !accept(req.ReqID, uint64(reqCnt), MaxBodyFetch) { if accept(req.ReqID, uint64(reqCnt), MaxBodyFetch) {
return errResp(ErrRequestRejected, "") go func() {
} for i, hash := range req.Hashes {
go func() { if i != 0 && !task.waitOrStop() {
for i, hash := range req.Hashes { sendResponse(req.ReqID, 0, nil, task.servingTime)
if i != 0 && !task.waitOrStop() { return
return }
} if bytes >= softResponseLimit {
if bytes >= softResponseLimit { break
break }
} // Retrieve the requested block body, stopping if enough was found
// Retrieve the requested block body, stopping if enough was found if number := rawdb.ReadHeaderNumber(pm.chainDb, hash); number != nil {
if number := rawdb.ReadHeaderNumber(pm.chainDb, hash); number != nil { if data := rawdb.ReadBodyRLP(pm.chainDb, hash, *number); len(data) != 0 {
if data := rawdb.ReadBodyRLP(pm.chainDb, hash, *number); len(data) != 0 { bodies = append(bodies, data)
bodies = append(bodies, data) bytes += len(data)
bytes += len(data) }
} }
} }
} sendResponse(req.ReqID, uint64(reqCnt), p.ReplyBlockBodiesRLP(req.ReqID, bodies), task.done())
sendResponse(req.ReqID, uint64(reqCnt), p.ReplyBlockBodiesRLP(req.ReqID, bodies), task.done()) }()
}() }
case BlockBodiesMsg: case BlockBodiesMsg:
if pm.odr == nil { if pm.odr == nil {
@ -651,45 +680,45 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
data [][]byte data [][]byte
) )
reqCnt := len(req.Reqs) reqCnt := len(req.Reqs)
if !accept(req.ReqID, uint64(reqCnt), MaxCodeFetch) { if accept(req.ReqID, uint64(reqCnt), MaxCodeFetch) {
return errResp(ErrRequestRejected, "") go func() {
} for i, request := range req.Reqs {
go func() { if i != 0 && !task.waitOrStop() {
for i, req := range req.Reqs { sendResponse(req.ReqID, 0, nil, task.servingTime)
if i != 0 && !task.waitOrStop() { return
return }
} // Look up the root hash belonging to the request
// Look up the root hash belonging to the request number := rawdb.ReadHeaderNumber(pm.chainDb, request.BHash)
number := rawdb.ReadHeaderNumber(pm.chainDb, req.BHash) if number == nil {
if number == nil { p.Log().Warn("Failed to retrieve block num for code", "hash", request.BHash)
p.Log().Warn("Failed to retrieve block num for code", "hash", req.BHash) continue
continue }
} header := rawdb.ReadHeader(pm.chainDb, request.BHash, *number)
header := rawdb.ReadHeader(pm.chainDb, req.BHash, *number) if header == nil {
if header == nil { p.Log().Warn("Failed to retrieve header for code", "block", *number, "hash", request.BHash)
p.Log().Warn("Failed to retrieve header for code", "block", *number, "hash", req.BHash) continue
continue }
} triedb := pm.blockchain.StateCache().TrieDB()
triedb := pm.blockchain.StateCache().TrieDB()
account, err := pm.getAccount(triedb, header.Root, common.BytesToHash(req.AccKey)) account, err := pm.getAccount(triedb, header.Root, common.BytesToHash(request.AccKey))
if err != nil { if err != nil {
p.Log().Warn("Failed to retrieve account for code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(req.AccKey), "err", err) p.Log().Warn("Failed to retrieve account for code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err)
continue continue
} }
code, err := triedb.Node(common.BytesToHash(account.CodeHash)) code, err := triedb.Node(common.BytesToHash(account.CodeHash))
if err != nil { if err != nil {
p.Log().Warn("Failed to retrieve account code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(req.AccKey), "codehash", common.BytesToHash(account.CodeHash), "err", err) p.Log().Warn("Failed to retrieve account code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "codehash", common.BytesToHash(account.CodeHash), "err", err)
continue continue
} }
// Accumulate the code and abort if enough data was retrieved // Accumulate the code and abort if enough data was retrieved
data = append(data, code) data = append(data, code)
if bytes += len(code); bytes >= softResponseLimit { if bytes += len(code); bytes >= softResponseLimit {
break break
}
} }
} sendResponse(req.ReqID, uint64(reqCnt), p.ReplyCode(req.ReqID, data), task.done())
sendResponse(req.ReqID, uint64(reqCnt), p.ReplyCode(req.ReqID, data), task.done()) }()
}() }
case CodeMsg: case CodeMsg:
if pm.odr == nil { if pm.odr == nil {
@ -728,37 +757,37 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
receipts []rlp.RawValue receipts []rlp.RawValue
) )
reqCnt := len(req.Hashes) reqCnt := len(req.Hashes)
if !accept(req.ReqID, uint64(reqCnt), MaxReceiptFetch) { if accept(req.ReqID, uint64(reqCnt), MaxReceiptFetch) {
return errResp(ErrRequestRejected, "") go func() {
} for i, hash := range req.Hashes {
go func() { if i != 0 && !task.waitOrStop() {
for i, hash := range req.Hashes { sendResponse(req.ReqID, 0, nil, task.servingTime)
if i != 0 && !task.waitOrStop() { return
return }
} if bytes >= softResponseLimit {
if bytes >= softResponseLimit { break
break }
} // Retrieve the requested block's receipts, skipping if unknown to us
// Retrieve the requested block's receipts, skipping if unknown to us var results types.Receipts
var results types.Receipts if number := rawdb.ReadHeaderNumber(pm.chainDb, hash); number != nil {
if number := rawdb.ReadHeaderNumber(pm.chainDb, hash); number != nil { results = rawdb.ReadRawReceipts(pm.chainDb, hash, *number)
results = rawdb.ReadRawReceipts(pm.chainDb, hash, *number) }
} if results == nil {
if results == nil { if header := pm.blockchain.GetHeaderByHash(hash); header == nil || header.ReceiptHash != types.EmptyRootHash {
if header := pm.blockchain.GetHeaderByHash(hash); header == nil || header.ReceiptHash != types.EmptyRootHash { continue
continue }
}
// If known, encode and queue for response packet
if encoded, err := rlp.EncodeToBytes(results); err != nil {
log.Error("Failed to encode receipt", "err", err)
} else {
receipts = append(receipts, encoded)
bytes += len(encoded)
} }
} }
// If known, encode and queue for response packet sendResponse(req.ReqID, uint64(reqCnt), p.ReplyReceiptsRLP(req.ReqID, receipts), task.done())
if encoded, err := rlp.EncodeToBytes(results); err != nil { }()
log.Error("Failed to encode receipt", "err", err) }
} else {
receipts = append(receipts, encoded)
bytes += len(encoded)
}
}
sendResponse(req.ReqID, uint64(reqCnt), p.ReplyReceiptsRLP(req.ReqID, receipts), task.done())
}()
case ReceiptsMsg: case ReceiptsMsg:
if pm.odr == nil { if pm.odr == nil {
@ -797,70 +826,70 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
root common.Hash root common.Hash
) )
reqCnt := len(req.Reqs) reqCnt := len(req.Reqs)
if !accept(req.ReqID, uint64(reqCnt), MaxProofsFetch) { if accept(req.ReqID, uint64(reqCnt), MaxProofsFetch) {
return errResp(ErrRequestRejected, "") go func() {
} nodes := light.NewNodeSet()
go func() {
nodes := light.NewNodeSet() for i, request := range req.Reqs {
if i != 0 && !task.waitOrStop() {
for i, req := range req.Reqs { sendResponse(req.ReqID, 0, nil, task.servingTime)
if i != 0 && !task.waitOrStop() { return
return
}
// Look up the root hash belonging to the request
var (
number *uint64
header *types.Header
trie state.Trie
)
if req.BHash != lastBHash {
root, lastBHash = common.Hash{}, req.BHash
if number = rawdb.ReadHeaderNumber(pm.chainDb, req.BHash); number == nil {
p.Log().Warn("Failed to retrieve block num for proof", "hash", req.BHash)
continue
} }
if header = rawdb.ReadHeader(pm.chainDb, req.BHash, *number); header == nil { // Look up the root hash belonging to the request
p.Log().Warn("Failed to retrieve header for proof", "block", *number, "hash", req.BHash) var (
continue number *uint64
header *types.Header
trie state.Trie
)
if request.BHash != lastBHash {
root, lastBHash = common.Hash{}, request.BHash
if number = rawdb.ReadHeaderNumber(pm.chainDb, request.BHash); number == nil {
p.Log().Warn("Failed to retrieve block num for proof", "hash", request.BHash)
continue
}
if header = rawdb.ReadHeader(pm.chainDb, request.BHash, *number); header == nil {
p.Log().Warn("Failed to retrieve header for proof", "block", *number, "hash", request.BHash)
continue
}
root = header.Root
} }
root = header.Root // Open the account or storage trie for the request
} statedb := pm.blockchain.StateCache()
// Open the account or storage trie for the request
statedb := pm.blockchain.StateCache() switch len(request.AccKey) {
case 0:
switch len(req.AccKey) { // No account key specified, open an account trie
case 0: trie, err = statedb.OpenTrie(root)
// No account key specified, open an account trie if trie == nil || err != nil {
trie, err = statedb.OpenTrie(root) p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "root", root, "err", err)
if trie == nil || err != nil { continue
p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "root", root, "err", err) }
continue default:
// Account key specified, open a storage trie
account, err := pm.getAccount(statedb.TrieDB(), root, common.BytesToHash(request.AccKey))
if err != nil {
p.Log().Warn("Failed to retrieve account for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err)
continue
}
trie, err = statedb.OpenStorageTrie(common.BytesToHash(request.AccKey), account.Root)
if trie == nil || err != nil {
p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "root", account.Root, "err", err)
continue
}
} }
default: // Prove the user's request from the account or stroage trie
// Account key specified, open a storage trie if err := trie.Prove(request.Key, request.FromLevel, nodes); err != nil {
account, err := pm.getAccount(statedb.TrieDB(), root, common.BytesToHash(req.AccKey)) p.Log().Warn("Failed to prove state request", "block", header.Number, "hash", header.Hash(), "err", err)
if err != nil {
p.Log().Warn("Failed to retrieve account for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(req.AccKey), "err", err)
continue continue
} }
trie, err = statedb.OpenStorageTrie(common.BytesToHash(req.AccKey), account.Root) if nodes.DataSize() >= softResponseLimit {
if trie == nil || err != nil { break
p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(req.AccKey), "root", account.Root, "err", err)
continue
} }
} }
// Prove the user's request from the account or stroage trie sendResponse(req.ReqID, uint64(reqCnt), p.ReplyProofsV2(req.ReqID, nodes.NodeList()), task.done())
if err := trie.Prove(req.Key, req.FromLevel, nodes); err != nil { }()
p.Log().Warn("Failed to prove state request", "block", header.Number, "hash", header.Hash(), "err", err) }
continue
}
if nodes.DataSize() >= softResponseLimit {
break
}
}
sendResponse(req.ReqID, uint64(reqCnt), p.ReplyProofsV2(req.ReqID, nodes.NodeList()), task.done())
}()
case ProofsV2Msg: case ProofsV2Msg:
if pm.odr == nil { if pm.odr == nil {
@ -899,53 +928,53 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
auxData [][]byte auxData [][]byte
) )
reqCnt := len(req.Reqs) reqCnt := len(req.Reqs)
if !accept(req.ReqID, uint64(reqCnt), MaxHelperTrieProofsFetch) { if accept(req.ReqID, uint64(reqCnt), MaxHelperTrieProofsFetch) {
return errResp(ErrRequestRejected, "") go func() {
}
go func() {
var (
lastIdx uint64
lastType uint
root common.Hash
auxTrie *trie.Trie
)
nodes := light.NewNodeSet()
for i, req := range req.Reqs {
if i != 0 && !task.waitOrStop() {
return
}
if auxTrie == nil || req.Type != lastType || req.TrieIdx != lastIdx {
auxTrie, lastType, lastIdx = nil, req.Type, req.TrieIdx
var prefix string var (
if root, prefix = pm.getHelperTrie(req.Type, req.TrieIdx); root != (common.Hash{}) { lastIdx uint64
auxTrie, _ = trie.New(root, trie.NewDatabase(rawdb.NewTable(pm.chainDb, prefix))) lastType uint
} root common.Hash
} auxTrie *trie.Trie
if req.AuxReq == auxRoot { )
var data []byte nodes := light.NewNodeSet()
if root != (common.Hash{}) { for i, request := range req.Reqs {
data = root[:] if i != 0 && !task.waitOrStop() {
sendResponse(req.ReqID, 0, nil, task.servingTime)
return
} }
auxData = append(auxData, data) if auxTrie == nil || request.Type != lastType || request.TrieIdx != lastIdx {
auxBytes += len(data) auxTrie, lastType, lastIdx = nil, request.Type, request.TrieIdx
} else {
if auxTrie != nil { var prefix string
auxTrie.Prove(req.Key, req.FromLevel, nodes) if root, prefix = pm.getHelperTrie(request.Type, request.TrieIdx); root != (common.Hash{}) {
auxTrie, _ = trie.New(root, trie.NewDatabase(rawdb.NewTable(pm.chainDb, prefix)))
}
} }
if req.AuxReq != 0 { if request.AuxReq == auxRoot {
data := pm.getHelperTrieAuxData(req) var data []byte
if root != (common.Hash{}) {
data = root[:]
}
auxData = append(auxData, data) auxData = append(auxData, data)
auxBytes += len(data) auxBytes += len(data)
} else {
if auxTrie != nil {
auxTrie.Prove(request.Key, request.FromLevel, nodes)
}
if request.AuxReq != 0 {
data := pm.getHelperTrieAuxData(request)
auxData = append(auxData, data)
auxBytes += len(data)
}
}
if nodes.DataSize()+auxBytes >= softResponseLimit {
break
} }
} }
if nodes.DataSize()+auxBytes >= softResponseLimit { sendResponse(req.ReqID, uint64(reqCnt), p.ReplyHelperTrieProofs(req.ReqID, HelperTrieResps{Proofs: nodes.NodeList(), AuxData: auxData}), task.done())
break }()
} }
}
sendResponse(req.ReqID, uint64(reqCnt), p.ReplyHelperTrieProofs(req.ReqID, HelperTrieResps{Proofs: nodes.NodeList(), AuxData: auxData}), task.done())
}()
case HelperTrieProofsMsg: case HelperTrieProofsMsg:
if pm.odr == nil { if pm.odr == nil {
@ -981,27 +1010,27 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
return errResp(ErrDecode, "msg %v: %v", msg, err) return errResp(ErrDecode, "msg %v: %v", msg, err)
} }
reqCnt := len(req.Txs) reqCnt := len(req.Txs)
if !accept(req.ReqID, uint64(reqCnt), MaxTxSend) { if accept(req.ReqID, uint64(reqCnt), MaxTxSend) {
return errResp(ErrRequestRejected, "") go func() {
} stats := make([]light.TxStatus, len(req.Txs))
go func() { for i, tx := range req.Txs {
stats := make([]light.TxStatus, len(req.Txs)) if i != 0 && !task.waitOrStop() {
for i, tx := range req.Txs { sendResponse(req.ReqID, 0, nil, task.servingTime)
if i != 0 && !task.waitOrStop() { return
return
}
hash := tx.Hash()
stats[i] = pm.txStatus(hash)
if stats[i].Status == core.TxStatusUnknown {
if errs := pm.txpool.AddRemotes([]*types.Transaction{tx}); errs[0] != nil {
stats[i].Error = errs[0].Error()
continue
} }
hash := tx.Hash()
stats[i] = pm.txStatus(hash) stats[i] = pm.txStatus(hash)
if stats[i].Status == core.TxStatusUnknown {
if errs := pm.txpool.AddRemotes([]*types.Transaction{tx}); errs[0] != nil {
stats[i].Error = errs[0].Error()
continue
}
stats[i] = pm.txStatus(hash)
}
} }
} sendResponse(req.ReqID, uint64(reqCnt), p.ReplyTxStatus(req.ReqID, stats), task.done())
sendResponse(req.ReqID, uint64(reqCnt), p.ReplyTxStatus(req.ReqID, stats), task.done()) }()
}() }
case GetTxStatusMsg: case GetTxStatusMsg:
if pm.txpool == nil { if pm.txpool == nil {
@ -1016,19 +1045,19 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
return errResp(ErrDecode, "msg %v: %v", msg, err) return errResp(ErrDecode, "msg %v: %v", msg, err)
} }
reqCnt := len(req.Hashes) reqCnt := len(req.Hashes)
if !accept(req.ReqID, uint64(reqCnt), MaxTxStatus) { if accept(req.ReqID, uint64(reqCnt), MaxTxStatus) {
return errResp(ErrRequestRejected, "") go func() {
} stats := make([]light.TxStatus, len(req.Hashes))
go func() { for i, hash := range req.Hashes {
stats := make([]light.TxStatus, len(req.Hashes)) if i != 0 && !task.waitOrStop() {
for i, hash := range req.Hashes { sendResponse(req.ReqID, 0, nil, task.servingTime)
if i != 0 && !task.waitOrStop() { return
return }
stats[i] = pm.txStatus(hash)
} }
stats[i] = pm.txStatus(hash) sendResponse(req.ReqID, uint64(reqCnt), p.ReplyTxStatus(req.ReqID, stats), task.done())
} }()
sendResponse(req.ReqID, uint64(reqCnt), p.ReplyTxStatus(req.ReqID, stats), task.done()) }
}()
case TxStatusMsg: case TxStatusMsg:
if pm.odr == nil { if pm.odr == nil {
@ -1053,6 +1082,26 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
Obj: resp.Status, Obj: resp.Status,
} }
case StopMsg:
if pm.odr == nil {
return errResp(ErrUnexpectedResponse, "")
}
p.freezeServer(true)
pm.retriever.frozen(p)
p.Log().Warn("Service stopped")
case ResumeMsg:
if pm.odr == nil {
return errResp(ErrUnexpectedResponse, "")
}
var bv uint64
if err := msg.Decode(&bv); err != nil {
return errResp(ErrDecode, "msg %v: %v", msg, err)
}
p.fcServer.ResumeFreeze(bv)
p.freezeServer(false)
p.Log().Warn("Service resumed")
default: default:
p.Log().Trace("Received unknown message", "code", msg.Code) p.Log().Trace("Received unknown message", "code", msg.Code)
return errResp(ErrInvalidMsgCode, "%v", msg.Code) return errResp(ErrInvalidMsgCode, "%v", msg.Code)

@ -24,6 +24,7 @@ import (
"time" "time"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/consensus/ethash" "github.com/ethereum/go-ethereum/consensus/ethash"
"github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/rawdb"
@ -438,7 +439,7 @@ func TestTransactionStatusLes2(t *testing.T) {
config.Journal = "" config.Journal = ""
txpool := core.NewTxPool(config, params.TestChainConfig, chain) txpool := core.NewTxPool(config, params.TestChainConfig, chain)
pm.txpool = txpool pm.txpool = txpool
peer, _ := newTestPeer(t, "peer", 2, pm, true) peer, _ := newTestPeer(t, "peer", 2, pm, true, 0)
defer peer.close() defer peer.close()
var reqID uint64 var reqID uint64
@ -519,3 +520,51 @@ func TestTransactionStatusLes2(t *testing.T) {
test(tx1, false, light.TxStatus{Status: core.TxStatusPending}) test(tx1, false, light.TxStatus{Status: core.TxStatusPending})
test(tx2, false, light.TxStatus{Status: core.TxStatusPending}) test(tx2, false, light.TxStatus{Status: core.TxStatusPending})
} }
func TestStopResumeLes3(t *testing.T) {
db := rawdb.NewMemoryDatabase()
clock := &mclock.Simulated{}
testCost := testBufLimit / 10
pm, err := newTestProtocolManager(false, 0, nil, nil, nil, db, nil, testCost, clock)
if err != nil {
t.Fatalf("Failed to create protocol manager: %v", err)
}
peer, _ := newTestPeer(t, "peer", 3, pm, true, testCost)
defer peer.close()
expBuf := testBufLimit
var reqID uint64
req := func() {
reqID++
sendRequest(peer.app, GetBlockHeadersMsg, reqID, testCost, &getBlockHeadersData{Origin: hashOrNumber{Hash: common.Hash{1}}, Amount: 1})
}
for i := 1; i <= 5; i++ {
// send requests while we still have enough buffer and expect a response
for expBuf >= testCost {
req()
expBuf -= testCost
if err := expectResponse(peer.app, BlockHeadersMsg, reqID, expBuf, nil); err != nil {
t.Errorf("expected response and failed: %v", err)
}
}
// send some more requests in excess and expect a single StopMsg
c := i
for c > 0 {
req()
c--
}
if err := p2p.ExpectMsg(peer.app, StopMsg, nil); err != nil {
t.Errorf("expected StopMsg and failed: %v", err)
}
// wait until the buffer is recharged by half of the limit
wait := testBufLimit / testBufRecharge / 2
clock.Run(time.Millisecond * time.Duration(wait))
// expect a ResumeMsg with the partially recharged buffer value
expBuf += testBufRecharge * wait
if err := p2p.ExpectMsg(peer.app, ResumeMsg, expBuf); err != nil {
t.Errorf("expected ResumeMsg and failed: %v", err)
}
}
}

@ -62,7 +62,8 @@ var (
testEventEmitterCode = common.Hex2Bytes("60606040523415600e57600080fd5b7f57050ab73f6b9ebdd9f76b8d4997793f48cf956e965ee070551b9ca0bb71584e60405160405180910390a160358060476000396000f3006060604052600080fd00a165627a7a723058203f727efcad8b5811f8cb1fc2620ce5e8c63570d697aef968172de296ea3994140029") testEventEmitterCode = common.Hex2Bytes("60606040523415600e57600080fd5b7f57050ab73f6b9ebdd9f76b8d4997793f48cf956e965ee070551b9ca0bb71584e60405160405180910390a160358060476000396000f3006060604052600080fd00a165627a7a723058203f727efcad8b5811f8cb1fc2620ce5e8c63570d697aef968172de296ea3994140029")
testEventEmitterAddr common.Address testEventEmitterAddr common.Address
testBufLimit = uint64(100) testBufLimit = uint64(1000000)
testBufRecharge = uint64(1000)
) )
/* /*
@ -138,7 +139,7 @@ func testIndexers(db ethdb.Database, odr light.OdrBackend, iConfig *light.Indexe
// newTestProtocolManager creates a new protocol manager for testing purposes, // newTestProtocolManager creates a new protocol manager for testing purposes,
// with the given number of blocks already known, potential notification // with the given number of blocks already known, potential notification
// channels for different events and relative chain indexers array. // channels for different events and relative chain indexers array.
func newTestProtocolManager(lightSync bool, blocks int, generator func(int, *core.BlockGen), odr *LesOdr, peers *peerSet, db ethdb.Database, ulcConfig *eth.ULCConfig) (*ProtocolManager, error) { func newTestProtocolManager(lightSync bool, blocks int, generator func(int, *core.BlockGen), odr *LesOdr, peers *peerSet, db ethdb.Database, ulcConfig *eth.ULCConfig, testCost uint64, clock mclock.Clock) (*ProtocolManager, error) {
var ( var (
evmux = new(event.TypeMux) evmux = new(event.TypeMux)
engine = ethash.NewFaker() engine = ethash.NewFaker()
@ -177,14 +178,15 @@ func newTestProtocolManager(lightSync bool, blocks int, generator func(int, *cor
if !lightSync { if !lightSync {
srv := &LesServer{lesCommons: lesCommons{protocolManager: pm}} srv := &LesServer{lesCommons: lesCommons{protocolManager: pm}}
pm.server = srv pm.server = srv
pm.servingQueue = newServingQueue(int64(time.Millisecond*10), 1, nil)
pm.servingQueue.setThreads(4) pm.servingQueue.setThreads(4)
srv.defParams = flowcontrol.ServerParams{ srv.defParams = flowcontrol.ServerParams{
BufLimit: testBufLimit, BufLimit: testBufLimit,
MinRecharge: 1, MinRecharge: testBufRecharge,
} }
srv.testCost = testCost
srv.fcManager = flowcontrol.NewClientManager(nil, &mclock.System{}) srv.fcManager = flowcontrol.NewClientManager(nil, clock)
} }
pm.Start(1000) pm.Start(1000)
return pm, nil return pm, nil
@ -195,7 +197,7 @@ func newTestProtocolManager(lightSync bool, blocks int, generator func(int, *cor
// channels for different events and relative chain indexers array. In case of an error, the constructor force- // channels for different events and relative chain indexers array. In case of an error, the constructor force-
// fails the test. // fails the test.
func newTestProtocolManagerMust(t *testing.T, lightSync bool, blocks int, generator func(int, *core.BlockGen), odr *LesOdr, peers *peerSet, db ethdb.Database, ulcConfig *eth.ULCConfig) *ProtocolManager { func newTestProtocolManagerMust(t *testing.T, lightSync bool, blocks int, generator func(int, *core.BlockGen), odr *LesOdr, peers *peerSet, db ethdb.Database, ulcConfig *eth.ULCConfig) *ProtocolManager {
pm, err := newTestProtocolManager(lightSync, blocks, generator, odr, peers, db, ulcConfig) pm, err := newTestProtocolManager(lightSync, blocks, generator, odr, peers, db, ulcConfig, 0, &mclock.System{})
if err != nil { if err != nil {
t.Fatalf("Failed to create protocol manager: %v", err) t.Fatalf("Failed to create protocol manager: %v", err)
} }
@ -210,7 +212,7 @@ type testPeer struct {
} }
// newTestPeer creates a new peer registered at the given protocol manager. // newTestPeer creates a new peer registered at the given protocol manager.
func newTestPeer(t *testing.T, name string, version int, pm *ProtocolManager, shake bool) (*testPeer, <-chan error) { func newTestPeer(t *testing.T, name string, version int, pm *ProtocolManager, shake bool, testCost uint64) (*testPeer, <-chan error) {
// Create a message pipe to communicate through // Create a message pipe to communicate through
app, net := p2p.MsgPipe() app, net := p2p.MsgPipe()
@ -242,7 +244,7 @@ func newTestPeer(t *testing.T, name string, version int, pm *ProtocolManager, sh
head = pm.blockchain.CurrentHeader() head = pm.blockchain.CurrentHeader()
td = pm.blockchain.GetTd(head.Hash(), head.Number.Uint64()) td = pm.blockchain.GetTd(head.Hash(), head.Number.Uint64())
) )
tp.handshake(t, td, head.Hash(), head.Number.Uint64(), genesis.Hash()) tp.handshake(t, td, head.Hash(), head.Number.Uint64(), genesis.Hash(), testCost)
} }
return tp, errc return tp, errc
} }
@ -282,7 +284,7 @@ func newTestPeerPair(name string, version int, pm, pm2 *ProtocolManager) (*peer,
// handshake simulates a trivial handshake that expects the same state from the // handshake simulates a trivial handshake that expects the same state from the
// remote side as we are simulating locally. // remote side as we are simulating locally.
func (p *testPeer) handshake(t *testing.T, td *big.Int, head common.Hash, headNum uint64, genesis common.Hash) { func (p *testPeer) handshake(t *testing.T, td *big.Int, head common.Hash, headNum uint64, genesis common.Hash, testCost uint64) {
var expList keyValueList var expList keyValueList
expList = expList.add("protocolVersion", uint64(p.version)) expList = expList.add("protocolVersion", uint64(p.version))
expList = expList.add("networkId", uint64(NetworkId)) expList = expList.add("networkId", uint64(NetworkId))
@ -295,10 +297,11 @@ func (p *testPeer) handshake(t *testing.T, td *big.Int, head common.Hash, headNu
expList = expList.add("serveHeaders", nil) expList = expList.add("serveHeaders", nil)
expList = expList.add("serveChainSince", uint64(0)) expList = expList.add("serveChainSince", uint64(0))
expList = expList.add("serveStateSince", uint64(0)) expList = expList.add("serveStateSince", uint64(0))
expList = expList.add("serveRecentState", uint64(core.TriesInMemory-4))
expList = expList.add("txRelay", nil) expList = expList.add("txRelay", nil)
expList = expList.add("flowControl/BL", testBufLimit) expList = expList.add("flowControl/BL", testBufLimit)
expList = expList.add("flowControl/MRR", uint64(1)) expList = expList.add("flowControl/MRR", testBufRecharge)
expList = expList.add("flowControl/MRC", testCostList()) expList = expList.add("flowControl/MRC", testCostList(testCost))
if err := p2p.ExpectMsg(p.app, StatusMsg, expList); err != nil { if err := p2p.ExpectMsg(p.app, StatusMsg, expList); err != nil {
t.Fatalf("status recv: %v", err) t.Fatalf("status recv: %v", err)
@ -309,7 +312,7 @@ func (p *testPeer) handshake(t *testing.T, td *big.Int, head common.Hash, headNu
p.fcParams = flowcontrol.ServerParams{ p.fcParams = flowcontrol.ServerParams{
BufLimit: testBufLimit, BufLimit: testBufLimit,
MinRecharge: 1, MinRecharge: testBufRecharge,
} }
} }
@ -338,7 +341,7 @@ func newServerEnv(t *testing.T, blocks int, protocol int, waitIndexers func(*cor
cIndexer, bIndexer, btIndexer := testIndexers(db, nil, light.TestServerIndexerConfig) cIndexer, bIndexer, btIndexer := testIndexers(db, nil, light.TestServerIndexerConfig)
pm := newTestProtocolManagerMust(t, false, blocks, testChainGen, nil, nil, db, nil) pm := newTestProtocolManagerMust(t, false, blocks, testChainGen, nil, nil, db, nil)
peer, _ := newTestPeer(t, "peer", protocol, pm, true) peer, _ := newTestPeer(t, "peer", protocol, pm, true, 0)
cIndexer.Start(pm.blockchain.(*core.BlockChain)) cIndexer.Start(pm.blockchain.(*core.BlockChain))
bIndexer.Start(pm.blockchain.(*core.BlockChain)) bIndexer.Start(pm.blockchain.(*core.BlockChain))

@ -20,11 +20,14 @@ import (
"errors" "errors"
"fmt" "fmt"
"math/big" "math/big"
"math/rand"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/eth" "github.com/ethereum/go-ethereum/eth"
"github.com/ethereum/go-ethereum/les/flowcontrol" "github.com/ethereum/go-ethereum/les/flowcontrol"
@ -47,10 +50,16 @@ const (
allowedUpdateRate = time.Millisecond * 10 // time constant for recharging one byte of allowance allowedUpdateRate = time.Millisecond * 10 // time constant for recharging one byte of allowance
) )
const (
freezeTimeBase = time.Millisecond * 700 // fixed component of client freeze time
freezeTimeRandom = time.Millisecond * 600 // random component of client freeze time
freezeCheckPeriod = time.Millisecond * 100 // buffer value recheck period after initial freeze time has elapsed
)
// if the total encoded size of a sent transaction batch is over txSizeCostLimit // if the total encoded size of a sent transaction batch is over txSizeCostLimit
// per transaction then the request cost is calculated as proportional to the // per transaction then the request cost is calculated as proportional to the
// encoded size instead of the transaction count // encoded size instead of the transaction count
const txSizeCostLimit = 0x10000 const txSizeCostLimit = 0x4000
const ( const (
announceTypeNone = iota announceTypeNone = iota
@ -86,14 +95,17 @@ type peer struct {
responseErrors int responseErrors int
updateCounter uint64 updateCounter uint64
updateTime mclock.AbsTime updateTime mclock.AbsTime
frozen uint32 // 1 if client is in frozen state
fcClient *flowcontrol.ClientNode // nil if the peer is server only fcClient *flowcontrol.ClientNode // nil if the peer is server only
fcServer *flowcontrol.ServerNode // nil if the peer is client only fcServer *flowcontrol.ServerNode // nil if the peer is client only
fcParams flowcontrol.ServerParams fcParams flowcontrol.ServerParams
fcCosts requestCostTable fcCosts requestCostTable
isTrusted bool isTrusted bool
isOnlyAnnounce bool isOnlyAnnounce bool
chainSince, chainRecent uint64
stateSince, stateRecent uint64
} }
func newPeer(version int, network uint64, isTrusted bool, p *p2p.Peer, rw p2p.MsgReadWriter) *peer { func newPeer(version int, network uint64, isTrusted bool, p *p2p.Peer, rw p2p.MsgReadWriter) *peer {
@ -129,8 +141,59 @@ func (p *peer) rejectUpdate(size uint64) bool {
return p.updateCounter > allowedUpdateBytes return p.updateCounter > allowedUpdateBytes
} }
// freezeClient temporarily puts the client in a frozen state which means all
// unprocessed and subsequent requests are dropped. Unfreezing happens automatically
// after a short time if the client's buffer value is at least in the slightly positive
// region. The client is also notified about being frozen/unfrozen with a Stop/Resume
// message.
func (p *peer) freezeClient() {
if p.version < lpv3 {
// if Stop/Resume is not supported then just drop the peer after setting
// its frozen status permanently
atomic.StoreUint32(&p.frozen, 1)
p.Peer.Disconnect(p2p.DiscUselessPeer)
return
}
if atomic.SwapUint32(&p.frozen, 1) == 0 {
go func() {
p.SendStop()
time.Sleep(freezeTimeBase + time.Duration(rand.Int63n(int64(freezeTimeRandom))))
for {
bufValue, bufLimit := p.fcClient.BufferStatus()
if bufLimit == 0 {
return
}
if bufValue <= bufLimit/8 {
time.Sleep(freezeCheckPeriod)
} else {
atomic.StoreUint32(&p.frozen, 0)
p.SendResume(bufValue)
break
}
}
}()
}
}
// freezeServer processes Stop/Resume messages from the given server
func (p *peer) freezeServer(frozen bool) {
var f uint32
if frozen {
f = 1
}
if atomic.SwapUint32(&p.frozen, f) != f && frozen {
p.sendQueue.clear()
}
}
// isFrozen returns true if the client is frozen or the server has put our
// client in frozen state
func (p *peer) isFrozen() bool {
return atomic.LoadUint32(&p.frozen) != 0
}
func (p *peer) canQueue() bool { func (p *peer) canQueue() bool {
return p.sendQueue.canQueue() return p.sendQueue.canQueue() && !p.isFrozen()
} }
func (p *peer) queueSend(f func()) { func (p *peer) queueSend(f func()) {
@ -265,10 +328,21 @@ func (p *peer) GetTxRelayCost(amount, size int) uint64 {
// HasBlock checks if the peer has a given block // HasBlock checks if the peer has a given block
func (p *peer) HasBlock(hash common.Hash, number uint64, hasState bool) bool { func (p *peer) HasBlock(hash common.Hash, number uint64, hasState bool) bool {
var head, since, recent uint64
p.lock.RLock() p.lock.RLock()
if p.headInfo != nil {
head = p.headInfo.Number
}
if hasState {
since = p.stateSince
recent = p.stateRecent
} else {
since = p.chainSince
recent = p.chainRecent
}
hasBlock := p.hasBlock hasBlock := p.hasBlock
p.lock.RUnlock() p.lock.RUnlock()
return hasBlock != nil && hasBlock(hash, number, hasState) return head >= number && number >= since && (recent == 0 || number+recent+4 > head) && hasBlock != nil && hasBlock(hash, number, hasState)
} }
// SendAnnounce announces the availability of a number of blocks through // SendAnnounce announces the availability of a number of blocks through
@ -277,6 +351,16 @@ func (p *peer) SendAnnounce(request announceData) error {
return p2p.Send(p.rw, AnnounceMsg, request) return p2p.Send(p.rw, AnnounceMsg, request)
} }
// SendStop notifies the client about being in frozen state
func (p *peer) SendStop() error {
return p2p.Send(p.rw, StopMsg, struct{}{})
}
// SendResume notifies the client about getting out of frozen state
func (p *peer) SendResume(bv uint64) error {
return p2p.Send(p.rw, ResumeMsg, bv)
}
// ReplyBlockHeaders creates a reply with a batch of block headers // ReplyBlockHeaders creates a reply with a batch of block headers
func (p *peer) ReplyBlockHeaders(reqID uint64, headers []*types.Header) *reply { func (p *peer) ReplyBlockHeaders(reqID uint64, headers []*types.Header) *reply {
data, _ := rlp.EncodeToBytes(headers) data, _ := rlp.EncodeToBytes(headers)
@ -464,19 +548,19 @@ func (p *peer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis
send = send.add("genesisHash", genesis) send = send.add("genesisHash", genesis)
if server != nil { if server != nil {
if !server.onlyAnnounce { if !server.onlyAnnounce {
//only announce server. It sends only announse requests
send = send.add("serveHeaders", nil) send = send.add("serveHeaders", nil)
send = send.add("serveChainSince", uint64(0)) send = send.add("serveChainSince", uint64(0))
send = send.add("serveStateSince", uint64(0)) send = send.add("serveStateSince", uint64(0))
send = send.add("serveRecentState", uint64(core.TriesInMemory-4))
send = send.add("txRelay", nil) send = send.add("txRelay", nil)
} }
send = send.add("flowControl/BL", server.defParams.BufLimit) send = send.add("flowControl/BL", server.defParams.BufLimit)
send = send.add("flowControl/MRR", server.defParams.MinRecharge) send = send.add("flowControl/MRR", server.defParams.MinRecharge)
var costList RequestCostList var costList RequestCostList
if server.costTracker != nil { if server.costTracker != nil {
costList = server.costTracker.makeCostList() costList = server.costTracker.makeCostList(server.costTracker.globalFactor())
} else { } else {
costList = testCostList() costList = testCostList(server.testCost)
} }
send = send.add("flowControl/MRC", costList) send = send.add("flowControl/MRC", costList)
p.fcCosts = costList.decode(ProtocolLengths[uint(p.version)]) p.fcCosts = costList.decode(ProtocolLengths[uint(p.version)])
@ -544,12 +628,18 @@ func (p *peer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis
p.fcClient = flowcontrol.NewClientNode(server.fcManager, server.defParams) p.fcClient = flowcontrol.NewClientNode(server.fcManager, server.defParams)
} else { } else {
//mark OnlyAnnounce server if "serveHeaders", "serveChainSince", "serveStateSince" or "txRelay" fields don't exist //mark OnlyAnnounce server if "serveHeaders", "serveChainSince", "serveStateSince" or "txRelay" fields don't exist
if recv.get("serveChainSince", nil) != nil { if recv.get("serveChainSince", &p.chainSince) != nil {
p.isOnlyAnnounce = true p.isOnlyAnnounce = true
} }
if recv.get("serveStateSince", nil) != nil { if recv.get("serveRecentChain", &p.chainRecent) != nil {
p.chainRecent = 0
}
if recv.get("serveStateSince", &p.stateSince) != nil {
p.isOnlyAnnounce = true p.isOnlyAnnounce = true
} }
if recv.get("serveRecentState", &p.stateRecent) != nil {
p.stateRecent = 0
}
if recv.get("txRelay", nil) != nil { if recv.get("txRelay", nil) != nil {
p.isOnlyAnnounce = true p.isOnlyAnnounce = true
} }

@ -54,7 +54,7 @@ func TestPeerHandshakeSetAnnounceTypeToAnnounceTypeSignedForTrustedPeer(t *testi
l = l.add("txRelay", nil) l = l.add("txRelay", nil)
l = l.add("flowControl/BL", uint64(0)) l = l.add("flowControl/BL", uint64(0))
l = l.add("flowControl/MRR", uint64(0)) l = l.add("flowControl/MRR", uint64(0))
l = l.add("flowControl/MRC", testCostList()) l = l.add("flowControl/MRC", testCostList(0))
return l return l
}, },
@ -99,7 +99,7 @@ func TestPeerHandshakeAnnounceTypeSignedForTrustedPeersPeerNotInTrusted(t *testi
l = l.add("txRelay", nil) l = l.add("txRelay", nil)
l = l.add("flowControl/BL", uint64(0)) l = l.add("flowControl/BL", uint64(0))
l = l.add("flowControl/MRR", uint64(0)) l = l.add("flowControl/MRR", uint64(0))
l = l.add("flowControl/MRC", testCostList()) l = l.add("flowControl/MRC", testCostList(0))
return l return l
}, },

@ -32,17 +32,18 @@ import (
// Constants to match up protocol versions and messages // Constants to match up protocol versions and messages
const ( const (
lpv2 = 2 lpv2 = 2
lpv3 = 3
) )
// Supported versions of the les protocol (first is primary) // Supported versions of the les protocol (first is primary)
var ( var (
ClientProtocolVersions = []uint{lpv2} ClientProtocolVersions = []uint{lpv2, lpv3}
ServerProtocolVersions = []uint{lpv2} ServerProtocolVersions = []uint{lpv2, lpv3}
AdvertiseProtocolVersions = []uint{lpv2} // clients are searching for the first advertised protocol in the list AdvertiseProtocolVersions = []uint{lpv2} // clients are searching for the first advertised protocol in the list
) )
// Number of implemented message corresponding to different protocol versions. // Number of implemented message corresponding to different protocol versions.
var ProtocolLengths = map[uint]uint64{lpv2: 22} var ProtocolLengths = map[uint]uint64{lpv2: 22, lpv3: 24}
const ( const (
NetworkId = 1 NetworkId = 1
@ -70,6 +71,9 @@ const (
SendTxV2Msg = 0x13 SendTxV2Msg = 0x13
GetTxStatusMsg = 0x14 GetTxStatusMsg = 0x14
TxStatusMsg = 0x15 TxStatusMsg = 0x15
// Protocol messages introduced in LPV3
StopMsg = 0x16
ResumeMsg = 0x17
) )
type requestInfo struct { type requestInfo struct {

@ -78,8 +78,8 @@ type sentReq struct {
// after which delivered is set to true, the validity of the response is sent on the // after which delivered is set to true, the validity of the response is sent on the
// valid channel and no more responses are accepted. // valid channel and no more responses are accepted.
type sentReqToPeer struct { type sentReqToPeer struct {
delivered bool delivered, frozen bool
valid chan bool event chan int
} }
// reqPeerEvent is sent by the request-from-peer goroutine (tryRequest) to the // reqPeerEvent is sent by the request-from-peer goroutine (tryRequest) to the
@ -95,6 +95,7 @@ const (
rpHardTimeout rpHardTimeout
rpDeliveredValid rpDeliveredValid
rpDeliveredInvalid rpDeliveredInvalid
rpNotDelivered
) )
// newRetrieveManager creates the retrieve manager // newRetrieveManager creates the retrieve manager
@ -149,7 +150,7 @@ func (rm *retrieveManager) sendReq(reqID uint64, req *distReq, val validatorFunc
req.request = func(p distPeer) func() { req.request = func(p distPeer) func() {
// before actually sending the request, put an entry into the sentTo map // before actually sending the request, put an entry into the sentTo map
r.lock.Lock() r.lock.Lock()
r.sentTo[p] = sentReqToPeer{false, make(chan bool, 1)} r.sentTo[p] = sentReqToPeer{delivered: false, frozen: false, event: make(chan int, 1)}
r.lock.Unlock() r.lock.Unlock()
return request(p) return request(p)
} }
@ -173,6 +174,17 @@ func (rm *retrieveManager) deliver(peer distPeer, msg *Msg) error {
return errResp(ErrUnexpectedResponse, "reqID = %v", msg.ReqID) return errResp(ErrUnexpectedResponse, "reqID = %v", msg.ReqID)
} }
// frozen is called by the LES protocol manager when a server has suspended its service and we
// should not expect an answer for the requests already sent there
func (rm *retrieveManager) frozen(peer distPeer) {
rm.lock.RLock()
defer rm.lock.RUnlock()
for _, req := range rm.sentReqs {
req.frozen(peer)
}
}
// reqStateFn represents a state of the retrieve loop state machine // reqStateFn represents a state of the retrieve loop state machine
type reqStateFn func() reqStateFn type reqStateFn func() reqStateFn
@ -215,7 +227,7 @@ func (r *sentReq) stateRequesting() reqStateFn {
go r.tryRequest() go r.tryRequest()
r.lastReqQueued = true r.lastReqQueued = true
return r.stateRequesting return r.stateRequesting
case rpDeliveredInvalid: case rpDeliveredInvalid, rpNotDelivered:
// if it was the last sent request (set to nil by update) then start a new one // if it was the last sent request (set to nil by update) then start a new one
if !r.lastReqQueued && r.lastReqSentTo == nil { if !r.lastReqQueued && r.lastReqSentTo == nil {
go r.tryRequest() go r.tryRequest()
@ -277,7 +289,7 @@ func (r *sentReq) update(ev reqPeerEvent) {
r.reqSrtoCount++ r.reqSrtoCount++
case rpHardTimeout: case rpHardTimeout:
r.reqSrtoCount-- r.reqSrtoCount--
case rpDeliveredValid, rpDeliveredInvalid: case rpDeliveredValid, rpDeliveredInvalid, rpNotDelivered:
if ev.peer == r.lastReqSentTo { if ev.peer == r.lastReqSentTo {
r.lastReqSentTo = nil r.lastReqSentTo = nil
} else { } else {
@ -343,12 +355,13 @@ func (r *sentReq) tryRequest() {
}() }()
select { select {
case ok := <-s.valid: case event := <-s.event:
if ok { if event == rpNotDelivered {
r.eventsCh <- reqPeerEvent{rpDeliveredValid, p} r.lock.Lock()
} else { delete(r.sentTo, p)
r.eventsCh <- reqPeerEvent{rpDeliveredInvalid, p} r.lock.Unlock()
} }
r.eventsCh <- reqPeerEvent{event, p}
return return
case <-time.After(softRequestTimeout): case <-time.After(softRequestTimeout):
srto = true srto = true
@ -356,12 +369,13 @@ func (r *sentReq) tryRequest() {
} }
select { select {
case ok := <-s.valid: case event := <-s.event:
if ok { if event == rpNotDelivered {
r.eventsCh <- reqPeerEvent{rpDeliveredValid, p} r.lock.Lock()
} else { delete(r.sentTo, p)
r.eventsCh <- reqPeerEvent{rpDeliveredInvalid, p} r.lock.Unlock()
} }
r.eventsCh <- reqPeerEvent{event, p}
case <-time.After(hardRequestTimeout): case <-time.After(hardRequestTimeout):
hrto = true hrto = true
r.eventsCh <- reqPeerEvent{rpHardTimeout, p} r.eventsCh <- reqPeerEvent{rpHardTimeout, p}
@ -377,15 +391,37 @@ func (r *sentReq) deliver(peer distPeer, msg *Msg) error {
if !ok || s.delivered { if !ok || s.delivered {
return errResp(ErrUnexpectedResponse, "reqID = %v", msg.ReqID) return errResp(ErrUnexpectedResponse, "reqID = %v", msg.ReqID)
} }
if s.frozen {
return nil
}
valid := r.validate(peer, msg) == nil valid := r.validate(peer, msg) == nil
r.sentTo[peer] = sentReqToPeer{true, s.valid} r.sentTo[peer] = sentReqToPeer{delivered: true, frozen: false, event: s.event}
s.valid <- valid if valid {
s.event <- rpDeliveredValid
} else {
s.event <- rpDeliveredInvalid
}
if !valid { if !valid {
return errResp(ErrInvalidResponse, "reqID = %v", msg.ReqID) return errResp(ErrInvalidResponse, "reqID = %v", msg.ReqID)
} }
return nil return nil
} }
// frozen sends a "not delivered" event to the peer event channel belonging to the
// given peer if the request has been sent there, causing the state machine to not
// expect an answer and potentially even send the request to the same peer again
// when canSend allows it.
func (r *sentReq) frozen(peer distPeer) {
r.lock.Lock()
defer r.lock.Unlock()
s, ok := r.sentTo[peer]
if ok && !s.delivered && !s.frozen {
r.sentTo[peer] = sentReqToPeer{delivered: false, frozen: true, event: s.event}
s.event <- rpNotDelivered
}
}
// stop stops the retrieval process and sets an error code that will be returned // stop stops the retrieval process and sets an error code that will be returned
// by getError // by getError
func (r *sentReq) stop(err error) { func (r *sentReq) stop(err error) {

@ -19,6 +19,7 @@ package les
import ( import (
"crypto/ecdsa" "crypto/ecdsa"
"sync" "sync"
"time"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/common/mclock"
@ -26,6 +27,7 @@ import (
"github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/eth" "github.com/ethereum/go-ethereum/eth"
"github.com/ethereum/go-ethereum/les/csvlogger"
"github.com/ethereum/go-ethereum/les/flowcontrol" "github.com/ethereum/go-ethereum/les/flowcontrol"
"github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/light"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
@ -37,26 +39,43 @@ import (
const bufLimitRatio = 6000 // fixed bufLimit/MRR ratio const bufLimitRatio = 6000 // fixed bufLimit/MRR ratio
const (
logFileName = "" // csv log file name (disabled if empty)
logClientPoolMetrics = true // log client pool metrics
logClientPoolEvents = false // detailed client pool event logging
logRequestServing = true // log request serving metrics and events
logBlockProcEvents = true // log block processing events
logProtocolHandler = true // log protocol handler events
)
type LesServer struct { type LesServer struct {
lesCommons lesCommons
fcManager *flowcontrol.ClientManager // nil if our node is client only fcManager *flowcontrol.ClientManager // nil if our node is client only
costTracker *costTracker costTracker *costTracker
testCost uint64
defParams flowcontrol.ServerParams defParams flowcontrol.ServerParams
lesTopics []discv5.Topic lesTopics []discv5.Topic
privateKey *ecdsa.PrivateKey privateKey *ecdsa.PrivateKey
quitSync chan struct{} quitSync chan struct{}
onlyAnnounce bool onlyAnnounce bool
csvLogger *csvlogger.Logger
logTotalCap *csvlogger.Channel
thcNormal, thcBlockProcessing int // serving thread count for normal operation and block processing mode thcNormal, thcBlockProcessing int // serving thread count for normal operation and block processing mode
maxPeers int maxPeers int
freeClientCap uint64 minCapacity, freeClientCap uint64
freeClientPool *freeClientPool freeClientPool *freeClientPool
priorityClientPool *priorityClientPool priorityClientPool *priorityClientPool
} }
func NewLesServer(eth *eth.Ethereum, config *eth.Config) (*LesServer, error) { func NewLesServer(eth *eth.Ethereum, config *eth.Config) (*LesServer, error) {
var csvLogger *csvlogger.Logger
if logFileName != "" {
csvLogger = csvlogger.NewLogger(logFileName, time.Second*10, "event, peerId")
}
quitSync := make(chan struct{}) quitSync := make(chan struct{})
pm, err := NewProtocolManager( pm, err := NewProtocolManager(
eth.BlockChain().Config(), eth.BlockChain().Config(),
@ -78,6 +97,14 @@ func NewLesServer(eth *eth.Ethereum, config *eth.Config) (*LesServer, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if logProtocolHandler {
pm.logger = csvLogger
}
requestLogger := csvLogger
if !logRequestServing {
requestLogger = nil
}
pm.servingQueue = newServingQueue(int64(time.Millisecond*10), float64(config.LightServ)/100, requestLogger)
lesTopics := make([]discv5.Topic, len(AdvertiseProtocolVersions)) lesTopics := make([]discv5.Topic, len(AdvertiseProtocolVersions))
for i, pv := range AdvertiseProtocolVersions { for i, pv := range AdvertiseProtocolVersions {
@ -93,11 +120,13 @@ func NewLesServer(eth *eth.Ethereum, config *eth.Config) (*LesServer, error) {
bloomTrieIndexer: light.NewBloomTrieIndexer(eth.ChainDb(), nil, params.BloomBitsBlocks, params.BloomTrieFrequency), bloomTrieIndexer: light.NewBloomTrieIndexer(eth.ChainDb(), nil, params.BloomBitsBlocks, params.BloomTrieFrequency),
protocolManager: pm, protocolManager: pm,
}, },
costTracker: newCostTracker(eth.ChainDb(), config),
quitSync: quitSync, quitSync: quitSync,
lesTopics: lesTopics, lesTopics: lesTopics,
onlyAnnounce: config.OnlyAnnounce, onlyAnnounce: config.OnlyAnnounce,
csvLogger: csvLogger,
logTotalCap: requestLogger.NewChannel("totalCapacity", 0.01),
} }
srv.costTracker, srv.minCapacity = newCostTracker(eth.ChainDb(), config, requestLogger)
logger := log.New() logger := log.New()
pm.server = srv pm.server = srv
@ -144,7 +173,11 @@ func (s *LesServer) APIs() []rpc.API {
func (s *LesServer) startEventLoop() { func (s *LesServer) startEventLoop() {
s.protocolManager.wg.Add(1) s.protocolManager.wg.Add(1)
var processing bool blockProcLogger := s.csvLogger
if !logBlockProcEvents {
blockProcLogger = nil
}
var processing, procLast bool
blockProcFeed := make(chan bool, 100) blockProcFeed := make(chan bool, 100)
s.protocolManager.blockchain.(*core.BlockChain).SubscribeBlockProcessingEvent(blockProcFeed) s.protocolManager.blockchain.(*core.BlockChain).SubscribeBlockProcessingEvent(blockProcFeed)
totalRechargeCh := make(chan uint64, 100) totalRechargeCh := make(chan uint64, 100)
@ -152,17 +185,25 @@ func (s *LesServer) startEventLoop() {
totalCapacityCh := make(chan uint64, 100) totalCapacityCh := make(chan uint64, 100)
updateRecharge := func() { updateRecharge := func() {
if processing { if processing {
if !procLast {
blockProcLogger.Event("block processing started")
}
s.protocolManager.servingQueue.setThreads(s.thcBlockProcessing) s.protocolManager.servingQueue.setThreads(s.thcBlockProcessing)
s.fcManager.SetRechargeCurve(flowcontrol.PieceWiseLinear{{0, 0}, {totalRecharge, totalRecharge}}) s.fcManager.SetRechargeCurve(flowcontrol.PieceWiseLinear{{0, 0}, {totalRecharge, totalRecharge}})
} else { } else {
if procLast {
blockProcLogger.Event("block processing finished")
}
s.protocolManager.servingQueue.setThreads(s.thcNormal) s.protocolManager.servingQueue.setThreads(s.thcNormal)
s.fcManager.SetRechargeCurve(flowcontrol.PieceWiseLinear{{0, 0}, {totalRecharge / 10, totalRecharge}, {totalRecharge, totalRecharge}}) s.fcManager.SetRechargeCurve(flowcontrol.PieceWiseLinear{{0, 0}, {totalRecharge / 16, totalRecharge / 2}, {totalRecharge / 2, totalRecharge / 2}, {totalRecharge, totalRecharge}})
} }
procLast = processing
} }
updateRecharge() updateRecharge()
totalCapacity := s.fcManager.SubscribeTotalCapacity(totalCapacityCh) totalCapacity := s.fcManager.SubscribeTotalCapacity(totalCapacityCh)
s.priorityClientPool.setLimits(s.maxPeers, totalCapacity) s.priorityClientPool.setLimits(s.maxPeers, totalCapacity)
var maxFreePeers uint64
go func() { go func() {
for { for {
select { select {
@ -171,6 +212,12 @@ func (s *LesServer) startEventLoop() {
case totalRecharge = <-totalRechargeCh: case totalRecharge = <-totalRechargeCh:
updateRecharge() updateRecharge()
case totalCapacity = <-totalCapacityCh: case totalCapacity = <-totalCapacityCh:
s.logTotalCap.Update(float64(totalCapacity))
newFreePeers := totalCapacity / s.freeClientCap
if newFreePeers < maxFreePeers && newFreePeers < uint64(s.maxPeers) {
log.Warn("Reduced total capacity", "maxFreePeers", newFreePeers)
}
maxFreePeers = newFreePeers
s.priorityClientPool.setLimits(s.maxPeers, totalCapacity) s.priorityClientPool.setLimits(s.maxPeers, totalCapacity)
case <-s.protocolManager.quitSync: case <-s.protocolManager.quitSync:
s.protocolManager.wg.Done() s.protocolManager.wg.Done()
@ -189,9 +236,9 @@ func (s *LesServer) Start(srvr *p2p.Server) {
s.maxPeers = s.config.LightPeers s.maxPeers = s.config.LightPeers
totalRecharge := s.costTracker.totalRecharge() totalRecharge := s.costTracker.totalRecharge()
if s.maxPeers > 0 { if s.maxPeers > 0 {
s.freeClientCap = minCapacity //totalRecharge / uint64(s.maxPeers) s.freeClientCap = s.minCapacity //totalRecharge / uint64(s.maxPeers)
if s.freeClientCap < minCapacity { if s.freeClientCap < s.minCapacity {
s.freeClientCap = minCapacity s.freeClientCap = s.minCapacity
} }
if s.freeClientCap > 0 { if s.freeClientCap > 0 {
s.defParams = flowcontrol.ServerParams{ s.defParams = flowcontrol.ServerParams{
@ -200,15 +247,25 @@ func (s *LesServer) Start(srvr *p2p.Server) {
} }
} }
} }
freePeers := int(totalRecharge / s.freeClientCap)
if freePeers < s.maxPeers {
log.Warn("Light peer count limited", "specified", s.maxPeers, "allowed", freePeers)
}
s.freeClientPool = newFreeClientPool(s.chainDb, s.freeClientCap, 10000, mclock.System{}, func(id string) { go s.protocolManager.removePeer(id) }) maxCapacity := s.freeClientCap * uint64(s.maxPeers)
s.priorityClientPool = newPriorityClientPool(s.freeClientCap, s.protocolManager.peers, s.freeClientPool) if totalRecharge > maxCapacity {
maxCapacity = totalRecharge
}
s.fcManager.SetCapacityLimits(s.freeClientCap, maxCapacity, s.freeClientCap*2)
poolMetricsLogger := s.csvLogger
if !logClientPoolMetrics {
poolMetricsLogger = nil
}
poolEventLogger := s.csvLogger
if !logClientPoolEvents {
poolEventLogger = nil
}
s.freeClientPool = newFreeClientPool(s.chainDb, s.freeClientCap, 10000, mclock.System{}, func(id string) { go s.protocolManager.removePeer(id) }, poolMetricsLogger, poolEventLogger)
s.priorityClientPool = newPriorityClientPool(s.freeClientCap, s.protocolManager.peers, s.freeClientPool, poolMetricsLogger, poolEventLogger)
s.protocolManager.peers.notify(s.priorityClientPool) s.protocolManager.peers.notify(s.priorityClientPool)
s.csvLogger.Start()
s.startEventLoop() s.startEventLoop()
s.protocolManager.Start(s.config.LightPeers) s.protocolManager.Start(s.config.LightPeers)
if srvr.DiscV5 != nil { if srvr.DiscV5 != nil {
@ -233,6 +290,7 @@ func (s *LesServer) SetBloomBitsIndexer(bloomIndexer *core.ChainIndexer) {
// Stop stops the LES service // Stop stops the LES service
func (s *LesServer) Stop() { func (s *LesServer) Stop() {
s.fcManager.Stop()
s.chtIndexer.Close() s.chtIndexer.Close()
// bloom trie indexer is closed by parent bloombits indexer // bloom trie indexer is closed by parent bloombits indexer
go func() { go func() {
@ -241,6 +299,7 @@ func (s *LesServer) Stop() {
s.freeClientPool.stop() s.freeClientPool.stop()
s.costTracker.stop() s.costTracker.stop()
s.protocolManager.Stop() s.protocolManager.Stop()
s.csvLogger.Stop()
} }
// todo(rjl493456442) separate client and server implementation. // todo(rjl493456442) separate client and server implementation.

@ -17,16 +17,24 @@
package les package les
import ( import (
"fmt"
"sort"
"sync" "sync"
"sync/atomic"
"github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/common/prque" "github.com/ethereum/go-ethereum/common/prque"
"github.com/ethereum/go-ethereum/les/csvlogger"
) )
// servingQueue allows running tasks in a limited number of threads and puts the // servingQueue allows running tasks in a limited number of threads and puts the
// waiting tasks in a priority queue // waiting tasks in a priority queue
type servingQueue struct { type servingQueue struct {
tokenCh chan runToken recentTime, queuedTime, servingTimeDiff uint64
burstLimit, burstDropLimit uint64
burstDecRate float64
lastUpdate mclock.AbsTime
queueAddCh, queueBestCh chan *servingTask queueAddCh, queueBestCh chan *servingTask
stopThreadCh, quit chan struct{} stopThreadCh, quit chan struct{}
setThreadsCh chan int setThreadsCh chan int
@ -36,6 +44,10 @@ type servingQueue struct {
queue *prque.Prque // priority queue for waiting or suspended tasks queue *prque.Prque // priority queue for waiting or suspended tasks
best *servingTask // the highest priority task (not included in the queue) best *servingTask // the highest priority task (not included in the queue)
suspendBias int64 // priority bias against suspending an already running task suspendBias int64 // priority bias against suspending an already running task
logger *csvlogger.Logger
logRecentTime *csvlogger.Channel
logQueuedTime *csvlogger.Channel
} }
// servingTask represents a request serving task. Tasks can be implemented to // servingTask represents a request serving task. Tasks can be implemented to
@ -47,12 +59,13 @@ type servingQueue struct {
// - run: execute a single step; return true if finished // - run: execute a single step; return true if finished
// - after: executed after run finishes or returns an error, receives the total serving time // - after: executed after run finishes or returns an error, receives the total serving time
type servingTask struct { type servingTask struct {
sq *servingQueue sq *servingQueue
servingTime uint64 servingTime, timeAdded, maxTime, expTime uint64
priority int64 peer *peer
biasAdded bool priority int64
token runToken biasAdded bool
tokenCh chan runToken token runToken
tokenCh chan runToken
} }
// runToken received by servingTask.start allows the task to run. Closing the // runToken received by servingTask.start allows the task to run. Closing the
@ -63,20 +76,19 @@ type runToken chan struct{}
// start blocks until the task can start and returns true if it is allowed to run. // start blocks until the task can start and returns true if it is allowed to run.
// Returning false means that the task should be cancelled. // Returning false means that the task should be cancelled.
func (t *servingTask) start() bool { func (t *servingTask) start() bool {
if t.peer.isFrozen() {
return false
}
t.tokenCh = make(chan runToken, 1)
select { select {
case t.token = <-t.sq.tokenCh: case t.sq.queueAddCh <- t:
default: case <-t.sq.quit:
t.tokenCh = make(chan runToken, 1) return false
select { }
case t.sq.queueAddCh <- t: select {
case <-t.sq.quit: case t.token = <-t.tokenCh:
return false case <-t.sq.quit:
} return false
select {
case t.token = <-t.tokenCh:
case <-t.sq.quit:
return false
}
} }
if t.token == nil { if t.token == nil {
return false return false
@ -90,6 +102,14 @@ func (t *servingTask) start() bool {
func (t *servingTask) done() uint64 { func (t *servingTask) done() uint64 {
t.servingTime += uint64(mclock.Now()) t.servingTime += uint64(mclock.Now())
close(t.token) close(t.token)
diff := t.servingTime - t.timeAdded
t.timeAdded = t.servingTime
if t.expTime > diff {
t.expTime -= diff
atomic.AddUint64(&t.sq.servingTimeDiff, t.expTime)
} else {
t.expTime = 0
}
return t.servingTime return t.servingTime
} }
@ -107,16 +127,22 @@ func (t *servingTask) waitOrStop() bool {
} }
// newServingQueue returns a new servingQueue // newServingQueue returns a new servingQueue
func newServingQueue(suspendBias int64) *servingQueue { func newServingQueue(suspendBias int64, utilTarget float64, logger *csvlogger.Logger) *servingQueue {
sq := &servingQueue{ sq := &servingQueue{
queue: prque.New(nil), queue: prque.New(nil),
suspendBias: suspendBias, suspendBias: suspendBias,
tokenCh: make(chan runToken), queueAddCh: make(chan *servingTask, 100),
queueAddCh: make(chan *servingTask, 100), queueBestCh: make(chan *servingTask),
queueBestCh: make(chan *servingTask), stopThreadCh: make(chan struct{}),
stopThreadCh: make(chan struct{}), quit: make(chan struct{}),
quit: make(chan struct{}), setThreadsCh: make(chan int, 10),
setThreadsCh: make(chan int, 10), burstLimit: uint64(utilTarget * bufLimitRatio * 1200000),
burstDropLimit: uint64(utilTarget * bufLimitRatio * 1000000),
burstDecRate: utilTarget,
lastUpdate: mclock.Now(),
logger: logger,
logRecentTime: logger.NewMinMaxChannel("recentTime", false),
logQueuedTime: logger.NewMinMaxChannel("queuedTime", false),
} }
sq.wg.Add(2) sq.wg.Add(2)
go sq.queueLoop() go sq.queueLoop()
@ -125,9 +151,12 @@ func newServingQueue(suspendBias int64) *servingQueue {
} }
// newTask creates a new task with the given priority // newTask creates a new task with the given priority
func (sq *servingQueue) newTask(priority int64) *servingTask { func (sq *servingQueue) newTask(peer *peer, maxTime uint64, priority int64) *servingTask {
return &servingTask{ return &servingTask{
sq: sq, sq: sq,
peer: peer,
maxTime: maxTime,
expTime: maxTime,
priority: priority, priority: priority,
} }
} }
@ -144,18 +173,12 @@ func (sq *servingQueue) threadController() {
select { select {
case best := <-sq.queueBestCh: case best := <-sq.queueBestCh:
best.tokenCh <- token best.tokenCh <- token
default: case <-sq.stopThreadCh:
select { sq.wg.Done()
case best := <-sq.queueBestCh: return
best.tokenCh <- token case <-sq.quit:
case sq.tokenCh <- token: sq.wg.Done()
case <-sq.stopThreadCh: return
sq.wg.Done()
return
case <-sq.quit:
sq.wg.Done()
return
}
} }
<-token <-token
select { select {
@ -170,6 +193,100 @@ func (sq *servingQueue) threadController() {
} }
} }
type (
// peerTasks lists the tasks received from a given peer when selecting peers to freeze
peerTasks struct {
peer *peer
list []*servingTask
sumTime uint64
priority float64
}
// peerList is a sortable list of peerTasks
peerList []*peerTasks
)
func (l peerList) Len() int {
return len(l)
}
func (l peerList) Less(i, j int) bool {
return l[i].priority < l[j].priority
}
func (l peerList) Swap(i, j int) {
l[i], l[j] = l[j], l[i]
}
// freezePeers selects the peers with the worst priority queued tasks and freezes
// them until burstTime goes under burstDropLimit or all peers are frozen
func (sq *servingQueue) freezePeers() {
peerMap := make(map[*peer]*peerTasks)
var peerList peerList
if sq.best != nil {
sq.queue.Push(sq.best, sq.best.priority)
}
sq.best = nil
for sq.queue.Size() > 0 {
task := sq.queue.PopItem().(*servingTask)
tasks := peerMap[task.peer]
if tasks == nil {
bufValue, bufLimit := task.peer.fcClient.BufferStatus()
if bufLimit < 1 {
bufLimit = 1
}
tasks = &peerTasks{
peer: task.peer,
priority: float64(bufValue) / float64(bufLimit), // lower value comes first
}
peerMap[task.peer] = tasks
peerList = append(peerList, tasks)
}
tasks.list = append(tasks.list, task)
tasks.sumTime += task.expTime
}
sort.Sort(peerList)
drop := true
sq.logger.Event("freezing peers")
for _, tasks := range peerList {
if drop {
tasks.peer.freezeClient()
tasks.peer.fcClient.Freeze()
sq.queuedTime -= tasks.sumTime
if sq.logQueuedTime != nil {
sq.logQueuedTime.Update(float64(sq.queuedTime) / 1000)
}
sq.logger.Event(fmt.Sprintf("frozen peer sumTime=%d, %v", tasks.sumTime, tasks.peer.id))
drop = sq.recentTime+sq.queuedTime > sq.burstDropLimit
for _, task := range tasks.list {
task.tokenCh <- nil
}
} else {
for _, task := range tasks.list {
sq.queue.Push(task, task.priority)
}
}
}
if sq.queue.Size() > 0 {
sq.best = sq.queue.PopItem().(*servingTask)
}
}
// updateRecentTime recalculates the recent serving time value
func (sq *servingQueue) updateRecentTime() {
subTime := atomic.SwapUint64(&sq.servingTimeDiff, 0)
now := mclock.Now()
dt := now - sq.lastUpdate
sq.lastUpdate = now
if dt > 0 {
subTime += uint64(float64(dt) * sq.burstDecRate)
}
if sq.recentTime > subTime {
sq.recentTime -= subTime
} else {
sq.recentTime = 0
}
}
// addTask inserts a task into the priority queue // addTask inserts a task into the priority queue
func (sq *servingQueue) addTask(task *servingTask) { func (sq *servingQueue) addTask(task *servingTask) {
if sq.best == nil { if sq.best == nil {
@ -177,10 +294,18 @@ func (sq *servingQueue) addTask(task *servingTask) {
} else if task.priority > sq.best.priority { } else if task.priority > sq.best.priority {
sq.queue.Push(sq.best, sq.best.priority) sq.queue.Push(sq.best, sq.best.priority)
sq.best = task sq.best = task
return
} else { } else {
sq.queue.Push(task, task.priority) sq.queue.Push(task, task.priority)
} }
sq.updateRecentTime()
sq.queuedTime += task.expTime
if sq.logQueuedTime != nil {
sq.logRecentTime.Update(float64(sq.recentTime) / 1000)
sq.logQueuedTime.Update(float64(sq.queuedTime) / 1000)
}
if sq.recentTime+sq.queuedTime > sq.burstLimit {
sq.freezePeers()
}
} }
// queueLoop is an event loop running in a goroutine. It receives tasks from queueAddCh // queueLoop is an event loop running in a goroutine. It receives tasks from queueAddCh
@ -189,10 +314,18 @@ func (sq *servingQueue) addTask(task *servingTask) {
func (sq *servingQueue) queueLoop() { func (sq *servingQueue) queueLoop() {
for { for {
if sq.best != nil { if sq.best != nil {
expTime := sq.best.expTime
select { select {
case task := <-sq.queueAddCh: case task := <-sq.queueAddCh:
sq.addTask(task) sq.addTask(task)
case sq.queueBestCh <- sq.best: case sq.queueBestCh <- sq.best:
sq.updateRecentTime()
sq.queuedTime -= expTime
sq.recentTime += expTime
if sq.logQueuedTime != nil {
sq.logRecentTime.Update(float64(sq.recentTime) / 1000)
sq.logQueuedTime.Update(float64(sq.queuedTime) / 1000)
}
if sq.queue.Size() == 0 { if sq.queue.Size() == 0 {
sq.best = nil sq.best = nil
} else { } else {

Loading…
Cancel
Save