diff --git a/p2p/discover/database.go b/p2p/discover/database.go index 6f98de9b42..22554145f4 100644 --- a/p2p/discover/database.go +++ b/p2p/discover/database.go @@ -42,6 +42,7 @@ var ( nodeDBNilNodeID = NodeID{} // Special node ID to use as a nil element. nodeDBNodeExpiration = 24 * time.Hour // Time after which an unseen node should be dropped. nodeDBCleanupCycle = time.Hour // Time period for running the expiration task. + nodeDBVersion = 5 ) // nodeDB stores all nodes we know about. @@ -257,7 +258,7 @@ func (db *nodeDB) expireNodes() error { } // Skip the node if not expired yet (and not self) if !bytes.Equal(id[:], db.self[:]) { - if seen := db.bondTime(id); seen.After(threshold) { + if seen := db.lastPongReceived(id); seen.After(threshold) { continue } } @@ -267,29 +268,28 @@ func (db *nodeDB) expireNodes() error { return nil } -// lastPing retrieves the time of the last ping packet send to a remote node, -// requesting binding. -func (db *nodeDB) lastPing(id NodeID) time.Time { +// lastPingReceived retrieves the time of the last ping packet sent by the remote node. +func (db *nodeDB) lastPingReceived(id NodeID) time.Time { return time.Unix(db.fetchInt64(makeKey(id, nodeDBDiscoverPing)), 0) } -// updateLastPing updates the last time we tried contacting a remote node. -func (db *nodeDB) updateLastPing(id NodeID, instance time.Time) error { +// updateLastPing updates the last time remote node pinged us. +func (db *nodeDB) updateLastPingReceived(id NodeID, instance time.Time) error { return db.storeInt64(makeKey(id, nodeDBDiscoverPing), instance.Unix()) } -// bondTime retrieves the time of the last successful pong from remote node. -func (db *nodeDB) bondTime(id NodeID) time.Time { +// lastPongReceived retrieves the time of the last successful pong from remote node. +func (db *nodeDB) lastPongReceived(id NodeID) time.Time { return time.Unix(db.fetchInt64(makeKey(id, nodeDBDiscoverPong)), 0) } // hasBond reports whether the given node is considered bonded. func (db *nodeDB) hasBond(id NodeID) bool { - return time.Since(db.bondTime(id)) < nodeDBNodeExpiration + return time.Since(db.lastPongReceived(id)) < nodeDBNodeExpiration } -// updateBondTime updates the last pong time of a node. -func (db *nodeDB) updateBondTime(id NodeID, instance time.Time) error { +// updateLastPongReceived updates the last pong time of a node. +func (db *nodeDB) updateLastPongReceived(id NodeID, instance time.Time) error { return db.storeInt64(makeKey(id, nodeDBDiscoverPong), instance.Unix()) } @@ -332,7 +332,7 @@ seek: if n.ID == db.self { continue seek } - if now.Sub(db.bondTime(n.ID)) > maxAge { + if now.Sub(db.lastPongReceived(n.ID)) > maxAge { continue seek } for i := range nodes { diff --git a/p2p/discover/database_test.go b/p2p/discover/database_test.go index c4fa44d099..27974344e0 100644 --- a/p2p/discover/database_test.go +++ b/p2p/discover/database_test.go @@ -79,7 +79,7 @@ var nodeDBInt64Tests = []struct { } func TestNodeDBInt64(t *testing.T) { - db, _ := newNodeDB("", Version, NodeID{}) + db, _ := newNodeDB("", nodeDBVersion, NodeID{}) defer db.close() tests := nodeDBInt64Tests @@ -111,27 +111,27 @@ func TestNodeDBFetchStore(t *testing.T) { inst := time.Now() num := 314 - db, _ := newNodeDB("", Version, NodeID{}) + db, _ := newNodeDB("", nodeDBVersion, NodeID{}) defer db.close() // Check fetch/store operations on a node ping object - if stored := db.lastPing(node.ID); stored.Unix() != 0 { + if stored := db.lastPingReceived(node.ID); stored.Unix() != 0 { t.Errorf("ping: non-existing object: %v", stored) } - if err := db.updateLastPing(node.ID, inst); err != nil { + if err := db.updateLastPingReceived(node.ID, inst); err != nil { t.Errorf("ping: failed to update: %v", err) } - if stored := db.lastPing(node.ID); stored.Unix() != inst.Unix() { + if stored := db.lastPingReceived(node.ID); stored.Unix() != inst.Unix() { t.Errorf("ping: value mismatch: have %v, want %v", stored, inst) } // Check fetch/store operations on a node pong object - if stored := db.bondTime(node.ID); stored.Unix() != 0 { + if stored := db.lastPongReceived(node.ID); stored.Unix() != 0 { t.Errorf("pong: non-existing object: %v", stored) } - if err := db.updateBondTime(node.ID, inst); err != nil { + if err := db.updateLastPongReceived(node.ID, inst); err != nil { t.Errorf("pong: failed to update: %v", err) } - if stored := db.bondTime(node.ID); stored.Unix() != inst.Unix() { + if stored := db.lastPongReceived(node.ID); stored.Unix() != inst.Unix() { t.Errorf("pong: value mismatch: have %v, want %v", stored, inst) } // Check fetch/store operations on a node findnode-failure object @@ -216,7 +216,7 @@ var nodeDBSeedQueryNodes = []struct { } func TestNodeDBSeedQuery(t *testing.T) { - db, _ := newNodeDB("", Version, nodeDBSeedQueryNodes[1].node.ID) + db, _ := newNodeDB("", nodeDBVersion, nodeDBSeedQueryNodes[1].node.ID) defer db.close() // Insert a batch of nodes for querying @@ -224,7 +224,7 @@ func TestNodeDBSeedQuery(t *testing.T) { if err := db.updateNode(seed.node); err != nil { t.Fatalf("node %d: failed to insert: %v", i, err) } - if err := db.updateBondTime(seed.node.ID, seed.pong); err != nil { + if err := db.updateLastPongReceived(seed.node.ID, seed.pong); err != nil { t.Fatalf("node %d: failed to insert bondTime: %v", i, err) } } @@ -267,7 +267,7 @@ func TestNodeDBPersistency(t *testing.T) { ) // Create a persistent database and store some values - db, err := newNodeDB(filepath.Join(root, "database"), Version, NodeID{}) + db, err := newNodeDB(filepath.Join(root, "database"), nodeDBVersion, NodeID{}) if err != nil { t.Fatalf("failed to create persistent database: %v", err) } @@ -277,7 +277,7 @@ func TestNodeDBPersistency(t *testing.T) { db.close() // Reopen the database and check the value - db, err = newNodeDB(filepath.Join(root, "database"), Version, NodeID{}) + db, err = newNodeDB(filepath.Join(root, "database"), nodeDBVersion, NodeID{}) if err != nil { t.Fatalf("failed to open persistent database: %v", err) } @@ -287,7 +287,7 @@ func TestNodeDBPersistency(t *testing.T) { db.close() // Change the database version and check flush - db, err = newNodeDB(filepath.Join(root, "database"), Version+1, NodeID{}) + db, err = newNodeDB(filepath.Join(root, "database"), nodeDBVersion+1, NodeID{}) if err != nil { t.Fatalf("failed to open persistent database: %v", err) } @@ -324,7 +324,7 @@ var nodeDBExpirationNodes = []struct { } func TestNodeDBExpiration(t *testing.T) { - db, _ := newNodeDB("", Version, NodeID{}) + db, _ := newNodeDB("", nodeDBVersion, NodeID{}) defer db.close() // Add all the test nodes and set their last pong time @@ -332,7 +332,7 @@ func TestNodeDBExpiration(t *testing.T) { if err := db.updateNode(seed.node); err != nil { t.Fatalf("node %d: failed to insert: %v", i, err) } - if err := db.updateBondTime(seed.node.ID, seed.pong); err != nil { + if err := db.updateLastPongReceived(seed.node.ID, seed.pong); err != nil { t.Fatalf("node %d: failed to update bondTime: %v", i, err) } } @@ -357,7 +357,7 @@ func TestNodeDBSelfExpiration(t *testing.T) { break } } - db, _ := newNodeDB("", Version, self) + db, _ := newNodeDB("", nodeDBVersion, self) defer db.close() // Add all the test nodes and set their last pong time @@ -365,7 +365,7 @@ func TestNodeDBSelfExpiration(t *testing.T) { if err := db.updateNode(seed.node); err != nil { t.Fatalf("node %d: failed to insert: %v", i, err) } - if err := db.updateBondTime(seed.node.ID, seed.pong); err != nil { + if err := db.updateLastPongReceived(seed.node.ID, seed.pong); err != nil { t.Fatalf("node %d: failed to update bondTime: %v", i, err) } } diff --git a/p2p/discover/table.go b/p2p/discover/table.go index 18920ccfdd..8803daa56e 100644 --- a/p2p/discover/table.go +++ b/p2p/discover/table.go @@ -25,7 +25,6 @@ package discover import ( crand "crypto/rand" "encoding/binary" - "errors" "fmt" mrand "math/rand" "net" @@ -54,15 +53,13 @@ const ( bucketIPLimit, bucketSubnet = 2, 24 // at most 2 addresses from the same /24 tableIPLimit, tableSubnet = 10, 24 - maxBondingPingPongs = 16 // Limit on the number of concurrent ping/pong interactions - maxFindnodeFailures = 5 // Nodes exceeding this limit are dropped - - refreshInterval = 30 * time.Minute - revalidateInterval = 10 * time.Second - copyNodesInterval = 30 * time.Second - seedMinTableTime = 5 * time.Minute - seedCount = 30 - seedMaxAge = 5 * 24 * time.Hour + maxFindnodeFailures = 5 // Nodes exceeding this limit are dropped + refreshInterval = 30 * time.Minute + revalidateInterval = 10 * time.Second + copyNodesInterval = 30 * time.Second + seedMinTableTime = 5 * time.Minute + seedCount = 30 + seedMaxAge = 5 * 24 * time.Hour ) type Table struct { @@ -78,28 +75,17 @@ type Table struct { closeReq chan struct{} closed chan struct{} - bondmu sync.Mutex - bonding map[NodeID]*bondproc - bondslots chan struct{} // limits total number of active bonding processes - nodeAddedHook func(*Node) // for testing net transport self *Node // metadata of the local node } -type bondproc struct { - err error - n *Node - done chan struct{} -} - // transport is implemented by the UDP transport. // it is an interface so we can test without opening lots of UDP // sockets and without generating a private key. type transport interface { ping(NodeID, *net.UDPAddr) error - waitping(NodeID) error findnode(toid NodeID, addr *net.UDPAddr, target NodeID) ([]*Node, error) close() } @@ -114,7 +100,7 @@ type bucket struct { func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string, bootnodes []*Node) (*Table, error) { // If no node database was given, use an in-memory one - db, err := newNodeDB(nodeDBPath, Version, ourID) + db, err := newNodeDB(nodeDBPath, nodeDBVersion, ourID) if err != nil { return nil, err } @@ -122,8 +108,6 @@ func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string net: t, db: db, self: NewNode(ourID, ourAddr.IP, uint16(ourAddr.Port), uint16(ourAddr.Port)), - bonding: make(map[NodeID]*bondproc), - bondslots: make(chan struct{}, maxBondingPingPongs), refreshReq: make(chan chan struct{}), initDone: make(chan struct{}), closeReq: make(chan struct{}), @@ -134,16 +118,13 @@ func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string if err := tab.setFallbackNodes(bootnodes); err != nil { return nil, err } - for i := 0; i < cap(tab.bondslots); i++ { - tab.bondslots <- struct{}{} - } for i := range tab.buckets { tab.buckets[i] = &bucket{ ips: netutil.DistinctNetSet{Subnet: bucketSubnet, Limit: bucketIPLimit}, } } tab.seedRand() - tab.loadSeedNodes(false) + tab.loadSeedNodes() // Start the background expiration goroutine after loading seeds so that the search for // seed nodes also considers older nodes that would otherwise be removed by the // expiration. @@ -315,22 +296,7 @@ func (tab *Table) lookup(targetID NodeID, refreshIfEmpty bool) []*Node { if !asked[n.ID] { asked[n.ID] = true pendingQueries++ - go func() { - // Find potential neighbors to bond with - r, err := tab.net.findnode(n.ID, n.addr(), targetID) - if err != nil { - // Bump the failure counter to detect and evacuate non-bonded entries - fails := tab.db.findFails(n.ID) + 1 - tab.db.updateFindFails(n.ID, fails) - log.Trace("Bumping findnode failure counter", "id", n.ID, "failcount", fails) - - if fails >= maxFindnodeFailures { - log.Trace("Too many findnode failures, dropping", "id", n.ID, "failcount", fails) - tab.delete(n) - } - } - reply <- tab.bondall(r) - }() + go tab.findnode(n, targetID, reply) } } if pendingQueries == 0 { @@ -349,6 +315,29 @@ func (tab *Table) lookup(targetID NodeID, refreshIfEmpty bool) []*Node { return result.entries } +func (tab *Table) findnode(n *Node, targetID NodeID, reply chan<- []*Node) { + fails := tab.db.findFails(n.ID) + r, err := tab.net.findnode(n.ID, n.addr(), targetID) + if err != nil || len(r) == 0 { + fails++ + tab.db.updateFindFails(n.ID, fails) + log.Trace("Findnode failed", "id", n.ID, "failcount", fails, "err", err) + if fails >= maxFindnodeFailures { + log.Trace("Too many findnode failures, dropping", "id", n.ID, "failcount", fails) + tab.delete(n) + } + } else if fails > 0 { + tab.db.updateFindFails(n.ID, fails-1) + } + + // Grab as many nodes as possible. Some of them might not be alive anymore, but we'll + // just remove those again during revalidation. + for _, n := range r { + tab.add(n) + } + reply <- r +} + func (tab *Table) refresh() <-chan struct{} { done := make(chan struct{}) select { @@ -401,7 +390,7 @@ loop: case <-revalidateDone: revalidate.Reset(tab.nextRevalidateTime()) case <-copyNodes.C: - go tab.copyBondedNodes() + go tab.copyLiveNodes() case <-tab.closeReq: break loop } @@ -429,7 +418,7 @@ func (tab *Table) doRefresh(done chan struct{}) { // Load nodes from the database and insert // them. This should yield a few previously seen nodes that are // (hopefully) still alive. - tab.loadSeedNodes(true) + tab.loadSeedNodes() // Run self lookup to discover new neighbor nodes. tab.lookup(tab.self.ID, false) @@ -447,15 +436,12 @@ func (tab *Table) doRefresh(done chan struct{}) { } } -func (tab *Table) loadSeedNodes(bond bool) { +func (tab *Table) loadSeedNodes() { seeds := tab.db.querySeeds(seedCount, seedMaxAge) seeds = append(seeds, tab.nursery...) - if bond { - seeds = tab.bondall(seeds) - } for i := range seeds { seed := seeds[i] - age := log.Lazy{Fn: func() interface{} { return time.Since(tab.db.bondTime(seed.ID)) }} + age := log.Lazy{Fn: func() interface{} { return time.Since(tab.db.lastPongReceived(seed.ID)) }} log.Debug("Found seed node in database", "id", seed.ID, "addr", seed.addr(), "age", age) tab.add(seed) } @@ -473,7 +459,7 @@ func (tab *Table) doRevalidate(done chan<- struct{}) { } // Ping the selected node and wait for a pong. - err := tab.ping(last.ID, last.addr()) + err := tab.net.ping(last.ID, last.addr()) tab.mutex.Lock() defer tab.mutex.Unlock() @@ -515,9 +501,9 @@ func (tab *Table) nextRevalidateTime() time.Duration { return time.Duration(tab.rand.Int63n(int64(revalidateInterval))) } -// copyBondedNodes adds nodes from the table to the database if they have been in the table +// copyLiveNodes adds nodes from the table to the database if they have been in the table // longer then minTableTime. -func (tab *Table) copyBondedNodes() { +func (tab *Table) copyLiveNodes() { tab.mutex.Lock() defer tab.mutex.Unlock() @@ -553,120 +539,6 @@ func (tab *Table) len() (n int) { return n } -// bondall bonds with all given nodes concurrently and returns -// those nodes for which bonding has probably succeeded. -func (tab *Table) bondall(nodes []*Node) (result []*Node) { - rc := make(chan *Node, len(nodes)) - for i := range nodes { - go func(n *Node) { - nn, _ := tab.bond(false, n.ID, n.addr(), n.TCP) - rc <- nn - }(nodes[i]) - } - for range nodes { - if n := <-rc; n != nil { - result = append(result, n) - } - } - return result -} - -// bond ensures the local node has a bond with the given remote node. -// It also attempts to insert the node into the table if bonding succeeds. -// The caller must not hold tab.mutex. -// -// A bond is must be established before sending findnode requests. -// Both sides must have completed a ping/pong exchange for a bond to -// exist. The total number of active bonding processes is limited in -// order to restrain network use. -// -// bond is meant to operate idempotently in that bonding with a remote -// node which still remembers a previously established bond will work. -// The remote node will simply not send a ping back, causing waitping -// to time out. -// -// If pinged is true, the remote node has just pinged us and one half -// of the process can be skipped. -func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16) (*Node, error) { - if id == tab.self.ID { - return nil, errors.New("is self") - } - if pinged && !tab.isInitDone() { - return nil, errors.New("still initializing") - } - // Start bonding if we haven't seen this node for a while or if it failed findnode too often. - node, fails := tab.db.node(id), tab.db.findFails(id) - age := time.Since(tab.db.bondTime(id)) - var result error - if fails > 0 || age > nodeDBNodeExpiration { - log.Trace("Starting bonding ping/pong", "id", id, "known", node != nil, "failcount", fails, "age", age) - - tab.bondmu.Lock() - w := tab.bonding[id] - if w != nil { - // Wait for an existing bonding process to complete. - tab.bondmu.Unlock() - <-w.done - } else { - // Register a new bonding process. - w = &bondproc{done: make(chan struct{})} - tab.bonding[id] = w - tab.bondmu.Unlock() - // Do the ping/pong. The result goes into w. - tab.pingpong(w, pinged, id, addr, tcpPort) - // Unregister the process after it's done. - tab.bondmu.Lock() - delete(tab.bonding, id) - tab.bondmu.Unlock() - } - // Retrieve the bonding results - result = w.err - if result == nil { - node = w.n - } - } - // Add the node to the table even if the bonding ping/pong - // fails. It will be relaced quickly if it continues to be - // unresponsive. - if node != nil { - tab.add(node) - tab.db.updateFindFails(id, 0) - } - return node, result -} - -func (tab *Table) pingpong(w *bondproc, pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16) { - // Request a bonding slot to limit network usage - <-tab.bondslots - defer func() { tab.bondslots <- struct{}{} }() - - // Ping the remote side and wait for a pong. - if w.err = tab.ping(id, addr); w.err != nil { - close(w.done) - return - } - if !pinged { - // Give the remote node a chance to ping us before we start - // sending findnode requests. If they still remember us, - // waitping will simply time out. - tab.net.waitping(id) - } - // Bonding succeeded, update the node database. - w.n = NewNode(id, addr.IP, uint16(addr.Port), tcpPort) - close(w.done) -} - -// ping a remote endpoint and wait for a reply, also updating the node -// database accordingly. -func (tab *Table) ping(id NodeID, addr *net.UDPAddr) error { - tab.db.updateLastPing(id, time.Now()) - if err := tab.net.ping(id, addr); err != nil { - return err - } - tab.db.updateBondTime(id, time.Now()) - return nil -} - // bucket returns the bucket for the given node ID hash. func (tab *Table) bucket(sha common.Hash) *bucket { d := logdist(tab.self.sha, sha) @@ -676,21 +548,33 @@ func (tab *Table) bucket(sha common.Hash) *bucket { return tab.buckets[d-bucketMinDistance-1] } -// add attempts to add the given node its corresponding bucket. If the -// bucket has space available, adding the node succeeds immediately. -// Otherwise, the node is added if the least recently active node in -// the bucket does not respond to a ping packet. +// add attempts to add the given node to its corresponding bucket. If the bucket has space +// available, adding the node succeeds immediately. Otherwise, the node is added if the +// least recently active node in the bucket does not respond to a ping packet. // // The caller must not hold tab.mutex. -func (tab *Table) add(new *Node) { +func (tab *Table) add(n *Node) { tab.mutex.Lock() defer tab.mutex.Unlock() - b := tab.bucket(new.sha) - if !tab.bumpOrAdd(b, new) { + b := tab.bucket(n.sha) + if !tab.bumpOrAdd(b, n) { // Node is not in table. Add it to the replacement list. - tab.addReplacement(b, new) + tab.addReplacement(b, n) + } +} + +// addThroughPing adds the given node to the table. Compared to plain +// 'add' there is an additional safety measure: if the table is still +// initializing the node is not added. This prevents an attack where the +// table could be filled by just sending ping repeatedly. +// +// The caller must not hold tab.mutex. +func (tab *Table) addThroughPing(n *Node) { + if !tab.isInitDone() { + return } + tab.add(n) } // stuff adds nodes the table to the end of their corresponding bucket @@ -710,8 +594,7 @@ func (tab *Table) stuff(nodes []*Node) { } } -// delete removes an entry from the node table (used to evacuate -// failed/non-bonded discovery peers). +// delete removes an entry from the node table. It is used to evacuate dead nodes. func (tab *Table) delete(node *Node) { tab.mutex.Lock() defer tab.mutex.Unlock() diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go index f2d3f9a2ad..ed55ebd9a9 100644 --- a/p2p/discover/table_test.go +++ b/p2p/discover/table_test.go @@ -52,27 +52,22 @@ func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil) defer tab.Close() - // Wait for init so bond is accepted. <-tab.initDone - // fill up the sender's bucket. + // Fill up the sender's bucket. pingSender := NewNode(MustHexID("a502af0f59b2aab7746995408c79e9ca312d2793cc997e44fc55eda62f0150bbb8c59a6f9269ba3a081518b62699ee807c7c19c20125ddfccca872608af9e370"), net.IP{}, 99, 99) last := fillBucket(tab, pingSender) - // this call to bond should replace the last node - // in its bucket if the node is not responding. + // Add the sender as if it just pinged us. Revalidate should replace the last node in + // its bucket if it is unresponsive. Revalidate again to ensure that transport.dead[last.ID] = !lastInBucketIsResponding transport.dead[pingSender.ID] = !newNodeIsResponding - tab.bond(true, pingSender.ID, &net.UDPAddr{}, 0) + tab.add(pingSender) + tab.doRevalidate(make(chan struct{}, 1)) tab.doRevalidate(make(chan struct{}, 1)) - // first ping goes to sender (bonding pingback) - if !transport.pinged[pingSender.ID] { - t.Error("table did not ping back sender") - } if !transport.pinged[last.ID] { - // second ping goes to oldest node in bucket - // to see whether it is still alive. + // Oldest node in bucket is pinged to see whether it is still alive. t.Error("table did not ping last node in bucket") } @@ -83,7 +78,7 @@ func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding wantSize-- } if l := len(tab.bucket(pingSender.sha).entries); l != wantSize { - t.Errorf("wrong bucket size after bond: got %d, want %d", l, wantSize) + t.Errorf("wrong bucket size after add: got %d, want %d", l, wantSize) } if found := contains(tab.bucket(pingSender.sha).entries, last.ID); found != lastInBucketIsResponding { t.Errorf("last entry found: %t, want: %t", found, lastInBucketIsResponding) @@ -206,10 +201,7 @@ func newPingRecorder() *pingRecorder { func (t *pingRecorder) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) { return nil, nil } -func (t *pingRecorder) close() {} -func (t *pingRecorder) waitping(from NodeID) error { - return nil // remote always pings -} + func (t *pingRecorder) ping(toid NodeID, toaddr *net.UDPAddr) error { t.mu.Lock() defer t.mu.Unlock() @@ -222,6 +214,8 @@ func (t *pingRecorder) ping(toid NodeID, toaddr *net.UDPAddr) error { } } +func (t *pingRecorder) close() {} + func TestTable_closest(t *testing.T) { t.Parallel() diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go index f6bcd97085..0ff47c5e46 100644 --- a/p2p/discover/udp.go +++ b/p2p/discover/udp.go @@ -32,8 +32,6 @@ import ( "github.com/ethereum/go-ethereum/rlp" ) -const Version = 4 - // Errors var ( errPacketTooSmall = errors.New("too small") @@ -272,21 +270,33 @@ func (t *udp) close() { // ping sends a ping message to the given node and waits for a reply. func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error { + return <-t.sendPing(toid, toaddr, nil) +} + +// sendPing sends a ping message to the given node and invokes the callback +// when the reply arrives. +func (t *udp) sendPing(toid NodeID, toaddr *net.UDPAddr, callback func()) <-chan error { req := &ping{ - Version: Version, + Version: 4, From: t.ourEndpoint, To: makeEndpoint(toaddr, 0), // TODO: maybe use known TCP port from DB Expiration: uint64(time.Now().Add(expiration).Unix()), } packet, hash, err := encodePacket(t.priv, pingPacket, req) if err != nil { - return err + errc := make(chan error, 1) + errc <- err + return errc } errc := t.pending(toid, pongPacket, func(p interface{}) bool { - return bytes.Equal(p.(*pong).ReplyTok, hash) + ok := bytes.Equal(p.(*pong).ReplyTok, hash) + if ok && callback != nil { + callback() + } + return ok }) t.write(toaddr, req.name(), packet) - return <-errc + return errc } func (t *udp) waitping(from NodeID) error { @@ -296,6 +306,13 @@ func (t *udp) waitping(from NodeID) error { // findnode sends a findnode request to the given node and waits until // the node has sent up to k neighbors. func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) { + // If we haven't seen a ping from the destination node for a while, it won't remember + // our endpoint proof and reject findnode. Solicit a ping first. + if time.Since(t.db.lastPingReceived(toid)) > nodeDBNodeExpiration { + t.ping(toid, toaddr) + t.waitping(toid) + } + nodes := make([]*Node, 0, bucketSize) nreceived := 0 errc := t.pending(toid, neighborsPacket, func(r interface{}) bool { @@ -315,8 +332,7 @@ func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node Target: target, Expiration: uint64(time.Now().Add(expiration).Unix()), }) - err := <-errc - return nodes, err + return nodes, <-errc } // pending adds a reply callback to the pending reply queue. @@ -587,10 +603,17 @@ func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) er ReplyTok: mac, Expiration: uint64(time.Now().Add(expiration).Unix()), }) - if !t.handleReply(fromID, pingPacket, req) { - // Note: we're ignoring the provided IP address right now - go t.bond(true, fromID, from, req.From.TCP) + t.handleReply(fromID, pingPacket, req) + + // Add the node to the table. Before doing so, ensure that we have a recent enough pong + // recorded in the database so their findnode requests will be accepted later. + n := NewNode(fromID, from.IP, uint16(from.Port), req.From.TCP) + if time.Since(t.db.lastPongReceived(fromID)) > nodeDBNodeExpiration { + t.sendPing(fromID, from, func() { t.addThroughPing(n) }) + } else { + t.addThroughPing(n) } + t.db.updateLastPingReceived(fromID, time.Now()) return nil } @@ -603,6 +626,7 @@ func (req *pong) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) er if !t.handleReply(fromID, pongPacket, req) { return errUnsolicitedReply } + t.db.updateLastPongReceived(fromID, time.Now()) return nil } @@ -613,13 +637,12 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte return errExpired } if !t.db.hasBond(fromID) { - // No bond exists, we don't process the packet. This prevents - // an attack vector where the discovery protocol could be used - // to amplify traffic in a DDOS attack. A malicious actor - // would send a findnode request with the IP address and UDP - // port of the target as the source address. The recipient of - // the findnode packet would then send a neighbors packet - // (which is a much bigger packet than findnode) to the victim. + // No endpoint proof pong exists, we don't process the packet. This prevents an + // attack vector where the discovery protocol could be used to amplify traffic in a + // DDOS attack. A malicious actor would send a findnode request with the IP address + // and UDP port of the target as the source address. The recipient of the findnode + // packet would then send a neighbors packet (which is a much bigger packet than + // findnode) to the victim. return errUnknownNode } target := crypto.Keccak256Hash(req.Target[:]) diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go index db9804f7bc..b4363a12b8 100644 --- a/p2p/discover/udp_test.go +++ b/p2p/discover/udp_test.go @@ -124,7 +124,7 @@ func TestUDP_packetErrors(t *testing.T) { test := newUDPTest(t) defer test.table.Close() - test.packetIn(errExpired, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: Version}) + test.packetIn(errExpired, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: 4}) test.packetIn(errUnsolicitedReply, pongPacket, &pong{ReplyTok: []byte{}, Expiration: futureExp}) test.packetIn(errUnknownNode, findnodePacket, &findnode{Expiration: futureExp}) test.packetIn(errUnsolicitedReply, neighborsPacket, &neighbors{Expiration: futureExp}) @@ -247,7 +247,7 @@ func TestUDP_findnode(t *testing.T) { // ensure there's a bond with the test node, // findnode won't be accepted otherwise. - test.table.db.updateBondTime(PubkeyID(&test.remotekey.PublicKey), time.Now()) + test.table.db.updateLastPongReceived(PubkeyID(&test.remotekey.PublicKey), time.Now()) // check that closest neighbors are returned. test.packetIn(nil, findnodePacket, &findnode{Target: testTarget, Expiration: futureExp}) @@ -273,10 +273,12 @@ func TestUDP_findnodeMultiReply(t *testing.T) { test := newUDPTest(t) defer test.table.Close() + rid := PubkeyID(&test.remotekey.PublicKey) + test.table.db.updateLastPingReceived(rid, time.Now()) + // queue a pending findnode request resultc, errc := make(chan []*Node), make(chan error) go func() { - rid := PubkeyID(&test.remotekey.PublicKey) ns, err := test.udp.findnode(rid, test.remoteaddr, testTarget) if err != nil && len(ns) == 0 { errc <- err @@ -328,7 +330,7 @@ func TestUDP_successfulPing(t *testing.T) { defer test.table.Close() // The remote side sends a ping packet to initiate the exchange. - go test.packetIn(nil, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: Version, Expiration: futureExp}) + go test.packetIn(nil, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp}) // the ping is replied to. test.waitPacketOut(func(p *pong) {