diff --git a/p2p/discover/node.go b/p2p/discover/node.go index e1130e0b58..99cb549a59 100644 --- a/p2p/discover/node.go +++ b/p2p/discover/node.go @@ -13,6 +13,8 @@ import ( "net/url" "strconv" "strings" + "sync" + "sync/atomic" "time" "github.com/ethereum/go-ethereum/crypto" @@ -30,7 +32,8 @@ type Node struct { DiscPort int // UDP listening port for discovery protocol TCPPort int // TCP listening port for RLPx - active time.Time + // this must be set/read using atomic load and store. + activeStamp int64 } func newNode(id NodeID, addr *net.UDPAddr) *Node { @@ -39,7 +42,6 @@ func newNode(id NodeID, addr *net.UDPAddr) *Node { IP: addr.IP, DiscPort: addr.Port, TCPPort: addr.Port, - active: time.Now(), } } @@ -48,6 +50,20 @@ func (n *Node) isValid() bool { return !n.IP.IsMulticast() && !n.IP.IsUnspecified() && n.TCPPort != 0 && n.DiscPort != 0 } +func (n *Node) bumpActive() { + stamp := time.Now().Unix() + atomic.StoreInt64(&n.activeStamp, stamp) +} + +func (n *Node) active() time.Time { + stamp := atomic.LoadInt64(&n.activeStamp) + return time.Unix(stamp, 0) +} + +func (n *Node) addr() *net.UDPAddr { + return &net.UDPAddr{IP: n.IP, Port: n.DiscPort} +} + // The string representation of a Node is a URL. // Please see ParseNode for a description of the format. func (n *Node) String() string { @@ -304,3 +320,26 @@ func randomID(a NodeID, n int) (b NodeID) { } return b } + +// nodeDB stores all nodes we know about. +type nodeDB struct { + mu sync.RWMutex + byID map[NodeID]*Node +} + +func (db *nodeDB) get(id NodeID) *Node { + db.mu.RLock() + defer db.mu.RUnlock() + return db.byID[id] +} + +func (db *nodeDB) add(id NodeID, addr *net.UDPAddr, tcpPort uint16) *Node { + db.mu.Lock() + defer db.mu.Unlock() + if db.byID == nil { + db.byID = make(map[NodeID]*Node) + } + n := &Node{ID: id, IP: addr.IP, DiscPort: addr.Port, TCPPort: int(tcpPort)} + db.byID[n.ID] = n + return n +} diff --git a/p2p/discover/table.go b/p2p/discover/table.go index 33b705a12b..842f55d9f3 100644 --- a/p2p/discover/table.go +++ b/p2p/discover/table.go @@ -14,9 +14,10 @@ import ( ) const ( - alpha = 3 // Kademlia concurrency factor - bucketSize = 16 // Kademlia bucket size - nBuckets = nodeIDBits + 1 // Number of buckets + alpha = 3 // Kademlia concurrency factor + bucketSize = 16 // Kademlia bucket size + nBuckets = nodeIDBits + 1 // Number of buckets + maxBondingPingPongs = 10 ) type Table struct { @@ -24,27 +25,50 @@ type Table struct { buckets [nBuckets]*bucket // index of known nodes by distance nursery []*Node // bootstrap nodes + bondmu sync.Mutex + bonding map[NodeID]*bondproc + bondslots chan struct{} // limits total number of active bonding processes + net transport self *Node // metadata of the local node + db *nodeDB +} + +type bondproc struct { + err error + n *Node + done chan struct{} } // transport is implemented by the UDP transport. // it is an interface so we can test without opening lots of UDP // sockets and without generating a private key. type transport interface { - ping(*Node) error - findnode(e *Node, target NodeID) ([]*Node, error) + ping(NodeID, *net.UDPAddr) error + waitping(NodeID) error + findnode(toid NodeID, addr *net.UDPAddr, target NodeID) ([]*Node, error) close() } // bucket contains nodes, ordered by their last activity. +// the entry that was most recently active is the last element +// in entries. type bucket struct { lastLookup time.Time entries []*Node } func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr) *Table { - tab := &Table{net: t, self: newNode(ourID, ourAddr)} + tab := &Table{ + net: t, + db: new(nodeDB), + self: newNode(ourID, ourAddr), + bonding: make(map[NodeID]*bondproc), + bondslots: make(chan struct{}, maxBondingPingPongs), + } + for i := 0; i < cap(tab.bondslots); i++ { + tab.bondslots <- struct{}{} + } for i := range tab.buckets { tab.buckets[i] = new(bucket) } @@ -107,8 +131,8 @@ func (tab *Table) Lookup(target NodeID) []*Node { asked[n.ID] = true pendingQueries++ go func() { - result, _ := tab.net.findnode(n, target) - reply <- result + r, _ := tab.net.findnode(n.ID, n.addr(), target) + reply <- tab.bondall(r) }() } } @@ -116,13 +140,11 @@ func (tab *Table) Lookup(target NodeID) []*Node { // we have asked all closest nodes, stop the search break } - // wait for the next reply for _, n := range <-reply { - cn := n - if !seen[n.ID] { + if n != nil && !seen[n.ID] { seen[n.ID] = true - result.push(cn, bucketSize) + result.push(n, bucketSize) } } pendingQueries-- @@ -145,8 +167,9 @@ func (tab *Table) refresh() { result := tab.Lookup(randomID(tab.self.ID, ld)) if len(result) == 0 { // bootstrap the table with a self lookup + all := tab.bondall(tab.nursery) tab.mutex.Lock() - tab.add(tab.nursery) + tab.add(all) tab.mutex.Unlock() tab.Lookup(tab.self.ID) // TODO: the Kademlia paper says that we're supposed to perform @@ -176,45 +199,105 @@ func (tab *Table) len() (n int) { return n } -// bumpOrAdd updates the activity timestamp for the given node and -// attempts to insert the node into a bucket. The returned Node might -// not be part of the table. The caller must hold tab.mutex. -func (tab *Table) bumpOrAdd(node NodeID, from *net.UDPAddr) (n *Node) { - b := tab.buckets[logdist(tab.self.ID, node)] - if n = b.bump(node); n == nil { - n = newNode(node, from) - if len(b.entries) == bucketSize { - tab.pingReplace(n, b) - } else { - b.entries = append(b.entries, n) +// bondall bonds with all given nodes concurrently and returns +// those nodes for which bonding has probably succeeded. +func (tab *Table) bondall(nodes []*Node) (result []*Node) { + rc := make(chan *Node, len(nodes)) + for i := range nodes { + go func(n *Node) { + nn, _ := tab.bond(false, n.ID, n.addr(), uint16(n.TCPPort)) + rc <- nn + }(nodes[i]) + } + for _ = range nodes { + if n := <-rc; n != nil { + result = append(result, n) } } - return n + return result } -func (tab *Table) pingReplace(n *Node, b *bucket) { - old := b.entries[bucketSize-1] - go func() { - if err := tab.net.ping(old); err == nil { - // it responded, we don't need to replace it. - return +// bond ensures the local node has a bond with the given remote node. +// It also attempts to insert the node into the table if bonding succeeds. +// The caller must not hold tab.mutex. +// +// A bond is must be established before sending findnode requests. +// Both sides must have completed a ping/pong exchange for a bond to +// exist. The total number of active bonding processes is limited in +// order to restrain network use. +// +// bond is meant to operate idempotently in that bonding with a remote +// node which still remembers a previously established bond will work. +// The remote node will simply not send a ping back, causing waitping +// to time out. +// +// If pinged is true, the remote node has just pinged us and one half +// of the process can be skipped. +func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16) (*Node, error) { + var n *Node + if n = tab.db.get(id); n == nil { + tab.bondmu.Lock() + w := tab.bonding[id] + if w != nil { + // Wait for an existing bonding process to complete. + tab.bondmu.Unlock() + <-w.done + } else { + // Register a new bonding process. + w = &bondproc{done: make(chan struct{})} + tab.bonding[id] = w + tab.bondmu.Unlock() + // Do the ping/pong. The result goes into w. + tab.pingpong(w, pinged, id, addr, tcpPort) + // Unregister the process after it's done. + tab.bondmu.Lock() + delete(tab.bonding, id) + tab.bondmu.Unlock() } - // it didn't respond, replace the node if it is still the oldest node. - tab.mutex.Lock() - if len(b.entries) > 0 && b.entries[len(b.entries)-1] == old { - // slide down other entries and put the new one in front. - // TODO: insert in correct position to keep the order - copy(b.entries[1:], b.entries) - b.entries[0] = n + n = w.n + if w.err != nil { + return nil, w.err } - tab.mutex.Unlock() - }() + } + tab.mutex.Lock() + defer tab.mutex.Unlock() + if b := tab.buckets[logdist(tab.self.ID, n.ID)]; !b.bump(n) { + tab.pingreplace(n, b) + } + return n, nil +} + +func (tab *Table) pingpong(w *bondproc, pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16) { + <-tab.bondslots + defer func() { tab.bondslots <- struct{}{} }() + if w.err = tab.net.ping(id, addr); w.err != nil { + close(w.done) + return + } + if !pinged { + // Give the remote node a chance to ping us before we start + // sending findnode requests. If they still remember us, + // waitping will simply time out. + tab.net.waitping(id) + } + w.n = tab.db.add(id, addr, tcpPort) + close(w.done) } -// bump updates the activity timestamp for the given node. -// The caller must hold tab.mutex. -func (tab *Table) bump(node NodeID) { - tab.buckets[logdist(tab.self.ID, node)].bump(node) +func (tab *Table) pingreplace(new *Node, b *bucket) { + if len(b.entries) == bucketSize { + oldest := b.entries[bucketSize-1] + if err := tab.net.ping(oldest.ID, oldest.addr()); err == nil { + // The node responded, we don't need to replace it. + return + } + } else { + // Add a slot at the end so the last entry doesn't + // fall off when adding the new node. + b.entries = append(b.entries, nil) + } + copy(b.entries[1:], b.entries) + b.entries[0] = new } // add puts the entries into the table if their corresponding @@ -240,17 +323,17 @@ outer: } } -func (b *bucket) bump(id NodeID) *Node { - for i, n := range b.entries { - if n.ID == id { - n.active = time.Now() +func (b *bucket) bump(n *Node) bool { + for i := range b.entries { + if b.entries[i].ID == n.ID { + n.bumpActive() // move it to the front copy(b.entries[1:], b.entries[:i+1]) b.entries[0] = n - return n + return true } } - return nil + return false } // nodesByDistance is a list of nodes, ordered by diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go index 08faea68e9..95ec30bea4 100644 --- a/p2p/discover/table_test.go +++ b/p2p/discover/table_test.go @@ -2,79 +2,68 @@ package discover import ( "crypto/ecdsa" - "errors" "fmt" "math/rand" "net" "reflect" "testing" "testing/quick" - "time" "github.com/ethereum/go-ethereum/crypto" ) -func TestTable_bumpOrAddBucketAssign(t *testing.T) { - tab := newTable(nil, NodeID{}, &net.UDPAddr{}) - for i := 1; i < len(tab.buckets); i++ { - tab.bumpOrAdd(randomID(tab.self.ID, i), &net.UDPAddr{}) - } - for i, b := range tab.buckets { - if i > 0 && len(b.entries) != 1 { - t.Errorf("bucket %d has %d entries, want 1", i, len(b.entries)) +func TestTable_pingReplace(t *testing.T) { + doit := func(newNodeIsResponding, lastInBucketIsResponding bool) { + transport := newPingRecorder() + tab := newTable(transport, NodeID{}, &net.UDPAddr{}) + last := fillBucket(tab, 200) + pingSender := randomID(tab.self.ID, 200) + + // this gotPing should replace the last node + // if the last node is not responding. + transport.responding[last.ID] = lastInBucketIsResponding + transport.responding[pingSender] = newNodeIsResponding + tab.bond(true, pingSender, &net.UDPAddr{}, 0) + + // first ping goes to sender (bonding pingback) + if !transport.pinged[pingSender] { + t.Error("table did not ping back sender") + } + if newNodeIsResponding { + // second ping goes to oldest node in bucket + // to see whether it is still alive. + if !transport.pinged[last.ID] { + t.Error("table did not ping last node in bucket") + } } - } -} - -func TestTable_bumpOrAddPingReplace(t *testing.T) { - pingC := make(pingC) - tab := newTable(pingC, NodeID{}, &net.UDPAddr{}) - last := fillBucket(tab, 200) - - // this bumpOrAdd should not replace the last node - // because the node replies to ping. - new := tab.bumpOrAdd(randomID(tab.self.ID, 200), &net.UDPAddr{}) - pinged := <-pingC - if pinged != last.ID { - t.Fatalf("pinged wrong node: %v\nwant %v", pinged, last.ID) - } + tab.mutex.Lock() + defer tab.mutex.Unlock() + if l := len(tab.buckets[200].entries); l != bucketSize { + t.Errorf("wrong bucket size after gotPing: got %d, want %d", bucketSize, l) + } - tab.mutex.Lock() - defer tab.mutex.Unlock() - if l := len(tab.buckets[200].entries); l != bucketSize { - t.Errorf("wrong bucket size after bumpOrAdd: got %d, want %d", bucketSize, l) - } - if !contains(tab.buckets[200].entries, last.ID) { - t.Error("last entry was removed") - } - if contains(tab.buckets[200].entries, new.ID) { - t.Error("new entry was added") + if lastInBucketIsResponding || !newNodeIsResponding { + if !contains(tab.buckets[200].entries, last.ID) { + t.Error("last entry was removed") + } + if contains(tab.buckets[200].entries, pingSender) { + t.Error("new entry was added") + } + } else { + if contains(tab.buckets[200].entries, last.ID) { + t.Error("last entry was not removed") + } + if !contains(tab.buckets[200].entries, pingSender) { + t.Error("new entry was not added") + } + } } -} - -func TestTable_bumpOrAddPingTimeout(t *testing.T) { - tab := newTable(pingC(nil), NodeID{}, &net.UDPAddr{}) - last := fillBucket(tab, 200) - // this bumpOrAdd should replace the last node - // because the node does not reply to ping. - new := tab.bumpOrAdd(randomID(tab.self.ID, 200), &net.UDPAddr{}) - - // wait for async bucket update. damn. this needs to go away. - time.Sleep(2 * time.Millisecond) - - tab.mutex.Lock() - defer tab.mutex.Unlock() - if l := len(tab.buckets[200].entries); l != bucketSize { - t.Errorf("wrong bucket size after bumpOrAdd: got %d, want %d", bucketSize, l) - } - if contains(tab.buckets[200].entries, last.ID) { - t.Error("last entry was not removed") - } - if !contains(tab.buckets[200].entries, new.ID) { - t.Error("new entry was not added") - } + doit(true, true) + doit(false, true) + doit(false, true) + doit(false, false) } func fillBucket(tab *Table, ld int) (last *Node) { @@ -85,44 +74,27 @@ func fillBucket(tab *Table, ld int) (last *Node) { return b.entries[bucketSize-1] } -type pingC chan NodeID +type pingRecorder struct{ responding, pinged map[NodeID]bool } -func (t pingC) findnode(n *Node, target NodeID) ([]*Node, error) { +func newPingRecorder() *pingRecorder { + return &pingRecorder{make(map[NodeID]bool), make(map[NodeID]bool)} +} + +func (t *pingRecorder) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) { panic("findnode called on pingRecorder") } -func (t pingC) close() { +func (t *pingRecorder) close() { panic("close called on pingRecorder") } -func (t pingC) ping(n *Node) error { - if t == nil { - return errTimeout - } - t <- n.ID - return nil +func (t *pingRecorder) waitping(from NodeID) error { + return nil // remote always pings } - -func TestTable_bump(t *testing.T) { - tab := newTable(nil, NodeID{}, &net.UDPAddr{}) - - // add an old entry and two recent ones - oldactive := time.Now().Add(-2 * time.Minute) - old := &Node{ID: randomID(tab.self.ID, 200), active: oldactive} - others := []*Node{ - &Node{ID: randomID(tab.self.ID, 200), active: time.Now()}, - &Node{ID: randomID(tab.self.ID, 200), active: time.Now()}, - } - tab.add(append(others, old)) - if tab.buckets[200].entries[0] == old { - t.Fatal("old entry is at front of bucket") - } - - // bumping the old entry should move it to the front - tab.bump(old.ID) - if old.active == oldactive { - t.Error("activity timestamp not updated") - } - if tab.buckets[200].entries[0] != old { - t.Errorf("bumped entry did not move to the front of bucket") +func (t *pingRecorder) ping(toid NodeID, toaddr *net.UDPAddr) error { + t.pinged[toid] = true + if t.responding[toid] { + return nil + } else { + return errTimeout } } @@ -210,7 +182,7 @@ func TestTable_Lookup(t *testing.T) { t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results) } // seed table with initial node (otherwise lookup will terminate immediately) - tab.bumpOrAdd(randomID(target, 200), &net.UDPAddr{Port: 200}) + tab.add([]*Node{newNode(randomID(target, 200), &net.UDPAddr{Port: 200})}) results := tab.Lookup(target) t.Logf("results:") @@ -238,16 +210,16 @@ type findnodeOracle struct { target NodeID } -func (t findnodeOracle) findnode(n *Node, target NodeID) ([]*Node, error) { - t.t.Logf("findnode query at dist %d", n.DiscPort) +func (t findnodeOracle) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) { + t.t.Logf("findnode query at dist %d", toaddr.Port) // current log distance is encoded in port number var result []*Node - switch n.DiscPort { + switch toaddr.Port { case 0: panic("query to node at distance 0") default: // TODO: add more randomness to distances - next := n.DiscPort - 1 + next := toaddr.Port - 1 for i := 0; i < bucketSize; i++ { result = append(result, &Node{ID: randomID(t.target, next), DiscPort: next}) } @@ -255,11 +227,9 @@ func (t findnodeOracle) findnode(n *Node, target NodeID) ([]*Node, error) { return result, nil } -func (t findnodeOracle) close() {} - -func (t findnodeOracle) ping(n *Node) error { - return errors.New("ping is not supported by this transport") -} +func (t findnodeOracle) close() {} +func (t findnodeOracle) waitping(from NodeID) error { return nil } +func (t findnodeOracle) ping(toid NodeID, toaddr *net.UDPAddr) error { return nil } func hasDuplicates(slice []*Node) bool { seen := make(map[NodeID]bool) diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go index 738a01fb7e..e9ede13972 100644 --- a/p2p/discover/udp.go +++ b/p2p/discover/udp.go @@ -20,12 +20,14 @@ const Version = 3 // Errors var ( - errPacketTooSmall = errors.New("too small") - errBadHash = errors.New("bad hash") - errExpired = errors.New("expired") - errBadVersion = errors.New("version mismatch") - errTimeout = errors.New("RPC timeout") - errClosed = errors.New("socket closed") + errPacketTooSmall = errors.New("too small") + errBadHash = errors.New("bad hash") + errExpired = errors.New("expired") + errBadVersion = errors.New("version mismatch") + errUnsolicitedReply = errors.New("unsolicited reply") + errUnknownNode = errors.New("unknown node") + errTimeout = errors.New("RPC timeout") + errClosed = errors.New("socket closed") ) // Timeouts @@ -80,14 +82,27 @@ type rpcNode struct { ID NodeID } +type packet interface { + handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error +} + +type conn interface { + ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) + WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error) + Close() error + LocalAddr() net.Addr +} + // udp implements the RPC protocol. type udp struct { - conn *net.UDPConn - priv *ecdsa.PrivateKey + conn conn + priv *ecdsa.PrivateKey + addpending chan *pending - replies chan reply - closing chan struct{} - nat nat.Interface + gotreply chan reply + + closing chan struct{} + nat nat.Interface *Table } @@ -124,6 +139,9 @@ type reply struct { from NodeID ptype byte data interface{} + // loop indicates whether there was + // a matching request by sending on this channel. + matched chan<- bool } // ListenUDP returns a new table that listens for UDP packets on laddr. @@ -136,15 +154,20 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface) (*Table if err != nil { return nil, err } + tab, _ := newUDP(priv, conn, natm) + log.Infoln("Listening,", tab.self) + return tab, nil +} + +func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface) (*Table, *udp) { udp := &udp{ - conn: conn, + conn: c, priv: priv, closing: make(chan struct{}), + gotreply: make(chan reply), addpending: make(chan *pending), - replies: make(chan reply), } - - realaddr := conn.LocalAddr().(*net.UDPAddr) + realaddr := c.LocalAddr().(*net.UDPAddr) if natm != nil { if !realaddr.IP.IsLoopback() { go nat.Map(natm, udp.closing, "udp", realaddr.Port, realaddr.Port, "ethereum discovery") @@ -155,11 +178,9 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface) (*Table } } udp.Table = newTable(udp, PubkeyID(&priv.PublicKey), realaddr) - go udp.loop() go udp.readLoop() - log.Infoln("Listening, ", udp.self) - return udp.Table, nil + return udp.Table, udp } func (t *udp) close() { @@ -169,10 +190,10 @@ func (t *udp) close() { } // ping sends a ping message to the given node and waits for a reply. -func (t *udp) ping(e *Node) error { +func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error { // TODO: maybe check for ReplyTo field in callback to measure RTT - errc := t.pending(e.ID, pongPacket, func(interface{}) bool { return true }) - t.send(e, pingPacket, ping{ + errc := t.pending(toid, pongPacket, func(interface{}) bool { return true }) + t.send(toaddr, pingPacket, ping{ Version: Version, IP: t.self.IP.String(), Port: uint16(t.self.TCPPort), @@ -181,12 +202,16 @@ func (t *udp) ping(e *Node) error { return <-errc } +func (t *udp) waitping(from NodeID) error { + return <-t.pending(from, pingPacket, func(interface{}) bool { return true }) +} + // findnode sends a findnode request to the given node and waits until // the node has sent up to k neighbors. -func (t *udp) findnode(to *Node, target NodeID) ([]*Node, error) { +func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) { nodes := make([]*Node, 0, bucketSize) nreceived := 0 - errc := t.pending(to.ID, neighborsPacket, func(r interface{}) bool { + errc := t.pending(toid, neighborsPacket, func(r interface{}) bool { reply := r.(*neighbors) for _, n := range reply.Nodes { nreceived++ @@ -196,8 +221,7 @@ func (t *udp) findnode(to *Node, target NodeID) ([]*Node, error) { } return nreceived >= bucketSize }) - - t.send(to, findnodePacket, findnode{ + t.send(toaddr, findnodePacket, findnode{ Target: target, Expiration: uint64(time.Now().Add(expiration).Unix()), }) @@ -219,6 +243,17 @@ func (t *udp) pending(id NodeID, ptype byte, callback func(interface{}) bool) <- return ch } +func (t *udp) handleReply(from NodeID, ptype byte, req packet) bool { + matched := make(chan bool) + select { + case t.gotreply <- reply{from, ptype, req, matched}: + // loop will handle it + return <-matched + case <-t.closing: + return false + } +} + // loop runs in its own goroutin. it keeps track of // the refresh timer and the pending reply queue. func (t *udp) loop() { @@ -249,6 +284,7 @@ func (t *udp) loop() { for _, p := range pending { p.errc <- errClosed } + pending = nil return case p := <-t.addpending: @@ -256,18 +292,21 @@ func (t *udp) loop() { pending = append(pending, p) rearmTimeout() - case reply := <-t.replies: - // run matching callbacks, remove if they return false. + case r := <-t.gotreply: + var matched bool for i := 0; i < len(pending); i++ { - p := pending[i] - if reply.from == p.from && reply.ptype == p.ptype && p.callback(reply.data) { - p.errc <- nil - copy(pending[i:], pending[i+1:]) - pending = pending[:len(pending)-1] - i-- + if p := pending[i]; p.from == r.from && p.ptype == r.ptype { + matched = true + if p.callback(r.data) { + // callback indicates the request is done, remove it. + p.errc <- nil + copy(pending[i:], pending[i+1:]) + pending = pending[:len(pending)-1] + i-- + } } } - rearmTimeout() + r.matched <- matched case now := <-timeout.C: // notify and remove callbacks whose deadline is in the past. @@ -292,33 +331,38 @@ const ( var headSpace = make([]byte, headSize) -func (t *udp) send(to *Node, ptype byte, req interface{}) error { +func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req interface{}) error { + packet, err := encodePacket(t.priv, ptype, req) + if err != nil { + return err + } + log.DebugDetailf(">>> %v %T %v\n", toaddr, req, req) + if _, err = t.conn.WriteToUDP(packet, toaddr); err != nil { + log.DebugDetailln("UDP send failed:", err) + } + return err +} + +func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) ([]byte, error) { b := new(bytes.Buffer) b.Write(headSpace) b.WriteByte(ptype) if err := rlp.Encode(b, req); err != nil { log.Errorln("error encoding packet:", err) - return err + return nil, err } - packet := b.Bytes() - sig, err := crypto.Sign(crypto.Sha3(packet[headSize:]), t.priv) + sig, err := crypto.Sign(crypto.Sha3(packet[headSize:]), priv) if err != nil { log.Errorln("could not sign packet:", err) - return err + return nil, err } copy(packet[macSize:], sig) // add the hash to the front. Note: this doesn't protect the // packet in any way. Our public key will be part of this hash in - // the future. + // The future. copy(packet, crypto.Sha3(packet[macSize:])) - - toaddr := &net.UDPAddr{IP: to.IP, Port: to.DiscPort} - log.DebugDetailf(">>> %v %T %v\n", toaddr, req, req) - if _, err = t.conn.WriteToUDP(packet, toaddr); err != nil { - log.DebugDetailln("UDP send failed:", err) - } - return err + return packet, nil } // readLoop runs in its own goroutine. it handles incoming UDP packets. @@ -330,29 +374,34 @@ func (t *udp) readLoop() { if err != nil { return } - if err := t.packetIn(from, buf[:nbytes]); err != nil { + packet, fromID, hash, err := decodePacket(buf[:nbytes]) + if err != nil { log.Debugf("Bad packet from %v: %v\n", from, err) + continue } + log.DebugDetailf("<<< %v %T %v\n", from, packet, packet) + go func() { + if err := packet.handle(t, from, fromID, hash); err != nil { + log.Debugf("error handling %T from %v: %v", packet, from, err) + } + }() } } -func (t *udp) packetIn(from *net.UDPAddr, buf []byte) error { +func decodePacket(buf []byte) (packet, NodeID, []byte, error) { if len(buf) < headSize+1 { - return errPacketTooSmall + return nil, NodeID{}, nil, errPacketTooSmall } hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:] shouldhash := crypto.Sha3(buf[macSize:]) if !bytes.Equal(hash, shouldhash) { - return errBadHash + return nil, NodeID{}, nil, errBadHash } fromID, err := recoverNodeID(crypto.Sha3(buf[headSize:]), sig) if err != nil { - return err - } - - var req interface { - handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error + return nil, NodeID{}, hash, err } + var req packet switch ptype := sigdata[0]; ptype { case pingPacket: req = new(ping) @@ -363,13 +412,10 @@ func (t *udp) packetIn(from *net.UDPAddr, buf []byte) error { case neighborsPacket: req = new(neighbors) default: - return fmt.Errorf("unknown type: %d", ptype) + return nil, fromID, hash, fmt.Errorf("unknown type: %d", ptype) } - if err := rlp.Decode(bytes.NewReader(sigdata[1:]), req); err != nil { - return err - } - log.DebugDetailf("<<< %v %T %v\n", from, req, req) - return req.handle(t, from, fromID, hash) + err = rlp.Decode(bytes.NewReader(sigdata[1:]), req) + return req, fromID, hash, err } func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { @@ -379,18 +425,14 @@ func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) er if req.Version != Version { return errBadVersion } - t.mutex.Lock() - // Note: we're ignoring the provided IP address right now - n := t.bumpOrAdd(fromID, from) - if req.Port != 0 { - n.TCPPort = int(req.Port) - } - t.mutex.Unlock() - - t.send(n, pongPacket, pong{ + t.send(from, pongPacket, pong{ ReplyTok: mac, Expiration: uint64(time.Now().Add(expiration).Unix()), }) + if !t.handleReply(fromID, pingPacket, req) { + // Note: we're ignoring the provided IP address right now + t.bond(true, fromID, from, req.Port) + } return nil } @@ -398,11 +440,9 @@ func (req *pong) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) er if expired(req.Expiration) { return errExpired } - t.mutex.Lock() - t.bump(fromID) - t.mutex.Unlock() - - t.replies <- reply{fromID, pongPacket, req} + if !t.handleReply(fromID, pongPacket, req) { + return errUnsolicitedReply + } return nil } @@ -410,12 +450,21 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte if expired(req.Expiration) { return errExpired } + if t.db.get(fromID) == nil { + // No bond exists, we don't process the packet. This prevents + // an attack vector where the discovery protocol could be used + // to amplify traffic in a DDOS attack. A malicious actor + // would send a findnode request with the IP address and UDP + // port of the target as the source address. The recipient of + // the findnode packet would then send a neighbors packet + // (which is a much bigger packet than findnode) to the victim. + return errUnknownNode + } t.mutex.Lock() - e := t.bumpOrAdd(fromID, from) closest := t.closest(req.Target, bucketSize).entries t.mutex.Unlock() - t.send(e, neighborsPacket, neighbors{ + t.send(from, neighborsPacket, neighbors{ Nodes: closest, Expiration: uint64(time.Now().Add(expiration).Unix()), }) @@ -426,12 +475,9 @@ func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byt if expired(req.Expiration) { return errExpired } - t.mutex.Lock() - t.bump(fromID) - t.add(req.Nodes) - t.mutex.Unlock() - - t.replies <- reply{fromID, neighborsPacket, req} + if !t.handleReply(fromID, neighborsPacket, req) { + return errUnsolicitedReply + } return nil } diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go index 0a8ff63589..c6c4d78e3d 100644 --- a/p2p/discover/udp_test.go +++ b/p2p/discover/udp_test.go @@ -1,10 +1,18 @@ package discover import ( + "bytes" + "crypto/ecdsa" + "errors" "fmt" + "io" logpkg "log" "net" "os" + "path" + "reflect" + "runtime" + "sync" "testing" "time" @@ -15,197 +23,317 @@ func init() { logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, logpkg.LstdFlags, logger.ErrorLevel)) } -func TestUDP_ping(t *testing.T) { - t.Parallel() - - n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil) - n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil) - defer n1.Close() - defer n2.Close() +type udpTest struct { + t *testing.T + pipe *dgramPipe + table *Table + udp *udp + sent [][]byte + localkey, remotekey *ecdsa.PrivateKey + remoteaddr *net.UDPAddr +} - if err := n1.net.ping(n2.self); err != nil { - t.Fatalf("ping error: %v", err) +func newUDPTest(t *testing.T) *udpTest { + test := &udpTest{ + t: t, + pipe: newpipe(), + localkey: newkey(), + remotekey: newkey(), + remoteaddr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 30303}, } - if find(n2, n1.self.ID) == nil { - t.Errorf("node 2 does not contain id of node 1") + test.table, test.udp = newUDP(test.localkey, test.pipe, nil) + return test +} + +// handles a packet as if it had been sent to the transport. +func (test *udpTest) packetIn(wantError error, ptype byte, data packet) error { + enc, err := encodePacket(test.remotekey, ptype, data) + if err != nil { + return test.errorf("packet (%d) encode error: %v", err) } - if e := find(n1, n2.self.ID); e != nil { - t.Errorf("node 1 does contains id of node 2: %v", e) + test.sent = append(test.sent, enc) + err = data.handle(test.udp, test.remoteaddr, PubkeyID(&test.remotekey.PublicKey), enc[:macSize]) + if err != wantError { + return test.errorf("error mismatch: got %q, want %q", err, wantError) } + return nil } -func find(tab *Table, id NodeID) *Node { - for _, b := range tab.buckets { - for _, e := range b.entries { - if e.ID == id { - return e - } - } +// waits for a packet to be sent by the transport. +// validate should have type func(*udpTest, X) error, where X is a packet type. +func (test *udpTest) waitPacketOut(validate interface{}) error { + dgram := test.pipe.waitPacketOut() + p, _, _, err := decodePacket(dgram) + if err != nil { + return test.errorf("sent packet decode error: %v", err) } + fn := reflect.ValueOf(validate) + exptype := fn.Type().In(0) + if reflect.TypeOf(p) != exptype { + return test.errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype) + } + fn.Call([]reflect.Value{reflect.ValueOf(p)}) return nil } -func TestUDP_findnode(t *testing.T) { +func (test *udpTest) errorf(format string, args ...interface{}) error { + _, file, line, ok := runtime.Caller(2) // errorf + waitPacketOut + if ok { + file = path.Base(file) + } else { + file = "???" + line = 1 + } + err := fmt.Errorf(format, args...) + fmt.Printf("\t%s:%d: %v\n", file, line, err) + test.t.Fail() + return err +} + +// shared test variables +var ( + futureExp = uint64(time.Now().Add(10 * time.Hour).Unix()) + testTarget = MustHexID("01010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101") +) + +func TestUDP_packetErrors(t *testing.T) { + test := newUDPTest(t) + defer test.table.Close() + + test.packetIn(errExpired, pingPacket, &ping{IP: "foo", Port: 99, Version: Version}) + test.packetIn(errBadVersion, pingPacket, &ping{IP: "foo", Port: 99, Version: 99, Expiration: futureExp}) + test.packetIn(errUnsolicitedReply, pongPacket, &pong{ReplyTok: []byte{}, Expiration: futureExp}) + test.packetIn(errUnknownNode, findnodePacket, &findnode{Expiration: futureExp}) + test.packetIn(errUnsolicitedReply, neighborsPacket, &neighbors{Expiration: futureExp}) +} + +func TestUDP_pingTimeout(t *testing.T) { + t.Parallel() + test := newUDPTest(t) + defer test.table.Close() + + toaddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 2222} + toid := NodeID{1, 2, 3, 4} + if err := test.udp.ping(toid, toaddr); err != errTimeout { + t.Error("expected timeout error, got", err) + } +} + +func TestUDP_findnodeTimeout(t *testing.T) { t.Parallel() + test := newUDPTest(t) + defer test.table.Close() - n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil) - n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil) - defer n1.Close() - defer n2.Close() + toaddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 2222} + toid := NodeID{1, 2, 3, 4} + target := NodeID{4, 5, 6, 7} + result, err := test.udp.findnode(toid, toaddr, target) + if err != errTimeout { + t.Error("expected timeout error, got", err) + } + if len(result) > 0 { + t.Error("expected empty result, got", result) + } +} - // put a few nodes into n2. the exact distribution shouldn't - // matter much, altough we need to take care not to overflow - // any bucket. - target := randomID(n1.self.ID, 100) +func TestUDP_findnode(t *testing.T) { + test := newUDPTest(t) + defer test.table.Close() + + // put a few nodes into the table. their exact + // distribution shouldn't matter much, altough we need to + // take care not to overflow any bucket. + target := testTarget nodes := &nodesByDistance{target: target} for i := 0; i < bucketSize; i++ { - n2.add([]*Node{&Node{ + nodes.push(&Node{ IP: net.IP{1, 2, 3, byte(i)}, DiscPort: i + 2, TCPPort: i + 2, - ID: randomID(n2.self.ID, i+2), - }}) + ID: randomID(test.table.self.ID, i+2), + }, bucketSize) } - n2.add(nodes.entries) - n2.bumpOrAdd(n1.self.ID, &net.UDPAddr{IP: n1.self.IP, Port: n1.self.DiscPort}) - expected := n2.closest(target, bucketSize) + test.table.add(nodes.entries) + + // ensure there's a bond with the test node, + // findnode won't be accepted otherwise. + test.table.db.add(PubkeyID(&test.remotekey.PublicKey), test.remoteaddr, 99) - err := runUDP(10, func() error { - result, _ := n1.net.findnode(n2.self, target) - if len(result) != bucketSize { - return fmt.Errorf("wrong number of results: got %d, want %d", len(result), bucketSize) + // check that closest neighbors are returned. + test.packetIn(nil, findnodePacket, &findnode{Target: testTarget, Expiration: futureExp}) + test.waitPacketOut(func(p *neighbors) { + expected := test.table.closest(testTarget, bucketSize) + if len(p.Nodes) != bucketSize { + t.Errorf("wrong number of results: got %d, want %d", len(p.Nodes), bucketSize) } - for i := range result { - if result[i].ID != expected.entries[i].ID { - return fmt.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, result[i], expected.entries[i]) + for i := range p.Nodes { + if p.Nodes[i].ID != expected.entries[i].ID { + t.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, p.Nodes[i], expected.entries[i]) } } - return nil }) - if err != nil { - t.Error(err) - } } -func TestUDP_replytimeout(t *testing.T) { - t.Parallel() +func TestUDP_findnodeMultiReply(t *testing.T) { + test := newUDPTest(t) + defer test.table.Close() - // reserve a port so we don't talk to an existing service by accident - addr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0") - fd, err := net.ListenUDP("udp", addr) - if err != nil { - t.Fatal(err) - } - defer fd.Close() + // queue a pending findnode request + resultc, errc := make(chan []*Node), make(chan error) + go func() { + rid := PubkeyID(&test.remotekey.PublicKey) + ns, err := test.udp.findnode(rid, test.remoteaddr, testTarget) + if err != nil && len(ns) == 0 { + errc <- err + } else { + resultc <- ns + } + }() - n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil) - defer n1.Close() - n2 := n1.bumpOrAdd(randomID(n1.self.ID, 10), fd.LocalAddr().(*net.UDPAddr)) + // wait for the findnode to be sent. + // after it is sent, the transport is waiting for a reply + test.waitPacketOut(func(p *findnode) { + if p.Target != testTarget { + t.Errorf("wrong target: got %v, want %v", p.Target, testTarget) + } + }) - if err := n1.net.ping(n2); err != errTimeout { - t.Error("expected timeout error, got", err) + // send the reply as two packets. + list := []*Node{ + MustParseNode("enode://ba85011c70bcc5c04d8607d3a0ed29aa6179c092cbdda10d5d32684fb33ed01bd94f588ca8f91ac48318087dcb02eaf36773a7a453f0eedd6742af668097b29c@10.0.1.16:30303"), + MustParseNode("enode://81fa361d25f157cd421c60dcc28d8dac5ef6a89476633339c5df30287474520caca09627da18543d9079b5b288698b542d56167aa5c09111e55acdbbdf2ef799@10.0.1.16:30303"), + MustParseNode("enode://9bffefd833d53fac8e652415f4973bee289e8b1a5c6c4cbe70abf817ce8a64cee11b823b66a987f51aaa9fba0d6a91b3e6bf0d5a5d1042de8e9eeea057b217f8@10.0.1.36:30301"), + MustParseNode("enode://1b5b4aa662d7cb44a7221bfba67302590b643028197a7d5214790f3bac7aaa4a3241be9e83c09cf1f6c69d007c634faae3dc1b1221793e8446c0b3a09de65960@10.0.1.16:30303"), } + test.packetIn(nil, neighborsPacket, &neighbors{Expiration: futureExp, Nodes: list[:2]}) + test.packetIn(nil, neighborsPacket, &neighbors{Expiration: futureExp, Nodes: list[2:]}) - if result, err := n1.net.findnode(n2, n1.self.ID); err != errTimeout { - t.Error("expected timeout error, got", err) - } else if len(result) > 0 { - t.Error("expected empty result, got", result) + // check that the sent neighbors are all returned by findnode + select { + case result := <-resultc: + if !reflect.DeepEqual(result, list) { + t.Errorf("neighbors mismatch:\n got: %v\n want: %v", result, list) + } + case err := <-errc: + t.Errorf("findnode error: %v", err) + case <-time.After(5 * time.Second): + t.Error("findnode did not return within 5 seconds") } } -func TestUDP_findnodeMultiReply(t *testing.T) { - t.Parallel() +func TestUDP_successfulPing(t *testing.T) { + test := newUDPTest(t) + defer test.table.Close() - n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil) - n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil) - udp2 := n2.net.(*udp) - defer n1.Close() - defer n2.Close() - - err := runUDP(10, func() error { - nodes := make([]*Node, bucketSize) - for i := range nodes { - nodes[i] = &Node{ - IP: net.IP{1, 2, 3, 4}, - DiscPort: i + 1, - TCPPort: i + 1, - ID: randomID(n2.self.ID, i+1), - } - } + done := make(chan struct{}) + go func() { + test.packetIn(nil, pingPacket, &ping{IP: "foo", Port: 99, Version: Version, Expiration: futureExp}) + close(done) + }() - // ask N2 for neighbors. it will send an empty reply back. - // the request will wait for up to bucketSize replies. - resultc := make(chan []*Node) - errc := make(chan error) - go func() { - ns, err := n1.net.findnode(n2.self, n1.self.ID) - if err != nil { - errc <- err - } else { - resultc <- ns - } - }() - - // send a few more neighbors packets to N1. - // it should collect those. - for end := 0; end < len(nodes); { - off := end - if end = end + 5; end > len(nodes) { - end = len(nodes) - } - udp2.send(n1.self, neighborsPacket, neighbors{ - Nodes: nodes[off:end], - Expiration: uint64(time.Now().Add(10 * time.Second).Unix()), - }) + // the ping is replied to. + test.waitPacketOut(func(p *pong) { + pinghash := test.sent[0][:macSize] + if !bytes.Equal(p.ReplyTok, pinghash) { + t.Errorf("got ReplyTok %x, want %x", p.ReplyTok, pinghash) } + }) - // check that they are all returned. we cannot just check for - // equality because they might not be returned in the order they - // were sent. - var result []*Node - select { - case result = <-resultc: - case err := <-errc: - return err - } - if hasDuplicates(result) { - return fmt.Errorf("result slice contains duplicates") - } - if len(result) != len(nodes) { - return fmt.Errorf("wrong number of nodes returned: got %d, want %d", len(result), len(nodes)) - } - matched := make(map[NodeID]bool) - for _, n := range result { - for _, expn := range nodes { - if n.ID == expn.ID { // && bytes.Equal(n.Addr.IP, expn.Addr.IP) && n.Addr.Port == expn.Addr.Port { - matched[n.ID] = true - } + // remote is unknown, the table pings back. + test.waitPacketOut(func(p *ping) error { return nil }) + test.packetIn(nil, pongPacket, &pong{Expiration: futureExp}) + + // ping should return shortly after getting the pong packet. + <-done + + // check that the node was added. + rid := PubkeyID(&test.remotekey.PublicKey) + rnode := find(test.table, rid) + if rnode == nil { + t.Fatalf("node %v not found in table", rid) + } + if !bytes.Equal(rnode.IP, test.remoteaddr.IP) { + t.Errorf("node has wrong IP: got %v, want: %v", rnode.IP, test.remoteaddr.IP) + } + if rnode.DiscPort != test.remoteaddr.Port { + t.Errorf("node has wrong Port: got %v, want: %v", rnode.DiscPort, test.remoteaddr.Port) + } + if rnode.TCPPort != 99 { + t.Errorf("node has wrong Port: got %v, want: %v", rnode.TCPPort, 99) + } +} + +func find(tab *Table, id NodeID) *Node { + for _, b := range tab.buckets { + for _, e := range b.entries { + if e.ID == id { + return e } } - if len(matched) != len(nodes) { - return fmt.Errorf("wrong number of matching nodes: got %d, want %d", len(matched), len(nodes)) - } - return nil - }) - if err != nil { - t.Error(err) } + return nil } -// runUDP runs a test n times and returns an error if the test failed -// in all n runs. This is necessary because UDP is unreliable even for -// connections on the local machine, causing test failures. -func runUDP(n int, test func() error) error { - errcount := 0 - errors := "" - for i := 0; i < n; i++ { - if err := test(); err != nil { - errors += fmt.Sprintf("\n#%d: %v", i, err) - errcount++ - } +// dgramPipe is a fake UDP socket. It queues all sent datagrams. +type dgramPipe struct { + mu *sync.Mutex + cond *sync.Cond + closing chan struct{} + closed bool + queue [][]byte +} + +func newpipe() *dgramPipe { + mu := new(sync.Mutex) + return &dgramPipe{ + closing: make(chan struct{}), + cond: &sync.Cond{L: mu}, + mu: mu, } - if errcount == n { - return fmt.Errorf("failed on all %d iterations:%s", n, errors) +} + +// WriteToUDP queues a datagram. +func (c *dgramPipe) WriteToUDP(b []byte, to *net.UDPAddr) (n int, err error) { + msg := make([]byte, len(b)) + copy(msg, b) + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return 0, errors.New("closed") + } + c.queue = append(c.queue, msg) + c.cond.Signal() + return len(b), nil +} + +// ReadFromUDP just hangs until the pipe is closed. +func (c *dgramPipe) ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) { + <-c.closing + return 0, nil, io.EOF +} + +func (c *dgramPipe) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + if !c.closed { + close(c.closing) + c.closed = true } return nil } + +func (c *dgramPipe) LocalAddr() net.Addr { + return &net.UDPAddr{} +} + +func (c *dgramPipe) waitPacketOut() []byte { + c.mu.Lock() + defer c.mu.Unlock() + for len(c.queue) == 0 { + c.cond.Wait() + } + p := c.queue[0] + copy(c.queue, c.queue[1:]) + c.queue = c.queue[:len(c.queue)-1] + return p +}