diff --git a/cmd/devp2p/internal/ethtest/conn.go b/cmd/devp2p/internal/ethtest/conn.go index ba3c0585fd..757b137aa1 100644 --- a/cmd/devp2p/internal/ethtest/conn.go +++ b/cmd/devp2p/internal/ethtest/conn.go @@ -53,7 +53,8 @@ func (s *Suite) dial() (*Conn, error) { // dialAs attempts to dial a given node and perform a handshake using the given // private key. func (s *Suite) dialAs(key *ecdsa.PrivateKey) (*Conn, error) { - fd, err := net.Dial("tcp", fmt.Sprintf("%v:%d", s.Dest.IP(), s.Dest.TCP())) + tcpEndpoint, _ := s.Dest.TCPEndpoint() + fd, err := net.Dial("tcp", tcpEndpoint.String()) if err != nil { return nil, err } diff --git a/cmd/devp2p/internal/v4test/framework.go b/cmd/devp2p/internal/v4test/framework.go index df1f1f8aba..958fb71179 100644 --- a/cmd/devp2p/internal/v4test/framework.go +++ b/cmd/devp2p/internal/v4test/framework.go @@ -53,10 +53,12 @@ func newTestEnv(remote string, listen1, listen2 string) *testenv { if err != nil { panic(err) } - if node.IP() == nil || node.UDP() == 0 { + if !node.IPAddr().IsValid() || node.UDP() == 0 { var ip net.IP var tcpPort, udpPort int - if ip = node.IP(); ip == nil { + if node.IPAddr().IsValid() { + ip = node.IPAddr().AsSlice() + } else { ip = net.ParseIP("127.0.0.1") } if tcpPort = node.TCP(); tcpPort == 0 { diff --git a/cmd/devp2p/nodesetcmd.go b/cmd/devp2p/nodesetcmd.go index 6fbc185ad8..f0773edfb8 100644 --- a/cmd/devp2p/nodesetcmd.go +++ b/cmd/devp2p/nodesetcmd.go @@ -19,7 +19,7 @@ package main import ( "errors" "fmt" - "net" + "net/netip" "sort" "strconv" "strings" @@ -205,11 +205,11 @@ func trueFilter(args []string) (nodeFilter, error) { } func ipFilter(args []string) (nodeFilter, error) { - _, cidr, err := net.ParseCIDR(args[0]) + prefix, err := netip.ParsePrefix(args[0]) if err != nil { return nil, err } - f := func(n nodeJSON) bool { return cidr.Contains(n.N.IP()) } + f := func(n nodeJSON) bool { return prefix.Contains(n.N.IPAddr()) } return f, nil } diff --git a/cmd/devp2p/rlpxcmd.go b/cmd/devp2p/rlpxcmd.go index aa7d065818..fb8066ee1a 100644 --- a/cmd/devp2p/rlpxcmd.go +++ b/cmd/devp2p/rlpxcmd.go @@ -77,7 +77,11 @@ var ( func rlpxPing(ctx *cli.Context) error { n := getNodeArg(ctx) - fd, err := net.Dial("tcp", fmt.Sprintf("%v:%d", n.IP(), n.TCP())) + tcpEndpoint, ok := n.TCPEndpoint() + if !ok { + return fmt.Errorf("node has no TCP endpoint") + } + fd, err := net.Dial("tcp", tcpEndpoint.String()) if err != nil { return err } diff --git a/p2p/dial.go b/p2p/dial.go index 08e1db2877..24d4dc2e89 100644 --- a/p2p/dial.go +++ b/p2p/dial.go @@ -65,11 +65,8 @@ type tcpDialer struct { } func (t tcpDialer) Dial(ctx context.Context, dest *enode.Node) (net.Conn, error) { - return t.d.DialContext(ctx, "tcp", nodeAddr(dest).String()) -} - -func nodeAddr(n *enode.Node) net.Addr { - return &net.TCPAddr{IP: n.IP(), Port: n.TCP()} + addr, _ := dest.TCPEndpoint() + return t.d.DialContext(ctx, "tcp", addr.String()) } // checkDial errors: @@ -243,7 +240,7 @@ loop: select { case node := <-nodesCh: if err := d.checkDial(node); err != nil { - d.log.Trace("Discarding dial candidate", "id", node.ID(), "ip", node.IP(), "reason", err) + d.log.Trace("Discarding dial candidate", "id", node.ID(), "ip", node.IPAddr(), "reason", err) } else { d.startDial(newDialTask(node, dynDialedConn)) } @@ -277,7 +274,7 @@ loop: case node := <-d.addStaticCh: id := node.ID() _, exists := d.static[id] - d.log.Trace("Adding static node", "id", id, "ip", node.IP(), "added", !exists) + d.log.Trace("Adding static node", "id", id, "ip", node.IPAddr(), "added", !exists) if exists { continue loop } @@ -376,7 +373,7 @@ func (d *dialScheduler) checkDial(n *enode.Node) error { if n.ID() == d.self { return errSelf } - if n.IP() != nil && n.TCP() == 0 { + if n.IPAddr().IsValid() && n.TCP() == 0 { // This check can trigger if a non-TCP node is found // by discovery. If there is no IP, the node is a static // node and the actual endpoint will be resolved later in dialTask. @@ -388,7 +385,7 @@ func (d *dialScheduler) checkDial(n *enode.Node) error { if _, ok := d.peers[n.ID()]; ok { return errAlreadyConnected } - if d.netRestrict != nil && !d.netRestrict.Contains(n.IP()) { + if d.netRestrict != nil && !d.netRestrict.ContainsAddr(n.IPAddr()) { return errNetRestrict } if d.history.contains(string(n.ID().Bytes())) { @@ -439,7 +436,7 @@ func (d *dialScheduler) removeFromStaticPool(idx int) { // startDial runs the given dial task in a separate goroutine. func (d *dialScheduler) startDial(task *dialTask) { node := task.dest() - d.log.Trace("Starting p2p dial", "id", node.ID(), "ip", node.IP(), "flag", task.flags) + d.log.Trace("Starting p2p dial", "id", node.ID(), "ip", node.IPAddr(), "flag", task.flags) hkey := string(node.ID().Bytes()) d.history.add(hkey, d.clock.Now().Add(dialHistoryExpiration)) d.dialing[node.ID()] = task @@ -492,7 +489,7 @@ func (t *dialTask) run(d *dialScheduler) { } func (t *dialTask) needResolve() bool { - return t.flags&staticDialedConn != 0 && t.dest().IP() == nil + return t.flags&staticDialedConn != 0 && !t.dest().IPAddr().IsValid() } // resolve attempts to find the current endpoint for the destination @@ -526,7 +523,8 @@ func (t *dialTask) resolve(d *dialScheduler) bool { // The node was found. t.resolveDelay = initialResolveDelay t.destPtr.Store(resolved) - d.log.Debug("Resolved node", "id", resolved.ID(), "addr", &net.TCPAddr{IP: resolved.IP(), Port: resolved.TCP()}) + resAddr, _ := resolved.TCPEndpoint() + d.log.Debug("Resolved node", "id", resolved.ID(), "addr", resAddr) return true } @@ -535,7 +533,8 @@ func (t *dialTask) dial(d *dialScheduler, dest *enode.Node) error { dialMeter.Mark(1) fd, err := d.dialer.Dial(d.ctx, dest) if err != nil { - d.log.Trace("Dial error", "id", dest.ID(), "addr", nodeAddr(dest), "conn", t.flags, "err", cleanupDialErr(err)) + addr, _ := dest.TCPEndpoint() + d.log.Trace("Dial error", "id", dest.ID(), "addr", addr, "conn", t.flags, "err", cleanupDialErr(err)) dialConnectionError.Mark(1) return &dialError{err} } @@ -545,7 +544,7 @@ func (t *dialTask) dial(d *dialScheduler, dest *enode.Node) error { func (t *dialTask) String() string { node := t.dest() id := node.ID() - return fmt.Sprintf("%v %x %v:%d", t.flags, id[:8], node.IP(), node.TCP()) + return fmt.Sprintf("%v %x %v:%d", t.flags, id[:8], node.IPAddr(), node.TCP()) } func cleanupDialErr(err error) error { diff --git a/p2p/discover/table.go b/p2p/discover/table.go index bd3c9b4143..bb5ab4f3fc 100644 --- a/p2p/discover/table.go +++ b/p2p/discover/table.go @@ -25,7 +25,7 @@ package discover import ( "context" "fmt" - "net" + "net/netip" "slices" "sync" "time" @@ -207,8 +207,8 @@ func (tab *Table) setFallbackNodes(nodes []*enode.Node) error { if err := n.ValidateComplete(); err != nil { return fmt.Errorf("bad bootstrap node %q: %v", n, err) } - if tab.cfg.NetRestrict != nil && !tab.cfg.NetRestrict.Contains(n.IP()) { - tab.log.Error("Bootstrap node filtered by netrestrict", "id", n.ID(), "ip", n.IP()) + if tab.cfg.NetRestrict != nil && !tab.cfg.NetRestrict.ContainsAddr(n.IPAddr()) { + tab.log.Error("Bootstrap node filtered by netrestrict", "id", n.ID(), "ip", n.IPAddr()) continue } nursery = append(nursery, n) @@ -448,7 +448,7 @@ func (tab *Table) loadSeedNodes() { for i := range seeds { seed := seeds[i] if tab.log.Enabled(context.Background(), log.LevelTrace) { - age := time.Since(tab.db.LastPongReceived(seed.ID(), seed.IP())) + age := time.Since(tab.db.LastPongReceived(seed.ID(), seed.IPAddr())) addr, _ := seed.UDPEndpoint() tab.log.Trace("Found seed node in database", "id", seed.ID(), "addr", addr, "age", age) } @@ -474,31 +474,31 @@ func (tab *Table) bucketAtDistance(d int) *bucket { return tab.buckets[d-bucketMinDistance-1] } -func (tab *Table) addIP(b *bucket, ip net.IP) bool { - if len(ip) == 0 { +func (tab *Table) addIP(b *bucket, ip netip.Addr) bool { + if !ip.IsValid() || ip.IsUnspecified() { return false // Nodes without IP cannot be added. } - if netutil.IsLAN(ip) { + if netutil.AddrIsLAN(ip) { return true } - if !tab.ips.Add(ip) { + if !tab.ips.AddAddr(ip) { tab.log.Debug("IP exceeds table limit", "ip", ip) return false } - if !b.ips.Add(ip) { + if !b.ips.AddAddr(ip) { tab.log.Debug("IP exceeds bucket limit", "ip", ip) - tab.ips.Remove(ip) + tab.ips.RemoveAddr(ip) return false } return true } -func (tab *Table) removeIP(b *bucket, ip net.IP) { - if netutil.IsLAN(ip) { +func (tab *Table) removeIP(b *bucket, ip netip.Addr) { + if netutil.AddrIsLAN(ip) { return } - tab.ips.Remove(ip) - b.ips.Remove(ip) + tab.ips.RemoveAddr(ip) + b.ips.RemoveAddr(ip) } // handleAddNode adds the node in the request to the table, if there is space. @@ -524,7 +524,7 @@ func (tab *Table) handleAddNode(req addNodeOp) bool { tab.addReplacement(b, req.node) return false } - if !tab.addIP(b, req.node.IP()) { + if !tab.addIP(b, req.node.IPAddr()) { // Can't add: IP limit reached. return false } @@ -547,7 +547,7 @@ func (tab *Table) addReplacement(b *bucket, n *enode.Node) { // TODO: update ENR return } - if !tab.addIP(b, n.IP()) { + if !tab.addIP(b, n.IPAddr()) { return } @@ -555,7 +555,7 @@ func (tab *Table) addReplacement(b *bucket, n *enode.Node) { var removed *tableNode b.replacements, removed = pushNode(b.replacements, wn, maxReplacements) if removed != nil { - tab.removeIP(b, removed.IP()) + tab.removeIP(b, removed.IPAddr()) } } @@ -595,12 +595,12 @@ func (tab *Table) deleteInBucket(b *bucket, id enode.ID) *tableNode { // Remove the node. n := b.entries[index] b.entries = slices.Delete(b.entries, index, index+1) - tab.removeIP(b, n.IP()) + tab.removeIP(b, n.IPAddr()) tab.nodeRemoved(b, n) // Add replacement. if len(b.replacements) == 0 { - tab.log.Debug("Removed dead node", "b", b.index, "id", n.ID(), "ip", n.IP()) + tab.log.Debug("Removed dead node", "b", b.index, "id", n.ID(), "ip", n.IPAddr()) return nil } rindex := tab.rand.Intn(len(b.replacements)) @@ -608,7 +608,7 @@ func (tab *Table) deleteInBucket(b *bucket, id enode.ID) *tableNode { b.replacements = slices.Delete(b.replacements, rindex, rindex+1) b.entries = append(b.entries, rep) tab.nodeAdded(b, rep) - tab.log.Debug("Replaced dead node", "b", b.index, "id", n.ID(), "ip", n.IP(), "r", rep.ID(), "rip", rep.IP()) + tab.log.Debug("Replaced dead node", "b", b.index, "id", n.ID(), "ip", n.IPAddr(), "r", rep.ID(), "rip", rep.IPAddr()) return rep } @@ -635,10 +635,10 @@ func (tab *Table) bumpInBucket(b *bucket, newRecord *enode.Node, isInbound bool) ipchanged := newRecord.IPAddr() != n.IPAddr() portchanged := newRecord.UDP() != n.UDP() if ipchanged { - tab.removeIP(b, n.IP()) - if !tab.addIP(b, newRecord.IP()) { + tab.removeIP(b, n.IPAddr()) + if !tab.addIP(b, newRecord.IPAddr()) { // It doesn't fit with the limit, put the previous record back. - tab.addIP(b, n.IP()) + tab.addIP(b, n.IPAddr()) return n, false } } @@ -657,11 +657,11 @@ func (tab *Table) handleTrackRequest(op trackRequestOp) { var fails int if op.success { // Reset failure counter because it counts _consecutive_ failures. - tab.db.UpdateFindFails(op.node.ID(), op.node.IP(), 0) + tab.db.UpdateFindFails(op.node.ID(), op.node.IPAddr(), 0) } else { - fails = tab.db.FindFails(op.node.ID(), op.node.IP()) + fails = tab.db.FindFails(op.node.ID(), op.node.IPAddr()) fails++ - tab.db.UpdateFindFails(op.node.ID(), op.node.IP(), fails) + tab.db.UpdateFindFails(op.node.ID(), op.node.IPAddr(), fails) } tab.mutex.Lock() diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go index 30e7d56f4a..2f1797d1e2 100644 --- a/p2p/discover/table_test.go +++ b/p2p/discover/table_test.go @@ -188,7 +188,7 @@ func checkIPLimitInvariant(t *testing.T, tab *Table) { tabset := netutil.DistinctNetSet{Subnet: tableSubnet, Limit: tableIPLimit} for _, b := range tab.buckets { for _, n := range b.entries { - tabset.Add(n.IP()) + tabset.AddAddr(n.IPAddr()) } } if tabset.String() != tab.ips.String() { @@ -268,7 +268,7 @@ func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value { } for _, id := range gen([]enode.ID{}, rand).([]enode.ID) { r := new(enr.Record) - r.Set(enr.IP(genIP(rand))) + r.Set(enr.IPv4Addr(netutil.RandomAddr(rand, true))) n := enode.SignNull(r, id) t.All = append(t.All, n) } @@ -385,11 +385,11 @@ func checkBucketContent(t *testing.T, tab *Table, nodes []*enode.Node) { } t.Log("wrong bucket content. have nodes:") for _, n := range b.entries { - t.Logf(" %v (seq=%v, ip=%v)", n.ID(), n.Seq(), n.IP()) + t.Logf(" %v (seq=%v, ip=%v)", n.ID(), n.Seq(), n.IPAddr()) } t.Log("want nodes:") for _, n := range nodes { - t.Logf(" %v (seq=%v, ip=%v)", n.ID(), n.Seq(), n.IP()) + t.Logf(" %v (seq=%v, ip=%v)", n.ID(), n.Seq(), n.IPAddr()) } t.FailNow() @@ -483,12 +483,6 @@ func gen(typ interface{}, rand *rand.Rand) interface{} { return v.Interface() } -func genIP(rand *rand.Rand) net.IP { - ip := make(net.IP, 4) - rand.Read(ip) - return ip -} - func quickcfg() *quick.Config { return &quick.Config{ MaxCount: 5000, diff --git a/p2p/discover/table_util_test.go b/p2p/discover/table_util_test.go index 997ac37799..5b2699d460 100644 --- a/p2p/discover/table_util_test.go +++ b/p2p/discover/table_util_test.go @@ -100,8 +100,9 @@ func idAtDistance(a enode.ID, n int) (b enode.ID) { return b } +// intIP returns a LAN IP address based on i. func intIP(i int) net.IP { - return net.IP{byte(i), 0, 2, byte(i)} + return net.IP{10, 0, byte(i >> 8), byte(i & 0xFF)} } // fillBucket inserts nodes into the given bucket until it is full. @@ -254,7 +255,7 @@ NotEqual: } func nodeEqual(n1 *enode.Node, n2 *enode.Node) bool { - return n1.ID() == n2.ID() && n1.IP().Equal(n2.IP()) + return n1.ID() == n2.ID() && n1.IPAddr() == n2.IPAddr() } func sortByID[N nodeType](nodes []N) { diff --git a/p2p/discover/v4_udp.go b/p2p/discover/v4_udp.go index 3880ca34a7..6234fd93ca 100644 --- a/p2p/discover/v4_udp.go +++ b/p2p/discover/v4_udp.go @@ -25,7 +25,6 @@ import ( "errors" "fmt" "io" - "net" "net/netip" "sync" "time" @@ -250,8 +249,7 @@ func (t *UDPv4) sendPing(toid enode.ID, toaddr netip.AddrPort, callback func()) return matched, matched }) // Send the packet. - toUDPAddr := &net.UDPAddr{IP: toaddr.Addr().AsSlice()} - t.localNode.UDPContact(toUDPAddr) + t.localNode.UDPContact(toaddr) t.write(toaddr, toid, req.Name(), packet) return rm } @@ -383,7 +381,7 @@ func (t *UDPv4) RequestENR(n *enode.Node) (*enode.Node, error) { if respN.Seq() < n.Seq() { return n, nil // response record is older } - if err := netutil.CheckRelayIP(addr.Addr().AsSlice(), respN.IP()); err != nil { + if err := netutil.CheckRelayAddr(addr.Addr(), respN.IPAddr()); err != nil { return nil, fmt.Errorf("invalid IP in response record: %v", err) } return respN, nil @@ -578,15 +576,14 @@ func (t *UDPv4) handlePacket(from netip.AddrPort, buf []byte) error { // checkBond checks if the given node has a recent enough endpoint proof. func (t *UDPv4) checkBond(id enode.ID, ip netip.AddrPort) bool { - return time.Since(t.db.LastPongReceived(id, ip.Addr().AsSlice())) < bondExpiration + return time.Since(t.db.LastPongReceived(id, ip.Addr())) < bondExpiration } // ensureBond solicits a ping from a node if we haven't seen a ping from it for a while. // This ensures there is a valid endpoint proof on the remote end. func (t *UDPv4) ensureBond(toid enode.ID, toaddr netip.AddrPort) { - ip := toaddr.Addr().AsSlice() - tooOld := time.Since(t.db.LastPingReceived(toid, ip)) > bondExpiration - if tooOld || t.db.FindFails(toid, ip) > maxFindnodeFailures { + tooOld := time.Since(t.db.LastPingReceived(toid, toaddr.Addr())) > bondExpiration + if tooOld || t.db.FindFails(toid, toaddr.Addr()) > maxFindnodeFailures { rm := t.sendPing(toid, toaddr, nil) <-rm.errc // Wait for them to ping back and process our pong. @@ -687,7 +684,7 @@ func (t *UDPv4) handlePing(h *packetHandlerV4, from netip.AddrPort, fromID enode // Ping back if our last pong on file is too far in the past. fromIP := from.Addr().AsSlice() n := enode.NewV4(h.senderKey, fromIP, int(req.From.TCP), int(from.Port())) - if time.Since(t.db.LastPongReceived(n.ID(), fromIP)) > bondExpiration { + if time.Since(t.db.LastPongReceived(n.ID(), from.Addr())) > bondExpiration { t.sendPing(fromID, from, func() { t.tab.addInboundNode(n) }) @@ -696,10 +693,9 @@ func (t *UDPv4) handlePing(h *packetHandlerV4, from netip.AddrPort, fromID enode } // Update node database and endpoint predictor. - t.db.UpdateLastPingReceived(n.ID(), fromIP, time.Now()) - fromUDPAddr := &net.UDPAddr{IP: fromIP, Port: int(from.Port())} - toUDPAddr := &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)} - t.localNode.UDPEndpointStatement(fromUDPAddr, toUDPAddr) + t.db.UpdateLastPingReceived(n.ID(), from.Addr(), time.Now()) + toaddr := netip.AddrPortFrom(netutil.IPToAddr(req.To.IP), req.To.UDP) + t.localNode.UDPEndpointStatement(from, toaddr) } // PONG/v4 @@ -713,11 +709,9 @@ func (t *UDPv4) verifyPong(h *packetHandlerV4, from netip.AddrPort, fromID enode if !t.handleReply(fromID, from.Addr(), req) { return errUnsolicitedReply } - fromIP := from.Addr().AsSlice() - fromUDPAddr := &net.UDPAddr{IP: fromIP, Port: int(from.Port())} - toUDPAddr := &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)} - t.localNode.UDPEndpointStatement(fromUDPAddr, toUDPAddr) - t.db.UpdateLastPongReceived(fromID, fromIP, time.Now()) + toaddr := netip.AddrPortFrom(netutil.IPToAddr(req.To.IP), req.To.UDP) + t.localNode.UDPEndpointStatement(from, toaddr) + t.db.UpdateLastPongReceived(fromID, from.Addr(), time.Now()) return nil } @@ -753,8 +747,7 @@ func (t *UDPv4) handleFindnode(h *packetHandlerV4, from netip.AddrPort, fromID e p := v4wire.Neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())} var sent bool for _, n := range closest { - fromIP := from.Addr().AsSlice() - if netutil.CheckRelayIP(fromIP, n.IP()) == nil { + if netutil.CheckRelayAddr(from.Addr(), n.IPAddr()) == nil { p.Nodes = append(p.Nodes, nodeToRPC(n)) } if len(p.Nodes) == v4wire.MaxNeighbors { diff --git a/p2p/discover/v4_udp_test.go b/p2p/discover/v4_udp_test.go index 28a6fb8675..9d6df08ead 100644 --- a/p2p/discover/v4_udp_test.go +++ b/p2p/discover/v4_udp_test.go @@ -274,7 +274,7 @@ func TestUDPv4_findnode(t *testing.T) { // ensure there's a bond with the test node, // findnode won't be accepted otherwise. remoteID := v4wire.EncodePubkey(&test.remotekey.PublicKey).ID() - test.table.db.UpdateLastPongReceived(remoteID, test.remoteaddr.Addr().AsSlice(), time.Now()) + test.table.db.UpdateLastPongReceived(remoteID, test.remoteaddr.Addr(), time.Now()) // check that closest neighbors are returned. expected := test.table.findnodeByID(testTarget.ID(), bucketSize, true) @@ -309,7 +309,7 @@ func TestUDPv4_findnodeMultiReply(t *testing.T) { defer test.close() rid := enode.PubkeyToIDV4(&test.remotekey.PublicKey) - test.table.db.UpdateLastPingReceived(rid, test.remoteaddr.Addr().AsSlice(), time.Now()) + test.table.db.UpdateLastPingReceived(rid, test.remoteaddr.Addr(), time.Now()) // queue a pending findnode request resultc, errc := make(chan []*enode.Node, 1), make(chan error, 1) @@ -437,8 +437,8 @@ func TestUDPv4_successfulPing(t *testing.T) { if n.ID() != rid { t.Errorf("node has wrong ID: got %v, want %v", n.ID(), rid) } - if !n.IP().Equal(test.remoteaddr.Addr().AsSlice()) { - t.Errorf("node has wrong IP: got %v, want: %v", n.IP(), test.remoteaddr.Addr()) + if n.IPAddr() != test.remoteaddr.Addr() { + t.Errorf("node has wrong IP: got %v, want: %v", n.IPAddr(), test.remoteaddr.Addr()) } if n.UDP() != int(test.remoteaddr.Port()) { t.Errorf("node has wrong UDP port: got %v, want: %v", n.UDP(), test.remoteaddr.Port()) diff --git a/p2p/discover/v5_udp.go b/p2p/discover/v5_udp.go index 9ba54b3d40..4cc21cd9c5 100644 --- a/p2p/discover/v5_udp.go +++ b/p2p/discover/v5_udp.go @@ -428,10 +428,10 @@ func (t *UDPv5) verifyResponseNode(c *callV5, r *enr.Record, distances []uint, s if err != nil { return nil, err } - if err := netutil.CheckRelayIP(c.addr.Addr().AsSlice(), node.IP()); err != nil { + if err := netutil.CheckRelayAddr(c.addr.Addr(), node.IPAddr()); err != nil { return nil, err } - if t.netrestrict != nil && !t.netrestrict.Contains(node.IP()) { + if t.netrestrict != nil && !t.netrestrict.ContainsAddr(node.IPAddr()) { return nil, errors.New("not contained in netrestrict list") } if node.UDP() <= 1024 { @@ -754,9 +754,8 @@ func (t *UDPv5) handle(p v5wire.Packet, fromID enode.ID, fromAddr netip.AddrPort t.handlePing(p, fromID, fromAddr) case *v5wire.Pong: if t.handleCallResponse(fromID, fromAddr, p) { - fromUDPAddr := &net.UDPAddr{IP: fromAddr.Addr().AsSlice(), Port: int(fromAddr.Port())} - toUDPAddr := &net.UDPAddr{IP: p.ToIP, Port: int(p.ToPort)} - t.localNode.UDPEndpointStatement(fromUDPAddr, toUDPAddr) + toAddr := netip.AddrPortFrom(netutil.IPToAddr(p.ToIP), p.ToPort) + t.localNode.UDPEndpointStatement(fromAddr, toAddr) } case *v5wire.Findnode: t.handleFindnode(p, fromID, fromAddr) @@ -848,7 +847,6 @@ func (t *UDPv5) handleFindnode(p *v5wire.Findnode, fromID enode.ID, fromAddr net // collectTableNodes creates a FINDNODE result set for the given distances. func (t *UDPv5) collectTableNodes(rip netip.Addr, distances []uint, limit int) []*enode.Node { - ripSlice := rip.AsSlice() var bn []*enode.Node var nodes []*enode.Node var processed = make(map[uint]struct{}) @@ -863,7 +861,7 @@ func (t *UDPv5) collectTableNodes(rip netip.Addr, distances []uint, limit int) [ for _, n := range t.tab.appendLiveNodes(dist, bn[:0]) { // Apply some pre-checks to avoid sending invalid nodes. // Note liveness is checked by appendLiveNodes. - if netutil.CheckRelayIP(ripSlice, n.IP()) != nil { + if netutil.CheckRelayAddr(rip, n.IPAddr()) != nil { continue } nodes = append(nodes, n) diff --git a/p2p/discover/v5wire/encoding_test.go b/p2p/discover/v5wire/encoding_test.go index 27966f2afc..8dd02620eb 100644 --- a/p2p/discover/v5wire/encoding_test.go +++ b/p2p/discover/v5wire/encoding_test.go @@ -606,7 +606,7 @@ func (n *handshakeTestNode) n() *enode.Node { } func (n *handshakeTestNode) addr() string { - return n.ln.Node().IP().String() + return n.ln.Node().IPAddr().String() } func (n *handshakeTestNode) id() enode.ID { diff --git a/p2p/enode/localnode.go b/p2p/enode/localnode.go index a18204e752..6e79c9cbdc 100644 --- a/p2p/enode/localnode.go +++ b/p2p/enode/localnode.go @@ -20,8 +20,8 @@ import ( "crypto/ecdsa" "fmt" "net" + "net/netip" "reflect" - "strconv" "sync" "sync/atomic" "time" @@ -175,8 +175,8 @@ func (ln *LocalNode) delete(e enr.Entry) { } } -func (ln *LocalNode) endpointForIP(ip net.IP) *lnEndpoint { - if ip.To4() != nil { +func (ln *LocalNode) endpointForIP(ip netip.Addr) *lnEndpoint { + if ip.Is4() { return &ln.endpoint4 } return &ln.endpoint6 @@ -188,7 +188,7 @@ func (ln *LocalNode) SetStaticIP(ip net.IP) { ln.mu.Lock() defer ln.mu.Unlock() - ln.endpointForIP(ip).staticIP = ip + ln.endpointForIP(netutil.IPToAddr(ip)).staticIP = ip ln.updateEndpoints() } @@ -198,7 +198,7 @@ func (ln *LocalNode) SetFallbackIP(ip net.IP) { ln.mu.Lock() defer ln.mu.Unlock() - ln.endpointForIP(ip).fallbackIP = ip + ln.endpointForIP(netutil.IPToAddr(ip)).fallbackIP = ip ln.updateEndpoints() } @@ -215,21 +215,21 @@ func (ln *LocalNode) SetFallbackUDP(port int) { // UDPEndpointStatement should be called whenever a statement about the local node's // UDP endpoint is received. It feeds the local endpoint predictor. -func (ln *LocalNode) UDPEndpointStatement(fromaddr, endpoint *net.UDPAddr) { +func (ln *LocalNode) UDPEndpointStatement(fromaddr, endpoint netip.AddrPort) { ln.mu.Lock() defer ln.mu.Unlock() - ln.endpointForIP(endpoint.IP).track.AddStatement(fromaddr.String(), endpoint.String()) + ln.endpointForIP(endpoint.Addr()).track.AddStatement(fromaddr.Addr(), endpoint) ln.updateEndpoints() } // UDPContact should be called whenever the local node has announced itself to another node // via UDP. It feeds the local endpoint predictor. -func (ln *LocalNode) UDPContact(toaddr *net.UDPAddr) { +func (ln *LocalNode) UDPContact(toaddr netip.AddrPort) { ln.mu.Lock() defer ln.mu.Unlock() - ln.endpointForIP(toaddr.IP).track.AddContact(toaddr.String()) + ln.endpointForIP(toaddr.Addr()).track.AddContact(toaddr.Addr()) ln.updateEndpoints() } @@ -268,29 +268,13 @@ func (e *lnEndpoint) get() (newIP net.IP, newPort uint16) { } if e.staticIP != nil { newIP = e.staticIP - } else if ip, port := predictAddr(e.track); ip != nil { - newIP = ip - newPort = port + } else if ap := e.track.PredictEndpoint(); ap.IsValid() { + newIP = ap.Addr().AsSlice() + newPort = ap.Port() } return newIP, newPort } -// predictAddr wraps IPTracker.PredictEndpoint, converting from its string-based -// endpoint representation to IP and port types. -func predictAddr(t *netutil.IPTracker) (net.IP, uint16) { - ep := t.PredictEndpoint() - if ep == "" { - return nil, 0 - } - ipString, portString, _ := net.SplitHostPort(ep) - ip := net.ParseIP(ipString) - port, err := strconv.ParseUint(portString, 10, 16) - if err != nil { - return nil, 0 - } - return ip, uint16(port) -} - func (ln *LocalNode) invalidate() { ln.cur.Store((*Node)(nil)) } @@ -314,7 +298,7 @@ func (ln *LocalNode) sign() { panic(fmt.Errorf("enode: can't verify local record: %v", err)) } ln.cur.Store(n) - log.Info("New local node record", "seq", ln.seq, "id", n.ID(), "ip", n.IP(), "udp", n.UDP(), "tcp", n.TCP()) + log.Info("New local node record", "seq", ln.seq, "id", n.ID(), "ip", n.IPAddr(), "udp", n.UDP(), "tcp", n.TCP()) } func (ln *LocalNode) bumpSeq() { diff --git a/p2p/enode/localnode_test.go b/p2p/enode/localnode_test.go index 7f97ad392f..86b962a74e 100644 --- a/p2p/enode/localnode_test.go +++ b/p2p/enode/localnode_test.go @@ -17,12 +17,14 @@ package enode import ( - "crypto/rand" + "math/rand" "net" + "net/netip" "testing" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/p2p/enr" + "github.com/ethereum/go-ethereum/p2p/netutil" "github.com/stretchr/testify/assert" ) @@ -88,6 +90,7 @@ func TestLocalNodeSeqPersist(t *testing.T) { // This test checks behavior of the endpoint predictor. func TestLocalNodeEndpoint(t *testing.T) { var ( + rng = rand.New(rand.NewSource(4)) fallback = &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: 80} predicted = &net.UDPAddr{IP: net.IP{127, 0, 1, 2}, Port: 81} staticIP = net.IP{127, 0, 1, 2} @@ -96,6 +99,7 @@ func TestLocalNodeEndpoint(t *testing.T) { defer db.Close() // Nothing is set initially. + assert.Equal(t, netip.Addr{}, ln.Node().IPAddr()) assert.Equal(t, net.IP(nil), ln.Node().IP()) assert.Equal(t, 0, ln.Node().UDP()) initialSeq := ln.Node().Seq() @@ -103,26 +107,30 @@ func TestLocalNodeEndpoint(t *testing.T) { // Set up fallback address. ln.SetFallbackIP(fallback.IP) ln.SetFallbackUDP(fallback.Port) + assert.Equal(t, netutil.IPToAddr(fallback.IP), ln.Node().IPAddr()) assert.Equal(t, fallback.IP, ln.Node().IP()) assert.Equal(t, fallback.Port, ln.Node().UDP()) assert.Equal(t, initialSeq+1, ln.Node().Seq()) // Add endpoint statements from random hosts. for i := 0; i < iptrackMinStatements; i++ { + assert.Equal(t, netutil.IPToAddr(fallback.IP), ln.Node().IPAddr()) assert.Equal(t, fallback.IP, ln.Node().IP()) assert.Equal(t, fallback.Port, ln.Node().UDP()) assert.Equal(t, initialSeq+1, ln.Node().Seq()) - from := &net.UDPAddr{IP: make(net.IP, 4), Port: 90} - rand.Read(from.IP) - ln.UDPEndpointStatement(from, predicted) + from := netip.AddrPortFrom(netutil.RandomAddr(rng, true), 9000) + endpoint := netip.AddrPortFrom(netutil.IPToAddr(predicted.IP), uint16(predicted.Port)) + ln.UDPEndpointStatement(from, endpoint) } + assert.Equal(t, netutil.IPToAddr(predicted.IP), ln.Node().IPAddr()) assert.Equal(t, predicted.IP, ln.Node().IP()) assert.Equal(t, predicted.Port, ln.Node().UDP()) assert.Equal(t, initialSeq+2, ln.Node().Seq()) // Static IP overrides prediction. ln.SetStaticIP(staticIP) + assert.Equal(t, netutil.IPToAddr(staticIP), ln.Node().IPAddr()) assert.Equal(t, staticIP, ln.Node().IP()) assert.Equal(t, fallback.Port, ln.Node().UDP()) assert.Equal(t, initialSeq+3, ln.Node().Seq()) diff --git a/p2p/enode/nodedb.go b/p2p/enode/nodedb.go index 654d71d47b..1f31c98d22 100644 --- a/p2p/enode/nodedb.go +++ b/p2p/enode/nodedb.go @@ -21,7 +21,7 @@ import ( "crypto/rand" "encoding/binary" "fmt" - "net" + "net/netip" "os" "sync" "time" @@ -66,7 +66,7 @@ var ( errInvalidIP = errors.New("invalid IP") ) -var zeroIP = make(net.IP, 16) +var zeroIP = netip.IPv6Unspecified() // DB is the node database, storing previously seen nodes and any collected metadata about // them for QoS purposes. @@ -151,39 +151,37 @@ func splitNodeKey(key []byte) (id ID, rest []byte) { } // nodeItemKey returns the database key for a node metadata field. -func nodeItemKey(id ID, ip net.IP, field string) []byte { - ip16 := ip.To16() - if ip16 == nil { - panic(fmt.Errorf("invalid IP (length %d)", len(ip))) +func nodeItemKey(id ID, ip netip.Addr, field string) []byte { + if !ip.IsValid() { + panic("invalid IP") } - return bytes.Join([][]byte{nodeKey(id), ip16, []byte(field)}, []byte{':'}) + ip16 := ip.As16() + return bytes.Join([][]byte{nodeKey(id), ip16[:], []byte(field)}, []byte{':'}) } // splitNodeItemKey returns the components of a key created by nodeItemKey. -func splitNodeItemKey(key []byte) (id ID, ip net.IP, field string) { +func splitNodeItemKey(key []byte) (id ID, ip netip.Addr, field string) { id, key = splitNodeKey(key) // Skip discover root. if string(key) == dbDiscoverRoot { - return id, nil, "" + return id, netip.Addr{}, "" } key = key[len(dbDiscoverRoot)+1:] // Split out the IP. - ip = key[:16] - if ip4 := ip.To4(); ip4 != nil { - ip = ip4 - } + ip, _ = netip.AddrFromSlice(key[:16]) key = key[16+1:] // Field is the remainder of key. field = string(key) return id, ip, field } -func v5Key(id ID, ip net.IP, field string) []byte { +func v5Key(id ID, ip netip.Addr, field string) []byte { + ip16 := ip.As16() return bytes.Join([][]byte{ []byte(dbNodePrefix), id[:], []byte(dbDiscv5Root), - ip.To16(), + ip16[:], []byte(field), }, []byte{':'}) } @@ -364,24 +362,24 @@ func (db *DB) expireNodes() { // LastPingReceived retrieves the time of the last ping packet received from // a remote node. -func (db *DB) LastPingReceived(id ID, ip net.IP) time.Time { - if ip = ip.To16(); ip == nil { +func (db *DB) LastPingReceived(id ID, ip netip.Addr) time.Time { + if !ip.IsValid() { return time.Time{} } return time.Unix(db.fetchInt64(nodeItemKey(id, ip, dbNodePing)), 0) } // UpdateLastPingReceived updates the last time we tried contacting a remote node. -func (db *DB) UpdateLastPingReceived(id ID, ip net.IP, instance time.Time) error { - if ip = ip.To16(); ip == nil { +func (db *DB) UpdateLastPingReceived(id ID, ip netip.Addr, instance time.Time) error { + if !ip.IsValid() { return errInvalidIP } return db.storeInt64(nodeItemKey(id, ip, dbNodePing), instance.Unix()) } // LastPongReceived retrieves the time of the last successful pong from remote node. -func (db *DB) LastPongReceived(id ID, ip net.IP) time.Time { - if ip = ip.To16(); ip == nil { +func (db *DB) LastPongReceived(id ID, ip netip.Addr) time.Time { + if !ip.IsValid() { return time.Time{} } // Launch expirer @@ -390,40 +388,40 @@ func (db *DB) LastPongReceived(id ID, ip net.IP) time.Time { } // UpdateLastPongReceived updates the last pong time of a node. -func (db *DB) UpdateLastPongReceived(id ID, ip net.IP, instance time.Time) error { - if ip = ip.To16(); ip == nil { +func (db *DB) UpdateLastPongReceived(id ID, ip netip.Addr, instance time.Time) error { + if !ip.IsValid() { return errInvalidIP } return db.storeInt64(nodeItemKey(id, ip, dbNodePong), instance.Unix()) } // FindFails retrieves the number of findnode failures since bonding. -func (db *DB) FindFails(id ID, ip net.IP) int { - if ip = ip.To16(); ip == nil { +func (db *DB) FindFails(id ID, ip netip.Addr) int { + if !ip.IsValid() { return 0 } return int(db.fetchInt64(nodeItemKey(id, ip, dbNodeFindFails))) } // UpdateFindFails updates the number of findnode failures since bonding. -func (db *DB) UpdateFindFails(id ID, ip net.IP, fails int) error { - if ip = ip.To16(); ip == nil { +func (db *DB) UpdateFindFails(id ID, ip netip.Addr, fails int) error { + if !ip.IsValid() { return errInvalidIP } return db.storeInt64(nodeItemKey(id, ip, dbNodeFindFails), int64(fails)) } // FindFailsV5 retrieves the discv5 findnode failure counter. -func (db *DB) FindFailsV5(id ID, ip net.IP) int { - if ip = ip.To16(); ip == nil { +func (db *DB) FindFailsV5(id ID, ip netip.Addr) int { + if !ip.IsValid() { return 0 } return int(db.fetchInt64(v5Key(id, ip, dbNodeFindFails))) } // UpdateFindFailsV5 stores the discv5 findnode failure counter. -func (db *DB) UpdateFindFailsV5(id ID, ip net.IP, fails int) error { - if ip = ip.To16(); ip == nil { +func (db *DB) UpdateFindFailsV5(id ID, ip netip.Addr, fails int) error { + if !ip.IsValid() { return errInvalidIP } return db.storeInt64(v5Key(id, ip, dbNodeFindFails), int64(fails)) @@ -470,7 +468,7 @@ seek: id[0] = 0 continue seek // iterator exhausted } - if now.Sub(db.LastPongReceived(n.ID(), n.IP())) > maxAge { + if now.Sub(db.LastPongReceived(n.ID(), n.IPAddr())) > maxAge { continue seek } for i := range nodes { diff --git a/p2p/enode/nodedb_test.go b/p2p/enode/nodedb_test.go index 38764f31b1..bc0291665d 100644 --- a/p2p/enode/nodedb_test.go +++ b/p2p/enode/nodedb_test.go @@ -20,6 +20,7 @@ import ( "bytes" "fmt" "net" + "net/netip" "path/filepath" "reflect" "testing" @@ -48,8 +49,10 @@ func TestDBNodeKey(t *testing.T) { } func TestDBNodeItemKey(t *testing.T) { - wantIP := net.IP{127, 0, 0, 3} + wantIP := netip.MustParseAddr("127.0.0.3") + wantIP4in6 := netip.AddrFrom16(wantIP.As16()) wantField := "foobar" + enc := nodeItemKey(keytestID, wantIP, wantField) want := []byte{ 'n', ':', @@ -69,7 +72,7 @@ func TestDBNodeItemKey(t *testing.T) { if id != keytestID { t.Errorf("splitNodeItemKey returned wrong ID: %v", id) } - if !ip.Equal(wantIP) { + if ip != wantIP4in6 { t.Errorf("splitNodeItemKey returned wrong IP: %v", ip) } if field != wantField { @@ -123,33 +126,33 @@ func TestDBFetchStore(t *testing.T) { defer db.Close() // Check fetch/store operations on a node ping object - if stored := db.LastPingReceived(node.ID(), node.IP()); stored.Unix() != 0 { + if stored := db.LastPingReceived(node.ID(), node.IPAddr()); stored.Unix() != 0 { t.Errorf("ping: non-existing object: %v", stored) } - if err := db.UpdateLastPingReceived(node.ID(), node.IP(), inst); err != nil { + if err := db.UpdateLastPingReceived(node.ID(), node.IPAddr(), inst); err != nil { t.Errorf("ping: failed to update: %v", err) } - if stored := db.LastPingReceived(node.ID(), node.IP()); stored.Unix() != inst.Unix() { + if stored := db.LastPingReceived(node.ID(), node.IPAddr()); stored.Unix() != inst.Unix() { t.Errorf("ping: value mismatch: have %v, want %v", stored, inst) } // Check fetch/store operations on a node pong object - if stored := db.LastPongReceived(node.ID(), node.IP()); stored.Unix() != 0 { + if stored := db.LastPongReceived(node.ID(), node.IPAddr()); stored.Unix() != 0 { t.Errorf("pong: non-existing object: %v", stored) } - if err := db.UpdateLastPongReceived(node.ID(), node.IP(), inst); err != nil { + if err := db.UpdateLastPongReceived(node.ID(), node.IPAddr(), inst); err != nil { t.Errorf("pong: failed to update: %v", err) } - if stored := db.LastPongReceived(node.ID(), node.IP()); stored.Unix() != inst.Unix() { + if stored := db.LastPongReceived(node.ID(), node.IPAddr()); stored.Unix() != inst.Unix() { t.Errorf("pong: value mismatch: have %v, want %v", stored, inst) } // Check fetch/store operations on a node findnode-failure object - if stored := db.FindFails(node.ID(), node.IP()); stored != 0 { + if stored := db.FindFails(node.ID(), node.IPAddr()); stored != 0 { t.Errorf("find-node fails: non-existing object: %v", stored) } - if err := db.UpdateFindFails(node.ID(), node.IP(), num); err != nil { + if err := db.UpdateFindFails(node.ID(), node.IPAddr(), num); err != nil { t.Errorf("find-node fails: failed to update: %v", err) } - if stored := db.FindFails(node.ID(), node.IP()); stored != num { + if stored := db.FindFails(node.ID(), node.IPAddr()); stored != num { t.Errorf("find-node fails: value mismatch: have %v, want %v", stored, num) } // Check fetch/store operations on an actual node object @@ -266,7 +269,7 @@ func testSeedQuery() error { if err := db.UpdateNode(seed.node); err != nil { return fmt.Errorf("node %d: failed to insert: %v", i, err) } - if err := db.UpdateLastPongReceived(seed.node.ID(), seed.node.IP(), seed.pong); err != nil { + if err := db.UpdateLastPongReceived(seed.node.ID(), seed.node.IPAddr(), seed.pong); err != nil { return fmt.Errorf("node %d: failed to insert bondTime: %v", i, err) } } @@ -427,7 +430,7 @@ func TestDBExpiration(t *testing.T) { t.Fatalf("node %d: failed to insert: %v", i, err) } } - if err := db.UpdateLastPongReceived(seed.node.ID(), seed.node.IP(), seed.pong); err != nil { + if err := db.UpdateLastPongReceived(seed.node.ID(), seed.node.IPAddr(), seed.pong); err != nil { t.Fatalf("node %d: failed to update bondTime: %v", i, err) } } @@ -438,13 +441,13 @@ func TestDBExpiration(t *testing.T) { unixZeroTime := time.Unix(0, 0) for i, seed := range nodeDBExpirationNodes { node := db.Node(seed.node.ID()) - pong := db.LastPongReceived(seed.node.ID(), seed.node.IP()) + pong := db.LastPongReceived(seed.node.ID(), seed.node.IPAddr()) if seed.exp { if seed.storeNode && node != nil { t.Errorf("node %d (%s) shouldn't be present after expiration", i, seed.node.ID().TerminalString()) } if !pong.Equal(unixZeroTime) { - t.Errorf("pong time %d (%s %v) shouldn't be present after expiration", i, seed.node.ID().TerminalString(), seed.node.IP()) + t.Errorf("pong time %d (%s %v) shouldn't be present after expiration", i, seed.node.ID().TerminalString(), seed.node.IPAddr()) } } else { if seed.storeNode && node == nil { @@ -463,7 +466,7 @@ func TestDBExpireV5(t *testing.T) { db, _ := OpenDB("") defer db.Close() - ip := net.IP{127, 0, 0, 1} + ip := netip.MustParseAddr("127.0.0.1") db.UpdateFindFailsV5(ID{}, ip, 4) db.expireNodes() } diff --git a/p2p/netutil/addrutil.go b/p2p/netutil/addrutil.go index fb6d8d2731..b8b318571b 100644 --- a/p2p/netutil/addrutil.go +++ b/p2p/netutil/addrutil.go @@ -16,18 +16,53 @@ package netutil -import "net" +import ( + "fmt" + "math/rand" + "net" + "net/netip" +) -// AddrIP gets the IP address contained in addr. It returns nil if no address is present. -func AddrIP(addr net.Addr) net.IP { +// AddrAddr gets the IP address contained in addr. The result will be invalid if the +// address type is unsupported. +func AddrAddr(addr net.Addr) netip.Addr { switch a := addr.(type) { case *net.IPAddr: - return a.IP + return IPToAddr(a.IP) case *net.TCPAddr: - return a.IP + return IPToAddr(a.IP) case *net.UDPAddr: - return a.IP + return IPToAddr(a.IP) default: - return nil + return netip.Addr{} } } + +// IPToAddr converts net.IP to netip.Addr. Note that unlike netip.AddrFromSlice, this +// function will always ensure that the resulting Addr is IPv4 when the input is. +func IPToAddr(ip net.IP) netip.Addr { + if ip4 := ip.To4(); ip4 != nil { + addr, _ := netip.AddrFromSlice(ip4) + return addr + } else if ip6 := ip.To16(); ip6 != nil { + addr, _ := netip.AddrFromSlice(ip6) + return addr + } + return netip.Addr{} +} + +// RandomAddr creates a random IP address. +func RandomAddr(rng *rand.Rand, ipv4 bool) netip.Addr { + var bytes []byte + if ipv4 || rng.Intn(2) == 0 { + bytes = make([]byte, 4) + } else { + bytes = make([]byte, 16) + } + rng.Read(bytes) + addr, ok := netip.AddrFromSlice(bytes) + if !ok { + panic(fmt.Errorf("BUG! invalid IP %v", bytes)) + } + return addr +} diff --git a/p2p/netutil/iptrack.go b/p2p/netutil/iptrack.go index a070499e19..5140ac7539 100644 --- a/p2p/netutil/iptrack.go +++ b/p2p/netutil/iptrack.go @@ -17,6 +17,7 @@ package netutil import ( + "net/netip" "time" "github.com/ethereum/go-ethereum/common/mclock" @@ -29,14 +30,14 @@ type IPTracker struct { contactWindow time.Duration minStatements int clock mclock.Clock - statements map[string]ipStatement - contact map[string]mclock.AbsTime + statements map[netip.Addr]ipStatement + contact map[netip.Addr]mclock.AbsTime lastStatementGC mclock.AbsTime lastContactGC mclock.AbsTime } type ipStatement struct { - endpoint string + endpoint netip.AddrPort time mclock.AbsTime } @@ -51,9 +52,9 @@ func NewIPTracker(window, contactWindow time.Duration, minStatements int) *IPTra return &IPTracker{ window: window, contactWindow: contactWindow, - statements: make(map[string]ipStatement), + statements: make(map[netip.Addr]ipStatement), minStatements: minStatements, - contact: make(map[string]mclock.AbsTime), + contact: make(map[netip.Addr]mclock.AbsTime), clock: mclock.System{}, } } @@ -74,12 +75,15 @@ func (it *IPTracker) PredictFullConeNAT() bool { } // PredictEndpoint returns the current prediction of the external endpoint. -func (it *IPTracker) PredictEndpoint() string { +func (it *IPTracker) PredictEndpoint() netip.AddrPort { it.gcStatements(it.clock.Now()) // The current strategy is simple: find the endpoint with most statements. - counts := make(map[string]int, len(it.statements)) - maxcount, max := 0, "" + var ( + counts = make(map[netip.AddrPort]int, len(it.statements)) + maxcount int + max netip.AddrPort + ) for _, s := range it.statements { c := counts[s.endpoint] + 1 counts[s.endpoint] = c @@ -91,7 +95,7 @@ func (it *IPTracker) PredictEndpoint() string { } // AddStatement records that a certain host thinks our external endpoint is the one given. -func (it *IPTracker) AddStatement(host, endpoint string) { +func (it *IPTracker) AddStatement(host netip.Addr, endpoint netip.AddrPort) { now := it.clock.Now() it.statements[host] = ipStatement{endpoint, now} if time.Duration(now-it.lastStatementGC) >= it.window { @@ -101,7 +105,7 @@ func (it *IPTracker) AddStatement(host, endpoint string) { // AddContact records that a packet containing our endpoint information has been sent to a // certain host. -func (it *IPTracker) AddContact(host string) { +func (it *IPTracker) AddContact(host netip.Addr) { now := it.clock.Now() it.contact[host] = now if time.Duration(now-it.lastContactGC) >= it.contactWindow { diff --git a/p2p/netutil/iptrack_test.go b/p2p/netutil/iptrack_test.go index ee3bba861e..81653a2733 100644 --- a/p2p/netutil/iptrack_test.go +++ b/p2p/netutil/iptrack_test.go @@ -19,6 +19,7 @@ package netutil import ( crand "crypto/rand" "fmt" + "net/netip" "testing" "time" @@ -42,37 +43,37 @@ func TestIPTracker(t *testing.T) { tests := map[string][]iptrackTestEvent{ "minStatements": { {opPredict, 0, "", ""}, - {opStatement, 0, "127.0.0.1", "127.0.0.2"}, + {opStatement, 0, "127.0.0.1:8000", "127.0.0.2"}, {opPredict, 1000, "", ""}, - {opStatement, 1000, "127.0.0.1", "127.0.0.3"}, + {opStatement, 1000, "127.0.0.1:8000", "127.0.0.3"}, {opPredict, 1000, "", ""}, - {opStatement, 1000, "127.0.0.1", "127.0.0.4"}, - {opPredict, 1000, "127.0.0.1", ""}, + {opStatement, 1000, "127.0.0.1:8000", "127.0.0.4"}, + {opPredict, 1000, "127.0.0.1:8000", ""}, }, "window": { - {opStatement, 0, "127.0.0.1", "127.0.0.2"}, - {opStatement, 2000, "127.0.0.1", "127.0.0.3"}, - {opStatement, 3000, "127.0.0.1", "127.0.0.4"}, - {opPredict, 10000, "127.0.0.1", ""}, + {opStatement, 0, "127.0.0.1:8000", "127.0.0.2"}, + {opStatement, 2000, "127.0.0.1:8000", "127.0.0.3"}, + {opStatement, 3000, "127.0.0.1:8000", "127.0.0.4"}, + {opPredict, 10000, "127.0.0.1:8000", ""}, {opPredict, 10001, "", ""}, // first statement expired - {opStatement, 10100, "127.0.0.1", "127.0.0.2"}, - {opPredict, 10200, "127.0.0.1", ""}, + {opStatement, 10100, "127.0.0.1:8000", "127.0.0.2"}, + {opPredict, 10200, "127.0.0.1:8000", ""}, }, "fullcone": { {opContact, 0, "", "127.0.0.2"}, - {opStatement, 10, "127.0.0.1", "127.0.0.2"}, + {opStatement, 10, "127.0.0.1:8000", "127.0.0.2"}, {opContact, 2000, "", "127.0.0.3"}, - {opStatement, 2010, "127.0.0.1", "127.0.0.3"}, + {opStatement, 2010, "127.0.0.1:8000", "127.0.0.3"}, {opContact, 3000, "", "127.0.0.4"}, - {opStatement, 3010, "127.0.0.1", "127.0.0.4"}, + {opStatement, 3010, "127.0.0.1:8000", "127.0.0.4"}, {opCheckFullCone, 3500, "false", ""}, }, "fullcone_2": { {opContact, 0, "", "127.0.0.2"}, - {opStatement, 10, "127.0.0.1", "127.0.0.2"}, + {opStatement, 10, "127.0.0.1:8000", "127.0.0.2"}, {opContact, 2000, "", "127.0.0.3"}, - {opStatement, 2010, "127.0.0.1", "127.0.0.3"}, - {opStatement, 3000, "127.0.0.1", "127.0.0.4"}, + {opStatement, 2010, "127.0.0.1:8000", "127.0.0.3"}, + {opStatement, 3000, "127.0.0.1:8000", "127.0.0.4"}, {opContact, 3010, "", "127.0.0.4"}, {opCheckFullCone, 3500, "true", ""}, }, @@ -93,12 +94,19 @@ func runIPTrackerTest(t *testing.T, evs []iptrackTestEvent) { clock.Run(evtime - time.Duration(clock.Now())) switch ev.op { case opStatement: - it.AddStatement(ev.from, ev.ip) + it.AddStatement(netip.MustParseAddr(ev.from), netip.MustParseAddrPort(ev.ip)) case opContact: - it.AddContact(ev.from) + it.AddContact(netip.MustParseAddr(ev.from)) case opPredict: - if pred := it.PredictEndpoint(); pred != ev.ip { - t.Errorf("op %d: wrong prediction %q, want %q", i, pred, ev.ip) + pred := it.PredictEndpoint() + if ev.ip == "" { + if pred.IsValid() { + t.Errorf("op %d: wrong prediction %v, expected invalid", i, pred) + } + } else { + if pred != netip.MustParseAddrPort(ev.ip) { + t.Errorf("op %d: wrong prediction %v, want %q", i, pred, ev.ip) + } } case opCheckFullCone: pred := fmt.Sprintf("%t", it.PredictFullConeNAT()) @@ -121,12 +129,11 @@ func TestIPTrackerForceGC(t *testing.T) { it.clock = &clock for i := 0; i < 5*max; i++ { - e1 := make([]byte, 4) - e2 := make([]byte, 4) - crand.Read(e1) - crand.Read(e2) - it.AddStatement(string(e1), string(e2)) - it.AddContact(string(e1)) + var e1, e2 [4]byte + crand.Read(e1[:]) + crand.Read(e2[:]) + it.AddStatement(netip.AddrFrom4(e1), netip.AddrPortFrom(netip.AddrFrom4(e2), 9000)) + it.AddContact(netip.AddrFrom4(e1)) clock.Run(rate) } if len(it.contact) > 2*max { diff --git a/p2p/netutil/net.go b/p2p/netutil/net.go index d5da3c694f..474912978e 100644 --- a/p2p/netutil/net.go +++ b/p2p/netutil/net.go @@ -22,21 +22,19 @@ import ( "errors" "fmt" "net" - "sort" + "net/netip" + "slices" "strings" + + "golang.org/x/exp/maps" ) -var lan4, lan6, special4, special6 Netlist +var special4, special6 Netlist func init() { // Lists from RFC 5735, RFC 5156, // https://www.iana.org/assignments/iana-ipv4-special-registry/ - lan4.Add("0.0.0.0/8") // "This" network - lan4.Add("10.0.0.0/8") // Private Use - lan4.Add("172.16.0.0/12") // Private Use - lan4.Add("192.168.0.0/16") // Private Use - lan6.Add("fe80::/10") // Link-Local - lan6.Add("fc00::/7") // Unique-Local + special4.Add("0.0.0.0/8") // "This" network. special4.Add("192.0.0.0/29") // IPv4 Service Continuity special4.Add("192.0.0.9/32") // PCP Anycast special4.Add("192.0.0.170/32") // NAT64/DNS64 Discovery @@ -66,7 +64,7 @@ func init() { } // Netlist is a list of IP networks. -type Netlist []net.IPNet +type Netlist []netip.Prefix // ParseNetlist parses a comma-separated list of CIDR masks. // Whitespace and extra commas are ignored. @@ -78,11 +76,11 @@ func ParseNetlist(s string) (*Netlist, error) { if mask == "" { continue } - _, n, err := net.ParseCIDR(mask) + prefix, err := netip.ParsePrefix(mask) if err != nil { return nil, err } - l = append(l, *n) + l = append(l, prefix) } return &l, nil } @@ -103,11 +101,11 @@ func (l *Netlist) UnmarshalTOML(fn func(interface{}) error) error { return err } for _, mask := range masks { - _, n, err := net.ParseCIDR(mask) + prefix, err := netip.ParsePrefix(mask) if err != nil { return err } - *l = append(*l, *n) + *l = append(*l, prefix) } return nil } @@ -115,15 +113,20 @@ func (l *Netlist) UnmarshalTOML(fn func(interface{}) error) error { // Add parses a CIDR mask and appends it to the list. It panics for invalid masks and is // intended to be used for setting up static lists. func (l *Netlist) Add(cidr string) { - _, n, err := net.ParseCIDR(cidr) + prefix, err := netip.ParsePrefix(cidr) if err != nil { panic(err) } - *l = append(*l, *n) + *l = append(*l, prefix) } // Contains reports whether the given IP is contained in the list. func (l *Netlist) Contains(ip net.IP) bool { + return l.ContainsAddr(IPToAddr(ip)) +} + +// ContainsAddr reports whether the given IP is contained in the list. +func (l *Netlist) ContainsAddr(ip netip.Addr) bool { if l == nil { return false } @@ -137,25 +140,39 @@ func (l *Netlist) Contains(ip net.IP) bool { // IsLAN reports whether an IP is a local network address. func IsLAN(ip net.IP) bool { + return AddrIsLAN(IPToAddr(ip)) +} + +// AddrIsLAN reports whether an IP is a local network address. +func AddrIsLAN(ip netip.Addr) bool { + if ip.Is4In6() { + ip = netip.AddrFrom4(ip.As4()) + } if ip.IsLoopback() { return true } - if v4 := ip.To4(); v4 != nil { - return lan4.Contains(v4) - } - return lan6.Contains(ip) + return ip.IsPrivate() || ip.IsLinkLocalUnicast() } // IsSpecialNetwork reports whether an IP is located in a special-use network range // This includes broadcast, multicast and documentation addresses. func IsSpecialNetwork(ip net.IP) bool { + return AddrIsSpecialNetwork(IPToAddr(ip)) +} + +// AddrIsSpecialNetwork reports whether an IP is located in a special-use network range +// This includes broadcast, multicast and documentation addresses. +func AddrIsSpecialNetwork(ip netip.Addr) bool { + if ip.Is4In6() { + ip = netip.AddrFrom4(ip.As4()) + } if ip.IsMulticast() { return true } - if v4 := ip.To4(); v4 != nil { - return special4.Contains(v4) + if ip.Is4() { + return special4.ContainsAddr(ip) } - return special6.Contains(ip) + return special6.ContainsAddr(ip) } var ( @@ -175,19 +192,31 @@ var ( // - LAN addresses are OK if relayed by a LAN host. // - All other addresses are always acceptable. func CheckRelayIP(sender, addr net.IP) error { - if len(addr) != net.IPv4len && len(addr) != net.IPv6len { + return CheckRelayAddr(IPToAddr(sender), IPToAddr(addr)) +} + +// CheckRelayAddr reports whether an IP relayed from the given sender IP +// is a valid connection target. +// +// There are four rules: +// - Special network addresses are never valid. +// - Loopback addresses are OK if relayed by a loopback host. +// - LAN addresses are OK if relayed by a LAN host. +// - All other addresses are always acceptable. +func CheckRelayAddr(sender, addr netip.Addr) error { + if !addr.IsValid() { return errInvalid } if addr.IsUnspecified() { return errUnspecified } - if IsSpecialNetwork(addr) { + if AddrIsSpecialNetwork(addr) { return errSpecial } if addr.IsLoopback() && !sender.IsLoopback() { return errLoopback } - if IsLAN(addr) && !IsLAN(sender) { + if AddrIsLAN(addr) && !AddrIsLAN(sender) { return errLAN } return nil @@ -221,17 +250,22 @@ type DistinctNetSet struct { Subnet uint // number of common prefix bits Limit uint // maximum number of IPs in each subnet - members map[string]uint - buf net.IP + members map[netip.Prefix]uint } // Add adds an IP address to the set. It returns false (and doesn't add the IP) if the // number of existing IPs in the defined range exceeds the limit. func (s *DistinctNetSet) Add(ip net.IP) bool { + return s.AddAddr(IPToAddr(ip)) +} + +// AddAddr adds an IP address to the set. It returns false (and doesn't add the IP) if the +// number of existing IPs in the defined range exceeds the limit. +func (s *DistinctNetSet) AddAddr(ip netip.Addr) bool { key := s.key(ip) - n := s.members[string(key)] + n := s.members[key] if n < s.Limit { - s.members[string(key)] = n + 1 + s.members[key] = n + 1 return true } return false @@ -239,20 +273,30 @@ func (s *DistinctNetSet) Add(ip net.IP) bool { // Remove removes an IP from the set. func (s *DistinctNetSet) Remove(ip net.IP) { + s.RemoveAddr(IPToAddr(ip)) +} + +// RemoveAddr removes an IP from the set. +func (s *DistinctNetSet) RemoveAddr(ip netip.Addr) { key := s.key(ip) - if n, ok := s.members[string(key)]; ok { + if n, ok := s.members[key]; ok { if n == 1 { - delete(s.members, string(key)) + delete(s.members, key) } else { - s.members[string(key)] = n - 1 + s.members[key] = n - 1 } } } // Contains whether the given IP is contained in the set. func (s DistinctNetSet) Contains(ip net.IP) bool { + return s.ContainsAddr(IPToAddr(ip)) +} + +// ContainsAddr whether the given IP is contained in the set. +func (s DistinctNetSet) ContainsAddr(ip netip.Addr) bool { key := s.key(ip) - _, ok := s.members[string(key)] + _, ok := s.members[key] return ok } @@ -265,54 +309,30 @@ func (s DistinctNetSet) Len() int { return int(n) } -// key encodes the map key for an address into a temporary buffer. -// -// The first byte of key is '4' or '6' to distinguish IPv4/IPv6 address types. -// The remainder of the key is the IP, truncated to the number of bits. -func (s *DistinctNetSet) key(ip net.IP) net.IP { +// key returns the map key for ip. +func (s *DistinctNetSet) key(ip netip.Addr) netip.Prefix { // Lazily initialize storage. if s.members == nil { - s.members = make(map[string]uint) - s.buf = make(net.IP, 17) - } - // Canonicalize ip and bits. - typ := byte('6') - if ip4 := ip.To4(); ip4 != nil { - typ, ip = '4', ip4 + s.members = make(map[netip.Prefix]uint) } - bits := s.Subnet - if bits > uint(len(ip)*8) { - bits = uint(len(ip) * 8) - } - // Encode the prefix into s.buf. - nb := int(bits / 8) - mask := ^byte(0xFF >> (bits % 8)) - s.buf[0] = typ - buf := append(s.buf[:1], ip[:nb]...) - if nb < len(ip) && mask != 0 { - buf = append(buf, ip[nb]&mask) + p, err := ip.Prefix(int(s.Subnet)) + if err != nil { + panic(err) } - return buf + return p } // String implements fmt.Stringer func (s DistinctNetSet) String() string { + keys := maps.Keys(s.members) + slices.SortFunc(keys, func(a, b netip.Prefix) int { + return strings.Compare(a.String(), b.String()) + }) + var buf bytes.Buffer buf.WriteString("{") - keys := make([]string, 0, len(s.members)) - for k := range s.members { - keys = append(keys, k) - } - sort.Strings(keys) for i, k := range keys { - var ip net.IP - if k[0] == '4' { - ip = make(net.IP, 4) - } else { - ip = make(net.IP, 16) - } - copy(ip, k[1:]) - fmt.Fprintf(&buf, "%v×%d", ip, s.members[k]) + fmt.Fprintf(&buf, "%v×%d", k, s.members[k]) if i != len(keys)-1 { buf.WriteString(" ") } diff --git a/p2p/netutil/net_test.go b/p2p/netutil/net_test.go index 3a6aa081f2..569c7ac454 100644 --- a/p2p/netutil/net_test.go +++ b/p2p/netutil/net_test.go @@ -18,7 +18,9 @@ package netutil import ( "fmt" + "math/rand" "net" + "net/netip" "reflect" "testing" "testing/quick" @@ -29,7 +31,7 @@ import ( func TestParseNetlist(t *testing.T) { var tests = []struct { input string - wantErr error + wantErr string wantList *Netlist }{ { @@ -38,25 +40,27 @@ func TestParseNetlist(t *testing.T) { }, { input: "127.0.0.0/8", - wantErr: nil, - wantList: &Netlist{{IP: net.IP{127, 0, 0, 0}, Mask: net.CIDRMask(8, 32)}}, + wantList: &Netlist{netip.MustParsePrefix("127.0.0.0/8")}, }, { input: "127.0.0.0/44", - wantErr: &net.ParseError{Type: "CIDR address", Text: "127.0.0.0/44"}, + wantErr: `netip.ParsePrefix("127.0.0.0/44"): prefix length out of range`, }, { input: "127.0.0.0/16, 23.23.23.23/24,", wantList: &Netlist{ - {IP: net.IP{127, 0, 0, 0}, Mask: net.CIDRMask(16, 32)}, - {IP: net.IP{23, 23, 23, 0}, Mask: net.CIDRMask(24, 32)}, + netip.MustParsePrefix("127.0.0.0/16"), + netip.MustParsePrefix("23.23.23.23/24"), }, }, } for _, test := range tests { l, err := ParseNetlist(test.input) - if !reflect.DeepEqual(err, test.wantErr) { + if err == nil && test.wantErr != "" { + t.Errorf("%q: got no error, expected %q", test.input, test.wantErr) + continue + } else if err != nil && err.Error() != test.wantErr { t.Errorf("%q: got error %q, want %q", test.input, err, test.wantErr) continue } @@ -70,14 +74,12 @@ func TestParseNetlist(t *testing.T) { func TestNilNetListContains(t *testing.T) { var list *Netlist - checkContains(t, list.Contains, nil, []string{"1.2.3.4"}) + checkContains(t, list.Contains, list.ContainsAddr, nil, []string{"1.2.3.4"}) } func TestIsLAN(t *testing.T) { - checkContains(t, IsLAN, + checkContains(t, IsLAN, AddrIsLAN, []string{ // included - "0.0.0.0", - "0.2.0.8", "127.0.0.1", "10.0.1.1", "10.22.0.3", @@ -86,25 +88,35 @@ func TestIsLAN(t *testing.T) { "fe80::f4a1:8eff:fec5:9d9d", "febf::ab32:2233", "fc00::4", + // 4-in-6 + "::ffff:127.0.0.1", + "::ffff:10.10.0.2", }, []string{ // excluded "192.0.2.1", "1.0.0.0", "172.32.0.1", "fec0::2233", + // 4-in-6 + "::ffff:88.99.100.2", }, ) } func TestIsSpecialNetwork(t *testing.T) { - checkContains(t, IsSpecialNetwork, + checkContains(t, IsSpecialNetwork, AddrIsSpecialNetwork, []string{ // included + "0.0.0.0", + "0.2.0.8", "192.0.2.1", "192.0.2.44", "2001:db8:85a3:8d3:1319:8a2e:370:7348", "255.255.255.255", "224.0.0.22", // IPv4 multicast "ff05::1:3", // IPv6 multicast + // 4-in-6 + "::ffff:255.255.255.255", + "::ffff:192.0.2.1", }, []string{ // excluded "192.0.3.1", @@ -115,15 +127,21 @@ func TestIsSpecialNetwork(t *testing.T) { ) } -func checkContains(t *testing.T, fn func(net.IP) bool, inc, exc []string) { +func checkContains(t *testing.T, fn func(net.IP) bool, fn2 func(netip.Addr) bool, inc, exc []string) { for _, s := range inc { if !fn(parseIP(s)) { - t.Error("returned false for included address", s) + t.Error("returned false for included net.IP", s) + } + if !fn2(netip.MustParseAddr(s)) { + t.Error("returned false for included netip.Addr", s) } } for _, s := range exc { if fn(parseIP(s)) { - t.Error("returned true for excluded address", s) + t.Error("returned true for excluded net.IP", s) + } + if fn2(netip.MustParseAddr(s)) { + t.Error("returned true for excluded netip.Addr", s) } } } @@ -244,14 +262,22 @@ func TestDistinctNetSet(t *testing.T) { } func TestDistinctNetSetAddRemove(t *testing.T) { - cfg := &quick.Config{} - fn := func(ips []net.IP) bool { + cfg := &quick.Config{ + Values: func(s []reflect.Value, rng *rand.Rand) { + slice := make([]netip.Addr, rng.Intn(20)+1) + for i := range slice { + slice[i] = RandomAddr(rng, false) + } + s[0] = reflect.ValueOf(slice) + }, + } + fn := func(ips []netip.Addr) bool { s := DistinctNetSet{Limit: 3, Subnet: 2} for _, ip := range ips { - s.Add(ip) + s.AddAddr(ip) } for _, ip := range ips { - s.Remove(ip) + s.RemoveAddr(ip) } return s.Len() == 0 } diff --git a/p2p/server.go b/p2p/server.go index 13eebed3f4..22e5f6cb94 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -905,14 +905,14 @@ func (srv *Server) listenLoop() { break } - remoteIP := netutil.AddrIP(fd.RemoteAddr()) + remoteIP := netutil.AddrAddr(fd.RemoteAddr()) if err := srv.checkInboundConn(remoteIP); err != nil { srv.log.Debug("Rejected inbound connection", "addr", fd.RemoteAddr(), "err", err) fd.Close() slots <- struct{}{} continue } - if remoteIP != nil { + if remoteIP.IsValid() { fd = newMeteredConn(fd) serveMeter.Mark(1) srv.log.Trace("Accepted connection", "addr", fd.RemoteAddr()) @@ -924,18 +924,19 @@ func (srv *Server) listenLoop() { } } -func (srv *Server) checkInboundConn(remoteIP net.IP) error { - if remoteIP == nil { +func (srv *Server) checkInboundConn(remoteIP netip.Addr) error { + if !remoteIP.IsValid() { + // This case happens for internal test connections without remote address. return nil } // Reject connections that do not match NetRestrict. - if srv.NetRestrict != nil && !srv.NetRestrict.Contains(remoteIP) { + if srv.NetRestrict != nil && !srv.NetRestrict.ContainsAddr(remoteIP) { return errors.New("not in netrestrict list") } // Reject Internet peers that try too often. now := srv.clock.Now() srv.inboundHistory.expire(now, nil) - if !netutil.IsLAN(remoteIP) && srv.inboundHistory.contains(remoteIP.String()) { + if !netutil.AddrIsLAN(remoteIP) && srv.inboundHistory.contains(remoteIP.String()) { return errors.New("too many attempts") } srv.inboundHistory.add(remoteIP.String(), now.Add(inboundThrottleTime)) @@ -1108,7 +1109,7 @@ func (srv *Server) NodeInfo() *NodeInfo { Name: srv.Name, Enode: node.URLv4(), ID: node.ID().String(), - IP: node.IP().String(), + IP: node.IPAddr().String(), ListenAddr: srv.ListenAddr, Protocols: make(map[string]interface{}), } diff --git a/p2p/server_nat_test.go b/p2p/server_nat_test.go index de935fcfc5..cbb1f37e0a 100644 --- a/p2p/server_nat_test.go +++ b/p2p/server_nat_test.go @@ -18,6 +18,7 @@ package p2p import ( "net" + "net/netip" "sync/atomic" "testing" "time" @@ -64,8 +65,8 @@ func TestServerPortMapping(t *testing.T) { t.Error("wrong request count:", reqCount) } enr := srv.LocalNode().Node() - if enr.IP().String() != "192.0.2.0" { - t.Error("wrong IP in ENR:", enr.IP()) + if enr.IPAddr() != netip.MustParseAddr("192.0.2.0") { + t.Error("wrong IP in ENR:", enr.IPAddr()) } if enr.TCP() != 30000 { t.Error("wrong TCP port in ENR:", enr.TCP())