p2p/discover: add node URL functions, distinguish TCP/UDP ports

The discovery RPC protocol does not yet distinguish TCP and UDP ports.
But it can't hurt to do so in our internal model.
poc8
Felix Lange 10 years ago
parent 56f777b2fc
commit 8564eb9f7e
  1. 289
      p2p/discover/node.go
  2. 201
      p2p/discover/node_test.go
  3. 197
      p2p/discover/table.go
  4. 116
      p2p/discover/table_test.go
  5. 32
      p2p/discover/udp.go
  6. 11
      p2p/discover/udp_test.go
  7. 7
      p2p/server.go
  8. 3
      p2p/server_test.go

@ -0,0 +1,289 @@
package discover
import (
"crypto/ecdsa"
"crypto/elliptic"
"encoding/hex"
"errors"
"fmt"
"io"
"math/rand"
"net"
"net/url"
"strconv"
"strings"
"time"
"github.com/ethereum/go-ethereum/crypto/secp256k1"
"github.com/ethereum/go-ethereum/rlp"
)
// Node represents a host on the network.
type Node struct {
ID NodeID
IP net.IP
DiscPort int // UDP listening port for discovery protocol
TCPPort int // TCP listening port for RLPx
active time.Time
}
func newNode(id NodeID, addr *net.UDPAddr) *Node {
return &Node{
ID: id,
IP: addr.IP,
DiscPort: addr.Port,
TCPPort: addr.Port,
active: time.Now(),
}
}
func (n *Node) isValid() bool {
// TODO: don't accept localhost, LAN addresses from internet hosts
return !n.IP.IsMulticast() && !n.IP.IsUnspecified() && n.TCPPort != 0 && n.DiscPort != 0
}
// The string representation of a Node is a URL.
// Please see ParseNode for a description of the format.
func (n *Node) String() string {
addr := net.TCPAddr{IP: n.IP, Port: n.TCPPort}
u := url.URL{
Scheme: "enode",
User: url.User(fmt.Sprintf("%x", n.ID[:])),
Host: addr.String(),
}
if n.DiscPort != n.TCPPort {
u.RawQuery = "discport=" + strconv.Itoa(n.DiscPort)
}
return u.String()
}
// ParseNode parses a node URL.
//
// A node URL has scheme "enode".
//
// The hexadecimal node ID is encoded in the username portion of the
// URL, separated from the host by an @ sign. The hostname can only be
// given as an IP address, DNS domain names are not allowed. The port
// in the host name section is the TCP listening port. If the TCP and
// UDP (discovery) ports differ, the UDP port is specified as query
// parameter "discport".
//
// In the following example, the node URL describes
// a node with IP address 10.3.58.6, TCP listening port 30303
// and UDP discovery port 30301.
//
// enode://<hex node id>@10.3.58.6:30303?discport=30301
func ParseNode(rawurl string) (*Node, error) {
var n Node
u, err := url.Parse(rawurl)
if u.Scheme != "enode" {
return nil, errors.New("invalid URL scheme, want \"enode\"")
}
if u.User == nil {
return nil, errors.New("does not contain node ID")
}
if n.ID, err = HexID(u.User.String()); err != nil {
return nil, fmt.Errorf("invalid node ID (%v)", err)
}
ip, port, err := net.SplitHostPort(u.Host)
if err != nil {
return nil, fmt.Errorf("invalid host: %v", err)
}
if n.IP = net.ParseIP(ip); n.IP == nil {
return nil, errors.New("invalid IP address")
}
if n.TCPPort, err = strconv.Atoi(port); err != nil {
return nil, errors.New("invalid port")
}
qv := u.Query()
if qv.Get("discport") == "" {
n.DiscPort = n.TCPPort
} else {
if n.DiscPort, err = strconv.Atoi(qv.Get("discport")); err != nil {
return nil, errors.New("invalid discport in query")
}
}
return &n, nil
}
// MustParseNode parses a node URL. It panics if the URL is not valid.
func MustParseNode(rawurl string) *Node {
n, err := ParseNode(rawurl)
if err != nil {
panic("invalid node URL: " + err.Error())
}
return n
}
func (n Node) EncodeRLP(w io.Writer) error {
return rlp.Encode(w, rpcNode{IP: n.IP.String(), Port: uint16(n.TCPPort), ID: n.ID})
}
func (n *Node) DecodeRLP(s *rlp.Stream) (err error) {
var ext rpcNode
if err = s.Decode(&ext); err == nil {
n.TCPPort = int(ext.Port)
n.DiscPort = int(ext.Port)
n.ID = ext.ID
if n.IP = net.ParseIP(ext.IP); n.IP == nil {
return errors.New("invalid IP string")
}
}
return err
}
// NodeID is a unique identifier for each node.
// The node identifier is a marshaled elliptic curve public key.
type NodeID [512 / 8]byte
// NodeID prints as a long hexadecimal number.
func (n NodeID) String() string {
return fmt.Sprintf("%#x", n[:])
}
// The Go syntax representation of a NodeID is a call to HexID.
func (n NodeID) GoString() string {
return fmt.Sprintf("discover.HexID(\"%#x\")", n[:])
}
// HexID converts a hex string to a NodeID.
// The string may be prefixed with 0x.
func HexID(in string) (NodeID, error) {
if strings.HasPrefix(in, "0x") {
in = in[2:]
}
var id NodeID
b, err := hex.DecodeString(in)
if err != nil {
return id, err
} else if len(b) != len(id) {
return id, fmt.Errorf("wrong length, need %d hex bytes", len(id))
}
copy(id[:], b)
return id, nil
}
// MustHexID converts a hex string to a NodeID.
// It panics if the string is not a valid NodeID.
func MustHexID(in string) NodeID {
id, err := HexID(in)
if err != nil {
panic(err)
}
return id
}
// PubkeyID returns a marshaled representation of the given public key.
func PubkeyID(pub *ecdsa.PublicKey) NodeID {
var id NodeID
pbytes := elliptic.Marshal(pub.Curve, pub.X, pub.Y)
if len(pbytes)-1 != len(id) {
panic(fmt.Errorf("need %d bit pubkey, got %d bits", (len(id)+1)*8, len(pbytes)))
}
copy(id[:], pbytes[1:])
return id
}
// recoverNodeID computes the public key used to sign the
// given hash from the signature.
func recoverNodeID(hash, sig []byte) (id NodeID, err error) {
pubkey, err := secp256k1.RecoverPubkey(hash, sig)
if err != nil {
return id, err
}
if len(pubkey)-1 != len(id) {
return id, fmt.Errorf("recovered pubkey has %d bits, want %d bits", len(pubkey)*8, (len(id)+1)*8)
}
for i := range id {
id[i] = pubkey[i+1]
}
return id, nil
}
// distcmp compares the distances a->target and b->target.
// Returns -1 if a is closer to target, 1 if b is closer to target
// and 0 if they are equal.
func distcmp(target, a, b NodeID) int {
for i := range target {
da := a[i] ^ target[i]
db := b[i] ^ target[i]
if da > db {
return 1
} else if da < db {
return -1
}
}
return 0
}
// table of leading zero counts for bytes [0..255]
var lzcount = [256]int{
8, 7, 6, 6, 5, 5, 5, 5,
4, 4, 4, 4, 4, 4, 4, 4,
3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3,
2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2,
1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
}
// logdist returns the logarithmic distance between a and b, log2(a ^ b).
func logdist(a, b NodeID) int {
lz := 0
for i := range a {
x := a[i] ^ b[i]
if x == 0 {
lz += 8
} else {
lz += lzcount[x]
break
}
}
return len(a)*8 - lz
}
// randomID returns a random NodeID such that logdist(a, b) == n
func randomID(a NodeID, n int) (b NodeID) {
if n == 0 {
return a
}
// flip bit at position n, fill the rest with random bits
b = a
pos := len(a) - n/8 - 1
bit := byte(0x01) << (byte(n%8) - 1)
if bit == 0 {
pos++
bit = 0x80
}
b[pos] = a[pos]&^bit | ^a[pos]&bit // TODO: randomize end bits
for i := pos + 1; i < len(a); i++ {
b[i] = byte(rand.Intn(255))
}
return b
}

