diff --git a/p2p/dial.go b/p2p/dial.go index 24d4dc2e89..225709427c 100644 --- a/p2p/dial.go +++ b/p2p/dial.go @@ -24,6 +24,7 @@ import ( "fmt" mrand "math/rand" "net" + "net/netip" "sync" "sync/atomic" "time" @@ -31,6 +32,7 @@ import ( "github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" "github.com/ethereum/go-ethereum/p2p/netutil" ) @@ -77,6 +79,7 @@ var ( errRecentlyDialed = errors.New("recently dialed") errNetRestrict = errors.New("not contained in netrestrict list") errNoPort = errors.New("node does not provide TCP port") + errNoResolvedIP = errors.New("node does not provide a resolved IP") ) // dialer creates outbound connections and submits them into Server. @@ -90,16 +93,17 @@ var ( // to create peer connections to nodes arriving through the iterator. type dialScheduler struct { dialConfig - setupFunc dialSetupFunc - wg sync.WaitGroup - cancel context.CancelFunc - ctx context.Context - nodesIn chan *enode.Node - doneCh chan *dialTask - addStaticCh chan *enode.Node - remStaticCh chan *enode.Node - addPeerCh chan *conn - remPeerCh chan *conn + setupFunc dialSetupFunc + dnsLookupFunc func(ctx context.Context, network string, name string) ([]netip.Addr, error) + wg sync.WaitGroup + cancel context.CancelFunc + ctx context.Context + nodesIn chan *enode.Node + doneCh chan *dialTask + addStaticCh chan *enode.Node + remStaticCh chan *enode.Node + addPeerCh chan *conn + remPeerCh chan *conn // Everything below here belongs to loop and // should only be accessed by code on the loop goroutine. @@ -159,18 +163,19 @@ func (cfg dialConfig) withDefaults() dialConfig { func newDialScheduler(config dialConfig, it enode.Iterator, setupFunc dialSetupFunc) *dialScheduler { cfg := config.withDefaults() d := &dialScheduler{ - dialConfig: cfg, - historyTimer: mclock.NewAlarm(cfg.clock), - setupFunc: setupFunc, - dialing: make(map[enode.ID]*dialTask), - static: make(map[enode.ID]*dialTask), - peers: make(map[enode.ID]struct{}), - doneCh: make(chan *dialTask), - nodesIn: make(chan *enode.Node), - addStaticCh: make(chan *enode.Node), - remStaticCh: make(chan *enode.Node), - addPeerCh: make(chan *conn), - remPeerCh: make(chan *conn), + dialConfig: cfg, + historyTimer: mclock.NewAlarm(cfg.clock), + setupFunc: setupFunc, + dnsLookupFunc: net.DefaultResolver.LookupNetIP, + dialing: make(map[enode.ID]*dialTask), + static: make(map[enode.ID]*dialTask), + peers: make(map[enode.ID]struct{}), + doneCh: make(chan *dialTask), + nodesIn: make(chan *enode.Node), + addStaticCh: make(chan *enode.Node), + remStaticCh: make(chan *enode.Node), + addPeerCh: make(chan *conn), + remPeerCh: make(chan *conn), } d.lastStatsLog = d.clock.Now() d.ctx, d.cancel = context.WithCancel(context.Background()) @@ -274,7 +279,7 @@ loop: case node := <-d.addStaticCh: id := node.ID() _, exists := d.static[id] - d.log.Trace("Adding static node", "id", id, "ip", node.IPAddr(), "added", !exists) + d.log.Trace("Adding static node", "id", id, "endpoint", nodeEndpointForLog(node), "added", !exists) if exists { continue loop } @@ -433,10 +438,68 @@ func (d *dialScheduler) removeFromStaticPool(idx int) { task.staticPoolIndex = -1 } +// dnsResolveHostname updates the given node from its DNS hostname. +// This is used to resolve static dial targets. +func (d *dialScheduler) dnsResolveHostname(n *enode.Node) (*enode.Node, error) { + if n.Hostname() == "" { + return n, nil + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + foundIPs, err := d.dnsLookupFunc(ctx, "ip", n.Hostname()) + if err != nil { + return n, err + } + + // Check for IP updates. + var ( + nodeIP4, nodeIP6 netip.Addr + foundIP4, foundIP6 netip.Addr + ) + n.Load((*enr.IPv4Addr)(&nodeIP4)) + n.Load((*enr.IPv6Addr)(&nodeIP6)) + for _, ip := range foundIPs { + if ip.Is4() && !foundIP4.IsValid() { + foundIP4 = ip + } + if ip.Is6() && !foundIP6.IsValid() { + foundIP6 = ip + } + } + + if !foundIP4.IsValid() && !foundIP6.IsValid() { + // Lookup failed. + return n, errNoResolvedIP + } + if foundIP4 == nodeIP4 && foundIP6 == nodeIP6 { + // No updates necessary. + d.log.Trace("Node DNS lookup had no update", "id", n.ID(), "name", n.Hostname(), "ip", foundIP4, "ip6", foundIP6) + return n, nil + } + + // Update the node. Note this invalidates the ENR signature, because we use SignNull + // to create a modified copy. But this should be OK, since we just use the node as a + // dial target. And nodes will usually only have a DNS hostname if they came from a + // enode:// URL, which has no signature anyway. If it ever becomes a problem, the + // resolved IP could also be stored into dialTask instead of the node. + rec := n.Record() + if foundIP4.IsValid() { + rec.Set(enr.IPv4Addr(foundIP4)) + } + if foundIP6.IsValid() { + rec.Set(enr.IPv6Addr(foundIP6)) + } + rec.SetSeq(n.Seq()) // ensure seq not bumped by update + newNode := enode.SignNull(rec, n.ID()).WithHostname(n.Hostname()) + d.log.Debug("Node updated from DNS lookup", "id", n.ID(), "name", n.Hostname(), "ip", newNode.IP()) + return newNode, nil +} + // startDial runs the given dial task in a separate goroutine. func (d *dialScheduler) startDial(task *dialTask) { node := task.dest() - d.log.Trace("Starting p2p dial", "id", node.ID(), "ip", node.IPAddr(), "flag", task.flags) + d.log.Trace("Starting p2p dial", "id", node.ID(), "endpoint", nodeEndpointForLog(node), "flag", task.flags) hkey := string(node.ID().Bytes()) d.history.add(hkey, d.clock.Now().Add(dialHistoryExpiration)) d.dialing[node.ID()] = task @@ -473,14 +536,29 @@ func (t *dialTask) dest() *enode.Node { } func (t *dialTask) run(d *dialScheduler) { - if t.needResolve() && !t.resolve(d) { - return + if t.isStatic() { + // Resolve DNS. + if n := t.dest(); n.Hostname() != "" { + resolved, err := d.dnsResolveHostname(n) + if err != nil { + d.log.Warn("DNS lookup of static node failed", "id", n.ID(), "name", n.Hostname(), "err", err) + } else { + t.destPtr.Store(resolved) + } + } + // Try resolving node ID through the DHT if there is no IP address. + if !t.dest().IPAddr().IsValid() { + if !t.resolve(d) { + return // DHT resolve failed, skip dial. + } + } } err := t.dial(d, t.dest()) if err != nil { // For static nodes, resolve one more time if dialing fails. - if _, ok := err.(*dialError); ok && t.flags&staticDialedConn != 0 { + var dialErr *dialError + if errors.As(err, &dialErr) && t.isStatic() { if t.resolve(d) { t.dial(d, t.dest()) } @@ -488,8 +566,8 @@ func (t *dialTask) run(d *dialScheduler) { } } -func (t *dialTask) needResolve() bool { - return t.flags&staticDialedConn != 0 && !t.dest().IPAddr().IsValid() +func (t *dialTask) isStatic() bool { + return t.flags&staticDialedConn != 0 } // resolve attempts to find the current endpoint for the destination @@ -553,3 +631,10 @@ func cleanupDialErr(err error) error { } return err } + +func nodeEndpointForLog(n *enode.Node) string { + if n.Hostname() != "" { + return n.Hostname() + } + return n.IPAddr().String() +} diff --git a/p2p/dial_test.go b/p2p/dial_test.go index 13908f11ea..f18dacce2a 100644 --- a/p2p/dial_test.go +++ b/p2p/dial_test.go @@ -22,6 +22,7 @@ import ( "fmt" "math/rand" "net" + "net/netip" "reflect" "sync" "testing" @@ -394,6 +395,34 @@ func TestDialSchedResolve(t *testing.T) { }) } +func TestDialSchedDNSHostname(t *testing.T) { + t.Parallel() + + config := dialConfig{ + maxActiveDials: 1, + maxDialPeers: 1, + } + node := newNode(uintID(0x01), ":30303").WithHostname("node-hostname") + resolved := newNode(uintID(0x01), "1.2.3.4:30303").WithHostname("node-hostname") + runDialTest(t, config, []dialTestRound{ + { + update: func(d *dialScheduler) { + d.dnsLookupFunc = func(ctx context.Context, network string, name string) ([]netip.Addr, error) { + if name != "node-hostname" { + t.Error("wrong hostname in DNS lookup:", name) + } + result := []netip.Addr{netip.MustParseAddr("1.2.3.4")} + return result, nil + } + d.addStatic(node) + }, + wantNewDials: []*enode.Node{ + resolved, + }, + }, + }) +} + // ------- // Code below here is the framework for the tests above. diff --git a/p2p/enode/node.go b/p2p/enode/node.go index 4d93d3f6be..d6f2ac7ff5 100644 --- a/p2p/enode/node.go +++ b/p2p/enode/node.go @@ -37,6 +37,10 @@ var errMissingPrefix = errors.New("missing 'enr:' prefix for base64-encoded reco type Node struct { r enr.Record id ID + + // hostname tracks the DNS name of the node. + hostname string + // endpoint information ip netip.Addr udp uint16 @@ -77,6 +81,8 @@ func newNodeWithID(r *enr.Record, id ID) *Node { n.setIP4(ip4) case valid6: n.setIP6(ip6) + default: + n.setIPv4Ports() } return n } @@ -103,6 +109,10 @@ func localityScore(ip netip.Addr) int { func (n *Node) setIP4(ip netip.Addr) { n.ip = ip + n.setIPv4Ports() +} + +func (n *Node) setIPv4Ports() { n.Load((*enr.UDP)(&n.udp)) n.Load((*enr.TCP)(&n.tcp)) } @@ -184,6 +194,18 @@ func (n *Node) TCP() int { return int(n.tcp) } +// WithHostname adds a DNS hostname to the node. +func (n *Node) WithHostname(hostname string) *Node { + cpy := *n + cpy.hostname = hostname + return &cpy +} + +// Hostname returns the DNS name assigned by WithHostname. +func (n *Node) Hostname() string { + return n.hostname +} + // UDPEndpoint returns the announced UDP endpoint. func (n *Node) UDPEndpoint() (netip.AddrPort, bool) { if !n.ip.IsValid() || n.ip.IsUnspecified() || n.udp == 0 { diff --git a/p2p/enode/node_test.go b/p2p/enode/node_test.go index f38c77415e..e9fe631f34 100644 --- a/p2p/enode/node_test.go +++ b/p2p/enode/node_test.go @@ -74,6 +74,7 @@ func TestNodeEndpoints(t *testing.T) { wantUDP int wantTCP int wantQUIC int + wantDNS string } tests := []endpointTest{ { @@ -90,6 +91,7 @@ func TestNodeEndpoints(t *testing.T) { r.Set(enr.UDP(9000)) return SignNull(&r, id) }(), + wantUDP: 9000, }, { name: "tcp-only", @@ -98,6 +100,7 @@ func TestNodeEndpoints(t *testing.T) { r.Set(enr.TCP(9000)) return SignNull(&r, id) }(), + wantTCP: 9000, }, { name: "quic-only", @@ -268,6 +271,19 @@ func TestNodeEndpoints(t *testing.T) { wantIP: netip.MustParseAddr("2001::ff00:0042:8329"), wantQUIC: 9001, }, + { + name: "dns-only", + node: func() *Node { + var r enr.Record + r.Set(enr.UDP(30303)) + r.Set(enr.TCP(30303)) + n := SignNull(&r, id).WithHostname("example.com") + return n + }(), + wantTCP: 30303, + wantUDP: 30303, + wantDNS: "example.com", + }, } for _, test := range tests { @@ -284,6 +300,9 @@ func TestNodeEndpoints(t *testing.T) { if quic, _ := test.node.QUICEndpoint(); test.wantQUIC != int(quic.Port()) { t.Errorf("node has wrong QUIC port %d, want %d", quic.Port(), test.wantQUIC) } + if test.wantDNS != test.node.Hostname() { + t.Errorf("node has wrong DNS name %s, want %s", test.node.Hostname(), test.wantDNS) + } }) } } diff --git a/p2p/enode/urlv4.go b/p2p/enode/urlv4.go index a55dfa6632..b455cd4533 100644 --- a/p2p/enode/urlv4.go +++ b/p2p/enode/urlv4.go @@ -33,7 +33,6 @@ import ( var ( incompleteNodeURL = regexp.MustCompile("(?i)^(?:enode://)?([0-9a-f]+)$") - lookupIPFunc = net.LookupIP ) // MustParseV4 parses a node URL. It panics if the URL is not valid. @@ -126,20 +125,9 @@ func parseComplete(rawurl string) (*Node, error) { if id, err = parsePubkey(u.User.String()); err != nil { return nil, fmt.Errorf("invalid public key (%v)", err) } - // Parse the IP address. + + // Parse the IP and ports. ip := net.ParseIP(u.Hostname()) - if ip == nil { - ips, err := lookupIPFunc(u.Hostname()) - if err != nil { - return nil, err - } - ip = ips[0] - } - // Ensure the IP is 4 bytes long for IPv4 addresses. - if ipv4 := ip.To4(); ipv4 != nil { - ip = ipv4 - } - // Parse the port numbers. if tcpPort, err = strconv.ParseUint(u.Port(), 10, 16); err != nil { return nil, errors.New("invalid port") } @@ -151,7 +139,13 @@ func parseComplete(rawurl string) (*Node, error) { return nil, errors.New("invalid discport in query") } } - return NewV4(id, ip, int(tcpPort), int(udpPort)), nil + + // Create the node. + node := NewV4(id, ip, int(tcpPort), int(udpPort)) + if ip == nil && u.Hostname() != "" { + node = node.WithHostname(u.Hostname()) + } + return node, nil } // parsePubkey parses a hex-encoded secp256k1 public key. @@ -181,15 +175,23 @@ func (n *Node) URLv4() string { nodeid = fmt.Sprintf("%s.%x", scheme, n.id[:]) } u := url.URL{Scheme: "enode"} - if !n.ip.IsValid() { - u.Host = nodeid - } else { + if n.Hostname() != "" { + // For nodes with a DNS name: include DNS name, TCP port, and optional UDP port + u.User = url.User(nodeid) + u.Host = fmt.Sprintf("%s:%d", n.Hostname(), n.TCP()) + if n.UDP() != n.TCP() { + u.RawQuery = "discport=" + strconv.Itoa(n.UDP()) + } + } else if n.ip.IsValid() { + // For IP-based nodes: include IP address, TCP port, and optional UDP port addr := net.TCPAddr{IP: n.IP(), Port: n.TCP()} u.User = url.User(nodeid) u.Host = addr.String() if n.UDP() != n.TCP() { u.RawQuery = "discport=" + strconv.Itoa(n.UDP()) } + } else { + u.Host = nodeid } return u.String() } diff --git a/p2p/enode/urlv4_test.go b/p2p/enode/urlv4_test.go index 33de96cc57..f39d5a2deb 100644 --- a/p2p/enode/urlv4_test.go +++ b/p2p/enode/urlv4_test.go @@ -18,7 +18,6 @@ package enode import ( "crypto/ecdsa" - "errors" "net" "reflect" "strings" @@ -28,15 +27,6 @@ import ( "github.com/ethereum/go-ethereum/p2p/enr" ) -func init() { - lookupIPFunc = func(name string) ([]net.IP, error) { - if name == "node.example.org" { - return []net.IP{{33, 44, 55, 66}}, nil - } - return nil, errors.New("no such host") - } -} - var parseNodeTests = []struct { input string wantError string @@ -70,10 +60,6 @@ var parseNodeTests = []struct { wantError: enr.ErrInvalidSig.Error(), }, // Complete node URLs with IP address and ports - { - input: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@invalid.:3", - wantError: `no such host`, - }, { input: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:foo", wantError: `invalid port`, @@ -91,6 +77,15 @@ var parseNodeTests = []struct { 52150, ), }, + { + input: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@valid.:3", + wantResult: NewV4( + hexPubkey("1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), + nil, + 3, + 3, + ).WithHostname("valid."), + }, { input: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@[::]:52150", wantResult: NewV4( diff --git a/p2p/peer_test.go b/p2p/peer_test.go index 4308bbd2eb..dea72875fe 100644 --- a/p2p/peer_test.go +++ b/p2p/peer_test.go @@ -60,22 +60,22 @@ func uintID(i uint16) enode.ID { // newNode creates a node record with the given address. func newNode(id enode.ID, addr string) *enode.Node { var r enr.Record - if addr != "" { - // Set the port if present. - if strings.Contains(addr, ":") { - hs, ps, err := net.SplitHostPort(addr) - if err != nil { - panic(fmt.Errorf("invalid address %q", addr)) - } - port, err := strconv.Atoi(ps) - if err != nil { - panic(fmt.Errorf("invalid port in %q", addr)) - } - r.Set(enr.TCP(port)) - r.Set(enr.UDP(port)) - addr = hs + // Set the port if present. + if strings.Contains(addr, ":") { + hs, ps, err := net.SplitHostPort(addr) + if err != nil { + panic(fmt.Errorf("invalid address %q", addr)) + } + port, err := strconv.Atoi(ps) + if err != nil { + panic(fmt.Errorf("invalid port in %q", addr)) } - // Set the IP. + r.Set(enr.TCP(port)) + r.Set(enr.UDP(port)) + addr = hs + } + // Set the IP. + if addr != "" { ip := net.ParseIP(addr) if ip == nil { panic(fmt.Errorf("invalid IP %q", addr))