p2p/discover: implement node bonding

This a fix for an attack vector where the discovery protocol could be
used to amplify traffic in a DDOS attack. A malicious actor would send a
findnode request with the IP address and UDP port of the target as the
source address. The recipient of the findnode packet would then send a
neighbors packet (which is 16x the size of findnode) to the victim.

Our solution is to require a 'bond' with the sender of findnode. If no
bond exists, the findnode packet is not processed. A bond between nodes
α and β is created when α replies to a ping from β.

This (initial) version of the bonding implementation might still be
vulnerable against replay attacks during the expiration time window.
We will add stricter source address validation later.
pull/592/head
Felix Lange 10 years ago
parent 92928309b2
commit de7af720d6
  1. 43
      p2p/discover/node.go
  2. 173
      p2p/discover/table.go
  3. 138
      p2p/discover/table_test.go
  4. 188
      p2p/discover/udp.go
  5. 396
      p2p/discover/udp_test.go

@ -13,6 +13,8 @@ import (
"net/url" "net/url"
"strconv" "strconv"
"strings" "strings"
"sync"
"sync/atomic"
"time" "time"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
@ -30,7 +32,8 @@ type Node struct {
DiscPort int // UDP listening port for discovery protocol DiscPort int // UDP listening port for discovery protocol
TCPPort int // TCP listening port for RLPx TCPPort int // TCP listening port for RLPx
active time.Time // this must be set/read using atomic load and store.
activeStamp int64
} }
func newNode(id NodeID, addr *net.UDPAddr) *Node { func newNode(id NodeID, addr *net.UDPAddr) *Node {
@ -39,7 +42,6 @@ func newNode(id NodeID, addr *net.UDPAddr) *Node {
IP: addr.IP, IP: addr.IP,
DiscPort: addr.Port, DiscPort: addr.Port,
TCPPort: addr.Port, TCPPort: addr.Port,
active: time.Now(),
} }
} }
@ -48,6 +50,20 @@ func (n *Node) isValid() bool {
return !n.IP.IsMulticast() && !n.IP.IsUnspecified() && n.TCPPort != 0 && n.DiscPort != 0 return !n.IP.IsMulticast() && !n.IP.IsUnspecified() && n.TCPPort != 0 && n.DiscPort != 0
} }
func (n *Node) bumpActive() {
stamp := time.Now().Unix()
atomic.StoreInt64(&n.activeStamp, stamp)
}
func (n *Node) active() time.Time {
stamp := atomic.LoadInt64(&n.activeStamp)
return time.Unix(stamp, 0)
}
func (n *Node) addr() *net.UDPAddr {
return &net.UDPAddr{IP: n.IP, Port: n.DiscPort}
}
// The string representation of a Node is a URL. // The string representation of a Node is a URL.
// Please see ParseNode for a description of the format. // Please see ParseNode for a description of the format.
func (n *Node) String() string { func (n *Node) String() string {
@ -304,3 +320,26 @@ func randomID(a NodeID, n int) (b NodeID) {
} }
return b return b
} }
// nodeDB stores all nodes we know about.
type nodeDB struct {
mu sync.RWMutex
byID map[NodeID]*Node
}
func (db *nodeDB) get(id NodeID) *Node {
db.mu.RLock()
defer db.mu.RUnlock()
return db.byID[id]
}
func (db *nodeDB) add(id NodeID, addr *net.UDPAddr, tcpPort uint16) *Node {
db.mu.Lock()
defer db.mu.Unlock()
if db.byID == nil {
db.byID = make(map[NodeID]*Node)
}
n := &Node{ID: id, IP: addr.IP, DiscPort: addr.Port, TCPPort: int(tcpPort)}
db.byID[n.ID] = n
return n
}