@ -0,0 +1,201 @@
package discover
import (
"math/big"
"math/rand"
"net"
"reflect"
"testing"
"testing/quick"
"time"
"github.com/ethereum/go-ethereum/crypto"
)
var (
quickrand = rand.New(rand.NewSource(time.Now().Unix()))
quickcfg = &quick.Config{MaxCount: 5000, Rand: quickrand}
)
var parseNodeTests = []struct {
rawurl string
wantError string
wantResult *Node
}{
{
rawurl: "http://foobar",
wantError: `invalid URL scheme, want "enode"`,
},
{
rawurl: "enode://foobar",
wantError: `does not contain node ID`,
},
{
rawurl: "enode://01010101@123.124.125.126:3",
wantError: `invalid node ID (wrong length, need 64 hex bytes)`,
},
{
rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@hostname:3",
wantError: `invalid IP address`,
},
{
rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:foo",
wantError: `invalid port`,
},
{
rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:3?discport=foo",
wantError: `invalid discport in query`,
},
{
rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:52150",
wantResult: &Node{
ID: MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
IP: net.ParseIP("127.0.0.1"),
DiscPort: 52150,
TCPPort: 52150,
},
},
{
rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@[::]:52150",
wantResult: &Node{
ID: MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
IP: net.ParseIP("::"),
DiscPort: 52150,
TCPPort: 52150,
},
},
{
rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:52150?discport=223344",
wantResult: &Node{
ID: MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
IP: net.ParseIP("127.0.0.1"),
DiscPort: 223344,
TCPPort: 52150,
},
},
}
func TestParseNode(t *testing.T) {
for i, test := range parseNodeTests {
n, err := ParseNode(test.rawurl)
if err == nil && test.wantError != "" {
t.Errorf("test %d: got nil error, expected %#q", i, test.wantError)
continue
}
if err != nil && err.Error() != test.wantError {
t.Errorf("test %d: got error %#q, expected %#q", i, err.Error(), test.wantError)
continue
}
if !reflect.DeepEqual(n, test.wantResult) {
t.Errorf("test %d: result mismatch:\ngot: %#v, want: %#v", i, n, test.wantResult)
}
}
}
func TestNodeString(t *testing.T) {
for i, test := range parseNodeTests {
if test.wantError != "" {
continue
}
str := test.wantResult.String()
if str != test.rawurl {
t.Errorf("test %d: Node.String() mismatch:\ngot: %s\nwant: %s", i, str, test.rawurl)
}
}
}
func TestHexID(t *testing.T) {
ref := NodeID{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 128, 106, 217, 182, 31, 165, 174, 1, 67, 7, 235, 220, 150, 66, 83, 173, 205, 159, 44, 10, 57, 42, 161, 26, 188}
id1 := MustHexID("0x000000000000000000000000000000000000000000000000000000000000000000000000000000806ad9b61fa5ae014307ebdc964253adcd9f2c0a392aa11abc")
id2 := MustHexID("000000000000000000000000000000000000000000000000000000000000000000000000000000806ad9b61fa5ae014307ebdc964253adcd9f2c0a392aa11abc")
if id1 != ref {
t.Errorf("wrong id1\ngot %v\nwant %v", id1[:], ref[:])
}
if id2 != ref {
t.Errorf("wrong id2\ngot %v\nwant %v", id2[:], ref[:])
}
}
func TestNodeID_recover(t *testing.T) {
prv := newkey()
hash := make([]byte, 32)
sig, err := crypto.Sign(hash, prv)
if err != nil {
t.Fatalf("signing error: %v", err)
}
pub := PubkeyID(&prv.PublicKey)
recpub, err := recoverNodeID(hash, sig)
if err != nil {
t.Fatalf("recovery error: %v", err)
}
if pub != recpub {
t.Errorf("recovered wrong pubkey:\ngot: %v\nwant: %v", recpub, pub)
}
}
func TestNodeID_distcmp(t *testing.T) {
distcmpBig := func(target, a, b NodeID) int {
tbig := new(big.Int).SetBytes(target[:])
abig := new(big.Int).SetBytes(a[:])
bbig := new(big.Int).SetBytes(b[:])
return new(big.Int).Xor(tbig, abig).Cmp(new(big.Int).Xor(tbig, bbig))
}
if err := quick.CheckEqual(distcmp, distcmpBig, quickcfg); err != nil {
t.Error(err)
}
}
// the random tests is likely to miss the case where they're equal.
func TestNodeID_distcmpEqual(t *testing.T) {
base := NodeID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
x := NodeID{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}
if distcmp(base, x, x) != 0 {
t.Errorf("distcmp(base, x, x) != 0")
}
}
func TestNodeID_logdist(t *testing.T) {
logdistBig := func(a, b NodeID) int {
abig, bbig := new(big.Int).SetBytes(a[:]), new(big.Int).SetBytes(b[:])
return new(big.Int).Xor(abig, bbig).BitLen()
}
if err := quick.CheckEqual(logdist, logdistBig, quickcfg); err != nil {
t.Error(err)
}
}
// the random tests is likely to miss the case where they're equal.
func TestNodeID_logdistEqual(t *testing.T) {
x := NodeID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
if logdist(x, x) != 0 {
t.Errorf("logdist(x, x) != 0")
}
}
func TestNodeID_randomID(t *testing.T) {
// we don't use quick.Check here because its output isn't
// very helpful when the test fails.
for i := 0; i < quickcfg.MaxCount; i++ {
a := gen(NodeID{}, quickrand).(NodeID)
dist := quickrand.Intn(len(NodeID{}) * 8)
result := randomID(a, dist)
actualdist := logdist(result, a)
if dist != actualdist {
t.Log("a: ", a)
t.Log("result:", result)
t.Fatalf("#%d: distance of result is %d, want %d", i, actualdist, dist)
}
}
}
func (NodeID) Generate(rand *rand.Rand, size int) reflect.Value {
var id NodeID
m := rand.Intn(len(id))
for i := len(id) - 1; i > m; i-- {
id[i] = byte(rand.Uint32())
}
return reflect.ValueOf(id)
}

