diff --git a/cmd/devp2p/internal/v4test/framework.go b/cmd/devp2p/internal/v4test/framework.go index 9286594181..e8f4c021b8 100644 --- a/cmd/devp2p/internal/v4test/framework.go +++ b/cmd/devp2p/internal/v4test/framework.go @@ -110,7 +110,7 @@ func (te *testenv) localEndpoint(c net.PacketConn) v4wire.Endpoint { } func (te *testenv) remoteEndpoint() v4wire.Endpoint { - return v4wire.NewEndpoint(te.remoteAddr, 0) + return v4wire.NewEndpoint(te.remoteAddr.AddrPort(), 0) } func contains(ns []v4wire.Node, key v4wire.Pubkey) bool { diff --git a/p2p/discover/common.go b/p2p/discover/common.go index bebea8cc38..0716f7472f 100644 --- a/p2p/discover/common.go +++ b/p2p/discover/common.go @@ -22,6 +22,7 @@ import ( "encoding/binary" "math/rand" "net" + "net/netip" "sync" "time" @@ -34,8 +35,8 @@ import ( // UDPConn is a network connection on which discovery can operate. type UDPConn interface { - ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) - WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error) + ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) + WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (n int, err error) Close() error LocalAddr() net.Addr } @@ -94,7 +95,7 @@ func ListenUDP(c UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv4, error) { // channel if configured. type ReadPacket struct { Data []byte - Addr *net.UDPAddr + Addr netip.AddrPort } type randomSource interface { diff --git a/p2p/discover/lookup.go b/p2p/discover/lookup.go index 5c3d90d6c9..09808b71e0 100644 --- a/p2p/discover/lookup.go +++ b/p2p/discover/lookup.go @@ -29,16 +29,16 @@ import ( // not need to be an actual node identifier. type lookup struct { tab *Table - queryfunc func(*node) ([]*node, error) - replyCh chan []*node + queryfunc queryFunc + replyCh chan []*enode.Node cancelCh <-chan struct{} asked, seen map[enode.ID]bool result nodesByDistance - replyBuffer []*node + replyBuffer []*enode.Node queries int } -type queryFunc func(*node) ([]*node, error) +type queryFunc func(*enode.Node) ([]*enode.Node, error) func newLookup(ctx context.Context, tab *Table, target enode.ID, q queryFunc) *lookup { it := &lookup{ @@ -47,7 +47,7 @@ func newLookup(ctx context.Context, tab *Table, target enode.ID, q queryFunc) *l asked: make(map[enode.ID]bool), seen: make(map[enode.ID]bool), result: nodesByDistance{target: target}, - replyCh: make(chan []*node, alpha), + replyCh: make(chan []*enode.Node, alpha), cancelCh: ctx.Done(), queries: -1, } @@ -61,7 +61,7 @@ func newLookup(ctx context.Context, tab *Table, target enode.ID, q queryFunc) *l func (it *lookup) run() []*enode.Node { for it.advance() { } - return unwrapNodes(it.result.entries) + return it.result.entries } // advance advances the lookup until any new nodes have been found. @@ -139,7 +139,7 @@ func (it *lookup) slowdown() { } } -func (it *lookup) query(n *node, reply chan<- []*node) { +func (it *lookup) query(n *enode.Node, reply chan<- []*enode.Node) { r, err := it.queryfunc(n) if !errors.Is(err, errClosed) { // avoid recording failures on shutdown. success := len(r) > 0 @@ -154,7 +154,7 @@ func (it *lookup) query(n *node, reply chan<- []*node) { // lookupIterator performs lookup operations and iterates over all seen nodes. // When a lookup finishes, a new one is created through nextLookup. type lookupIterator struct { - buffer []*node + buffer []*enode.Node nextLookup lookupFunc ctx context.Context cancel func() @@ -173,7 +173,7 @@ func (it *lookupIterator) Node() *enode.Node { if len(it.buffer) == 0 { return nil } - return unwrapNode(it.buffer[0]) + return it.buffer[0] } // Next moves to the next node. diff --git a/p2p/discover/metrics.go b/p2p/discover/metrics.go index 3cd0ab0414..8deafbbce4 100644 --- a/p2p/discover/metrics.go +++ b/p2p/discover/metrics.go @@ -18,7 +18,7 @@ package discover import ( "fmt" - "net" + "net/netip" "github.com/ethereum/go-ethereum/metrics" ) @@ -58,16 +58,16 @@ func newMeteredConn(conn UDPConn) UDPConn { return &meteredUdpConn{UDPConn: conn} } -// ReadFromUDP delegates a network read to the underlying connection, bumping the udp ingress traffic meter along the way. -func (c *meteredUdpConn) ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) { - n, addr, err = c.UDPConn.ReadFromUDP(b) +// ReadFromUDPAddrPort delegates a network read to the underlying connection, bumping the udp ingress traffic meter along the way. +func (c *meteredUdpConn) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) { + n, addr, err = c.UDPConn.ReadFromUDPAddrPort(b) ingressTrafficMeter.Mark(int64(n)) return n, addr, err } -// Write delegates a network write to the underlying connection, bumping the udp egress traffic meter along the way. -func (c *meteredUdpConn) WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error) { - n, err = c.UDPConn.WriteToUDP(b, addr) +// WriteToUDP delegates a network write to the underlying connection, bumping the udp egress traffic meter along the way. +func (c *meteredUdpConn) WriteToUDP(b []byte, addr netip.AddrPort) (n int, err error) { + n, err = c.UDPConn.WriteToUDPAddrPort(b, addr) egressTrafficMeter.Mark(int64(n)) return n, err } diff --git a/p2p/discover/node.go b/p2p/discover/node.go index 47788248f4..042619221b 100644 --- a/p2p/discover/node.go +++ b/p2p/discover/node.go @@ -21,7 +21,8 @@ import ( "crypto/elliptic" "errors" "math/big" - "net" + "slices" + "sort" "time" "github.com/ethereum/go-ethereum/common/math" @@ -37,9 +38,8 @@ type BucketNode struct { Live bool `json:"live"` } -// node represents a host on the network. -// The fields of Node may not be modified. -type node struct { +// tableNode is an entry in Table. +type tableNode struct { *enode.Node revalList *revalidationList addedToTable time.Time // first time node was added to bucket or replacement list @@ -75,34 +75,59 @@ func (e encPubkey) id() enode.ID { return enode.ID(crypto.Keccak256Hash(e[:])) } -func wrapNode(n *enode.Node) *node { - return &node{Node: n} -} - -func wrapNodes(ns []*enode.Node) []*node { - result := make([]*node, len(ns)) +func unwrapNodes(ns []*tableNode) []*enode.Node { + result := make([]*enode.Node, len(ns)) for i, n := range ns { - result[i] = wrapNode(n) + result[i] = n.Node } return result } -func unwrapNode(n *node) *enode.Node { - return n.Node +func (n *tableNode) String() string { + return n.Node.String() +} + +// nodesByDistance is a list of nodes, ordered by distance to target. +type nodesByDistance struct { + entries []*enode.Node + target enode.ID } -func unwrapNodes(ns []*node) []*enode.Node { - result := make([]*enode.Node, len(ns)) - for i, n := range ns { - result[i] = unwrapNode(n) +// push adds the given node to the list, keeping the total size below maxElems. +func (h *nodesByDistance) push(n *enode.Node, maxElems int) { + ix := sort.Search(len(h.entries), func(i int) bool { + return enode.DistCmp(h.target, h.entries[i].ID(), n.ID()) > 0 + }) + + end := len(h.entries) + if len(h.entries) < maxElems { + h.entries = append(h.entries, n) + } + if ix < end { + // Slide existing entries down to make room. + // This will overwrite the entry we just appended. + copy(h.entries[ix+1:], h.entries[ix:]) + h.entries[ix] = n } - return result } -func (n *node) addr() *net.UDPAddr { - return &net.UDPAddr{IP: n.IP(), Port: n.UDP()} +type nodeType interface { + ID() enode.ID } -func (n *node) String() string { - return n.Node.String() +// containsID reports whether ns contains a node with the given ID. +func containsID[N nodeType](ns []N, id enode.ID) bool { + for _, n := range ns { + if n.ID() == id { + return true + } + } + return false +} + +// deleteNode removes a node from the list. +func deleteNode[N nodeType](list []N, id enode.ID) []N { + return slices.DeleteFunc(list, func(n N) bool { + return n.ID() == id + }) } diff --git a/p2p/discover/table.go b/p2p/discover/table.go index 2b4ba7f5d8..bd3c9b4143 100644 --- a/p2p/discover/table.go +++ b/p2p/discover/table.go @@ -27,7 +27,6 @@ import ( "fmt" "net" "slices" - "sort" "sync" "time" @@ -65,7 +64,7 @@ const ( type Table struct { mutex sync.Mutex // protects buckets, bucket content, nursery, rand buckets [nBuckets]*bucket // index of known nodes by distance - nursery []*node // bootstrap nodes + nursery []*enode.Node // bootstrap nodes rand reseedingRandom // source of randomness, periodically reseeded ips netutil.DistinctNetSet revalidation tableRevalidation @@ -85,8 +84,8 @@ type Table struct { closeReq chan struct{} closed chan struct{} - nodeAddedHook func(*bucket, *node) - nodeRemovedHook func(*bucket, *node) + nodeAddedHook func(*bucket, *tableNode) + nodeRemovedHook func(*bucket, *tableNode) } // transport is implemented by the UDP transports. @@ -101,20 +100,21 @@ type transport interface { // bucket contains nodes, ordered by their last activity. the entry // that was most recently active is the first element in entries. type bucket struct { - entries []*node // live entries, sorted by time of last contact - replacements []*node // recently seen nodes to be used if revalidation fails + entries []*tableNode // live entries, sorted by time of last contact + replacements []*tableNode // recently seen nodes to be used if revalidation fails ips netutil.DistinctNetSet index int } type addNodeOp struct { - node *node - isInbound bool + node *enode.Node + isInbound bool + forceSetLive bool // for tests } type trackRequestOp struct { - node *node - foundNodes []*node + node *enode.Node + foundNodes []*enode.Node success bool } @@ -186,7 +186,7 @@ func (tab *Table) getNode(id enode.ID) *enode.Node { b := tab.bucket(id) for _, e := range b.entries { if e.ID() == id { - return unwrapNode(e) + return e.Node } } return nil @@ -202,7 +202,7 @@ func (tab *Table) close() { // are used to connect to the network if the table is empty and there // are no known nodes in the database. func (tab *Table) setFallbackNodes(nodes []*enode.Node) error { - nursery := make([]*node, 0, len(nodes)) + nursery := make([]*enode.Node, 0, len(nodes)) for _, n := range nodes { if err := n.ValidateComplete(); err != nil { return fmt.Errorf("bad bootstrap node %q: %v", n, err) @@ -211,7 +211,7 @@ func (tab *Table) setFallbackNodes(nodes []*enode.Node) error { tab.log.Error("Bootstrap node filtered by netrestrict", "id", n.ID(), "ip", n.IP()) continue } - nursery = append(nursery, wrapNode(n)) + nursery = append(nursery, n) } tab.nursery = nursery return nil @@ -255,9 +255,9 @@ func (tab *Table) findnodeByID(target enode.ID, nresults int, preferLive bool) * liveNodes := &nodesByDistance{target: target} for _, b := range &tab.buckets { for _, n := range b.entries { - nodes.push(n, nresults) + nodes.push(n.Node, nresults) if preferLive && n.isValidatedLive { - liveNodes.push(n, nresults) + liveNodes.push(n.Node, nresults) } } } @@ -309,8 +309,8 @@ func (tab *Table) len() (n int) { // list. // // The caller must not hold tab.mutex. -func (tab *Table) addFoundNode(n *node) bool { - op := addNodeOp{node: n, isInbound: false} +func (tab *Table) addFoundNode(n *enode.Node, forceSetLive bool) bool { + op := addNodeOp{node: n, isInbound: false, forceSetLive: forceSetLive} select { case tab.addNodeCh <- op: return <-tab.addNodeHandled @@ -327,7 +327,7 @@ func (tab *Table) addFoundNode(n *node) bool { // repeatedly. // // The caller must not hold tab.mutex. -func (tab *Table) addInboundNode(n *node) bool { +func (tab *Table) addInboundNode(n *enode.Node) bool { op := addNodeOp{node: n, isInbound: true} select { case tab.addNodeCh <- op: @@ -337,7 +337,7 @@ func (tab *Table) addInboundNode(n *node) bool { } } -func (tab *Table) trackRequest(n *node, success bool, foundNodes []*node) { +func (tab *Table) trackRequest(n *enode.Node, success bool, foundNodes []*enode.Node) { op := trackRequestOp{n, foundNodes, success} select { case tab.trackRequestCh <- op: @@ -443,13 +443,14 @@ func (tab *Table) doRefresh(done chan struct{}) { } func (tab *Table) loadSeedNodes() { - seeds := wrapNodes(tab.db.QuerySeeds(seedCount, seedMaxAge)) + seeds := tab.db.QuerySeeds(seedCount, seedMaxAge) seeds = append(seeds, tab.nursery...) for i := range seeds { seed := seeds[i] if tab.log.Enabled(context.Background(), log.LevelTrace) { age := time.Since(tab.db.LastPongReceived(seed.ID(), seed.IP())) - tab.log.Trace("Found seed node in database", "id", seed.ID(), "addr", seed.addr(), "age", age) + addr, _ := seed.UDPEndpoint() + tab.log.Trace("Found seed node in database", "id", seed.ID(), "addr", addr, "age", age) } tab.handleAddNode(addNodeOp{node: seed, isInbound: false}) } @@ -513,7 +514,7 @@ func (tab *Table) handleAddNode(req addNodeOp) bool { } b := tab.bucket(req.node.ID()) - n, _ := tab.bumpInBucket(b, req.node.Node, req.isInbound) + n, _ := tab.bumpInBucket(b, req.node, req.isInbound) if n != nil { // Already in bucket. return false @@ -529,15 +530,20 @@ func (tab *Table) handleAddNode(req addNodeOp) bool { } // Add to bucket. - b.entries = append(b.entries, req.node) - b.replacements = deleteNode(b.replacements, req.node) - tab.nodeAdded(b, req.node) + wn := &tableNode{Node: req.node} + if req.forceSetLive { + wn.livenessChecks = 1 + wn.isValidatedLive = true + } + b.entries = append(b.entries, wn) + b.replacements = deleteNode(b.replacements, wn.ID()) + tab.nodeAdded(b, wn) return true } // addReplacement adds n to the replacement cache of bucket b. -func (tab *Table) addReplacement(b *bucket, n *node) { - if contains(b.replacements, n.ID()) { +func (tab *Table) addReplacement(b *bucket, n *enode.Node) { + if containsID(b.replacements, n.ID()) { // TODO: update ENR return } @@ -545,15 +551,15 @@ func (tab *Table) addReplacement(b *bucket, n *node) { return } - n.addedToTable = time.Now() - var removed *node - b.replacements, removed = pushNode(b.replacements, n, maxReplacements) + wn := &tableNode{Node: n, addedToTable: time.Now()} + var removed *tableNode + b.replacements, removed = pushNode(b.replacements, wn, maxReplacements) if removed != nil { tab.removeIP(b, removed.IP()) } } -func (tab *Table) nodeAdded(b *bucket, n *node) { +func (tab *Table) nodeAdded(b *bucket, n *tableNode) { if n.addedToTable == (time.Time{}) { n.addedToTable = time.Now() } @@ -567,7 +573,7 @@ func (tab *Table) nodeAdded(b *bucket, n *node) { } } -func (tab *Table) nodeRemoved(b *bucket, n *node) { +func (tab *Table) nodeRemoved(b *bucket, n *tableNode) { tab.revalidation.nodeRemoved(n) if tab.nodeRemovedHook != nil { tab.nodeRemovedHook(b, n) @@ -579,8 +585,8 @@ func (tab *Table) nodeRemoved(b *bucket, n *node) { // deleteInBucket removes node n from the table. // If there are replacement nodes in the bucket, the node is replaced. -func (tab *Table) deleteInBucket(b *bucket, id enode.ID) *node { - index := slices.IndexFunc(b.entries, func(e *node) bool { return e.ID() == id }) +func (tab *Table) deleteInBucket(b *bucket, id enode.ID) *tableNode { + index := slices.IndexFunc(b.entries, func(e *tableNode) bool { return e.ID() == id }) if index == -1 { // Entry has been removed already. return nil @@ -608,8 +614,8 @@ func (tab *Table) deleteInBucket(b *bucket, id enode.ID) *node { // bumpInBucket updates a node record if it exists in the bucket. // The second return value reports whether the node's endpoint (IP/port) was updated. -func (tab *Table) bumpInBucket(b *bucket, newRecord *enode.Node, isInbound bool) (n *node, endpointChanged bool) { - i := slices.IndexFunc(b.entries, func(elem *node) bool { +func (tab *Table) bumpInBucket(b *bucket, newRecord *enode.Node, isInbound bool) (n *tableNode, endpointChanged bool) { + i := slices.IndexFunc(b.entries, func(elem *tableNode) bool { return elem.ID() == newRecord.ID() }) if i == -1 { @@ -672,21 +678,12 @@ func (tab *Table) handleTrackRequest(op trackRequestOp) { // Add found nodes. for _, n := range op.foundNodes { - tab.handleAddNode(addNodeOp{n, false}) + tab.handleAddNode(addNodeOp{n, false, false}) } } -func contains(ns []*node, id enode.ID) bool { - for _, n := range ns { - if n.ID() == id { - return true - } - } - return false -} - // pushNode adds n to the front of list, keeping at most max items. -func pushNode(list []*node, n *node, max int) ([]*node, *node) { +func pushNode(list []*tableNode, n *tableNode, max int) ([]*tableNode, *tableNode) { if len(list) < max { list = append(list, nil) } @@ -695,37 +692,3 @@ func pushNode(list []*node, n *node, max int) ([]*node, *node) { list[0] = n return list, removed } - -// deleteNode removes n from list. -func deleteNode(list []*node, n *node) []*node { - for i := range list { - if list[i].ID() == n.ID() { - return append(list[:i], list[i+1:]...) - } - } - return list -} - -// nodesByDistance is a list of nodes, ordered by distance to target. -type nodesByDistance struct { - entries []*node - target enode.ID -} - -// push adds the given node to the list, keeping the total size below maxElems. -func (h *nodesByDistance) push(n *node, maxElems int) { - ix := sort.Search(len(h.entries), func(i int) bool { - return enode.DistCmp(h.target, h.entries[i].ID(), n.ID()) > 0 - }) - - end := len(h.entries) - if len(h.entries) < maxElems { - h.entries = append(h.entries, n) - } - if ix < end { - // Slide existing entries down to make room. - // This will overwrite the entry we just appended. - copy(h.entries[ix+1:], h.entries[ix:]) - h.entries[ix] = n - } -} diff --git a/p2p/discover/table_reval.go b/p2p/discover/table_reval.go index 5d185aa8b4..f2ea8b34fa 100644 --- a/p2p/discover/table_reval.go +++ b/p2p/discover/table_reval.go @@ -39,7 +39,7 @@ type tableRevalidation struct { } type revalidationResponse struct { - n *node + n *tableNode newRecord *enode.Node didRespond bool } @@ -55,12 +55,12 @@ func (tr *tableRevalidation) init(cfg *Config) { } // nodeAdded is called when the table receives a new node. -func (tr *tableRevalidation) nodeAdded(tab *Table, n *node) { +func (tr *tableRevalidation) nodeAdded(tab *Table, n *tableNode) { tr.fast.push(n, tab.cfg.Clock.Now(), &tab.rand) } // nodeRemoved is called when a node was removed from the table. -func (tr *tableRevalidation) nodeRemoved(n *node) { +func (tr *tableRevalidation) nodeRemoved(n *tableNode) { if n.revalList == nil { panic(fmt.Errorf("removed node %v has nil revalList", n.ID())) } @@ -68,7 +68,7 @@ func (tr *tableRevalidation) nodeRemoved(n *node) { } // nodeEndpointChanged is called when a change in IP or port is detected. -func (tr *tableRevalidation) nodeEndpointChanged(tab *Table, n *node) { +func (tr *tableRevalidation) nodeEndpointChanged(tab *Table, n *tableNode) { n.isValidatedLive = false tr.moveToList(&tr.fast, n, tab.cfg.Clock.Now(), &tab.rand) } @@ -90,7 +90,7 @@ func (tr *tableRevalidation) run(tab *Table, now mclock.AbsTime) (nextTime mcloc } // startRequest spawns a revalidation request for node n. -func (tr *tableRevalidation) startRequest(tab *Table, n *node) { +func (tr *tableRevalidation) startRequest(tab *Table, n *tableNode) { if _, ok := tr.activeReq[n.ID()]; ok { panic(fmt.Errorf("duplicate startRequest (node %v)", n.ID())) } @@ -180,7 +180,7 @@ func (tr *tableRevalidation) handleResponse(tab *Table, resp revalidationRespons } // moveToList ensures n is in the 'dest' list. -func (tr *tableRevalidation) moveToList(dest *revalidationList, n *node, now mclock.AbsTime, rand randomSource) { +func (tr *tableRevalidation) moveToList(dest *revalidationList, n *tableNode, now mclock.AbsTime, rand randomSource) { if n.revalList == dest { return } @@ -192,14 +192,14 @@ func (tr *tableRevalidation) moveToList(dest *revalidationList, n *node, now mcl // revalidationList holds a list nodes and the next revalidation time. type revalidationList struct { - nodes []*node + nodes []*tableNode nextTime mclock.AbsTime interval time.Duration name string } // get returns a random node from the queue. Nodes in the 'exclude' map are not returned. -func (list *revalidationList) get(now mclock.AbsTime, rand randomSource, exclude map[enode.ID]struct{}) *node { +func (list *revalidationList) get(now mclock.AbsTime, rand randomSource, exclude map[enode.ID]struct{}) *tableNode { if now < list.nextTime || len(list.nodes) == 0 { return nil } @@ -217,7 +217,7 @@ func (list *revalidationList) schedule(now mclock.AbsTime, rand randomSource) { list.nextTime = now.Add(time.Duration(rand.Int63n(int64(list.interval)))) } -func (list *revalidationList) push(n *node, now mclock.AbsTime, rand randomSource) { +func (list *revalidationList) push(n *tableNode, now mclock.AbsTime, rand randomSource) { list.nodes = append(list.nodes, n) if list.nextTime == never { list.schedule(now, rand) @@ -225,7 +225,7 @@ func (list *revalidationList) push(n *node, now mclock.AbsTime, rand randomSourc n.revalList = list } -func (list *revalidationList) remove(n *node) { +func (list *revalidationList) remove(n *tableNode) { i := slices.Index(list.nodes, n) if i == -1 { panic(fmt.Errorf("node %v not found in list", n.ID())) @@ -238,7 +238,7 @@ func (list *revalidationList) remove(n *node) { } func (list *revalidationList) contains(id enode.ID) bool { - return slices.ContainsFunc(list.nodes, func(n *node) bool { + return slices.ContainsFunc(list.nodes, func(n *tableNode) bool { return n.ID() == id }) } diff --git a/p2p/discover/table_reval_test.go b/p2p/discover/table_reval_test.go index d168767e0d..3605443934 100644 --- a/p2p/discover/table_reval_test.go +++ b/p2p/discover/table_reval_test.go @@ -110,10 +110,10 @@ func TestRevalidation_endpointUpdate(t *testing.T) { } tr.handleResponse(tab, resp) - if !tr.fast.contains(node.ID()) { + if tr.fast.nodes[0].ID() != node.ID() { t.Fatal("node not contained in fast revalidation list") } - if node.isValidatedLive { + if tr.fast.nodes[0].isValidatedLive { t.Fatal("node is marked live after endpoint change") } } diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go index b0be2a94c5..30e7d56f4a 100644 --- a/p2p/discover/table_test.go +++ b/p2p/discover/table_test.go @@ -22,6 +22,7 @@ import ( "math/rand" "net" "reflect" + "slices" "testing" "testing/quick" "time" @@ -64,7 +65,7 @@ func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding // Fill up the sender's bucket. replacementNodeKey, _ := crypto.HexToECDSA("45a915e4d060149eb4365960e6a7a45f334393093061116b197e3240065ff2d8") - replacementNode := wrapNode(enode.NewV4(&replacementNodeKey.PublicKey, net.IP{127, 0, 0, 1}, 99, 99)) + replacementNode := enode.NewV4(&replacementNodeKey.PublicKey, net.IP{127, 0, 0, 1}, 99, 99) last := fillBucket(tab, replacementNode.ID()) tab.mutex.Lock() nodeEvents := newNodeEventRecorder(128) @@ -78,7 +79,7 @@ func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding transport.dead[replacementNode.ID()] = !newNodeIsResponding // Add replacement node to table. - tab.addFoundNode(replacementNode) + tab.addFoundNode(replacementNode, false) t.Log("last:", last.ID()) t.Log("replacement:", replacementNode.ID()) @@ -115,11 +116,11 @@ func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding if l := len(bucket.entries); l != wantSize { t.Errorf("wrong bucket size after revalidation: got %d, want %d", l, wantSize) } - if ok := contains(bucket.entries, last.ID()); ok != lastInBucketIsResponding { + if ok := containsID(bucket.entries, last.ID()); ok != lastInBucketIsResponding { t.Errorf("revalidated node found: %t, want: %t", ok, lastInBucketIsResponding) } wantNewEntry := newNodeIsResponding && !lastInBucketIsResponding - if ok := contains(bucket.entries, replacementNode.ID()); ok != wantNewEntry { + if ok := containsID(bucket.entries, replacementNode.ID()); ok != wantNewEntry { t.Errorf("replacement node found: %t, want: %t", ok, wantNewEntry) } } @@ -153,7 +154,7 @@ func TestTable_IPLimit(t *testing.T) { for i := 0; i < tableIPLimit+1; i++ { n := nodeAtDistance(tab.self().ID(), i, net.IP{172, 0, 1, byte(i)}) - tab.addFoundNode(n) + tab.addFoundNode(n, false) } if tab.len() > tableIPLimit { t.Errorf("too many nodes in table") @@ -171,7 +172,7 @@ func TestTable_BucketIPLimit(t *testing.T) { d := 3 for i := 0; i < bucketIPLimit+1; i++ { n := nodeAtDistance(tab.self().ID(), d, net.IP{172, 0, 1, byte(i)}) - tab.addFoundNode(n) + tab.addFoundNode(n, false) } if tab.len() > bucketIPLimit { t.Errorf("too many nodes in table") @@ -232,7 +233,7 @@ func TestTable_findnodeByID(t *testing.T) { // check that the result nodes have minimum distance to target. for _, b := range tab.buckets { for _, n := range b.entries { - if contains(result, n.ID()) { + if containsID(result, n.ID()) { continue // don't run the check below for nodes in result } farthestResult := result[len(result)-1].ID() @@ -255,7 +256,7 @@ func TestTable_findnodeByID(t *testing.T) { type closeTest struct { Self enode.ID Target enode.ID - All []*node + All []*enode.Node N int } @@ -268,8 +269,7 @@ func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value { for _, id := range gen([]enode.ID{}, rand).([]enode.ID) { r := new(enr.Record) r.Set(enr.IP(genIP(rand))) - n := wrapNode(enode.SignNull(r, id)) - n.livenessChecks = 1 + n := enode.SignNull(r, id) t.All = append(t.All, n) } return reflect.ValueOf(t) @@ -284,16 +284,16 @@ func TestTable_addInboundNode(t *testing.T) { // Insert two nodes. n1 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 1}) n2 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 2}) - tab.addFoundNode(n1) - tab.addFoundNode(n2) - checkBucketContent(t, tab, []*enode.Node{n1.Node, n2.Node}) + tab.addFoundNode(n1, false) + tab.addFoundNode(n2, false) + checkBucketContent(t, tab, []*enode.Node{n1, n2}) // Add a changed version of n2. The bucket should be updated. newrec := n2.Record() newrec.Set(enr.IP{99, 99, 99, 99}) n2v2 := enode.SignNull(newrec, n2.ID()) - tab.addInboundNode(wrapNode(n2v2)) - checkBucketContent(t, tab, []*enode.Node{n1.Node, n2v2}) + tab.addInboundNode(n2v2) + checkBucketContent(t, tab, []*enode.Node{n1, n2v2}) // Try updating n2 without sequence number change. The update is accepted // because it's inbound. @@ -301,8 +301,8 @@ func TestTable_addInboundNode(t *testing.T) { newrec.Set(enr.IP{100, 100, 100, 100}) newrec.SetSeq(n2.Seq()) n2v3 := enode.SignNull(newrec, n2.ID()) - tab.addInboundNode(wrapNode(n2v3)) - checkBucketContent(t, tab, []*enode.Node{n1.Node, n2v3}) + tab.addInboundNode(n2v3) + checkBucketContent(t, tab, []*enode.Node{n1, n2v3}) } func TestTable_addFoundNode(t *testing.T) { @@ -314,16 +314,16 @@ func TestTable_addFoundNode(t *testing.T) { // Insert two nodes. n1 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 1}) n2 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 2}) - tab.addFoundNode(n1) - tab.addFoundNode(n2) - checkBucketContent(t, tab, []*enode.Node{n1.Node, n2.Node}) + tab.addFoundNode(n1, false) + tab.addFoundNode(n2, false) + checkBucketContent(t, tab, []*enode.Node{n1, n2}) // Add a changed version of n2. The bucket should be updated. newrec := n2.Record() newrec.Set(enr.IP{99, 99, 99, 99}) n2v2 := enode.SignNull(newrec, n2.ID()) - tab.addFoundNode(wrapNode(n2v2)) - checkBucketContent(t, tab, []*enode.Node{n1.Node, n2v2}) + tab.addFoundNode(n2v2, false) + checkBucketContent(t, tab, []*enode.Node{n1, n2v2}) // Try updating n2 without a sequence number change. // The update should not be accepted. @@ -331,8 +331,8 @@ func TestTable_addFoundNode(t *testing.T) { newrec.Set(enr.IP{100, 100, 100, 100}) newrec.SetSeq(n2.Seq()) n2v3 := enode.SignNull(newrec, n2.ID()) - tab.addFoundNode(wrapNode(n2v3)) - checkBucketContent(t, tab, []*enode.Node{n1.Node, n2v2}) + tab.addFoundNode(n2v3, false) + checkBucketContent(t, tab, []*enode.Node{n1, n2v2}) } // This test checks that discv4 nodes can update their own endpoint via PING. @@ -345,13 +345,13 @@ func TestTable_addInboundNodeUpdateV4Accept(t *testing.T) { // Add a v4 node. key, _ := crypto.HexToECDSA("dd3757a8075e88d0f2b1431e7d3c5b1562e1c0aab9643707e8cbfcc8dae5cfe3") n1 := enode.NewV4(&key.PublicKey, net.IP{88, 77, 66, 1}, 9000, 9000) - tab.addInboundNode(wrapNode(n1)) + tab.addInboundNode(n1) checkBucketContent(t, tab, []*enode.Node{n1}) // Add an updated version with changed IP. // The update will be accepted because it is inbound. n1v2 := enode.NewV4(&key.PublicKey, net.IP{99, 99, 99, 99}, 9000, 9000) - tab.addInboundNode(wrapNode(n1v2)) + tab.addInboundNode(n1v2) checkBucketContent(t, tab, []*enode.Node{n1v2}) } @@ -366,13 +366,13 @@ func TestTable_addFoundNodeV4UpdateReject(t *testing.T) { // Add a v4 node. key, _ := crypto.HexToECDSA("dd3757a8075e88d0f2b1431e7d3c5b1562e1c0aab9643707e8cbfcc8dae5cfe3") n1 := enode.NewV4(&key.PublicKey, net.IP{88, 77, 66, 1}, 9000, 9000) - tab.addFoundNode(wrapNode(n1)) + tab.addFoundNode(n1, false) checkBucketContent(t, tab, []*enode.Node{n1}) // Add an updated version with changed IP. // The update won't be accepted because it isn't inbound. n1v2 := enode.NewV4(&key.PublicKey, net.IP{99, 99, 99, 99}, 9000, 9000) - tab.addFoundNode(wrapNode(n1v2)) + tab.addFoundNode(n1v2, false) checkBucketContent(t, tab, []*enode.Node{n1}) } @@ -413,8 +413,8 @@ func TestTable_revalidateSyncRecord(t *testing.T) { var r enr.Record r.Set(enr.IP(net.IP{127, 0, 0, 1})) id := enode.ID{1} - n1 := wrapNode(enode.SignNull(&r, id)) - tab.addFoundNode(n1) + n1 := enode.SignNull(&r, id) + tab.addFoundNode(n1, false) // Update the node record. r.Set(enr.WithEntry("foo", "bar")) @@ -437,7 +437,7 @@ func TestNodesPush(t *testing.T) { n1 := nodeAtDistance(target, 255, intIP(1)) n2 := nodeAtDistance(target, 254, intIP(2)) n3 := nodeAtDistance(target, 253, intIP(3)) - perm := [][]*node{ + perm := [][]*enode.Node{ {n3, n2, n1}, {n3, n1, n2}, {n2, n3, n1}, @@ -452,7 +452,7 @@ func TestNodesPush(t *testing.T) { for _, n := range nodes { list.push(n, 3) } - if !slicesEqual(list.entries, perm[0], nodeIDEqual) { + if !slices.EqualFunc(list.entries, perm[0], nodeIDEqual) { t.Fatal("not equal") } } @@ -463,28 +463,16 @@ func TestNodesPush(t *testing.T) { for _, n := range nodes { list.push(n, 2) } - if !slicesEqual(list.entries, perm[0][:2], nodeIDEqual) { + if !slices.EqualFunc(list.entries, perm[0][:2], nodeIDEqual) { t.Fatal("not equal") } } } -func nodeIDEqual(n1, n2 *node) bool { +func nodeIDEqual[N nodeType](n1, n2 N) bool { return n1.ID() == n2.ID() } -func slicesEqual[T any](s1, s2 []T, check func(e1, e2 T) bool) bool { - if len(s1) != len(s2) { - return false - } - for i := range s1 { - if !check(s1[i], s2[i]) { - return false - } - } - return true -} - // gen wraps quick.Value so it's easier to use. // it generates a random value of the given value's type. func gen(typ interface{}, rand *rand.Rand) interface{} { diff --git a/p2p/discover/table_util_test.go b/p2p/discover/table_util_test.go index 59045bf2a8..997ac37799 100644 --- a/p2p/discover/table_util_test.go +++ b/p2p/discover/table_util_test.go @@ -56,18 +56,18 @@ func newInactiveTestTable(t transport, cfg Config) (*Table, *enode.DB) { } // nodeAtDistance creates a node for which enode.LogDist(base, n.id) == ld. -func nodeAtDistance(base enode.ID, ld int, ip net.IP) *node { +func nodeAtDistance(base enode.ID, ld int, ip net.IP) *enode.Node { var r enr.Record r.Set(enr.IP(ip)) r.Set(enr.UDP(30303)) - return wrapNode(enode.SignNull(&r, idAtDistance(base, ld))) + return enode.SignNull(&r, idAtDistance(base, ld)) } // nodesAtDistance creates n nodes for which enode.LogDist(base, node.ID()) == ld. func nodesAtDistance(base enode.ID, ld int, n int) []*enode.Node { results := make([]*enode.Node, n) for i := range results { - results[i] = unwrapNode(nodeAtDistance(base, ld, intIP(i))) + results[i] = nodeAtDistance(base, ld, intIP(i)) } return results } @@ -105,12 +105,12 @@ func intIP(i int) net.IP { } // fillBucket inserts nodes into the given bucket until it is full. -func fillBucket(tab *Table, id enode.ID) (last *node) { +func fillBucket(tab *Table, id enode.ID) (last *tableNode) { ld := enode.LogDist(tab.self().ID(), id) b := tab.bucket(id) for len(b.entries) < bucketSize { node := nodeAtDistance(tab.self().ID(), ld, intIP(ld)) - if !tab.addFoundNode(node) { + if !tab.addFoundNode(node, false) { panic("node not added") } } @@ -119,13 +119,9 @@ func fillBucket(tab *Table, id enode.ID) (last *node) { // fillTable adds nodes the table to the end of their corresponding bucket // if the bucket is not full. The caller must not hold tab.mutex. -func fillTable(tab *Table, nodes []*node, setLive bool) { +func fillTable(tab *Table, nodes []*enode.Node, setLive bool) { for _, n := range nodes { - if setLive { - n.livenessChecks = 1 - n.isValidatedLive = true - } - tab.addFoundNode(n) + tab.addFoundNode(n, setLive) } } @@ -219,7 +215,7 @@ func (t *pingRecorder) RequestENR(n *enode.Node) (*enode.Node, error) { return t.records[n.ID()], nil } -func hasDuplicates(slice []*node) bool { +func hasDuplicates(slice []*enode.Node) bool { seen := make(map[enode.ID]bool, len(slice)) for i, e := range slice { if e == nil { @@ -261,14 +257,14 @@ func nodeEqual(n1 *enode.Node, n2 *enode.Node) bool { return n1.ID() == n2.ID() && n1.IP().Equal(n2.IP()) } -func sortByID(nodes []*enode.Node) { - slices.SortFunc(nodes, func(a, b *enode.Node) int { +func sortByID[N nodeType](nodes []N) { + slices.SortFunc(nodes, func(a, b N) int { return bytes.Compare(a.ID().Bytes(), b.ID().Bytes()) }) } -func sortedByDistanceTo(distbase enode.ID, slice []*node) bool { - return slices.IsSortedFunc(slice, func(a, b *node) int { +func sortedByDistanceTo(distbase enode.ID, slice []*enode.Node) bool { + return slices.IsSortedFunc(slice, func(a, b *enode.Node) int { return enode.DistCmp(distbase, a.ID(), b.ID()) }) } @@ -304,7 +300,7 @@ type nodeEventRecorder struct { } type recordedNodeEvent struct { - node *node + node *tableNode added bool } @@ -314,7 +310,7 @@ func newNodeEventRecorder(buffer int) *nodeEventRecorder { } } -func (set *nodeEventRecorder) nodeAdded(b *bucket, n *node) { +func (set *nodeEventRecorder) nodeAdded(b *bucket, n *tableNode) { select { case set.evc <- recordedNodeEvent{n, true}: default: @@ -322,7 +318,7 @@ func (set *nodeEventRecorder) nodeAdded(b *bucket, n *node) { } } -func (set *nodeEventRecorder) nodeRemoved(b *bucket, n *node) { +func (set *nodeEventRecorder) nodeRemoved(b *bucket, n *tableNode) { select { case set.evc <- recordedNodeEvent{n, false}: default: diff --git a/p2p/discover/v4_lookup_test.go b/p2p/discover/v4_lookup_test.go index 5682f262be..bc9475a8b3 100644 --- a/p2p/discover/v4_lookup_test.go +++ b/p2p/discover/v4_lookup_test.go @@ -19,7 +19,7 @@ package discover import ( "crypto/ecdsa" "fmt" - "net" + "net/netip" "slices" "testing" @@ -40,7 +40,7 @@ func TestUDPv4_Lookup(t *testing.T) { } // Seed table with initial node. - fillTable(test.table, []*node{wrapNode(lookupTestnet.node(256, 0))}, true) + fillTable(test.table, []*enode.Node{lookupTestnet.node(256, 0)}, true) // Start the lookup. resultC := make(chan []*enode.Node, 1) @@ -70,9 +70,9 @@ func TestUDPv4_LookupIterator(t *testing.T) { defer test.close() // Seed table with initial nodes. - bootnodes := make([]*node, len(lookupTestnet.dists[256])) + bootnodes := make([]*enode.Node, len(lookupTestnet.dists[256])) for i := range lookupTestnet.dists[256] { - bootnodes[i] = wrapNode(lookupTestnet.node(256, i)) + bootnodes[i] = lookupTestnet.node(256, i) } fillTable(test.table, bootnodes, true) go serveTestnet(test, lookupTestnet) @@ -105,9 +105,9 @@ func TestUDPv4_LookupIteratorClose(t *testing.T) { defer test.close() // Seed table with initial nodes. - bootnodes := make([]*node, len(lookupTestnet.dists[256])) + bootnodes := make([]*enode.Node, len(lookupTestnet.dists[256])) for i := range lookupTestnet.dists[256] { - bootnodes[i] = wrapNode(lookupTestnet.node(256, i)) + bootnodes[i] = lookupTestnet.node(256, i) } fillTable(test.table, bootnodes, true) go serveTestnet(test, lookupTestnet) @@ -136,7 +136,7 @@ func TestUDPv4_LookupIteratorClose(t *testing.T) { func serveTestnet(test *udpTest, testnet *preminedTestnet) { for done := false; !done; { - done = test.waitPacketOut(func(p v4wire.Packet, to *net.UDPAddr, hash []byte) { + done = test.waitPacketOut(func(p v4wire.Packet, to netip.AddrPort, hash []byte) { n, key := testnet.nodeByAddr(to) switch p.(type) { case *v4wire.Ping: @@ -158,10 +158,10 @@ func checkLookupResults(t *testing.T, tn *preminedTestnet, results []*enode.Node for _, e := range results { t.Logf(" ld=%d, %x", enode.LogDist(tn.target.id(), e.ID()), e.ID().Bytes()) } - if hasDuplicates(wrapNodes(results)) { + if hasDuplicates(results) { t.Errorf("result set contains duplicate entries") } - if !sortedByDistanceTo(tn.target.id(), wrapNodes(results)) { + if !sortedByDistanceTo(tn.target.id(), results) { t.Errorf("result set not sorted by distance to target") } wantNodes := tn.closest(len(results)) @@ -264,9 +264,10 @@ func (tn *preminedTestnet) node(dist, index int) *enode.Node { return n } -func (tn *preminedTestnet) nodeByAddr(addr *net.UDPAddr) (*enode.Node, *ecdsa.PrivateKey) { - dist := int(addr.IP[1])<<8 + int(addr.IP[2]) - index := int(addr.IP[3]) +func (tn *preminedTestnet) nodeByAddr(addr netip.AddrPort) (*enode.Node, *ecdsa.PrivateKey) { + ip := addr.Addr().As4() + dist := int(ip[1])<<8 + int(ip[2]) + index := int(ip[3]) key := tn.dists[dist][index] return tn.node(dist, index), key } @@ -274,7 +275,7 @@ func (tn *preminedTestnet) nodeByAddr(addr *net.UDPAddr) (*enode.Node, *ecdsa.Pr func (tn *preminedTestnet) nodesAtDistance(dist int) []v4wire.Node { result := make([]v4wire.Node, len(tn.dists[dist])) for i := range result { - result[i] = nodeToRPC(wrapNode(tn.node(dist, i))) + result[i] = nodeToRPC(tn.node(dist, i)) } return result } diff --git a/p2p/discover/v4_udp.go b/p2p/discover/v4_udp.go index be6058ec50..3880ca34a7 100644 --- a/p2p/discover/v4_udp.go +++ b/p2p/discover/v4_udp.go @@ -26,6 +26,7 @@ import ( "fmt" "io" "net" + "net/netip" "sync" "time" @@ -45,6 +46,7 @@ var ( errClockWarp = errors.New("reply deadline too far in the future") errClosed = errors.New("socket closed") errLowPort = errors.New("low port") + errNoUDPEndpoint = errors.New("node has no UDP endpoint") ) const ( @@ -93,7 +95,7 @@ type UDPv4 struct { type replyMatcher struct { // these fields must match in the reply. from enode.ID - ip net.IP + ip netip.Addr ptype byte // time when the request must complete @@ -119,7 +121,7 @@ type replyMatchFunc func(v4wire.Packet) (matched bool, requestDone bool) // reply is a reply packet from a certain node. type reply struct { from enode.ID - ip net.IP + ip netip.Addr data v4wire.Packet // loop indicates whether there was // a matching request by sending on this channel. @@ -201,9 +203,12 @@ func (t *UDPv4) Resolve(n *enode.Node) *enode.Node { } func (t *UDPv4) ourEndpoint() v4wire.Endpoint { - n := t.Self() - a := &net.UDPAddr{IP: n.IP(), Port: n.UDP()} - return v4wire.NewEndpoint(a, uint16(n.TCP())) + node := t.Self() + addr, ok := node.UDPEndpoint() + if !ok { + return v4wire.Endpoint{} + } + return v4wire.NewEndpoint(addr, uint16(node.TCP())) } // Ping sends a ping message to the given node. @@ -214,7 +219,11 @@ func (t *UDPv4) Ping(n *enode.Node) error { // ping sends a ping message to the given node and waits for a reply. func (t *UDPv4) ping(n *enode.Node) (seq uint64, err error) { - rm := t.sendPing(n.ID(), &net.UDPAddr{IP: n.IP(), Port: n.UDP()}, nil) + addr, ok := n.UDPEndpoint() + if !ok { + return 0, errNoUDPEndpoint + } + rm := t.sendPing(n.ID(), addr, nil) if err = <-rm.errc; err == nil { seq = rm.reply.(*v4wire.Pong).ENRSeq } @@ -223,7 +232,7 @@ func (t *UDPv4) ping(n *enode.Node) (seq uint64, err error) { // sendPing sends a ping message to the given node and invokes the callback // when the reply arrives. -func (t *UDPv4) sendPing(toid enode.ID, toaddr *net.UDPAddr, callback func()) *replyMatcher { +func (t *UDPv4) sendPing(toid enode.ID, toaddr netip.AddrPort, callback func()) *replyMatcher { req := t.makePing(toaddr) packet, hash, err := v4wire.Encode(t.priv, req) if err != nil { @@ -233,7 +242,7 @@ func (t *UDPv4) sendPing(toid enode.ID, toaddr *net.UDPAddr, callback func()) *r } // Add a matcher for the reply to the pending reply queue. Pongs are matched if they // reference the ping we're about to send. - rm := t.pending(toid, toaddr.IP, v4wire.PongPacket, func(p v4wire.Packet) (matched bool, requestDone bool) { + rm := t.pending(toid, toaddr.Addr(), v4wire.PongPacket, func(p v4wire.Packet) (matched bool, requestDone bool) { matched = bytes.Equal(p.(*v4wire.Pong).ReplyTok, hash) if matched && callback != nil { callback() @@ -241,12 +250,13 @@ func (t *UDPv4) sendPing(toid enode.ID, toaddr *net.UDPAddr, callback func()) *r return matched, matched }) // Send the packet. - t.localNode.UDPContact(toaddr) + toUDPAddr := &net.UDPAddr{IP: toaddr.Addr().AsSlice()} + t.localNode.UDPContact(toUDPAddr) t.write(toaddr, toid, req.Name(), packet) return rm } -func (t *UDPv4) makePing(toaddr *net.UDPAddr) *v4wire.Ping { +func (t *UDPv4) makePing(toaddr netip.AddrPort) *v4wire.Ping { return &v4wire.Ping{ Version: 4, From: t.ourEndpoint(), @@ -290,35 +300,39 @@ func (t *UDPv4) newRandomLookup(ctx context.Context) *lookup { func (t *UDPv4) newLookup(ctx context.Context, targetKey encPubkey) *lookup { target := enode.ID(crypto.Keccak256Hash(targetKey[:])) ekey := v4wire.Pubkey(targetKey) - it := newLookup(ctx, t.tab, target, func(n *node) ([]*node, error) { - return t.findnode(n.ID(), n.addr(), ekey) + it := newLookup(ctx, t.tab, target, func(n *enode.Node) ([]*enode.Node, error) { + addr, ok := n.UDPEndpoint() + if !ok { + return nil, errNoUDPEndpoint + } + return t.findnode(n.ID(), addr, ekey) }) return it } // findnode sends a findnode request to the given node and waits until // the node has sent up to k neighbors. -func (t *UDPv4) findnode(toid enode.ID, toaddr *net.UDPAddr, target v4wire.Pubkey) ([]*node, error) { - t.ensureBond(toid, toaddr) +func (t *UDPv4) findnode(toid enode.ID, toAddrPort netip.AddrPort, target v4wire.Pubkey) ([]*enode.Node, error) { + t.ensureBond(toid, toAddrPort) // Add a matcher for 'neighbours' replies to the pending reply queue. The matcher is // active until enough nodes have been received. - nodes := make([]*node, 0, bucketSize) + nodes := make([]*enode.Node, 0, bucketSize) nreceived := 0 - rm := t.pending(toid, toaddr.IP, v4wire.NeighborsPacket, func(r v4wire.Packet) (matched bool, requestDone bool) { + rm := t.pending(toid, toAddrPort.Addr(), v4wire.NeighborsPacket, func(r v4wire.Packet) (matched bool, requestDone bool) { reply := r.(*v4wire.Neighbors) for _, rn := range reply.Nodes { nreceived++ - n, err := t.nodeFromRPC(toaddr, rn) + n, err := t.nodeFromRPC(toAddrPort, rn) if err != nil { - t.log.Trace("Invalid neighbor node received", "ip", rn.IP, "addr", toaddr, "err", err) + t.log.Trace("Invalid neighbor node received", "ip", rn.IP, "addr", toAddrPort, "err", err) continue } nodes = append(nodes, n) } return true, nreceived >= bucketSize }) - t.send(toaddr, toid, &v4wire.Findnode{ + t.send(toAddrPort, toid, &v4wire.Findnode{ Target: target, Expiration: uint64(time.Now().Add(expiration).Unix()), }) @@ -336,7 +350,7 @@ func (t *UDPv4) findnode(toid enode.ID, toaddr *net.UDPAddr, target v4wire.Pubke // RequestENR sends ENRRequest to the given node and waits for a response. func (t *UDPv4) RequestENR(n *enode.Node) (*enode.Node, error) { - addr := &net.UDPAddr{IP: n.IP(), Port: n.UDP()} + addr, _ := n.UDPEndpoint() t.ensureBond(n.ID(), addr) req := &v4wire.ENRRequest{ @@ -349,7 +363,7 @@ func (t *UDPv4) RequestENR(n *enode.Node) (*enode.Node, error) { // Add a matcher for the reply to the pending reply queue. Responses are matched if // they reference the request we're about to send. - rm := t.pending(n.ID(), addr.IP, v4wire.ENRResponsePacket, func(r v4wire.Packet) (matched bool, requestDone bool) { + rm := t.pending(n.ID(), addr.Addr(), v4wire.ENRResponsePacket, func(r v4wire.Packet) (matched bool, requestDone bool) { matched = bytes.Equal(r.(*v4wire.ENRResponse).ReplyTok, hash) return matched, matched }) @@ -369,7 +383,7 @@ func (t *UDPv4) RequestENR(n *enode.Node) (*enode.Node, error) { if respN.Seq() < n.Seq() { return n, nil // response record is older } - if err := netutil.CheckRelayIP(addr.IP, respN.IP()); err != nil { + if err := netutil.CheckRelayIP(addr.Addr().AsSlice(), respN.IP()); err != nil { return nil, fmt.Errorf("invalid IP in response record: %v", err) } return respN, nil @@ -381,7 +395,7 @@ func (t *UDPv4) TableBuckets() [][]BucketNode { // pending adds a reply matcher to the pending reply queue. // see the documentation of type replyMatcher for a detailed explanation. -func (t *UDPv4) pending(id enode.ID, ip net.IP, ptype byte, callback replyMatchFunc) *replyMatcher { +func (t *UDPv4) pending(id enode.ID, ip netip.Addr, ptype byte, callback replyMatchFunc) *replyMatcher { ch := make(chan error, 1) p := &replyMatcher{from: id, ip: ip, ptype: ptype, callback: callback, errc: ch} select { @@ -395,7 +409,7 @@ func (t *UDPv4) pending(id enode.ID, ip net.IP, ptype byte, callback replyMatchF // handleReply dispatches a reply packet, invoking reply matchers. It returns // whether any matcher considered the packet acceptable. -func (t *UDPv4) handleReply(from enode.ID, fromIP net.IP, req v4wire.Packet) bool { +func (t *UDPv4) handleReply(from enode.ID, fromIP netip.Addr, req v4wire.Packet) bool { matched := make(chan bool, 1) select { case t.gotreply <- reply{from, fromIP, req, matched}: @@ -461,7 +475,7 @@ func (t *UDPv4) loop() { var matched bool // whether any replyMatcher considered the reply acceptable. for el := plist.Front(); el != nil; el = el.Next() { p := el.Value.(*replyMatcher) - if p.from == r.from && p.ptype == r.data.Kind() && p.ip.Equal(r.ip) { + if p.from == r.from && p.ptype == r.data.Kind() && p.ip == r.ip { ok, requestDone := p.callback(r.data) matched = matched || ok p.reply = r.data @@ -500,7 +514,7 @@ func (t *UDPv4) loop() { } } -func (t *UDPv4) send(toaddr *net.UDPAddr, toid enode.ID, req v4wire.Packet) ([]byte, error) { +func (t *UDPv4) send(toaddr netip.AddrPort, toid enode.ID, req v4wire.Packet) ([]byte, error) { packet, hash, err := v4wire.Encode(t.priv, req) if err != nil { return hash, err @@ -508,8 +522,8 @@ func (t *UDPv4) send(toaddr *net.UDPAddr, toid enode.ID, req v4wire.Packet) ([]b return hash, t.write(toaddr, toid, req.Name(), packet) } -func (t *UDPv4) write(toaddr *net.UDPAddr, toid enode.ID, what string, packet []byte) error { - _, err := t.conn.WriteToUDP(packet, toaddr) +func (t *UDPv4) write(toaddr netip.AddrPort, toid enode.ID, what string, packet []byte) error { + _, err := t.conn.WriteToUDPAddrPort(packet, toaddr) t.log.Trace(">> "+what, "id", toid, "addr", toaddr, "err", err) return err } @@ -523,7 +537,7 @@ func (t *UDPv4) readLoop(unhandled chan<- ReadPacket) { buf := make([]byte, maxPacketSize) for { - nbytes, from, err := t.conn.ReadFromUDP(buf) + nbytes, from, err := t.conn.ReadFromUDPAddrPort(buf) if netutil.IsTemporaryError(err) { // Ignore temporary read errors. t.log.Debug("Temporary UDP read error", "err", err) @@ -544,7 +558,7 @@ func (t *UDPv4) readLoop(unhandled chan<- ReadPacket) { } } -func (t *UDPv4) handlePacket(from *net.UDPAddr, buf []byte) error { +func (t *UDPv4) handlePacket(from netip.AddrPort, buf []byte) error { rawpacket, fromKey, hash, err := v4wire.Decode(buf) if err != nil { t.log.Debug("Bad discv4 packet", "addr", from, "err", err) @@ -563,15 +577,16 @@ func (t *UDPv4) handlePacket(from *net.UDPAddr, buf []byte) error { } // checkBond checks if the given node has a recent enough endpoint proof. -func (t *UDPv4) checkBond(id enode.ID, ip net.IP) bool { - return time.Since(t.db.LastPongReceived(id, ip)) < bondExpiration +func (t *UDPv4) checkBond(id enode.ID, ip netip.AddrPort) bool { + return time.Since(t.db.LastPongReceived(id, ip.Addr().AsSlice())) < bondExpiration } // ensureBond solicits a ping from a node if we haven't seen a ping from it for a while. // This ensures there is a valid endpoint proof on the remote end. -func (t *UDPv4) ensureBond(toid enode.ID, toaddr *net.UDPAddr) { - tooOld := time.Since(t.db.LastPingReceived(toid, toaddr.IP)) > bondExpiration - if tooOld || t.db.FindFails(toid, toaddr.IP) > maxFindnodeFailures { +func (t *UDPv4) ensureBond(toid enode.ID, toaddr netip.AddrPort) { + ip := toaddr.Addr().AsSlice() + tooOld := time.Since(t.db.LastPingReceived(toid, ip)) > bondExpiration + if tooOld || t.db.FindFails(toid, ip) > maxFindnodeFailures { rm := t.sendPing(toid, toaddr, nil) <-rm.errc // Wait for them to ping back and process our pong. @@ -579,11 +594,11 @@ func (t *UDPv4) ensureBond(toid enode.ID, toaddr *net.UDPAddr) { } } -func (t *UDPv4) nodeFromRPC(sender *net.UDPAddr, rn v4wire.Node) (*node, error) { +func (t *UDPv4) nodeFromRPC(sender netip.AddrPort, rn v4wire.Node) (*enode.Node, error) { if rn.UDP <= 1024 { return nil, errLowPort } - if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil { + if err := netutil.CheckRelayIP(sender.Addr().AsSlice(), rn.IP); err != nil { return nil, err } if t.netrestrict != nil && !t.netrestrict.Contains(rn.IP) { @@ -593,12 +608,12 @@ func (t *UDPv4) nodeFromRPC(sender *net.UDPAddr, rn v4wire.Node) (*node, error) if err != nil { return nil, err } - n := wrapNode(enode.NewV4(key, rn.IP, int(rn.TCP), int(rn.UDP))) + n := enode.NewV4(key, rn.IP, int(rn.TCP), int(rn.UDP)) err = n.ValidateComplete() return n, err } -func nodeToRPC(n *node) v4wire.Node { +func nodeToRPC(n *enode.Node) v4wire.Node { var key ecdsa.PublicKey var ekey v4wire.Pubkey if err := n.Load((*enode.Secp256k1)(&key)); err == nil { @@ -637,14 +652,14 @@ type packetHandlerV4 struct { senderKey *ecdsa.PublicKey // used for ping // preverify checks whether the packet is valid and should be handled at all. - preverify func(p *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, fromKey v4wire.Pubkey) error + preverify func(p *packetHandlerV4, from netip.AddrPort, fromID enode.ID, fromKey v4wire.Pubkey) error // handle handles the packet. - handle func(req *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, mac []byte) + handle func(req *packetHandlerV4, from netip.AddrPort, fromID enode.ID, mac []byte) } // PING/v4 -func (t *UDPv4) verifyPing(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, fromKey v4wire.Pubkey) error { +func (t *UDPv4) verifyPing(h *packetHandlerV4, from netip.AddrPort, fromID enode.ID, fromKey v4wire.Pubkey) error { req := h.Packet.(*v4wire.Ping) if v4wire.Expired(req.Expiration) { @@ -658,7 +673,7 @@ func (t *UDPv4) verifyPing(h *packetHandlerV4, from *net.UDPAddr, fromID enode.I return nil } -func (t *UDPv4) handlePing(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, mac []byte) { +func (t *UDPv4) handlePing(h *packetHandlerV4, from netip.AddrPort, fromID enode.ID, mac []byte) { req := h.Packet.(*v4wire.Ping) // Reply. @@ -670,8 +685,9 @@ func (t *UDPv4) handlePing(h *packetHandlerV4, from *net.UDPAddr, fromID enode.I }) // Ping back if our last pong on file is too far in the past. - n := wrapNode(enode.NewV4(h.senderKey, from.IP, int(req.From.TCP), from.Port)) - if time.Since(t.db.LastPongReceived(n.ID(), from.IP)) > bondExpiration { + fromIP := from.Addr().AsSlice() + n := enode.NewV4(h.senderKey, fromIP, int(req.From.TCP), int(from.Port())) + if time.Since(t.db.LastPongReceived(n.ID(), fromIP)) > bondExpiration { t.sendPing(fromID, from, func() { t.tab.addInboundNode(n) }) @@ -680,35 +696,40 @@ func (t *UDPv4) handlePing(h *packetHandlerV4, from *net.UDPAddr, fromID enode.I } // Update node database and endpoint predictor. - t.db.UpdateLastPingReceived(n.ID(), from.IP, time.Now()) - t.localNode.UDPEndpointStatement(from, &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)}) + t.db.UpdateLastPingReceived(n.ID(), fromIP, time.Now()) + fromUDPAddr := &net.UDPAddr{IP: fromIP, Port: int(from.Port())} + toUDPAddr := &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)} + t.localNode.UDPEndpointStatement(fromUDPAddr, toUDPAddr) } // PONG/v4 -func (t *UDPv4) verifyPong(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, fromKey v4wire.Pubkey) error { +func (t *UDPv4) verifyPong(h *packetHandlerV4, from netip.AddrPort, fromID enode.ID, fromKey v4wire.Pubkey) error { req := h.Packet.(*v4wire.Pong) if v4wire.Expired(req.Expiration) { return errExpired } - if !t.handleReply(fromID, from.IP, req) { + if !t.handleReply(fromID, from.Addr(), req) { return errUnsolicitedReply } - t.localNode.UDPEndpointStatement(from, &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)}) - t.db.UpdateLastPongReceived(fromID, from.IP, time.Now()) + fromIP := from.Addr().AsSlice() + fromUDPAddr := &net.UDPAddr{IP: fromIP, Port: int(from.Port())} + toUDPAddr := &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)} + t.localNode.UDPEndpointStatement(fromUDPAddr, toUDPAddr) + t.db.UpdateLastPongReceived(fromID, fromIP, time.Now()) return nil } // FINDNODE/v4 -func (t *UDPv4) verifyFindnode(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, fromKey v4wire.Pubkey) error { +func (t *UDPv4) verifyFindnode(h *packetHandlerV4, from netip.AddrPort, fromID enode.ID, fromKey v4wire.Pubkey) error { req := h.Packet.(*v4wire.Findnode) if v4wire.Expired(req.Expiration) { return errExpired } - if !t.checkBond(fromID, from.IP) { + if !t.checkBond(fromID, from) { // No endpoint proof pong exists, we don't process the packet. This prevents an // attack vector where the discovery protocol could be used to amplify traffic in a // DDOS attack. A malicious actor would send a findnode request with the IP address @@ -720,7 +741,7 @@ func (t *UDPv4) verifyFindnode(h *packetHandlerV4, from *net.UDPAddr, fromID eno return nil } -func (t *UDPv4) handleFindnode(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, mac []byte) { +func (t *UDPv4) handleFindnode(h *packetHandlerV4, from netip.AddrPort, fromID enode.ID, mac []byte) { req := h.Packet.(*v4wire.Findnode) // Determine closest nodes. @@ -732,7 +753,8 @@ func (t *UDPv4) handleFindnode(h *packetHandlerV4, from *net.UDPAddr, fromID eno p := v4wire.Neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())} var sent bool for _, n := range closest { - if netutil.CheckRelayIP(from.IP, n.IP()) == nil { + fromIP := from.Addr().AsSlice() + if netutil.CheckRelayIP(fromIP, n.IP()) == nil { p.Nodes = append(p.Nodes, nodeToRPC(n)) } if len(p.Nodes) == v4wire.MaxNeighbors { @@ -748,13 +770,13 @@ func (t *UDPv4) handleFindnode(h *packetHandlerV4, from *net.UDPAddr, fromID eno // NEIGHBORS/v4 -func (t *UDPv4) verifyNeighbors(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, fromKey v4wire.Pubkey) error { +func (t *UDPv4) verifyNeighbors(h *packetHandlerV4, from netip.AddrPort, fromID enode.ID, fromKey v4wire.Pubkey) error { req := h.Packet.(*v4wire.Neighbors) if v4wire.Expired(req.Expiration) { return errExpired } - if !t.handleReply(fromID, from.IP, h.Packet) { + if !t.handleReply(fromID, from.Addr(), h.Packet) { return errUnsolicitedReply } return nil @@ -762,19 +784,19 @@ func (t *UDPv4) verifyNeighbors(h *packetHandlerV4, from *net.UDPAddr, fromID en // ENRREQUEST/v4 -func (t *UDPv4) verifyENRRequest(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, fromKey v4wire.Pubkey) error { +func (t *UDPv4) verifyENRRequest(h *packetHandlerV4, from netip.AddrPort, fromID enode.ID, fromKey v4wire.Pubkey) error { req := h.Packet.(*v4wire.ENRRequest) if v4wire.Expired(req.Expiration) { return errExpired } - if !t.checkBond(fromID, from.IP) { + if !t.checkBond(fromID, from) { return errUnknownNode } return nil } -func (t *UDPv4) handleENRRequest(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, mac []byte) { +func (t *UDPv4) handleENRRequest(h *packetHandlerV4, from netip.AddrPort, fromID enode.ID, mac []byte) { t.send(from, fromID, &v4wire.ENRResponse{ ReplyTok: mac, Record: *t.localNode.Node().Record(), @@ -783,8 +805,8 @@ func (t *UDPv4) handleENRRequest(h *packetHandlerV4, from *net.UDPAddr, fromID e // ENRRESPONSE/v4 -func (t *UDPv4) verifyENRResponse(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, fromKey v4wire.Pubkey) error { - if !t.handleReply(fromID, from.IP, h.Packet) { +func (t *UDPv4) verifyENRResponse(h *packetHandlerV4, from netip.AddrPort, fromID enode.ID, fromKey v4wire.Pubkey) error { + if !t.handleReply(fromID, from.Addr(), h.Packet) { return errUnsolicitedReply } return nil diff --git a/p2p/discover/v4_udp_test.go b/p2p/discover/v4_udp_test.go index 9c454d98e3..28a6fb8675 100644 --- a/p2p/discover/v4_udp_test.go +++ b/p2p/discover/v4_udp_test.go @@ -26,6 +26,7 @@ import ( "io" "math/rand" "net" + "net/netip" "reflect" "sync" "testing" @@ -55,7 +56,7 @@ type udpTest struct { udp *UDPv4 sent [][]byte localkey, remotekey *ecdsa.PrivateKey - remoteaddr *net.UDPAddr + remoteaddr netip.AddrPort } func newUDPTest(t *testing.T) *udpTest { @@ -64,7 +65,7 @@ func newUDPTest(t *testing.T) *udpTest { pipe: newpipe(), localkey: newkey(), remotekey: newkey(), - remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303}, + remoteaddr: netip.MustParseAddrPort("10.0.1.99:30303"), } test.db, _ = enode.OpenDB("") @@ -92,7 +93,7 @@ func (test *udpTest) packetIn(wantError error, data v4wire.Packet) { } // handles a packet as if it had been sent to the transport by the key/endpoint. -func (test *udpTest) packetInFrom(wantError error, key *ecdsa.PrivateKey, addr *net.UDPAddr, data v4wire.Packet) { +func (test *udpTest) packetInFrom(wantError error, key *ecdsa.PrivateKey, addr netip.AddrPort, data v4wire.Packet) { test.t.Helper() enc, _, err := v4wire.Encode(key, data) @@ -106,7 +107,7 @@ func (test *udpTest) packetInFrom(wantError error, key *ecdsa.PrivateKey, addr * } // waits for a packet to be sent by the transport. -// validate should have type func(X, *net.UDPAddr, []byte), where X is a packet type. +// validate should have type func(X, netip.AddrPort, []byte), where X is a packet type. func (test *udpTest) waitPacketOut(validate interface{}) (closed bool) { test.t.Helper() @@ -128,7 +129,7 @@ func (test *udpTest) waitPacketOut(validate interface{}) (closed bool) { test.t.Errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype) return false } - fn.Call([]reflect.Value{reflect.ValueOf(p), reflect.ValueOf(&dgram.to), reflect.ValueOf(hash)}) + fn.Call([]reflect.Value{reflect.ValueOf(p), reflect.ValueOf(dgram.to), reflect.ValueOf(hash)}) return false } @@ -236,7 +237,7 @@ func TestUDPv4_findnodeTimeout(t *testing.T) { test := newUDPTest(t) defer test.close() - toaddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 2222} + toaddr := netip.AddrPortFrom(netip.MustParseAddr("1.2.3.4"), 2222) toid := enode.ID{1, 2, 3, 4} target := v4wire.Pubkey{4, 5, 6, 7} result, err := test.udp.findnode(toid, toaddr, target) @@ -261,26 +262,25 @@ func TestUDPv4_findnode(t *testing.T) { for i := 0; i < numCandidates; i++ { key := newkey() ip := net.IP{10, 13, 0, byte(i)} - n := wrapNode(enode.NewV4(&key.PublicKey, ip, 0, 2000)) + n := enode.NewV4(&key.PublicKey, ip, 0, 2000) // Ensure half of table content isn't verified live yet. if i > numCandidates/2 { - n.isValidatedLive = true live[n.ID()] = true } + test.table.addFoundNode(n, live[n.ID()]) nodes.push(n, numCandidates) } - fillTable(test.table, nodes.entries, false) // ensure there's a bond with the test node, // findnode won't be accepted otherwise. remoteID := v4wire.EncodePubkey(&test.remotekey.PublicKey).ID() - test.table.db.UpdateLastPongReceived(remoteID, test.remoteaddr.IP, time.Now()) + test.table.db.UpdateLastPongReceived(remoteID, test.remoteaddr.Addr().AsSlice(), time.Now()) // check that closest neighbors are returned. expected := test.table.findnodeByID(testTarget.ID(), bucketSize, true) test.packetIn(nil, &v4wire.Findnode{Target: testTarget, Expiration: futureExp}) - waitNeighbors := func(want []*node) { - test.waitPacketOut(func(p *v4wire.Neighbors, to *net.UDPAddr, hash []byte) { + waitNeighbors := func(want []*enode.Node) { + test.waitPacketOut(func(p *v4wire.Neighbors, to netip.AddrPort, hash []byte) { if len(p.Nodes) != len(want) { t.Errorf("wrong number of results: got %d, want %d", len(p.Nodes), len(want)) return @@ -309,10 +309,10 @@ func TestUDPv4_findnodeMultiReply(t *testing.T) { defer test.close() rid := enode.PubkeyToIDV4(&test.remotekey.PublicKey) - test.table.db.UpdateLastPingReceived(rid, test.remoteaddr.IP, time.Now()) + test.table.db.UpdateLastPingReceived(rid, test.remoteaddr.Addr().AsSlice(), time.Now()) // queue a pending findnode request - resultc, errc := make(chan []*node, 1), make(chan error, 1) + resultc, errc := make(chan []*enode.Node, 1), make(chan error, 1) go func() { rid := encodePubkey(&test.remotekey.PublicKey).id() ns, err := test.udp.findnode(rid, test.remoteaddr, testTarget) @@ -325,18 +325,18 @@ func TestUDPv4_findnodeMultiReply(t *testing.T) { // wait for the findnode to be sent. // after it is sent, the transport is waiting for a reply - test.waitPacketOut(func(p *v4wire.Findnode, to *net.UDPAddr, hash []byte) { + test.waitPacketOut(func(p *v4wire.Findnode, to netip.AddrPort, hash []byte) { if p.Target != testTarget { t.Errorf("wrong target: got %v, want %v", p.Target, testTarget) } }) // send the reply as two packets. - list := []*node{ - wrapNode(enode.MustParse("enode://ba85011c70bcc5c04d8607d3a0ed29aa6179c092cbdda10d5d32684fb33ed01bd94f588ca8f91ac48318087dcb02eaf36773a7a453f0eedd6742af668097b29c@10.0.1.16:30303?discport=30304")), - wrapNode(enode.MustParse("enode://81fa361d25f157cd421c60dcc28d8dac5ef6a89476633339c5df30287474520caca09627da18543d9079b5b288698b542d56167aa5c09111e55acdbbdf2ef799@10.0.1.16:30303")), - wrapNode(enode.MustParse("enode://9bffefd833d53fac8e652415f4973bee289e8b1a5c6c4cbe70abf817ce8a64cee11b823b66a987f51aaa9fba0d6a91b3e6bf0d5a5d1042de8e9eeea057b217f8@10.0.1.36:30301?discport=17")), - wrapNode(enode.MustParse("enode://1b5b4aa662d7cb44a7221bfba67302590b643028197a7d5214790f3bac7aaa4a3241be9e83c09cf1f6c69d007c634faae3dc1b1221793e8446c0b3a09de65960@10.0.1.16:30303")), + list := []*enode.Node{ + enode.MustParse("enode://ba85011c70bcc5c04d8607d3a0ed29aa6179c092cbdda10d5d32684fb33ed01bd94f588ca8f91ac48318087dcb02eaf36773a7a453f0eedd6742af668097b29c@10.0.1.16:30303?discport=30304"), + enode.MustParse("enode://81fa361d25f157cd421c60dcc28d8dac5ef6a89476633339c5df30287474520caca09627da18543d9079b5b288698b542d56167aa5c09111e55acdbbdf2ef799@10.0.1.16:30303"), + enode.MustParse("enode://9bffefd833d53fac8e652415f4973bee289e8b1a5c6c4cbe70abf817ce8a64cee11b823b66a987f51aaa9fba0d6a91b3e6bf0d5a5d1042de8e9eeea057b217f8@10.0.1.36:30301?discport=17"), + enode.MustParse("enode://1b5b4aa662d7cb44a7221bfba67302590b643028197a7d5214790f3bac7aaa4a3241be9e83c09cf1f6c69d007c634faae3dc1b1221793e8446c0b3a09de65960@10.0.1.16:30303"), } rpclist := make([]v4wire.Node, len(list)) for i := range list { @@ -368,8 +368,8 @@ func TestUDPv4_pingMatch(t *testing.T) { crand.Read(randToken) test.packetIn(nil, &v4wire.Ping{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp}) - test.waitPacketOut(func(*v4wire.Pong, *net.UDPAddr, []byte) {}) - test.waitPacketOut(func(*v4wire.Ping, *net.UDPAddr, []byte) {}) + test.waitPacketOut(func(*v4wire.Pong, netip.AddrPort, []byte) {}) + test.waitPacketOut(func(*v4wire.Ping, netip.AddrPort, []byte) {}) test.packetIn(errUnsolicitedReply, &v4wire.Pong{ReplyTok: randToken, To: testLocalAnnounced, Expiration: futureExp}) } @@ -379,10 +379,10 @@ func TestUDPv4_pingMatchIP(t *testing.T) { defer test.close() test.packetIn(nil, &v4wire.Ping{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp}) - test.waitPacketOut(func(*v4wire.Pong, *net.UDPAddr, []byte) {}) + test.waitPacketOut(func(*v4wire.Pong, netip.AddrPort, []byte) {}) - test.waitPacketOut(func(p *v4wire.Ping, to *net.UDPAddr, hash []byte) { - wrongAddr := &net.UDPAddr{IP: net.IP{33, 44, 1, 2}, Port: 30000} + test.waitPacketOut(func(p *v4wire.Ping, to netip.AddrPort, hash []byte) { + wrongAddr := netip.MustParseAddrPort("33.44.1.2:30000") test.packetInFrom(errUnsolicitedReply, test.remotekey, wrongAddr, &v4wire.Pong{ ReplyTok: hash, To: testLocalAnnounced, @@ -393,41 +393,36 @@ func TestUDPv4_pingMatchIP(t *testing.T) { func TestUDPv4_successfulPing(t *testing.T) { test := newUDPTest(t) - added := make(chan *node, 1) - test.table.nodeAddedHook = func(b *bucket, n *node) { added <- n } + added := make(chan *tableNode, 1) + test.table.nodeAddedHook = func(b *bucket, n *tableNode) { added <- n } defer test.close() // The remote side sends a ping packet to initiate the exchange. go test.packetIn(nil, &v4wire.Ping{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp}) // The ping is replied to. - test.waitPacketOut(func(p *v4wire.Pong, to *net.UDPAddr, hash []byte) { + test.waitPacketOut(func(p *v4wire.Pong, to netip.AddrPort, hash []byte) { pinghash := test.sent[0][:32] if !bytes.Equal(p.ReplyTok, pinghash) { t.Errorf("got pong.ReplyTok %x, want %x", p.ReplyTok, pinghash) } - wantTo := v4wire.Endpoint{ - // The mirrored UDP address is the UDP packet sender - IP: test.remoteaddr.IP, UDP: uint16(test.remoteaddr.Port), - // The mirrored TCP port is the one from the ping packet - TCP: testRemote.TCP, - } + // The mirrored UDP address is the UDP packet sender. + // The mirrored TCP port is the one from the ping packet. + wantTo := v4wire.NewEndpoint(test.remoteaddr, testRemote.TCP) if !reflect.DeepEqual(p.To, wantTo) { t.Errorf("got pong.To %v, want %v", p.To, wantTo) } }) // Remote is unknown, the table pings back. - test.waitPacketOut(func(p *v4wire.Ping, to *net.UDPAddr, hash []byte) { - if !reflect.DeepEqual(p.From, test.udp.ourEndpoint()) { + test.waitPacketOut(func(p *v4wire.Ping, to netip.AddrPort, hash []byte) { + wantFrom := test.udp.ourEndpoint() + wantFrom.IP = net.IP{} + if !reflect.DeepEqual(p.From, wantFrom) { t.Errorf("got ping.From %#v, want %#v", p.From, test.udp.ourEndpoint()) } - wantTo := v4wire.Endpoint{ - // The mirrored UDP address is the UDP packet sender. - IP: test.remoteaddr.IP, - UDP: uint16(test.remoteaddr.Port), - TCP: 0, - } + // The mirrored UDP address is the UDP packet sender. + wantTo := v4wire.NewEndpoint(test.remoteaddr, 0) if !reflect.DeepEqual(p.To, wantTo) { t.Errorf("got ping.To %v, want %v", p.To, wantTo) } @@ -442,11 +437,11 @@ func TestUDPv4_successfulPing(t *testing.T) { if n.ID() != rid { t.Errorf("node has wrong ID: got %v, want %v", n.ID(), rid) } - if !n.IP().Equal(test.remoteaddr.IP) { - t.Errorf("node has wrong IP: got %v, want: %v", n.IP(), test.remoteaddr.IP) + if !n.IP().Equal(test.remoteaddr.Addr().AsSlice()) { + t.Errorf("node has wrong IP: got %v, want: %v", n.IP(), test.remoteaddr.Addr()) } - if n.UDP() != test.remoteaddr.Port { - t.Errorf("node has wrong UDP port: got %v, want: %v", n.UDP(), test.remoteaddr.Port) + if n.UDP() != int(test.remoteaddr.Port()) { + t.Errorf("node has wrong UDP port: got %v, want: %v", n.UDP(), test.remoteaddr.Port()) } if n.TCP() != int(testRemote.TCP) { t.Errorf("node has wrong TCP port: got %v, want: %v", n.TCP(), testRemote.TCP) @@ -469,12 +464,12 @@ func TestUDPv4_EIP868(t *testing.T) { // Perform endpoint proof and check for sequence number in packet tail. test.packetIn(nil, &v4wire.Ping{Expiration: futureExp}) - test.waitPacketOut(func(p *v4wire.Pong, addr *net.UDPAddr, hash []byte) { + test.waitPacketOut(func(p *v4wire.Pong, addr netip.AddrPort, hash []byte) { if p.ENRSeq != wantNode.Seq() { t.Errorf("wrong sequence number in pong: %d, want %d", p.ENRSeq, wantNode.Seq()) } }) - test.waitPacketOut(func(p *v4wire.Ping, addr *net.UDPAddr, hash []byte) { + test.waitPacketOut(func(p *v4wire.Ping, addr netip.AddrPort, hash []byte) { if p.ENRSeq != wantNode.Seq() { t.Errorf("wrong sequence number in ping: %d, want %d", p.ENRSeq, wantNode.Seq()) } @@ -483,7 +478,7 @@ func TestUDPv4_EIP868(t *testing.T) { // Request should work now. test.packetIn(nil, &v4wire.ENRRequest{Expiration: futureExp}) - test.waitPacketOut(func(p *v4wire.ENRResponse, addr *net.UDPAddr, hash []byte) { + test.waitPacketOut(func(p *v4wire.ENRResponse, addr netip.AddrPort, hash []byte) { n, err := enode.New(enode.ValidSchemes, &p.Record) if err != nil { t.Fatalf("invalid record: %v", err) @@ -584,7 +579,7 @@ type dgramPipe struct { } type dgram struct { - to net.UDPAddr + to netip.AddrPort data []byte } @@ -597,8 +592,8 @@ func newpipe() *dgramPipe { } } -// WriteToUDP queues a datagram. -func (c *dgramPipe) WriteToUDP(b []byte, to *net.UDPAddr) (n int, err error) { +// WriteToUDPAddrPort queues a datagram. +func (c *dgramPipe) WriteToUDPAddrPort(b []byte, to netip.AddrPort) (n int, err error) { msg := make([]byte, len(b)) copy(msg, b) c.mu.Lock() @@ -606,15 +601,15 @@ func (c *dgramPipe) WriteToUDP(b []byte, to *net.UDPAddr) (n int, err error) { if c.closed { return 0, errors.New("closed") } - c.queue = append(c.queue, dgram{*to, b}) + c.queue = append(c.queue, dgram{to, b}) c.cond.Signal() return len(b), nil } -// ReadFromUDP just hangs until the pipe is closed. -func (c *dgramPipe) ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) { +// ReadFromUDPAddrPort just hangs until the pipe is closed. +func (c *dgramPipe) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) { <-c.closing - return 0, nil, io.EOF + return 0, netip.AddrPort{}, io.EOF } func (c *dgramPipe) Close() error { diff --git a/p2p/discover/v4wire/v4wire.go b/p2p/discover/v4wire/v4wire.go index 9c59359fb2..958cca324d 100644 --- a/p2p/discover/v4wire/v4wire.go +++ b/p2p/discover/v4wire/v4wire.go @@ -25,6 +25,7 @@ import ( "fmt" "math/big" "net" + "net/netip" "time" "github.com/ethereum/go-ethereum/common/math" @@ -150,14 +151,15 @@ type Endpoint struct { } // NewEndpoint creates an endpoint. -func NewEndpoint(addr *net.UDPAddr, tcpPort uint16) Endpoint { - ip := net.IP{} - if ip4 := addr.IP.To4(); ip4 != nil { - ip = ip4 - } else if ip6 := addr.IP.To16(); ip6 != nil { - ip = ip6 +func NewEndpoint(addr netip.AddrPort, tcpPort uint16) Endpoint { + var ip net.IP + if addr.Addr().Is4() || addr.Addr().Is4In6() { + ip4 := addr.Addr().As4() + ip = ip4[:] + } else { + ip = addr.Addr().AsSlice() } - return Endpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort} + return Endpoint{IP: ip, UDP: addr.Port(), TCP: tcpPort} } type Packet interface { diff --git a/p2p/discover/v5_talk.go b/p2p/discover/v5_talk.go index c1f6787940..2246b47141 100644 --- a/p2p/discover/v5_talk.go +++ b/p2p/discover/v5_talk.go @@ -18,6 +18,7 @@ package discover import ( "net" + "net/netip" "sync" "time" @@ -70,7 +71,7 @@ func (t *talkSystem) register(protocol string, handler TalkRequestHandler) { } // handleRequest handles a talk request. -func (t *talkSystem) handleRequest(id enode.ID, addr *net.UDPAddr, req *v5wire.TalkRequest) { +func (t *talkSystem) handleRequest(id enode.ID, addr netip.AddrPort, req *v5wire.TalkRequest) { t.mutex.Lock() handler, ok := t.handlers[req.Protocol] t.mutex.Unlock() @@ -88,7 +89,8 @@ func (t *talkSystem) handleRequest(id enode.ID, addr *net.UDPAddr, req *v5wire.T case <-t.slots: go func() { defer func() { t.slots <- struct{}{} }() - respMessage := handler(id, addr, req.Message) + udpAddr := &net.UDPAddr{IP: addr.Addr().AsSlice(), Port: int(addr.Port())} + respMessage := handler(id, udpAddr, req.Message) resp := &v5wire.TalkResponse{ReqID: req.ReqID, Message: respMessage} t.transport.sendFromAnotherThread(id, addr, resp) }() diff --git a/p2p/discover/v5_udp.go b/p2p/discover/v5_udp.go index 8cdc9dfbce..9ba54b3d40 100644 --- a/p2p/discover/v5_udp.go +++ b/p2p/discover/v5_udp.go @@ -25,6 +25,7 @@ import ( "fmt" "io" "net" + "net/netip" "slices" "sync" "time" @@ -101,14 +102,14 @@ type UDPv5 struct { type sendRequest struct { destID enode.ID - destAddr *net.UDPAddr + destAddr netip.AddrPort msg v5wire.Packet } // callV5 represents a remote procedure call against another node. type callV5 struct { id enode.ID - addr *net.UDPAddr + addr netip.AddrPort node *enode.Node // This is required to perform handshakes. packet v5wire.Packet @@ -233,7 +234,7 @@ func (t *UDPv5) AllNodes() []*enode.Node { for _, b := range &t.tab.buckets { for _, n := range b.entries { - nodes = append(nodes, unwrapNode(n)) + nodes = append(nodes, n.Node) } } return nodes @@ -266,7 +267,7 @@ func (t *UDPv5) TalkRequest(n *enode.Node, protocol string, request []byte) ([]b } // TalkRequestToID sends a talk request to a node and waits for a response. -func (t *UDPv5) TalkRequestToID(id enode.ID, addr *net.UDPAddr, protocol string, request []byte) ([]byte, error) { +func (t *UDPv5) TalkRequestToID(id enode.ID, addr netip.AddrPort, protocol string, request []byte) ([]byte, error) { req := &v5wire.TalkRequest{Protocol: protocol, Message: request} resp := t.callToID(id, addr, v5wire.TalkResponseMsg, req) defer t.callDone(resp) @@ -314,26 +315,26 @@ func (t *UDPv5) newRandomLookup(ctx context.Context) *lookup { } func (t *UDPv5) newLookup(ctx context.Context, target enode.ID) *lookup { - return newLookup(ctx, t.tab, target, func(n *node) ([]*node, error) { + return newLookup(ctx, t.tab, target, func(n *enode.Node) ([]*enode.Node, error) { return t.lookupWorker(n, target) }) } // lookupWorker performs FINDNODE calls against a single node during lookup. -func (t *UDPv5) lookupWorker(destNode *node, target enode.ID) ([]*node, error) { +func (t *UDPv5) lookupWorker(destNode *enode.Node, target enode.ID) ([]*enode.Node, error) { var ( dists = lookupDistances(target, destNode.ID()) nodes = nodesByDistance{target: target} err error ) var r []*enode.Node - r, err = t.findnode(unwrapNode(destNode), dists) + r, err = t.findnode(destNode, dists) if errors.Is(err, errClosed) { return nil, err } for _, n := range r { if n.ID() != t.Self().ID() { - nodes.push(wrapNode(n), findnodeResultLimit) + nodes.push(n, findnodeResultLimit) } } return nodes.entries, err @@ -427,7 +428,7 @@ func (t *UDPv5) verifyResponseNode(c *callV5, r *enr.Record, distances []uint, s if err != nil { return nil, err } - if err := netutil.CheckRelayIP(c.addr.IP, node.IP()); err != nil { + if err := netutil.CheckRelayIP(c.addr.Addr().AsSlice(), node.IP()); err != nil { return nil, err } if t.netrestrict != nil && !t.netrestrict.Contains(node.IP()) { @@ -452,14 +453,14 @@ func (t *UDPv5) verifyResponseNode(c *callV5, r *enr.Record, distances []uint, s // callToNode sends the given call and sets up a handler for response packets (of message // type responseType). Responses are dispatched to the call's response channel. func (t *UDPv5) callToNode(n *enode.Node, responseType byte, req v5wire.Packet) *callV5 { - addr := &net.UDPAddr{IP: n.IP(), Port: n.UDP()} + addr, _ := n.UDPEndpoint() c := &callV5{id: n.ID(), addr: addr, node: n} t.initCall(c, responseType, req) return c } // callToID is like callToNode, but for cases where the node record is not available. -func (t *UDPv5) callToID(id enode.ID, addr *net.UDPAddr, responseType byte, req v5wire.Packet) *callV5 { +func (t *UDPv5) callToID(id enode.ID, addr netip.AddrPort, responseType byte, req v5wire.Packet) *callV5 { c := &callV5{id: id, addr: addr} t.initCall(c, responseType, req) return c @@ -619,12 +620,12 @@ func (t *UDPv5) sendCall(c *callV5) { // sendResponse sends a response packet to the given node. // This doesn't trigger a handshake even if no keys are available. -func (t *UDPv5) sendResponse(toID enode.ID, toAddr *net.UDPAddr, packet v5wire.Packet) error { +func (t *UDPv5) sendResponse(toID enode.ID, toAddr netip.AddrPort, packet v5wire.Packet) error { _, err := t.send(toID, toAddr, packet, nil) return err } -func (t *UDPv5) sendFromAnotherThread(toID enode.ID, toAddr *net.UDPAddr, packet v5wire.Packet) { +func (t *UDPv5) sendFromAnotherThread(toID enode.ID, toAddr netip.AddrPort, packet v5wire.Packet) { select { case t.sendCh <- sendRequest{toID, toAddr, packet}: case <-t.closeCtx.Done(): @@ -632,7 +633,7 @@ func (t *UDPv5) sendFromAnotherThread(toID enode.ID, toAddr *net.UDPAddr, packet } // send sends a packet to the given node. -func (t *UDPv5) send(toID enode.ID, toAddr *net.UDPAddr, packet v5wire.Packet, c *v5wire.Whoareyou) (v5wire.Nonce, error) { +func (t *UDPv5) send(toID enode.ID, toAddr netip.AddrPort, packet v5wire.Packet, c *v5wire.Whoareyou) (v5wire.Nonce, error) { addr := toAddr.String() t.logcontext = append(t.logcontext[:0], "id", toID, "addr", addr) t.logcontext = packet.AppendLogInfo(t.logcontext) @@ -644,7 +645,7 @@ func (t *UDPv5) send(toID enode.ID, toAddr *net.UDPAddr, packet v5wire.Packet, c return nonce, err } - _, err = t.conn.WriteToUDP(enc, toAddr) + _, err = t.conn.WriteToUDPAddrPort(enc, toAddr) t.log.Trace(">> "+packet.Name(), t.logcontext...) return nonce, err } @@ -655,7 +656,7 @@ func (t *UDPv5) readLoop() { buf := make([]byte, maxPacketSize) for range t.readNextCh { - nbytes, from, err := t.conn.ReadFromUDP(buf) + nbytes, from, err := t.conn.ReadFromUDPAddrPort(buf) if netutil.IsTemporaryError(err) { // Ignore temporary read errors. t.log.Debug("Temporary UDP read error", "err", err) @@ -672,7 +673,7 @@ func (t *UDPv5) readLoop() { } // dispatchReadPacket sends a packet into the dispatch loop. -func (t *UDPv5) dispatchReadPacket(from *net.UDPAddr, content []byte) bool { +func (t *UDPv5) dispatchReadPacket(from netip.AddrPort, content []byte) bool { select { case t.packetInCh <- ReadPacket{content, from}: return true @@ -682,7 +683,7 @@ func (t *UDPv5) dispatchReadPacket(from *net.UDPAddr, content []byte) bool { } // handlePacket decodes and processes an incoming packet from the network. -func (t *UDPv5) handlePacket(rawpacket []byte, fromAddr *net.UDPAddr) error { +func (t *UDPv5) handlePacket(rawpacket []byte, fromAddr netip.AddrPort) error { addr := fromAddr.String() fromID, fromNode, packet, err := t.codec.Decode(rawpacket, addr) if err != nil { @@ -699,7 +700,7 @@ func (t *UDPv5) handlePacket(rawpacket []byte, fromAddr *net.UDPAddr) error { } if fromNode != nil { // Handshake succeeded, add to table. - t.tab.addInboundNode(wrapNode(fromNode)) + t.tab.addInboundNode(fromNode) } if packet.Kind() != v5wire.WhoareyouPacket { // WHOAREYOU logged separately to report errors. @@ -712,13 +713,13 @@ func (t *UDPv5) handlePacket(rawpacket []byte, fromAddr *net.UDPAddr) error { } // handleCallResponse dispatches a response packet to the call waiting for it. -func (t *UDPv5) handleCallResponse(fromID enode.ID, fromAddr *net.UDPAddr, p v5wire.Packet) bool { +func (t *UDPv5) handleCallResponse(fromID enode.ID, fromAddr netip.AddrPort, p v5wire.Packet) bool { ac := t.activeCallByNode[fromID] if ac == nil || !bytes.Equal(p.RequestID(), ac.reqid) { t.log.Debug(fmt.Sprintf("Unsolicited/late %s response", p.Name()), "id", fromID, "addr", fromAddr) return false } - if !fromAddr.IP.Equal(ac.addr.IP) || fromAddr.Port != ac.addr.Port { + if fromAddr != ac.addr { t.log.Debug(fmt.Sprintf("%s from wrong endpoint", p.Name()), "id", fromID, "addr", fromAddr) return false } @@ -743,7 +744,7 @@ func (t *UDPv5) getNode(id enode.ID) *enode.Node { } // handle processes incoming packets according to their message type. -func (t *UDPv5) handle(p v5wire.Packet, fromID enode.ID, fromAddr *net.UDPAddr) { +func (t *UDPv5) handle(p v5wire.Packet, fromID enode.ID, fromAddr netip.AddrPort) { switch p := p.(type) { case *v5wire.Unknown: t.handleUnknown(p, fromID, fromAddr) @@ -753,7 +754,9 @@ func (t *UDPv5) handle(p v5wire.Packet, fromID enode.ID, fromAddr *net.UDPAddr) t.handlePing(p, fromID, fromAddr) case *v5wire.Pong: if t.handleCallResponse(fromID, fromAddr, p) { - t.localNode.UDPEndpointStatement(fromAddr, &net.UDPAddr{IP: p.ToIP, Port: int(p.ToPort)}) + fromUDPAddr := &net.UDPAddr{IP: fromAddr.Addr().AsSlice(), Port: int(fromAddr.Port())} + toUDPAddr := &net.UDPAddr{IP: p.ToIP, Port: int(p.ToPort)} + t.localNode.UDPEndpointStatement(fromUDPAddr, toUDPAddr) } case *v5wire.Findnode: t.handleFindnode(p, fromID, fromAddr) @@ -767,7 +770,7 @@ func (t *UDPv5) handle(p v5wire.Packet, fromID enode.ID, fromAddr *net.UDPAddr) } // handleUnknown initiates a handshake by responding with WHOAREYOU. -func (t *UDPv5) handleUnknown(p *v5wire.Unknown, fromID enode.ID, fromAddr *net.UDPAddr) { +func (t *UDPv5) handleUnknown(p *v5wire.Unknown, fromID enode.ID, fromAddr netip.AddrPort) { challenge := &v5wire.Whoareyou{Nonce: p.Nonce} crand.Read(challenge.IDNonce[:]) if n := t.getNode(fromID); n != nil { @@ -783,7 +786,7 @@ var ( ) // handleWhoareyou resends the active call as a handshake packet. -func (t *UDPv5) handleWhoareyou(p *v5wire.Whoareyou, fromID enode.ID, fromAddr *net.UDPAddr) { +func (t *UDPv5) handleWhoareyou(p *v5wire.Whoareyou, fromID enode.ID, fromAddr netip.AddrPort) { c, err := t.matchWithCall(fromID, p.Nonce) if err != nil { t.log.Debug("Invalid "+p.Name(), "addr", fromAddr, "err", err) @@ -817,32 +820,35 @@ func (t *UDPv5) matchWithCall(fromID enode.ID, nonce v5wire.Nonce) (*callV5, err } // handlePing sends a PONG response. -func (t *UDPv5) handlePing(p *v5wire.Ping, fromID enode.ID, fromAddr *net.UDPAddr) { - remoteIP := fromAddr.IP - // Handle IPv4 mapped IPv6 addresses in the - // event the local node is binded to an - // ipv6 interface. - if remoteIP.To4() != nil { - remoteIP = remoteIP.To4() +func (t *UDPv5) handlePing(p *v5wire.Ping, fromID enode.ID, fromAddr netip.AddrPort) { + var remoteIP net.IP + // Handle IPv4 mapped IPv6 addresses in the event the local node is binded + // to an ipv6 interface. + if fromAddr.Addr().Is4() || fromAddr.Addr().Is4In6() { + ip4 := fromAddr.Addr().As4() + remoteIP = ip4[:] + } else { + remoteIP = fromAddr.Addr().AsSlice() } t.sendResponse(fromID, fromAddr, &v5wire.Pong{ ReqID: p.ReqID, ToIP: remoteIP, - ToPort: uint16(fromAddr.Port), + ToPort: fromAddr.Port(), ENRSeq: t.localNode.Node().Seq(), }) } // handleFindnode returns nodes to the requester. -func (t *UDPv5) handleFindnode(p *v5wire.Findnode, fromID enode.ID, fromAddr *net.UDPAddr) { - nodes := t.collectTableNodes(fromAddr.IP, p.Distances, findnodeResultLimit) +func (t *UDPv5) handleFindnode(p *v5wire.Findnode, fromID enode.ID, fromAddr netip.AddrPort) { + nodes := t.collectTableNodes(fromAddr.Addr(), p.Distances, findnodeResultLimit) for _, resp := range packNodes(p.ReqID, nodes) { t.sendResponse(fromID, fromAddr, resp) } } // collectTableNodes creates a FINDNODE result set for the given distances. -func (t *UDPv5) collectTableNodes(rip net.IP, distances []uint, limit int) []*enode.Node { +func (t *UDPv5) collectTableNodes(rip netip.Addr, distances []uint, limit int) []*enode.Node { + ripSlice := rip.AsSlice() var bn []*enode.Node var nodes []*enode.Node var processed = make(map[uint]struct{}) @@ -857,7 +863,7 @@ func (t *UDPv5) collectTableNodes(rip net.IP, distances []uint, limit int) []*en for _, n := range t.tab.appendLiveNodes(dist, bn[:0]) { // Apply some pre-checks to avoid sending invalid nodes. // Note liveness is checked by appendLiveNodes. - if netutil.CheckRelayIP(rip, n.IP()) != nil { + if netutil.CheckRelayIP(ripSlice, n.IP()) != nil { continue } nodes = append(nodes, n) diff --git a/p2p/discover/v5_udp_test.go b/p2p/discover/v5_udp_test.go index 0015f7cc70..1f8e972200 100644 --- a/p2p/discover/v5_udp_test.go +++ b/p2p/discover/v5_udp_test.go @@ -23,6 +23,7 @@ import ( "fmt" "math/rand" "net" + "net/netip" "reflect" "slices" "testing" @@ -103,7 +104,7 @@ func TestUDPv5_pingHandling(t *testing.T) { defer test.close() test.packetIn(&v5wire.Ping{ReqID: []byte("foo")}) - test.waitPacketOut(func(p *v5wire.Pong, addr *net.UDPAddr, _ v5wire.Nonce) { + test.waitPacketOut(func(p *v5wire.Pong, addr netip.AddrPort, _ v5wire.Nonce) { if !bytes.Equal(p.ReqID, []byte("foo")) { t.Error("wrong request ID in response:", p.ReqID) } @@ -135,16 +136,16 @@ func TestUDPv5_unknownPacket(t *testing.T) { // Unknown packet from unknown node. test.packetIn(&v5wire.Unknown{Nonce: nonce}) - test.waitPacketOut(func(p *v5wire.Whoareyou, addr *net.UDPAddr, _ v5wire.Nonce) { + test.waitPacketOut(func(p *v5wire.Whoareyou, addr netip.AddrPort, _ v5wire.Nonce) { check(p, 0) }) // Make node known. n := test.getNode(test.remotekey, test.remoteaddr).Node() - test.table.addFoundNode(wrapNode(n)) + test.table.addFoundNode(n, false) test.packetIn(&v5wire.Unknown{Nonce: nonce}) - test.waitPacketOut(func(p *v5wire.Whoareyou, addr *net.UDPAddr, _ v5wire.Nonce) { + test.waitPacketOut(func(p *v5wire.Whoareyou, addr netip.AddrPort, _ v5wire.Nonce) { check(p, n.Seq()) }) } @@ -159,9 +160,9 @@ func TestUDPv5_findnodeHandling(t *testing.T) { nodes253 := nodesAtDistance(test.table.self().ID(), 253, 16) nodes249 := nodesAtDistance(test.table.self().ID(), 249, 4) nodes248 := nodesAtDistance(test.table.self().ID(), 248, 10) - fillTable(test.table, wrapNodes(nodes253), true) - fillTable(test.table, wrapNodes(nodes249), true) - fillTable(test.table, wrapNodes(nodes248), true) + fillTable(test.table, nodes253, true) + fillTable(test.table, nodes249, true) + fillTable(test.table, nodes248, true) // Requesting with distance zero should return the node's own record. test.packetIn(&v5wire.Findnode{ReqID: []byte{0}, Distances: []uint{0}}) @@ -199,7 +200,7 @@ func (test *udpV5Test) expectNodes(wantReqID []byte, wantTotal uint8, wantNodes } for { - test.waitPacketOut(func(p *v5wire.Nodes, addr *net.UDPAddr, _ v5wire.Nonce) { + test.waitPacketOut(func(p *v5wire.Nodes, addr netip.AddrPort, _ v5wire.Nonce) { if !bytes.Equal(p.ReqID, wantReqID) { test.t.Fatalf("wrong request ID %v in response, want %v", p.ReqID, wantReqID) } @@ -238,7 +239,7 @@ func TestUDPv5_pingCall(t *testing.T) { _, err := test.udp.ping(remote) done <- err }() - test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, _ v5wire.Nonce) {}) + test.waitPacketOut(func(p *v5wire.Ping, addr netip.AddrPort, _ v5wire.Nonce) {}) if err := <-done; err != errTimeout { t.Fatalf("want errTimeout, got %q", err) } @@ -248,7 +249,7 @@ func TestUDPv5_pingCall(t *testing.T) { _, err := test.udp.ping(remote) done <- err }() - test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, _ v5wire.Nonce) { + test.waitPacketOut(func(p *v5wire.Ping, addr netip.AddrPort, _ v5wire.Nonce) { test.packetInFrom(test.remotekey, test.remoteaddr, &v5wire.Pong{ReqID: p.ReqID}) }) if err := <-done; err != nil { @@ -260,8 +261,8 @@ func TestUDPv5_pingCall(t *testing.T) { _, err := test.udp.ping(remote) done <- err }() - test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, _ v5wire.Nonce) { - wrongAddr := &net.UDPAddr{IP: net.IP{33, 44, 55, 22}, Port: 10101} + test.waitPacketOut(func(p *v5wire.Ping, addr netip.AddrPort, _ v5wire.Nonce) { + wrongAddr := netip.MustParseAddrPort("33.44.55.22:10101") test.packetInFrom(test.remotekey, wrongAddr, &v5wire.Pong{ReqID: p.ReqID}) }) if err := <-done; err != errTimeout { @@ -291,7 +292,7 @@ func TestUDPv5_findnodeCall(t *testing.T) { }() // Serve the responses: - test.waitPacketOut(func(p *v5wire.Findnode, addr *net.UDPAddr, _ v5wire.Nonce) { + test.waitPacketOut(func(p *v5wire.Findnode, addr netip.AddrPort, _ v5wire.Nonce) { if !reflect.DeepEqual(p.Distances, distances) { t.Fatalf("wrong distances in request: %v", p.Distances) } @@ -337,15 +338,15 @@ func TestUDPv5_callResend(t *testing.T) { }() // Ping answered by WHOAREYOU. - test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, nonce v5wire.Nonce) { + test.waitPacketOut(func(p *v5wire.Ping, addr netip.AddrPort, nonce v5wire.Nonce) { test.packetIn(&v5wire.Whoareyou{Nonce: nonce}) }) // Ping should be re-sent. - test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, _ v5wire.Nonce) { + test.waitPacketOut(func(p *v5wire.Ping, addr netip.AddrPort, _ v5wire.Nonce) { test.packetIn(&v5wire.Pong{ReqID: p.ReqID}) }) // Answer the other ping. - test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, _ v5wire.Nonce) { + test.waitPacketOut(func(p *v5wire.Ping, addr netip.AddrPort, _ v5wire.Nonce) { test.packetIn(&v5wire.Pong{ReqID: p.ReqID}) }) if err := <-done; err != nil { @@ -370,11 +371,11 @@ func TestUDPv5_multipleHandshakeRounds(t *testing.T) { }() // Ping answered by WHOAREYOU. - test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, nonce v5wire.Nonce) { + test.waitPacketOut(func(p *v5wire.Ping, addr netip.AddrPort, nonce v5wire.Nonce) { test.packetIn(&v5wire.Whoareyou{Nonce: nonce}) }) // Ping answered by WHOAREYOU again. - test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, nonce v5wire.Nonce) { + test.waitPacketOut(func(p *v5wire.Ping, addr netip.AddrPort, nonce v5wire.Nonce) { test.packetIn(&v5wire.Whoareyou{Nonce: nonce}) }) if err := <-done; err != errTimeout { @@ -401,7 +402,7 @@ func TestUDPv5_callTimeoutReset(t *testing.T) { }() // Serve two responses, slowly. - test.waitPacketOut(func(p *v5wire.Findnode, addr *net.UDPAddr, _ v5wire.Nonce) { + test.waitPacketOut(func(p *v5wire.Findnode, addr netip.AddrPort, _ v5wire.Nonce) { time.Sleep(respTimeout - 50*time.Millisecond) test.packetIn(&v5wire.Nodes{ ReqID: p.ReqID, @@ -439,7 +440,7 @@ func TestUDPv5_talkHandling(t *testing.T) { Protocol: "test", Message: []byte("test request"), }) - test.waitPacketOut(func(p *v5wire.TalkResponse, addr *net.UDPAddr, _ v5wire.Nonce) { + test.waitPacketOut(func(p *v5wire.TalkResponse, addr netip.AddrPort, _ v5wire.Nonce) { if !bytes.Equal(p.ReqID, []byte("foo")) { t.Error("wrong request ID in response:", p.ReqID) } @@ -458,7 +459,7 @@ func TestUDPv5_talkHandling(t *testing.T) { Protocol: "wrong", Message: []byte("test request"), }) - test.waitPacketOut(func(p *v5wire.TalkResponse, addr *net.UDPAddr, _ v5wire.Nonce) { + test.waitPacketOut(func(p *v5wire.TalkResponse, addr netip.AddrPort, _ v5wire.Nonce) { if !bytes.Equal(p.ReqID, []byte("2")) { t.Error("wrong request ID in response:", p.ReqID) } @@ -485,7 +486,7 @@ func TestUDPv5_talkRequest(t *testing.T) { _, err := test.udp.TalkRequest(remote, "test", []byte("test request")) done <- err }() - test.waitPacketOut(func(p *v5wire.TalkRequest, addr *net.UDPAddr, _ v5wire.Nonce) {}) + test.waitPacketOut(func(p *v5wire.TalkRequest, addr netip.AddrPort, _ v5wire.Nonce) {}) if err := <-done; err != errTimeout { t.Fatalf("want errTimeout, got %q", err) } @@ -495,7 +496,7 @@ func TestUDPv5_talkRequest(t *testing.T) { _, err := test.udp.TalkRequest(remote, "test", []byte("test request")) done <- err }() - test.waitPacketOut(func(p *v5wire.TalkRequest, addr *net.UDPAddr, _ v5wire.Nonce) { + test.waitPacketOut(func(p *v5wire.TalkRequest, addr netip.AddrPort, _ v5wire.Nonce) { if p.Protocol != "test" { t.Errorf("wrong protocol ID in talk request: %q", p.Protocol) } @@ -516,7 +517,7 @@ func TestUDPv5_talkRequest(t *testing.T) { _, err := test.udp.TalkRequestToID(remote.ID(), test.remoteaddr, "test", []byte("test request 2")) done <- err }() - test.waitPacketOut(func(p *v5wire.TalkRequest, addr *net.UDPAddr, _ v5wire.Nonce) { + test.waitPacketOut(func(p *v5wire.TalkRequest, addr netip.AddrPort, _ v5wire.Nonce) { if p.Protocol != "test" { t.Errorf("wrong protocol ID in talk request: %q", p.Protocol) } @@ -583,13 +584,14 @@ func TestUDPv5_lookup(t *testing.T) { for d, nn := range lookupTestnet.dists { for i, key := range nn { n := lookupTestnet.node(d, i) - test.getNode(key, &net.UDPAddr{IP: n.IP(), Port: n.UDP()}) + addr, _ := n.UDPEndpoint() + test.getNode(key, addr) } } // Seed table with initial node. initialNode := lookupTestnet.node(256, 0) - fillTable(test.table, []*node{wrapNode(initialNode)}, true) + fillTable(test.table, []*enode.Node{initialNode}, true) // Start the lookup. resultC := make(chan []*enode.Node, 1) @@ -601,7 +603,7 @@ func TestUDPv5_lookup(t *testing.T) { // Answer lookup packets. asked := make(map[enode.ID]bool) for done := false; !done; { - done = test.waitPacketOut(func(p v5wire.Packet, to *net.UDPAddr, _ v5wire.Nonce) { + done = test.waitPacketOut(func(p v5wire.Packet, to netip.AddrPort, _ v5wire.Nonce) { recipient, key := lookupTestnet.nodeByAddr(to) switch p := p.(type) { case *v5wire.Ping: @@ -652,11 +654,8 @@ func TestUDPv5_PingWithIPV4MappedAddress(t *testing.T) { test := newUDPV5Test(t) defer test.close() - rawIP := net.IPv4(0xFF, 0x12, 0x33, 0xE5) - test.remoteaddr = &net.UDPAddr{ - IP: rawIP.To16(), - Port: 0, - } + rawIP := netip.AddrFrom4([4]byte{0xFF, 0x12, 0x33, 0xE5}) + test.remoteaddr = netip.AddrPortFrom(netip.AddrFrom16(rawIP.As16()), 0) remote := test.getNode(test.remotekey, test.remoteaddr).Node() done := make(chan struct{}, 1) @@ -665,14 +664,14 @@ func TestUDPv5_PingWithIPV4MappedAddress(t *testing.T) { test.udp.handlePing(&v5wire.Ping{ENRSeq: 1}, remote.ID(), test.remoteaddr) done <- struct{}{} }() - test.waitPacketOut(func(p *v5wire.Pong, addr *net.UDPAddr, _ v5wire.Nonce) { + test.waitPacketOut(func(p *v5wire.Pong, addr netip.AddrPort, _ v5wire.Nonce) { if len(p.ToIP) == net.IPv6len { t.Error("Received untruncated ip address") } if len(p.ToIP) != net.IPv4len { t.Errorf("Received ip address with incorrect length: %d", len(p.ToIP)) } - if !p.ToIP.Equal(rawIP) { + if !p.ToIP.Equal(rawIP.AsSlice()) { t.Errorf("Received incorrect ip address: wanted %s but received %s", rawIP.String(), p.ToIP.String()) } }) @@ -688,9 +687,9 @@ type udpV5Test struct { db *enode.DB udp *UDPv5 localkey, remotekey *ecdsa.PrivateKey - remoteaddr *net.UDPAddr + remoteaddr netip.AddrPort nodesByID map[enode.ID]*enode.LocalNode - nodesByIP map[string]*enode.LocalNode + nodesByIP map[netip.Addr]*enode.LocalNode } // testCodec is the packet encoding used by protocol tests. This codec does not perform encryption. @@ -750,9 +749,9 @@ func newUDPV5Test(t *testing.T) *udpV5Test { pipe: newpipe(), localkey: newkey(), remotekey: newkey(), - remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303}, + remoteaddr: netip.MustParseAddrPort("10.0.1.99:30303"), nodesByID: make(map[enode.ID]*enode.LocalNode), - nodesByIP: make(map[string]*enode.LocalNode), + nodesByIP: make(map[netip.Addr]*enode.LocalNode), } test.db, _ = enode.OpenDB("") ln := enode.NewLocalNode(test.db, test.localkey) @@ -777,8 +776,8 @@ func (test *udpV5Test) packetIn(packet v5wire.Packet) { test.packetInFrom(test.remotekey, test.remoteaddr, packet) } -// handles a packet as if it had been sent to the transport by the key/endpoint. -func (test *udpV5Test) packetInFrom(key *ecdsa.PrivateKey, addr *net.UDPAddr, packet v5wire.Packet) { +// packetInFrom handles a packet as if it had been sent to the transport by the key/endpoint. +func (test *udpV5Test) packetInFrom(key *ecdsa.PrivateKey, addr netip.AddrPort, packet v5wire.Packet) { test.t.Helper() ln := test.getNode(key, addr) @@ -793,22 +792,22 @@ func (test *udpV5Test) packetInFrom(key *ecdsa.PrivateKey, addr *net.UDPAddr, pa } // getNode ensures the test knows about a node at the given endpoint. -func (test *udpV5Test) getNode(key *ecdsa.PrivateKey, addr *net.UDPAddr) *enode.LocalNode { +func (test *udpV5Test) getNode(key *ecdsa.PrivateKey, addr netip.AddrPort) *enode.LocalNode { id := encodePubkey(&key.PublicKey).id() ln := test.nodesByID[id] if ln == nil { db, _ := enode.OpenDB("") ln = enode.NewLocalNode(db, key) - ln.SetStaticIP(addr.IP) - ln.Set(enr.UDP(addr.Port)) + ln.SetStaticIP(addr.Addr().AsSlice()) + ln.Set(enr.UDP(addr.Port())) test.nodesByID[id] = ln } - test.nodesByIP[string(addr.IP)] = ln + test.nodesByIP[addr.Addr()] = ln return ln } // waitPacketOut waits for the next output packet and handles it using the given 'validate' -// function. The function must be of type func (X, *net.UDPAddr, v5wire.Nonce) where X is +// function. The function must be of type func (X, netip.AddrPort, v5wire.Nonce) where X is // assignable to packetV5. func (test *udpV5Test) waitPacketOut(validate interface{}) (closed bool) { test.t.Helper() @@ -824,7 +823,7 @@ func (test *udpV5Test) waitPacketOut(validate interface{}) (closed bool) { test.t.Fatalf("timed out waiting for %v", exptype) return false } - ln := test.nodesByIP[string(dgram.to.IP)] + ln := test.nodesByIP[dgram.to.Addr()] if ln == nil { test.t.Fatalf("attempt to send to non-existing node %v", &dgram.to) return false @@ -839,7 +838,7 @@ func (test *udpV5Test) waitPacketOut(validate interface{}) (closed bool) { test.t.Errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype) return false } - fn.Call([]reflect.Value{reflect.ValueOf(p), reflect.ValueOf(&dgram.to), reflect.ValueOf(frame.AuthTag)}) + fn.Call([]reflect.Value{reflect.ValueOf(p), reflect.ValueOf(dgram.to), reflect.ValueOf(frame.AuthTag)}) return false } diff --git a/p2p/server.go b/p2p/server.go index a3c53b0781..13eebed3f4 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -24,6 +24,7 @@ import ( "errors" "fmt" "net" + "net/netip" "slices" "sync" "sync/atomic" @@ -435,11 +436,11 @@ type sharedUDPConn struct { unhandled chan discover.ReadPacket } -// ReadFromUDP implements discover.UDPConn -func (s *sharedUDPConn) ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) { +// ReadFromUDPAddrPort implements discover.UDPConn +func (s *sharedUDPConn) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) { packet, ok := <-s.unhandled if !ok { - return 0, nil, errors.New("connection was closed") + return 0, netip.AddrPort{}, errors.New("connection was closed") } l := len(packet.Data) if l > len(b) {