@ -17,6 +17,7 @@ const (
alpha = 3 // Kademlia concurrency factor alpha = 3 // Kademlia concurrency factor
bucketSize = 16 // Kademlia bucket size bucketSize = 16 // Kademlia bucket size
nBuckets = nodeIDBits + 1 // Number of buckets nBuckets = nodeIDBits + 1 // Number of buckets
maxBondingPingPongs = 10
) )
type Table struct { type Table struct {
@ -24,27 +25,50 @@ type Table struct {
buckets [nBuckets]*bucket // index of known nodes by distance buckets [nBuckets]*bucket // index of known nodes by distance
nursery []*Node // bootstrap nodes nursery []*Node // bootstrap nodes
bondmu sync.Mutex
bonding map[NodeID]*bondproc
bondslots chan struct{} // limits total number of active bonding processes
net transport net transport
self *Node // metadata of the local node self *Node // metadata of the local node
db *nodeDB
}
type bondproc struct {
err error
n *Node
done chan struct{}
} }
// transport is implemented by the UDP transport. // transport is implemented by the UDP transport.
// it is an interface so we can test without opening lots of UDP // it is an interface so we can test without opening lots of UDP
// sockets and without generating a private key. // sockets and without generating a private key.
type transport interface { type transport interface {
ping(*Node) error ping(NodeID, *net.UDPAddr) error
findnode(e *Node, target NodeID) ([]*Node, error) waitping(NodeID) error
findnode(toid NodeID, addr *net.UDPAddr, target NodeID) ([]*Node, error)
close() close()
} }
// bucket contains nodes, ordered by their last activity. // bucket contains nodes, ordered by their last activity.
// the entry that was most recently active is the last element
// in entries.
type bucket struct { type bucket struct {
lastLookup time.Time lastLookup time.Time
entries []*Node entries []*Node
} }
func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr) *Table { func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr) *Table {
tab := &Table{net: t, self: newNode(ourID, ourAddr)} tab := &Table{
net: t,
db: new(nodeDB),
self: newNode(ourID, ourAddr),
bonding: make(map[NodeID]*bondproc),
bondslots: make(chan struct{}, maxBondingPingPongs),
}
for i := 0; i < cap(tab.bondslots); i++ {
tab.bondslots <- struct{}{}
}
for i := range tab.buckets { for i := range tab.buckets {
tab.buckets[i] = new(bucket) tab.buckets[i] = new(bucket)
} }
@ -107,8 +131,8 @@ func (tab *Table) Lookup(target NodeID) []*Node {
asked[n.ID] = true asked[n.ID] = true
pendingQueries++ pendingQueries++
go func() { go func() {
result, _ := tab.net.findnode(n, target) r, _ := tab.net.findnode(n.ID, n.addr(), target)
reply <- result reply <- tab.bondall(r)
}() }()
} }
} }
@ -116,13 +140,11 @@ func (tab *Table) Lookup(target NodeID) []*Node {
// we have asked all closest nodes, stop the search // we have asked all closest nodes, stop the search
break break
} }
// wait for the next reply // wait for the next reply
for _, n := range <-reply { for _, n := range <-reply {
cn := n if n != nil && !seen[n.ID] {
if !seen[n.ID] {
seen[n.ID] = true seen[n.ID] = true
result.push(cn, bucketSize) result.push(n, bucketSize)
} }
} }
pendingQueries-- pendingQueries--
@ -145,8 +167,9 @@ func (tab *Table) refresh() {
result := tab.Lookup(randomID(tab.self.ID, ld)) result := tab.Lookup(randomID(tab.self.ID, ld))
if len(result) == 0 { if len(result) == 0 {
// bootstrap the table with a self lookup // bootstrap the table with a self lookup
all := tab.bondall(tab.nursery)
tab.mutex.Lock() tab.mutex.Lock()
tab.add(tab.nursery) tab.add(all)
tab.mutex.Unlock() tab.mutex.Unlock()
tab.Lookup(tab.self.ID) tab.Lookup(tab.self.ID)
// TODO: the Kademlia paper says that we're supposed to perform // TODO: the Kademlia paper says that we're supposed to perform
@ -176,45 +199,105 @@ func (tab *Table) len() (n int) {
return n return n
} }
// bumpOrAdd updates the activity timestamp for the given node and // bondall bonds with all given nodes concurrently and returns
// attempts to insert the node into a bucket. The returned Node might // those nodes for which bonding has probably succeeded.
// not be part of the table. The caller must hold tab.mutex. func (tab *Table) bondall(nodes []*Node) (result []*Node) {
func (tab *Table) bumpOrAdd(node NodeID, from *net.UDPAddr) (n *Node) { rc := make(chan *Node, len(nodes))
b := tab.buckets[logdist(tab.self.ID, node)] for i := range nodes {
if n = b.bump(node); n == nil { go func(n *Node) {
n = newNode(node, from) nn, _ := tab.bond(false, n.ID, n.addr(), uint16(n.TCPPort))
if len(b.entries) == bucketSize { rc <- nn
tab.pingReplace(n, b) }(nodes[i])
}
for _ = range nodes {
if n := <-rc; n != nil {
result = append(result, n)
}
}
return result
}
// bond ensures the local node has a bond with the given remote node.
// It also attempts to insert the node into the table if bonding succeeds.
// The caller must not hold tab.mutex.
//
// A bond is must be established before sending findnode requests.
// Both sides must have completed a ping/pong exchange for a bond to
// exist. The total number of active bonding processes is limited in
// order to restrain network use.
//
// bond is meant to operate idempotently in that bonding with a remote
// node which still remembers a previously established bond will work.
// The remote node will simply not send a ping back, causing waitping
// to time out.
//
// If pinged is true, the remote node has just pinged us and one half
// of the process can be skipped.
func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16) (*Node, error) {
var n *Node
if n = tab.db.get(id); n == nil {
tab.bondmu.Lock()
w := tab.bonding[id]
if w != nil {
// Wait for an existing bonding process to complete.
tab.bondmu.Unlock()
<-w.done
} else { } else {
b.entries = append(b.entries, n) // Register a new bonding process.
w = &bondproc{done: make(chan struct{})}
tab.bonding[id] = w
tab.bondmu.Unlock()
// Do the ping/pong. The result goes into w.
tab.pingpong(w, pinged, id, addr, tcpPort)
// Unregister the process after it's done.
tab.bondmu.Lock()
delete(tab.bonding, id)
tab.bondmu.Unlock()
}
n = w.n
if w.err != nil {
return nil, w.err
} }
} }
return n tab.mutex.Lock()
defer tab.mutex.Unlock()
if b := tab.buckets[logdist(tab.self.ID, n.ID)]; !b.bump(n) {
tab.pingreplace(n, b)
}
return n, nil
} }
func (tab *Table) pingReplace(n *Node, b *bucket) { func (tab *Table) pingpong(w *bondproc, pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16) {
old := b.entries[bucketSize-1] <-tab.bondslots
go func() { defer func() { tab.bondslots <- struct{}{} }()
if err := tab.net.ping(old); err == nil { if w.err = tab.net.ping(id, addr); w.err != nil {
// it responded, we don't need to replace it. close(w.done)
return return
} }
// it didn't respond, replace the node if it is still the oldest node. if !pinged {
tab.mutex.Lock() // Give the remote node a chance to ping us before we start
if len(b.entries) > 0 && b.entries[len(b.entries)-1] == old { // sending findnode requests. If they still remember us,
// slide down other entries and put the new one in front. // waitping will simply time out.
// TODO: insert in correct position to keep the order tab.net.waitping(id)
copy(b.entries[1:], b.entries)
b.entries[0] = n
} }
tab.mutex.Unlock() w.n = tab.db.add(id, addr, tcpPort)
}() close(w.done)
} }
// bump updates the activity timestamp for the given node. func (tab *Table) pingreplace(new *Node, b *bucket) {
// The caller must hold tab.mutex. if len(b.entries) == bucketSize {
func (tab *Table) bump(node NodeID) { oldest := b.entries[bucketSize-1]
tab.buckets[logdist(tab.self.ID, node)].bump(node) if err := tab.net.ping(oldest.ID, oldest.addr()); err == nil {
// The node responded, we don't need to replace it.
return
}
} else {
// Add a slot at the end so the last entry doesn't
// fall off when adding the new node.
b.entries = append(b.entries, nil)
}
copy(b.entries[1:], b.entries)
b.entries[0] = new
} }
// add puts the entries into the table if their corresponding // add puts the entries into the table if their corresponding
@ -240,17 +323,17 @@ outer:
} }
} }
func (b *bucket) bump(id NodeID) *Node { func (b *bucket) bump(n *Node) bool {
for i, n := range b.entries { for i := range b.entries {
if n.ID == id { if b.entries[i].ID == n.ID {
n.active = time.Now() n.bumpActive()
// move it to the front // move it to the front
copy(b.entries[1:], b.entries[:i+1]) copy(b.entries[1:], b.entries[:i+1])
b.entries[0] = n b.entries[0] = n
return n return true
} }
} }
return nil return false
} }
// nodesByDistance is a list of nodes, ordered by // nodesByDistance is a list of nodes, ordered by