@ -7,20 +7,10 @@
package discover
import (
"crypto/ecdsa"
"crypto/elliptic"
"encoding/hex"
"fmt"
"io"
"math/rand"
"net"
"sort"
"strings"
"sync"
"time"
"github.com/ethereum/go-ethereum/crypto/secp256k1"
"github.com/ethereum/go-ethereum/rlp"
)
const (
@ -53,36 +43,10 @@ type bucket struct {
entries []*Node
}
// Node represents node metadata that is stored in the table.
type Node struct {
Addr *net.UDPAddr
ID NodeID
active time.Time
}
type rpcNode struct {
IP string
Port uint16
ID NodeID
}
func (n Node) EncodeRLP(w io.Writer) error {
return rlp.Encode(w, rpcNode{IP: n.Addr.IP.String(), Port: uint16(n.Addr.Port), ID: n.ID})
}
func (n *Node) DecodeRLP(s *rlp.Stream) (err error) {
var ext rpcNode
if err = s.Decode(&ext); err == nil {
n.Addr = &net.UDPAddr{IP: net.ParseIP(ext.IP), Port: int(ext.Port)}
n.ID = ext.ID
}
return err
}
func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr) *Table {
tab := &Table{net: t, self: &Node{ID: ourID, Addr: ourAddr}}
tab := &Table{net: t, self: newNode(ourID, ourAddr)}
for i := range tab.buckets {
tab.buckets[i] = &bucket{}
tab.buckets[i] = new(bucket)
}
return tab
}
@ -217,7 +181,7 @@ func (tab *Table) len() (n int) {
func (tab *Table) bumpOrAdd(node NodeID, from *net.UDPAddr) (n *Node) {
b := tab.buckets[logdist(tab.self.ID, node)]
if n = b.bump(node); n == nil {
n = &Node{ID: node, Addr: from, active: time.Now()}
n = newNode(node, from)
if len(b.entries) == bucketSize {
tab.pingReplace(n, b)
} else {
@ -238,6 +202,7 @@ func (tab *Table) pingReplace(n *Node, b *bucket) {
tab.mutex.Lock()
if len(b.entries) > 0 && b.entries[len(b.entries)-1] == old {
// slide down other entries and put the new one in front.
// TODO: insert in correct position to keep the order
copy(b.entries[1:], b.entries)
b.entries[0] = n
}
@ -312,157 +277,3 @@ func (h *nodesByDistance) push(n *Node, maxElems int) {
h.entries[ix] = n
}
}
// NodeID is a unique identifier for each node.
// The node identifier is a marshaled elliptic curve public key.
type NodeID [512 / 8]byte
// NodeID prints as a long hexadecimal number.
func (n NodeID) String() string {
return fmt.Sprintf("%#x", n[:])
}
// The Go syntax representation of a NodeID is a call to HexID.
func (n NodeID) GoString() string {
return fmt.Sprintf("HexID(\"%#x\")", n[:])
}
// HexID converts a hex string to a NodeID.
// The string may be prefixed with 0x.
func HexID(in string) (NodeID, error) {
if strings.HasPrefix(in, "0x") {
in = in[2:]
}
var id NodeID
b, err := hex.DecodeString(in)
if err != nil {
return id, err
} else if len(b) != len(id) {
return id, fmt.Errorf("wrong length, need %d hex bytes", len(id))
}
copy(id[:], b)
return id, nil
}
// MustHexID converts a hex string to a NodeID.
// It panics if the string is not a valid NodeID.
func MustHexID(in string) NodeID {
id, err := HexID(in)
if err != nil {
panic(err)
}
return id
}
func PubkeyID(pub *ecdsa.PublicKey) NodeID {
var id NodeID
pbytes := elliptic.Marshal(pub.Curve, pub.X, pub.Y)
if len(pbytes)-1 != len(id) {
panic(fmt.Errorf("invalid key: need %d bit pubkey, got %d bits", (len(id)+1)*8, len(pbytes)))
}
copy(id[:], pbytes[1:])
return id
}
// recoverNodeID computes the public key used to sign the
// given hash from the signature.
func recoverNodeID(hash, sig []byte) (id NodeID, err error) {
pubkey, err := secp256k1.RecoverPubkey(hash, sig)
if err != nil {
return id, err
}
if len(pubkey)-1 != len(id) {
return id, fmt.Errorf("recovered pubkey has %d bits, want %d bits", len(pubkey)*8, (len(id)+1)*8)
}
for i := range id {
id[i] = pubkey[i+1]
}
return id, nil
}
// distcmp compares the distances a->target and b->target.
// Returns -1 if a is closer to target, 1 if b is closer to target
// and 0 if they are equal.
func distcmp(target, a, b NodeID) int {
for i := range target {
da := a[i] ^ target[i]
db := b[i] ^ target[i]
if da > db {
return 1
} else if da < db {
return -1
}
}
return 0
}
// table of leading zero counts for bytes [0..255]
var lzcount = [256]int{
8, 7, 6, 6, 5, 5, 5, 5,
4, 4, 4, 4, 4, 4, 4, 4,
3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3,
2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2,
1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
}
// logdist returns the logarithmic distance between a and b, log2(a ^ b).
func logdist(a, b NodeID) int {
lz := 0
for i := range a {
x := a[i] ^ b[i]
if x == 0 {
lz += 8
} else {
lz += lzcount[x]
break
}
}
return len(a)*8 - lz
}
// randomID returns a random NodeID such that logdist(a, b) == n
func randomID(a NodeID, n int) (b NodeID) {
if n == 0 {
return a
}
// flip bit at position n, fill the rest with random bits
b = a
pos := len(a) - n/8 - 1
bit := byte(0x01) << (byte(n%8) - 1)
if bit == 0 {
pos++
bit = 0x80
}
b[pos] = a[pos]&^bit | ^a[pos]&bit // TODO: randomize end bits
for i := pos + 1; i < len(a); i++ {
b[i] = byte(rand.Intn(255))
}
return b
}

@ -4,7 +4,6 @@ import (
"crypto/ecdsa"
"errors"
"fmt"
"math/big"
"math/rand"
"net"
"reflect"
@ -15,107 +14,6 @@ import (
"github.com/ethereum/go-ethereum/crypto"
)
var (
quickrand = rand.New(rand.NewSource(time.Now().Unix()))
quickcfg = &quick.Config{MaxCount: 5000, Rand: quickrand}
)
func TestHexID(t *testing.T) {
ref := NodeID{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 128, 106, 217, 182, 31, 165, 174, 1, 67, 7, 235, 220, 150, 66, 83, 173, 205, 159, 44, 10, 57, 42, 161, 26, 188}
id1 := MustHexID("0x000000000000000000000000000000000000000000000000000000000000000000000000000000806ad9b61fa5ae014307ebdc964253adcd9f2c0a392aa11abc")
id2 := MustHexID("000000000000000000000000000000000000000000000000000000000000000000000000000000806ad9b61fa5ae014307ebdc964253adcd9f2c0a392aa11abc")
if id1 != ref {
t.Errorf("wrong id1\ngot %v\nwant %v", id1[:], ref[:])
}
if id2 != ref {
t.Errorf("wrong id2\ngot %v\nwant %v", id2[:], ref[:])
}
}
func TestNodeID_recover(t *testing.T) {
prv := newkey()
hash := make([]byte, 32)
sig, err := crypto.Sign(hash, prv)
if err != nil {
t.Fatalf("signing error: %v", err)
}
pub := PubkeyID(&prv.PublicKey)
recpub, err := recoverNodeID(hash, sig)
if err != nil {
t.Fatalf("recovery error: %v", err)
}
if pub != recpub {
t.Errorf("recovered wrong pubkey:\ngot: %v\nwant: %v", recpub, pub)
}
}
func TestNodeID_distcmp(t *testing.T) {
distcmpBig := func(target, a, b NodeID) int {
tbig := new(big.Int).SetBytes(target[:])
abig := new(big.Int).SetBytes(a[:])
bbig := new(big.Int).SetBytes(b[:])
return new(big.Int).Xor(tbig, abig).Cmp(new(big.Int).Xor(tbig, bbig))
}
if err := quick.CheckEqual(distcmp, distcmpBig, quickcfg); err != nil {
t.Error(err)
}
}
// the random tests is likely to miss the case where they're equal.
func TestNodeID_distcmpEqual(t *testing.T) {
base := NodeID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
x := NodeID{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}
if distcmp(base, x, x) != 0 {
t.Errorf("distcmp(base, x, x) != 0")
}
}
func TestNodeID_logdist(t *testing.T) {
logdistBig := func(a, b NodeID) int {
abig, bbig := new(big.Int).SetBytes(a[:]), new(big.Int).SetBytes(b[:])
return new(big.Int).Xor(abig, bbig).BitLen()
}
if err := quick.CheckEqual(logdist, logdistBig, quickcfg); err != nil {
t.Error(err)
}
}
// the random tests is likely to miss the case where they're equal.
func TestNodeID_logdistEqual(t *testing.T) {
x := NodeID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
if logdist(x, x) != 0 {
t.Errorf("logdist(x, x) != 0")
}
}
func TestNodeID_randomID(t *testing.T) {
// we don't use quick.Check here because its output isn't
// very helpful when the test fails.
for i := 0; i < quickcfg.MaxCount; i++ {
a := gen(NodeID{}, quickrand).(NodeID)
dist := quickrand.Intn(len(NodeID{}) * 8)
result := randomID(a, dist)
actualdist := logdist(result, a)
if dist != actualdist {
t.Log("a: ", a)
t.Log("result:", result)
t.Fatalf("#%d: distance of result is %d, want %d", i, actualdist, dist)
}
}
}
func (NodeID) Generate(rand *rand.Rand, size int) reflect.Value {
var id NodeID
m := rand.Intn(len(id))
for i := len(id) - 1; i > m; i-- {
id[i] = byte(rand.Uint32())
}
return reflect.ValueOf(id)
}
func TestTable_bumpOrAddPingReplace(t *testing.T) {
pingC := make(pingC)
tab := newTable(pingC, NodeID{}, &net.UDPAddr{})
@ -123,7 +21,7 @@ func TestTable_bumpOrAddPingReplace(t *testing.T) {
// this bumpOrAdd should not replace the last node
// because the node replies to ping.
new := tab.bumpOrAdd(randomID(tab.self.ID, 200), nil)
new := tab.bumpOrAdd(randomID(tab.self.ID, 200), &net.UDPAddr{})
pinged := <-pingC
if pinged != last.ID {
@ -149,7 +47,7 @@ func TestTable_bumpOrAddPingTimeout(t *testing.T) {
// this bumpOrAdd should replace the last node
// because the node does not reply to ping.
new := tab.bumpOrAdd(randomID(tab.self.ID, 200), nil)
new := tab.bumpOrAdd(randomID(tab.self.ID, 200), &net.UDPAddr{})
// wait for async bucket update. damn. this needs to go away.
time.Sleep(2 * time.Millisecond)
@ -329,19 +227,17 @@ type findnodeOracle struct {
}
func (t findnodeOracle) findnode(n *Node, target NodeID) ([]*Node, error) {
t.t.Logf("findnode query at dist %d", n.Addr.Port)
t.t.Logf("findnode query at dist %d", n.DiscPort)
// current log distance is encoded in port number
var result []*Node
switch port := n.Addr.Port; port {
switch n.DiscPort {
case 0:
panic("query to node at distance 0")
case 1:
result = append(result, &Node{ID: t.target, Addr: &net.UDPAddr{Port: 0}})
default:
// TODO: add more randomness to distances
port--
next := n.DiscPort - 1
for i := 0; i < bucketSize; i++ {
result = append(result, &Node{ID: randomID(t.target, port), Addr: &net.UDPAddr{Port: port}})
result = append(result, &Node{ID: randomID(t.target, next), DiscPort: next})
}
}
return result, nil

@ -69,6 +69,12 @@ type (
}
)
type rpcNode struct {
IP string
Port uint16
ID NodeID
}
// udp implements the RPC protocol.
type udp struct {
conn *net.UDPConn
@ -121,7 +127,7 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string) (*Table, error) {
return nil, err
}
net.Table = newTable(net, PubkeyID(&priv.PublicKey), realaddr)
log.Debugf("Listening on %v, my ID %x\n", realaddr, net.self.ID[:])
log.Debugf("Listening, %v\n", net.self)
return net.Table, nil
}
@ -159,8 +165,8 @@ func (t *udp) ping(e *Node) error {
// TODO: maybe check for ReplyTo field in callback to measure RTT
errc := t.pending(e.ID, pongPacket, func(interface{}) bool { return true })
t.send(e, pingPacket, ping{
IP: t.self.Addr.String(),
Port: uint16(t.self.Addr.Port),
IP: t.self.IP.String(),
Port: uint16(t.self.TCPPort),
Expiration: uint64(time.Now().Add(expiration).Unix()),
})
return <-errc
@ -176,7 +182,7 @@ func (t *udp) findnode(to *Node, target NodeID) ([]*Node, error) {
for i := 0; i < len(reply.Nodes); i++ {
nreceived++
n := reply.Nodes[i]
if validAddr(n.Addr) && n.ID != t.self.ID {
if n.ID != t.self.ID && n.isValid() {
nodes = append(nodes, n)
}
}
@ -191,10 +197,6 @@ func (t *udp) findnode(to *Node, target NodeID) ([]*Node, error) {
return nodes, err
}
func validAddr(a *net.UDPAddr) bool {
return !a.IP.IsMulticast() && !a.IP.IsUnspecified() && a.Port != 0
}
// pending adds a reply callback to the pending reply queue.
// see the documentation of type pending for a detailed explanation.
func (t *udp) pending(id NodeID, ptype byte, callback func(interface{}) bool) <-chan error {
@ -302,8 +304,9 @@ func (t *udp) send(to *Node, ptype byte, req interface{}) error {
// the future.
copy(packet, crypto.Sha3(packet[macSize:]))
log.DebugDetailf(">>> %v %T %v\n", to.Addr, req, req)
if _, err = t.conn.WriteToUDP(packet, to.Addr); err != nil {
toaddr := &net.UDPAddr{IP: to.IP, Port: to.DiscPort}
log.DebugDetailf(">>> %v %T %v\n", toaddr, req, req)
if _, err = t.conn.WriteToUDP(packet, toaddr); err != nil {
log.DebugDetailln("UDP send failed:", err)
}
return err
@ -365,11 +368,14 @@ func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) er
return errExpired
}
t.mutex.Lock()
// Note: we're ignoring the provided IP/Port right now.
e := t.bumpOrAdd(fromID, from)
// Note: we're ignoring the provided IP address right now
n := t.bumpOrAdd(fromID, from)
if req.Port != 0 {
n.TCPPort = int(req.Port)
}
t.mutex.Unlock()
t.send(e, pongPacket, pong{
t.send(n, pongPacket, pong{
ReplyTok: mac,
Expiration: uint64(time.Now().Add(expiration).Unix()),
})

@ -4,6 +4,7 @@ import (
logpkg "log"
"net"
"os"
"reflect"
"testing"
"time"
@ -11,7 +12,7 @@ import (
)
func init() {
logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, logpkg.LstdFlags, logger.DebugLevel))
logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, logpkg.LstdFlags, logger.ErrorLevel))
}
func TestUDP_ping(t *testing.T) {
@ -52,7 +53,7 @@ func TestUDP_findnode(t *testing.T) {
defer n1.Close()
defer n2.Close()
entry := &Node{ID: NodeID{1}, Addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 15}}
entry := MustParseNode("enode://9d8a19597e312ef32d76e6b4903bb43d7bcd892d17b769d30b404bd3a4c2dca6c86184b17d0fdeeafe3b01e0e912d990ddc853db3f325d5419f31446543c30be@127.0.0.1:54194")
n2.add([]*Node{entry})
target := randomID(n1.self.ID, 100)
@ -60,7 +61,7 @@ func TestUDP_findnode(t *testing.T) {
if len(result) != 1 {
t.Fatalf("wrong number of results: got %d, want 1", len(result))
}
if result[0].ID != entry.ID {
if !reflect.DeepEqual(result[0], entry) {
t.Errorf("wrong result: got %v, want %v", result[0], entry)
}
}
@ -103,7 +104,9 @@ func TestUDP_findnodeMultiReply(t *testing.T) {
nodes := make([]*Node, bucketSize)
for i := range nodes {
nodes[i] = &Node{
Addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: i + 1},
IP: net.IP{1, 2, 3, 4},
DiscPort: i + 1,
TCPPort: i + 1,
ID: randomID(n2.self.ID, i+1),
}
}

@ -136,7 +136,7 @@ func (srv *Server) PeerCount() int {
// SuggestPeer creates a connection to the given Node if it
// is not already connected.
func (srv *Server) SuggestPeer(ip net.IP, port int, id discover.NodeID) {
srv.peerConnect <- &discover.Node{ID: id, Addr: &net.UDPAddr{IP: ip, Port: port}}
srv.peerConnect <- &discover.Node{ID: id, IP: ip, TCPPort: port}
}
// Broadcast sends an RLP-encoded message to all connected peers.
@ -364,8 +364,9 @@ func (srv *Server) dialLoop() {
}
func (srv *Server) dialNode(dest *discover.Node) {
srvlog.Debugf("Dialing %v\n", dest.Addr)
conn, err := srv.Dialer.Dial("tcp", dest.Addr.String())
addr := &net.TCPAddr{IP: dest.IP, Port: dest.TCPPort}
srvlog.Debugf("Dialing %v\n", dest)
conn, err := srv.Dialer.Dial("tcp", addr.String())
if err != nil {
srvlog.DebugDetailf("dial error: %v", err)
return

@ -91,8 +91,7 @@ func TestServerDial(t *testing.T) {
// tell the server to connect
tcpAddr := listener.Addr().(*net.TCPAddr)
connAddr := &discover.Node{Addr: &net.UDPAddr{IP: tcpAddr.IP, Port: tcpAddr.Port}}
srv.peerConnect <- connAddr
srv.peerConnect <- &discover.Node{IP: tcpAddr.IP, TCPPort: tcpAddr.Port}
select {
case conn := <-accepted:

Loading…
Cancel
Save