diff --git a/cmd/bootnode/main.go b/cmd/bootnode/main.go index dda9f34d4..26912525d 100644 --- a/cmd/bootnode/main.go +++ b/cmd/bootnode/main.go @@ -71,7 +71,7 @@ func main() { } } - if _, err := discover.ListenUDP(nodeKey, *listenAddr, natm); err != nil { + if _, err := discover.ListenUDP(nodeKey, *listenAddr, natm, ""); err != nil { log.Fatal(err) } select {} diff --git a/eth/backend.go b/eth/backend.go index 466912899..c5fa328b0 100644 --- a/eth/backend.go +++ b/eth/backend.go @@ -125,6 +125,7 @@ type Ethereum struct { blockDb common.Database // Block chain database stateDb common.Database // State changes database extraDb common.Database // Extra database (txs, etc) + // Closed when databases are flushed and closed databasesClosed chan bool @@ -179,6 +180,7 @@ func New(config *Config) (*Ethereum, error) { if err != nil { return nil, err } + nodeDb := path.Join(config.DataDir, "nodes") // Perform database sanity checks d, _ := blockDb.Get([]byte("ProtocolVersion")) @@ -245,6 +247,7 @@ func New(config *Config) (*Ethereum, error) { NAT: config.NAT, NoDial: !config.Dial, BootstrapNodes: config.parseBootNodes(), + NodeDatabase: nodeDb, } if len(config.Port) > 0 { eth.net.ListenAddr = ":" + config.Port diff --git a/p2p/discover/database.go b/p2p/discover/database.go new file mode 100644 index 000000000..d966a6ac1 --- /dev/null +++ b/p2p/discover/database.go @@ -0,0 +1,304 @@ +// Contains the node database, storing previously seen nodes and any collected +// metadata about them for QoS purposes. + +package discover + +import ( + "bytes" + "encoding/binary" + "os" + "sync" + "time" + + "github.com/ethereum/go-ethereum/logger" + "github.com/ethereum/go-ethereum/logger/glog" + "github.com/ethereum/go-ethereum/rlp" + "github.com/syndtr/goleveldb/leveldb" + "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/storage" +) + +var ( + nodeDBNilNodeID = NodeID{} // Special node ID to use as a nil element. + nodeDBNodeExpiration = 24 * time.Hour // Time after which an unseen node should be dropped. + nodeDBCleanupCycle = time.Hour // Time period for running the expiration task. +) + +// nodeDB stores all nodes we know about. +type nodeDB struct { + lvl *leveldb.DB // Interface to the database itself + seeder iterator.Iterator // Iterator for fetching possible seed nodes + + runner sync.Once // Ensures we can start at most one expirer + quit chan struct{} // Channel to signal the expiring thread to stop +} + +// Schema layout for the node database +var ( + nodeDBVersionKey = []byte("version") // Version of the database to flush if changes + nodeDBItemPrefix = []byte("n:") // Identifier to prefix node entries with + + nodeDBDiscoverRoot = ":discover" + nodeDBDiscoverPing = nodeDBDiscoverRoot + ":lastping" + nodeDBDiscoverPong = nodeDBDiscoverRoot + ":lastpong" +) + +// newNodeDB creates a new node database for storing and retrieving infos about +// known peers in the network. If no path is given, an in-memory, temporary +// database is constructed. +func newNodeDB(path string, version int) (*nodeDB, error) { + if path == "" { + return newMemoryNodeDB() + } + return newPersistentNodeDB(path, version) +} + +// newMemoryNodeDB creates a new in-memory node database without a persistent +// backend. +func newMemoryNodeDB() (*nodeDB, error) { + db, err := leveldb.Open(storage.NewMemStorage(), nil) + if err != nil { + return nil, err + } + return &nodeDB{ + lvl: db, + quit: make(chan struct{}), + }, nil +} + +// newPersistentNodeDB creates/opens a leveldb backed persistent node database, +// also flushing its contents in case of a version mismatch. +func newPersistentNodeDB(path string, version int) (*nodeDB, error) { + // Try to open the cache, recovering any corruption + db, err := leveldb.OpenFile(path, nil) + if _, iscorrupted := err.(leveldb.ErrCorrupted); iscorrupted { + db, err = leveldb.RecoverFile(path, nil) + } + if err != nil { + return nil, err + } + // The nodes contained in the cache correspond to a certain protocol version. + // Flush all nodes if the version doesn't match. + currentVer := make([]byte, binary.MaxVarintLen64) + currentVer = currentVer[:binary.PutVarint(currentVer, int64(version))] + + blob, err := db.Get(nodeDBVersionKey, nil) + switch err { + case leveldb.ErrNotFound: + // Version not found (i.e. empty cache), insert it + if err := db.Put(nodeDBVersionKey, currentVer, nil); err != nil { + db.Close() + return nil, err + } + + case nil: + // Version present, flush if different + if !bytes.Equal(blob, currentVer) { + db.Close() + if err = os.RemoveAll(path); err != nil { + return nil, err + } + return newPersistentNodeDB(path, version) + } + } + return &nodeDB{ + lvl: db, + quit: make(chan struct{}), + }, nil +} + +// makeKey generates the leveldb key-blob from a node id and its particular +// field of interest. +func makeKey(id NodeID, field string) []byte { + if bytes.Equal(id[:], nodeDBNilNodeID[:]) { + return []byte(field) + } + return append(nodeDBItemPrefix, append(id[:], field...)...) +} + +// splitKey tries to split a database key into a node id and a field part. +func splitKey(key []byte) (id NodeID, field string) { + // If the key is not of a node, return it plainly + if !bytes.HasPrefix(key, nodeDBItemPrefix) { + return NodeID{}, string(key) + } + // Otherwise split the id and field + item := key[len(nodeDBItemPrefix):] + copy(id[:], item[:len(id)]) + field = string(item[len(id):]) + + return id, field +} + +// fetchInt64 retrieves an integer instance associated with a particular +// database key. +func (db *nodeDB) fetchInt64(key []byte) int64 { + blob, err := db.lvl.Get(key, nil) + if err != nil { + return 0 + } + val, read := binary.Varint(blob) + if read <= 0 { + return 0 + } + return val +} + +// storeInt64 update a specific database entry to the current time instance as a +// unix timestamp. +func (db *nodeDB) storeInt64(key []byte, n int64) error { + blob := make([]byte, binary.MaxVarintLen64) + blob = blob[:binary.PutVarint(blob, n)] + + return db.lvl.Put(key, blob, nil) +} + +// node retrieves a node with a given id from the database. +func (db *nodeDB) node(id NodeID) *Node { + blob, err := db.lvl.Get(makeKey(id, nodeDBDiscoverRoot), nil) + if err != nil { + glog.V(logger.Detail).Infof("failed to retrieve node %v: %v", id, err) + return nil + } + node := new(Node) + if err := rlp.DecodeBytes(blob, node); err != nil { + glog.V(logger.Warn).Infof("failed to decode node RLP: %v", err) + return nil + } + return node +} + +// updateNode inserts - potentially overwriting - a node into the peer database. +func (db *nodeDB) updateNode(node *Node) error { + blob, err := rlp.EncodeToBytes(node) + if err != nil { + return err + } + return db.lvl.Put(makeKey(node.ID, nodeDBDiscoverRoot), blob, nil) +} + +// ensureExpirer is a small helper method ensuring that the data expiration +// mechanism is running. If the expiration goroutine is already running, this +// method simply returns. +// +// The goal is to start the data evacuation only after the network successfully +// bootstrapped itself (to prevent dumping potentially useful seed nodes). Since +// it would require significant overhead to exactly trace the first successful +// convergence, it's simpler to "ensure" the correct state when an appropriate +// condition occurs (i.e. a successful bonding), and discard further events. +func (db *nodeDB) ensureExpirer() { + db.runner.Do(func() { go db.expirer() }) +} + +// expirer should be started in a go routine, and is responsible for looping ad +// infinitum and dropping stale data from the database. +func (db *nodeDB) expirer() { + tick := time.Tick(nodeDBCleanupCycle) + for { + select { + case <-tick: + if err := db.expireNodes(); err != nil { + glog.V(logger.Error).Infof("Failed to expire nodedb items: %v", err) + } + + case <-db.quit: + return + } + } +} + +// expireNodes iterates over the database and deletes all nodes that have not +// been seen (i.e. received a pong from) for some alloted time. +func (db *nodeDB) expireNodes() error { + threshold := time.Now().Add(-nodeDBNodeExpiration) + + // Find discovered nodes that are older than the allowance + it := db.lvl.NewIterator(nil, nil) + defer it.Release() + + for it.Next() { + // Skip the item if not a discovery node + id, field := splitKey(it.Key()) + if field != nodeDBDiscoverRoot { + continue + } + // Skip the node if not expired yet + if seen := db.lastPong(id); seen.After(threshold) { + continue + } + // Otherwise delete all associated information + prefix := makeKey(id, "") + for ok := it.Seek(prefix); ok && bytes.HasPrefix(it.Key(), prefix); ok = it.Next() { + if err := db.lvl.Delete(it.Key(), nil); err != nil { + return err + } + } + } + return nil +} + +// lastPing retrieves the time of the last ping packet send to a remote node, +// requesting binding. +func (db *nodeDB) lastPing(id NodeID) time.Time { + return time.Unix(db.fetchInt64(makeKey(id, nodeDBDiscoverPing)), 0) +} + +// updateLastPing updates the last time we tried contacting a remote node. +func (db *nodeDB) updateLastPing(id NodeID, instance time.Time) error { + return db.storeInt64(makeKey(id, nodeDBDiscoverPing), instance.Unix()) +} + +// lastPong retrieves the time of the last successful contact from remote node. +func (db *nodeDB) lastPong(id NodeID) time.Time { + return time.Unix(db.fetchInt64(makeKey(id, nodeDBDiscoverPong)), 0) +} + +// updateLastPong updates the last time a remote node successfully contacted. +func (db *nodeDB) updateLastPong(id NodeID, instance time.Time) error { + return db.storeInt64(makeKey(id, nodeDBDiscoverPong), instance.Unix()) +} + +// querySeeds retrieves a batch of nodes to be used as potential seed servers +// during bootstrapping the node into the network. +// +// Ideal seeds are the most recently seen nodes (highest probability to be still +// alive), but yet untried. However, since leveldb only supports dumb iteration +// we will instead start pulling in potential seeds that haven't been yet pinged +// since the start of the boot procedure. +// +// If the database runs out of potential seeds, we restart the startup counter +// and start iterating over the peers again. +func (db *nodeDB) querySeeds(n int) []*Node { + // Create a new seed iterator if none exists + if db.seeder == nil { + db.seeder = db.lvl.NewIterator(nil, nil) + } + // Iterate over the nodes and find suitable seeds + nodes := make([]*Node, 0, n) + for len(nodes) < n && db.seeder.Next() { + // Iterate until a discovery node is found + id, field := splitKey(db.seeder.Key()) + if field != nodeDBDiscoverRoot { + continue + } + // Load it as a potential seed + if node := db.node(id); node != nil { + nodes = append(nodes, node) + } + } + // Release the iterator if we reached the end + if len(nodes) == 0 { + db.seeder.Release() + db.seeder = nil + } + return nodes +} + +// close flushes and closes the database files. +func (db *nodeDB) close() { + if db.seeder != nil { + db.seeder.Release() + } + close(db.quit) + db.lvl.Close() +} diff --git a/p2p/discover/database_test.go b/p2p/discover/database_test.go new file mode 100644 index 000000000..f327cf73b --- /dev/null +++ b/p2p/discover/database_test.go @@ -0,0 +1,313 @@ +package discover + +import ( + "bytes" + "io/ioutil" + "net" + "os" + "path/filepath" + "testing" + "time" +) + +var nodeDBKeyTests = []struct { + id NodeID + field string + key []byte +}{ + { + id: NodeID{}, + field: "version", + key: []byte{0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e}, // field + }, + { + id: MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), + field: ":discover", + key: []byte{0x6e, 0x3a, // prefix + 0x1d, 0xd9, 0xd6, 0x5c, 0x45, 0x52, 0xb5, 0xeb, // node id + 0x43, 0xd5, 0xad, 0x55, 0xa2, 0xee, 0x3f, 0x56, // + 0xc6, 0xcb, 0xc1, 0xc6, 0x4a, 0x5c, 0x8d, 0x65, // + 0x9f, 0x51, 0xfc, 0xd5, 0x1b, 0xac, 0xe2, 0x43, // + 0x51, 0x23, 0x2b, 0x8d, 0x78, 0x21, 0x61, 0x7d, // + 0x2b, 0x29, 0xb5, 0x4b, 0x81, 0xcd, 0xef, 0xb9, // + 0xb3, 0xe9, 0xc3, 0x7d, 0x7f, 0xd5, 0xf6, 0x32, // + 0x70, 0xbc, 0xc9, 0xe1, 0xa6, 0xf6, 0xa4, 0x39, // + 0x3a, 0x64, 0x69, 0x73, 0x63, 0x6f, 0x76, 0x65, 0x72, // field + }, + }, +} + +func TestNodeDBKeys(t *testing.T) { + for i, tt := range nodeDBKeyTests { + if key := makeKey(tt.id, tt.field); !bytes.Equal(key, tt.key) { + t.Errorf("make test %d: key mismatch: have 0x%x, want 0x%x", i, key, tt.key) + } + id, field := splitKey(tt.key) + if !bytes.Equal(id[:], tt.id[:]) { + t.Errorf("split test %d: id mismatch: have 0x%x, want 0x%x", i, id, tt.id) + } + if field != tt.field { + t.Errorf("split test %d: field mismatch: have 0x%x, want 0x%x", i, field, tt.field) + } + } +} + +var nodeDBInt64Tests = []struct { + key []byte + value int64 +}{ + {key: []byte{0x01}, value: 1}, + {key: []byte{0x02}, value: 2}, + {key: []byte{0x03}, value: 3}, +} + +func TestNodeDBInt64(t *testing.T) { + db, _ := newNodeDB("", Version) + defer db.close() + + tests := nodeDBInt64Tests + for i := 0; i < len(tests); i++ { + // Insert the next value + if err := db.storeInt64(tests[i].key, tests[i].value); err != nil { + t.Errorf("test %d: failed to store value: %v", i, err) + } + // Check all existing and non existing values + for j := 0; j < len(tests); j++ { + num := db.fetchInt64(tests[j].key) + switch { + case j <= i && num != tests[j].value: + t.Errorf("test %d, item %d: value mismatch: have %v, want %v", i, j, num, tests[j].value) + case j > i && num != 0: + t.Errorf("test %d, item %d: value mismatch: have %v, want %v", i, j, num, 0) + } + } + } +} + +func TestNodeDBFetchStore(t *testing.T) { + node := &Node{ + ID: MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), + IP: net.IP([]byte{192, 168, 0, 1}), + TCPPort: 30303, + } + inst := time.Now() + + db, _ := newNodeDB("", Version) + defer db.close() + + // Check fetch/store operations on a node ping object + if stored := db.lastPing(node.ID); stored.Unix() != 0 { + t.Errorf("ping: non-existing object: %v", stored) + } + if err := db.updateLastPing(node.ID, inst); err != nil { + t.Errorf("ping: failed to update: %v", err) + } + if stored := db.lastPing(node.ID); stored.Unix() != inst.Unix() { + t.Errorf("ping: value mismatch: have %v, want %v", stored, inst) + } + // Check fetch/store operations on a node pong object + if stored := db.lastPong(node.ID); stored.Unix() != 0 { + t.Errorf("pong: non-existing object: %v", stored) + } + if err := db.updateLastPong(node.ID, inst); err != nil { + t.Errorf("pong: failed to update: %v", err) + } + if stored := db.lastPong(node.ID); stored.Unix() != inst.Unix() { + t.Errorf("pong: value mismatch: have %v, want %v", stored, inst) + } + // Check fetch/store operations on an actual node object + if stored := db.node(node.ID); stored != nil { + t.Errorf("node: non-existing object: %v", stored) + } + if err := db.updateNode(node); err != nil { + t.Errorf("node: failed to update: %v", err) + } + if stored := db.node(node.ID); stored == nil { + t.Errorf("node: not found") + } else if !bytes.Equal(stored.ID[:], node.ID[:]) || !stored.IP.Equal(node.IP) || stored.TCPPort != node.TCPPort { + t.Errorf("node: data mismatch: have %v, want %v", stored, node) + } +} + +var nodeDBSeedQueryNodes = []struct { + node Node + pong time.Time +}{ + { + node: Node{ + ID: MustHexID("0x01d9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), + IP: []byte{127, 0, 0, 1}, + }, + pong: time.Now().Add(-2 * time.Second), + }, + { + node: Node{ + ID: MustHexID("0x02d9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), + IP: []byte{127, 0, 0, 2}, + }, + pong: time.Now().Add(-3 * time.Second), + }, + { + node: Node{ + ID: MustHexID("0x03d9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), + IP: []byte{127, 0, 0, 3}, + }, + pong: time.Now().Add(-1 * time.Second), + }, +} + +func TestNodeDBSeedQuery(t *testing.T) { + db, _ := newNodeDB("", Version) + defer db.close() + + // Insert a batch of nodes for querying + for i, seed := range nodeDBSeedQueryNodes { + if err := db.updateNode(&seed.node); err != nil { + t.Fatalf("node %d: failed to insert: %v", i, err) + } + } + // Retrieve the entire batch and check for duplicates + seeds := db.querySeeds(2 * len(nodeDBSeedQueryNodes)) + if len(seeds) != len(nodeDBSeedQueryNodes) { + t.Errorf("seed count mismatch: have %v, want %v", len(seeds), len(nodeDBSeedQueryNodes)) + } + have := make(map[NodeID]struct{}) + for _, seed := range seeds { + have[seed.ID] = struct{}{} + } + want := make(map[NodeID]struct{}) + for _, seed := range nodeDBSeedQueryNodes { + want[seed.node.ID] = struct{}{} + } + for id, _ := range have { + if _, ok := want[id]; !ok { + t.Errorf("extra seed: %v", id) + } + } + for id, _ := range want { + if _, ok := have[id]; !ok { + t.Errorf("missing seed: %v", id) + } + } + // Make sure the next batch is empty (seed EOF) + seeds = db.querySeeds(2 * len(nodeDBSeedQueryNodes)) + if len(seeds) != 0 { + t.Errorf("seed count mismatch: have %v, want %v", len(seeds), 0) + } +} + +func TestNodeDBSeedQueryContinuation(t *testing.T) { + db, _ := newNodeDB("", Version) + defer db.close() + + // Insert a batch of nodes for querying + for i, seed := range nodeDBSeedQueryNodes { + if err := db.updateNode(&seed.node); err != nil { + t.Fatalf("node %d: failed to insert: %v", i, err) + } + } + // Iteratively retrieve the batch, checking for an empty batch on reset + for i := 0; i < len(nodeDBSeedQueryNodes); i++ { + if seeds := db.querySeeds(1); len(seeds) != 1 { + t.Errorf("1st iteration %d: seed count mismatch: have %v, want %v", i, len(seeds), 1) + } + } + if seeds := db.querySeeds(1); len(seeds) != 0 { + t.Errorf("reset: seed count mismatch: have %v, want %v", len(seeds), 0) + } + for i := 0; i < len(nodeDBSeedQueryNodes); i++ { + if seeds := db.querySeeds(1); len(seeds) != 1 { + t.Errorf("2nd iteration %d: seed count mismatch: have %v, want %v", i, len(seeds), 1) + } + } +} + +func TestNodeDBPersistency(t *testing.T) { + root, err := ioutil.TempDir("", "nodedb-") + if err != nil { + t.Fatalf("failed to create temporary data folder: %v", err) + } + defer os.RemoveAll(root) + + var ( + testKey = []byte("somekey") + testInt = int64(314) + ) + + // Create a persistent database and store some values + db, err := newNodeDB(filepath.Join("root", "database"), Version) + if err != nil { + t.Fatalf("failed to create persistent database: %v", err) + } + if err := db.storeInt64(testKey, testInt); err != nil { + t.Fatalf("failed to store value: %v.", err) + } + db.close() + + // Reopen the database and check the value + db, err = newNodeDB(filepath.Join("root", "database"), Version) + if err != nil { + t.Fatalf("failed to open persistent database: %v", err) + } + if val := db.fetchInt64(testKey); val != testInt { + t.Fatalf("value mismatch: have %v, want %v", val, testInt) + } + db.close() + + // Change the database version and check flush + db, err = newNodeDB(filepath.Join("root", "database"), Version+1) + if err != nil { + t.Fatalf("failed to open persistent database: %v", err) + } + if val := db.fetchInt64(testKey); val != 0 { + t.Fatalf("value mismatch: have %v, want %v", val, 0) + } + db.close() +} + +var nodeDBExpirationNodes = []struct { + node Node + pong time.Time + exp bool +}{ + { + node: Node{ + ID: MustHexID("0x01d9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), + IP: []byte{127, 0, 0, 1}, + }, + pong: time.Now().Add(-nodeDBNodeExpiration + time.Minute), + exp: false, + }, { + node: Node{ + ID: MustHexID("0x02d9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), + IP: []byte{127, 0, 0, 2}, + }, + pong: time.Now().Add(-nodeDBNodeExpiration - time.Minute), + exp: true, + }, +} + +func TestNodeDBExpiration(t *testing.T) { + db, _ := newNodeDB("", Version) + defer db.close() + + // Add all the test nodes and set their last pong time + for i, seed := range nodeDBExpirationNodes { + if err := db.updateNode(&seed.node); err != nil { + t.Fatalf("node %d: failed to insert: %v", i, err) + } + if err := db.updateLastPong(seed.node.ID, seed.pong); err != nil { + t.Fatalf("node %d: failed to update pong: %v", i, err) + } + } + // Expire some of them, and check the rest + if err := db.expireNodes(); err != nil { + t.Fatalf("failed to expire nodes: %v", err) + } + for i, seed := range nodeDBExpirationNodes { + node := db.node(seed.node.ID) + if (node == nil && !seed.exp) || (node != nil && seed.exp) { + t.Errorf("node %d: expiration mismatch: have %v, want %v", i, node, seed.exp) + } + } +} diff --git a/p2p/discover/node.go b/p2p/discover/node.go index 6662a6cb7..e66ca37a4 100644 --- a/p2p/discover/node.go +++ b/p2p/discover/node.go @@ -13,7 +13,6 @@ import ( "net/url" "strconv" "strings" - "sync" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto/secp256k1" @@ -305,26 +304,3 @@ func randomID(a NodeID, n int) (b NodeID) { } return b } - -// nodeDB stores all nodes we know about. -type nodeDB struct { - mu sync.RWMutex - byID map[NodeID]*Node -} - -func (db *nodeDB) get(id NodeID) *Node { - db.mu.RLock() - defer db.mu.RUnlock() - return db.byID[id] -} - -func (db *nodeDB) add(id NodeID, addr *net.UDPAddr, tcpPort uint16) *Node { - db.mu.Lock() - defer db.mu.Unlock() - if db.byID == nil { - db.byID = make(map[NodeID]*Node) - } - n := &Node{ID: id, IP: addr.IP, DiscPort: addr.Port, TCPPort: int(tcpPort)} - db.byID[n.ID] = n - return n -} diff --git a/p2p/discover/table.go b/p2p/discover/table.go index e2e846456..d3fe373f4 100644 --- a/p2p/discover/table.go +++ b/p2p/discover/table.go @@ -11,6 +11,9 @@ import ( "sort" "sync" "time" + + "github.com/ethereum/go-ethereum/logger" + "github.com/ethereum/go-ethereum/logger/glog" ) const ( @@ -24,6 +27,7 @@ type Table struct { mutex sync.Mutex // protects buckets, their content, and nursery buckets [nBuckets]*bucket // index of known nodes by distance nursery []*Node // bootstrap nodes + db *nodeDB // database of known nodes bondmu sync.Mutex bonding map[NodeID]*bondproc @@ -31,7 +35,6 @@ type Table struct { net transport self *Node // metadata of the local node - db *nodeDB } type bondproc struct { @@ -58,10 +61,16 @@ type bucket struct { entries []*Node } -func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr) *Table { +func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string) *Table { + // If no node database was given, use an in-memory one + db, err := newNodeDB(nodeDBPath, Version) + if err != nil { + glog.V(logger.Warn).Infoln("Failed to open node database:", err) + db, _ = newNodeDB("", Version) + } tab := &Table{ net: t, - db: new(nodeDB), + db: db, self: newNode(ourID, ourAddr), bonding: make(map[NodeID]*bondproc), bondslots: make(chan struct{}, maxBondingPingPongs), @@ -80,9 +89,10 @@ func (tab *Table) Self() *Node { return tab.self } -// Close terminates the network listener. +// Close terminates the network listener and flushes the node database. func (tab *Table) Close() { tab.net.close() + tab.db.close() } // Bootstrap sets the bootstrap nodes. These nodes are used to connect @@ -166,8 +176,13 @@ func (tab *Table) refresh() { result := tab.Lookup(randomID(tab.self.ID, ld)) if len(result) == 0 { - // bootstrap the table with a self lookup - all := tab.bondall(tab.nursery) + // Pick a batch of previously know seeds to lookup with + seeds := tab.db.querySeeds(10) + for _, seed := range seeds { + glog.V(logger.Debug).Infoln("Seeding network with", seed) + } + // Bootstrap the table with a self lookup + all := tab.bondall(append(tab.nursery, seeds...)) tab.mutex.Lock() tab.add(all) tab.mutex.Unlock() @@ -235,7 +250,7 @@ func (tab *Table) bondall(nodes []*Node) (result []*Node) { // of the process can be skipped. func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16) (*Node, error) { var n *Node - if n = tab.db.get(id); n == nil { + if n = tab.db.node(id); n == nil { tab.bondmu.Lock() w := tab.bonding[id] if w != nil { @@ -268,9 +283,12 @@ func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16 } func (tab *Table) pingpong(w *bondproc, pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16) { + // Request a bonding slot to limit network usage <-tab.bondslots defer func() { tab.bondslots <- struct{}{} }() - if w.err = tab.net.ping(id, addr); w.err != nil { + + // Ping the remote side and wait for a pong + if w.err = tab.ping(id, addr); w.err != nil { close(w.done) return } @@ -280,14 +298,21 @@ func (tab *Table) pingpong(w *bondproc, pinged bool, id NodeID, addr *net.UDPAdd // waitping will simply time out. tab.net.waitping(id) } - w.n = tab.db.add(id, addr, tcpPort) + // Bonding succeeded, update the node database + w.n = &Node{ + ID: id, + IP: addr.IP, + DiscPort: addr.Port, + TCPPort: int(tcpPort), + } + tab.db.updateNode(w.n) close(w.done) } func (tab *Table) pingreplace(new *Node, b *bucket) { if len(b.entries) == bucketSize { oldest := b.entries[bucketSize-1] - if err := tab.net.ping(oldest.ID, oldest.addr()); err == nil { + if err := tab.ping(oldest.ID, oldest.addr()); err == nil { // The node responded, we don't need to replace it. return } @@ -300,6 +325,21 @@ func (tab *Table) pingreplace(new *Node, b *bucket) { b.entries[0] = new } +// ping a remote endpoint and wait for a reply, also updating the node database +// accordingly. +func (tab *Table) ping(id NodeID, addr *net.UDPAddr) error { + // Update the last ping and send the message + tab.db.updateLastPing(id, time.Now()) + if err := tab.net.ping(id, addr); err != nil { + return err + } + // Pong received, update the database and return + tab.db.updateLastPong(id, time.Now()) + tab.db.ensureExpirer() + + return nil +} + // add puts the entries into the table if their corresponding // bucket is not full. The caller must hold tab.mutex. func (tab *Table) add(entries []*Node) { diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go index a98376bca..e2bd3c8ad 100644 --- a/p2p/discover/table_test.go +++ b/p2p/discover/table_test.go @@ -15,7 +15,7 @@ import ( func TestTable_pingReplace(t *testing.T) { doit := func(newNodeIsResponding, lastInBucketIsResponding bool) { transport := newPingRecorder() - tab := newTable(transport, NodeID{}, &net.UDPAddr{}) + tab := newTable(transport, NodeID{}, &net.UDPAddr{}, "") last := fillBucket(tab, 200) pingSender := randomID(tab.self.ID, 200) @@ -145,7 +145,7 @@ func TestTable_closest(t *testing.T) { test := func(test *closeTest) bool { // for any node table, Target and N - tab := newTable(nil, test.Self, &net.UDPAddr{}) + tab := newTable(nil, test.Self, &net.UDPAddr{}, "") tab.add(test.All) // check that doClosest(Target, N) returns nodes @@ -217,7 +217,7 @@ func TestTable_Lookup(t *testing.T) { self := gen(NodeID{}, quickrand).(NodeID) target := randomID(self, 200) transport := findnodeOracle{t, target} - tab := newTable(transport, self, &net.UDPAddr{}) + tab := newTable(transport, self, &net.UDPAddr{}, "") // lookup on empty table returns no nodes if results := tab.Lookup(target); len(results) > 0 { diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go index 07a1a739c..65741b5f5 100644 --- a/p2p/discover/udp.go +++ b/p2p/discover/udp.go @@ -144,7 +144,7 @@ type reply struct { } // ListenUDP returns a new table that listens for UDP packets on laddr. -func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface) (*Table, error) { +func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string) (*Table, error) { addr, err := net.ResolveUDPAddr("udp", laddr) if err != nil { return nil, err @@ -153,12 +153,12 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface) (*Table if err != nil { return nil, err } - tab, _ := newUDP(priv, conn, natm) + tab, _ := newUDP(priv, conn, natm, nodeDBPath) glog.V(logger.Info).Infoln("Listening,", tab.self) return tab, nil } -func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface) (*Table, *udp) { +func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface, nodeDBPath string) (*Table, *udp) { udp := &udp{ conn: c, priv: priv, @@ -176,7 +176,7 @@ func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface) (*Table, *udp) { realaddr = &net.UDPAddr{IP: ext, Port: realaddr.Port} } } - udp.Table = newTable(udp, PubkeyID(&priv.PublicKey), realaddr) + udp.Table = newTable(udp, PubkeyID(&priv.PublicKey), realaddr, nodeDBPath) go udp.loop() go udp.readLoop() return udp.Table, udp @@ -449,7 +449,7 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte if expired(req.Expiration) { return errExpired } - if t.db.get(fromID) == nil { + if t.db.node(fromID) == nil { // No bond exists, we don't process the packet. This prevents // an attack vector where the discovery protocol could be used // to amplify traffic in a DDOS attack. A malicious actor diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go index c6c4d78e3..47e04b85a 100644 --- a/p2p/discover/udp_test.go +++ b/p2p/discover/udp_test.go @@ -41,7 +41,7 @@ func newUDPTest(t *testing.T) *udpTest { remotekey: newkey(), remoteaddr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 30303}, } - test.table, test.udp = newUDP(test.localkey, test.pipe, nil) + test.table, test.udp = newUDP(test.localkey, test.pipe, nil, "") return test } @@ -157,8 +157,12 @@ func TestUDP_findnode(t *testing.T) { // ensure there's a bond with the test node, // findnode won't be accepted otherwise. - test.table.db.add(PubkeyID(&test.remotekey.PublicKey), test.remoteaddr, 99) - + test.table.db.updateNode(&Node{ + ID: PubkeyID(&test.remotekey.PublicKey), + IP: test.remoteaddr.IP, + DiscPort: test.remoteaddr.Port, + TCPPort: 99, + }) // check that closest neighbors are returned. test.packetIn(nil, findnodePacket, &findnode{Target: testTarget, Expiration: futureExp}) test.waitPacketOut(func(p *neighbors) { diff --git a/p2p/server.go b/p2p/server.go index ecf418d13..5c5883ae8 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -59,6 +59,10 @@ type Server struct { // with the rest of the network. BootstrapNodes []*discover.Node + // NodeDatabase is the path to the database containing the previously seen + // live nodes in the network. + NodeDatabase string + // Protocols should contain the protocols supported // by the server. Matching protocols are launched for // each peer. @@ -197,7 +201,7 @@ func (srv *Server) Start() (err error) { } // node table - ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT) + ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT, srv.NodeDatabase) if err != nil { return err }