@ -2,80 +2,69 @@ package discover
import ( import (
"crypto/ecdsa" "crypto/ecdsa"
"errors"
"fmt" "fmt"
"math/rand" "math/rand"
"net" "net"
"reflect" "reflect"
"testing" "testing"
"testing/quick" "testing/quick"
"time"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
) )
func TestTable_bumpOrAddBucketAssign(t *testing.T) { func TestTable_pingReplace(t *testing.T) {
tab := newTable(nil, NodeID{}, &net.UDPAddr{}) doit := func(newNodeIsResponding, lastInBucketIsResponding bool) {
for i := 1; i < len(tab.buckets); i++ { transport := newPingRecorder()
tab.bumpOrAdd(randomID(tab.self.ID, i), &net.UDPAddr{}) tab := newTable(transport, NodeID{}, &net.UDPAddr{})
}
for i, b := range tab.buckets {
if i > 0 && len(b.entries) != 1 {
t.Errorf("bucket %d has %d entries, want 1", i, len(b.entries))
}
}
}
func TestTable_bumpOrAddPingReplace(t *testing.T) {
pingC := make(pingC)
tab := newTable(pingC, NodeID{}, &net.UDPAddr{})
last := fillBucket(tab, 200) last := fillBucket(tab, 200)
pingSender := randomID(tab.self.ID, 200)
// this bumpOrAdd should not replace the last node // this gotPing should replace the last node
// because the node replies to ping. // if the last node is not responding.
new := tab.bumpOrAdd(randomID(tab.self.ID, 200), &net.UDPAddr{}) transport.responding[last.ID] = lastInBucketIsResponding
transport.responding[pingSender] = newNodeIsResponding
tab.bond(true, pingSender, &net.UDPAddr{}, 0)
pinged := <-pingC // first ping goes to sender (bonding pingback)
if pinged != last.ID { if !transport.pinged[pingSender] {
t.Fatalf("pinged wrong node: %v\nwant %v", pinged, last.ID) t.Error("table did not ping back sender")
}
if newNodeIsResponding {
// second ping goes to oldest node in bucket
// to see whether it is still alive.
if !transport.pinged[last.ID] {
t.Error("table did not ping last node in bucket")
}
} }
tab.mutex.Lock() tab.mutex.Lock()
defer tab.mutex.Unlock() defer tab.mutex.Unlock()
if l := len(tab.buckets[200].entries); l != bucketSize { if l := len(tab.buckets[200].entries); l != bucketSize {
t.Errorf("wrong bucket size after bumpOrAdd: got %d, want %d", bucketSize, l) t.Errorf("wrong bucket size after gotPing: got %d, want %d", bucketSize, l)
} }
if lastInBucketIsResponding || !newNodeIsResponding {
if !contains(tab.buckets[200].entries, last.ID) { if !contains(tab.buckets[200].entries, last.ID) {
t.Error("last entry was removed") t.Error("last entry was removed")
} }
if contains(tab.buckets[200].entries, new.ID) { if contains(tab.buckets[200].entries, pingSender) {
t.Error("new entry was added") t.Error("new entry was added")
} }
} } else {
func TestTable_bumpOrAddPingTimeout(t *testing.T) {
tab := newTable(pingC(nil), NodeID{}, &net.UDPAddr{})
last := fillBucket(tab, 200)
// this bumpOrAdd should replace the last node
// because the node does not reply to ping.
new := tab.bumpOrAdd(randomID(tab.self.ID, 200), &net.UDPAddr{})
// wait for async bucket update. damn. this needs to go away.
time.Sleep(2 * time.Millisecond)
tab.mutex.Lock()
defer tab.mutex.Unlock()
if l := len(tab.buckets[200].entries); l != bucketSize {
t.Errorf("wrong bucket size after bumpOrAdd: got %d, want %d", bucketSize, l)
}
if contains(tab.buckets[200].entries, last.ID) { if contains(tab.buckets[200].entries, last.ID) {
t.Error("last entry was not removed") t.Error("last entry was not removed")
} }
if !contains(tab.buckets[200].entries, new.ID) { if !contains(tab.buckets[200].entries, pingSender) {
t.Error("new entry was not added") t.Error("new entry was not added")
} }
} }
}
doit(true, true)
doit(false, true)
doit(false, true)
doit(false, false)
}
func fillBucket(tab *Table, ld int) (last *Node) { func fillBucket(tab *Table, ld int) (last *Node) {
b := tab.buckets[ld] b := tab.buckets[ld]
@ -85,44 +74,27 @@ func fillBucket(tab *Table, ld int) (last *Node) {
return b.entries[bucketSize-1] return b.entries[bucketSize-1]
} }
type pingC chan NodeID type pingRecorder struct{ responding, pinged map[NodeID]bool }
func (t pingC) findnode(n *Node, target NodeID) ([]*Node, error) { func newPingRecorder() *pingRecorder {
return &pingRecorder{make(map[NodeID]bool), make(map[NodeID]bool)}
}
func (t *pingRecorder) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) {
panic("findnode called on pingRecorder") panic("findnode called on pingRecorder")
} }
func (t pingC) close() { func (t *pingRecorder) close() {
panic("close called on pingRecorder") panic("close called on pingRecorder")
} }
func (t pingC) ping(n *Node) error { func (t *pingRecorder) waitping(from NodeID) error {
if t == nil { return nil // remote always pings
return errTimeout
} }
t <- n.ID func (t *pingRecorder) ping(toid NodeID, toaddr *net.UDPAddr) error {
t.pinged[toid] = true
if t.responding[toid] {
return nil return nil
} } else {
return errTimeout
func TestTable_bump(t *testing.T) {
tab := newTable(nil, NodeID{}, &net.UDPAddr{})
// add an old entry and two recent ones
oldactive := time.Now().Add(-2 * time.Minute)
old := &Node{ID: randomID(tab.self.ID, 200), active: oldactive}
others := []*Node{
&Node{ID: randomID(tab.self.ID, 200), active: time.Now()},
&Node{ID: randomID(tab.self.ID, 200), active: time.Now()},
}
tab.add(append(others, old))
if tab.buckets[200].entries[0] == old {
t.Fatal("old entry is at front of bucket")
}
// bumping the old entry should move it to the front
tab.bump(old.ID)
if old.active == oldactive {
t.Error("activity timestamp not updated")
}
if tab.buckets[200].entries[0] != old {
t.Errorf("bumped entry did not move to the front of bucket")
} }
} }
@ -210,7 +182,7 @@ func TestTable_Lookup(t *testing.T) {
t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results) t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results)
} }
// seed table with initial node (otherwise lookup will terminate immediately) // seed table with initial node (otherwise lookup will terminate immediately)
tab.bumpOrAdd(randomID(target, 200), &net.UDPAddr{Port: 200}) tab.add([]*Node{newNode(randomID(target, 200), &net.UDPAddr{Port: 200})})
results := tab.Lookup(target) results := tab.Lookup(target)
t.Logf("results:") t.Logf("results:")
@ -238,16 +210,16 @@ type findnodeOracle struct {
target NodeID target NodeID
} }
func (t findnodeOracle) findnode(n *Node, target NodeID) ([]*Node, error) { func (t findnodeOracle) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) {
t.t.Logf("findnode query at dist %d", n.DiscPort) t.t.Logf("findnode query at dist %d", toaddr.Port)
// current log distance is encoded in port number // current log distance is encoded in port number
var result []*Node var result []*Node
switch n.DiscPort { switch toaddr.Port {
case 0: case 0:
panic("query to node at distance 0") panic("query to node at distance 0")
default: default:
// TODO: add more randomness to distances // TODO: add more randomness to distances
next := n.DiscPort - 1 next := toaddr.Port - 1
for i := 0; i < bucketSize; i++ { for i := 0; i < bucketSize; i++ {
result = append(result, &Node{ID: randomID(t.target, next), DiscPort: next}) result = append(result, &Node{ID: randomID(t.target, next), DiscPort: next})
} }
@ -256,10 +228,8 @@ func (t findnodeOracle) findnode(n *Node, target NodeID) ([]*Node, error) {
} }
func (t findnodeOracle) close() {} func (t findnodeOracle) close() {}
func (t findnodeOracle) waitping(from NodeID) error { return nil }
func (t findnodeOracle) ping(n *Node) error { func (t findnodeOracle) ping(toid NodeID, toaddr *net.UDPAddr) error { return nil }
return errors.New("ping is not supported by this transport")
}
func hasDuplicates(slice []*Node) bool { func hasDuplicates(slice []*Node) bool {
seen := make(map[NodeID]bool) seen := make(map[NodeID]bool)

@ -24,6 +24,8 @@ var (
errBadHash = errors.New("bad hash") errBadHash = errors.New("bad hash")
errExpired = errors.New("expired") errExpired = errors.New("expired")
errBadVersion = errors.New("version mismatch") errBadVersion = errors.New("version mismatch")
errUnsolicitedReply = errors.New("unsolicited reply")
errUnknownNode = errors.New("unknown node")
errTimeout = errors.New("RPC timeout") errTimeout = errors.New("RPC timeout")
errClosed = errors.New("socket closed") errClosed = errors.New("socket closed")
) )
@ -80,12 +82,25 @@ type rpcNode struct {
ID NodeID ID NodeID
} }
type packet interface {
handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error
}
type conn interface {
ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error)
WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error)
Close() error
LocalAddr() net.Addr
}
// udp implements the RPC protocol. // udp implements the RPC protocol.
type udp struct { type udp struct {
conn *net.UDPConn conn conn
priv *ecdsa.PrivateKey priv *ecdsa.PrivateKey
addpending chan *pending addpending chan *pending
replies chan reply gotreply chan reply
closing chan struct{} closing chan struct{}
nat nat.Interface nat nat.Interface
@ -124,6 +139,9 @@ type reply struct {
from NodeID from NodeID
ptype byte ptype byte
data interface{} data interface{}
// loop indicates whether there was
// a matching request by sending on this channel.
matched chan<- bool
} }
// ListenUDP returns a new table that listens for UDP packets on laddr. // ListenUDP returns a new table that listens for UDP packets on laddr.
@ -136,15 +154,20 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface) (*Table
if err != nil { if err != nil {
return nil, err return nil, err
} }
tab, _ := newUDP(priv, conn, natm)
log.Infoln("Listening,", tab.self)
return tab, nil
}
func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface) (*Table, *udp) {
udp := &udp{ udp := &udp{
conn: conn, conn: c,
priv: priv, priv: priv,
closing: make(chan struct{}), closing: make(chan struct{}),
gotreply: make(chan reply),
addpending: make(chan *pending), addpending: make(chan *pending),
replies: make(chan reply),
} }
realaddr := c.LocalAddr().(*net.UDPAddr)
realaddr := conn.LocalAddr().(*net.UDPAddr)
if natm != nil { if natm != nil {
if !realaddr.IP.IsLoopback() { if !realaddr.IP.IsLoopback() {
go nat.Map(natm, udp.closing, "udp", realaddr.Port, realaddr.Port, "ethereum discovery") go nat.Map(natm, udp.closing, "udp", realaddr.Port, realaddr.Port, "ethereum discovery")
@ -155,11 +178,9 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface) (*Table
} }
} }
udp.Table = newTable(udp, PubkeyID(&priv.PublicKey), realaddr) udp.Table = newTable(udp, PubkeyID(&priv.PublicKey), realaddr)
go udp.loop() go udp.loop()
go udp.readLoop() go udp.readLoop()
log.Infoln("Listening, ", udp.self) return udp.Table, udp
return udp.Table, nil
} }
func (t *udp) close() { func (t *udp) close() {
@ -169,10 +190,10 @@ func (t *udp) close() {
} }
// ping sends a ping message to the given node and waits for a reply. // ping sends a ping message to the given node and waits for a reply.
func (t *udp) ping(e *Node) error { func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error {
// TODO: maybe check for ReplyTo field in callback to measure RTT // TODO: maybe check for ReplyTo field in callback to measure RTT
errc := t.pending(e.ID, pongPacket, func(interface{}) bool { return true }) errc := t.pending(toid, pongPacket, func(interface{}) bool { return true })
t.send(e, pingPacket, ping{ t.send(toaddr, pingPacket, ping{
Version: Version, Version: Version,
IP: t.self.IP.String(), IP: t.self.IP.String(),
Port: uint16(t.self.TCPPort), Port: uint16(t.self.TCPPort),
@ -181,12 +202,16 @@ func (t *udp) ping(e *Node) error {
return <-errc return <-errc
} }
func (t *udp) waitping(from NodeID) error {
return <-t.pending(from, pingPacket, func(interface{}) bool { return true })
}
// findnode sends a findnode request to the given node and waits until // findnode sends a findnode request to the given node and waits until
// the node has sent up to k neighbors. // the node has sent up to k neighbors.
func (t *udp) findnode(to *Node, target NodeID) ([]*Node, error) { func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) {
nodes := make([]*Node, 0, bucketSize) nodes := make([]*Node, 0, bucketSize)
nreceived := 0 nreceived := 0
errc := t.pending(to.ID, neighborsPacket, func(r interface{}) bool { errc := t.pending(toid, neighborsPacket, func(r interface{}) bool {
reply := r.(*neighbors) reply := r.(*neighbors)
for _, n := range reply.Nodes { for _, n := range reply.Nodes {
nreceived++ nreceived++
@ -196,8 +221,7 @@ func (t *udp) findnode(to *Node, target NodeID) ([]*Node, error) {
} }
return nreceived >= bucketSize return nreceived >= bucketSize
}) })
t.send(toaddr, findnodePacket, findnode{
t.send(to, findnodePacket, findnode{
Target: target, Target: target,
Expiration: uint64(time.Now().Add(expiration).Unix()), Expiration: uint64(time.Now().Add(expiration).Unix()),
}) })
@ -219,6 +243,17 @@ func (t *udp) pending(id NodeID, ptype byte, callback func(interface{}) bool) <-
return ch return ch
} }
func (t *udp) handleReply(from NodeID, ptype byte, req packet) bool {
matched := make(chan bool)
select {
case t.gotreply <- reply{from, ptype, req, matched}:
// loop will handle it
return <-matched
case <-t.closing:
return false
}
}
// loop runs in its own goroutin. it keeps track of // loop runs in its own goroutin. it keeps track of
// the refresh timer and the pending reply queue. // the refresh timer and the pending reply queue.
func (t *udp) loop() { func (t *udp) loop() {
@ -249,6 +284,7 @@ func (t *udp) loop() {
for _, p := range pending { for _, p := range pending {
p.errc <- errClosed p.errc <- errClosed
} }
pending = nil
return return
case p := <-t.addpending: case p := <-t.addpending:
@ -256,18 +292,21 @@ func (t *udp) loop() {
pending = append(pending, p) pending = append(pending, p)
rearmTimeout() rearmTimeout()
case reply := <-t.replies: case r := <-t.gotreply:
// run matching callbacks, remove if they return false. var matched bool
for i := 0; i < len(pending); i++ { for i := 0; i < len(pending); i++ {
p := pending[i] if p := pending[i]; p.from == r.from && p.ptype == r.ptype {
if reply.from == p.from && reply.ptype == p.ptype && p.callback(reply.data) { matched = true
if p.callback(r.data) {
// callback indicates the request is done, remove it.
p.errc <- nil p.errc <- nil
copy(pending[i:], pending[i+1:]) copy(pending[i:], pending[i+1:])
pending = pending[:len(pending)-1] pending = pending[:len(pending)-1]
i-- i--
} }
} }
rearmTimeout() }
r.matched <- matched
case now := <-timeout.C: case now := <-timeout.C:
// notify and remove callbacks whose deadline is in the past. // notify and remove callbacks whose deadline is in the past.
@ -292,33 +331,38 @@ const (
var headSpace = make([]byte, headSize) var headSpace = make([]byte, headSize)
func (t *udp) send(to *Node, ptype byte, req interface{}) error { func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req interface{}) error {
packet, err := encodePacket(t.priv, ptype, req)
if err != nil {
return err
}
log.DebugDetailf(">>> %v %T %v\n", toaddr, req, req)
if _, err = t.conn.WriteToUDP(packet, toaddr); err != nil {
log.DebugDetailln("UDP send failed:", err)
}
return err
}
func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) ([]byte, error) {
b := new(bytes.Buffer) b := new(bytes.Buffer)
b.Write(headSpace) b.Write(headSpace)
b.WriteByte(ptype) b.WriteByte(ptype)
if err := rlp.Encode(b, req); err != nil { if err := rlp.Encode(b, req); err != nil {
log.Errorln("error encoding packet:", err) log.Errorln("error encoding packet:", err)
return err return nil, err
} }
packet := b.Bytes() packet := b.Bytes()
sig, err := crypto.Sign(crypto.Sha3(packet[headSize:]), t.priv) sig, err := crypto.Sign(crypto.Sha3(packet[headSize:]), priv)
if err != nil { if err != nil {
log.Errorln("could not sign packet:", err) log.Errorln("could not sign packet:", err)
return err return nil, err
} }
copy(packet[macSize:], sig) copy(packet[macSize:], sig)
// add the hash to the front. Note: this doesn't protect the // add the hash to the front. Note: this doesn't protect the
// packet in any way. Our public key will be part of this hash in // packet in any way. Our public key will be part of this hash in
// the future. // The future.
copy(packet, crypto.Sha3(packet[macSize:])) copy(packet, crypto.Sha3(packet[macSize:]))
return packet, nil
toaddr := &net.UDPAddr{IP: to.IP, Port: to.DiscPort}
log.DebugDetailf(">>> %v %T %v\n", toaddr, req, req)
if _, err = t.conn.WriteToUDP(packet, toaddr); err != nil {
log.DebugDetailln("UDP send failed:", err)
}
return err
} }
// readLoop runs in its own goroutine. it handles incoming UDP packets. // readLoop runs in its own goroutine. it handles incoming UDP packets.
@ -330,29 +374,34 @@ func (t *udp) readLoop() {
if err != nil { if err != nil {
return return
} }
if err := t.packetIn(from, buf[:nbytes]); err != nil { packet, fromID, hash, err := decodePacket(buf[:nbytes])
if err != nil {
log.Debugf("Bad packet from %v: %v\n", from, err) log.Debugf("Bad packet from %v: %v\n", from, err)
continue
}
log.DebugDetailf("<<< %v %T %v\n", from, packet, packet)
go func() {
if err := packet.handle(t, from, fromID, hash); err != nil {
log.Debugf("error handling %T from %v: %v", packet, from, err)
} }
}()
} }
} }
func (t *udp) packetIn(from *net.UDPAddr, buf []byte) error { func decodePacket(buf []byte) (packet, NodeID, []byte, error) {
if len(buf) < headSize+1 { if len(buf) < headSize+1 {
return errPacketTooSmall return nil, NodeID{}, nil, errPacketTooSmall
} }
hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:] hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:]
shouldhash := crypto.Sha3(buf[macSize:]) shouldhash := crypto.Sha3(buf[macSize:])
if !bytes.Equal(hash, shouldhash) { if !bytes.Equal(hash, shouldhash) {
return errBadHash return nil, NodeID{}, nil, errBadHash
} }
fromID, err := recoverNodeID(crypto.Sha3(buf[headSize:]), sig) fromID, err := recoverNodeID(crypto.Sha3(buf[headSize:]), sig)
if err != nil { if err != nil {
return err return nil, NodeID{}, hash, err
}
var req interface {
handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error
} }
var req packet
switch ptype := sigdata[0]; ptype { switch ptype := sigdata[0]; ptype {
case pingPacket: case pingPacket:
req = new(ping) req = new(ping)
@ -363,13 +412,10 @@ func (t *udp) packetIn(from *net.UDPAddr, buf []byte) error {
case neighborsPacket: case neighborsPacket:
req = new(neighbors) req = new(neighbors)
default: default:
return fmt.Errorf("unknown type: %d", ptype) return nil, fromID, hash, fmt.Errorf("unknown type: %d", ptype)
}
if err := rlp.Decode(bytes.NewReader(sigdata[1:]), req); err != nil {
return err
} }
log.DebugDetailf("<<< %v %T %v\n", from, req, req) err = rlp.Decode(bytes.NewReader(sigdata[1:]), req)
return req.handle(t, from, fromID, hash) return req, fromID, hash, err
} }
func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
@ -379,18 +425,14 @@ func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) er
if req.Version != Version { if req.Version != Version {
return errBadVersion return errBadVersion
} }
t.mutex.Lock() t.send(from, pongPacket, pong{
// Note: we're ignoring the provided IP address right now
n := t.bumpOrAdd(fromID, from)
if req.Port != 0 {
n.TCPPort = int(req.Port)
}
t.mutex.Unlock()
t.send(n, pongPacket, pong{
ReplyTok: mac, ReplyTok: mac,
Expiration: uint64(time.Now().Add(expiration).Unix()), Expiration: uint64(time.Now().Add(expiration).Unix()),
}) })
if !t.handleReply(fromID, pingPacket, req) {
// Note: we're ignoring the provided IP address right now
t.bond(true, fromID, from, req.Port)
}
return nil return nil
} }
@ -398,11 +440,9 @@ func (req *pong) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) er
if expired(req.Expiration) { if expired(req.Expiration) {
return errExpired return errExpired
} }
t.mutex.Lock() if !t.handleReply(fromID, pongPacket, req) {
t.bump(fromID) return errUnsolicitedReply
t.mutex.Unlock() }
t.replies <- reply{fromID, pongPacket, req}
return nil return nil
} }
@ -410,12 +450,21 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte
if expired(req.Expiration) { if expired(req.Expiration) {
return errExpired return errExpired
} }
if t.db.get(fromID) == nil {
// No bond exists, we don't process the packet. This prevents
// an attack vector where the discovery protocol could be used
// to amplify traffic in a DDOS attack. A malicious actor
// would send a findnode request with the IP address and UDP
// port of the target as the source address. The recipient of
// the findnode packet would then send a neighbors packet
// (which is a much bigger packet than findnode) to the victim.
return errUnknownNode
}
t.mutex.Lock() t.mutex.Lock()
e := t.bumpOrAdd(fromID, from)
closest := t.closest(req.Target, bucketSize).entries closest := t.closest(req.Target, bucketSize).entries
t.mutex.Unlock() t.mutex.Unlock()
t.send(e, neighborsPacket, neighbors{ t.send(from, neighborsPacket, neighbors{
Nodes: closest, Nodes: closest,
Expiration: uint64(time.Now().Add(expiration).Unix()), Expiration: uint64(time.Now().Add(expiration).Unix()),
}) })
@ -426,12 +475,9 @@ func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byt
if expired(req.Expiration) { if expired(req.Expiration) {
return errExpired return errExpired
} }
t.mutex.Lock() if !t.handleReply(fromID, neighborsPacket, req) {
t.bump(fromID) return errUnsolicitedReply
t.add(req.Nodes) }
t.mutex.Unlock()
t.replies <- reply{fromID, neighborsPacket, req}
return nil return nil
} }

