les/utils: UDP rate limiter (#21930)

* les/utils: Limiter

* les/utils: dropped prior weight vs variable cost logic, using fixed weights

* les/utils: always create node selector in addressGroup

* les/utils: renamed request weight to request cost

* les/utils: simplified and improved the DoS penalty mechanism

* les/utils: minor fixes

* les/utils: made selection weight calculation nicer

* les/utils: fixed linter warning

* les/utils: more precise and reliable probabilistic test

* les/utils: fixed linter warning
pull/22251/head
Felföldi Zsolt 4 years ago committed by GitHub
parent eb21c652c0
commit 7a800f98f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 405
      les/utils/limiter.go
  2. 206
      les/utils/limiter_test.go
  3. 28
      les/utils/weighted_select.go

@ -0,0 +1,405 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package utils
import (
"sort"
"sync"
"github.com/ethereum/go-ethereum/p2p/enode"
)
const maxSelectionWeight = 1000000000 // maximum selection weight of each individual node/address group
// Limiter protects a network request serving mechanism from denial-of-service attacks.
// It limits the total amount of resources used for serving requests while ensuring that
// the most valuable connections always have a reasonable chance of being served.
type Limiter struct {
lock sync.Mutex
cond *sync.Cond
quit bool
nodes map[enode.ID]*nodeQueue
addresses map[string]*addressGroup
addressSelect, valueSelect *WeightedRandomSelect
maxValue float64
maxCost, sumCost, sumCostLimit uint
selectAddressNext bool
}
// nodeQueue represents queued requests coming from a single node ID
type nodeQueue struct {
queue []request // always nil if penaltyCost != 0
id enode.ID
address string
value float64
flatWeight, valueWeight uint64 // current selection weights in the address/value selectors
sumCost uint // summed cost of requests queued by the node
penaltyCost uint // cumulative cost of dropped requests since last processed request
groupIndex int
}
// addressGroup is a group of node IDs that have sent their last requests from the same
// network address
type addressGroup struct {
nodes []*nodeQueue
nodeSelect *WeightedRandomSelect
sumFlatWeight, groupWeight uint64
}
// request represents an incoming request scheduled for processing
type request struct {
process chan chan struct{}
cost uint
}
// flatWeight distributes weights equally between each active network address
func flatWeight(item interface{}) uint64 { return item.(*nodeQueue).flatWeight }
// add adds the node queue to the address group. It is the caller's responsibility to
// add the address group to the address map and the address selector if it wasn't
// there before.
func (ag *addressGroup) add(nq *nodeQueue) {
if nq.groupIndex != -1 {
panic("added node queue is already in an address group")
}
l := len(ag.nodes)
nq.groupIndex = l
ag.nodes = append(ag.nodes, nq)
ag.sumFlatWeight += nq.flatWeight
ag.groupWeight = ag.sumFlatWeight / uint64(l+1)
ag.nodeSelect.Update(ag.nodes[l])
}
// update updates the selection weight of the node queue inside the address group.
// It is the caller's responsibility to update the group's selection weight in the
// address selector.
func (ag *addressGroup) update(nq *nodeQueue, weight uint64) {
if nq.groupIndex == -1 || nq.groupIndex >= len(ag.nodes) || ag.nodes[nq.groupIndex] != nq {
panic("updated node queue is not in this address group")
}
ag.sumFlatWeight += weight - nq.flatWeight
nq.flatWeight = weight
ag.groupWeight = ag.sumFlatWeight / uint64(len(ag.nodes))
ag.nodeSelect.Update(nq)
}
// remove removes the node queue from the address group. It is the caller's responsibility
// to remove the address group from the address map if it is empty.
func (ag *addressGroup) remove(nq *nodeQueue) {
if nq.groupIndex == -1 || nq.groupIndex >= len(ag.nodes) || ag.nodes[nq.groupIndex] != nq {
panic("removed node queue is not in this address group")
}
l := len(ag.nodes) - 1
if nq.groupIndex != l {
ag.nodes[nq.groupIndex] = ag.nodes[l]
ag.nodes[nq.groupIndex].groupIndex = nq.groupIndex
}
nq.groupIndex = -1
ag.nodes = ag.nodes[:l]
ag.sumFlatWeight -= nq.flatWeight
if l >= 1 {
ag.groupWeight = ag.sumFlatWeight / uint64(l)
} else {
ag.groupWeight = 0
}
ag.nodeSelect.Remove(nq)
}
// choose selects one of the node queues belonging to the address group
func (ag *addressGroup) choose() *nodeQueue {
return ag.nodeSelect.Choose().(*nodeQueue)
}
// NewLimiter creates a new Limiter
func NewLimiter(sumCostLimit uint) *Limiter {
l := &Limiter{
addressSelect: NewWeightedRandomSelect(func(item interface{}) uint64 { return item.(*addressGroup).groupWeight }),
valueSelect: NewWeightedRandomSelect(func(item interface{}) uint64 { return item.(*nodeQueue).valueWeight }),
nodes: make(map[enode.ID]*nodeQueue),
addresses: make(map[string]*addressGroup),
sumCostLimit: sumCostLimit,
}
l.cond = sync.NewCond(&l.lock)
go l.processLoop()
return l
}
// selectionWeights calculates the selection weights of a node for both the address and
// the value selector. The selection weight depends on the next request cost or the
// summed cost of recently dropped requests.
func (l *Limiter) selectionWeights(reqCost uint, value float64) (flatWeight, valueWeight uint64) {
if value > l.maxValue {
l.maxValue = value
}
if value > 0 {
// normalize value to <= 1
value /= l.maxValue
}
if reqCost > l.maxCost {
l.maxCost = reqCost
}
relCost := float64(reqCost) / float64(l.maxCost)
var f float64
if relCost <= 0.001 {
f = 1
} else {
f = 0.001 / relCost
}
f *= maxSelectionWeight
flatWeight, valueWeight = uint64(f), uint64(f*value)
if flatWeight == 0 {
flatWeight = 1
}
return
}
// Add adds a new request to the node queue belonging to the given id. Value belongs
// to the requesting node. A higher value gives the request a higher chance of being
// served quickly in case of heavy load or a DDoS attack. Cost is a rough estimate
// of the serving cost of the request. A lower cost also gives the request a
// better chance.
func (l *Limiter) Add(id enode.ID, address string, value float64, reqCost uint) chan chan struct{} {
l.lock.Lock()
defer l.lock.Unlock()
process := make(chan chan struct{}, 1)
if l.quit {
close(process)
return process
}
if reqCost == 0 {
reqCost = 1
}
if nq, ok := l.nodes[id]; ok {
if nq.queue != nil {
nq.queue = append(nq.queue, request{process, reqCost})
nq.sumCost += reqCost
nq.value = value
if address != nq.address {
// known id sending request from a new address, move to different address group
l.removeFromGroup(nq)
l.addToGroup(nq, address)
}
} else {
// already waiting on a penalty, just add to the penalty cost and drop the request
nq.penaltyCost += reqCost
l.update(nq)
close(process)
return process
}
} else {
nq := &nodeQueue{
queue: []request{{process, reqCost}},
id: id,
value: value,
sumCost: reqCost,
groupIndex: -1,
}
nq.flatWeight, nq.valueWeight = l.selectionWeights(reqCost, value)
if len(l.nodes) == 0 {
l.cond.Signal()
}
l.nodes[id] = nq
if nq.valueWeight != 0 {
l.valueSelect.Update(nq)
}
l.addToGroup(nq, address)
}
l.sumCost += reqCost
if l.sumCost > l.sumCostLimit {
l.dropRequests()
}
return process
}
// update updates the selection weights of the node queue
func (l *Limiter) update(nq *nodeQueue) {
var cost uint
if nq.queue != nil {
cost = nq.queue[0].cost
} else {
cost = nq.penaltyCost
}
flatWeight, valueWeight := l.selectionWeights(cost, nq.value)
ag := l.addresses[nq.address]
ag.update(nq, flatWeight)
l.addressSelect.Update(ag)
nq.valueWeight = valueWeight
l.valueSelect.Update(nq)
}
// addToGroup adds the node queue to the given address group. The group is created if
// it does not exist yet.
func (l *Limiter) addToGroup(nq *nodeQueue, address string) {
nq.address = address
ag := l.addresses[address]
if ag == nil {
ag = &addressGroup{nodeSelect: NewWeightedRandomSelect(flatWeight)}
l.addresses[address] = ag
}
ag.add(nq)
l.addressSelect.Update(ag)
}
// removeFromGroup removes the node queue from its address group
func (l *Limiter) removeFromGroup(nq *nodeQueue) {
ag := l.addresses[nq.address]
ag.remove(nq)
if len(ag.nodes) == 0 {
delete(l.addresses, nq.address)
}
l.addressSelect.Update(ag)
}
// remove removes the node queue from its address group, the nodes map and the value
// selector
func (l *Limiter) remove(nq *nodeQueue) {
l.removeFromGroup(nq)
if nq.valueWeight != 0 {
l.valueSelect.Remove(nq)
}
delete(l.nodes, nq.id)
}
// choose selects the next node queue to process.
func (l *Limiter) choose() *nodeQueue {
if l.valueSelect.IsEmpty() || l.selectAddressNext {
if ag, ok := l.addressSelect.Choose().(*addressGroup); ok {
l.selectAddressNext = false
return ag.choose()
}
}
nq, _ := l.valueSelect.Choose().(*nodeQueue)
l.selectAddressNext = true
return nq
}
// processLoop processes requests sequentially
func (l *Limiter) processLoop() {
l.lock.Lock()
defer l.lock.Unlock()
for {
if l.quit {
for _, nq := range l.nodes {
for _, request := range nq.queue {
close(request.process)
}
}
return
}
nq := l.choose()
if nq == nil {
l.cond.Wait()
continue
}
if nq.queue != nil {
request := nq.queue[0]
nq.queue = nq.queue[1:]
nq.sumCost -= request.cost
l.sumCost -= request.cost
l.lock.Unlock()
ch := make(chan struct{})
request.process <- ch
<-ch
l.lock.Lock()
if len(nq.queue) > 0 {
l.update(nq)
} else {
l.remove(nq)
}
} else {
// penalized queue removed, next request will be added to a clean queue
l.remove(nq)
}
}
}
// Stop stops the processing loop. All queued and future requests are rejected.
func (l *Limiter) Stop() {
l.lock.Lock()
defer l.lock.Unlock()
l.quit = true
l.cond.Signal()
}
type (
dropList []dropListItem
dropListItem struct {
nq *nodeQueue
priority float64
}
)
func (l dropList) Len() int {
return len(l)
}
func (l dropList) Less(i, j int) bool {
return l[i].priority < l[j].priority
}
func (l dropList) Swap(i, j int) {
l[i], l[j] = l[j], l[i]
}
// dropRequests selects the nodes with the highest queued request cost to selection
// weight ratio and drops their queued request. The empty node queues stay in the
// selectors with a low selection weight in order to penalize these nodes.
func (l *Limiter) dropRequests() {
var (
sumValue float64
list dropList
)
for _, nq := range l.nodes {
sumValue += nq.value
}
for _, nq := range l.nodes {
if nq.sumCost == 0 {
continue
}
w := 1 / float64(len(l.addresses)*len(l.addresses[nq.address].nodes))
if sumValue > 0 {
w += nq.value / sumValue
}
list = append(list, dropListItem{
nq: nq,
priority: w / float64(nq.sumCost),
})
}
sort.Sort(list)
for _, item := range list {
for _, request := range item.nq.queue {
close(request.process)
}
// make the queue penalized; no more requests are accepted until the node is
// selected based on the penalty cost which is the cumulative cost of all dropped
// requests. This ensures that sending excess requests is always penalized
// and incentivizes the sender to stop for a while if no replies are received.
item.nq.queue = nil
item.nq.penaltyCost = item.nq.sumCost
l.sumCost -= item.nq.sumCost // penalty costs are not counted in sumCost
item.nq.sumCost = 0
l.update(item.nq)
if l.sumCost <= l.sumCostLimit/2 {
return
}
}
}

@ -0,0 +1,206 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package utils
import (
"math/rand"
"testing"
"github.com/ethereum/go-ethereum/p2p/enode"
)
const (
ltTolerance = 0.03
ltRounds = 7
)
type (
ltNode struct {
addr, id int
value, exp float64
cost uint
reqRate float64
reqMax, runCount int
lastTotalCost uint
served, dropped int
}
ltResult struct {
node *ltNode
ch chan struct{}
}
limTest struct {
limiter *Limiter
results chan ltResult
runCount int
expCost, totalCost uint
}
)
func (lt *limTest) request(n *ltNode) {
var (
address string
id enode.ID
)
if n.addr >= 0 {
address = string([]byte{byte(n.addr)})
} else {
var b [32]byte
rand.Read(b[:])
address = string(b[:])
}
if n.id >= 0 {
id = enode.ID{byte(n.id)}
} else {
rand.Read(id[:])
}
lt.runCount++
n.runCount++
cch := lt.limiter.Add(id, address, n.value, n.cost)
go func() {
lt.results <- ltResult{n, <-cch}
}()
}
func (lt *limTest) moreRequests(n *ltNode) {
maxStart := int(float64(lt.totalCost-n.lastTotalCost) * n.reqRate)
if maxStart != 0 {
n.lastTotalCost = lt.totalCost
}
for n.reqMax > n.runCount && maxStart > 0 {
lt.request(n)
maxStart--
}
}
func (lt *limTest) process() {
res := <-lt.results
lt.runCount--
res.node.runCount--
if res.ch != nil {
res.node.served++
if res.node.exp != 0 {
lt.expCost += res.node.cost
}
lt.totalCost += res.node.cost
close(res.ch)
} else {
res.node.dropped++
}
}
func TestLimiter(t *testing.T) {
limTests := [][]*ltNode{
{ // one id from an individual address and two ids from a shared address
{addr: 0, id: 0, value: 0, cost: 1, reqRate: 1, reqMax: 1, exp: 0.5},
{addr: 1, id: 1, value: 0, cost: 1, reqRate: 1, reqMax: 1, exp: 0.25},
{addr: 1, id: 2, value: 0, cost: 1, reqRate: 1, reqMax: 1, exp: 0.25},
},
{ // varying request costs
{addr: 0, id: 0, value: 0, cost: 10, reqRate: 0.2, reqMax: 1, exp: 0.5},
{addr: 1, id: 1, value: 0, cost: 3, reqRate: 0.5, reqMax: 1, exp: 0.25},
{addr: 1, id: 2, value: 0, cost: 1, reqRate: 1, reqMax: 1, exp: 0.25},
},
{ // different request rate
{addr: 0, id: 0, value: 0, cost: 1, reqRate: 2, reqMax: 2, exp: 0.5},
{addr: 1, id: 1, value: 0, cost: 1, reqRate: 10, reqMax: 10, exp: 0.25},
{addr: 1, id: 2, value: 0, cost: 1, reqRate: 1, reqMax: 1, exp: 0.25},
},
{ // adding value
{addr: 0, id: 0, value: 3, cost: 1, reqRate: 1, reqMax: 1, exp: (0.5 + 0.3) / 2},
{addr: 1, id: 1, value: 0, cost: 1, reqRate: 1, reqMax: 1, exp: 0.25 / 2},
{addr: 1, id: 2, value: 7, cost: 1, reqRate: 1, reqMax: 1, exp: (0.25 + 0.7) / 2},
},
{ // DoS attack from a single address with a single id
{addr: 0, id: 0, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333},
{addr: 1, id: 1, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333},
{addr: 2, id: 2, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333},
{addr: 3, id: 3, value: 0, cost: 1, reqRate: 10, reqMax: 1000000000, exp: 0},
},
{ // DoS attack from a single address with different ids
{addr: 0, id: 0, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333},
{addr: 1, id: 1, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333},
{addr: 2, id: 2, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333},
{addr: 3, id: -1, value: 0, cost: 1, reqRate: 1, reqMax: 1000000000, exp: 0},
},
{ // DDoS attack from different addresses with a single id
{addr: 0, id: 0, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333},
{addr: 1, id: 1, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333},
{addr: 2, id: 2, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333},
{addr: -1, id: 3, value: 0, cost: 1, reqRate: 1, reqMax: 1000000000, exp: 0},
},
{ // DDoS attack from different addresses with different ids
{addr: 0, id: 0, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333},
{addr: 1, id: 1, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333},
{addr: 2, id: 2, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333},
{addr: -1, id: -1, value: 0, cost: 1, reqRate: 1, reqMax: 1000000000, exp: 0},
},
}
lt := &limTest{
limiter: NewLimiter(100),
results: make(chan ltResult),
}
for _, test := range limTests {
lt.expCost, lt.totalCost = 0, 0
iterCount := 10000
for j := 0; j < ltRounds; j++ {
// try to reach expected target range in multiple rounds with increasing iteration counts
last := j == ltRounds-1
for _, n := range test {
lt.request(n)
}
for i := 0; i < iterCount; i++ {
lt.process()
for _, n := range test {
lt.moreRequests(n)
}
}
for lt.runCount > 0 {
lt.process()
}
if spamRatio := 1 - float64(lt.expCost)/float64(lt.totalCost); spamRatio > 0.5*(1+ltTolerance) {
t.Errorf("Spam ratio too high (%f)", spamRatio)
}
fail, success := false, true
for _, n := range test {
if n.exp != 0 {
if n.dropped > 0 {
t.Errorf("Dropped %d requests of non-spam node", n.dropped)
fail = true
}
r := float64(n.served) * float64(n.cost) / float64(lt.expCost)
if r < n.exp*(1-ltTolerance) || r > n.exp*(1+ltTolerance) {
if last {
// print error only if the target is still not reached in the last round
t.Errorf("Request ratio (%f) does not match expected value (%f)", r, n.exp)
}
success = false
}
}
}
if fail || success {
break
}
// neither failed nor succeeded; try more iterations to reach probability targets
iterCount *= 2
}
}
lt.limiter.Stop()
}

@ -52,17 +52,17 @@ func (w *WeightedRandomSelect) Remove(item WrsItem) {
// IsEmpty returns true if the set is empty
func (w *WeightedRandomSelect) IsEmpty() bool {
return w.root.sumWeight == 0
return w.root.sumCost == 0
}
// setWeight sets an item's weight to a specific value (removes it if zero)
func (w *WeightedRandomSelect) setWeight(item WrsItem, weight uint64) {
if weight > math.MaxInt64-w.root.sumWeight {
// old weight is still included in sumWeight, remove and check again
if weight > math.MaxInt64-w.root.sumCost {
// old weight is still included in sumCost, remove and check again
w.setWeight(item, 0)
if weight > math.MaxInt64-w.root.sumWeight {
log.Error("WeightedRandomSelect overflow", "sumWeight", w.root.sumWeight, "new weight", weight)
weight = math.MaxInt64 - w.root.sumWeight
if weight > math.MaxInt64-w.root.sumCost {
log.Error("WeightedRandomSelect overflow", "sumCost", w.root.sumCost, "new weight", weight)
weight = math.MaxInt64 - w.root.sumCost
}
}
idx, ok := w.idx[item]
@ -75,9 +75,9 @@ func (w *WeightedRandomSelect) setWeight(item WrsItem, weight uint64) {
if weight != 0 {
if w.root.itemCnt == w.root.maxItems {
// add a new level
newRoot := &wrsNode{sumWeight: w.root.sumWeight, itemCnt: w.root.itemCnt, level: w.root.level + 1, maxItems: w.root.maxItems * wrsBranches}
newRoot := &wrsNode{sumCost: w.root.sumCost, itemCnt: w.root.itemCnt, level: w.root.level + 1, maxItems: w.root.maxItems * wrsBranches}
newRoot.items[0] = w.root
newRoot.weights[0] = w.root.sumWeight
newRoot.weights[0] = w.root.sumCost
w.root = newRoot
}
w.idx[item] = w.root.insert(item, weight)
@ -91,10 +91,10 @@ func (w *WeightedRandomSelect) setWeight(item WrsItem, weight uint64) {
// updates its weight and selects another one
func (w *WeightedRandomSelect) Choose() WrsItem {
for {
if w.root.sumWeight == 0 {
if w.root.sumCost == 0 {
return nil
}
val := uint64(rand.Int63n(int64(w.root.sumWeight)))
val := uint64(rand.Int63n(int64(w.root.sumCost)))
choice, lastWeight := w.root.choose(val)
weight := w.wfn(choice)
if weight != lastWeight {
@ -112,7 +112,7 @@ const wrsBranches = 8 // max number of branches in the wrsNode tree
type wrsNode struct {
items [wrsBranches]interface{}
weights [wrsBranches]uint64
sumWeight uint64
sumCost uint64
level, itemCnt, maxItems int
}
@ -126,7 +126,7 @@ func (n *wrsNode) insert(item WrsItem, weight uint64) int {
}
}
n.itemCnt++
n.sumWeight += weight
n.sumCost += weight
n.weights[branch] += weight
if n.level == 0 {
n.items[branch] = item
@ -150,7 +150,7 @@ func (n *wrsNode) setWeight(idx int, weight uint64) uint64 {
oldWeight := n.weights[idx]
n.weights[idx] = weight
diff := weight - oldWeight
n.sumWeight += diff
n.sumCost += diff
if weight == 0 {
n.items[idx] = nil
n.itemCnt--
@ -161,7 +161,7 @@ func (n *wrsNode) setWeight(idx int, weight uint64) uint64 {
branch := idx / branchItems
diff := n.items[branch].(*wrsNode).setWeight(idx-branch*branchItems, weight)
n.weights[branch] += diff
n.sumWeight += diff
n.sumCost += diff
if weight == 0 {
n.itemCnt--
}

Loading…
Cancel
Save