diff --git a/cmd/devp2p/dnscmd.go b/cmd/devp2p/dnscmd.go index eb15764b0..f24510405 100644 --- a/cmd/devp2p/dnscmd.go +++ b/cmd/devp2p/dnscmd.go @@ -214,8 +214,7 @@ func dnsClient(ctx *cli.Context) *dnsdisc.Client { if commandHasFlag(ctx, dnsTimeoutFlag) { cfg.Timeout = ctx.Duration(dnsTimeoutFlag.Name) } - c, _ := dnsdisc.NewClient(cfg) // cannot fail because no URLs given - return c + return dnsdisc.NewClient(cfg) } // There are two file formats for DNS node trees on disk: diff --git a/go.mod b/go.mod index d4b420833..a280949e9 100644 --- a/go.mod +++ b/go.mod @@ -60,6 +60,7 @@ require ( golang.org/x/sync v0.0.0-20181108010431-42b317875d0f golang.org/x/sys v0.0.0-20190712062909-fae7ac547cb7 golang.org/x/text v0.3.2 + golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce gopkg.in/olebedev/go-duktape.v3 v3.0.0-20190213234257-ec84240a7772 gopkg.in/sourcemap.v1 v1.0.5 // indirect diff --git a/p2p/dnsdisc/client.go b/p2p/dnsdisc/client.go index 677c0aa92..a29f82cd8 100644 --- a/p2p/dnsdisc/client.go +++ b/p2p/dnsdisc/client.go @@ -23,6 +23,7 @@ import ( "math/rand" "net" "strings" + "sync" "time" "github.com/ethereum/go-ethereum/common/mclock" @@ -31,15 +32,13 @@ import ( "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enr" lru "github.com/hashicorp/golang-lru" + "golang.org/x/time/rate" ) // Client discovers nodes by querying DNS servers. type Client struct { - cfg Config - clock mclock.Clock - linkCache linkCache - trees map[string]*clientTree - + cfg Config + clock mclock.Clock entries *lru.Cache } @@ -48,6 +47,7 @@ type Config struct { Timeout time.Duration // timeout used for DNS lookups (default 5s) RecheckInterval time.Duration // time between tree root update checks (default 30min) CacheLimit int // maximum number of cached records (default 1000) + RateLimit float64 // maximum DNS requests / second (default 3) ValidSchemes enr.IdentityScheme // acceptable ENR identity schemes (default enode.ValidSchemes) Resolver Resolver // the DNS resolver to use (defaults to system DNS) Logger log.Logger // destination of client log messages (defaults to root logger) @@ -60,9 +60,10 @@ type Resolver interface { func (cfg Config) withDefaults() Config { const ( - defaultTimeout = 5 * time.Second - defaultRecheck = 30 * time.Minute - defaultCache = 1000 + defaultTimeout = 5 * time.Second + defaultRecheck = 30 * time.Minute + defaultRateLimit = 3 + defaultCache = 1000 ) if cfg.Timeout == 0 { cfg.Timeout = defaultTimeout @@ -73,6 +74,9 @@ func (cfg Config) withDefaults() Config { if cfg.CacheLimit == 0 { cfg.CacheLimit = defaultCache } + if cfg.RateLimit == 0 { + cfg.RateLimit = defaultRateLimit + } if cfg.ValidSchemes == nil { cfg.ValidSchemes = enode.ValidSchemes } @@ -86,32 +90,24 @@ func (cfg Config) withDefaults() Config { } // NewClient creates a client. -func NewClient(cfg Config, urls ...string) (*Client, error) { - c := &Client{ - cfg: cfg.withDefaults(), - clock: mclock.System{}, - trees: make(map[string]*clientTree), - } - var err error - if c.entries, err = lru.New(c.cfg.CacheLimit); err != nil { - return nil, err - } - for _, url := range urls { - if err := c.AddTree(url); err != nil { - return nil, err - } +func NewClient(cfg Config) *Client { + cfg = cfg.withDefaults() + cache, err := lru.New(cfg.CacheLimit) + if err != nil { + panic(err) } - return c, nil + rlimit := rate.NewLimiter(rate.Limit(cfg.RateLimit), 10) + cfg.Resolver = &rateLimitResolver{cfg.Resolver, rlimit} + return &Client{cfg: cfg, entries: cache, clock: mclock.System{}} } -// SyncTree downloads the entire node tree at the given URL. This doesn't add the tree for -// later use, but any previously-synced entries are reused. +// SyncTree downloads the entire node tree at the given URL. func (c *Client) SyncTree(url string) (*Tree, error) { le, err := parseLink(url) if err != nil { return nil, fmt.Errorf("invalid enrtree URL: %v", err) } - ct := newClientTree(c, le) + ct := newClientTree(c, new(linkCache), le) t := &Tree{entries: make(map[string]entry)} if err := ct.syncAll(t.entries); err != nil { return nil, err @@ -120,75 +116,16 @@ func (c *Client) SyncTree(url string) (*Tree, error) { return t, nil } -// AddTree adds a enrtree:// URL to crawl. -func (c *Client) AddTree(url string) error { - le, err := parseLink(url) - if err != nil { - return fmt.Errorf("invalid enrtree URL: %v", err) - } - ct, err := c.ensureTree(le) - if err != nil { - return err - } - c.linkCache.add(ct) - return nil -} - -func (c *Client) ensureTree(le *linkEntry) (*clientTree, error) { - if tree, ok := c.trees[le.domain]; ok { - if !tree.matchPubkey(le.pubkey) { - return nil, fmt.Errorf("conflicting public keys for domain %q", le.domain) - } - return tree, nil - } - ct := newClientTree(c, le) - c.trees[le.domain] = ct - return ct, nil -} - -// RandomNode retrieves the next random node. -func (c *Client) RandomNode(ctx context.Context) *enode.Node { - for { - ct := c.randomTree() - if ct == nil { - return nil - } - n, err := ct.syncRandom(ctx) - if err != nil { - if err == ctx.Err() { - return nil // context canceled. - } - c.cfg.Logger.Debug("Error in DNS random node sync", "tree", ct.loc.domain, "err", err) - continue - } - if n != nil { - return n - } - } -} - -// randomTree returns a random tree. -func (c *Client) randomTree() *clientTree { - if !c.linkCache.valid() { - c.gcTrees() - } - limit := rand.Intn(len(c.trees)) - for _, ct := range c.trees { - if limit == 0 { - return ct +// NewIterator creates an iterator that visits all nodes at the +// given tree URLs. +func (c *Client) NewIterator(urls ...string) (enode.Iterator, error) { + it := c.newRandomIterator() + for _, url := range urls { + if err := it.addTree(url); err != nil { + return nil, err } - limit-- - } - return nil -} - -// gcTrees rebuilds the 'trees' map. -func (c *Client) gcTrees() { - trees := make(map[string]*clientTree) - for t := range c.linkCache.all() { - trees[t.loc.domain] = t } - c.trees = trees + return it, nil } // resolveRoot retrieves a root entry via DNS. @@ -258,3 +195,128 @@ func (c *Client) doResolveEntry(ctx context.Context, domain, hash string) (entry } return nil, nameError{name, errNoEntry} } + +// rateLimitResolver applies a rate limit to a Resolver. +type rateLimitResolver struct { + r Resolver + limiter *rate.Limiter +} + +func (r *rateLimitResolver) LookupTXT(ctx context.Context, domain string) ([]string, error) { + if err := r.limiter.Wait(ctx); err != nil { + return nil, err + } + return r.r.LookupTXT(ctx, domain) +} + +// randomIterator traverses a set of trees and returns nodes found in them. +type randomIterator struct { + cur *enode.Node + ctx context.Context + cancelFn context.CancelFunc + c *Client + + mu sync.Mutex + trees map[string]*clientTree // all trees + lc linkCache // tracks tree dependencies +} + +func (c *Client) newRandomIterator() *randomIterator { + ctx, cancel := context.WithCancel(context.Background()) + return &randomIterator{ + c: c, + ctx: ctx, + cancelFn: cancel, + trees: make(map[string]*clientTree), + } +} + +// Node returns the current node. +func (it *randomIterator) Node() *enode.Node { + return it.cur +} + +// Close closes the iterator. +func (it *randomIterator) Close() { + it.mu.Lock() + defer it.mu.Unlock() + + it.cancelFn() + it.trees = nil +} + +// Next moves the iterator to the next node. +func (it *randomIterator) Next() bool { + it.cur = it.nextNode() + return it.cur != nil +} + +// addTree adds a enrtree:// URL to the iterator. +func (it *randomIterator) addTree(url string) error { + le, err := parseLink(url) + if err != nil { + return fmt.Errorf("invalid enrtree URL: %v", err) + } + it.lc.addLink("", le.str) + return nil +} + +// nextNode syncs random tree entries until it finds a node. +func (it *randomIterator) nextNode() *enode.Node { + for { + ct := it.nextTree() + if ct == nil { + return nil + } + n, err := ct.syncRandom(it.ctx) + if err != nil { + if err == it.ctx.Err() { + return nil // context canceled. + } + it.c.cfg.Logger.Debug("Error in DNS random node sync", "tree", ct.loc.domain, "err", err) + continue + } + if n != nil { + return n + } + } +} + +// nextTree returns a random tree. +func (it *randomIterator) nextTree() *clientTree { + it.mu.Lock() + defer it.mu.Unlock() + + if it.lc.changed { + it.rebuildTrees() + it.lc.changed = false + } + if len(it.trees) == 0 { + return nil + } + limit := rand.Intn(len(it.trees)) + for _, ct := range it.trees { + if limit == 0 { + return ct + } + limit-- + } + return nil +} + +// rebuildTrees rebuilds the 'trees' map. +func (it *randomIterator) rebuildTrees() { + // Delete removed trees. + for loc := range it.trees { + if !it.lc.isReferenced(loc) { + delete(it.trees, loc) + } + } + // Add new trees. + for loc := range it.lc.backrefs { + if it.trees[loc] == nil { + link, _ := parseLink(linkPrefix + loc) + it.trees[loc] = newClientTree(it.c, &it.lc, link) + } + } +} diff --git a/p2p/dnsdisc/client_test.go b/p2p/dnsdisc/client_test.go index d8e3ecee3..1a1d5ade9 100644 --- a/p2p/dnsdisc/client_test.go +++ b/p2p/dnsdisc/client_test.go @@ -54,7 +54,7 @@ func TestClientSyncTree(t *testing.T) { wantSeq = uint(1) ) - c, _ := NewClient(Config{Resolver: r, Logger: testlog.Logger(t, log.LvlTrace)}) + c := NewClient(Config{Resolver: r, Logger: testlog.Logger(t, log.LvlTrace)}) stree, err := c.SyncTree("enrtree://AKPYQIUQIL7PSIACI32J7FGZW56E5FKHEFCCOFHILBIMW3M6LWXS2@n") if err != nil { t.Fatal("sync error:", err) @@ -68,9 +68,6 @@ func TestClientSyncTree(t *testing.T) { if stree.Seq() != wantSeq { t.Errorf("synced tree has wrong seq: %d", stree.Seq()) } - if len(c.trees) > 0 { - t.Errorf("tree from SyncTree added to client") - } } // In this test, syncing the tree fails because it contains an invalid ENR entry. @@ -91,7 +88,7 @@ func TestClientSyncTreeBadNode(t *testing.T) { "C7HRFPF3BLGF3YR4DY5KX3SMBE.n": "enrtree://AM5FCQLWIZX2QFPNJAP7VUERCCRNGRHWZG3YYHIUV7BVDQ5FDPRT2@morenodes.example.org", "INDMVBZEEQ4ESVYAKGIYU74EAA.n": "enr:-----", } - c, _ := NewClient(Config{Resolver: r, Logger: testlog.Logger(t, log.LvlTrace)}) + c := NewClient(Config{Resolver: r, Logger: testlog.Logger(t, log.LvlTrace)}) _, err := c.SyncTree("enrtree://AKPYQIUQIL7PSIACI32J7FGZW56E5FKHEFCCOFHILBIMW3M6LWXS2@n") wantErr := nameError{name: "INDMVBZEEQ4ESVYAKGIYU74EAA.n", err: entryError{typ: "enr", err: errInvalidENR}} if err != wantErr { @@ -99,57 +96,89 @@ func TestClientSyncTreeBadNode(t *testing.T) { } } -// This test checks that RandomNode hits all entries. -func TestClientRandomNode(t *testing.T) { +// This test checks that randomIterator finds all entries. +func TestIterator(t *testing.T) { nodes := testNodes(nodesSeed1, 30) tree, url := makeTestTree("n", nodes, nil) r := mapResolver(tree.ToTXT("n")) - c, _ := NewClient(Config{Resolver: r, Logger: testlog.Logger(t, log.LvlTrace)}) - if err := c.AddTree(url); err != nil { + c := NewClient(Config{ + Resolver: r, + Logger: testlog.Logger(t, log.LvlTrace), + RateLimit: 500, + }) + it, err := c.NewIterator(url) + if err != nil { t.Fatal(err) } - checkRandomNode(t, c, nodes) + checkIterator(t, it, nodes) } -// This test checks that RandomNode traverses linked trees as well as explicitly added trees. -func TestClientRandomNodeLinks(t *testing.T) { +// This test checks if closing randomIterator races. +func TestIteratorClose(t *testing.T) { + nodes := testNodes(nodesSeed1, 500) + tree1, url1 := makeTestTree("t1", nodes, nil) + c := NewClient(Config{Resolver: newMapResolver(tree1.ToTXT("t1"))}) + it, err := c.NewIterator(url1) + if err != nil { + t.Fatal(err) + } + + done := make(chan struct{}) + go func() { + for it.Next() { + _ = it.Node() + } + close(done) + }() + + time.Sleep(50 * time.Millisecond) + it.Close() + <-done +} + +// This test checks that randomIterator traverses linked trees as well as explicitly added trees. +func TestIteratorLinks(t *testing.T) { nodes := testNodes(nodesSeed1, 40) tree1, url1 := makeTestTree("t1", nodes[:10], nil) tree2, url2 := makeTestTree("t2", nodes[10:], []string{url1}) - cfg := Config{ - Resolver: newMapResolver(tree1.ToTXT("t1"), tree2.ToTXT("t2")), - Logger: testlog.Logger(t, log.LvlTrace), - } - c, _ := NewClient(cfg) - if err := c.AddTree(url2); err != nil { + c := NewClient(Config{ + Resolver: newMapResolver(tree1.ToTXT("t1"), tree2.ToTXT("t2")), + Logger: testlog.Logger(t, log.LvlTrace), + RateLimit: 500, + }) + it, err := c.NewIterator(url2) + if err != nil { t.Fatal(err) } - checkRandomNode(t, c, nodes) + checkIterator(t, it, nodes) } -// This test verifies that RandomNode re-checks the root of the tree to catch +// This test verifies that randomIterator re-checks the root of the tree to catch // updates to nodes. -func TestClientRandomNodeUpdates(t *testing.T) { +func TestIteratorNodeUpdates(t *testing.T) { var ( clock = new(mclock.Simulated) nodes = testNodes(nodesSeed1, 30) resolver = newMapResolver() - cfg = Config{ + c = NewClient(Config{ Resolver: resolver, Logger: testlog.Logger(t, log.LvlTrace), RecheckInterval: 20 * time.Minute, - } - c, _ = NewClient(cfg) + RateLimit: 500, + }) ) c.clock = clock tree1, url := makeTestTree("n", nodes[:25], nil) + it, err := c.NewIterator(url) + if err != nil { + t.Fatal(err) + } - // Sync the original tree. + // sync the original tree. resolver.add(tree1.ToTXT("n")) - c.AddTree(url) - checkRandomNode(t, c, nodes[:25]) + checkIterator(t, it, nodes[:25]) // Update some nodes and ensure RandomNode returns the new nodes as well. keys := testKeys(nodesSeed1, len(nodes)) @@ -162,25 +191,25 @@ func TestClientRandomNodeUpdates(t *testing.T) { nodes[i] = n2 } tree2, _ := makeTestTree("n", nodes, nil) - clock.Run(cfg.RecheckInterval + 1*time.Second) + clock.Run(c.cfg.RecheckInterval + 1*time.Second) resolver.clear() resolver.add(tree2.ToTXT("n")) - checkRandomNode(t, c, nodes) + checkIterator(t, it, nodes) } -// This test verifies that RandomNode re-checks the root of the tree to catch +// This test verifies that randomIterator re-checks the root of the tree to catch // updates to links. -func TestClientRandomNodeLinkUpdates(t *testing.T) { +func TestIteratorLinkUpdates(t *testing.T) { var ( clock = new(mclock.Simulated) nodes = testNodes(nodesSeed1, 30) resolver = newMapResolver() - cfg = Config{ + c = NewClient(Config{ Resolver: resolver, Logger: testlog.Logger(t, log.LvlTrace), RecheckInterval: 20 * time.Minute, - } - c, _ = NewClient(cfg) + RateLimit: 500, + }) ) c.clock = clock tree3, url3 := makeTestTree("t3", nodes[20:30], nil) @@ -190,49 +219,53 @@ func TestClientRandomNodeLinkUpdates(t *testing.T) { resolver.add(tree2.ToTXT("t2")) resolver.add(tree3.ToTXT("t3")) + it, err := c.NewIterator(url1) + if err != nil { + t.Fatal(err) + } + // Sync tree1 using RandomNode. - c.AddTree(url1) - checkRandomNode(t, c, nodes[:20]) + checkIterator(t, it, nodes[:20]) // Add link to tree3, remove link to tree2. tree1, _ = makeTestTree("t1", nodes[:10], []string{url3}) resolver.add(tree1.ToTXT("t1")) - clock.Run(cfg.RecheckInterval + 1*time.Second) + clock.Run(c.cfg.RecheckInterval + 1*time.Second) t.Log("tree1 updated") var wantNodes []*enode.Node wantNodes = append(wantNodes, tree1.Nodes()...) wantNodes = append(wantNodes, tree3.Nodes()...) - checkRandomNode(t, c, wantNodes) + checkIterator(t, it, wantNodes) // Check that linked trees are GCed when they're no longer referenced. - if len(c.trees) != 2 { - t.Errorf("client knows %d trees, want 2", len(c.trees)) + knownTrees := it.(*randomIterator).trees + if len(knownTrees) != 2 { + t.Errorf("client knows %d trees, want 2", len(knownTrees)) } } -func checkRandomNode(t *testing.T, c *Client, wantNodes []*enode.Node) { +func checkIterator(t *testing.T, it enode.Iterator, wantNodes []*enode.Node) { t.Helper() var ( want = make(map[enode.ID]*enode.Node) - maxCalls = len(wantNodes) * 2 + maxCalls = len(wantNodes) * 3 calls = 0 - ctx = context.Background() ) for _, n := range wantNodes { want[n.ID()] = n } for ; len(want) > 0 && calls < maxCalls; calls++ { - n := c.RandomNode(ctx) - if n == nil { - t.Fatalf("RandomNode returned nil (call %d)", calls) + if !it.Next() { + t.Fatalf("Next returned false (call %d)", calls) } + n := it.Node() delete(want, n.ID()) } - t.Logf("checkRandomNode called RandomNode %d times to find %d nodes", calls, len(wantNodes)) + t.Logf("checkIterator called Next %d times to find %d nodes", calls, len(wantNodes)) for _, n := range want { - t.Errorf("RandomNode didn't discover node %v", n.ID()) + t.Errorf("iterator didn't discover node %v", n.ID()) } } diff --git a/p2p/dnsdisc/sync.go b/p2p/dnsdisc/sync.go index 533dacc65..53423527a 100644 --- a/p2p/dnsdisc/sync.go +++ b/p2p/dnsdisc/sync.go @@ -18,7 +18,6 @@ package dnsdisc import ( "context" - "crypto/ecdsa" "math/rand" "time" @@ -28,27 +27,21 @@ import ( // clientTree is a full tree being synced. type clientTree struct { - c *Client - loc *linkEntry - root *rootEntry + c *Client + loc *linkEntry // link to this tree + lastRootCheck mclock.AbsTime // last revalidation of root + root *rootEntry enrs *subtreeSync links *subtreeSync - linkCache linkCache -} - -func newClientTree(c *Client, loc *linkEntry) *clientTree { - ct := &clientTree{c: c, loc: loc} - ct.linkCache.self = ct - return ct -} -func (ct *clientTree) matchPubkey(key *ecdsa.PublicKey) bool { - return keysEqual(ct.loc.pubkey, key) + lc *linkCache // tracks all links between all trees + curLinks map[string]struct{} // links contained in this tree + linkGCRoot string // root on which last link GC has run } -func keysEqual(k1, k2 *ecdsa.PublicKey) bool { - return k1.Curve == k2.Curve && k1.X.Cmp(k2.X) == 0 && k1.Y.Cmp(k2.Y) == 0 +func newClientTree(c *Client, lc *linkCache, loc *linkEntry) *clientTree { + return &clientTree{c: c, lc: lc, loc: loc} } // syncAll retrieves all entries of the tree. @@ -78,6 +71,7 @@ func (ct *clientTree) syncRandom(ctx context.Context) (*enode.Node, error) { err := ct.syncNextLink(ctx) return nil, err } + ct.gcLinks() // Sync next random entry in ENR tree. Once every node has been visited, we simply // start over. This is fine because entries are cached. @@ -87,6 +81,16 @@ func (ct *clientTree) syncRandom(ctx context.Context) (*enode.Node, error) { return ct.syncNextRandomENR(ctx) } +// gcLinks removes outdated links from the global link cache. GC runs once +// when the link sync finishes. +func (ct *clientTree) gcLinks() { + if !ct.links.done() || ct.root.lroot == ct.linkGCRoot { + return + } + ct.lc.resetLinks(ct.loc.str, ct.curLinks) + ct.linkGCRoot = ct.root.lroot +} + func (ct *clientTree) syncNextLink(ctx context.Context) error { hash := ct.links.missing[0] e, err := ct.links.resolveNext(ctx, hash) @@ -95,12 +99,9 @@ func (ct *clientTree) syncNextLink(ctx context.Context) error { } ct.links.missing = ct.links.missing[1:] - if le, ok := e.(*linkEntry); ok { - lt, err := ct.c.ensureTree(le) - if err != nil { - return err - } - ct.linkCache.add(lt) + if dest, ok := e.(*linkEntry); ok { + ct.lc.addLink(ct.loc.str, dest.str) + ct.curLinks[dest.str] = struct{}{} } return nil } @@ -150,7 +151,7 @@ func (ct *clientTree) updateRoot() error { // Invalidate subtrees if changed. if ct.links == nil || root.lroot != ct.links.root { ct.links = newSubtreeSync(ct.c, ct.loc, root.lroot, true) - ct.linkCache.reset() + ct.curLinks = make(map[string]struct{}) } if ct.enrs == nil || root.eroot != ct.enrs.root { ct.enrs = newSubtreeSync(ct.c, ct.loc, root.eroot, false) @@ -215,63 +216,51 @@ func (ts *subtreeSync) resolveNext(ctx context.Context, hash string) (entry, err return e, nil } -// linkCache tracks the links of a tree. +// linkCache tracks links between trees. type linkCache struct { - self *clientTree - directM map[*clientTree]struct{} // direct links - allM map[*clientTree]struct{} // direct & transitive links + backrefs map[string]map[string]struct{} + changed bool } -// reset clears the cache. -func (lc *linkCache) reset() { - lc.directM = nil - lc.allM = nil +func (lc *linkCache) isReferenced(r string) bool { + return len(lc.backrefs[r]) != 0 } -// add adds a direct link to the cache. -func (lc *linkCache) add(ct *clientTree) { - if lc.directM == nil { - lc.directM = make(map[*clientTree]struct{}) - } - if _, ok := lc.directM[ct]; !ok { - lc.invalidate() +func (lc *linkCache) addLink(from, to string) { + if _, ok := lc.backrefs[to][from]; ok { + return } - lc.directM[ct] = struct{}{} -} - -// invalidate resets the cache of transitive links. -func (lc *linkCache) invalidate() { - lc.allM = nil -} -// valid returns true when the cache of transitive links is up-to-date. -func (lc *linkCache) valid() bool { - // Re-check validity of child caches to catch updates. - for ct := range lc.allM { - if ct != lc.self && !ct.linkCache.valid() { - lc.allM = nil - break - } + if lc.backrefs == nil { + lc.backrefs = make(map[string]map[string]struct{}) + } + if _, ok := lc.backrefs[to]; !ok { + lc.backrefs[to] = make(map[string]struct{}) } - return lc.allM != nil + lc.backrefs[to][from] = struct{}{} + lc.changed = true } -// all returns all trees reachable through the cache. -func (lc *linkCache) all() map[*clientTree]struct{} { - if lc.valid() { - return lc.allM - } - // Remake lc.allM it by taking the union of all() across children. - m := make(map[*clientTree]struct{}) - if lc.self != nil { - m[lc.self] = struct{}{} - } - for ct := range lc.directM { - m[ct] = struct{}{} - for lt := range ct.linkCache.all() { - m[lt] = struct{}{} +// resetLinks clears all links of the given tree. +func (lc *linkCache) resetLinks(from string, keep map[string]struct{}) { + stk := []string{from} + for len(stk) > 0 { + item := stk[len(stk)-1] + stk = stk[:len(stk)-1] + + for r, refs := range lc.backrefs { + if _, ok := keep[r]; ok { + continue + } + if _, ok := refs[item]; !ok { + continue + } + lc.changed = true + delete(refs, item) + if len(refs) == 0 { + delete(lc.backrefs, r) + stk = append(stk, r) + } } } - lc.allM = m - return m } diff --git a/p2p/dnsdisc/sync_test.go b/p2p/dnsdisc/sync_test.go new file mode 100644 index 000000000..32af3656e --- /dev/null +++ b/p2p/dnsdisc/sync_test.go @@ -0,0 +1,83 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package dnsdisc + +import ( + "math/rand" + "strconv" + "testing" +) + +func TestLinkCache(t *testing.T) { + var lc linkCache + + // Check adding links. + lc.addLink("1", "2") + if !lc.changed { + t.Error("changed flag not set") + } + lc.changed = false + lc.addLink("1", "2") + if lc.changed { + t.Error("changed flag set after adding link that's already present") + } + lc.addLink("2", "3") + lc.addLink("3", "1") + lc.addLink("2", "4") + lc.changed = false + + if !lc.isReferenced("3") { + t.Error("3 not referenced") + } + if lc.isReferenced("6") { + t.Error("6 is referenced") + } + + lc.resetLinks("1", nil) + if !lc.changed { + t.Error("changed flag not set") + } + if len(lc.backrefs) != 0 { + t.Logf("%+v", lc) + t.Error("reference maps should be empty") + } +} + +func TestLinkCacheRandom(t *testing.T) { + tags := make([]string, 1000) + for i := range tags { + tags[i] = strconv.Itoa(i) + } + + // Create random links. + var lc linkCache + var remove []string + for i := 0; i < 100; i++ { + a, b := tags[rand.Intn(len(tags))], tags[rand.Intn(len(tags))] + lc.addLink(a, b) + remove = append(remove, a) + } + + // Remove all the links. + for _, s := range remove { + lc.resetLinks(s, nil) + } + if len(lc.backrefs) != 0 { + t.Logf("%+v", lc) + t.Error("reference maps should be empty") + } +} diff --git a/p2p/dnsdisc/tree.go b/p2p/dnsdisc/tree.go index eba2ff9c0..82a935ca4 100644 --- a/p2p/dnsdisc/tree.go +++ b/p2p/dnsdisc/tree.go @@ -48,7 +48,7 @@ func (t *Tree) Sign(key *ecdsa.PrivateKey, domain string) (url string, err error } root.sig = sig t.root = &root - link := &linkEntry{domain, &key.PublicKey} + link := newLinkEntry(domain, &key.PublicKey) return link.String(), nil } @@ -209,6 +209,7 @@ type ( node *enode.Node } linkEntry struct { + str string domain string pubkey *ecdsa.PublicKey } @@ -246,7 +247,8 @@ func (e *rootEntry) sigHash() []byte { func (e *rootEntry) verifySignature(pubkey *ecdsa.PublicKey) bool { sig := e.sig[:crypto.RecoveryIDOffset] // remove recovery id - return crypto.VerifySignature(crypto.FromECDSAPub(pubkey), e.sigHash(), sig) + enckey := crypto.FromECDSAPub(pubkey) + return crypto.VerifySignature(enckey, e.sigHash(), sig) } func (e *branchEntry) String() string { @@ -258,8 +260,13 @@ func (e *enrEntry) String() string { } func (e *linkEntry) String() string { - pubkey := b32format.EncodeToString(crypto.CompressPubkey(e.pubkey)) - return fmt.Sprintf("%s%s@%s", linkPrefix, pubkey, e.domain) + return linkPrefix + e.str +} + +func newLinkEntry(domain string, pubkey *ecdsa.PublicKey) *linkEntry { + key := b32format.EncodeToString(crypto.CompressPubkey(pubkey)) + str := key + "@" + domain + return &linkEntry{str, domain, pubkey} } // Entry Parsing @@ -319,7 +326,7 @@ func parseLink(e string) (*linkEntry, error) { if err != nil { return nil, entryError{"link", errBadPubkey} } - return &linkEntry{domain, key}, nil + return &linkEntry{e, domain, key}, nil } func parseBranch(e string) (entry, error) { diff --git a/p2p/dnsdisc/tree_test.go b/p2p/dnsdisc/tree_test.go index b6d0a8433..4048c35d6 100644 --- a/p2p/dnsdisc/tree_test.go +++ b/p2p/dnsdisc/tree_test.go @@ -91,7 +91,7 @@ func TestParseEntry(t *testing.T) { // Links { input: "enrtree://AKPYQIUQIL7PSIACI32J7FGZW56E5FKHEFCCOFHILBIMW3M6LWXS2@nodes.example.org", - e: &linkEntry{"nodes.example.org", &testkey.PublicKey}, + e: &linkEntry{"AKPYQIUQIL7PSIACI32J7FGZW56E5FKHEFCCOFHILBIMW3M6LWXS2@nodes.example.org", "nodes.example.org", &testkey.PublicKey}, }, { input: "enrtree://nodes.example.org",