@ -1,10 +1,18 @@
package discover package discover
import ( import (
"bytes"
"crypto/ecdsa"
"errors"
"fmt" "fmt"
"io"
logpkg "log" logpkg "log"
"net" "net"
"os" "os"
"path"
"reflect"
"runtime"
"sync"
"testing" "testing"
"time" "time"
@ -15,197 +23,317 @@ func init() {
logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, logpkg.LstdFlags, logger.ErrorLevel)) logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, logpkg.LstdFlags, logger.ErrorLevel))
} }
func TestUDP_ping(t *testing.T) { type udpTest struct {
t.Parallel() t *testing.T
pipe *dgramPipe
n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil) table *Table
n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil) udp *udp
defer n1.Close() sent [][]byte
defer n2.Close() localkey, remotekey *ecdsa.PrivateKey
remoteaddr *net.UDPAddr
if err := n1.net.ping(n2.self); err != nil {
t.Fatalf("ping error: %v", err)
} }
if find(n2, n1.self.ID) == nil {
t.Errorf("node 2 does not contain id of node 1") func newUDPTest(t *testing.T) *udpTest {
} test := &udpTest{
if e := find(n1, n2.self.ID); e != nil { t: t,
t.Errorf("node 1 does contains id of node 2: %v", e) pipe: newpipe(),
localkey: newkey(),
remotekey: newkey(),
remoteaddr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 30303},
} }
test.table, test.udp = newUDP(test.localkey, test.pipe, nil)
return test
} }
func find(tab *Table, id NodeID) *Node { // handles a packet as if it had been sent to the transport.
for _, b := range tab.buckets { func (test *udpTest) packetIn(wantError error, ptype byte, data packet) error {
for _, e := range b.entries { enc, err := encodePacket(test.remotekey, ptype, data)
if e.ID == id { if err != nil {
return e return test.errorf("packet (%d) encode error: %v", err)
}
} }
test.sent = append(test.sent, enc)
err = data.handle(test.udp, test.remoteaddr, PubkeyID(&test.remotekey.PublicKey), enc[:macSize])
if err != wantError {
return test.errorf("error mismatch: got %q, want %q", err, wantError)
} }
return nil return nil
} }
func TestUDP_findnode(t *testing.T) { // waits for a packet to be sent by the transport.
t.Parallel() // validate should have type func(*udpTest, X) error, where X is a packet type.
func (test *udpTest) waitPacketOut(validate interface{}) error {
n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil) dgram := test.pipe.waitPacketOut()
n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil) p, _, _, err := decodePacket(dgram)
defer n1.Close() if err != nil {
defer n2.Close() return test.errorf("sent packet decode error: %v", err)
// put a few nodes into n2. the exact distribution shouldn't
// matter much, altough we need to take care not to overflow
// any bucket.
target := randomID(n1.self.ID, 100)
nodes := &nodesByDistance{target: target}
for i := 0; i < bucketSize; i++ {
n2.add([]*Node{&Node{
IP: net.IP{1, 2, 3, byte(i)},
DiscPort: i + 2,
TCPPort: i + 2,
ID: randomID(n2.self.ID, i+2),
}})
}
n2.add(nodes.entries)
n2.bumpOrAdd(n1.self.ID, &net.UDPAddr{IP: n1.self.IP, Port: n1.self.DiscPort})
expected := n2.closest(target, bucketSize)
err := runUDP(10, func() error {
result, _ := n1.net.findnode(n2.self, target)
if len(result) != bucketSize {
return fmt.Errorf("wrong number of results: got %d, want %d", len(result), bucketSize)
}
for i := range result {
if result[i].ID != expected.entries[i].ID {
return fmt.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, result[i], expected.entries[i])
} }
fn := reflect.ValueOf(validate)
exptype := fn.Type().In(0)
if reflect.TypeOf(p) != exptype {
return test.errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype)
} }
fn.Call([]reflect.Value{reflect.ValueOf(p)})
return nil return nil
})
if err != nil {
t.Error(err)
} }
func (test *udpTest) errorf(format string, args ...interface{}) error {
_, file, line, ok := runtime.Caller(2) // errorf + waitPacketOut
if ok {
file = path.Base(file)
} else {
file = "???"
line = 1
}
err := fmt.Errorf(format, args...)
fmt.Printf("\t%s:%d: %v\n", file, line, err)
test.t.Fail()
return err
} }
func TestUDP_replytimeout(t *testing.T) { // shared test variables
t.Parallel() var (
futureExp = uint64(time.Now().Add(10 * time.Hour).Unix())
testTarget = MustHexID("01010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101")
)
// reserve a port so we don't talk to an existing service by accident func TestUDP_packetErrors(t *testing.T) {
addr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0") test := newUDPTest(t)
fd, err := net.ListenUDP("udp", addr) defer test.table.Close()
if err != nil {
t.Fatal(err) test.packetIn(errExpired, pingPacket, &ping{IP: "foo", Port: 99, Version: Version})
test.packetIn(errBadVersion, pingPacket, &ping{IP: "foo", Port: 99, Version: 99, Expiration: futureExp})
test.packetIn(errUnsolicitedReply, pongPacket, &pong{ReplyTok: []byte{}, Expiration: futureExp})
test.packetIn(errUnknownNode, findnodePacket, &findnode{Expiration: futureExp})
test.packetIn(errUnsolicitedReply, neighborsPacket, &neighbors{Expiration: futureExp})
} }
defer fd.Close()
n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil) func TestUDP_pingTimeout(t *testing.T) {
defer n1.Close() t.Parallel()
n2 := n1.bumpOrAdd(randomID(n1.self.ID, 10), fd.LocalAddr().(*net.UDPAddr)) test := newUDPTest(t)
defer test.table.Close()
if err := n1.net.ping(n2); err != errTimeout { toaddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 2222}
toid := NodeID{1, 2, 3, 4}
if err := test.udp.ping(toid, toaddr); err != errTimeout {
t.Error("expected timeout error, got", err) t.Error("expected timeout error, got", err)
} }
}
func TestUDP_findnodeTimeout(t *testing.T) {
t.Parallel()
test := newUDPTest(t)
defer test.table.Close()
if result, err := n1.net.findnode(n2, n1.self.ID); err != errTimeout { toaddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 2222}
toid := NodeID{1, 2, 3, 4}
target := NodeID{4, 5, 6, 7}
result, err := test.udp.findnode(toid, toaddr, target)
if err != errTimeout {
t.Error("expected timeout error, got", err) t.Error("expected timeout error, got", err)
} else if len(result) > 0 { }
if len(result) > 0 {
t.Error("expected empty result, got", result) t.Error("expected empty result, got", result)
} }
} }
func TestUDP_findnodeMultiReply(t *testing.T) { func TestUDP_findnode(t *testing.T) {
t.Parallel() test := newUDPTest(t)
defer test.table.Close()
n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil) // put a few nodes into the table. their exact
n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil) // distribution shouldn't matter much, altough we need to
udp2 := n2.net.(*udp) // take care not to overflow any bucket.
defer n1.Close() target := testTarget
defer n2.Close() nodes := &nodesByDistance{target: target}
for i := 0; i < bucketSize; i++ {
nodes.push(&Node{
IP: net.IP{1, 2, 3, byte(i)},
DiscPort: i + 2,
TCPPort: i + 2,
ID: randomID(test.table.self.ID, i+2),
}, bucketSize)
}
test.table.add(nodes.entries)
err := runUDP(10, func() error { // ensure there's a bond with the test node,
nodes := make([]*Node, bucketSize) // findnode won't be accepted otherwise.
for i := range nodes { test.table.db.add(PubkeyID(&test.remotekey.PublicKey), test.remoteaddr, 99)
nodes[i] = &Node{
IP: net.IP{1, 2, 3, 4}, // check that closest neighbors are returned.
DiscPort: i + 1, test.packetIn(nil, findnodePacket, &findnode{Target: testTarget, Expiration: futureExp})
TCPPort: i + 1, test.waitPacketOut(func(p *neighbors) {
ID: randomID(n2.self.ID, i+1), expected := test.table.closest(testTarget, bucketSize)
if len(p.Nodes) != bucketSize {
t.Errorf("wrong number of results: got %d, want %d", len(p.Nodes), bucketSize)
}
for i := range p.Nodes {
if p.Nodes[i].ID != expected.entries[i].ID {
t.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, p.Nodes[i], expected.entries[i])
}
} }
})
} }
// ask N2 for neighbors. it will send an empty reply back. func TestUDP_findnodeMultiReply(t *testing.T) {
// the request will wait for up to bucketSize replies. test := newUDPTest(t)
resultc := make(chan []*Node) defer test.table.Close()
errc := make(chan error)
// queue a pending findnode request
resultc, errc := make(chan []*Node), make(chan error)
go func() { go func() {
ns, err := n1.net.findnode(n2.self, n1.self.ID) rid := PubkeyID(&test.remotekey.PublicKey)
if err != nil { ns, err := test.udp.findnode(rid, test.remoteaddr, testTarget)
if err != nil && len(ns) == 0 {
errc <- err errc <- err
} else { } else {
resultc <- ns resultc <- ns
} }
}() }()
// send a few more neighbors packets to N1. // wait for the findnode to be sent.
// it should collect those. // after it is sent, the transport is waiting for a reply
for end := 0; end < len(nodes); { test.waitPacketOut(func(p *findnode) {
off := end if p.Target != testTarget {
if end = end + 5; end > len(nodes) { t.Errorf("wrong target: got %v, want %v", p.Target, testTarget)
end = len(nodes)
} }
udp2.send(n1.self, neighborsPacket, neighbors{
Nodes: nodes[off:end],
Expiration: uint64(time.Now().Add(10 * time.Second).Unix()),
}) })
}
// check that they are all returned. we cannot just check for // send the reply as two packets.
// equality because they might not be returned in the order they list := []*Node{
// were sent. MustParseNode("enode://ba85011c70bcc5c04d8607d3a0ed29aa6179c092cbdda10d5d32684fb33ed01bd94f588ca8f91ac48318087dcb02eaf36773a7a453f0eedd6742af668097b29c@10.0.1.16:30303"),
var result []*Node MustParseNode("enode://81fa361d25f157cd421c60dcc28d8dac5ef6a89476633339c5df30287474520caca09627da18543d9079b5b288698b542d56167aa5c09111e55acdbbdf2ef799@10.0.1.16:30303"),
MustParseNode("enode://9bffefd833d53fac8e652415f4973bee289e8b1a5c6c4cbe70abf817ce8a64cee11b823b66a987f51aaa9fba0d6a91b3e6bf0d5a5d1042de8e9eeea057b217f8@10.0.1.36:30301"),
MustParseNode("enode://1b5b4aa662d7cb44a7221bfba67302590b643028197a7d5214790f3bac7aaa4a3241be9e83c09cf1f6c69d007c634faae3dc1b1221793e8446c0b3a09de65960@10.0.1.16:30303"),
}
test.packetIn(nil, neighborsPacket, &neighbors{Expiration: futureExp, Nodes: list[:2]})
test.packetIn(nil, neighborsPacket, &neighbors{Expiration: futureExp, Nodes: list[2:]})
// check that the sent neighbors are all returned by findnode
select { select {
case result = <-resultc: case result := <-resultc:
if !reflect.DeepEqual(result, list) {
t.Errorf("neighbors mismatch:\n got: %v\n want: %v", result, list)
}
case err := <-errc: case err := <-errc:
return err t.Errorf("findnode error: %v", err)
case <-time.After(5 * time.Second):
t.Error("findnode did not return within 5 seconds")
} }
if hasDuplicates(result) {
return fmt.Errorf("result slice contains duplicates")
} }
if len(result) != len(nodes) {
return fmt.Errorf("wrong number of nodes returned: got %d, want %d", len(result), len(nodes)) func TestUDP_successfulPing(t *testing.T) {
test := newUDPTest(t)
defer test.table.Close()
done := make(chan struct{})
go func() {
test.packetIn(nil, pingPacket, &ping{IP: "foo", Port: 99, Version: Version, Expiration: futureExp})
close(done)
}()
// the ping is replied to.
test.waitPacketOut(func(p *pong) {
pinghash := test.sent[0][:macSize]
if !bytes.Equal(p.ReplyTok, pinghash) {
t.Errorf("got ReplyTok %x, want %x", p.ReplyTok, pinghash)
} }
matched := make(map[NodeID]bool) })
for _, n := range result {
for _, expn := range nodes { // remote is unknown, the table pings back.
if n.ID == expn.ID { // && bytes.Equal(n.Addr.IP, expn.Addr.IP) && n.Addr.Port == expn.Addr.Port { test.waitPacketOut(func(p *ping) error { return nil })
matched[n.ID] = true test.packetIn(nil, pongPacket, &pong{Expiration: futureExp})
// ping should return shortly after getting the pong packet.
<-done
// check that the node was added.
rid := PubkeyID(&test.remotekey.PublicKey)
rnode := find(test.table, rid)
if rnode == nil {
t.Fatalf("node %v not found in table", rid)
} }
if !bytes.Equal(rnode.IP, test.remoteaddr.IP) {
t.Errorf("node has wrong IP: got %v, want: %v", rnode.IP, test.remoteaddr.IP)
}
if rnode.DiscPort != test.remoteaddr.Port {
t.Errorf("node has wrong Port: got %v, want: %v", rnode.DiscPort, test.remoteaddr.Port)
}
if rnode.TCPPort != 99 {
t.Errorf("node has wrong Port: got %v, want: %v", rnode.TCPPort, 99)
}
}
func find(tab *Table, id NodeID) *Node {
for _, b := range tab.buckets {
for _, e := range b.entries {
if e.ID == id {
return e
} }
} }
if len(matched) != len(nodes) {
return fmt.Errorf("wrong number of matching nodes: got %d, want %d", len(matched), len(nodes))
} }
return nil return nil
})
if err != nil {
t.Error(err)
} }
// dgramPipe is a fake UDP socket. It queues all sent datagrams.
type dgramPipe struct {
mu *sync.Mutex
cond *sync.Cond
closing chan struct{}
closed bool
queue [][]byte
} }
// runUDP runs a test n times and returns an error if the test failed func newpipe() *dgramPipe {
// in all n runs. This is necessary because UDP is unreliable even for mu := new(sync.Mutex)
// connections on the local machine, causing test failures. return &dgramPipe{
func runUDP(n int, test func() error) error { closing: make(chan struct{}),
errcount := 0 cond: &sync.Cond{L: mu},
errors := "" mu: mu,
for i := 0; i < n; i++ { }
if err := test(); err != nil {
errors += fmt.Sprintf("\n#%d: %v", i, err)
errcount++
} }
// WriteToUDP queues a datagram.
func (c *dgramPipe) WriteToUDP(b []byte, to *net.UDPAddr) (n int, err error) {
msg := make([]byte, len(b))
copy(msg, b)
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return 0, errors.New("closed")
}
c.queue = append(c.queue, msg)
c.cond.Signal()
return len(b), nil
} }
if errcount == n {
return fmt.Errorf("failed on all %d iterations:%s", n, errors) // ReadFromUDP just hangs until the pipe is closed.
func (c *dgramPipe) ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) {
<-c.closing
return 0, nil, io.EOF
}
func (c *dgramPipe) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if !c.closed {
close(c.closing)
c.closed = true
} }
return nil return nil
} }
func (c *dgramPipe) LocalAddr() net.Addr {
return &net.UDPAddr{}
}
func (c *dgramPipe) waitPacketOut() []byte {
c.mu.Lock()
defer c.mu.Unlock()
for len(c.queue) == 0 {
c.cond.Wait()
}
p := c.queue[0]
copy(c.queue, c.queue[1:])
c.queue = c.queue[:len(c.queue)-1]
return p
}

Loading…
Cancel
Save