From c420dcb39c342842f6c115376774c79162465c64 Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Tue, 11 Jun 2019 12:45:33 +0200 Subject: [PATCH] p2p: enforce connection retry limit on server side (#19684) The dialer limits itself to one attempt every 30s. Apply the same limit in Server and reject peers which try to connect too eagerly. The check against the limit happens right after accepting the connection. Further changes in this commit ensure we pass the Server logger down to Peer instances, discovery and dialState. Unit test logging now works in all Server tests. --- p2p/dial.go | 107 ++++++--------------- p2p/dial_test.go | 160 ++++++++++++++------------------ p2p/netutil/addrutil.go | 33 +++++++ p2p/peer.go | 4 +- p2p/peer_test.go | 4 +- p2p/server.go | 200 ++++++++++++++++++++++------------------ p2p/server_test.go | 170 +++++++++++++++++++++++++++------- p2p/util.go | 82 ++++++++++++++++ p2p/util_test.go | 54 +++++++++++ 9 files changed, 520 insertions(+), 294 deletions(-) create mode 100644 p2p/netutil/addrutil.go create mode 100644 p2p/util.go create mode 100644 p2p/util_test.go diff --git a/p2p/dial.go b/p2p/dial.go index 936887a1a1..8dee5063f1 100644 --- a/p2p/dial.go +++ b/p2p/dial.go @@ -17,7 +17,6 @@ package p2p import ( - "container/heap" "errors" "fmt" "net" @@ -29,9 +28,10 @@ import ( ) const ( - // This is the amount of time spent waiting in between - // redialing a certain node. - dialHistoryExpiration = 30 * time.Second + // This is the amount of time spent waiting in between redialing a certain node. The + // limit is a bit higher than inboundThrottleTime to prevent failing dials in small + // private networks. + dialHistoryExpiration = inboundThrottleTime + 5*time.Second // Discovery lookups are throttled and can only run // once every few seconds. @@ -72,16 +72,16 @@ type dialstate struct { ntab discoverTable netrestrict *netutil.Netlist self enode.ID + bootnodes []*enode.Node // default dials when there are no peers + log log.Logger + start time.Time // time when the dialer was first used lookupRunning bool dialing map[enode.ID]connFlag lookupBuf []*enode.Node // current discovery lookup results randomNodes []*enode.Node // filled from Table static map[enode.ID]*dialTask - hist *dialHistory - - start time.Time // time when the dialer was first used - bootnodes []*enode.Node // default dials when there are no peers + hist expHeap } type discoverTable interface { @@ -91,15 +91,6 @@ type discoverTable interface { ReadRandomNodes([]*enode.Node) int } -// the dial history remembers recent dials. -type dialHistory []pastDial - -// pastDial is an entry in the dial history. -type pastDial struct { - id enode.ID - exp time.Time -} - type task interface { Do(*Server) } @@ -126,20 +117,23 @@ type waitExpireTask struct { time.Duration } -func newDialState(self enode.ID, static []*enode.Node, bootnodes []*enode.Node, ntab discoverTable, maxdyn int, netrestrict *netutil.Netlist) *dialstate { +func newDialState(self enode.ID, ntab discoverTable, maxdyn int, cfg *Config) *dialstate { s := &dialstate{ maxDynDials: maxdyn, ntab: ntab, self: self, - netrestrict: netrestrict, + netrestrict: cfg.NetRestrict, + log: cfg.Logger, static: make(map[enode.ID]*dialTask), dialing: make(map[enode.ID]connFlag), - bootnodes: make([]*enode.Node, len(bootnodes)), + bootnodes: make([]*enode.Node, len(cfg.BootstrapNodes)), randomNodes: make([]*enode.Node, maxdyn/2), - hist: new(dialHistory), } - copy(s.bootnodes, bootnodes) - for _, n := range static { + copy(s.bootnodes, cfg.BootstrapNodes) + if s.log == nil { + s.log = log.Root() + } + for _, n := range cfg.StaticNodes { s.addStatic(n) } return s @@ -154,9 +148,6 @@ func (s *dialstate) addStatic(n *enode.Node) { func (s *dialstate) removeStatic(n *enode.Node) { // This removes a task so future attempts to connect will not be made. delete(s.static, n.ID()) - // This removes a previous dial timestamp so that application - // can force a server to reconnect with chosen peer immediately. - s.hist.remove(n.ID()) } func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Time) []task { @@ -167,7 +158,7 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti var newtasks []task addDial := func(flag connFlag, n *enode.Node) bool { if err := s.checkDial(n, peers); err != nil { - log.Trace("Skipping dial candidate", "id", n.ID(), "addr", &net.TCPAddr{IP: n.IP(), Port: n.TCP()}, "err", err) + s.log.Trace("Skipping dial candidate", "id", n.ID(), "addr", &net.TCPAddr{IP: n.IP(), Port: n.TCP()}, "err", err) return false } s.dialing[n.ID()] = flag @@ -196,7 +187,7 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti err := s.checkDial(t.dest, peers) switch err { case errNotWhitelisted, errSelf: - log.Warn("Removing static dial candidate", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}, "err", err) + s.log.Warn("Removing static dial candidate", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}, "err", err) delete(s.static, t.dest.ID()) case nil: s.dialing[id] = t.flags @@ -246,7 +237,7 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti // This should prevent cases where the dialer logic is not ticked // because there are no pending events. if nRunning == 0 && len(newtasks) == 0 && s.hist.Len() > 0 { - t := &waitExpireTask{s.hist.min().exp.Sub(now)} + t := &waitExpireTask{s.hist.nextExpiry().Sub(now)} newtasks = append(newtasks, t) } return newtasks @@ -271,7 +262,7 @@ func (s *dialstate) checkDial(n *enode.Node, peers map[enode.ID]*Peer) error { return errSelf case s.netrestrict != nil && !s.netrestrict.Contains(n.IP()): return errNotWhitelisted - case s.hist.contains(n.ID()): + case s.hist.contains(string(n.ID().Bytes())): return errRecentlyDialed } return nil @@ -280,7 +271,7 @@ func (s *dialstate) checkDial(n *enode.Node, peers map[enode.ID]*Peer) error { func (s *dialstate) taskDone(t task, now time.Time) { switch t := t.(type) { case *dialTask: - s.hist.add(t.dest.ID(), now.Add(dialHistoryExpiration)) + s.hist.add(string(t.dest.ID().Bytes()), now.Add(dialHistoryExpiration)) delete(s.dialing, t.dest.ID()) case *discoverTask: s.lookupRunning = false @@ -296,7 +287,7 @@ func (t *dialTask) Do(srv *Server) { } err := t.dial(srv, t.dest) if err != nil { - log.Trace("Dial error", "task", t, "err", err) + srv.log.Trace("Dial error", "task", t, "err", err) // Try resolving the ID of static nodes if dialing failed. if _, ok := err.(*dialError); ok && t.flags&staticDialedConn != 0 { if t.resolve(srv) { @@ -314,7 +305,7 @@ func (t *dialTask) Do(srv *Server) { // The backoff delay resets when the node is found. func (t *dialTask) resolve(srv *Server) bool { if srv.ntab == nil { - log.Debug("Can't resolve node", "id", t.dest.ID, "err", "discovery is disabled") + srv.log.Debug("Can't resolve node", "id", t.dest.ID, "err", "discovery is disabled") return false } if t.resolveDelay == 0 { @@ -330,13 +321,13 @@ func (t *dialTask) resolve(srv *Server) bool { if t.resolveDelay > maxResolveDelay { t.resolveDelay = maxResolveDelay } - log.Debug("Resolving node failed", "id", t.dest.ID, "newdelay", t.resolveDelay) + srv.log.Debug("Resolving node failed", "id", t.dest.ID, "newdelay", t.resolveDelay) return false } // The node was found. t.resolveDelay = initialResolveDelay t.dest = resolved - log.Debug("Resolved node", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}) + srv.log.Debug("Resolved node", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}) return true } @@ -385,49 +376,3 @@ func (t waitExpireTask) Do(*Server) { func (t waitExpireTask) String() string { return fmt.Sprintf("wait for dial hist expire (%v)", t.Duration) } - -// Use only these methods to access or modify dialHistory. -func (h dialHistory) min() pastDial { - return h[0] -} -func (h *dialHistory) add(id enode.ID, exp time.Time) { - heap.Push(h, pastDial{id, exp}) - -} -func (h *dialHistory) remove(id enode.ID) bool { - for i, v := range *h { - if v.id == id { - heap.Remove(h, i) - return true - } - } - return false -} -func (h dialHistory) contains(id enode.ID) bool { - for _, v := range h { - if v.id == id { - return true - } - } - return false -} -func (h *dialHistory) expire(now time.Time) { - for h.Len() > 0 && h.min().exp.Before(now) { - heap.Pop(h) - } -} - -// heap.Interface boilerplate -func (h dialHistory) Len() int { return len(h) } -func (h dialHistory) Less(i, j int) bool { return h[i].exp.Before(h[j].exp) } -func (h dialHistory) Swap(i, j int) { h[i], h[j] = h[j], h[i] } -func (h *dialHistory) Push(x interface{}) { - *h = append(*h, x.(pastDial)) -} -func (h *dialHistory) Pop() interface{} { - old := *h - n := len(old) - x := old[n-1] - *h = old[0 : n-1] - return x -} diff --git a/p2p/dial_test.go b/p2p/dial_test.go index f41ab77524..de8fc4a6e3 100644 --- a/p2p/dial_test.go +++ b/p2p/dial_test.go @@ -20,10 +20,13 @@ import ( "encoding/binary" "net" "reflect" + "strings" "testing" "time" "github.com/davecgh/go-spew/spew" + "github.com/ethereum/go-ethereum/internal/testlog" + "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enr" "github.com/ethereum/go-ethereum/p2p/netutil" @@ -67,10 +70,10 @@ func runDialTest(t *testing.T, test dialtest) { new := test.init.newTasks(running, pm(round.peers), vtime) if !sametasks(new, round.new) { - t.Errorf("round %d: new tasks mismatch:\ngot %v\nwant %v\nstate: %v\nrunning: %v\n", + t.Errorf("ERROR round %d: got %v\nwant %v\nstate: %v\nrunning: %v", i, spew.Sdump(new), spew.Sdump(round.new), spew.Sdump(test.init), spew.Sdump(running)) } - t.Log("tasks:", spew.Sdump(new)) + t.Logf("round %d new tasks: %s", i, strings.TrimSpace(spew.Sdump(new))) // Time advances by 16 seconds on every round. vtime = vtime.Add(16 * time.Second) @@ -88,8 +91,9 @@ func (t fakeTable) ReadRandomNodes(buf []*enode.Node) int { return copy(buf, t) // This test checks that dynamic dials are launched from discovery results. func TestDialStateDynDial(t *testing.T) { + config := &Config{Logger: testlog.Logger(t, log.LvlTrace)} runDialTest(t, dialtest{ - init: newDialState(enode.ID{}, nil, nil, fakeTable{}, 5, nil), + init: newDialState(enode.ID{}, fakeTable{}, 5, config), rounds: []round{ // A discovery query is launched. { @@ -153,7 +157,7 @@ func TestDialStateDynDial(t *testing.T) { &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, }, new: []task{ - &waitExpireTask{Duration: 14 * time.Second}, + &waitExpireTask{Duration: 19 * time.Second}, }, }, // In this round, the peer with id 2 drops off. The query @@ -223,10 +227,13 @@ func TestDialStateDynDial(t *testing.T) { // Tests that bootnodes are dialed if no peers are connectd, but not otherwise. func TestDialStateDynDialBootnode(t *testing.T) { - bootnodes := []*enode.Node{ - newNode(uintID(1), nil), - newNode(uintID(2), nil), - newNode(uintID(3), nil), + config := &Config{ + BootstrapNodes: []*enode.Node{ + newNode(uintID(1), nil), + newNode(uintID(2), nil), + newNode(uintID(3), nil), + }, + Logger: testlog.Logger(t, log.LvlTrace), } table := fakeTable{ newNode(uintID(4), nil), @@ -236,7 +243,7 @@ func TestDialStateDynDialBootnode(t *testing.T) { newNode(uintID(8), nil), } runDialTest(t, dialtest{ - init: newDialState(enode.ID{}, nil, bootnodes, table, 5, nil), + init: newDialState(enode.ID{}, table, 5, config), rounds: []round{ // 2 dynamic dials attempted, bootnodes pending fallback interval { @@ -259,25 +266,24 @@ func TestDialStateDynDialBootnode(t *testing.T) { { new: []task{ &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, }, }, // No dials succeed, 2nd bootnode is attempted { done: []task{ &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, }, new: []task{ &dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, }, }, // No dials succeed, 3rd bootnode is attempted { done: []task{ &dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, }, new: []task{ &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)}, @@ -288,21 +294,19 @@ func TestDialStateDynDialBootnode(t *testing.T) { done: []task{ &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)}, }, - new: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, - }, + new: []task{}, }, // Random dial succeeds, no more bootnodes are attempted { + new: []task{ + &waitExpireTask{3 * time.Second}, + }, peers: []*Peer{ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(4), nil)}}, }, done: []task{ &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)}, &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, }, }, }, @@ -324,7 +328,7 @@ func TestDialStateDynDialFromTable(t *testing.T) { } runDialTest(t, dialtest{ - init: newDialState(enode.ID{}, nil, nil, table, 10, nil), + init: newDialState(enode.ID{}, table, 10, &Config{Logger: testlog.Logger(t, log.LvlTrace)}), rounds: []round{ // 5 out of 8 of the nodes returned by ReadRandomNodes are dialed. { @@ -430,7 +434,7 @@ func TestDialStateNetRestrict(t *testing.T) { restrict.Add("127.0.2.0/24") runDialTest(t, dialtest{ - init: newDialState(enode.ID{}, nil, nil, table, 10, restrict), + init: newDialState(enode.ID{}, table, 10, &Config{NetRestrict: restrict}), rounds: []round{ { new: []task{ @@ -444,16 +448,18 @@ func TestDialStateNetRestrict(t *testing.T) { // This test checks that static dials are launched. func TestDialStateStaticDial(t *testing.T) { - wantStatic := []*enode.Node{ - newNode(uintID(1), nil), - newNode(uintID(2), nil), - newNode(uintID(3), nil), - newNode(uintID(4), nil), - newNode(uintID(5), nil), + config := &Config{ + StaticNodes: []*enode.Node{ + newNode(uintID(1), nil), + newNode(uintID(2), nil), + newNode(uintID(3), nil), + newNode(uintID(4), nil), + newNode(uintID(5), nil), + }, + Logger: testlog.Logger(t, log.LvlTrace), } - runDialTest(t, dialtest{ - init: newDialState(enode.ID{}, wantStatic, nil, fakeTable{}, 0, nil), + init: newDialState(enode.ID{}, fakeTable{}, 0, config), rounds: []round{ // Static dials are launched for the nodes that // aren't yet connected. @@ -495,7 +501,7 @@ func TestDialStateStaticDial(t *testing.T) { &dialTask{flags: staticDialedConn, dest: newNode(uintID(5), nil)}, }, new: []task{ - &waitExpireTask{Duration: 14 * time.Second}, + &waitExpireTask{Duration: 19 * time.Second}, }, }, // Wait a round for dial history to expire, no new tasks should spawn. @@ -511,6 +517,9 @@ func TestDialStateStaticDial(t *testing.T) { // If a static node is dropped, it should be immediately redialed, // irrespective whether it was originally static or dynamic. { + done: []task{ + &waitExpireTask{Duration: 19 * time.Second}, + }, peers: []*Peer{ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, {rw: &conn{flags: staticDialedConn, node: newNode(uintID(3), nil)}}, @@ -518,67 +527,24 @@ func TestDialStateStaticDial(t *testing.T) { }, new: []task{ &dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)}, - &dialTask{flags: staticDialedConn, dest: newNode(uintID(4), nil)}, }, }, }, }) } -// This test checks that static peers will be redialed immediately if they were re-added to a static list. -func TestDialStaticAfterReset(t *testing.T) { - wantStatic := []*enode.Node{ - newNode(uintID(1), nil), - newNode(uintID(2), nil), - } - - rounds := []round{ - // Static dials are launched for the nodes that aren't yet connected. - { - peers: nil, - new: []task{ - &dialTask{flags: staticDialedConn, dest: newNode(uintID(1), nil)}, - &dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)}, - }, - }, - // No new dial tasks, all peers are connected. - { - peers: []*Peer{ - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}}, - }, - done: []task{ - &dialTask{flags: staticDialedConn, dest: newNode(uintID(1), nil)}, - &dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)}, - }, - new: []task{ - &waitExpireTask{Duration: 30 * time.Second}, - }, - }, - } - dTest := dialtest{ - init: newDialState(enode.ID{}, wantStatic, nil, fakeTable{}, 0, nil), - rounds: rounds, - } - runDialTest(t, dTest) - for _, n := range wantStatic { - dTest.init.removeStatic(n) - dTest.init.addStatic(n) - } - // without removing peers they will be considered recently dialed - runDialTest(t, dTest) -} - // This test checks that past dials are not retried for some time. func TestDialStateCache(t *testing.T) { - wantStatic := []*enode.Node{ - newNode(uintID(1), nil), - newNode(uintID(2), nil), - newNode(uintID(3), nil), + config := &Config{ + StaticNodes: []*enode.Node{ + newNode(uintID(1), nil), + newNode(uintID(2), nil), + newNode(uintID(3), nil), + }, + Logger: testlog.Logger(t, log.LvlTrace), } - runDialTest(t, dialtest{ - init: newDialState(enode.ID{}, wantStatic, nil, fakeTable{}, 0, nil), + init: newDialState(enode.ID{}, fakeTable{}, 0, config), rounds: []round{ // Static dials are launched for the nodes that // aren't yet connected. @@ -606,28 +572,37 @@ func TestDialStateCache(t *testing.T) { // entry to expire. { peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, + {rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}}, + {rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}}, }, done: []task{ &dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)}, }, new: []task{ - &waitExpireTask{Duration: 14 * time.Second}, + &waitExpireTask{Duration: 19 * time.Second}, }, }, // Still waiting for node 3's entry to expire in the cache. { peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, + {rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}}, + {rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}}, + }, + }, + { + peers: []*Peer{ + {rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}}, + {rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}}, }, }, // The cache entry for node 3 has expired and is retried. { + done: []task{ + &waitExpireTask{Duration: 19 * time.Second}, + }, peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, + {rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}}, + {rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}}, }, new: []task{ &dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)}, @@ -638,9 +613,13 @@ func TestDialStateCache(t *testing.T) { } func TestDialResolve(t *testing.T) { + config := &Config{ + Logger: testlog.Logger(t, log.LvlTrace), + Dialer: TCPDialer{&net.Dialer{Deadline: time.Now().Add(-5 * time.Minute)}}, + } resolved := newNode(uintID(1), net.IP{127, 0, 55, 234}) table := &resolveMock{answer: resolved} - state := newDialState(enode.ID{}, nil, nil, table, 0, nil) + state := newDialState(enode.ID{}, table, 0, config) // Check that the task is generated with an incomplete ID. dest := newNode(uintID(1), nil) @@ -651,8 +630,7 @@ func TestDialResolve(t *testing.T) { } // Now run the task, it should resolve the ID once. - config := Config{Dialer: TCPDialer{&net.Dialer{Deadline: time.Now().Add(-5 * time.Minute)}}} - srv := &Server{ntab: table, Config: config} + srv := &Server{ntab: table, log: config.Logger, Config: *config} tasks[0].Do(srv) if !reflect.DeepEqual(table.resolveCalls, []*enode.Node{dest}) { t.Fatalf("wrong resolve calls, got %v", table.resolveCalls) diff --git a/p2p/netutil/addrutil.go b/p2p/netutil/addrutil.go new file mode 100644 index 0000000000..b261a52955 --- /dev/null +++ b/p2p/netutil/addrutil.go @@ -0,0 +1,33 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package netutil + +import "net" + +// AddrIP gets the IP address contained in addr. It returns nil if no address is present. +func AddrIP(addr net.Addr) net.IP { + switch a := addr.(type) { + case *net.IPAddr: + return a.IP + case *net.TCPAddr: + return a.IP + case *net.UDPAddr: + return a.IP + default: + return nil + } +} diff --git a/p2p/peer.go b/p2p/peer.go index af019d07a8..98ea6835d0 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -120,7 +120,7 @@ func NewPeer(id enode.ID, name string, caps []Cap) *Peer { pipe, _ := net.Pipe() node := enode.SignNull(new(enr.Record), id) conn := &conn{fd: pipe, transport: nil, node: node, caps: caps, name: name} - peer := newPeer(conn, nil) + peer := newPeer(log.Root(), conn, nil) close(peer.closed) // ensures Disconnect doesn't block return peer } @@ -176,7 +176,7 @@ func (p *Peer) Inbound() bool { return p.rw.is(inboundConn) } -func newPeer(conn *conn, protocols []Protocol) *Peer { +func newPeer(log log.Logger, conn *conn, protocols []Protocol) *Peer { protomap := matchProtocols(protocols, conn.caps, conn) p := &Peer{ rw: conn, diff --git a/p2p/peer_test.go b/p2p/peer_test.go index 5aa64a32e1..984cc411ad 100644 --- a/p2p/peer_test.go +++ b/p2p/peer_test.go @@ -24,6 +24,8 @@ import ( "reflect" "testing" "time" + + "github.com/ethereum/go-ethereum/log" ) var discard = Protocol{ @@ -52,7 +54,7 @@ func testPeer(protos []Protocol) (func(), *conn, *Peer, <-chan error) { c2.caps = append(c2.caps, p.cap()) } - peer := newPeer(c1, protos) + peer := newPeer(log.Root(), c1, protos) errc := make(chan error, 1) go func() { _, err := peer.run() diff --git a/p2p/server.go b/p2p/server.go index b3494cc888..a373904fc1 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -22,6 +22,7 @@ import ( "crypto/ecdsa" "encoding/hex" "errors" + "fmt" "net" "sort" "sync" @@ -49,6 +50,9 @@ const ( defaultMaxPendingPeers = 50 defaultDialRatio = 3 + // This time limits inbound connection attempts per source IP. + inboundThrottleTime = 30 * time.Second + // Maximum time allowed for reading a complete message. // This is effectively the amount of time a connection can be idle. frameReadTimeout = 30 * time.Second @@ -158,6 +162,7 @@ type Server struct { // the whole protocol stack. newTransport func(net.Conn) transport newPeerHook func(*Peer) + listenFunc func(network, addr string) (net.Listener, error) lock sync.Mutex // protects running running bool @@ -167,24 +172,26 @@ type Server struct { ntab discoverTable listener net.Listener ourHandshake *protoHandshake - lastLookup time.Time DiscV5 *discv5.Network - - // These are for Peers, PeerCount (and nothing else). - peerOp chan peerOpFunc - peerOpDone chan struct{} - - quit chan struct{} - addstatic chan *enode.Node - removestatic chan *enode.Node - addtrusted chan *enode.Node - removetrusted chan *enode.Node - posthandshake chan *conn - addpeer chan *conn - delpeer chan peerDrop - loopWG sync.WaitGroup // loop, listenLoop - peerFeed event.Feed - log log.Logger + loopWG sync.WaitGroup // loop, listenLoop + peerFeed event.Feed + log log.Logger + + // Channels into the run loop. + quit chan struct{} + addstatic chan *enode.Node + removestatic chan *enode.Node + addtrusted chan *enode.Node + removetrusted chan *enode.Node + peerOp chan peerOpFunc + peerOpDone chan struct{} + delpeer chan peerDrop + checkpointPostHandshake chan *conn + checkpointAddPeer chan *conn + + // State of run loop and listenLoop. + lastLookup time.Time + inboundHistory expHeap } type peerOpFunc func(map[enode.ID]*Peer) @@ -415,7 +422,7 @@ func (srv *Server) Start() (err error) { srv.running = true srv.log = srv.Config.Logger if srv.log == nil { - srv.log = log.New() + srv.log = log.Root() } if srv.NoDial && srv.ListenAddr == "" { srv.log.Warn("P2P server will be useless, neither dialing nor listening") @@ -428,13 +435,16 @@ func (srv *Server) Start() (err error) { if srv.newTransport == nil { srv.newTransport = newRLPX } + if srv.listenFunc == nil { + srv.listenFunc = net.Listen + } if srv.Dialer == nil { srv.Dialer = TCPDialer{&net.Dialer{Timeout: defaultDialTimeout}} } srv.quit = make(chan struct{}) - srv.addpeer = make(chan *conn) srv.delpeer = make(chan peerDrop) - srv.posthandshake = make(chan *conn) + srv.checkpointPostHandshake = make(chan *conn) + srv.checkpointAddPeer = make(chan *conn) srv.addstatic = make(chan *enode.Node) srv.removestatic = make(chan *enode.Node) srv.addtrusted = make(chan *enode.Node) @@ -455,7 +465,7 @@ func (srv *Server) Start() (err error) { } dynPeers := srv.maxDialedConns() - dialer := newDialState(srv.localnode.ID(), srv.StaticNodes, srv.BootstrapNodes, srv.ntab, dynPeers, srv.NetRestrict) + dialer := newDialState(srv.localnode.ID(), srv.ntab, dynPeers, &srv.Config) srv.loopWG.Add(1) go srv.run(dialer) return nil @@ -541,6 +551,7 @@ func (srv *Server) setupDiscovery() error { NetRestrict: srv.NetRestrict, Bootnodes: srv.BootstrapNodes, Unhandled: unhandled, + Log: srv.log, } ntab, err := discover.ListenUDP(conn, srv.localnode, cfg) if err != nil { @@ -569,27 +580,28 @@ func (srv *Server) setupDiscovery() error { } func (srv *Server) setupListening() error { - // Launch the TCP listener. - listener, err := net.Listen("tcp", srv.ListenAddr) + // Launch the listener. + listener, err := srv.listenFunc("tcp", srv.ListenAddr) if err != nil { return err } - laddr := listener.Addr().(*net.TCPAddr) - srv.ListenAddr = laddr.String() srv.listener = listener - srv.localnode.Set(enr.TCP(laddr.Port)) + srv.ListenAddr = listener.Addr().String() + + // Update the local node record and map the TCP listening port if NAT is configured. + if tcp, ok := listener.Addr().(*net.TCPAddr); ok { + srv.localnode.Set(enr.TCP(tcp.Port)) + if !tcp.IP.IsLoopback() && srv.NAT != nil { + srv.loopWG.Add(1) + go func() { + nat.Map(srv.NAT, srv.quit, "tcp", tcp.Port, tcp.Port, "ethereum p2p") + srv.loopWG.Done() + }() + } + } srv.loopWG.Add(1) go srv.listenLoop() - - // Map the TCP listening port if NAT is configured. - if !laddr.IP.IsLoopback() && srv.NAT != nil { - srv.loopWG.Add(1) - go func() { - nat.Map(srv.NAT, srv.quit, "tcp", laddr.Port, laddr.Port, "ethereum p2p") - srv.loopWG.Done() - }() - } return nil } @@ -657,12 +669,14 @@ running: case <-srv.quit: // The server was stopped. Run the cleanup logic. break running + case n := <-srv.addstatic: // This channel is used by AddPeer to add to the // ephemeral static peer list. Add it to the dialer, // it will keep the node connected. srv.log.Trace("Adding static node", "node", n) dialstate.addStatic(n) + case n := <-srv.removestatic: // This channel is used by RemovePeer to send a // disconnect request to a peer and begin the @@ -672,6 +686,7 @@ running: if p, ok := peers[n.ID()]; ok { p.Disconnect(DiscRequested) } + case n := <-srv.addtrusted: // This channel is used by AddTrustedPeer to add an enode // to the trusted node set. @@ -681,6 +696,7 @@ running: if p, ok := peers[n.ID()]; ok { p.rw.set(trustedConn, true) } + case n := <-srv.removetrusted: // This channel is used by RemoveTrustedPeer to remove an enode // from the trusted node set. @@ -691,10 +707,12 @@ running: if p, ok := peers[n.ID()]; ok { p.rw.set(trustedConn, false) } + case op := <-srv.peerOp: // This channel is used by Peers and PeerCount. op(peers) srv.peerOpDone <- struct{}{} + case t := <-taskdone: // A task got done. Tell dialstate about it so it // can update its state and remove it from the active @@ -702,7 +720,8 @@ running: srv.log.Trace("Dial task done", "task", t) dialstate.taskDone(t, time.Now()) delTask(t) - case c := <-srv.posthandshake: + + case c := <-srv.checkpointPostHandshake: // A connection has passed the encryption handshake so // the remote identity is known (but hasn't been verified yet). if trusted[c.node.ID()] { @@ -710,18 +729,15 @@ running: c.flags |= trustedConn } // TODO: track in-progress inbound node IDs (pre-Peer) to avoid dialing them. - select { - case c.cont <- srv.encHandshakeChecks(peers, inboundCount, c): - case <-srv.quit: - break running - } - case c := <-srv.addpeer: + c.cont <- srv.postHandshakeChecks(peers, inboundCount, c) + + case c := <-srv.checkpointAddPeer: // At this point the connection is past the protocol handshake. // Its capabilities are known and the remote identity is verified. - err := srv.protoHandshakeChecks(peers, inboundCount, c) + err := srv.addPeerChecks(peers, inboundCount, c) if err == nil { // The handshakes are done and it passed all checks. - p := newPeer(c, srv.Protocols) + p := newPeer(srv.log, c, srv.Protocols) // If message events are enabled, pass the peerFeed // to the peer if srv.EnableMsgEvents { @@ -738,11 +754,8 @@ running: // The dialer logic relies on the assumption that // dial tasks complete after the peer has been added or // discarded. Unblock the task last. - select { - case c.cont <- err: - case <-srv.quit: - break running - } + c.cont <- err + case pd := <-srv.delpeer: // A peer disconnected. d := common.PrettyDuration(mclock.Now() - pd.created) @@ -777,17 +790,7 @@ running: } } -func (srv *Server) protoHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int, c *conn) error { - // Drop connections with no matching protocols. - if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, c.caps) == 0 { - return DiscUselessPeer - } - // Repeat the encryption handshake checks because the - // peer set might have changed between the handshakes. - return srv.encHandshakeChecks(peers, inboundCount, c) -} - -func (srv *Server) encHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int, c *conn) error { +func (srv *Server) postHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int, c *conn) error { switch { case !c.is(trustedConn|staticDialedConn) && len(peers) >= srv.MaxPeers: return DiscTooManyPeers @@ -802,9 +805,20 @@ func (srv *Server) encHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int } } +func (srv *Server) addPeerChecks(peers map[enode.ID]*Peer, inboundCount int, c *conn) error { + // Drop connections with no matching protocols. + if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, c.caps) == 0 { + return DiscUselessPeer + } + // Repeat the post-handshake checks because the + // peer set might have changed since those checks were performed. + return srv.postHandshakeChecks(peers, inboundCount, c) +} + func (srv *Server) maxInboundConns() int { return srv.MaxPeers - srv.maxDialedConns() } + func (srv *Server) maxDialedConns() int { if srv.NoDiscovery || srv.NoDial { return 0 @@ -832,7 +846,7 @@ func (srv *Server) listenLoop() { } for { - // Wait for a handshake slot before accepting. + // Wait for a free slot before accepting. <-slots var ( @@ -851,21 +865,16 @@ func (srv *Server) listenLoop() { break } - // Reject connections that do not match NetRestrict. - if srv.NetRestrict != nil { - if tcp, ok := fd.RemoteAddr().(*net.TCPAddr); ok && !srv.NetRestrict.Contains(tcp.IP) { - srv.log.Debug("Rejected conn (not whitelisted in NetRestrict)", "addr", fd.RemoteAddr()) - fd.Close() - slots <- struct{}{} - continue - } + remoteIP := netutil.AddrIP(fd.RemoteAddr()) + if err := srv.checkInboundConn(fd, remoteIP); err != nil { + srv.log.Debug("Rejected inbound connnection", "addr", fd.RemoteAddr(), "err", err) + fd.Close() + slots <- struct{}{} + continue } - - var ip net.IP - if tcp, ok := fd.RemoteAddr().(*net.TCPAddr); ok { - ip = tcp.IP + if remoteIP != nil { + fd = newMeteredConn(fd, true, remoteIP) } - fd = newMeteredConn(fd, true, ip) srv.log.Trace("Accepted connection", "addr", fd.RemoteAddr()) go func() { srv.SetupConn(fd, inboundConn, nil) @@ -874,6 +883,22 @@ func (srv *Server) listenLoop() { } } +func (srv *Server) checkInboundConn(fd net.Conn, remoteIP net.IP) error { + if remoteIP != nil { + // Reject connections that do not match NetRestrict. + if srv.NetRestrict != nil && !srv.NetRestrict.Contains(remoteIP) { + return fmt.Errorf("not whitelisted in NetRestrict") + } + // Reject Internet peers that try too often. + srv.inboundHistory.expire(time.Now()) + if !netutil.IsLAN(remoteIP) && srv.inboundHistory.contains(remoteIP.String()) { + return fmt.Errorf("too many attempts") + } + srv.inboundHistory.add(remoteIP.String(), time.Now().Add(inboundThrottleTime)) + } + return nil +} + // SetupConn runs the handshakes and attempts to add the connection // as a peer. It returns when the connection has been added as a peer // or the handshakes have failed. @@ -895,6 +920,7 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) erro if !running { return errServerStopped } + // If dialing, figure out the remote public key. var dialPubkey *ecdsa.PublicKey if dialDest != nil { @@ -903,7 +929,8 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) erro return errors.New("dial destination doesn't have a secp256k1 public key") } } - // Run the encryption handshake. + + // Run the RLPx handshake. remotePubkey, err := c.doEncHandshake(srv.PrivateKey, dialPubkey) if err != nil { srv.log.Trace("Failed RLPx handshake", "addr", c.fd.RemoteAddr(), "conn", c.flags, "err", err) @@ -922,12 +949,13 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) erro conn.handshakeDone(c.node.ID()) } clog := srv.log.New("id", c.node.ID(), "addr", c.fd.RemoteAddr(), "conn", c.flags) - err = srv.checkpoint(c, srv.posthandshake) + err = srv.checkpoint(c, srv.checkpointPostHandshake) if err != nil { - clog.Trace("Rejected peer before protocol handshake", "err", err) + clog.Trace("Rejected peer", "err", err) return err } - // Run the protocol handshake + + // Run the capability negotiation handshake. phs, err := c.doProtoHandshake(srv.ourHandshake) if err != nil { clog.Trace("Failed proto handshake", "err", err) @@ -938,14 +966,15 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) erro return DiscUnexpectedIdentity } c.caps, c.name = phs.Caps, phs.Name - err = srv.checkpoint(c, srv.addpeer) + err = srv.checkpoint(c, srv.checkpointAddPeer) if err != nil { clog.Trace("Rejected peer", "err", err) return err } - // If the checks completed successfully, runPeer has now been - // launched by run. - clog.Trace("connection set up", "inbound", dialDest == nil) + + // If the checks completed successfully, the connection has been added as a peer and + // runPeer has been launched. + clog.Trace("Connection set up", "inbound", dialDest == nil) return nil } @@ -974,12 +1003,7 @@ func (srv *Server) checkpoint(c *conn, stage chan<- *conn) error { case <-srv.quit: return errServerStopped } - select { - case err := <-c.cont: - return err - case <-srv.quit: - return errServerStopped - } + return <-c.cont } // runPeer runs in its own goroutine for each peer. diff --git a/p2p/server_test.go b/p2p/server_test.go index f665c14245..e8bc627e1d 100644 --- a/p2p/server_test.go +++ b/p2p/server_test.go @@ -19,6 +19,7 @@ package p2p import ( "crypto/ecdsa" "errors" + "io" "math/rand" "net" "reflect" @@ -26,6 +27,7 @@ import ( "time" "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/internal/testlog" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enr" @@ -74,6 +76,7 @@ func startTestServer(t *testing.T, remoteKey *ecdsa.PublicKey, pf func(*Peer)) * MaxPeers: 10, ListenAddr: "127.0.0.1:0", PrivateKey: newkey(), + Logger: testlog.Logger(t, log.LvlTrace), } server := &Server{ Config: config, @@ -359,6 +362,7 @@ func TestServerAtCap(t *testing.T) { PrivateKey: newkey(), MaxPeers: 10, NoDial: true, + NoDiscovery: true, TrustedNodes: []*enode.Node{newNode(trustedID, nil)}, }, } @@ -377,19 +381,19 @@ func TestServerAtCap(t *testing.T) { // Inject a few connections to fill up the peer set. for i := 0; i < 10; i++ { c := newconn(randomID()) - if err := srv.checkpoint(c, srv.addpeer); err != nil { + if err := srv.checkpoint(c, srv.checkpointAddPeer); err != nil { t.Fatalf("could not add conn %d: %v", i, err) } } // Try inserting a non-trusted connection. anotherID := randomID() c := newconn(anotherID) - if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers { + if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != DiscTooManyPeers { t.Error("wrong error for insert:", err) } // Try inserting a trusted connection. c = newconn(trustedID) - if err := srv.checkpoint(c, srv.posthandshake); err != nil { + if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != nil { t.Error("unexpected error for trusted conn @posthandshake:", err) } if !c.is(trustedConn) { @@ -399,14 +403,14 @@ func TestServerAtCap(t *testing.T) { // Remove from trusted set and try again srv.RemoveTrustedPeer(newNode(trustedID, nil)) c = newconn(trustedID) - if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers { + if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != DiscTooManyPeers { t.Error("wrong error for insert:", err) } // Add anotherID to trusted set and try again srv.AddTrustedPeer(newNode(anotherID, nil)) c = newconn(anotherID) - if err := srv.checkpoint(c, srv.posthandshake); err != nil { + if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != nil { t.Error("unexpected error for trusted conn @posthandshake:", err) } if !c.is(trustedConn) { @@ -430,10 +434,11 @@ func TestServerPeerLimits(t *testing.T) { srv := &Server{ Config: Config{ - PrivateKey: srvkey, - MaxPeers: 0, - NoDial: true, - Protocols: []Protocol{discard}, + PrivateKey: srvkey, + MaxPeers: 0, + NoDial: true, + NoDiscovery: true, + Protocols: []Protocol{discard}, }, newTransport: func(fd net.Conn) transport { return tp }, log: log.New(), @@ -541,29 +546,35 @@ func TestServerSetupConn(t *testing.T) { } for i, test := range tests { - srv := &Server{ - Config: Config{ - PrivateKey: srvkey, - MaxPeers: 10, - NoDial: true, - Protocols: []Protocol{discard}, - }, - newTransport: func(fd net.Conn) transport { return test.tt }, - log: log.New(), - } - if !test.dontstart { - if err := srv.Start(); err != nil { - t.Fatalf("couldn't start server: %v", err) + t.Run(test.wantCalls, func(t *testing.T) { + cfg := Config{ + PrivateKey: srvkey, + MaxPeers: 10, + NoDial: true, + NoDiscovery: true, + Protocols: []Protocol{discard}, + Logger: testlog.Logger(t, log.LvlTrace), } - } - p1, _ := net.Pipe() - srv.SetupConn(p1, test.flags, test.dialDest) - if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) { - t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr) - } - if test.tt.calls != test.wantCalls { - t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls) - } + srv := &Server{ + Config: cfg, + newTransport: func(fd net.Conn) transport { return test.tt }, + log: cfg.Logger, + } + if !test.dontstart { + if err := srv.Start(); err != nil { + t.Fatalf("couldn't start server: %v", err) + } + defer srv.Stop() + } + p1, _ := net.Pipe() + srv.SetupConn(p1, test.flags, test.dialDest) + if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) { + t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr) + } + if test.tt.calls != test.wantCalls { + t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls) + } + }) } } @@ -616,3 +627,100 @@ func randomID() (id enode.ID) { } return id } + +// This test checks that inbound connections are throttled by IP. +func TestServerInboundThrottle(t *testing.T) { + const timeout = 5 * time.Second + newTransportCalled := make(chan struct{}) + srv := &Server{ + Config: Config{ + PrivateKey: newkey(), + ListenAddr: "127.0.0.1:0", + MaxPeers: 10, + NoDial: true, + NoDiscovery: true, + Protocols: []Protocol{discard}, + Logger: testlog.Logger(t, log.LvlTrace), + }, + newTransport: func(fd net.Conn) transport { + newTransportCalled <- struct{}{} + return newRLPX(fd) + }, + listenFunc: func(network, laddr string) (net.Listener, error) { + fakeAddr := &net.TCPAddr{IP: net.IP{95, 33, 21, 2}, Port: 4444} + return listenFakeAddr(network, laddr, fakeAddr) + }, + } + if err := srv.Start(); err != nil { + t.Fatal("can't start: ", err) + } + defer srv.Stop() + + // Dial the test server. + conn, err := net.DialTimeout("tcp", srv.ListenAddr, timeout) + if err != nil { + t.Fatalf("could not dial: %v", err) + } + select { + case <-newTransportCalled: + // OK + case <-time.After(timeout): + t.Error("newTransport not called") + } + conn.Close() + + // Dial again. This time the server should close the connection immediately. + connClosed := make(chan struct{}) + conn, err = net.DialTimeout("tcp", srv.ListenAddr, timeout) + if err != nil { + t.Fatalf("could not dial: %v", err) + } + defer conn.Close() + go func() { + conn.SetDeadline(time.Now().Add(timeout)) + buf := make([]byte, 10) + if n, err := conn.Read(buf); err != io.EOF || n != 0 { + t.Errorf("expected io.EOF and n == 0, got error %q and n == %d", err, n) + } + connClosed <- struct{}{} + }() + select { + case <-connClosed: + // OK + case <-newTransportCalled: + t.Error("newTransport called for second attempt") + case <-time.After(timeout): + t.Error("connection not closed within timeout") + } +} + +func listenFakeAddr(network, laddr string, remoteAddr net.Addr) (net.Listener, error) { + l, err := net.Listen(network, laddr) + if err == nil { + l = &fakeAddrListener{l, remoteAddr} + } + return l, err +} + +// fakeAddrListener is a listener that creates connections with a mocked remote address. +type fakeAddrListener struct { + net.Listener + remoteAddr net.Addr +} + +type fakeAddrConn struct { + net.Conn + remoteAddr net.Addr +} + +func (l *fakeAddrListener) Accept() (net.Conn, error) { + c, err := l.Listener.Accept() + if err != nil { + return nil, err + } + return &fakeAddrConn{c, l.remoteAddr}, nil +} + +func (c *fakeAddrConn) RemoteAddr() net.Addr { + return c.remoteAddr +} diff --git a/p2p/util.go b/p2p/util.go new file mode 100644 index 0000000000..2a6edf5cee --- /dev/null +++ b/p2p/util.go @@ -0,0 +1,82 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package p2p + +import ( + "container/heap" + "time" +) + +// expHeap tracks strings and their expiry time. +type expHeap []expItem + +// expItem is an entry in addrHistory. +type expItem struct { + item string + exp time.Time +} + +// nextExpiry returns the next expiry time. +func (h *expHeap) nextExpiry() time.Time { + return (*h)[0].exp +} + +// add adds an item and sets its expiry time. +func (h *expHeap) add(item string, exp time.Time) { + heap.Push(h, expItem{item, exp}) +} + +// remove removes an item. +func (h *expHeap) remove(item string) bool { + for i, v := range *h { + if v.item == item { + heap.Remove(h, i) + return true + } + } + return false +} + +// contains checks whether an item is present. +func (h expHeap) contains(item string) bool { + for _, v := range h { + if v.item == item { + return true + } + } + return false +} + +// expire removes items with expiry time before 'now'. +func (h *expHeap) expire(now time.Time) { + for h.Len() > 0 && h.nextExpiry().Before(now) { + heap.Pop(h) + } +} + +// heap.Interface boilerplate +func (h expHeap) Len() int { return len(h) } +func (h expHeap) Less(i, j int) bool { return h[i].exp.Before(h[j].exp) } +func (h expHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } +func (h *expHeap) Push(x interface{}) { *h = append(*h, x.(expItem)) } +func (h *expHeap) Pop() interface{} { + old := *h + n := len(old) + x := old[n-1] + *h = old[0 : n-1] + return x +} diff --git a/p2p/util_test.go b/p2p/util_test.go new file mode 100644 index 0000000000..c9f2648dc9 --- /dev/null +++ b/p2p/util_test.go @@ -0,0 +1,54 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package p2p + +import ( + "testing" + "time" +) + +func TestExpHeap(t *testing.T) { + var h expHeap + + var ( + basetime = time.Unix(4000, 0) + exptimeA = basetime.Add(2 * time.Second) + exptimeB = basetime.Add(3 * time.Second) + exptimeC = basetime.Add(4 * time.Second) + ) + h.add("a", exptimeA) + h.add("b", exptimeB) + h.add("c", exptimeC) + + if !h.nextExpiry().Equal(exptimeA) { + t.Fatal("wrong nextExpiry") + } + if !h.contains("a") || !h.contains("b") || !h.contains("c") { + t.Fatal("heap doesn't contain all live items") + } + + h.expire(exptimeA.Add(1)) + if !h.nextExpiry().Equal(exptimeB) { + t.Fatal("wrong nextExpiry") + } + if h.contains("a") { + t.Fatal("heap contains a even though it has already expired") + } + if !h.contains("b") || !h.contains("c") { + t.Fatal("heap doesn't contain all live items") + } +}