diff --git a/les/balance.go b/les/balance.go index a36a997cf3..2813db01c5 100644 --- a/les/balance.go +++ b/les/balance.go @@ -67,7 +67,7 @@ type balanceCallback struct { // init initializes balanceTracker func (bt *balanceTracker) init(clock mclock.Clock, capacity uint64) { bt.clock = clock - bt.initTime = clock.Now() + bt.initTime, bt.lastUpdate = clock.Now(), clock.Now() // Init timestamps for i := range bt.callbackIndex { bt.callbackIndex[i] = -1 } diff --git a/les/balance_test.go b/les/balance_test.go new file mode 100644 index 0000000000..b571c2cc5c --- /dev/null +++ b/les/balance_test.go @@ -0,0 +1,260 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package les + +import ( + "testing" + "time" + + "github.com/ethereum/go-ethereum/common/mclock" +) + +func TestSetBalance(t *testing.T) { + var clock = &mclock.Simulated{} + var inputs = []struct { + pos uint64 + neg uint64 + }{ + {1000, 0}, + {0, 1000}, + {1000, 1000}, + } + + tracker := balanceTracker{} + tracker.init(clock, 1000) + defer tracker.stop(clock.Now()) + + for _, i := range inputs { + tracker.setBalance(i.pos, i.neg) + pos, neg := tracker.getBalance(clock.Now()) + if pos != i.pos { + t.Fatalf("Positive balance mismatch, want %v, got %v", i.pos, pos) + } + if neg != i.neg { + t.Fatalf("Negative balance mismatch, want %v, got %v", i.neg, neg) + } + } +} + +func TestBalanceTimeCost(t *testing.T) { + var ( + clock = &mclock.Simulated{} + tracker = balanceTracker{} + ) + tracker.init(clock, 1000) + defer tracker.stop(clock.Now()) + tracker.setFactors(false, 1, 1) + tracker.setFactors(true, 1, 1) + + tracker.setBalance(uint64(time.Minute), 0) // 1 minute time allowance + + var inputs = []struct { + runTime time.Duration + expPos uint64 + expNeg uint64 + }{ + {time.Second, uint64(time.Second * 59), 0}, + {0, uint64(time.Second * 59), 0}, + {time.Second * 59, 0, 0}, + {time.Second, 0, uint64(time.Second)}, + } + for _, i := range inputs { + clock.Run(i.runTime) + if pos, _ := tracker.getBalance(clock.Now()); pos != i.expPos { + t.Fatalf("Positive balance mismatch, want %v, got %v", i.expPos, pos) + } + if _, neg := tracker.getBalance(clock.Now()); neg != i.expNeg { + t.Fatalf("Negative balance mismatch, want %v, got %v", i.expNeg, neg) + } + } + + tracker.setBalance(uint64(time.Minute), 0) // Refill 1 minute time allowance + for _, i := range inputs { + clock.Run(i.runTime) + if pos, _ := tracker.getBalance(clock.Now()); pos != i.expPos { + t.Fatalf("Positive balance mismatch, want %v, got %v", i.expPos, pos) + } + if _, neg := tracker.getBalance(clock.Now()); neg != i.expNeg { + t.Fatalf("Negative balance mismatch, want %v, got %v", i.expNeg, neg) + } + } +} + +func TestBalanceReqCost(t *testing.T) { + var ( + clock = &mclock.Simulated{} + tracker = balanceTracker{} + ) + tracker.init(clock, 1000) + defer tracker.stop(clock.Now()) + tracker.setFactors(false, 1, 1) + tracker.setFactors(true, 1, 1) + + tracker.setBalance(uint64(time.Minute), 0) // 1 minute time serving time allowance + var inputs = []struct { + reqCost uint64 + expPos uint64 + expNeg uint64 + }{ + {uint64(time.Second), uint64(time.Second * 59), 0}, + {0, uint64(time.Second * 59), 0}, + {uint64(time.Second * 59), 0, 0}, + {uint64(time.Second), 0, uint64(time.Second)}, + } + for _, i := range inputs { + tracker.requestCost(i.reqCost) + if pos, _ := tracker.getBalance(clock.Now()); pos != i.expPos { + t.Fatalf("Positive balance mismatch, want %v, got %v", i.expPos, pos) + } + if _, neg := tracker.getBalance(clock.Now()); neg != i.expNeg { + t.Fatalf("Negative balance mismatch, want %v, got %v", i.expNeg, neg) + } + } +} + +func TestBalanceToPriority(t *testing.T) { + var ( + clock = &mclock.Simulated{} + tracker = balanceTracker{} + ) + tracker.init(clock, 1000) // cap = 1000 + defer tracker.stop(clock.Now()) + tracker.setFactors(false, 1, 1) + tracker.setFactors(true, 1, 1) + + var inputs = []struct { + pos uint64 + neg uint64 + priority int64 + }{ + {1000, 0, ^int64(1)}, + {2000, 0, ^int64(2)}, // Higher balance, lower priority value + {0, 0, 0}, + {0, 1000, 1000}, + } + for _, i := range inputs { + tracker.setBalance(i.pos, i.neg) + priority := tracker.getPriority(clock.Now()) + if priority != i.priority { + t.Fatalf("Priority mismatch, want %v, got %v", i.priority, priority) + } + } +} + +func TestEstimatedPriority(t *testing.T) { + var ( + clock = &mclock.Simulated{} + tracker = balanceTracker{} + ) + tracker.init(clock, 1000000000) // cap = 1000,000,000 + defer tracker.stop(clock.Now()) + tracker.setFactors(false, 1, 1) + tracker.setFactors(true, 1, 1) + + tracker.setBalance(uint64(time.Minute), 0) + var inputs = []struct { + runTime time.Duration // time cost + futureTime time.Duration // diff of future time + reqCost uint64 // single request cost + priority int64 // expected estimated priority + }{ + {time.Second, time.Second, 0, ^int64(58)}, + {0, time.Second, 0, ^int64(58)}, + + // 2 seconds time cost, 1 second estimated time cost, 10^9 request cost, + // 10^9 estimated request cost per second. + {time.Second, time.Second, 1000000000, ^int64(55)}, + + // 3 seconds time cost, 3 second estimated time cost, 10^9*2 request cost, + // 4*10^9 estimated request cost. + {time.Second, 3 * time.Second, 1000000000, ^int64(48)}, + + // All positive balance is used up + {time.Second * 55, 0, 0, 0}, + + // 1 minute estimated time cost, 4/58 * 10^9 estimated request cost per sec. + {0, time.Minute, 0, int64(time.Minute) + int64(time.Second)*120/29}, + } + for _, i := range inputs { + clock.Run(i.runTime) + tracker.requestCost(i.reqCost) + priority := tracker.estimatedPriority(clock.Now()+mclock.AbsTime(i.futureTime), true) + if priority != i.priority { + t.Fatalf("Estimated priority mismatch, want %v, got %v", i.priority, priority) + } + } +} + +func TestCallbackChecking(t *testing.T) { + var ( + clock = &mclock.Simulated{} + tracker = balanceTracker{} + ) + tracker.init(clock, 1000000) // cap = 1000,000 + defer tracker.stop(clock.Now()) + tracker.setFactors(false, 1, 1) + tracker.setFactors(true, 1, 1) + + var inputs = []struct { + priority int64 + expDiff time.Duration + }{ + {^int64(500), time.Millisecond * 500}, + {0, time.Second}, + {int64(time.Second), 2 * time.Second}, + } + tracker.setBalance(uint64(time.Second), 0) + for _, i := range inputs { + diff, _ := tracker.timeUntil(i.priority) + if diff != i.expDiff { + t.Fatalf("Time difference mismatch, want %v, got %v", i.expDiff, diff) + } + } +} + +func TestCallback(t *testing.T) { + var ( + clock = &mclock.Simulated{} + tracker = balanceTracker{} + ) + tracker.init(clock, 1000) // cap = 1000 + defer tracker.stop(clock.Now()) + tracker.setFactors(false, 1, 1) + tracker.setFactors(true, 1, 1) + + callCh := make(chan struct{}, 1) + tracker.setBalance(uint64(time.Minute), 0) + tracker.addCallback(balanceCallbackZero, 0, func() { callCh <- struct{}{} }) + + clock.Run(time.Minute) + select { + case <-callCh: + case <-time.NewTimer(time.Second).C: + t.Fatalf("Callback hasn't been called yet") + } + + tracker.setBalance(uint64(time.Minute), 0) + tracker.addCallback(balanceCallbackZero, 0, func() { callCh <- struct{}{} }) + tracker.removeCallback(balanceCallbackZero) + + clock.Run(time.Minute) + select { + case <-callCh: + t.Fatalf("Callback shouldn't be called") + case <-time.NewTimer(time.Millisecond * 100).C: + } +} diff --git a/les/clientpool.go b/les/clientpool.go index 6773aab551..2df538620b 100644 --- a/les/clientpool.go +++ b/les/clientpool.go @@ -17,67 +17,81 @@ package les import ( + "encoding/binary" "io" "math" "sync" "time" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/common/prque" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/rlp" + "github.com/hashicorp/golang-lru" ) const ( - negBalanceExpTC = time.Hour // time constant for exponentially reducing negative balance - fixedPointMultiplier = 0x1000000 // constant to convert logarithms to fixed point format - connectedBias = time.Minute * 5 // this bias is applied in favor of already connected clients in order to avoid kicking them out very soon - lazyQueueRefresh = time.Second * 10 // refresh period of the connected queue -) - -var ( - clientPoolDbKey = []byte("clientPool") - clientBalanceDbKey = []byte("clientPool-balance") + negBalanceExpTC = time.Hour // time constant for exponentially reducing negative balance + fixedPointMultiplier = 0x1000000 // constant to convert logarithms to fixed point format + lazyQueueRefresh = time.Second * 10 // refresh period of the connected queue + persistCumulativeTimeRefresh = time.Minute * 5 // refresh period of the cumulative running time persistence + posBalanceCacheLimit = 8192 // the maximum number of cached items in positive balance queue + negBalanceCacheLimit = 8192 // the maximum number of cached items in negative balance queue + + // connectedBias is applied to already connected clients So that + // already connected client won't be kicked out very soon and we + // can ensure all connected clients can have enough time to request + // or sync some data. + // + // todo(rjl493456442) make it configurable. It can be the option of + // free trial time! + connectedBias = time.Minute * 3 ) // clientPool implements a client database that assigns a priority to each client // based on a positive and negative balance. Positive balance is externally assigned // to prioritized clients and is decreased with connection time and processed // requests (unless the price factors are zero). If the positive balance is zero -// then negative balance is accumulated. Balance tracking and priority calculation -// for connected clients is done by balanceTracker. connectedQueue ensures that -// clients with the lowest positive or highest negative balance get evicted when -// the total capacity allowance is full and new clients with a better balance want -// to connect. Already connected nodes receive a small bias in their favor in order -// to avoid accepting and instantly kicking out clients. -// Balances of disconnected clients are stored in posBalanceQueue and negBalanceQueue -// and are also saved in the database. Negative balance is transformed into a -// logarithmic form with a constantly shifting linear offset in order to implement -// an exponential decrease. negBalanceQueue has a limited size and drops the smallest -// values when necessary. Positive balances are stored in the database as long as -// they exist, posBalanceQueue only acts as a cache for recently accessed entries. +// then negative balance is accumulated. +// +// Balance tracking and priority calculation for connected clients is done by +// balanceTracker. connectedQueue ensures that clients with the lowest positive or +// highest negative balance get evicted when the total capacity allowance is full +// and new clients with a better balance want to connect. +// +// Already connected nodes receive a small bias in their favor in order to avoid +// accepting and instantly kicking out clients. In theory, we try to ensure that +// each client can have several minutes of connection time. +// +// Balances of disconnected clients are stored in nodeDB including positive balance +// and negative banalce. Negative balance is transformed into a logarithmic form +// with a constantly shifting linear offset in order to implement an exponential +// decrease. Besides nodeDB will have a background thread to check the negative +// balance of disconnected client. If the balance is low enough, then the record +// will be dropped. type clientPool struct { - db ethdb.Database + ndb *nodeDB lock sync.Mutex clock mclock.Clock - stopCh chan chan struct{} + stopCh chan struct{} closed bool removePeer func(enode.ID) - queueLimit, countLimit int - freeClientCap, capacityLimit, connectedCapacity uint64 + connectedMap map[enode.ID]*clientInfo + connectedQueue *prque.LazyQueue + + posFactors, negFactors priceFactors - connectedMap map[enode.ID]*clientInfo - posBalanceMap map[enode.ID]*posBalance - negBalanceMap map[string]*negBalance - connectedQueue *prque.LazyQueue - posBalanceQueue, negBalanceQueue *prque.Prque - posFactors, negFactors priceFactors - posBalanceAccessCounter int64 - startupTime mclock.AbsTime - logOffsetAtStartup int64 + connLimit int // The maximum number of connections that clientpool can support + capLimit uint64 // The maximum cumulative capacity that clientpool can support + connectedCap uint64 // The sum of the capacity of the current clientpool connected + freeClientCap uint64 // The capacity value of each free client + startTime mclock.AbsTime // The timestamp at which the clientpool started running + cumulativeTime int64 // The cumulative running time of clientpool at the start point. + disableBias bool // Disable connection bias(used in testing) } // clientPeer represents a client in the pool. @@ -138,22 +152,25 @@ type priceFactors struct { } // newClientPool creates a new client pool -func newClientPool(db ethdb.Database, freeClientCap uint64, queueLimit int, clock mclock.Clock, removePeer func(enode.ID)) *clientPool { +func newClientPool(db ethdb.Database, freeClientCap uint64, clock mclock.Clock, removePeer func(enode.ID)) *clientPool { + ndb := newNodeDB(db, clock) pool := &clientPool{ - db: db, - clock: clock, - connectedMap: make(map[enode.ID]*clientInfo), - posBalanceMap: make(map[enode.ID]*posBalance), - negBalanceMap: make(map[string]*negBalance), - connectedQueue: prque.NewLazyQueue(connSetIndex, connPriority, connMaxPriority, clock, lazyQueueRefresh), - negBalanceQueue: prque.New(negSetIndex), - posBalanceQueue: prque.New(posSetIndex), - freeClientCap: freeClientCap, - queueLimit: queueLimit, - removePeer: removePeer, - stopCh: make(chan chan struct{}), - } - pool.loadFromDb() + ndb: ndb, + clock: clock, + connectedMap: make(map[enode.ID]*clientInfo), + connectedQueue: prque.NewLazyQueue(connSetIndex, connPriority, connMaxPriority, clock, lazyQueueRefresh), + freeClientCap: freeClientCap, + removePeer: removePeer, + startTime: clock.Now(), + cumulativeTime: ndb.getCumulativeTime(), + stopCh: make(chan struct{}), + } + // If the negative balance of free client is even lower than 1, + // delete this entry. + ndb.nbEvictCallBack = func(now mclock.AbsTime, b negBalance) bool { + balance := math.Exp(float64(b.logValue-pool.logOffset(now)) / fixedPointMultiplier) + return balance <= 1 + } go func() { for { select { @@ -161,8 +178,9 @@ func newClientPool(db ethdb.Database, freeClientCap uint64, queueLimit int, cloc pool.lock.Lock() pool.connectedQueue.Refresh() pool.lock.Unlock() - case stop := <-pool.stopCh: - close(stop) + case <-clock.After(persistCumulativeTimeRefresh): + pool.ndb.setCumulativeTime(pool.logOffset(clock.Now())) + case <-pool.stopCh: return } } @@ -172,13 +190,12 @@ func newClientPool(db ethdb.Database, freeClientCap uint64, queueLimit int, cloc // stop shuts the client pool down func (f *clientPool) stop() { - stop := make(chan struct{}) - f.stopCh <- stop - <-stop + close(f.stopCh) f.lock.Lock() f.closed = true - f.saveToDb() f.lock.Unlock() + f.ndb.setCumulativeTime(f.logOffset(f.clock.Now())) + f.ndb.close() } // connect should be called after a successful handshake. If the connection was @@ -187,7 +204,7 @@ func (f *clientPool) connect(peer clientPeer, capacity uint64) bool { f.lock.Lock() defer f.lock.Unlock() - // Short circuit is clientPool is already closed. + // Short circuit if clientPool is already closed. if f.closed { return false } @@ -199,14 +216,19 @@ func (f *clientPool) connect(peer clientPeer, capacity uint64) bool { return false } // Create a clientInfo but do not add it yet - now := f.clock.Now() - posBalance := f.getPosBalance(id).value + var ( + posBalance uint64 + negBalance uint64 + now = f.clock.Now() + ) + pb := f.ndb.getOrNewPB(id) + posBalance = pb.value e := &clientInfo{pool: f, peer: peer, address: freeID, queueIndex: -1, id: id, priority: posBalance != 0} - var negBalance uint64 - nb := f.negBalanceMap[freeID] - if nb != nil { + nb := f.ndb.getOrNewNB(freeID) + if nb.logValue != 0 { negBalance = uint64(math.Exp(float64(nb.logValue-f.logOffset(now)) / fixedPointMultiplier)) + negBalance *= uint64(time.Second) } // If the client is a free client, assign with a low free capacity, // Otherwise assign with the given value(priority client) @@ -219,6 +241,7 @@ func (f *clientPool) connect(peer clientPeer, capacity uint64) bool { } e.capacity = capacity + // Starts a balance tracker e.balanceTracker.init(f.clock, capacity) e.balanceTracker.setBalance(posBalance, negBalance) f.setClientPriceFactors(e) @@ -228,9 +251,9 @@ func (f *clientPool) connect(peer clientPeer, capacity uint64) bool { // // If the priority of the newly added client is lower than the priority of // all connected clients, the client is rejected. - newCapacity := f.connectedCapacity + capacity + newCapacity := f.connectedCap + capacity newCount := f.connectedQueue.Size() + 1 - if newCapacity > f.capacityLimit || newCount > f.countLimit { + if newCapacity > f.capLimit || newCount > f.connLimit { var ( kickList []*clientInfo kickPriority int64 @@ -241,10 +264,13 @@ func (f *clientPool) connect(peer clientPeer, capacity uint64) bool { kickPriority = priority newCapacity -= c.capacity newCount-- - return newCapacity > f.capacityLimit || newCount > f.countLimit + return newCapacity > f.capLimit || newCount > f.connLimit }) - if newCapacity > f.capacityLimit || newCount > f.countLimit || (e.balanceTracker.estimatedPriority(now+mclock.AbsTime(connectedBias), false)-kickPriority) > 0 { - // reject client + bias := connectedBias + if f.disableBias { + bias = 0 + } + if newCapacity > f.capLimit || newCount > f.connLimit || (e.balanceTracker.estimatedPriority(now+mclock.AbsTime(bias), false)-kickPriority) > 0 { for _, c := range kickList { f.connectedQueue.Push(c) } @@ -257,21 +283,22 @@ func (f *clientPool) connect(peer clientPeer, capacity uint64) bool { f.dropClient(c, now, true) } } - // client accepted, finish setting it up - if nb != nil { - delete(f.negBalanceMap, freeID) - f.negBalanceQueue.Remove(nb.queueIndex) - } + // Register new client to connection queue. + f.connectedMap[id] = e + f.connectedQueue.Push(e) + f.connectedCap += e.capacity + + // If the current client is a paid client, monitor the status of client, + // downgrade it to normal client if positive balance is used up. if e.priority { e.balanceTracker.addCallback(balanceCallbackZero, 0, func() { f.balanceExhausted(id) }) } - f.connectedMap[id] = e - f.connectedQueue.Push(e) - f.connectedCapacity += e.capacity - totalConnectedGauge.Update(int64(f.connectedCapacity)) + // If the capacity of client is not the default value(free capacity), notify + // it to update capacity. if e.capacity != f.freeClientCap { e.peer.updateCapacity(e.capacity) } + totalConnectedGauge.Update(int64(f.connectedCap)) clientConnectedMeter.Mark(1) log.Debug("Client accepted", "address", freeID) return true @@ -284,15 +311,14 @@ func (f *clientPool) disconnect(p clientPeer) { f.lock.Lock() defer f.lock.Unlock() + // Short circuit if client pool is already closed. if f.closed { return } - address := p.freeClientId() - id := p.ID() // Short circuit if the peer hasn't been registered. - e := f.connectedMap[id] + e := f.connectedMap[p.ID()] if e == nil { - log.Debug("Client not connected", "address", address, "id", peerIdToString(id)) + log.Debug("Client not connected", "address", p.freeClientId(), "id", peerIdToString(p.ID())) return } f.dropClient(e, f.clock.Now(), false) @@ -307,8 +333,8 @@ func (f *clientPool) dropClient(e *clientInfo, now mclock.AbsTime, kick bool) { f.finalizeBalance(e, now) f.connectedQueue.Remove(e.queueIndex) delete(f.connectedMap, e.id) - f.connectedCapacity -= e.capacity - totalConnectedGauge.Update(int64(f.connectedCapacity)) + f.connectedCap -= e.capacity + totalConnectedGauge.Update(int64(f.connectedCap)) if kick { clientKickedMeter.Mark(1) log.Debug("Client kicked out", "address", e.address) @@ -324,18 +350,17 @@ func (f *clientPool) dropClient(e *clientInfo, now mclock.AbsTime, kick bool) { func (f *clientPool) finalizeBalance(c *clientInfo, now mclock.AbsTime) { c.balanceTracker.stop(now) pos, neg := c.balanceTracker.getBalance(now) - pb := f.getPosBalance(c.id) + + pb, nb := f.ndb.getOrNewPB(c.id), f.ndb.getOrNewNB(c.address) pb.value = pos - f.storePosBalance(pb) - if neg < 1 { - neg = 1 - } - nb := &negBalance{address: c.address, queueIndex: -1, logValue: int64(math.Log(float64(neg))*fixedPointMultiplier) + f.logOffset(now)} - f.negBalanceMap[c.address] = nb - f.negBalanceQueue.Push(nb, -nb.logValue) - if f.negBalanceQueue.Size() > f.queueLimit { - nn := f.negBalanceQueue.PopItem().(*negBalance) - delete(f.negBalanceMap, nn.address) + f.ndb.setPB(c.id, pb) + + neg /= uint64(time.Second) // Convert the expanse to second level. + if neg > 1 { + nb.logValue = int64(math.Log(float64(neg))*fixedPointMultiplier) + f.logOffset(now) + f.ndb.setNB(c.address, nb) + } else { + f.ndb.delNB(c.address) // Negative balance is small enough, drop it directly. } } @@ -351,27 +376,26 @@ func (f *clientPool) balanceExhausted(id enode.ID) { } c.priority = false if c.capacity != f.freeClientCap { - f.connectedCapacity += f.freeClientCap - c.capacity - totalConnectedGauge.Update(int64(f.connectedCapacity)) + f.connectedCap += f.freeClientCap - c.capacity + totalConnectedGauge.Update(int64(f.connectedCap)) c.capacity = f.freeClientCap c.peer.updateCapacity(c.capacity) } + f.ndb.delPB(id) } // setConnLimit sets the maximum number and total capacity of connected clients, // dropping some of them if necessary. -func (f *clientPool) setLimits(count int, totalCap uint64) { +func (f *clientPool) setLimits(totalConn int, totalCap uint64) { f.lock.Lock() defer f.lock.Unlock() - f.countLimit = count - f.capacityLimit = totalCap - if f.connectedCapacity > f.capacityLimit || f.connectedQueue.Size() > f.countLimit { - now := mclock.Now() + f.connLimit = totalConn + f.capLimit = totalCap + if f.connectedCap > f.capLimit || f.connectedQueue.Size() > f.connLimit { f.connectedQueue.MultiPop(func(data interface{}, priority int64) bool { - c := data.(*clientInfo) - f.dropClient(c, now, true) - return f.connectedCapacity > f.capacityLimit || f.connectedQueue.Size() > f.countLimit + f.dropClient(data.(*clientInfo), mclock.Now(), true) + return f.connectedCap > f.capLimit || f.connectedQueue.Size() > f.connLimit }) } } @@ -390,11 +414,14 @@ func (f *clientPool) requestCost(p *peer, cost uint64) { // logOffset calculates the time-dependent offset for the logarithmic // representation of negative balance +// +// From another point of view, the result returned by the function represents +// the total time that the clientpool is cumulatively running(total_hours/multiplier). func (f *clientPool) logOffset(now mclock.AbsTime) int64 { // Note: fixedPointMultiplier acts as a multiplier here; the reason for dividing the divisor // is to avoid int64 overflow. We assume that int64(negBalanceExpTC) >> fixedPointMultiplier. - logDecay := int64((time.Duration(now - f.startupTime)) / (negBalanceExpTC / fixedPointMultiplier)) - return f.logOffsetAtStartup + logDecay + cumulativeTime := int64((time.Duration(now - f.startTime)) / (negBalanceExpTC / fixedPointMultiplier)) + return f.cumulativeTime + cumulativeTime } // setPriceFactors changes pricing factors for both positive and negative balances. @@ -415,100 +442,6 @@ func (f *clientPool) setClientPriceFactors(c *clientInfo) { c.balanceTracker.setFactors(false, f.posFactors.timeFactor+float64(c.capacity)*f.posFactors.capacityFactor/1000000, f.posFactors.requestFactor) } -// clientPoolStorage is the RLP representation of the pool's database storage -type clientPoolStorage struct { - LogOffset uint64 - List []*negBalance -} - -// loadFromDb restores pool status from the database storage -// (automatically called at initialization) -func (f *clientPool) loadFromDb() { - enc, err := f.db.Get(clientPoolDbKey) - if err != nil { - return - } - var storage clientPoolStorage - err = rlp.DecodeBytes(enc, &storage) - if err != nil { - log.Error("Failed to decode client list", "err", err) - return - } - f.logOffsetAtStartup = int64(storage.LogOffset) - f.startupTime = f.clock.Now() - for _, e := range storage.List { - log.Debug("Loaded free client record", "address", e.address, "logValue", e.logValue) - f.negBalanceMap[e.address] = e - f.negBalanceQueue.Push(e, -e.logValue) - } -} - -// saveToDb saves pool status to the database storage -// (automatically called during shutdown) -func (f *clientPool) saveToDb() { - now := f.clock.Now() - storage := clientPoolStorage{ - LogOffset: uint64(f.logOffset(now)), - } - for _, c := range f.connectedMap { - f.finalizeBalance(c, now) - } - i := 0 - storage.List = make([]*negBalance, len(f.negBalanceMap)) - for _, e := range f.negBalanceMap { - storage.List[i] = e - i++ - } - enc, err := rlp.EncodeToBytes(storage) - if err != nil { - log.Error("Failed to encode negative balance list", "err", err) - } else { - f.db.Put(clientPoolDbKey, enc) - } -} - -// storePosBalance stores a single positive balance entry in the database -func (f *clientPool) storePosBalance(b *posBalance) { - if b.value == b.lastStored { - return - } - enc, err := rlp.EncodeToBytes(b) - if err != nil { - log.Error("Failed to encode client balance", "err", err) - } else { - f.db.Put(append(clientBalanceDbKey, b.id[:]...), enc) - b.lastStored = b.value - } -} - -// getPosBalance retrieves a single positive balance entry from cache or the database -func (f *clientPool) getPosBalance(id enode.ID) *posBalance { - if b, ok := f.posBalanceMap[id]; ok { - f.posBalanceQueue.Remove(b.queueIndex) - f.posBalanceAccessCounter-- - f.posBalanceQueue.Push(b, f.posBalanceAccessCounter) - return b - } - balance := &posBalance{} - if enc, err := f.db.Get(append(clientBalanceDbKey, id[:]...)); err == nil { - if err := rlp.DecodeBytes(enc, balance); err != nil { - log.Error("Failed to decode client balance", "err", err) - balance = &posBalance{} - } - } - balance.id = id - balance.queueIndex = -1 - if f.posBalanceQueue.Size() >= f.queueLimit { - b := f.posBalanceQueue.PopItem().(*posBalance) - f.storePosBalance(b) - delete(f.posBalanceMap, b.id) - } - f.posBalanceAccessCounter-- - f.posBalanceQueue.Push(balance, f.posBalanceAccessCounter) - f.posBalanceMap[id] = balance - return balance -} - // addBalance updates the positive balance of a client. // If setTotal is false then the given amount is added to the balance. // If setTotal is true then amount represents the total amount ever added to the @@ -518,11 +451,18 @@ func (f *clientPool) addBalance(id enode.ID, amount uint64, setTotal bool) { f.lock.Lock() defer f.lock.Unlock() - pb := f.getPosBalance(id) + pb := f.ndb.getOrNewPB(id) c := f.connectedMap[id] - var negBalance uint64 if c != nil { - pb.value, negBalance = c.balanceTracker.getBalance(f.clock.Now()) + posBalance, negBalance := c.balanceTracker.getBalance(f.clock.Now()) + pb.value = posBalance + defer func() { + c.balanceTracker.setBalance(pb.value, negBalance) + if !c.priority && pb.value > 0 { + c.priority = true + c.balanceTracker.addCallback(balanceCallbackZero, 0, func() { f.balanceExhausted(id) }) + } + }() } if setTotal { if pb.value+amount > pb.lastTotal { @@ -535,21 +475,12 @@ func (f *clientPool) addBalance(id enode.ID, amount uint64, setTotal bool) { pb.value += amount pb.lastTotal += amount } - f.storePosBalance(pb) - if c != nil { - c.balanceTracker.setBalance(pb.value, negBalance) - if !c.priority && pb.value > 0 { - c.priority = true - c.balanceTracker.addCallback(balanceCallbackZero, 0, func() { f.balanceExhausted(id) }) - } - } + f.ndb.setPB(id, pb) } // posBalance represents a recently accessed positive balance entry type posBalance struct { - id enode.ID - value, lastStored, lastTotal uint64 - queueIndex int // position in posBalanceQueue + value, lastTotal uint64 } // EncodeRLP implements rlp.Encoder @@ -566,44 +497,207 @@ func (e *posBalance) DecodeRLP(s *rlp.Stream) error { return err } e.value = entry.Value - e.lastStored = entry.Value e.lastTotal = entry.LastTotal return nil } -// posSetIndex callback updates posBalance item index in posBalanceQueue -func posSetIndex(a interface{}, index int) { - a.(*posBalance).queueIndex = index -} - // negBalance represents a negative balance entry of a disconnected client -type negBalance struct { - address string - logValue int64 - queueIndex int // position in negBalanceQueue -} +type negBalance struct{ logValue int64 } // EncodeRLP implements rlp.Encoder func (e *negBalance) EncodeRLP(w io.Writer) error { - return rlp.Encode(w, []interface{}{e.address, uint64(e.logValue)}) + return rlp.Encode(w, []interface{}{uint64(e.logValue)}) } // DecodeRLP implements rlp.Decoder func (e *negBalance) DecodeRLP(s *rlp.Stream) error { var entry struct { - Address string LogValue uint64 } if err := s.Decode(&entry); err != nil { return err } - e.address = entry.Address e.logValue = int64(entry.LogValue) - e.queueIndex = -1 return nil } -// negSetIndex callback updates negBalance item index in negBalanceQueue -func negSetIndex(a interface{}, index int) { - a.(*negBalance).queueIndex = index +const ( + // nodeDBVersion is the version identifier of the node data in db + nodeDBVersion = 0 + + // dbCleanupCycle is the cycle of db for useless data cleanup + dbCleanupCycle = time.Hour +) + +var ( + positiveBalancePrefix = []byte("pb:") // dbVersion(uint16 big endian) + positiveBalancePrefix + id -> balance + negativeBalancePrefix = []byte("nb:") // dbVersion(uint16 big endian) + negativeBalancePrefix + ip -> balance + cumulativeRunningTimeKey = []byte("cumulativeTime:") // dbVersion(uint16 big endian) + cumulativeRunningTimeKey -> cumulativeTime +) + +type nodeDB struct { + db ethdb.Database + pcache *lru.Cache + ncache *lru.Cache + auxbuf []byte // 37-byte auxiliary buffer for key encoding + verbuf [2]byte // 2-byte auxiliary buffer for db version + nbEvictCallBack func(mclock.AbsTime, negBalance) bool // Callback to determine whether the negative balance can be evicted. + clock mclock.Clock + closeCh chan struct{} + cleanupHook func() // Test hook used for testing +} + +func newNodeDB(db ethdb.Database, clock mclock.Clock) *nodeDB { + pcache, _ := lru.New(posBalanceCacheLimit) + ncache, _ := lru.New(negBalanceCacheLimit) + ndb := &nodeDB{ + db: db, + pcache: pcache, + ncache: ncache, + auxbuf: make([]byte, 37), + clock: clock, + closeCh: make(chan struct{}), + } + binary.BigEndian.PutUint16(ndb.verbuf[:], uint16(nodeDBVersion)) + go ndb.expirer() + return ndb +} + +func (db *nodeDB) close() { + close(db.closeCh) +} + +func (db *nodeDB) key(id []byte, neg bool) []byte { + prefix := positiveBalancePrefix + if neg { + prefix = negativeBalancePrefix + } + if len(prefix)+len(db.verbuf)+len(id) > len(db.auxbuf) { + db.auxbuf = append(db.auxbuf, make([]byte, len(prefix)+len(db.verbuf)+len(id)-len(db.auxbuf))...) + } + copy(db.auxbuf[:len(db.verbuf)], db.verbuf[:]) + copy(db.auxbuf[len(db.verbuf):len(db.verbuf)+len(prefix)], prefix) + copy(db.auxbuf[len(prefix)+len(db.verbuf):len(prefix)+len(db.verbuf)+len(id)], id) + return db.auxbuf[:len(prefix)+len(db.verbuf)+len(id)] +} + +func (db *nodeDB) getCumulativeTime() int64 { + blob, err := db.db.Get(append(cumulativeRunningTimeKey, db.verbuf[:]...)) + if err != nil || len(blob) == 0 { + return 0 + } + return int64(binary.BigEndian.Uint64(blob)) +} + +func (db *nodeDB) setCumulativeTime(v int64) { + binary.BigEndian.PutUint64(db.auxbuf[:8], uint64(v)) + db.db.Put(append(cumulativeRunningTimeKey, db.verbuf[:]...), db.auxbuf[:8]) +} + +func (db *nodeDB) getOrNewPB(id enode.ID) posBalance { + key := db.key(id.Bytes(), false) + item, exist := db.pcache.Get(string(key)) + if exist { + return item.(posBalance) + } + var balance posBalance + if enc, err := db.db.Get(key); err == nil { + if err := rlp.DecodeBytes(enc, &balance); err != nil { + log.Error("Failed to decode positive balance", "err", err) + } + } + db.pcache.Add(string(key), balance) + return balance +} + +func (db *nodeDB) setPB(id enode.ID, b posBalance) { + key := db.key(id.Bytes(), false) + enc, err := rlp.EncodeToBytes(&(b)) + if err != nil { + log.Error("Failed to encode positive balance", "err", err) + return + } + db.db.Put(key, enc) + db.pcache.Add(string(key), b) +} + +func (db *nodeDB) delPB(id enode.ID) { + key := db.key(id.Bytes(), false) + db.db.Delete(key) + db.pcache.Remove(string(key)) +} + +func (db *nodeDB) getOrNewNB(id string) negBalance { + key := db.key([]byte(id), true) + item, exist := db.ncache.Get(string(key)) + if exist { + return item.(negBalance) + } + var balance negBalance + if enc, err := db.db.Get(key); err == nil { + if err := rlp.DecodeBytes(enc, &balance); err != nil { + log.Error("Failed to decode negative balance", "err", err) + } + } + db.ncache.Add(string(key), balance) + return balance +} + +func (db *nodeDB) setNB(id string, b negBalance) { + key := db.key([]byte(id), true) + enc, err := rlp.EncodeToBytes(&(b)) + if err != nil { + log.Error("Failed to encode negative balance", "err", err) + return + } + db.db.Put(key, enc) + db.ncache.Add(string(key), b) +} + +func (db *nodeDB) delNB(id string) { + key := db.key([]byte(id), true) + db.db.Delete(key) + db.ncache.Remove(string(key)) +} + +func (db *nodeDB) expirer() { + for { + select { + case <-db.clock.After(dbCleanupCycle): + db.expireNodes() + case <-db.closeCh: + return + } + } +} + +// expireNodes iterates the whole node db and checks whether the negative balance +// entry can deleted. +// +// The rationale behind this is: server doesn't need to keep the negative balance +// records if they are low enough. +func (db *nodeDB) expireNodes() { + var ( + visited int + deleted int + start = time.Now() + ) + iter := db.db.NewIteratorWithPrefix(append(db.verbuf[:], negativeBalancePrefix...)) + for iter.Next() { + visited += 1 + var balance negBalance + if err := rlp.DecodeBytes(iter.Value(), &balance); err != nil { + log.Error("Failed to decode negative balance", "err", err) + continue + } + if db.nbEvictCallBack != nil && db.nbEvictCallBack(db.clock.Now(), balance) { + deleted += 1 + db.db.Delete(iter.Key()) + } + } + // Invoke testing hook if it's not nil. + if db.cleanupHook != nil { + db.cleanupHook() + } + log.Debug("Expire nodes", "visited", visited, "deleted", deleted, "elapsed", common.PrettyDuration(time.Since(start))) } diff --git a/les/clientpool_test.go b/les/clientpool_test.go index 225f828ec6..5b9494c2aa 100644 --- a/les/clientpool_test.go +++ b/les/clientpool_test.go @@ -17,8 +17,11 @@ package les import ( + "bytes" "fmt" + "math" "math/rand" + "reflect" "testing" "time" @@ -51,7 +54,7 @@ func TestClientPoolL100C300P20(t *testing.T) { testClientPool(t, 100, 300, 20, false) } -const testClientPoolTicks = 500000 +const testClientPoolTicks = 100000 type poolTestPeer int @@ -76,8 +79,9 @@ func testClientPool(t *testing.T, connLimit, clientCount, paidCount int, randomD disconnFn = func(id enode.ID) { disconnCh <- int(id[0]) + int(id[1])<<8 } - pool = newClientPool(db, 1, 10000, &clock, disconnFn) + pool = newClientPool(db, 1, &clock, disconnFn) ) + pool.disableBias = true pool.setLimits(connLimit, uint64(connLimit)) pool.setPriceFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) @@ -89,16 +93,9 @@ func testClientPool(t *testing.T, connLimit, clientCount, paidCount int, randomD t.Fatalf("Test peer #%d rejected", i) } } - // since all accepted peers are new and should not be kicked out, the next one should be rejected - if pool.connect(poolTestPeer(connLimit), 0) { - connected[connLimit] = true - t.Fatalf("Peer accepted over connected limit") - } - // randomly connect and disconnect peers, expect to have a similar total connection time at the end for tickCounter := 0; tickCounter < testClientPoolTicks; tickCounter++ { clock.Run(1 * time.Second) - //time.Sleep(time.Microsecond * 100) if tickCounter == testClientPoolTicks/4 { // give a positive balance to some of the peers @@ -137,11 +134,11 @@ func testClientPool(t *testing.T, connLimit, clientCount, paidCount int, randomD } expTicks := testClientPoolTicks/2*connLimit/clientCount + testClientPoolTicks/2*(connLimit-paidCount)/(clientCount-paidCount) - expMin := expTicks - expTicks/10 - expMax := expTicks + expTicks/10 + expMin := expTicks - expTicks/5 + expMax := expTicks + expTicks/5 paidTicks := testClientPoolTicks/2*connLimit/clientCount + testClientPoolTicks/2 - paidMin := paidTicks - paidTicks/10 - paidMax := paidTicks + paidTicks/10 + paidMin := paidTicks - paidTicks/5 + paidMax := paidTicks + paidTicks/5 // check if the total connected time of peers are all in the expected range for i, c := range connected { @@ -157,24 +154,371 @@ func testClientPool(t *testing.T, connLimit, clientCount, paidCount int, randomD t.Errorf("Total connected time of test node #%d (%d) outside expected range (%d to %d)", i, connTicks[i], min, max) } } + pool.stop() +} + +func TestConnectPaidClient(t *testing.T) { + var ( + clock mclock.Simulated + db = rawdb.NewMemoryDatabase() + ) + pool := newClientPool(db, 1, &clock, nil) + defer pool.stop() + pool.setLimits(10, uint64(10)) + pool.setPriceFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) - // a previously unknown peer should be accepted now - if !pool.connect(poolTestPeer(54321), 0) { - t.Fatalf("Previously unknown peer rejected") + // Add balance for an external client and mark it as paid client + pool.addBalance(poolTestPeer(0).ID(), 1000, false) + + if !pool.connect(poolTestPeer(0), 10) { + t.Fatalf("Failed to connect paid client") } +} - // close and restart pool - pool.stop() - pool = newClientPool(db, 1, 10000, &clock, func(id enode.ID) {}) - pool.setLimits(connLimit, uint64(connLimit)) +func TestConnectPaidClientToSmallPool(t *testing.T) { + var ( + clock mclock.Simulated + db = rawdb.NewMemoryDatabase() + ) + pool := newClientPool(db, 1, &clock, nil) + defer pool.stop() + pool.setLimits(10, uint64(10)) // Total capacity limit is 10 + pool.setPriceFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) + + // Add balance for an external client and mark it as paid client + pool.addBalance(poolTestPeer(0).ID(), 1000, false) - // try connecting all known peers (connLimit should be filled up) - for i := 0; i < clientCount; i++ { - pool.connect(poolTestPeer(i), 0) + // Connect a fat paid client to pool, should reject it. + if pool.connect(poolTestPeer(0), 100) { + t.Fatalf("Connected fat paid client, should reject it") } - // expect pool to remember known nodes and kick out one of them to accept a new one - if !pool.connect(poolTestPeer(54322), 0) { - t.Errorf("Previously unknown peer rejected after restarting pool") +} + +func TestConnectPaidClientToFullPool(t *testing.T) { + var ( + clock mclock.Simulated + db = rawdb.NewMemoryDatabase() + ) + removeFn := func(enode.ID) {} // Noop + pool := newClientPool(db, 1, &clock, removeFn) + defer pool.stop() + pool.setLimits(10, uint64(10)) // Total capacity limit is 10 + pool.setPriceFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) + + for i := 0; i < 10; i++ { + pool.addBalance(poolTestPeer(i).ID(), 1000000000, false) + pool.connect(poolTestPeer(i), 1) + } + pool.addBalance(poolTestPeer(11).ID(), 1000, false) // Add low balance to new paid client + if pool.connect(poolTestPeer(11), 1) { + t.Fatalf("Low balance paid client should be rejected") + } + clock.Run(time.Second) + pool.addBalance(poolTestPeer(12).ID(), 1000000000*60*3, false) // Add high balance to new paid client + if !pool.connect(poolTestPeer(12), 1) { + t.Fatalf("High balance paid client should be accpected") + } +} + +func TestPaidClientKickedOut(t *testing.T) { + var ( + clock mclock.Simulated + db = rawdb.NewMemoryDatabase() + kickedCh = make(chan int, 1) + ) + removeFn := func(id enode.ID) { kickedCh <- int(id[0]) } + pool := newClientPool(db, 1, &clock, removeFn) + defer pool.stop() + pool.setLimits(10, uint64(10)) // Total capacity limit is 10 + pool.setPriceFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) + + for i := 0; i < 10; i++ { + pool.addBalance(poolTestPeer(i).ID(), 1000000000, false) // 1 second allowance + pool.connect(poolTestPeer(i), 1) + clock.Run(time.Millisecond) + } + clock.Run(time.Second) + clock.Run(connectedBias) + if !pool.connect(poolTestPeer(11), 0) { + t.Fatalf("Free client should be accectped") + } + select { + case id := <-kickedCh: + if id != 0 { + t.Fatalf("Kicked client mismatch, want %v, got %v", 0, id) + } + case <-time.NewTimer(time.Second).C: + t.Fatalf("timeout") + } +} + +func TestConnectFreeClient(t *testing.T) { + var ( + clock mclock.Simulated + db = rawdb.NewMemoryDatabase() + ) + pool := newClientPool(db, 1, &clock, nil) + defer pool.stop() + pool.setLimits(10, uint64(10)) + pool.setPriceFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) + if !pool.connect(poolTestPeer(0), 10) { + t.Fatalf("Failed to connect free client") + } +} + +func TestConnectFreeClientToFullPool(t *testing.T) { + var ( + clock mclock.Simulated + db = rawdb.NewMemoryDatabase() + ) + removeFn := func(enode.ID) {} // Noop + pool := newClientPool(db, 1, &clock, removeFn) + defer pool.stop() + pool.setLimits(10, uint64(10)) // Total capacity limit is 10 + pool.setPriceFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) + + for i := 0; i < 10; i++ { + pool.connect(poolTestPeer(i), 1) + } + if pool.connect(poolTestPeer(11), 1) { + t.Fatalf("New free client should be rejected") + } + clock.Run(time.Minute) + if pool.connect(poolTestPeer(12), 1) { + t.Fatalf("New free client should be rejected") + } + clock.Run(time.Millisecond) + clock.Run(4 * time.Minute) + if !pool.connect(poolTestPeer(13), 1) { + t.Fatalf("Old client connects more than 5min should be kicked") + } +} + +func TestFreeClientKickedOut(t *testing.T) { + var ( + clock mclock.Simulated + db = rawdb.NewMemoryDatabase() + kicked = make(chan int, 10) + ) + removeFn := func(id enode.ID) { kicked <- int(id[0]) } + pool := newClientPool(db, 1, &clock, removeFn) + defer pool.stop() + pool.setLimits(10, uint64(10)) // Total capacity limit is 10 + pool.setPriceFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) + + for i := 0; i < 10; i++ { + pool.connect(poolTestPeer(i), 1) + clock.Run(100 * time.Millisecond) + } + if pool.connect(poolTestPeer(11), 1) { + t.Fatalf("New free client should be rejected") + } + clock.Run(5 * time.Minute) + for i := 0; i < 10; i++ { + pool.connect(poolTestPeer(i+10), 1) + } + for i := 0; i < 10; i++ { + select { + case id := <-kicked: + if id != i { + t.Fatalf("Kicked client mismatch, want %v, got %v", i, id) + } + case <-time.NewTimer(time.Second).C: + t.Fatalf("timeout") + } + } +} + +func TestPositiveBalanceCalculation(t *testing.T) { + var ( + clock mclock.Simulated + db = rawdb.NewMemoryDatabase() + kicked = make(chan int, 10) + ) + removeFn := func(id enode.ID) { kicked <- int(id[0]) } // Noop + pool := newClientPool(db, 1, &clock, removeFn) + defer pool.stop() + pool.setLimits(10, uint64(10)) // Total capacity limit is 10 + pool.setPriceFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) + + pool.addBalance(poolTestPeer(0).ID(), uint64(time.Minute*3), false) + pool.connect(poolTestPeer(0), 10) + clock.Run(time.Minute) + + pool.disconnect(poolTestPeer(0)) + pb := pool.ndb.getOrNewPB(poolTestPeer(0).ID()) + if pb.value != uint64(time.Minute*2) { + t.Fatalf("Positive balance mismatch, want %v, got %v", uint64(time.Minute*2), pb.value) + } +} + +func TestDowngradePriorityClient(t *testing.T) { + var ( + clock mclock.Simulated + db = rawdb.NewMemoryDatabase() + kicked = make(chan int, 10) + ) + removeFn := func(id enode.ID) { kicked <- int(id[0]) } // Noop + pool := newClientPool(db, 1, &clock, removeFn) + defer pool.stop() + pool.setLimits(10, uint64(10)) // Total capacity limit is 10 + pool.setPriceFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) + + pool.addBalance(poolTestPeer(0).ID(), uint64(time.Minute), false) + pool.connect(poolTestPeer(0), 10) + clock.Run(time.Minute) // All positive balance should be used up. + time.Sleep(300 * time.Millisecond) // Ensure the callback is called + + pb := pool.ndb.getOrNewPB(poolTestPeer(0).ID()) + if pb.value != 0 { + t.Fatalf("Positive balance mismatch, want %v, got %v", 0, pb.value) + } + + pool.addBalance(poolTestPeer(0).ID(), uint64(time.Minute), false) + pb = pool.ndb.getOrNewPB(poolTestPeer(0).ID()) + if pb.value != uint64(time.Minute) { + t.Fatalf("Positive balance mismatch, want %v, got %v", uint64(time.Minute), pb.value) + } +} + +func TestNegativeBalanceCalculation(t *testing.T) { + var ( + clock mclock.Simulated + db = rawdb.NewMemoryDatabase() + kicked = make(chan int, 10) + ) + removeFn := func(id enode.ID) { kicked <- int(id[0]) } // Noop + pool := newClientPool(db, 1, &clock, removeFn) + defer pool.stop() + pool.setLimits(10, uint64(10)) // Total capacity limit is 10 + pool.setPriceFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) + + for i := 0; i < 10; i++ { + pool.connect(poolTestPeer(i), 1) + } + clock.Run(time.Second) + + for i := 0; i < 10; i++ { + pool.disconnect(poolTestPeer(i)) + nb := pool.ndb.getOrNewNB(poolTestPeer(i).freeClientId()) + if nb.logValue != 0 { + t.Fatalf("Short connection shouldn't be recorded") + } + } + + for i := 0; i < 10; i++ { + pool.connect(poolTestPeer(i), 1) + } + clock.Run(time.Minute) + for i := 0; i < 10; i++ { + pool.disconnect(poolTestPeer(i)) + nb := pool.ndb.getOrNewNB(poolTestPeer(i).freeClientId()) + nb.logValue -= pool.logOffset(clock.Now()) + nb.logValue /= fixedPointMultiplier + if nb.logValue != int64(math.Log(float64(time.Minute/time.Second))) { + t.Fatalf("Negative balance mismatch, want %v, got %v", int64(math.Log(float64(time.Minute/time.Second))), nb.logValue) + } + } +} + +func TestNodeDB(t *testing.T) { + ndb := newNodeDB(rawdb.NewMemoryDatabase(), mclock.System{}) + defer ndb.close() + + if !bytes.Equal(ndb.verbuf[:], []byte{0x00, 0x00}) { + t.Fatalf("version buffer mismatch, want %v, got %v", []byte{0x00, 0x00}, ndb.verbuf) + } + var cases = []struct { + id enode.ID + ip string + balance interface{} + positive bool + }{ + {enode.ID{0x00, 0x01, 0x02}, "", posBalance{value: 100, lastTotal: 200}, true}, + {enode.ID{0x00, 0x01, 0x02}, "", posBalance{value: 200, lastTotal: 300}, true}, + {enode.ID{}, "127.0.0.1", negBalance{logValue: 10}, false}, + {enode.ID{}, "127.0.0.1", negBalance{logValue: 20}, false}, + } + for _, c := range cases { + if c.positive { + ndb.setPB(c.id, c.balance.(posBalance)) + if pb := ndb.getOrNewPB(c.id); !reflect.DeepEqual(pb, c.balance.(posBalance)) { + t.Fatalf("Positive balance mismatch, want %v, got %v", c.balance.(posBalance), pb) + } + } else { + ndb.setNB(c.ip, c.balance.(negBalance)) + if nb := ndb.getOrNewNB(c.ip); !reflect.DeepEqual(nb, c.balance.(negBalance)) { + t.Fatalf("Negative balance mismatch, want %v, got %v", c.balance.(negBalance), nb) + } + } + } + for _, c := range cases { + if c.positive { + ndb.delPB(c.id) + if pb := ndb.getOrNewPB(c.id); !reflect.DeepEqual(pb, posBalance{}) { + t.Fatalf("Positive balance mismatch, want %v, got %v", posBalance{}, pb) + } + } else { + ndb.delNB(c.ip) + if nb := ndb.getOrNewNB(c.ip); !reflect.DeepEqual(nb, negBalance{}) { + t.Fatalf("Negative balance mismatch, want %v, got %v", negBalance{}, nb) + } + } + } + ndb.setCumulativeTime(100) + if ndb.getCumulativeTime() != 100 { + t.Fatalf("Cumulative time mismatch, want %v, got %v", 100, ndb.getCumulativeTime()) + } +} + +func TestNodeDBExpiration(t *testing.T) { + var ( + iterated int + done = make(chan struct{}, 1) + ) + callback := func(now mclock.AbsTime, b negBalance) bool { + iterated += 1 + return true + } + clock := &mclock.Simulated{} + ndb := newNodeDB(rawdb.NewMemoryDatabase(), clock) + defer ndb.close() + ndb.nbEvictCallBack = callback + ndb.cleanupHook = func() { done <- struct{}{} } + + var cases = []struct { + ip string + balance negBalance + }{ + {"127.0.0.1", negBalance{logValue: 1}}, + {"127.0.0.2", negBalance{logValue: 1}}, + {"127.0.0.3", negBalance{logValue: 1}}, + {"127.0.0.4", negBalance{logValue: 1}}, + } + for _, c := range cases { + ndb.setNB(c.ip, c.balance) + } + time.Sleep(100 * time.Millisecond) // Ensure the db expirer is registered. + clock.Run(time.Hour + time.Minute) + select { + case <-done: + case <-time.NewTimer(time.Second).C: + t.Fatalf("timeout") + } + if iterated != 4 { + t.Fatalf("Failed to evict useless negative balances, want %v, got %d", 4, iterated) + } + + for _, c := range cases { + ndb.setNB(c.ip, c.balance) + } + clock.Run(time.Hour + time.Minute) + select { + case <-done: + case <-time.NewTimer(time.Second).C: + t.Fatalf("timeout") + } + if iterated != 8 { + t.Fatalf("Failed to evict useless negative balances, want %v, got %d", 4, iterated) } - pool.stop() } diff --git a/les/odr_test.go b/les/odr_test.go index 97217e9488..74808b345f 100644 --- a/les/odr_test.go +++ b/les/odr_test.go @@ -188,6 +188,12 @@ func testOdr(t *testing.T, protocol int, expFail uint64, checkCached bool, fn od client.handler.synchronise(client.peer.peer) + // Ensure the client has synced all necessary data. + clientHead := client.handler.backend.blockchain.CurrentHeader() + if clientHead.Number.Uint64() != 4 { + t.Fatalf("Failed to sync the chain with server, head: %v", clientHead.Number.Uint64()) + } + test := func(expFail uint64) { // Mark this as a helper to put the failures at the correct lines t.Helper() diff --git a/les/request_test.go b/les/request_test.go index 69b57ca317..8d09703c57 100644 --- a/les/request_test.go +++ b/les/request_test.go @@ -81,8 +81,15 @@ func testAccess(t *testing.T, protocol int, fn accessTestFn) { // Assemble the test environment server, client, tearDown := newClientServerEnv(t, 4, protocol, nil, nil, 0, false, true) defer tearDown() + client.handler.synchronise(client.peer.peer) + // Ensure the client has synced all necessary data. + clientHead := client.handler.backend.blockchain.CurrentHeader() + if clientHead.Number.Uint64() != 4 { + t.Fatalf("Failed to sync the chain with server, head: %v", clientHead.Number.Uint64()) + } + test := func(expFail uint64) { for i := uint64(0); i <= server.handler.blockchain.CurrentHeader().Number.Uint64(); i++ { bhash := rawdb.ReadCanonicalHash(server.db, i) diff --git a/les/server.go b/les/server.go index 7e11833fb6..997a24191b 100644 --- a/les/server.go +++ b/les/server.go @@ -113,7 +113,7 @@ func NewLesServer(e *eth.Ethereum, config *eth.Config) (*LesServer, error) { maxCapacity = totalRecharge } srv.fcManager.SetCapacityLimits(srv.freeCapacity, maxCapacity, srv.freeCapacity*2) - srv.clientPool = newClientPool(srv.chainDb, srv.freeCapacity, 10000, mclock.System{}, func(id enode.ID) { go srv.peers.Unregister(peerIdToString(id)) }) + srv.clientPool = newClientPool(srv.chainDb, srv.freeCapacity, mclock.System{}, func(id enode.ID) { go srv.peers.Unregister(peerIdToString(id)) }) srv.clientPool.setPriceFactors(priceFactors{0, 1, 1}, priceFactors{0, 1, 1}) checkpoint := srv.latestLocalCheckpoint() @@ -183,9 +183,9 @@ func (s *LesServer) Stop() { s.peers.Close() s.fcManager.Stop() - s.clientPool.stop() s.costTracker.stop() s.handler.stop() + s.clientPool.stop() // client pool should be closed after handler. s.servingQueue.stop() // Note, bloom trie indexer is closed by parent bloombits indexer. diff --git a/les/sync_test.go b/les/sync_test.go index b02c3582f0..7eef13d4ca 100644 --- a/les/sync_test.go +++ b/les/sync_test.go @@ -30,17 +30,14 @@ import ( ) // Test light syncing which will download all headers from genesis. -func TestLightSyncingLes2(t *testing.T) { testCheckpointSyncing(t, 2, 0) } func TestLightSyncingLes3(t *testing.T) { testCheckpointSyncing(t, 3, 0) } // Test legacy checkpoint syncing which will download tail headers // based on a hardcoded checkpoint. -func TestLegacyCheckpointSyncingLes2(t *testing.T) { testCheckpointSyncing(t, 2, 1) } func TestLegacyCheckpointSyncingLes3(t *testing.T) { testCheckpointSyncing(t, 3, 1) } // Test checkpoint syncing which will download tail headers based // on a verified checkpoint. -func TestCheckpointSyncingLes2(t *testing.T) { testCheckpointSyncing(t, 2, 2) } func TestCheckpointSyncingLes3(t *testing.T) { testCheckpointSyncing(t, 3, 2) } func testCheckpointSyncing(t *testing.T, protocol int, syncMode int) { diff --git a/les/test_helper.go b/les/test_helper.go index 79cf323d62..67b0225fed 100644 --- a/les/test_helper.go +++ b/les/test_helper.go @@ -280,7 +280,7 @@ func newTestServerHandler(blocks int, indexers []*core.ChainIndexer, db ethdb.Da } server.costTracker, server.freeCapacity = newCostTracker(db, server.config) server.costTracker.testCostList = testCostList(0) // Disable flow control mechanism. - server.clientPool = newClientPool(db, 1, 10000, clock, nil) + server.clientPool = newClientPool(db, 1, clock, nil) server.clientPool.setLimits(10000, 10000) // Assign enough capacity for clientpool server.handler = newServerHandler(server, simulation.Blockchain(), db, txpool, func() bool { return true }) if server.oracle != nil { @@ -517,7 +517,7 @@ func newClientServerEnv(t *testing.T, blocks int, protocol int, callback indexer if connect { cpeer, err1, speer, err2 = newTestPeerPair("peer", protocol, server, client) select { - case <-time.After(time.Millisecond * 100): + case <-time.After(time.Millisecond * 300): case err := <-err1: t.Fatalf("peer 1 handshake error: %v", err) case err := <-err2: