mirror of https://github.com/ethereum/go-ethereum
les, les/lespay: implement new server pool (#20758)
This PR reimplements the light client server pool. It is also a first step to move certain logic into a new lespay package. This package will contain the implementation of the lespay token sale functions, the token buying and selling logic and other components related to peer selection/prioritization and service quality evaluation. Over the long term this package will be reusable for incentivizing future protocols. Since the LES peer logic is now based on enode.Iterator, it can now use DNS-based fallback discovery to find servers. This document describes the function of the new components: https://gist.github.com/zsfelfoldi/3c7ace895234b7b345ab4f71dab102d4pull/21118/head
parent
65ce550b37
commit
b4a2681120
@ -0,0 +1,107 @@ |
||||
// Copyright 2020 The go-ethereum Authors
|
||||
// This file is part of the go-ethereum library.
|
||||
//
|
||||
// The go-ethereum library is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Lesser General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// The go-ethereum library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Lesser General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Lesser General Public License
|
||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package client |
||||
|
||||
import ( |
||||
"sync" |
||||
|
||||
"github.com/ethereum/go-ethereum/p2p/enode" |
||||
"github.com/ethereum/go-ethereum/p2p/nodestate" |
||||
) |
||||
|
||||
// FillSet tries to read nodes from an input iterator and add them to a node set by
|
||||
// setting the specified node state flag(s) until the size of the set reaches the target.
|
||||
// Note that other mechanisms (like other FillSet instances reading from different inputs)
|
||||
// can also set the same flag(s) and FillSet will always care about the total number of
|
||||
// nodes having those flags.
|
||||
type FillSet struct { |
||||
lock sync.Mutex |
||||
cond *sync.Cond |
||||
ns *nodestate.NodeStateMachine |
||||
input enode.Iterator |
||||
closed bool |
||||
flags nodestate.Flags |
||||
count, target int |
||||
} |
||||
|
||||
// NewFillSet creates a new FillSet
|
||||
func NewFillSet(ns *nodestate.NodeStateMachine, input enode.Iterator, flags nodestate.Flags) *FillSet { |
||||
fs := &FillSet{ |
||||
ns: ns, |
||||
input: input, |
||||
flags: flags, |
||||
} |
||||
fs.cond = sync.NewCond(&fs.lock) |
||||
|
||||
ns.SubscribeState(flags, func(n *enode.Node, oldState, newState nodestate.Flags) { |
||||
fs.lock.Lock() |
||||
if oldState.Equals(flags) { |
||||
fs.count-- |
||||
} |
||||
if newState.Equals(flags) { |
||||
fs.count++ |
||||
} |
||||
if fs.target > fs.count { |
||||
fs.cond.Signal() |
||||
} |
||||
fs.lock.Unlock() |
||||
}) |
||||
|
||||
go fs.readLoop() |
||||
return fs |
||||
} |
||||
|
||||
// readLoop keeps reading nodes from the input and setting the specified flags for them
|
||||
// whenever the node set size is under the current target
|
||||
func (fs *FillSet) readLoop() { |
||||
for { |
||||
fs.lock.Lock() |
||||
for fs.target <= fs.count && !fs.closed { |
||||
fs.cond.Wait() |
||||
} |
||||
|
||||
fs.lock.Unlock() |
||||
if !fs.input.Next() { |
||||
return |
||||
} |
||||
fs.ns.SetState(fs.input.Node(), fs.flags, nodestate.Flags{}, 0) |
||||
} |
||||
} |
||||
|
||||
// SetTarget sets the current target for node set size. If the previous target was not
|
||||
// reached and FillSet was still waiting for the next node from the input then the next
|
||||
// incoming node will be added to the set regardless of the target. This ensures that
|
||||
// all nodes coming from the input are eventually added to the set.
|
||||
func (fs *FillSet) SetTarget(target int) { |
||||
fs.lock.Lock() |
||||
defer fs.lock.Unlock() |
||||
|
||||
fs.target = target |
||||
if fs.target > fs.count { |
||||
fs.cond.Signal() |
||||
} |
||||
} |
||||
|
||||
// Close shuts FillSet down and closes the input iterator
|
||||
func (fs *FillSet) Close() { |
||||
fs.lock.Lock() |
||||
defer fs.lock.Unlock() |
||||
|
||||
fs.closed = true |
||||
fs.input.Close() |
||||
fs.cond.Signal() |
||||
} |
@ -0,0 +1,113 @@ |
||||
// Copyright 2020 The go-ethereum Authors
|
||||
// This file is part of the go-ethereum library.
|
||||
//
|
||||
// The go-ethereum library is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Lesser General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// The go-ethereum library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Lesser General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Lesser General Public License
|
||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package client |
||||
|
||||
import ( |
||||
"math/rand" |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/ethereum/go-ethereum/common/mclock" |
||||
"github.com/ethereum/go-ethereum/p2p/enode" |
||||
"github.com/ethereum/go-ethereum/p2p/enr" |
||||
"github.com/ethereum/go-ethereum/p2p/nodestate" |
||||
) |
||||
|
||||
type testIter struct { |
||||
waitCh chan struct{} |
||||
nodeCh chan *enode.Node |
||||
node *enode.Node |
||||
} |
||||
|
||||
func (i *testIter) Next() bool { |
||||
i.waitCh <- struct{}{} |
||||
i.node = <-i.nodeCh |
||||
return i.node != nil |
||||
} |
||||
|
||||
func (i *testIter) Node() *enode.Node { |
||||
return i.node |
||||
} |
||||
|
||||
func (i *testIter) Close() {} |
||||
|
||||
func (i *testIter) push() { |
||||
var id enode.ID |
||||
rand.Read(id[:]) |
||||
i.nodeCh <- enode.SignNull(new(enr.Record), id) |
||||
} |
||||
|
||||
func (i *testIter) waiting(timeout time.Duration) bool { |
||||
select { |
||||
case <-i.waitCh: |
||||
return true |
||||
case <-time.After(timeout): |
||||
return false |
||||
} |
||||
} |
||||
|
||||
func TestFillSet(t *testing.T) { |
||||
ns := nodestate.NewNodeStateMachine(nil, nil, &mclock.Simulated{}, testSetup) |
||||
iter := &testIter{ |
||||
waitCh: make(chan struct{}), |
||||
nodeCh: make(chan *enode.Node), |
||||
} |
||||
fs := NewFillSet(ns, iter, sfTest1) |
||||
ns.Start() |
||||
|
||||
expWaiting := func(i int, push bool) { |
||||
for ; i > 0; i-- { |
||||
if !iter.waiting(time.Second * 10) { |
||||
t.Fatalf("FillSet not waiting for new nodes") |
||||
} |
||||
if push { |
||||
iter.push() |
||||
} |
||||
} |
||||
} |
||||
|
||||
expNotWaiting := func() { |
||||
if iter.waiting(time.Millisecond * 100) { |
||||
t.Fatalf("FillSet unexpectedly waiting for new nodes") |
||||
} |
||||
} |
||||
|
||||
expNotWaiting() |
||||
fs.SetTarget(3) |
||||
expWaiting(3, true) |
||||
expNotWaiting() |
||||
fs.SetTarget(100) |
||||
expWaiting(2, true) |
||||
expWaiting(1, false) |
||||
// lower the target before the previous one has been filled up
|
||||
fs.SetTarget(0) |
||||
iter.push() |
||||
expNotWaiting() |
||||
fs.SetTarget(10) |
||||
expWaiting(4, true) |
||||
expNotWaiting() |
||||
// remove all previosly set flags
|
||||
ns.ForEach(sfTest1, nodestate.Flags{}, func(node *enode.Node, state nodestate.Flags) { |
||||
ns.SetState(node, nodestate.Flags{}, sfTest1, 0) |
||||
}) |
||||
// now expect FillSet to fill the set up again with 10 new nodes
|
||||
expWaiting(10, true) |
||||
expNotWaiting() |
||||
|
||||
fs.Close() |
||||
ns.Stop() |
||||
} |
@ -0,0 +1,123 @@ |
||||
// Copyright 2020 The go-ethereum Authors
|
||||
// This file is part of the go-ethereum library.
|
||||
//
|
||||
// The go-ethereum library is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Lesser General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// The go-ethereum library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Lesser General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Lesser General Public License
|
||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package client |
||||
|
||||
import ( |
||||
"sync" |
||||
|
||||
"github.com/ethereum/go-ethereum/p2p/enode" |
||||
"github.com/ethereum/go-ethereum/p2p/nodestate" |
||||
) |
||||
|
||||
// QueueIterator returns nodes from the specified selectable set in the same order as
|
||||
// they entered the set.
|
||||
type QueueIterator struct { |
||||
lock sync.Mutex |
||||
cond *sync.Cond |
||||
|
||||
ns *nodestate.NodeStateMachine |
||||
queue []*enode.Node |
||||
nextNode *enode.Node |
||||
waitCallback func(bool) |
||||
fifo, closed bool |
||||
} |
||||
|
||||
// NewQueueIterator creates a new QueueIterator. Nodes are selectable if they have all the required
|
||||
// and none of the disabled flags set. When a node is selected the selectedFlag is set which also
|
||||
// disables further selectability until it is removed or times out.
|
||||
func NewQueueIterator(ns *nodestate.NodeStateMachine, requireFlags, disableFlags nodestate.Flags, fifo bool, waitCallback func(bool)) *QueueIterator { |
||||
qi := &QueueIterator{ |
||||
ns: ns, |
||||
fifo: fifo, |
||||
waitCallback: waitCallback, |
||||
} |
||||
qi.cond = sync.NewCond(&qi.lock) |
||||
|
||||
ns.SubscribeState(requireFlags.Or(disableFlags), func(n *enode.Node, oldState, newState nodestate.Flags) { |
||||
oldMatch := oldState.HasAll(requireFlags) && oldState.HasNone(disableFlags) |
||||
newMatch := newState.HasAll(requireFlags) && newState.HasNone(disableFlags) |
||||
if newMatch == oldMatch { |
||||
return |
||||
} |
||||
|
||||
qi.lock.Lock() |
||||
defer qi.lock.Unlock() |
||||
|
||||
if newMatch { |
||||
qi.queue = append(qi.queue, n) |
||||
} else { |
||||
id := n.ID() |
||||
for i, qn := range qi.queue { |
||||
if qn.ID() == id { |
||||
copy(qi.queue[i:len(qi.queue)-1], qi.queue[i+1:]) |
||||
qi.queue = qi.queue[:len(qi.queue)-1] |
||||
break |
||||
} |
||||
} |
||||
} |
||||
qi.cond.Signal() |
||||
}) |
||||
return qi |
||||
} |
||||
|
||||
// Next moves to the next selectable node.
|
||||
func (qi *QueueIterator) Next() bool { |
||||
qi.lock.Lock() |
||||
if !qi.closed && len(qi.queue) == 0 { |
||||
if qi.waitCallback != nil { |
||||
qi.waitCallback(true) |
||||
} |
||||
for !qi.closed && len(qi.queue) == 0 { |
||||
qi.cond.Wait() |
||||
} |
||||
if qi.waitCallback != nil { |
||||
qi.waitCallback(false) |
||||
} |
||||
} |
||||
if qi.closed { |
||||
qi.nextNode = nil |
||||
qi.lock.Unlock() |
||||
return false |
||||
} |
||||
// Move to the next node in queue.
|
||||
if qi.fifo { |
||||
qi.nextNode = qi.queue[0] |
||||
copy(qi.queue[:len(qi.queue)-1], qi.queue[1:]) |
||||
qi.queue = qi.queue[:len(qi.queue)-1] |
||||
} else { |
||||
qi.nextNode = qi.queue[len(qi.queue)-1] |
||||
qi.queue = qi.queue[:len(qi.queue)-1] |
||||
} |
||||
qi.lock.Unlock() |
||||
return true |
||||
} |
||||
|
||||
// Close ends the iterator.
|
||||
func (qi *QueueIterator) Close() { |
||||
qi.lock.Lock() |
||||
qi.closed = true |
||||
qi.lock.Unlock() |
||||
qi.cond.Signal() |
||||
} |
||||
|
||||
// Node returns the current node.
|
||||
func (qi *QueueIterator) Node() *enode.Node { |
||||
qi.lock.Lock() |
||||
defer qi.lock.Unlock() |
||||
|
||||
return qi.nextNode |
||||
} |
@ -0,0 +1,106 @@ |
||||
// Copyright 2020 The go-ethereum Authors
|
||||
// This file is part of the go-ethereum library.
|
||||
//
|
||||
// The go-ethereum library is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Lesser General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// The go-ethereum library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Lesser General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Lesser General Public License
|
||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package client |
||||
|
||||
import ( |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/ethereum/go-ethereum/common/mclock" |
||||
"github.com/ethereum/go-ethereum/p2p/enode" |
||||
"github.com/ethereum/go-ethereum/p2p/enr" |
||||
"github.com/ethereum/go-ethereum/p2p/nodestate" |
||||
) |
||||
|
||||
func testNodeID(i int) enode.ID { |
||||
return enode.ID{42, byte(i % 256), byte(i / 256)} |
||||
} |
||||
|
||||
func testNodeIndex(id enode.ID) int { |
||||
if id[0] != 42 { |
||||
return -1 |
||||
} |
||||
return int(id[1]) + int(id[2])*256 |
||||
} |
||||
|
||||
func testNode(i int) *enode.Node { |
||||
return enode.SignNull(new(enr.Record), testNodeID(i)) |
||||
} |
||||
|
||||
func TestQueueIteratorFIFO(t *testing.T) { |
||||
testQueueIterator(t, true) |
||||
} |
||||
|
||||
func TestQueueIteratorLIFO(t *testing.T) { |
||||
testQueueIterator(t, false) |
||||
} |
||||
|
||||
func testQueueIterator(t *testing.T, fifo bool) { |
||||
ns := nodestate.NewNodeStateMachine(nil, nil, &mclock.Simulated{}, testSetup) |
||||
qi := NewQueueIterator(ns, sfTest2, sfTest3.Or(sfTest4), fifo, nil) |
||||
ns.Start() |
||||
for i := 1; i <= iterTestNodeCount; i++ { |
||||
ns.SetState(testNode(i), sfTest1, nodestate.Flags{}, 0) |
||||
} |
||||
next := func() int { |
||||
ch := make(chan struct{}) |
||||
go func() { |
||||
qi.Next() |
||||
close(ch) |
||||
}() |
||||
select { |
||||
case <-ch: |
||||
case <-time.After(time.Second * 5): |
||||
t.Fatalf("Iterator.Next() timeout") |
||||
} |
||||
node := qi.Node() |
||||
ns.SetState(node, sfTest4, nodestate.Flags{}, 0) |
||||
return testNodeIndex(node.ID()) |
||||
} |
||||
exp := func(i int) { |
||||
n := next() |
||||
if n != i { |
||||
t.Errorf("Wrong item returned by iterator (expected %d, got %d)", i, n) |
||||
} |
||||
} |
||||
explist := func(list []int) { |
||||
for i := range list { |
||||
if fifo { |
||||
exp(list[i]) |
||||
} else { |
||||
exp(list[len(list)-1-i]) |
||||
} |
||||
} |
||||
} |
||||
|
||||
ns.SetState(testNode(1), sfTest2, nodestate.Flags{}, 0) |
||||
ns.SetState(testNode(2), sfTest2, nodestate.Flags{}, 0) |
||||
ns.SetState(testNode(3), sfTest2, nodestate.Flags{}, 0) |
||||
explist([]int{1, 2, 3}) |
||||
ns.SetState(testNode(4), sfTest2, nodestate.Flags{}, 0) |
||||
ns.SetState(testNode(5), sfTest2, nodestate.Flags{}, 0) |
||||
ns.SetState(testNode(6), sfTest2, nodestate.Flags{}, 0) |
||||
ns.SetState(testNode(5), sfTest3, nodestate.Flags{}, 0) |
||||
explist([]int{4, 6}) |
||||
ns.SetState(testNode(1), nodestate.Flags{}, sfTest4, 0) |
||||
ns.SetState(testNode(2), nodestate.Flags{}, sfTest4, 0) |
||||
ns.SetState(testNode(3), nodestate.Flags{}, sfTest4, 0) |
||||
ns.SetState(testNode(2), sfTest3, nodestate.Flags{}, 0) |
||||
ns.SetState(testNode(2), nodestate.Flags{}, sfTest3, 0) |
||||
explist([]int{1, 3, 2}) |
||||
ns.Stop() |
||||
} |
@ -0,0 +1,128 @@ |
||||
// Copyright 2020 The go-ethereum Authors
|
||||
// This file is part of the go-ethereum library.
|
||||
//
|
||||
// The go-ethereum library is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Lesser General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// The go-ethereum library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Lesser General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Lesser General Public License
|
||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package client |
||||
|
||||
import ( |
||||
"sync" |
||||
|
||||
"github.com/ethereum/go-ethereum/les/utils" |
||||
"github.com/ethereum/go-ethereum/p2p/enode" |
||||
"github.com/ethereum/go-ethereum/p2p/nodestate" |
||||
) |
||||
|
||||
// WrsIterator returns nodes from the specified selectable set with a weighted random
|
||||
// selection. Selection weights are provided by a callback function.
|
||||
type WrsIterator struct { |
||||
lock sync.Mutex |
||||
cond *sync.Cond |
||||
|
||||
ns *nodestate.NodeStateMachine |
||||
wrs *utils.WeightedRandomSelect |
||||
nextNode *enode.Node |
||||
closed bool |
||||
} |
||||
|
||||
// NewWrsIterator creates a new WrsIterator. Nodes are selectable if they have all the required
|
||||
// and none of the disabled flags set. When a node is selected the selectedFlag is set which also
|
||||
// disables further selectability until it is removed or times out.
|
||||
func NewWrsIterator(ns *nodestate.NodeStateMachine, requireFlags, disableFlags nodestate.Flags, weightField nodestate.Field) *WrsIterator { |
||||
wfn := func(i interface{}) uint64 { |
||||
n := ns.GetNode(i.(enode.ID)) |
||||
if n == nil { |
||||
return 0 |
||||
} |
||||
wt, _ := ns.GetField(n, weightField).(uint64) |
||||
return wt |
||||
} |
||||
|
||||
w := &WrsIterator{ |
||||
ns: ns, |
||||
wrs: utils.NewWeightedRandomSelect(wfn), |
||||
} |
||||
w.cond = sync.NewCond(&w.lock) |
||||
|
||||
ns.SubscribeField(weightField, func(n *enode.Node, state nodestate.Flags, oldValue, newValue interface{}) { |
||||
if state.HasAll(requireFlags) && state.HasNone(disableFlags) { |
||||
w.lock.Lock() |
||||
w.wrs.Update(n.ID()) |
||||
w.lock.Unlock() |
||||
w.cond.Signal() |
||||
} |
||||
}) |
||||
|
||||
ns.SubscribeState(requireFlags.Or(disableFlags), func(n *enode.Node, oldState, newState nodestate.Flags) { |
||||
oldMatch := oldState.HasAll(requireFlags) && oldState.HasNone(disableFlags) |
||||
newMatch := newState.HasAll(requireFlags) && newState.HasNone(disableFlags) |
||||
if newMatch == oldMatch { |
||||
return |
||||
} |
||||
|
||||
w.lock.Lock() |
||||
if newMatch { |
||||
w.wrs.Update(n.ID()) |
||||
} else { |
||||
w.wrs.Remove(n.ID()) |
||||
} |
||||
w.lock.Unlock() |
||||
w.cond.Signal() |
||||
}) |
||||
return w |
||||
} |
||||
|
||||
// Next selects the next node.
|
||||
func (w *WrsIterator) Next() bool { |
||||
w.nextNode = w.chooseNode() |
||||
return w.nextNode != nil |
||||
} |
||||
|
||||
func (w *WrsIterator) chooseNode() *enode.Node { |
||||
w.lock.Lock() |
||||
defer w.lock.Unlock() |
||||
|
||||
for { |
||||
for !w.closed && w.wrs.IsEmpty() { |
||||
w.cond.Wait() |
||||
} |
||||
if w.closed { |
||||
return nil |
||||
} |
||||
// Choose the next node at random. Even though w.wrs is guaranteed
|
||||
// non-empty here, Choose might return nil if all items have weight
|
||||
// zero.
|
||||
if c := w.wrs.Choose(); c != nil { |
||||
id := c.(enode.ID) |
||||
w.wrs.Remove(id) |
||||
return w.ns.GetNode(id) |
||||
} |
||||
} |
||||
|
||||
} |
||||
|
||||
// Close ends the iterator.
|
||||
func (w *WrsIterator) Close() { |
||||
w.lock.Lock() |
||||
w.closed = true |
||||
w.lock.Unlock() |
||||
w.cond.Signal() |
||||
} |
||||
|
||||
// Node returns the current node.
|
||||
func (w *WrsIterator) Node() *enode.Node { |
||||
w.lock.Lock() |
||||
defer w.lock.Unlock() |
||||
return w.nextNode |
||||
} |
@ -0,0 +1,103 @@ |
||||
// Copyright 2020 The go-ethereum Authors
|
||||
// This file is part of the go-ethereum library.
|
||||
//
|
||||
// The go-ethereum library is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Lesser General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// The go-ethereum library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Lesser General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Lesser General Public License
|
||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package client |
||||
|
||||
import ( |
||||
"reflect" |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/ethereum/go-ethereum/common/mclock" |
||||
"github.com/ethereum/go-ethereum/p2p/nodestate" |
||||
) |
||||
|
||||
var ( |
||||
testSetup = &nodestate.Setup{} |
||||
sfTest1 = testSetup.NewFlag("test1") |
||||
sfTest2 = testSetup.NewFlag("test2") |
||||
sfTest3 = testSetup.NewFlag("test3") |
||||
sfTest4 = testSetup.NewFlag("test4") |
||||
sfiTestWeight = testSetup.NewField("nodeWeight", reflect.TypeOf(uint64(0))) |
||||
) |
||||
|
||||
const iterTestNodeCount = 6 |
||||
|
||||
func TestWrsIterator(t *testing.T) { |
||||
ns := nodestate.NewNodeStateMachine(nil, nil, &mclock.Simulated{}, testSetup) |
||||
w := NewWrsIterator(ns, sfTest2, sfTest3.Or(sfTest4), sfiTestWeight) |
||||
ns.Start() |
||||
for i := 1; i <= iterTestNodeCount; i++ { |
||||
ns.SetState(testNode(i), sfTest1, nodestate.Flags{}, 0) |
||||
ns.SetField(testNode(i), sfiTestWeight, uint64(1)) |
||||
} |
||||
next := func() int { |
||||
ch := make(chan struct{}) |
||||
go func() { |
||||
w.Next() |
||||
close(ch) |
||||
}() |
||||
select { |
||||
case <-ch: |
||||
case <-time.After(time.Second * 5): |
||||
t.Fatalf("Iterator.Next() timeout") |
||||
} |
||||
node := w.Node() |
||||
ns.SetState(node, sfTest4, nodestate.Flags{}, 0) |
||||
return testNodeIndex(node.ID()) |
||||
} |
||||
set := make(map[int]bool) |
||||
expset := func() { |
||||
for len(set) > 0 { |
||||
n := next() |
||||
if !set[n] { |
||||
t.Errorf("Item returned by iterator not in the expected set (got %d)", n) |
||||
} |
||||
delete(set, n) |
||||
} |
||||
} |
||||
|
||||
ns.SetState(testNode(1), sfTest2, nodestate.Flags{}, 0) |
||||
ns.SetState(testNode(2), sfTest2, nodestate.Flags{}, 0) |
||||
ns.SetState(testNode(3), sfTest2, nodestate.Flags{}, 0) |
||||
set[1] = true |
||||
set[2] = true |
||||
set[3] = true |
||||
expset() |
||||
ns.SetState(testNode(4), sfTest2, nodestate.Flags{}, 0) |
||||
ns.SetState(testNode(5), sfTest2.Or(sfTest3), nodestate.Flags{}, 0) |
||||
ns.SetState(testNode(6), sfTest2, nodestate.Flags{}, 0) |
||||
set[4] = true |
||||
set[6] = true |
||||
expset() |
||||
ns.SetField(testNode(2), sfiTestWeight, uint64(0)) |
||||
ns.SetState(testNode(1), nodestate.Flags{}, sfTest4, 0) |
||||
ns.SetState(testNode(2), nodestate.Flags{}, sfTest4, 0) |
||||
ns.SetState(testNode(3), nodestate.Flags{}, sfTest4, 0) |
||||
set[1] = true |
||||
set[3] = true |
||||
expset() |
||||
ns.SetField(testNode(2), sfiTestWeight, uint64(1)) |
||||
ns.SetState(testNode(2), nodestate.Flags{}, sfTest2, 0) |
||||
ns.SetState(testNode(1), nodestate.Flags{}, sfTest4, 0) |
||||
ns.SetState(testNode(2), sfTest2, sfTest4, 0) |
||||
ns.SetState(testNode(3), nodestate.Flags{}, sfTest4, 0) |
||||
set[1] = true |
||||
set[2] = true |
||||
set[3] = true |
||||
expset() |
||||
ns.Stop() |
||||
} |
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,352 @@ |
||||
// Copyright 2020 The go-ethereum Authors
|
||||
// This file is part of the go-ethereum library.
|
||||
//
|
||||
// The go-ethereum library is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Lesser General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// The go-ethereum library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Lesser General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Lesser General Public License
|
||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package les |
||||
|
||||
import ( |
||||
"math/rand" |
||||
"sync/atomic" |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/ethereum/go-ethereum/common/mclock" |
||||
"github.com/ethereum/go-ethereum/ethdb" |
||||
"github.com/ethereum/go-ethereum/ethdb/memorydb" |
||||
lpc "github.com/ethereum/go-ethereum/les/lespay/client" |
||||
"github.com/ethereum/go-ethereum/p2p" |
||||
"github.com/ethereum/go-ethereum/p2p/enode" |
||||
"github.com/ethereum/go-ethereum/p2p/enr" |
||||
) |
||||
|
||||
const ( |
||||
spTestNodes = 1000 |
||||
spTestTarget = 5 |
||||
spTestLength = 10000 |
||||
spMinTotal = 40000 |
||||
spMaxTotal = 50000 |
||||
) |
||||
|
||||
func testNodeID(i int) enode.ID { |
||||
return enode.ID{42, byte(i % 256), byte(i / 256)} |
||||
} |
||||
|
||||
func testNodeIndex(id enode.ID) int { |
||||
if id[0] != 42 { |
||||
return -1 |
||||
} |
||||
return int(id[1]) + int(id[2])*256 |
||||
} |
||||
|
||||
type serverPoolTest struct { |
||||
db ethdb.KeyValueStore |
||||
clock *mclock.Simulated |
||||
quit chan struct{} |
||||
preNeg, preNegFail bool |
||||
vt *lpc.ValueTracker |
||||
sp *serverPool |
||||
input enode.Iterator |
||||
testNodes []spTestNode |
||||
trusted []string |
||||
waitCount, waitEnded int32 |
||||
|
||||
cycle, conn, servedConn int |
||||
serviceCycles, dialCount int |
||||
disconnect map[int][]int |
||||
} |
||||
|
||||
type spTestNode struct { |
||||
connectCycles, waitCycles int |
||||
nextConnCycle, totalConn int |
||||
connected, service bool |
||||
peer *serverPeer |
||||
} |
||||
|
||||
func newServerPoolTest(preNeg, preNegFail bool) *serverPoolTest { |
||||
nodes := make([]*enode.Node, spTestNodes) |
||||
for i := range nodes { |
||||
nodes[i] = enode.SignNull(&enr.Record{}, testNodeID(i)) |
||||
} |
||||
return &serverPoolTest{ |
||||
clock: &mclock.Simulated{}, |
||||
db: memorydb.New(), |
||||
input: enode.CycleNodes(nodes), |
||||
testNodes: make([]spTestNode, spTestNodes), |
||||
preNeg: preNeg, |
||||
preNegFail: preNegFail, |
||||
} |
||||
} |
||||
|
||||
func (s *serverPoolTest) beginWait() { |
||||
// ensure that dialIterator and the maximal number of pre-neg queries are not all stuck in a waiting state
|
||||
for atomic.AddInt32(&s.waitCount, 1) > preNegLimit { |
||||
atomic.AddInt32(&s.waitCount, -1) |
||||
s.clock.Run(time.Second) |
||||
} |
||||
} |
||||
|
||||
func (s *serverPoolTest) endWait() { |
||||
atomic.AddInt32(&s.waitCount, -1) |
||||
atomic.AddInt32(&s.waitEnded, 1) |
||||
} |
||||
|
||||
func (s *serverPoolTest) addTrusted(i int) { |
||||
s.trusted = append(s.trusted, enode.SignNull(&enr.Record{}, testNodeID(i)).String()) |
||||
} |
||||
|
||||
func (s *serverPoolTest) start() { |
||||
var testQuery queryFunc |
||||
if s.preNeg { |
||||
testQuery = func(node *enode.Node) int { |
||||
idx := testNodeIndex(node.ID()) |
||||
n := &s.testNodes[idx] |
||||
canConnect := !n.connected && n.connectCycles != 0 && s.cycle >= n.nextConnCycle |
||||
if s.preNegFail { |
||||
// simulate a scenario where UDP queries never work
|
||||
s.beginWait() |
||||
s.clock.Sleep(time.Second * 5) |
||||
s.endWait() |
||||
return -1 |
||||
} else { |
||||
switch idx % 3 { |
||||
case 0: |
||||
// pre-neg returns true only if connection is possible
|
||||
if canConnect { |
||||
return 1 |
||||
} else { |
||||
return 0 |
||||
} |
||||
case 1: |
||||
// pre-neg returns true but connection might still fail
|
||||
return 1 |
||||
case 2: |
||||
// pre-neg returns true if connection is possible, otherwise timeout (node unresponsive)
|
||||
if canConnect { |
||||
return 1 |
||||
} else { |
||||
s.beginWait() |
||||
s.clock.Sleep(time.Second * 5) |
||||
s.endWait() |
||||
return -1 |
||||
} |
||||
} |
||||
return -1 |
||||
} |
||||
} |
||||
} |
||||
|
||||
s.vt = lpc.NewValueTracker(s.db, s.clock, requestList, time.Minute, 1/float64(time.Hour), 1/float64(time.Hour*100), 1/float64(time.Hour*1000)) |
||||
s.sp = newServerPool(s.db, []byte("serverpool:"), s.vt, s.input, 0, testQuery, s.clock, s.trusted) |
||||
s.sp.validSchemes = enode.ValidSchemesForTesting |
||||
s.sp.unixTime = func() int64 { return int64(s.clock.Now()) / int64(time.Second) } |
||||
s.disconnect = make(map[int][]int) |
||||
s.sp.start() |
||||
s.quit = make(chan struct{}) |
||||
go func() { |
||||
last := int32(-1) |
||||
for { |
||||
select { |
||||
case <-time.After(time.Millisecond * 100): |
||||
c := atomic.LoadInt32(&s.waitEnded) |
||||
if c == last { |
||||
// advance clock if test is stuck (might happen in rare cases)
|
||||
s.clock.Run(time.Second) |
||||
} |
||||
last = c |
||||
case <-s.quit: |
||||
return |
||||
} |
||||
} |
||||
}() |
||||
} |
||||
|
||||
func (s *serverPoolTest) stop() { |
||||
close(s.quit) |
||||
s.sp.stop() |
||||
s.vt.Stop() |
||||
for i := range s.testNodes { |
||||
n := &s.testNodes[i] |
||||
if n.connected { |
||||
n.totalConn += s.cycle |
||||
} |
||||
n.connected = false |
||||
n.peer = nil |
||||
n.nextConnCycle = 0 |
||||
} |
||||
s.conn, s.servedConn = 0, 0 |
||||
} |
||||
|
||||
func (s *serverPoolTest) run() { |
||||
for count := spTestLength; count > 0; count-- { |
||||
if dcList := s.disconnect[s.cycle]; dcList != nil { |
||||
for _, idx := range dcList { |
||||
n := &s.testNodes[idx] |
||||
s.sp.unregisterPeer(n.peer) |
||||
n.totalConn += s.cycle |
||||
n.connected = false |
||||
n.peer = nil |
||||
s.conn-- |
||||
if n.service { |
||||
s.servedConn-- |
||||
} |
||||
n.nextConnCycle = s.cycle + n.waitCycles |
||||
} |
||||
delete(s.disconnect, s.cycle) |
||||
} |
||||
if s.conn < spTestTarget { |
||||
s.dialCount++ |
||||
s.beginWait() |
||||
s.sp.dialIterator.Next() |
||||
s.endWait() |
||||
dial := s.sp.dialIterator.Node() |
||||
id := dial.ID() |
||||
idx := testNodeIndex(id) |
||||
n := &s.testNodes[idx] |
||||
if !n.connected && n.connectCycles != 0 && s.cycle >= n.nextConnCycle { |
||||
s.conn++ |
||||
if n.service { |
||||
s.servedConn++ |
||||
} |
||||
n.totalConn -= s.cycle |
||||
n.connected = true |
||||
dc := s.cycle + n.connectCycles |
||||
s.disconnect[dc] = append(s.disconnect[dc], idx) |
||||
n.peer = &serverPeer{peerCommons: peerCommons{Peer: p2p.NewPeer(id, "", nil)}} |
||||
s.sp.registerPeer(n.peer) |
||||
if n.service { |
||||
s.vt.Served(s.vt.GetNode(id), []lpc.ServedRequest{{ReqType: 0, Amount: 100}}, 0) |
||||
} |
||||
} |
||||
} |
||||
s.serviceCycles += s.servedConn |
||||
s.clock.Run(time.Second) |
||||
s.cycle++ |
||||
} |
||||
} |
||||
|
||||
func (s *serverPoolTest) setNodes(count, conn, wait int, service, trusted bool) (res []int) { |
||||
for ; count > 0; count-- { |
||||
idx := rand.Intn(spTestNodes) |
||||
for s.testNodes[idx].connectCycles != 0 || s.testNodes[idx].connected { |
||||
idx = rand.Intn(spTestNodes) |
||||
} |
||||
res = append(res, idx) |
||||
s.testNodes[idx] = spTestNode{ |
||||
connectCycles: conn, |
||||
waitCycles: wait, |
||||
service: service, |
||||
} |
||||
if trusted { |
||||
s.addTrusted(idx) |
||||
} |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (s *serverPoolTest) resetNodes() { |
||||
for i, n := range s.testNodes { |
||||
if n.connected { |
||||
n.totalConn += s.cycle |
||||
s.sp.unregisterPeer(n.peer) |
||||
} |
||||
s.testNodes[i] = spTestNode{totalConn: n.totalConn} |
||||
} |
||||
s.conn, s.servedConn = 0, 0 |
||||
s.disconnect = make(map[int][]int) |
||||
s.trusted = nil |
||||
} |
||||
|
||||
func (s *serverPoolTest) checkNodes(t *testing.T, nodes []int) { |
||||
var sum int |
||||
for _, idx := range nodes { |
||||
n := &s.testNodes[idx] |
||||
if n.connected { |
||||
n.totalConn += s.cycle |
||||
} |
||||
sum += n.totalConn |
||||
n.totalConn = 0 |
||||
if n.connected { |
||||
n.totalConn -= s.cycle |
||||
} |
||||
} |
||||
if sum < spMinTotal || sum > spMaxTotal { |
||||
t.Errorf("Total connection amount %d outside expected range %d to %d", sum, spMinTotal, spMaxTotal) |
||||
} |
||||
} |
||||
|
||||
func TestServerPool(t *testing.T) { testServerPool(t, false, false) } |
||||
func TestServerPoolWithPreNeg(t *testing.T) { testServerPool(t, true, false) } |
||||
func TestServerPoolWithPreNegFail(t *testing.T) { testServerPool(t, true, true) } |
||||
func testServerPool(t *testing.T, preNeg, fail bool) { |
||||
s := newServerPoolTest(preNeg, fail) |
||||
nodes := s.setNodes(100, 200, 200, true, false) |
||||
s.setNodes(100, 20, 20, false, false) |
||||
s.start() |
||||
s.run() |
||||
s.stop() |
||||
s.checkNodes(t, nodes) |
||||
} |
||||
|
||||
func TestServerPoolChangedNodes(t *testing.T) { testServerPoolChangedNodes(t, false) } |
||||
func TestServerPoolChangedNodesWithPreNeg(t *testing.T) { testServerPoolChangedNodes(t, true) } |
||||
func testServerPoolChangedNodes(t *testing.T, preNeg bool) { |
||||
s := newServerPoolTest(preNeg, false) |
||||
nodes := s.setNodes(100, 200, 200, true, false) |
||||
s.setNodes(100, 20, 20, false, false) |
||||
s.start() |
||||
s.run() |
||||
s.checkNodes(t, nodes) |
||||
for i := 0; i < 3; i++ { |
||||
s.resetNodes() |
||||
nodes := s.setNodes(100, 200, 200, true, false) |
||||
s.setNodes(100, 20, 20, false, false) |
||||
s.run() |
||||
s.checkNodes(t, nodes) |
||||
} |
||||
s.stop() |
||||
} |
||||
|
||||
func TestServerPoolRestartNoDiscovery(t *testing.T) { testServerPoolRestartNoDiscovery(t, false) } |
||||
func TestServerPoolRestartNoDiscoveryWithPreNeg(t *testing.T) { |
||||
testServerPoolRestartNoDiscovery(t, true) |
||||
} |
||||
func testServerPoolRestartNoDiscovery(t *testing.T, preNeg bool) { |
||||
s := newServerPoolTest(preNeg, false) |
||||
nodes := s.setNodes(100, 200, 200, true, false) |
||||
s.setNodes(100, 20, 20, false, false) |
||||
s.start() |
||||
s.run() |
||||
s.stop() |
||||
s.checkNodes(t, nodes) |
||||
s.input = nil |
||||
s.start() |
||||
s.run() |
||||
s.stop() |
||||
s.checkNodes(t, nodes) |
||||
} |
||||
|
||||
func TestServerPoolTrustedNoDiscovery(t *testing.T) { testServerPoolTrustedNoDiscovery(t, false) } |
||||
func TestServerPoolTrustedNoDiscoveryWithPreNeg(t *testing.T) { |
||||
testServerPoolTrustedNoDiscovery(t, true) |
||||
} |
||||
func testServerPoolTrustedNoDiscovery(t *testing.T, preNeg bool) { |
||||
s := newServerPoolTest(preNeg, false) |
||||
trusted := s.setNodes(200, 200, 200, true, true) |
||||
s.input = nil |
||||
s.start() |
||||
s.run() |
||||
s.stop() |
||||
s.checkNodes(t, trusted) |
||||
} |
@ -0,0 +1,880 @@ |
||||
// Copyright 2020 The go-ethereum Authors
|
||||
// This file is part of the go-ethereum library.
|
||||
//
|
||||
// The go-ethereum library is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Lesser General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// The go-ethereum library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Lesser General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Lesser General Public License
|
||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package nodestate |
||||
|
||||
import ( |
||||
"errors" |
||||
"reflect" |
||||
"sync" |
||||
"time" |
||||
"unsafe" |
||||
|
||||
"github.com/ethereum/go-ethereum/common/mclock" |
||||
"github.com/ethereum/go-ethereum/ethdb" |
||||
"github.com/ethereum/go-ethereum/log" |
||||
"github.com/ethereum/go-ethereum/metrics" |
||||
"github.com/ethereum/go-ethereum/p2p/enode" |
||||
"github.com/ethereum/go-ethereum/p2p/enr" |
||||
"github.com/ethereum/go-ethereum/rlp" |
||||
) |
||||
|
||||
type ( |
||||
// NodeStateMachine connects different system components operating on subsets of
|
||||
// network nodes. Node states are represented by 64 bit vectors with each bit assigned
|
||||
// to a state flag. Each state flag has a descriptor structure and the mapping is
|
||||
// created automatically. It is possible to subscribe to subsets of state flags and
|
||||
// receive a callback if one of the nodes has a relevant state flag changed.
|
||||
// Callbacks can also modify further flags of the same node or other nodes. State
|
||||
// updates only return after all immediate effects throughout the system have happened
|
||||
// (deadlocks should be avoided by design of the implemented state logic). The caller
|
||||
// can also add timeouts assigned to a certain node and a subset of state flags.
|
||||
// If the timeout elapses, the flags are reset. If all relevant flags are reset then
|
||||
// the timer is dropped. State flags with no timeout are persisted in the database
|
||||
// if the flag descriptor enables saving. If a node has no state flags set at any
|
||||
// moment then it is discarded.
|
||||
//
|
||||
// Extra node fields can also be registered so system components can also store more
|
||||
// complex state for each node that is relevant to them, without creating a custom
|
||||
// peer set. Fields can be shared across multiple components if they all know the
|
||||
// field ID. Subscription to fields is also possible. Persistent fields should have
|
||||
// an encoder and a decoder function.
|
||||
NodeStateMachine struct { |
||||
started, stopped bool |
||||
lock sync.Mutex |
||||
clock mclock.Clock |
||||
db ethdb.KeyValueStore |
||||
dbNodeKey []byte |
||||
nodes map[enode.ID]*nodeInfo |
||||
offlineCallbackList []offlineCallback |
||||
|
||||
// Registered state flags or fields. Modifications are allowed
|
||||
// only when the node state machine has not been started.
|
||||
setup *Setup |
||||
fields []*fieldInfo |
||||
saveFlags bitMask |
||||
|
||||
// Installed callbacks. Modifications are allowed only when the
|
||||
// node state machine has not been started.
|
||||
stateSubs []stateSub |
||||
|
||||
// Testing hooks, only for testing purposes.
|
||||
saveNodeHook func(*nodeInfo) |
||||
} |
||||
|
||||
// Flags represents a set of flags from a certain setup
|
||||
Flags struct { |
||||
mask bitMask |
||||
setup *Setup |
||||
} |
||||
|
||||
// Field represents a field from a certain setup
|
||||
Field struct { |
||||
index int |
||||
setup *Setup |
||||
} |
||||
|
||||
// flagDefinition describes a node state flag. Each registered instance is automatically
|
||||
// mapped to a bit of the 64 bit node states.
|
||||
// If persistent is true then the node is saved when state machine is shutdown.
|
||||
flagDefinition struct { |
||||
name string |
||||
persistent bool |
||||
} |
||||
|
||||
// fieldDefinition describes an optional node field of the given type. The contents
|
||||
// of the field are only retained for each node as long as at least one of the
|
||||
// state flags is set.
|
||||
fieldDefinition struct { |
||||
name string |
||||
ftype reflect.Type |
||||
encode func(interface{}) ([]byte, error) |
||||
decode func([]byte) (interface{}, error) |
||||
} |
||||
|
||||
// stateSetup contains the list of flags and fields used by the application
|
||||
Setup struct { |
||||
Version uint |
||||
flags []flagDefinition |
||||
fields []fieldDefinition |
||||
} |
||||
|
||||
// bitMask describes a node state or state mask. It represents a subset
|
||||
// of node flags with each bit assigned to a flag index (LSB represents flag 0).
|
||||
bitMask uint64 |
||||
|
||||
// StateCallback is a subscription callback which is called when one of the
|
||||
// state flags that is included in the subscription state mask is changed.
|
||||
// Note: oldState and newState are also masked with the subscription mask so only
|
||||
// the relevant bits are included.
|
||||
StateCallback func(n *enode.Node, oldState, newState Flags) |
||||
|
||||
// FieldCallback is a subscription callback which is called when the value of
|
||||
// a specific field is changed.
|
||||
FieldCallback func(n *enode.Node, state Flags, oldValue, newValue interface{}) |
||||
|
||||
// nodeInfo contains node state, fields and state timeouts
|
||||
nodeInfo struct { |
||||
node *enode.Node |
||||
state bitMask |
||||
timeouts []*nodeStateTimeout |
||||
fields []interface{} |
||||
db, dirty bool |
||||
} |
||||
|
||||
nodeInfoEnc struct { |
||||
Enr enr.Record |
||||
Version uint |
||||
State bitMask |
||||
Fields [][]byte |
||||
} |
||||
|
||||
stateSub struct { |
||||
mask bitMask |
||||
callback StateCallback |
||||
} |
||||
|
||||
nodeStateTimeout struct { |
||||
mask bitMask |
||||
timer mclock.Timer |
||||
} |
||||
|
||||
fieldInfo struct { |
||||
fieldDefinition |
||||
subs []FieldCallback |
||||
} |
||||
|
||||
offlineCallback struct { |
||||
node *enode.Node |
||||
state bitMask |
||||
fields []interface{} |
||||
} |
||||
) |
||||
|
||||
// offlineState is a special state that is assumed to be set before a node is loaded from
|
||||
// the database and after it is shut down.
|
||||
const offlineState = bitMask(1) |
||||
|
||||
// NewFlag creates a new node state flag
|
||||
func (s *Setup) NewFlag(name string) Flags { |
||||
if s.flags == nil { |
||||
s.flags = []flagDefinition{{name: "offline"}} |
||||
} |
||||
f := Flags{mask: bitMask(1) << uint(len(s.flags)), setup: s} |
||||
s.flags = append(s.flags, flagDefinition{name: name}) |
||||
return f |
||||
} |
||||
|
||||
// NewPersistentFlag creates a new persistent node state flag
|
||||
func (s *Setup) NewPersistentFlag(name string) Flags { |
||||
if s.flags == nil { |
||||
s.flags = []flagDefinition{{name: "offline"}} |
||||
} |
||||
f := Flags{mask: bitMask(1) << uint(len(s.flags)), setup: s} |
||||
s.flags = append(s.flags, flagDefinition{name: name, persistent: true}) |
||||
return f |
||||
} |
||||
|
||||
// OfflineFlag returns the system-defined offline flag belonging to the given setup
|
||||
func (s *Setup) OfflineFlag() Flags { |
||||
return Flags{mask: offlineState, setup: s} |
||||
} |
||||
|
||||
// NewField creates a new node state field
|
||||
func (s *Setup) NewField(name string, ftype reflect.Type) Field { |
||||
f := Field{index: len(s.fields), setup: s} |
||||
s.fields = append(s.fields, fieldDefinition{ |
||||
name: name, |
||||
ftype: ftype, |
||||
}) |
||||
return f |
||||
} |
||||
|
||||
// NewPersistentField creates a new persistent node field
|
||||
func (s *Setup) NewPersistentField(name string, ftype reflect.Type, encode func(interface{}) ([]byte, error), decode func([]byte) (interface{}, error)) Field { |
||||
f := Field{index: len(s.fields), setup: s} |
||||
s.fields = append(s.fields, fieldDefinition{ |
||||
name: name, |
||||
ftype: ftype, |
||||
encode: encode, |
||||
decode: decode, |
||||
}) |
||||
return f |
||||
} |
||||
|
||||
// flagOp implements binary flag operations and also checks whether the operands belong to the same setup
|
||||
func flagOp(a, b Flags, trueIfA, trueIfB, trueIfBoth bool) Flags { |
||||
if a.setup == nil { |
||||
if a.mask != 0 { |
||||
panic("Node state flags have no setup reference") |
||||
} |
||||
a.setup = b.setup |
||||
} |
||||
if b.setup == nil { |
||||
if b.mask != 0 { |
||||
panic("Node state flags have no setup reference") |
||||
} |
||||
b.setup = a.setup |
||||
} |
||||
if a.setup != b.setup { |
||||
panic("Node state flags belong to a different setup") |
||||
} |
||||
res := Flags{setup: a.setup} |
||||
if trueIfA { |
||||
res.mask |= a.mask & ^b.mask |
||||
} |
||||
if trueIfB { |
||||
res.mask |= b.mask & ^a.mask |
||||
} |
||||
if trueIfBoth { |
||||
res.mask |= a.mask & b.mask |
||||
} |
||||
return res |
||||
} |
||||
|
||||
// And returns the set of flags present in both a and b
|
||||
func (a Flags) And(b Flags) Flags { return flagOp(a, b, false, false, true) } |
||||
|
||||
// AndNot returns the set of flags present in a but not in b
|
||||
func (a Flags) AndNot(b Flags) Flags { return flagOp(a, b, true, false, false) } |
||||
|
||||
// Or returns the set of flags present in either a or b
|
||||
func (a Flags) Or(b Flags) Flags { return flagOp(a, b, true, true, true) } |
||||
|
||||
// Xor returns the set of flags present in either a or b but not both
|
||||
func (a Flags) Xor(b Flags) Flags { return flagOp(a, b, true, true, false) } |
||||
|
||||
// HasAll returns true if b is a subset of a
|
||||
func (a Flags) HasAll(b Flags) bool { return flagOp(a, b, false, true, false).mask == 0 } |
||||
|
||||
// HasNone returns true if a and b have no shared flags
|
||||
func (a Flags) HasNone(b Flags) bool { return flagOp(a, b, false, false, true).mask == 0 } |
||||
|
||||
// Equals returns true if a and b have the same flags set
|
||||
func (a Flags) Equals(b Flags) bool { return flagOp(a, b, true, true, false).mask == 0 } |
||||
|
||||
// IsEmpty returns true if a has no flags set
|
||||
func (a Flags) IsEmpty() bool { return a.mask == 0 } |
||||
|
||||
// MergeFlags merges multiple sets of state flags
|
||||
func MergeFlags(list ...Flags) Flags { |
||||
if len(list) == 0 { |
||||
return Flags{} |
||||
} |
||||
res := list[0] |
||||
for i := 1; i < len(list); i++ { |
||||
res = res.Or(list[i]) |
||||
} |
||||
return res |
||||
} |
||||
|
||||
// String returns a list of the names of the flags specified in the bit mask
|
||||
func (f Flags) String() string { |
||||
if f.mask == 0 { |
||||
return "[]" |
||||
} |
||||
s := "[" |
||||
comma := false |
||||
for index, flag := range f.setup.flags { |
||||
if f.mask&(bitMask(1)<<uint(index)) != 0 { |
||||
if comma { |
||||
s = s + ", " |
||||
} |
||||
s = s + flag.name |
||||
comma = true |
||||
} |
||||
} |
||||
s = s + "]" |
||||
return s |
||||
} |
||||
|
||||
// NewNodeStateMachine creates a new node state machine.
|
||||
// If db is not nil then the node states, fields and active timeouts are persisted.
|
||||
// Persistence can be enabled or disabled for each state flag and field.
|
||||
func NewNodeStateMachine(db ethdb.KeyValueStore, dbKey []byte, clock mclock.Clock, setup *Setup) *NodeStateMachine { |
||||
if setup.flags == nil { |
||||
panic("No state flags defined") |
||||
} |
||||
if len(setup.flags) > 8*int(unsafe.Sizeof(bitMask(0))) { |
||||
panic("Too many node state flags") |
||||
} |
||||
ns := &NodeStateMachine{ |
||||
db: db, |
||||
dbNodeKey: dbKey, |
||||
clock: clock, |
||||
setup: setup, |
||||
nodes: make(map[enode.ID]*nodeInfo), |
||||
fields: make([]*fieldInfo, len(setup.fields)), |
||||
} |
||||
stateNameMap := make(map[string]int) |
||||
for index, flag := range setup.flags { |
||||
if _, ok := stateNameMap[flag.name]; ok { |
||||
panic("Node state flag name collision") |
||||
} |
||||
stateNameMap[flag.name] = index |
||||
if flag.persistent { |
||||
ns.saveFlags |= bitMask(1) << uint(index) |
||||
} |
||||
} |
||||
fieldNameMap := make(map[string]int) |
||||
for index, field := range setup.fields { |
||||
if _, ok := fieldNameMap[field.name]; ok { |
||||
panic("Node field name collision") |
||||
} |
||||
ns.fields[index] = &fieldInfo{fieldDefinition: field} |
||||
fieldNameMap[field.name] = index |
||||
} |
||||
return ns |
||||
} |
||||
|
||||
// stateMask checks whether the set of flags belongs to the same setup and returns its internal bit mask
|
||||
func (ns *NodeStateMachine) stateMask(flags Flags) bitMask { |
||||
if flags.setup != ns.setup && flags.mask != 0 { |
||||
panic("Node state flags belong to a different setup") |
||||
} |
||||
return flags.mask |
||||
} |
||||
|
||||
// fieldIndex checks whether the field belongs to the same setup and returns its internal index
|
||||
func (ns *NodeStateMachine) fieldIndex(field Field) int { |
||||
if field.setup != ns.setup { |
||||
panic("Node field belongs to a different setup") |
||||
} |
||||
return field.index |
||||
} |
||||
|
||||
// SubscribeState adds a node state subscription. The callback is called while the state
|
||||
// machine mutex is not held and it is allowed to make further state updates. All immediate
|
||||
// changes throughout the system are processed in the same thread/goroutine. It is the
|
||||
// responsibility of the implemented state logic to avoid deadlocks caused by the callbacks,
|
||||
// infinite toggling of flags or hazardous/non-deterministic state changes.
|
||||
// State subscriptions should be installed before loading the node database or making the
|
||||
// first state update.
|
||||
func (ns *NodeStateMachine) SubscribeState(flags Flags, callback StateCallback) { |
||||
ns.lock.Lock() |
||||
defer ns.lock.Unlock() |
||||
|
||||
if ns.started { |
||||
panic("state machine already started") |
||||
} |
||||
ns.stateSubs = append(ns.stateSubs, stateSub{ns.stateMask(flags), callback}) |
||||
} |
||||
|
||||
// SubscribeField adds a node field subscription. Same rules apply as for SubscribeState.
|
||||
func (ns *NodeStateMachine) SubscribeField(field Field, callback FieldCallback) { |
||||
ns.lock.Lock() |
||||
defer ns.lock.Unlock() |
||||
|
||||
if ns.started { |
||||
panic("state machine already started") |
||||
} |
||||
f := ns.fields[ns.fieldIndex(field)] |
||||
f.subs = append(f.subs, callback) |
||||
} |
||||
|
||||
// newNode creates a new nodeInfo
|
||||
func (ns *NodeStateMachine) newNode(n *enode.Node) *nodeInfo { |
||||
return &nodeInfo{node: n, fields: make([]interface{}, len(ns.fields))} |
||||
} |
||||
|
||||
// checkStarted checks whether the state machine has already been started and panics otherwise.
|
||||
func (ns *NodeStateMachine) checkStarted() { |
||||
if !ns.started { |
||||
panic("state machine not started yet") |
||||
} |
||||
} |
||||
|
||||
// Start starts the state machine, enabling state and field operations and disabling
|
||||
// further subscriptions.
|
||||
func (ns *NodeStateMachine) Start() { |
||||
ns.lock.Lock() |
||||
if ns.started { |
||||
panic("state machine already started") |
||||
} |
||||
ns.started = true |
||||
if ns.db != nil { |
||||
ns.loadFromDb() |
||||
} |
||||
ns.lock.Unlock() |
||||
ns.offlineCallbacks(true) |
||||
} |
||||
|
||||
// Stop stops the state machine and saves its state if a database was supplied
|
||||
func (ns *NodeStateMachine) Stop() { |
||||
ns.lock.Lock() |
||||
for _, node := range ns.nodes { |
||||
fields := make([]interface{}, len(node.fields)) |
||||
copy(fields, node.fields) |
||||
ns.offlineCallbackList = append(ns.offlineCallbackList, offlineCallback{node.node, node.state, fields}) |
||||
} |
||||
ns.stopped = true |
||||
if ns.db != nil { |
||||
ns.saveToDb() |
||||
ns.lock.Unlock() |
||||
} else { |
||||
ns.lock.Unlock() |
||||
} |
||||
ns.offlineCallbacks(false) |
||||
} |
||||
|
||||
// loadFromDb loads persisted node states from the database
|
||||
func (ns *NodeStateMachine) loadFromDb() { |
||||
it := ns.db.NewIterator(ns.dbNodeKey, nil) |
||||
for it.Next() { |
||||
var id enode.ID |
||||
if len(it.Key()) != len(ns.dbNodeKey)+len(id) { |
||||
log.Error("Node state db entry with invalid length", "found", len(it.Key()), "expected", len(ns.dbNodeKey)+len(id)) |
||||
continue |
||||
} |
||||
copy(id[:], it.Key()[len(ns.dbNodeKey):]) |
||||
ns.decodeNode(id, it.Value()) |
||||
} |
||||
} |
||||
|
||||
type dummyIdentity enode.ID |
||||
|
||||
func (id dummyIdentity) Verify(r *enr.Record, sig []byte) error { return nil } |
||||
func (id dummyIdentity) NodeAddr(r *enr.Record) []byte { return id[:] } |
||||
|
||||
// decodeNode decodes a node database entry and adds it to the node set if successful
|
||||
func (ns *NodeStateMachine) decodeNode(id enode.ID, data []byte) { |
||||
var enc nodeInfoEnc |
||||
if err := rlp.DecodeBytes(data, &enc); err != nil { |
||||
log.Error("Failed to decode node info", "id", id, "error", err) |
||||
return |
||||
} |
||||
n, _ := enode.New(dummyIdentity(id), &enc.Enr) |
||||
node := ns.newNode(n) |
||||
node.db = true |
||||
|
||||
if enc.Version != ns.setup.Version { |
||||
log.Debug("Removing stored node with unknown version", "current", ns.setup.Version, "stored", enc.Version) |
||||
ns.deleteNode(id) |
||||
return |
||||
} |
||||
if len(enc.Fields) > len(ns.setup.fields) { |
||||
log.Error("Invalid node field count", "id", id, "stored", len(enc.Fields)) |
||||
return |
||||
} |
||||
// Resolve persisted node fields
|
||||
for i, encField := range enc.Fields { |
||||
if len(encField) == 0 { |
||||
continue |
||||
} |
||||
if decode := ns.fields[i].decode; decode != nil { |
||||
if field, err := decode(encField); err == nil { |
||||
node.fields[i] = field |
||||
} else { |
||||
log.Error("Failed to decode node field", "id", id, "field name", ns.fields[i].name, "error", err) |
||||
return |
||||
} |
||||
} else { |
||||
log.Error("Cannot decode node field", "id", id, "field name", ns.fields[i].name) |
||||
return |
||||
} |
||||
} |
||||
// It's a compatible node record, add it to set.
|
||||
ns.nodes[id] = node |
||||
node.state = enc.State |
||||
fields := make([]interface{}, len(node.fields)) |
||||
copy(fields, node.fields) |
||||
ns.offlineCallbackList = append(ns.offlineCallbackList, offlineCallback{node.node, node.state, fields}) |
||||
log.Debug("Loaded node state", "id", id, "state", Flags{mask: enc.State, setup: ns.setup}) |
||||
} |
||||
|
||||
// saveNode saves the given node info to the database
|
||||
func (ns *NodeStateMachine) saveNode(id enode.ID, node *nodeInfo) error { |
||||
if ns.db == nil { |
||||
return nil |
||||
} |
||||
|
||||
storedState := node.state & ns.saveFlags |
||||
for _, t := range node.timeouts { |
||||
storedState &= ^t.mask |
||||
} |
||||
if storedState == 0 { |
||||
if node.db { |
||||
node.db = false |
||||
ns.deleteNode(id) |
||||
} |
||||
node.dirty = false |
||||
return nil |
||||
} |
||||
|
||||
enc := nodeInfoEnc{ |
||||
Enr: *node.node.Record(), |
||||
Version: ns.setup.Version, |
||||
State: storedState, |
||||
Fields: make([][]byte, len(ns.fields)), |
||||
} |
||||
log.Debug("Saved node state", "id", id, "state", Flags{mask: enc.State, setup: ns.setup}) |
||||
lastIndex := -1 |
||||
for i, f := range node.fields { |
||||
if f == nil { |
||||
continue |
||||
} |
||||
encode := ns.fields[i].encode |
||||
if encode == nil { |
||||
continue |
||||
} |
||||
blob, err := encode(f) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
enc.Fields[i] = blob |
||||
lastIndex = i |
||||
} |
||||
enc.Fields = enc.Fields[:lastIndex+1] |
||||
data, err := rlp.EncodeToBytes(&enc) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if err := ns.db.Put(append(ns.dbNodeKey, id[:]...), data); err != nil { |
||||
return err |
||||
} |
||||
node.dirty, node.db = false, true |
||||
|
||||
if ns.saveNodeHook != nil { |
||||
ns.saveNodeHook(node) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// deleteNode removes a node info from the database
|
||||
func (ns *NodeStateMachine) deleteNode(id enode.ID) { |
||||
ns.db.Delete(append(ns.dbNodeKey, id[:]...)) |
||||
} |
||||
|
||||
// saveToDb saves the persistent flags and fields of all nodes that have been changed
|
||||
func (ns *NodeStateMachine) saveToDb() { |
||||
for id, node := range ns.nodes { |
||||
if node.dirty { |
||||
err := ns.saveNode(id, node) |
||||
if err != nil { |
||||
log.Error("Failed to save node", "id", id, "error", err) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
// updateEnode updates the enode entry belonging to the given node if it already exists
|
||||
func (ns *NodeStateMachine) updateEnode(n *enode.Node) (enode.ID, *nodeInfo) { |
||||
id := n.ID() |
||||
node := ns.nodes[id] |
||||
if node != nil && n.Seq() > node.node.Seq() { |
||||
node.node = n |
||||
} |
||||
return id, node |
||||
} |
||||
|
||||
// Persist saves the persistent state and fields of the given node immediately
|
||||
func (ns *NodeStateMachine) Persist(n *enode.Node) error { |
||||
ns.lock.Lock() |
||||
defer ns.lock.Unlock() |
||||
|
||||
ns.checkStarted() |
||||
if id, node := ns.updateEnode(n); node != nil && node.dirty { |
||||
err := ns.saveNode(id, node) |
||||
if err != nil { |
||||
log.Error("Failed to save node", "id", id, "error", err) |
||||
} |
||||
return err |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// SetState updates the given node state flags and processes all resulting callbacks.
|
||||
// It only returns after all subsequent immediate changes (including those changed by the
|
||||
// callbacks) have been processed. If a flag with a timeout is set again, the operation
|
||||
// removes or replaces the existing timeout.
|
||||
func (ns *NodeStateMachine) SetState(n *enode.Node, setFlags, resetFlags Flags, timeout time.Duration) { |
||||
ns.lock.Lock() |
||||
ns.checkStarted() |
||||
if ns.stopped { |
||||
ns.lock.Unlock() |
||||
return |
||||
} |
||||
|
||||
set, reset := ns.stateMask(setFlags), ns.stateMask(resetFlags) |
||||
id, node := ns.updateEnode(n) |
||||
if node == nil { |
||||
if set == 0 { |
||||
ns.lock.Unlock() |
||||
return |
||||
} |
||||
node = ns.newNode(n) |
||||
ns.nodes[id] = node |
||||
} |
||||
oldState := node.state |
||||
newState := (node.state & (^reset)) | set |
||||
changed := oldState ^ newState |
||||
node.state = newState |
||||
|
||||
// Remove the timeout callbacks for all reset and set flags,
|
||||
// even they are not existent(it's noop).
|
||||
ns.removeTimeouts(node, set|reset) |
||||
|
||||
// Register the timeout callback if the new state is not empty
|
||||
// and timeout itself is required.
|
||||
if timeout != 0 && newState != 0 { |
||||
ns.addTimeout(n, set, timeout) |
||||
} |
||||
if newState == oldState { |
||||
ns.lock.Unlock() |
||||
return |
||||
} |
||||
if newState == 0 { |
||||
delete(ns.nodes, id) |
||||
if node.db { |
||||
ns.deleteNode(id) |
||||
} |
||||
} else { |
||||
if changed&ns.saveFlags != 0 { |
||||
node.dirty = true |
||||
} |
||||
} |
||||
ns.lock.Unlock() |
||||
// call state update subscription callbacks without holding the mutex
|
||||
for _, sub := range ns.stateSubs { |
||||
if changed&sub.mask != 0 { |
||||
sub.callback(n, Flags{mask: oldState & sub.mask, setup: ns.setup}, Flags{mask: newState & sub.mask, setup: ns.setup}) |
||||
} |
||||
} |
||||
if newState == 0 { |
||||
// call field subscriptions for discarded fields
|
||||
for i, v := range node.fields { |
||||
if v != nil { |
||||
f := ns.fields[i] |
||||
if len(f.subs) > 0 { |
||||
for _, cb := range f.subs { |
||||
cb(n, Flags{setup: ns.setup}, v, nil) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
// offlineCallbacks calls state update callbacks at startup or shutdown
|
||||
func (ns *NodeStateMachine) offlineCallbacks(start bool) { |
||||
for _, cb := range ns.offlineCallbackList { |
||||
for _, sub := range ns.stateSubs { |
||||
offState := offlineState & sub.mask |
||||
onState := cb.state & sub.mask |
||||
if offState != onState { |
||||
if start { |
||||
sub.callback(cb.node, Flags{mask: offState, setup: ns.setup}, Flags{mask: onState, setup: ns.setup}) |
||||
} else { |
||||
sub.callback(cb.node, Flags{mask: onState, setup: ns.setup}, Flags{mask: offState, setup: ns.setup}) |
||||
} |
||||
} |
||||
} |
||||
for i, f := range cb.fields { |
||||
if f != nil && ns.fields[i].subs != nil { |
||||
for _, fsub := range ns.fields[i].subs { |
||||
if start { |
||||
fsub(cb.node, Flags{mask: offlineState, setup: ns.setup}, nil, f) |
||||
} else { |
||||
fsub(cb.node, Flags{mask: offlineState, setup: ns.setup}, f, nil) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
} |
||||
ns.offlineCallbackList = nil |
||||
} |
||||
|
||||
// AddTimeout adds a node state timeout associated to the given state flag(s).
|
||||
// After the specified time interval, the relevant states will be reset.
|
||||
func (ns *NodeStateMachine) AddTimeout(n *enode.Node, flags Flags, timeout time.Duration) { |
||||
ns.lock.Lock() |
||||
defer ns.lock.Unlock() |
||||
|
||||
ns.checkStarted() |
||||
if ns.stopped { |
||||
return |
||||
} |
||||
ns.addTimeout(n, ns.stateMask(flags), timeout) |
||||
} |
||||
|
||||
// addTimeout adds a node state timeout associated to the given state flag(s).
|
||||
func (ns *NodeStateMachine) addTimeout(n *enode.Node, mask bitMask, timeout time.Duration) { |
||||
_, node := ns.updateEnode(n) |
||||
if node == nil { |
||||
return |
||||
} |
||||
mask &= node.state |
||||
if mask == 0 { |
||||
return |
||||
} |
||||
ns.removeTimeouts(node, mask) |
||||
t := &nodeStateTimeout{mask: mask} |
||||
t.timer = ns.clock.AfterFunc(timeout, func() { |
||||
ns.SetState(n, Flags{}, Flags{mask: t.mask, setup: ns.setup}, 0) |
||||
}) |
||||
node.timeouts = append(node.timeouts, t) |
||||
if mask&ns.saveFlags != 0 { |
||||
node.dirty = true |
||||
} |
||||
} |
||||
|
||||
// removeTimeout removes node state timeouts associated to the given state flag(s).
|
||||
// If a timeout was associated to multiple flags which are not all included in the
|
||||
// specified remove mask then only the included flags are de-associated and the timer
|
||||
// stays active.
|
||||
func (ns *NodeStateMachine) removeTimeouts(node *nodeInfo, mask bitMask) { |
||||
for i := 0; i < len(node.timeouts); i++ { |
||||
t := node.timeouts[i] |
||||
match := t.mask & mask |
||||
if match == 0 { |
||||
continue |
||||
} |
||||
t.mask -= match |
||||
if t.mask != 0 { |
||||
continue |
||||
} |
||||
t.timer.Stop() |
||||
node.timeouts[i] = node.timeouts[len(node.timeouts)-1] |
||||
node.timeouts = node.timeouts[:len(node.timeouts)-1] |
||||
i-- |
||||
if match&ns.saveFlags != 0 { |
||||
node.dirty = true |
||||
} |
||||
} |
||||
} |
||||
|
||||
// GetField retrieves the given field of the given node
|
||||
func (ns *NodeStateMachine) GetField(n *enode.Node, field Field) interface{} { |
||||
ns.lock.Lock() |
||||
defer ns.lock.Unlock() |
||||
|
||||
ns.checkStarted() |
||||
if ns.stopped { |
||||
return nil |
||||
} |
||||
if _, node := ns.updateEnode(n); node != nil { |
||||
return node.fields[ns.fieldIndex(field)] |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// SetField sets the given field of the given node
|
||||
func (ns *NodeStateMachine) SetField(n *enode.Node, field Field, value interface{}) error { |
||||
ns.lock.Lock() |
||||
ns.checkStarted() |
||||
if ns.stopped { |
||||
ns.lock.Unlock() |
||||
return nil |
||||
} |
||||
_, node := ns.updateEnode(n) |
||||
if node == nil { |
||||
ns.lock.Unlock() |
||||
return nil |
||||
} |
||||
fieldIndex := ns.fieldIndex(field) |
||||
f := ns.fields[fieldIndex] |
||||
if value != nil && reflect.TypeOf(value) != f.ftype { |
||||
log.Error("Invalid field type", "type", reflect.TypeOf(value), "required", f.ftype) |
||||
ns.lock.Unlock() |
||||
return errors.New("invalid field type") |
||||
} |
||||
oldValue := node.fields[fieldIndex] |
||||
if value == oldValue { |
||||
ns.lock.Unlock() |
||||
return nil |
||||
} |
||||
node.fields[fieldIndex] = value |
||||
if f.encode != nil { |
||||
node.dirty = true |
||||
} |
||||
|
||||
state := node.state |
||||
ns.lock.Unlock() |
||||
if len(f.subs) > 0 { |
||||
for _, cb := range f.subs { |
||||
cb(n, Flags{mask: state, setup: ns.setup}, oldValue, value) |
||||
} |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// ForEach calls the callback for each node having all of the required and none of the
|
||||
// disabled flags set
|
||||
func (ns *NodeStateMachine) ForEach(requireFlags, disableFlags Flags, cb func(n *enode.Node, state Flags)) { |
||||
ns.lock.Lock() |
||||
ns.checkStarted() |
||||
type callback struct { |
||||
node *enode.Node |
||||
state bitMask |
||||
} |
||||
require, disable := ns.stateMask(requireFlags), ns.stateMask(disableFlags) |
||||
var callbacks []callback |
||||
for _, node := range ns.nodes { |
||||
if node.state&require == require && node.state&disable == 0 { |
||||
callbacks = append(callbacks, callback{node.node, node.state & (require | disable)}) |
||||
} |
||||
} |
||||
ns.lock.Unlock() |
||||
for _, c := range callbacks { |
||||
cb(c.node, Flags{mask: c.state, setup: ns.setup}) |
||||
} |
||||
} |
||||
|
||||
// GetNode returns the enode currently associated with the given ID
|
||||
func (ns *NodeStateMachine) GetNode(id enode.ID) *enode.Node { |
||||
ns.lock.Lock() |
||||
defer ns.lock.Unlock() |
||||
|
||||
ns.checkStarted() |
||||
if node := ns.nodes[id]; node != nil { |
||||
return node.node |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// AddLogMetrics adds logging and/or metrics for nodes entering, exiting and currently
|
||||
// being in a given set specified by required and disabled state flags
|
||||
func (ns *NodeStateMachine) AddLogMetrics(requireFlags, disableFlags Flags, name string, inMeter, outMeter metrics.Meter, gauge metrics.Gauge) { |
||||
var count int64 |
||||
ns.SubscribeState(requireFlags.Or(disableFlags), func(n *enode.Node, oldState, newState Flags) { |
||||
oldMatch := oldState.HasAll(requireFlags) && oldState.HasNone(disableFlags) |
||||
newMatch := newState.HasAll(requireFlags) && newState.HasNone(disableFlags) |
||||
if newMatch == oldMatch { |
||||
return |
||||
} |
||||
|
||||
if newMatch { |
||||
count++ |
||||
if name != "" { |
||||
log.Debug("Node entered", "set", name, "id", n.ID(), "count", count) |
||||
} |
||||
if inMeter != nil { |
||||
inMeter.Mark(1) |
||||
} |
||||
} else { |
||||
count-- |
||||
if name != "" { |
||||
log.Debug("Node left", "set", name, "id", n.ID(), "count", count) |
||||
} |
||||
if outMeter != nil { |
||||
outMeter.Mark(1) |
||||
} |
||||
} |
||||
if gauge != nil { |
||||
gauge.Update(count) |
||||
} |
||||
}) |
||||
} |
@ -0,0 +1,389 @@ |
||||
// Copyright 2020 The go-ethereum Authors
|
||||
// This file is part of the go-ethereum library.
|
||||
//
|
||||
// The go-ethereum library is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Lesser General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// The go-ethereum library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Lesser General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Lesser General Public License
|
||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package nodestate |
||||
|
||||
import ( |
||||
"errors" |
||||
"fmt" |
||||
"reflect" |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/ethereum/go-ethereum/common/mclock" |
||||
"github.com/ethereum/go-ethereum/core/rawdb" |
||||
"github.com/ethereum/go-ethereum/p2p/enode" |
||||
"github.com/ethereum/go-ethereum/p2p/enr" |
||||
"github.com/ethereum/go-ethereum/rlp" |
||||
) |
||||
|
||||
func testSetup(flagPersist []bool, fieldType []reflect.Type) (*Setup, []Flags, []Field) { |
||||
setup := &Setup{} |
||||
flags := make([]Flags, len(flagPersist)) |
||||
for i, persist := range flagPersist { |
||||
if persist { |
||||
flags[i] = setup.NewPersistentFlag(fmt.Sprintf("flag-%d", i)) |
||||
} else { |
||||
flags[i] = setup.NewFlag(fmt.Sprintf("flag-%d", i)) |
||||
} |
||||
} |
||||
fields := make([]Field, len(fieldType)) |
||||
for i, ftype := range fieldType { |
||||
switch ftype { |
||||
case reflect.TypeOf(uint64(0)): |
||||
fields[i] = setup.NewPersistentField(fmt.Sprintf("field-%d", i), ftype, uint64FieldEnc, uint64FieldDec) |
||||
case reflect.TypeOf(""): |
||||
fields[i] = setup.NewPersistentField(fmt.Sprintf("field-%d", i), ftype, stringFieldEnc, stringFieldDec) |
||||
default: |
||||
fields[i] = setup.NewField(fmt.Sprintf("field-%d", i), ftype) |
||||
} |
||||
} |
||||
return setup, flags, fields |
||||
} |
||||
|
||||
func testNode(b byte) *enode.Node { |
||||
r := &enr.Record{} |
||||
r.SetSig(dummyIdentity{b}, []byte{42}) |
||||
n, _ := enode.New(dummyIdentity{b}, r) |
||||
return n |
||||
} |
||||
|
||||
func TestCallback(t *testing.T) { |
||||
mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} |
||||
|
||||
s, flags, _ := testSetup([]bool{false, false, false}, nil) |
||||
ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) |
||||
|
||||
set0 := make(chan struct{}, 1) |
||||
set1 := make(chan struct{}, 1) |
||||
set2 := make(chan struct{}, 1) |
||||
ns.SubscribeState(flags[0], func(n *enode.Node, oldState, newState Flags) { set0 <- struct{}{} }) |
||||
ns.SubscribeState(flags[1], func(n *enode.Node, oldState, newState Flags) { set1 <- struct{}{} }) |
||||
ns.SubscribeState(flags[2], func(n *enode.Node, oldState, newState Flags) { set2 <- struct{}{} }) |
||||
|
||||
ns.Start() |
||||
|
||||
ns.SetState(testNode(1), flags[0], Flags{}, 0) |
||||
ns.SetState(testNode(1), flags[1], Flags{}, time.Second) |
||||
ns.SetState(testNode(1), flags[2], Flags{}, 2*time.Second) |
||||
|
||||
for i := 0; i < 3; i++ { |
||||
select { |
||||
case <-set0: |
||||
case <-set1: |
||||
case <-set2: |
||||
case <-time.After(time.Second): |
||||
t.Fatalf("failed to invoke callback") |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestPersistentFlags(t *testing.T) { |
||||
mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} |
||||
|
||||
s, flags, _ := testSetup([]bool{true, true, true, false}, nil) |
||||
ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) |
||||
|
||||
saveNode := make(chan *nodeInfo, 5) |
||||
ns.saveNodeHook = func(node *nodeInfo) { |
||||
saveNode <- node |
||||
} |
||||
|
||||
ns.Start() |
||||
|
||||
ns.SetState(testNode(1), flags[0], Flags{}, time.Second) // state with timeout should not be saved
|
||||
ns.SetState(testNode(2), flags[1], Flags{}, 0) |
||||
ns.SetState(testNode(3), flags[2], Flags{}, 0) |
||||
ns.SetState(testNode(4), flags[3], Flags{}, 0) |
||||
ns.SetState(testNode(5), flags[0], Flags{}, 0) |
||||
ns.Persist(testNode(5)) |
||||
select { |
||||
case <-saveNode: |
||||
case <-time.After(time.Second): |
||||
t.Fatalf("Timeout") |
||||
} |
||||
ns.Stop() |
||||
|
||||
for i := 0; i < 2; i++ { |
||||
select { |
||||
case <-saveNode: |
||||
case <-time.After(time.Second): |
||||
t.Fatalf("Timeout") |
||||
} |
||||
} |
||||
select { |
||||
case <-saveNode: |
||||
t.Fatalf("Unexpected saveNode") |
||||
case <-time.After(time.Millisecond * 100): |
||||
} |
||||
} |
||||
|
||||
func TestSetField(t *testing.T) { |
||||
mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} |
||||
|
||||
s, flags, fields := testSetup([]bool{true}, []reflect.Type{reflect.TypeOf("")}) |
||||
ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) |
||||
|
||||
saveNode := make(chan *nodeInfo, 1) |
||||
ns.saveNodeHook = func(node *nodeInfo) { |
||||
saveNode <- node |
||||
} |
||||
|
||||
ns.Start() |
||||
|
||||
// Set field before setting state
|
||||
ns.SetField(testNode(1), fields[0], "hello world") |
||||
field := ns.GetField(testNode(1), fields[0]) |
||||
if field != nil { |
||||
t.Fatalf("Field shouldn't be set before setting states") |
||||
} |
||||
// Set field after setting state
|
||||
ns.SetState(testNode(1), flags[0], Flags{}, 0) |
||||
ns.SetField(testNode(1), fields[0], "hello world") |
||||
field = ns.GetField(testNode(1), fields[0]) |
||||
if field == nil { |
||||
t.Fatalf("Field should be set after setting states") |
||||
} |
||||
if err := ns.SetField(testNode(1), fields[0], 123); err == nil { |
||||
t.Fatalf("Invalid field should be rejected") |
||||
} |
||||
// Dirty node should be written back
|
||||
ns.Stop() |
||||
select { |
||||
case <-saveNode: |
||||
case <-time.After(time.Second): |
||||
t.Fatalf("Timeout") |
||||
} |
||||
} |
||||
|
||||
func TestUnsetField(t *testing.T) { |
||||
mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} |
||||
|
||||
s, flags, fields := testSetup([]bool{false}, []reflect.Type{reflect.TypeOf("")}) |
||||
ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) |
||||
|
||||
ns.Start() |
||||
|
||||
ns.SetState(testNode(1), flags[0], Flags{}, time.Second) |
||||
ns.SetField(testNode(1), fields[0], "hello world") |
||||
|
||||
ns.SetState(testNode(1), Flags{}, flags[0], 0) |
||||
if field := ns.GetField(testNode(1), fields[0]); field != nil { |
||||
t.Fatalf("Field should be unset") |
||||
} |
||||
} |
||||
|
||||
func TestSetState(t *testing.T) { |
||||
mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} |
||||
|
||||
s, flags, _ := testSetup([]bool{false, false, false}, nil) |
||||
ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) |
||||
|
||||
type change struct{ old, new Flags } |
||||
set := make(chan change, 1) |
||||
ns.SubscribeState(flags[0].Or(flags[1]), func(n *enode.Node, oldState, newState Flags) { |
||||
set <- change{ |
||||
old: oldState, |
||||
new: newState, |
||||
} |
||||
}) |
||||
|
||||
ns.Start() |
||||
|
||||
check := func(expectOld, expectNew Flags, expectChange bool) { |
||||
if expectChange { |
||||
select { |
||||
case c := <-set: |
||||
if !c.old.Equals(expectOld) { |
||||
t.Fatalf("Old state mismatch") |
||||
} |
||||
if !c.new.Equals(expectNew) { |
||||
t.Fatalf("New state mismatch") |
||||
} |
||||
case <-time.After(time.Second): |
||||
} |
||||
return |
||||
} |
||||
select { |
||||
case <-set: |
||||
t.Fatalf("Unexpected change") |
||||
case <-time.After(time.Millisecond * 100): |
||||
return |
||||
} |
||||
} |
||||
ns.SetState(testNode(1), flags[0], Flags{}, 0) |
||||
check(Flags{}, flags[0], true) |
||||
|
||||
ns.SetState(testNode(1), flags[1], Flags{}, 0) |
||||
check(flags[0], flags[0].Or(flags[1]), true) |
||||
|
||||
ns.SetState(testNode(1), flags[2], Flags{}, 0) |
||||
check(Flags{}, Flags{}, false) |
||||
|
||||
ns.SetState(testNode(1), Flags{}, flags[0], 0) |
||||
check(flags[0].Or(flags[1]), flags[1], true) |
||||
|
||||
ns.SetState(testNode(1), Flags{}, flags[1], 0) |
||||
check(flags[1], Flags{}, true) |
||||
|
||||
ns.SetState(testNode(1), Flags{}, flags[2], 0) |
||||
check(Flags{}, Flags{}, false) |
||||
|
||||
ns.SetState(testNode(1), flags[0].Or(flags[1]), Flags{}, time.Second) |
||||
check(Flags{}, flags[0].Or(flags[1]), true) |
||||
clock.Run(time.Second) |
||||
check(flags[0].Or(flags[1]), Flags{}, true) |
||||
} |
||||
|
||||
func uint64FieldEnc(field interface{}) ([]byte, error) { |
||||
if u, ok := field.(uint64); ok { |
||||
enc, err := rlp.EncodeToBytes(&u) |
||||
return enc, err |
||||
} else { |
||||
return nil, errors.New("invalid field type") |
||||
} |
||||
} |
||||
|
||||
func uint64FieldDec(enc []byte) (interface{}, error) { |
||||
var u uint64 |
||||
err := rlp.DecodeBytes(enc, &u) |
||||
return u, err |
||||
} |
||||
|
||||
func stringFieldEnc(field interface{}) ([]byte, error) { |
||||
if s, ok := field.(string); ok { |
||||
return []byte(s), nil |
||||
} else { |
||||
return nil, errors.New("invalid field type") |
||||
} |
||||
} |
||||
|
||||
func stringFieldDec(enc []byte) (interface{}, error) { |
||||
return string(enc), nil |
||||
} |
||||
|
||||
func TestPersistentFields(t *testing.T) { |
||||
mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} |
||||
|
||||
s, flags, fields := testSetup([]bool{true}, []reflect.Type{reflect.TypeOf(uint64(0)), reflect.TypeOf("")}) |
||||
ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) |
||||
|
||||
ns.Start() |
||||
ns.SetState(testNode(1), flags[0], Flags{}, 0) |
||||
ns.SetField(testNode(1), fields[0], uint64(100)) |
||||
ns.SetField(testNode(1), fields[1], "hello world") |
||||
ns.Stop() |
||||
|
||||
ns2 := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) |
||||
|
||||
ns2.Start() |
||||
field0 := ns2.GetField(testNode(1), fields[0]) |
||||
if !reflect.DeepEqual(field0, uint64(100)) { |
||||
t.Fatalf("Field changed") |
||||
} |
||||
field1 := ns2.GetField(testNode(1), fields[1]) |
||||
if !reflect.DeepEqual(field1, "hello world") { |
||||
t.Fatalf("Field changed") |
||||
} |
||||
|
||||
s.Version++ |
||||
ns3 := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) |
||||
ns3.Start() |
||||
if ns3.GetField(testNode(1), fields[0]) != nil { |
||||
t.Fatalf("Old field version should have been discarded") |
||||
} |
||||
} |
||||
|
||||
func TestFieldSub(t *testing.T) { |
||||
mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} |
||||
|
||||
s, flags, fields := testSetup([]bool{true}, []reflect.Type{reflect.TypeOf(uint64(0))}) |
||||
ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) |
||||
|
||||
var ( |
||||
lastState Flags |
||||
lastOldValue, lastNewValue interface{} |
||||
) |
||||
ns.SubscribeField(fields[0], func(n *enode.Node, state Flags, oldValue, newValue interface{}) { |
||||
lastState, lastOldValue, lastNewValue = state, oldValue, newValue |
||||
}) |
||||
check := func(state Flags, oldValue, newValue interface{}) { |
||||
if !lastState.Equals(state) || lastOldValue != oldValue || lastNewValue != newValue { |
||||
t.Fatalf("Incorrect field sub callback (expected [%v %v %v], got [%v %v %v])", state, oldValue, newValue, lastState, lastOldValue, lastNewValue) |
||||
} |
||||
} |
||||
ns.Start() |
||||
ns.SetState(testNode(1), flags[0], Flags{}, 0) |
||||
ns.SetField(testNode(1), fields[0], uint64(100)) |
||||
check(flags[0], nil, uint64(100)) |
||||
ns.Stop() |
||||
check(s.OfflineFlag(), uint64(100), nil) |
||||
|
||||
ns2 := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) |
||||
ns2.SubscribeField(fields[0], func(n *enode.Node, state Flags, oldValue, newValue interface{}) { |
||||
lastState, lastOldValue, lastNewValue = state, oldValue, newValue |
||||
}) |
||||
ns2.Start() |
||||
check(s.OfflineFlag(), nil, uint64(100)) |
||||
ns2.SetState(testNode(1), Flags{}, flags[0], 0) |
||||
check(Flags{}, uint64(100), nil) |
||||
ns2.Stop() |
||||
} |
||||
|
||||
func TestDuplicatedFlags(t *testing.T) { |
||||
mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} |
||||
|
||||
s, flags, _ := testSetup([]bool{true}, nil) |
||||
ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) |
||||
|
||||
type change struct{ old, new Flags } |
||||
set := make(chan change, 1) |
||||
ns.SubscribeState(flags[0], func(n *enode.Node, oldState, newState Flags) { |
||||
set <- change{oldState, newState} |
||||
}) |
||||
|
||||
ns.Start() |
||||
defer ns.Stop() |
||||
|
||||
check := func(expectOld, expectNew Flags, expectChange bool) { |
||||
if expectChange { |
||||
select { |
||||
case c := <-set: |
||||
if !c.old.Equals(expectOld) { |
||||
t.Fatalf("Old state mismatch") |
||||
} |
||||
if !c.new.Equals(expectNew) { |
||||
t.Fatalf("New state mismatch") |
||||
} |
||||
case <-time.After(time.Second): |
||||
} |
||||
return |
||||
} |
||||
select { |
||||
case <-set: |
||||
t.Fatalf("Unexpected change") |
||||
case <-time.After(time.Millisecond * 100): |
||||
return |
||||
} |
||||
} |
||||
ns.SetState(testNode(1), flags[0], Flags{}, time.Second) |
||||
check(Flags{}, flags[0], true) |
||||
ns.SetState(testNode(1), flags[0], Flags{}, 2*time.Second) // extend the timeout to 2s
|
||||
check(Flags{}, flags[0], false) |
||||
|
||||
clock.Run(2 * time.Second) |
||||
check(flags[0], Flags{}, true) |
||||
} |
Loading…
Reference in new issue