p2p/discover: refactor node and endpoint representation (#29844)

Here we clean up internal uses of type discover.node, converting most code to use
enode.Node instead. The discover.node type used to be the canonical representation of
network hosts before ENR was introduced. Most code worked with *node to avoid conversions
when interacting with Table methods. Since *node also contains internal state of Table and
is a mutable type, using *node outside of Table code is prone to data races. It's also
cleaner not having to wrap/unwrap *enode.Node all the time.

discover.node has been renamed to tableNode to clarify its purpose.

While here, we also change most uses of net.UDPAddr into netip.AddrPort. While this is
technically a separate refactoring from the *node -> *enode.Node change, it is more
convenient because *enode.Node handles IP addresses as netip.Addr. The switch to package
netip in discovery would've happened very soon anyway.

The change to netip.AddrPort stops at certain interface points. For example, since package
p2p/netutil has not been converted to use netip.Addr yet, we still have to convert to
net.IP/net.UDPAddr in a few places.
pull/29879/head
Felix Lange 4 months ago committed by GitHub
parent e26fa9e40e
commit 94a8b296e4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 2
      cmd/devp2p/internal/v4test/framework.go
  2. 7
      p2p/discover/common.go
  3. 18
      p2p/discover/lookup.go
  4. 14
      p2p/discover/metrics.go
  5. 69
      p2p/discover/node.go
  6. 125
      p2p/discover/table.go
  7. 22
      p2p/discover/table_reval.go
  8. 4
      p2p/discover/table_reval_test.go
  9. 80
      p2p/discover/table_test.go
  10. 34
      p2p/discover/table_util_test.go
  11. 27
      p2p/discover/v4_lookup_test.go
  12. 146
      p2p/discover/v4_udp.go
  13. 107
      p2p/discover/v4_udp_test.go
  14. 16
      p2p/discover/v4wire/v4wire.go
  15. 6
      p2p/discover/v5_talk.go
  16. 80
      p2p/discover/v5_udp.go
  17. 93
      p2p/discover/v5_udp_test.go
  18. 7
      p2p/server.go

@ -110,7 +110,7 @@ func (te *testenv) localEndpoint(c net.PacketConn) v4wire.Endpoint {
} }
func (te *testenv) remoteEndpoint() v4wire.Endpoint { func (te *testenv) remoteEndpoint() v4wire.Endpoint {
return v4wire.NewEndpoint(te.remoteAddr, 0) return v4wire.NewEndpoint(te.remoteAddr.AddrPort(), 0)
} }
func contains(ns []v4wire.Node, key v4wire.Pubkey) bool { func contains(ns []v4wire.Node, key v4wire.Pubkey) bool {

@ -22,6 +22,7 @@ import (
"encoding/binary" "encoding/binary"
"math/rand" "math/rand"
"net" "net"
"net/netip"
"sync" "sync"
"time" "time"
@ -34,8 +35,8 @@ import (
// UDPConn is a network connection on which discovery can operate. // UDPConn is a network connection on which discovery can operate.
type UDPConn interface { type UDPConn interface {
ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error)
WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (n int, err error)
Close() error Close() error
LocalAddr() net.Addr LocalAddr() net.Addr
} }
@ -94,7 +95,7 @@ func ListenUDP(c UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv4, error) {
// channel if configured. // channel if configured.
type ReadPacket struct { type ReadPacket struct {
Data []byte Data []byte
Addr *net.UDPAddr Addr netip.AddrPort
} }
type randomSource interface { type randomSource interface {

@ -29,16 +29,16 @@ import (
// not need to be an actual node identifier. // not need to be an actual node identifier.
type lookup struct { type lookup struct {
tab *Table tab *Table
queryfunc func(*node) ([]*node, error) queryfunc queryFunc
replyCh chan []*node replyCh chan []*enode.Node
cancelCh <-chan struct{} cancelCh <-chan struct{}
asked, seen map[enode.ID]bool asked, seen map[enode.ID]bool
result nodesByDistance result nodesByDistance
replyBuffer []*node replyBuffer []*enode.Node
queries int queries int
} }
type queryFunc func(*node) ([]*node, error) type queryFunc func(*enode.Node) ([]*enode.Node, error)
func newLookup(ctx context.Context, tab *Table, target enode.ID, q queryFunc) *lookup { func newLookup(ctx context.Context, tab *Table, target enode.ID, q queryFunc) *lookup {
it := &lookup{ it := &lookup{
@ -47,7 +47,7 @@ func newLookup(ctx context.Context, tab *Table, target enode.ID, q queryFunc) *l
asked: make(map[enode.ID]bool), asked: make(map[enode.ID]bool),
seen: make(map[enode.ID]bool), seen: make(map[enode.ID]bool),
result: nodesByDistance{target: target}, result: nodesByDistance{target: target},
replyCh: make(chan []*node, alpha), replyCh: make(chan []*enode.Node, alpha),
cancelCh: ctx.Done(), cancelCh: ctx.Done(),
queries: -1, queries: -1,
} }
@ -61,7 +61,7 @@ func newLookup(ctx context.Context, tab *Table, target enode.ID, q queryFunc) *l
func (it *lookup) run() []*enode.Node { func (it *lookup) run() []*enode.Node {
for it.advance() { for it.advance() {
} }
return unwrapNodes(it.result.entries) return it.result.entries
} }
// advance advances the lookup until any new nodes have been found. // advance advances the lookup until any new nodes have been found.
@ -139,7 +139,7 @@ func (it *lookup) slowdown() {
} }
} }
func (it *lookup) query(n *node, reply chan<- []*node) { func (it *lookup) query(n *enode.Node, reply chan<- []*enode.Node) {
r, err := it.queryfunc(n) r, err := it.queryfunc(n)
if !errors.Is(err, errClosed) { // avoid recording failures on shutdown. if !errors.Is(err, errClosed) { // avoid recording failures on shutdown.
success := len(r) > 0 success := len(r) > 0
@ -154,7 +154,7 @@ func (it *lookup) query(n *node, reply chan<- []*node) {
// lookupIterator performs lookup operations and iterates over all seen nodes. // lookupIterator performs lookup operations and iterates over all seen nodes.
// When a lookup finishes, a new one is created through nextLookup. // When a lookup finishes, a new one is created through nextLookup.
type lookupIterator struct { type lookupIterator struct {
buffer []*node buffer []*enode.Node
nextLookup lookupFunc nextLookup lookupFunc
ctx context.Context ctx context.Context
cancel func() cancel func()
@ -173,7 +173,7 @@ func (it *lookupIterator) Node() *enode.Node {
if len(it.buffer) == 0 { if len(it.buffer) == 0 {
return nil return nil
} }
return unwrapNode(it.buffer[0]) return it.buffer[0]
} }
// Next moves to the next node. // Next moves to the next node.

@ -18,7 +18,7 @@ package discover
import ( import (
"fmt" "fmt"
"net" "net/netip"
"github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/metrics"
) )
@ -58,16 +58,16 @@ func newMeteredConn(conn UDPConn) UDPConn {
return &meteredUdpConn{UDPConn: conn} return &meteredUdpConn{UDPConn: conn}
} }
// ReadFromUDP delegates a network read to the underlying connection, bumping the udp ingress traffic meter along the way. // ReadFromUDPAddrPort delegates a network read to the underlying connection, bumping the udp ingress traffic meter along the way.
func (c *meteredUdpConn) ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) { func (c *meteredUdpConn) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) {
n, addr, err = c.UDPConn.ReadFromUDP(b) n, addr, err = c.UDPConn.ReadFromUDPAddrPort(b)
ingressTrafficMeter.Mark(int64(n)) ingressTrafficMeter.Mark(int64(n))
return n, addr, err return n, addr, err
} }
// Write delegates a network write to the underlying connection, bumping the udp egress traffic meter along the way. // WriteToUDP delegates a network write to the underlying connection, bumping the udp egress traffic meter along the way.
func (c *meteredUdpConn) WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error) { func (c *meteredUdpConn) WriteToUDP(b []byte, addr netip.AddrPort) (n int, err error) {
n, err = c.UDPConn.WriteToUDP(b, addr) n, err = c.UDPConn.WriteToUDPAddrPort(b, addr)
egressTrafficMeter.Mark(int64(n)) egressTrafficMeter.Mark(int64(n))
return n, err return n, err
} }

@ -21,7 +21,8 @@ import (
"crypto/elliptic" "crypto/elliptic"
"errors" "errors"
"math/big" "math/big"
"net" "slices"
"sort"
"time" "time"
"github.com/ethereum/go-ethereum/common/math" "github.com/ethereum/go-ethereum/common/math"
@ -37,9 +38,8 @@ type BucketNode struct {
Live bool `json:"live"` Live bool `json:"live"`
} }
// node represents a host on the network. // tableNode is an entry in Table.
// The fields of Node may not be modified. type tableNode struct {
type node struct {
*enode.Node *enode.Node
revalList *revalidationList revalList *revalidationList
addedToTable time.Time // first time node was added to bucket or replacement list addedToTable time.Time // first time node was added to bucket or replacement list
@ -75,34 +75,59 @@ func (e encPubkey) id() enode.ID {
return enode.ID(crypto.Keccak256Hash(e[:])) return enode.ID(crypto.Keccak256Hash(e[:]))
} }
func wrapNode(n *enode.Node) *node { func unwrapNodes(ns []*tableNode) []*enode.Node {
return &node{Node: n} result := make([]*enode.Node, len(ns))
}
func wrapNodes(ns []*enode.Node) []*node {
result := make([]*node, len(ns))
for i, n := range ns { for i, n := range ns {
result[i] = wrapNode(n) result[i] = n.Node
} }
return result return result
} }
func unwrapNode(n *node) *enode.Node { func (n *tableNode) String() string {
return n.Node return n.Node.String()
}
// nodesByDistance is a list of nodes, ordered by distance to target.
type nodesByDistance struct {
entries []*enode.Node
target enode.ID
} }
func unwrapNodes(ns []*node) []*enode.Node { // push adds the given node to the list, keeping the total size below maxElems.
result := make([]*enode.Node, len(ns)) func (h *nodesByDistance) push(n *enode.Node, maxElems int) {
for i, n := range ns { ix := sort.Search(len(h.entries), func(i int) bool {
result[i] = unwrapNode(n) return enode.DistCmp(h.target, h.entries[i].ID(), n.ID()) > 0
})
end := len(h.entries)
if len(h.entries) < maxElems {
h.entries = append(h.entries, n)
}
if ix < end {
// Slide existing entries down to make room.
// This will overwrite the entry we just appended.
copy(h.entries[ix+1:], h.entries[ix:])
h.entries[ix] = n
} }
return result
} }
func (n *node) addr() *net.UDPAddr { type nodeType interface {
return &net.UDPAddr{IP: n.IP(), Port: n.UDP()} ID() enode.ID
} }
func (n *node) String() string { // containsID reports whether ns contains a node with the given ID.
return n.Node.String() func containsID[N nodeType](ns []N, id enode.ID) bool {
for _, n := range ns {
if n.ID() == id {
return true
}
}
return false
}
// deleteNode removes a node from the list.
func deleteNode[N nodeType](list []N, id enode.ID) []N {
return slices.DeleteFunc(list, func(n N) bool {
return n.ID() == id
})
} }

@ -27,7 +27,6 @@ import (
"fmt" "fmt"
"net" "net"
"slices" "slices"
"sort"
"sync" "sync"
"time" "time"
@ -65,7 +64,7 @@ const (
type Table struct { type Table struct {
mutex sync.Mutex // protects buckets, bucket content, nursery, rand mutex sync.Mutex // protects buckets, bucket content, nursery, rand
buckets [nBuckets]*bucket // index of known nodes by distance buckets [nBuckets]*bucket // index of known nodes by distance
nursery []*node // bootstrap nodes nursery []*enode.Node // bootstrap nodes
rand reseedingRandom // source of randomness, periodically reseeded rand reseedingRandom // source of randomness, periodically reseeded
ips netutil.DistinctNetSet ips netutil.DistinctNetSet
revalidation tableRevalidation revalidation tableRevalidation
@ -85,8 +84,8 @@ type Table struct {
closeReq chan struct{} closeReq chan struct{}
closed chan struct{} closed chan struct{}
nodeAddedHook func(*bucket, *node) nodeAddedHook func(*bucket, *tableNode)
nodeRemovedHook func(*bucket, *node) nodeRemovedHook func(*bucket, *tableNode)
} }
// transport is implemented by the UDP transports. // transport is implemented by the UDP transports.
@ -101,20 +100,21 @@ type transport interface {
// bucket contains nodes, ordered by their last activity. the entry // bucket contains nodes, ordered by their last activity. the entry
// that was most recently active is the first element in entries. // that was most recently active is the first element in entries.
type bucket struct { type bucket struct {
entries []*node // live entries, sorted by time of last contact entries []*tableNode // live entries, sorted by time of last contact
replacements []*node // recently seen nodes to be used if revalidation fails replacements []*tableNode // recently seen nodes to be used if revalidation fails
ips netutil.DistinctNetSet ips netutil.DistinctNetSet
index int index int
} }
type addNodeOp struct { type addNodeOp struct {
node *node node *enode.Node
isInbound bool isInbound bool
forceSetLive bool // for tests
} }
type trackRequestOp struct { type trackRequestOp struct {
node *node node *enode.Node
foundNodes []*node foundNodes []*enode.Node
success bool success bool
} }
@ -186,7 +186,7 @@ func (tab *Table) getNode(id enode.ID) *enode.Node {
b := tab.bucket(id) b := tab.bucket(id)
for _, e := range b.entries { for _, e := range b.entries {
if e.ID() == id { if e.ID() == id {
return unwrapNode(e) return e.Node
} }
} }
return nil return nil
@ -202,7 +202,7 @@ func (tab *Table) close() {
// are used to connect to the network if the table is empty and there // are used to connect to the network if the table is empty and there
// are no known nodes in the database. // are no known nodes in the database.
func (tab *Table) setFallbackNodes(nodes []*enode.Node) error { func (tab *Table) setFallbackNodes(nodes []*enode.Node) error {
nursery := make([]*node, 0, len(nodes)) nursery := make([]*enode.Node, 0, len(nodes))
for _, n := range nodes { for _, n := range nodes {
if err := n.ValidateComplete(); err != nil { if err := n.ValidateComplete(); err != nil {
return fmt.Errorf("bad bootstrap node %q: %v", n, err) return fmt.Errorf("bad bootstrap node %q: %v", n, err)
@ -211,7 +211,7 @@ func (tab *Table) setFallbackNodes(nodes []*enode.Node) error {
tab.log.Error("Bootstrap node filtered by netrestrict", "id", n.ID(), "ip", n.IP()) tab.log.Error("Bootstrap node filtered by netrestrict", "id", n.ID(), "ip", n.IP())
continue continue
} }
nursery = append(nursery, wrapNode(n)) nursery = append(nursery, n)
} }
tab.nursery = nursery tab.nursery = nursery
return nil return nil
@ -255,9 +255,9 @@ func (tab *Table) findnodeByID(target enode.ID, nresults int, preferLive bool) *
liveNodes := &nodesByDistance{target: target} liveNodes := &nodesByDistance{target: target}
for _, b := range &tab.buckets { for _, b := range &tab.buckets {
for _, n := range b.entries { for _, n := range b.entries {
nodes.push(n, nresults) nodes.push(n.Node, nresults)
if preferLive && n.isValidatedLive { if preferLive && n.isValidatedLive {
liveNodes.push(n, nresults) liveNodes.push(n.Node, nresults)
} }
} }
} }
@ -309,8 +309,8 @@ func (tab *Table) len() (n int) {
// list. // list.
// //
// The caller must not hold tab.mutex. // The caller must not hold tab.mutex.
func (tab *Table) addFoundNode(n *node) bool { func (tab *Table) addFoundNode(n *enode.Node, forceSetLive bool) bool {
op := addNodeOp{node: n, isInbound: false} op := addNodeOp{node: n, isInbound: false, forceSetLive: forceSetLive}
select { select {
case tab.addNodeCh <- op: case tab.addNodeCh <- op:
return <-tab.addNodeHandled return <-tab.addNodeHandled
@ -327,7 +327,7 @@ func (tab *Table) addFoundNode(n *node) bool {
// repeatedly. // repeatedly.
// //
// The caller must not hold tab.mutex. // The caller must not hold tab.mutex.
func (tab *Table) addInboundNode(n *node) bool { func (tab *Table) addInboundNode(n *enode.Node) bool {
op := addNodeOp{node: n, isInbound: true} op := addNodeOp{node: n, isInbound: true}
select { select {
case tab.addNodeCh <- op: case tab.addNodeCh <- op:
@ -337,7 +337,7 @@ func (tab *Table) addInboundNode(n *node) bool {
} }
} }
func (tab *Table) trackRequest(n *node, success bool, foundNodes []*node) { func (tab *Table) trackRequest(n *enode.Node, success bool, foundNodes []*enode.Node) {
op := trackRequestOp{n, foundNodes, success} op := trackRequestOp{n, foundNodes, success}
select { select {
case tab.trackRequestCh <- op: case tab.trackRequestCh <- op:
@ -443,13 +443,14 @@ func (tab *Table) doRefresh(done chan struct{}) {
} }
func (tab *Table) loadSeedNodes() { func (tab *Table) loadSeedNodes() {
seeds := wrapNodes(tab.db.QuerySeeds(seedCount, seedMaxAge)) seeds := tab.db.QuerySeeds(seedCount, seedMaxAge)
seeds = append(seeds, tab.nursery...) seeds = append(seeds, tab.nursery...)
for i := range seeds { for i := range seeds {
seed := seeds[i] seed := seeds[i]
if tab.log.Enabled(context.Background(), log.LevelTrace) { if tab.log.Enabled(context.Background(), log.LevelTrace) {
age := time.Since(tab.db.LastPongReceived(seed.ID(), seed.IP())) age := time.Since(tab.db.LastPongReceived(seed.ID(), seed.IP()))
tab.log.Trace("Found seed node in database", "id", seed.ID(), "addr", seed.addr(), "age", age) addr, _ := seed.UDPEndpoint()
tab.log.Trace("Found seed node in database", "id", seed.ID(), "addr", addr, "age", age)
} }
tab.handleAddNode(addNodeOp{node: seed, isInbound: false}) tab.handleAddNode(addNodeOp{node: seed, isInbound: false})
} }
@ -513,7 +514,7 @@ func (tab *Table) handleAddNode(req addNodeOp) bool {
} }
b := tab.bucket(req.node.ID()) b := tab.bucket(req.node.ID())
n, _ := tab.bumpInBucket(b, req.node.Node, req.isInbound) n, _ := tab.bumpInBucket(b, req.node, req.isInbound)
if n != nil { if n != nil {
// Already in bucket. // Already in bucket.
return false return false
@ -529,15 +530,20 @@ func (tab *Table) handleAddNode(req addNodeOp) bool {
} }
// Add to bucket. // Add to bucket.
b.entries = append(b.entries, req.node) wn := &tableNode{Node: req.node}
b.replacements = deleteNode(b.replacements, req.node) if req.forceSetLive {
tab.nodeAdded(b, req.node) wn.livenessChecks = 1
wn.isValidatedLive = true
}
b.entries = append(b.entries, wn)
b.replacements = deleteNode(b.replacements, wn.ID())
tab.nodeAdded(b, wn)
return true return true
} }
// addReplacement adds n to the replacement cache of bucket b. // addReplacement adds n to the replacement cache of bucket b.
func (tab *Table) addReplacement(b *bucket, n *node) { func (tab *Table) addReplacement(b *bucket, n *enode.Node) {
if contains(b.replacements, n.ID()) { if containsID(b.replacements, n.ID()) {
// TODO: update ENR // TODO: update ENR
return return
} }
@ -545,15 +551,15 @@ func (tab *Table) addReplacement(b *bucket, n *node) {
return return
} }
n.addedToTable = time.Now() wn := &tableNode{Node: n, addedToTable: time.Now()}
var removed *node var removed *tableNode
b.replacements, removed = pushNode(b.replacements, n, maxReplacements) b.replacements, removed = pushNode(b.replacements, wn, maxReplacements)
if removed != nil { if removed != nil {
tab.removeIP(b, removed.IP()) tab.removeIP(b, removed.IP())
} }
} }
func (tab *Table) nodeAdded(b *bucket, n *node) { func (tab *Table) nodeAdded(b *bucket, n *tableNode) {
if n.addedToTable == (time.Time{}) { if n.addedToTable == (time.Time{}) {
n.addedToTable = time.Now() n.addedToTable = time.Now()
} }
@ -567,7 +573,7 @@ func (tab *Table) nodeAdded(b *bucket, n *node) {
} }
} }
func (tab *Table) nodeRemoved(b *bucket, n *node) { func (tab *Table) nodeRemoved(b *bucket, n *tableNode) {
tab.revalidation.nodeRemoved(n) tab.revalidation.nodeRemoved(n)
if tab.nodeRemovedHook != nil { if tab.nodeRemovedHook != nil {
tab.nodeRemovedHook(b, n) tab.nodeRemovedHook(b, n)
@ -579,8 +585,8 @@ func (tab *Table) nodeRemoved(b *bucket, n *node) {
// deleteInBucket removes node n from the table. // deleteInBucket removes node n from the table.
// If there are replacement nodes in the bucket, the node is replaced. // If there are replacement nodes in the bucket, the node is replaced.
func (tab *Table) deleteInBucket(b *bucket, id enode.ID) *node { func (tab *Table) deleteInBucket(b *bucket, id enode.ID) *tableNode {
index := slices.IndexFunc(b.entries, func(e *node) bool { return e.ID() == id }) index := slices.IndexFunc(b.entries, func(e *tableNode) bool { return e.ID() == id })
if index == -1 { if index == -1 {
// Entry has been removed already. // Entry has been removed already.
return nil return nil
@ -608,8 +614,8 @@ func (tab *Table) deleteInBucket(b *bucket, id enode.ID) *node {
// bumpInBucket updates a node record if it exists in the bucket. // bumpInBucket updates a node record if it exists in the bucket.
// The second return value reports whether the node's endpoint (IP/port) was updated. // The second return value reports whether the node's endpoint (IP/port) was updated.
func (tab *Table) bumpInBucket(b *bucket, newRecord *enode.Node, isInbound bool) (n *node, endpointChanged bool) { func (tab *Table) bumpInBucket(b *bucket, newRecord *enode.Node, isInbound bool) (n *tableNode, endpointChanged bool) {
i := slices.IndexFunc(b.entries, func(elem *node) bool { i := slices.IndexFunc(b.entries, func(elem *tableNode) bool {
return elem.ID() == newRecord.ID() return elem.ID() == newRecord.ID()
}) })
if i == -1 { if i == -1 {
@ -672,21 +678,12 @@ func (tab *Table) handleTrackRequest(op trackRequestOp) {
// Add found nodes. // Add found nodes.
for _, n := range op.foundNodes { for _, n := range op.foundNodes {
tab.handleAddNode(addNodeOp{n, false}) tab.handleAddNode(addNodeOp{n, false, false})
} }
} }
func contains(ns []*node, id enode.ID) bool {
for _, n := range ns {
if n.ID() == id {
return true
}
}
return false
}
// pushNode adds n to the front of list, keeping at most max items. // pushNode adds n to the front of list, keeping at most max items.
func pushNode(list []*node, n *node, max int) ([]*node, *node) { func pushNode(list []*tableNode, n *tableNode, max int) ([]*tableNode, *tableNode) {
if len(list) < max { if len(list) < max {
list = append(list, nil) list = append(list, nil)
} }
@ -695,37 +692,3 @@ func pushNode(list []*node, n *node, max int) ([]*node, *node) {
list[0] = n list[0] = n
return list, removed return list, removed
} }
// deleteNode removes n from list.
func deleteNode(list []*node, n *node) []*node {
for i := range list {
if list[i].ID() == n.ID() {
return append(list[:i], list[i+1:]...)
}
}
return list
}
// nodesByDistance is a list of nodes, ordered by distance to target.
type nodesByDistance struct {
entries []*node
target enode.ID
}
// push adds the given node to the list, keeping the total size below maxElems.
func (h *nodesByDistance) push(n *node, maxElems int) {
ix := sort.Search(len(h.entries), func(i int) bool {
return enode.DistCmp(h.target, h.entries[i].ID(), n.ID()) > 0
})
end := len(h.entries)
if len(h.entries) < maxElems {
h.entries = append(h.entries, n)
}
if ix < end {
// Slide existing entries down to make room.
// This will overwrite the entry we just appended.
copy(h.entries[ix+1:], h.entries[ix:])
h.entries[ix] = n
}
}

@ -39,7 +39,7 @@ type tableRevalidation struct {
} }
type revalidationResponse struct { type revalidationResponse struct {
n *node n *tableNode
newRecord *enode.Node newRecord *enode.Node
didRespond bool didRespond bool
} }
@ -55,12 +55,12 @@ func (tr *tableRevalidation) init(cfg *Config) {
} }
// nodeAdded is called when the table receives a new node. // nodeAdded is called when the table receives a new node.
func (tr *tableRevalidation) nodeAdded(tab *Table, n *node) { func (tr *tableRevalidation) nodeAdded(tab *Table, n *tableNode) {
tr.fast.push(n, tab.cfg.Clock.Now(), &tab.rand) tr.fast.push(n, tab.cfg.Clock.Now(), &tab.rand)
} }
// nodeRemoved is called when a node was removed from the table. // nodeRemoved is called when a node was removed from the table.
func (tr *tableRevalidation) nodeRemoved(n *node) { func (tr *tableRevalidation) nodeRemoved(n *tableNode) {
if n.revalList == nil { if n.revalList == nil {
panic(fmt.Errorf("removed node %v has nil revalList", n.ID())) panic(fmt.Errorf("removed node %v has nil revalList", n.ID()))
} }
@ -68,7 +68,7 @@ func (tr *tableRevalidation) nodeRemoved(n *node) {
} }
// nodeEndpointChanged is called when a change in IP or port is detected. // nodeEndpointChanged is called when a change in IP or port is detected.
func (tr *tableRevalidation) nodeEndpointChanged(tab *Table, n *node) { func (tr *tableRevalidation) nodeEndpointChanged(tab *Table, n *tableNode) {
n.isValidatedLive = false n.isValidatedLive = false
tr.moveToList(&tr.fast, n, tab.cfg.Clock.Now(), &tab.rand) tr.moveToList(&tr.fast, n, tab.cfg.Clock.Now(), &tab.rand)
} }
@ -90,7 +90,7 @@ func (tr *tableRevalidation) run(tab *Table, now mclock.AbsTime) (nextTime mcloc
} }
// startRequest spawns a revalidation request for node n. // startRequest spawns a revalidation request for node n.
func (tr *tableRevalidation) startRequest(tab *Table, n *node) { func (tr *tableRevalidation) startRequest(tab *Table, n *tableNode) {
if _, ok := tr.activeReq[n.ID()]; ok { if _, ok := tr.activeReq[n.ID()]; ok {
panic(fmt.Errorf("duplicate startRequest (node %v)", n.ID())) panic(fmt.Errorf("duplicate startRequest (node %v)", n.ID()))
} }
@ -180,7 +180,7 @@ func (tr *tableRevalidation) handleResponse(tab *Table, resp revalidationRespons
} }
// moveToList ensures n is in the 'dest' list. // moveToList ensures n is in the 'dest' list.
func (tr *tableRevalidation) moveToList(dest *revalidationList, n *node, now mclock.AbsTime, rand randomSource) { func (tr *tableRevalidation) moveToList(dest *revalidationList, n *tableNode, now mclock.AbsTime, rand randomSource) {
if n.revalList == dest { if n.revalList == dest {
return return
} }
@ -192,14 +192,14 @@ func (tr *tableRevalidation) moveToList(dest *revalidationList, n *node, now mcl
// revalidationList holds a list nodes and the next revalidation time. // revalidationList holds a list nodes and the next revalidation time.
type revalidationList struct { type revalidationList struct {
nodes []*node nodes []*tableNode
nextTime mclock.AbsTime nextTime mclock.AbsTime
interval time.Duration interval time.Duration
name string name string
} }
// get returns a random node from the queue. Nodes in the 'exclude' map are not returned. // get returns a random node from the queue. Nodes in the 'exclude' map are not returned.
func (list *revalidationList) get(now mclock.AbsTime, rand randomSource, exclude map[enode.ID]struct{}) *node { func (list *revalidationList) get(now mclock.AbsTime, rand randomSource, exclude map[enode.ID]struct{}) *tableNode {
if now < list.nextTime || len(list.nodes) == 0 { if now < list.nextTime || len(list.nodes) == 0 {
return nil return nil
} }
@ -217,7 +217,7 @@ func (list *revalidationList) schedule(now mclock.AbsTime, rand randomSource) {
list.nextTime = now.Add(time.Duration(rand.Int63n(int64(list.interval)))) list.nextTime = now.Add(time.Duration(rand.Int63n(int64(list.interval))))
} }
func (list *revalidationList) push(n *node, now mclock.AbsTime, rand randomSource) { func (list *revalidationList) push(n *tableNode, now mclock.AbsTime, rand randomSource) {
list.nodes = append(list.nodes, n) list.nodes = append(list.nodes, n)
if list.nextTime == never { if list.nextTime == never {
list.schedule(now, rand) list.schedule(now, rand)
@ -225,7 +225,7 @@ func (list *revalidationList) push(n *node, now mclock.AbsTime, rand randomSourc
n.revalList = list n.revalList = list
} }
func (list *revalidationList) remove(n *node) { func (list *revalidationList) remove(n *tableNode) {
i := slices.Index(list.nodes, n) i := slices.Index(list.nodes, n)
if i == -1 { if i == -1 {
panic(fmt.Errorf("node %v not found in list", n.ID())) panic(fmt.Errorf("node %v not found in list", n.ID()))
@ -238,7 +238,7 @@ func (list *revalidationList) remove(n *node) {
} }
func (list *revalidationList) contains(id enode.ID) bool { func (list *revalidationList) contains(id enode.ID) bool {
return slices.ContainsFunc(list.nodes, func(n *node) bool { return slices.ContainsFunc(list.nodes, func(n *tableNode) bool {
return n.ID() == id return n.ID() == id
}) })
} }

@ -110,10 +110,10 @@ func TestRevalidation_endpointUpdate(t *testing.T) {
} }
tr.handleResponse(tab, resp) tr.handleResponse(tab, resp)
if !tr.fast.contains(node.ID()) { if tr.fast.nodes[0].ID() != node.ID() {
t.Fatal("node not contained in fast revalidation list") t.Fatal("node not contained in fast revalidation list")
} }
if node.isValidatedLive { if tr.fast.nodes[0].isValidatedLive {
t.Fatal("node is marked live after endpoint change") t.Fatal("node is marked live after endpoint change")
} }
} }

@ -22,6 +22,7 @@ import (
"math/rand" "math/rand"
"net" "net"
"reflect" "reflect"
"slices"
"testing" "testing"
"testing/quick" "testing/quick"
"time" "time"
@ -64,7 +65,7 @@ func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding
// Fill up the sender's bucket. // Fill up the sender's bucket.
replacementNodeKey, _ := crypto.HexToECDSA("45a915e4d060149eb4365960e6a7a45f334393093061116b197e3240065ff2d8") replacementNodeKey, _ := crypto.HexToECDSA("45a915e4d060149eb4365960e6a7a45f334393093061116b197e3240065ff2d8")
replacementNode := wrapNode(enode.NewV4(&replacementNodeKey.PublicKey, net.IP{127, 0, 0, 1}, 99, 99)) replacementNode := enode.NewV4(&replacementNodeKey.PublicKey, net.IP{127, 0, 0, 1}, 99, 99)
last := fillBucket(tab, replacementNode.ID()) last := fillBucket(tab, replacementNode.ID())
tab.mutex.Lock() tab.mutex.Lock()
nodeEvents := newNodeEventRecorder(128) nodeEvents := newNodeEventRecorder(128)
@ -78,7 +79,7 @@ func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding
transport.dead[replacementNode.ID()] = !newNodeIsResponding transport.dead[replacementNode.ID()] = !newNodeIsResponding
// Add replacement node to table. // Add replacement node to table.
tab.addFoundNode(replacementNode) tab.addFoundNode(replacementNode, false)
t.Log("last:", last.ID()) t.Log("last:", last.ID())
t.Log("replacement:", replacementNode.ID()) t.Log("replacement:", replacementNode.ID())
@ -115,11 +116,11 @@ func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding
if l := len(bucket.entries); l != wantSize { if l := len(bucket.entries); l != wantSize {
t.Errorf("wrong bucket size after revalidation: got %d, want %d", l, wantSize) t.Errorf("wrong bucket size after revalidation: got %d, want %d", l, wantSize)
} }
if ok := contains(bucket.entries, last.ID()); ok != lastInBucketIsResponding { if ok := containsID(bucket.entries, last.ID()); ok != lastInBucketIsResponding {
t.Errorf("revalidated node found: %t, want: %t", ok, lastInBucketIsResponding) t.Errorf("revalidated node found: %t, want: %t", ok, lastInBucketIsResponding)
} }
wantNewEntry := newNodeIsResponding && !lastInBucketIsResponding wantNewEntry := newNodeIsResponding && !lastInBucketIsResponding
if ok := contains(bucket.entries, replacementNode.ID()); ok != wantNewEntry { if ok := containsID(bucket.entries, replacementNode.ID()); ok != wantNewEntry {
t.Errorf("replacement node found: %t, want: %t", ok, wantNewEntry) t.Errorf("replacement node found: %t, want: %t", ok, wantNewEntry)
} }
} }
@ -153,7 +154,7 @@ func TestTable_IPLimit(t *testing.T) {
for i := 0; i < tableIPLimit+1; i++ { for i := 0; i < tableIPLimit+1; i++ {
n := nodeAtDistance(tab.self().ID(), i, net.IP{172, 0, 1, byte(i)}) n := nodeAtDistance(tab.self().ID(), i, net.IP{172, 0, 1, byte(i)})
tab.addFoundNode(n) tab.addFoundNode(n, false)
} }
if tab.len() > tableIPLimit { if tab.len() > tableIPLimit {
t.Errorf("too many nodes in table") t.Errorf("too many nodes in table")
@ -171,7 +172,7 @@ func TestTable_BucketIPLimit(t *testing.T) {
d := 3 d := 3
for i := 0; i < bucketIPLimit+1; i++ { for i := 0; i < bucketIPLimit+1; i++ {
n := nodeAtDistance(tab.self().ID(), d, net.IP{172, 0, 1, byte(i)}) n := nodeAtDistance(tab.self().ID(), d, net.IP{172, 0, 1, byte(i)})
tab.addFoundNode(n) tab.addFoundNode(n, false)
} }
if tab.len() > bucketIPLimit { if tab.len() > bucketIPLimit {
t.Errorf("too many nodes in table") t.Errorf("too many nodes in table")
@ -232,7 +233,7 @@ func TestTable_findnodeByID(t *testing.T) {
// check that the result nodes have minimum distance to target. // check that the result nodes have minimum distance to target.
for _, b := range tab.buckets { for _, b := range tab.buckets {
for _, n := range b.entries { for _, n := range b.entries {
if contains(result, n.ID()) { if containsID(result, n.ID()) {
continue // don't run the check below for nodes in result continue // don't run the check below for nodes in result
} }
farthestResult := result[len(result)-1].ID() farthestResult := result[len(result)-1].ID()
@ -255,7 +256,7 @@ func TestTable_findnodeByID(t *testing.T) {
type closeTest struct { type closeTest struct {
Self enode.ID Self enode.ID
Target enode.ID Target enode.ID
All []*node All []*enode.Node
N int N int
} }
@ -268,8 +269,7 @@ func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value {
for _, id := range gen([]enode.ID{}, rand).([]enode.ID) { for _, id := range gen([]enode.ID{}, rand).([]enode.ID) {
r := new(enr.Record) r := new(enr.Record)
r.Set(enr.IP(genIP(rand))) r.Set(enr.IP(genIP(rand)))
n := wrapNode(enode.SignNull(r, id)) n := enode.SignNull(r, id)
n.livenessChecks = 1
t.All = append(t.All, n) t.All = append(t.All, n)
} }
return reflect.ValueOf(t) return reflect.ValueOf(t)
@ -284,16 +284,16 @@ func TestTable_addInboundNode(t *testing.T) {
// Insert two nodes. // Insert two nodes.
n1 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 1}) n1 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 1})
n2 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 2}) n2 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 2})
tab.addFoundNode(n1) tab.addFoundNode(n1, false)
tab.addFoundNode(n2) tab.addFoundNode(n2, false)
checkBucketContent(t, tab, []*enode.Node{n1.Node, n2.Node}) checkBucketContent(t, tab, []*enode.Node{n1, n2})
// Add a changed version of n2. The bucket should be updated. // Add a changed version of n2. The bucket should be updated.
newrec := n2.Record() newrec := n2.Record()
newrec.Set(enr.IP{99, 99, 99, 99}) newrec.Set(enr.IP{99, 99, 99, 99})
n2v2 := enode.SignNull(newrec, n2.ID()) n2v2 := enode.SignNull(newrec, n2.ID())
tab.addInboundNode(wrapNode(n2v2)) tab.addInboundNode(n2v2)
checkBucketContent(t, tab, []*enode.Node{n1.Node, n2v2}) checkBucketContent(t, tab, []*enode.Node{n1, n2v2})
// Try updating n2 without sequence number change. The update is accepted // Try updating n2 without sequence number change. The update is accepted
// because it's inbound. // because it's inbound.
@ -301,8 +301,8 @@ func TestTable_addInboundNode(t *testing.T) {
newrec.Set(enr.IP{100, 100, 100, 100}) newrec.Set(enr.IP{100, 100, 100, 100})
newrec.SetSeq(n2.Seq()) newrec.SetSeq(n2.Seq())
n2v3 := enode.SignNull(newrec, n2.ID()) n2v3 := enode.SignNull(newrec, n2.ID())
tab.addInboundNode(wrapNode(n2v3)) tab.addInboundNode(n2v3)
checkBucketContent(t, tab, []*enode.Node{n1.Node, n2v3}) checkBucketContent(t, tab, []*enode.Node{n1, n2v3})
} }
func TestTable_addFoundNode(t *testing.T) { func TestTable_addFoundNode(t *testing.T) {
@ -314,16 +314,16 @@ func TestTable_addFoundNode(t *testing.T) {
// Insert two nodes. // Insert two nodes.
n1 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 1}) n1 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 1})
n2 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 2}) n2 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 2})
tab.addFoundNode(n1) tab.addFoundNode(n1, false)
tab.addFoundNode(n2) tab.addFoundNode(n2, false)
checkBucketContent(t, tab, []*enode.Node{n1.Node, n2.Node}) checkBucketContent(t, tab, []*enode.Node{n1, n2})
// Add a changed version of n2. The bucket should be updated. // Add a changed version of n2. The bucket should be updated.
newrec := n2.Record() newrec := n2.Record()
newrec.Set(enr.IP{99, 99, 99, 99}) newrec.Set(enr.IP{99, 99, 99, 99})
n2v2 := enode.SignNull(newrec, n2.ID()) n2v2 := enode.SignNull(newrec, n2.ID())
tab.addFoundNode(wrapNode(n2v2)) tab.addFoundNode(n2v2, false)
checkBucketContent(t, tab, []*enode.Node{n1.Node, n2v2}) checkBucketContent(t, tab, []*enode.Node{n1, n2v2})
// Try updating n2 without a sequence number change. // Try updating n2 without a sequence number change.
// The update should not be accepted. // The update should not be accepted.
@ -331,8 +331,8 @@ func TestTable_addFoundNode(t *testing.T) {
newrec.Set(enr.IP{100, 100, 100, 100}) newrec.Set(enr.IP{100, 100, 100, 100})
newrec.SetSeq(n2.Seq()) newrec.SetSeq(n2.Seq())
n2v3 := enode.SignNull(newrec, n2.ID()) n2v3 := enode.SignNull(newrec, n2.ID())
tab.addFoundNode(wrapNode(n2v3)) tab.addFoundNode(n2v3, false)
checkBucketContent(t, tab, []*enode.Node{n1.Node, n2v2}) checkBucketContent(t, tab, []*enode.Node{n1, n2v2})
} }
// This test checks that discv4 nodes can update their own endpoint via PING. // This test checks that discv4 nodes can update their own endpoint via PING.
@ -345,13 +345,13 @@ func TestTable_addInboundNodeUpdateV4Accept(t *testing.T) {
// Add a v4 node. // Add a v4 node.
key, _ := crypto.HexToECDSA("dd3757a8075e88d0f2b1431e7d3c5b1562e1c0aab9643707e8cbfcc8dae5cfe3") key, _ := crypto.HexToECDSA("dd3757a8075e88d0f2b1431e7d3c5b1562e1c0aab9643707e8cbfcc8dae5cfe3")
n1 := enode.NewV4(&key.PublicKey, net.IP{88, 77, 66, 1}, 9000, 9000) n1 := enode.NewV4(&key.PublicKey, net.IP{88, 77, 66, 1}, 9000, 9000)
tab.addInboundNode(wrapNode(n1)) tab.addInboundNode(n1)
checkBucketContent(t, tab, []*enode.Node{n1}) checkBucketContent(t, tab, []*enode.Node{n1})
// Add an updated version with changed IP. // Add an updated version with changed IP.
// The update will be accepted because it is inbound. // The update will be accepted because it is inbound.
n1v2 := enode.NewV4(&key.PublicKey, net.IP{99, 99, 99, 99}, 9000, 9000) n1v2 := enode.NewV4(&key.PublicKey, net.IP{99, 99, 99, 99}, 9000, 9000)
tab.addInboundNode(wrapNode(n1v2)) tab.addInboundNode(n1v2)
checkBucketContent(t, tab, []*enode.Node{n1v2}) checkBucketContent(t, tab, []*enode.Node{n1v2})
} }
@ -366,13 +366,13 @@ func TestTable_addFoundNodeV4UpdateReject(t *testing.T) {
// Add a v4 node. // Add a v4 node.
key, _ := crypto.HexToECDSA("dd3757a8075e88d0f2b1431e7d3c5b1562e1c0aab9643707e8cbfcc8dae5cfe3") key, _ := crypto.HexToECDSA("dd3757a8075e88d0f2b1431e7d3c5b1562e1c0aab9643707e8cbfcc8dae5cfe3")
n1 := enode.NewV4(&key.PublicKey, net.IP{88, 77, 66, 1}, 9000, 9000) n1 := enode.NewV4(&key.PublicKey, net.IP{88, 77, 66, 1}, 9000, 9000)
tab.addFoundNode(wrapNode(n1)) tab.addFoundNode(n1, false)
checkBucketContent(t, tab, []*enode.Node{n1}) checkBucketContent(t, tab, []*enode.Node{n1})
// Add an updated version with changed IP. // Add an updated version with changed IP.
// The update won't be accepted because it isn't inbound. // The update won't be accepted because it isn't inbound.
n1v2 := enode.NewV4(&key.PublicKey, net.IP{99, 99, 99, 99}, 9000, 9000) n1v2 := enode.NewV4(&key.PublicKey, net.IP{99, 99, 99, 99}, 9000, 9000)
tab.addFoundNode(wrapNode(n1v2)) tab.addFoundNode(n1v2, false)
checkBucketContent(t, tab, []*enode.Node{n1}) checkBucketContent(t, tab, []*enode.Node{n1})
} }
@ -413,8 +413,8 @@ func TestTable_revalidateSyncRecord(t *testing.T) {
var r enr.Record var r enr.Record
r.Set(enr.IP(net.IP{127, 0, 0, 1})) r.Set(enr.IP(net.IP{127, 0, 0, 1}))
id := enode.ID{1} id := enode.ID{1}
n1 := wrapNode(enode.SignNull(&r, id)) n1 := enode.SignNull(&r, id)
tab.addFoundNode(n1) tab.addFoundNode(n1, false)
// Update the node record. // Update the node record.
r.Set(enr.WithEntry("foo", "bar")) r.Set(enr.WithEntry("foo", "bar"))
@ -437,7 +437,7 @@ func TestNodesPush(t *testing.T) {
n1 := nodeAtDistance(target, 255, intIP(1)) n1 := nodeAtDistance(target, 255, intIP(1))
n2 := nodeAtDistance(target, 254, intIP(2)) n2 := nodeAtDistance(target, 254, intIP(2))
n3 := nodeAtDistance(target, 253, intIP(3)) n3 := nodeAtDistance(target, 253, intIP(3))
perm := [][]*node{ perm := [][]*enode.Node{
{n3, n2, n1}, {n3, n2, n1},
{n3, n1, n2}, {n3, n1, n2},
{n2, n3, n1}, {n2, n3, n1},
@ -452,7 +452,7 @@ func TestNodesPush(t *testing.T) {
for _, n := range nodes { for _, n := range nodes {
list.push(n, 3) list.push(n, 3)
} }
if !slicesEqual(list.entries, perm[0], nodeIDEqual) { if !slices.EqualFunc(list.entries, perm[0], nodeIDEqual) {
t.Fatal("not equal") t.Fatal("not equal")
} }
} }
@ -463,28 +463,16 @@ func TestNodesPush(t *testing.T) {
for _, n := range nodes { for _, n := range nodes {
list.push(n, 2) list.push(n, 2)
} }
if !slicesEqual(list.entries, perm[0][:2], nodeIDEqual) { if !slices.EqualFunc(list.entries, perm[0][:2], nodeIDEqual) {
t.Fatal("not equal") t.Fatal("not equal")
} }
} }
} }
func nodeIDEqual(n1, n2 *node) bool { func nodeIDEqual[N nodeType](n1, n2 N) bool {
return n1.ID() == n2.ID() return n1.ID() == n2.ID()
} }
func slicesEqual[T any](s1, s2 []T, check func(e1, e2 T) bool) bool {
if len(s1) != len(s2) {
return false
}
for i := range s1 {
if !check(s1[i], s2[i]) {
return false
}
}
return true
}
// gen wraps quick.Value so it's easier to use. // gen wraps quick.Value so it's easier to use.
// it generates a random value of the given value's type. // it generates a random value of the given value's type.
func gen(typ interface{}, rand *rand.Rand) interface{} { func gen(typ interface{}, rand *rand.Rand) interface{} {

@ -56,18 +56,18 @@ func newInactiveTestTable(t transport, cfg Config) (*Table, *enode.DB) {
} }
// nodeAtDistance creates a node for which enode.LogDist(base, n.id) == ld. // nodeAtDistance creates a node for which enode.LogDist(base, n.id) == ld.
func nodeAtDistance(base enode.ID, ld int, ip net.IP) *node { func nodeAtDistance(base enode.ID, ld int, ip net.IP) *enode.Node {
var r enr.Record var r enr.Record
r.Set(enr.IP(ip)) r.Set(enr.IP(ip))
r.Set(enr.UDP(30303)) r.Set(enr.UDP(30303))
return wrapNode(enode.SignNull(&r, idAtDistance(base, ld))) return enode.SignNull(&r, idAtDistance(base, ld))
} }
// nodesAtDistance creates n nodes for which enode.LogDist(base, node.ID()) == ld. // nodesAtDistance creates n nodes for which enode.LogDist(base, node.ID()) == ld.
func nodesAtDistance(base enode.ID, ld int, n int) []*enode.Node { func nodesAtDistance(base enode.ID, ld int, n int) []*enode.Node {
results := make([]*enode.Node, n) results := make([]*enode.Node, n)
for i := range results { for i := range results {
results[i] = unwrapNode(nodeAtDistance(base, ld, intIP(i))) results[i] = nodeAtDistance(base, ld, intIP(i))
} }
return results return results
} }
@ -105,12 +105,12 @@ func intIP(i int) net.IP {
} }
// fillBucket inserts nodes into the given bucket until it is full. // fillBucket inserts nodes into the given bucket until it is full.
func fillBucket(tab *Table, id enode.ID) (last *node) { func fillBucket(tab *Table, id enode.ID) (last *tableNode) {
ld := enode.LogDist(tab.self().ID(), id) ld := enode.LogDist(tab.self().ID(), id)
b := tab.bucket(id) b := tab.bucket(id)
for len(b.entries) < bucketSize { for len(b.entries) < bucketSize {
node := nodeAtDistance(tab.self().ID(), ld, intIP(ld)) node := nodeAtDistance(tab.self().ID(), ld, intIP(ld))
if !tab.addFoundNode(node) { if !tab.addFoundNode(node, false) {
panic("node not added") panic("node not added")
} }
} }
@ -119,13 +119,9 @@ func fillBucket(tab *Table, id enode.ID) (last *node) {
// fillTable adds nodes the table to the end of their corresponding bucket // fillTable adds nodes the table to the end of their corresponding bucket
// if the bucket is not full. The caller must not hold tab.mutex. // if the bucket is not full. The caller must not hold tab.mutex.
func fillTable(tab *Table, nodes []*node, setLive bool) { func fillTable(tab *Table, nodes []*enode.Node, setLive bool) {
for _, n := range nodes { for _, n := range nodes {
if setLive { tab.addFoundNode(n, setLive)
n.livenessChecks = 1
n.isValidatedLive = true
}
tab.addFoundNode(n)
} }
} }
@ -219,7 +215,7 @@ func (t *pingRecorder) RequestENR(n *enode.Node) (*enode.Node, error) {
return t.records[n.ID()], nil return t.records[n.ID()], nil
} }
func hasDuplicates(slice []*node) bool { func hasDuplicates(slice []*enode.Node) bool {
seen := make(map[enode.ID]bool, len(slice)) seen := make(map[enode.ID]bool, len(slice))
for i, e := range slice { for i, e := range slice {
if e == nil { if e == nil {
@ -261,14 +257,14 @@ func nodeEqual(n1 *enode.Node, n2 *enode.Node) bool {
return n1.ID() == n2.ID() && n1.IP().Equal(n2.IP()) return n1.ID() == n2.ID() && n1.IP().Equal(n2.IP())
} }
func sortByID(nodes []*enode.Node) { func sortByID[N nodeType](nodes []N) {
slices.SortFunc(nodes, func(a, b *enode.Node) int { slices.SortFunc(nodes, func(a, b N) int {
return bytes.Compare(a.ID().Bytes(), b.ID().Bytes()) return bytes.Compare(a.ID().Bytes(), b.ID().Bytes())
}) })
} }
func sortedByDistanceTo(distbase enode.ID, slice []*node) bool { func sortedByDistanceTo(distbase enode.ID, slice []*enode.Node) bool {
return slices.IsSortedFunc(slice, func(a, b *node) int { return slices.IsSortedFunc(slice, func(a, b *enode.Node) int {
return enode.DistCmp(distbase, a.ID(), b.ID()) return enode.DistCmp(distbase, a.ID(), b.ID())
}) })
} }
@ -304,7 +300,7 @@ type nodeEventRecorder struct {
} }
type recordedNodeEvent struct { type recordedNodeEvent struct {
node *node node *tableNode
added bool added bool
} }
@ -314,7 +310,7 @@ func newNodeEventRecorder(buffer int) *nodeEventRecorder {
} }
} }
func (set *nodeEventRecorder) nodeAdded(b *bucket, n *node) { func (set *nodeEventRecorder) nodeAdded(b *bucket, n *tableNode) {
select { select {
case set.evc <- recordedNodeEvent{n, true}: case set.evc <- recordedNodeEvent{n, true}:
default: default:
@ -322,7 +318,7 @@ func (set *nodeEventRecorder) nodeAdded(b *bucket, n *node) {
} }
} }
func (set *nodeEventRecorder) nodeRemoved(b *bucket, n *node) { func (set *nodeEventRecorder) nodeRemoved(b *bucket, n *tableNode) {
select { select {
case set.evc <- recordedNodeEvent{n, false}: case set.evc <- recordedNodeEvent{n, false}:
default: default:

@ -19,7 +19,7 @@ package discover
import ( import (
"crypto/ecdsa" "crypto/ecdsa"
"fmt" "fmt"
"net" "net/netip"
"slices" "slices"
"testing" "testing"
@ -40,7 +40,7 @@ func TestUDPv4_Lookup(t *testing.T) {
} }
// Seed table with initial node. // Seed table with initial node.
fillTable(test.table, []*node{wrapNode(lookupTestnet.node(256, 0))}, true) fillTable(test.table, []*enode.Node{lookupTestnet.node(256, 0)}, true)
// Start the lookup. // Start the lookup.
resultC := make(chan []*enode.Node, 1) resultC := make(chan []*enode.Node, 1)
@ -70,9 +70,9 @@ func TestUDPv4_LookupIterator(t *testing.T) {
defer test.close() defer test.close()
// Seed table with initial nodes. // Seed table with initial nodes.
bootnodes := make([]*node, len(lookupTestnet.dists[256])) bootnodes := make([]*enode.Node, len(lookupTestnet.dists[256]))
for i := range lookupTestnet.dists[256] { for i := range lookupTestnet.dists[256] {
bootnodes[i] = wrapNode(lookupTestnet.node(256, i)) bootnodes[i] = lookupTestnet.node(256, i)
} }
fillTable(test.table, bootnodes, true) fillTable(test.table, bootnodes, true)
go serveTestnet(test, lookupTestnet) go serveTestnet(test, lookupTestnet)
@ -105,9 +105,9 @@ func TestUDPv4_LookupIteratorClose(t *testing.T) {
defer test.close() defer test.close()
// Seed table with initial nodes. // Seed table with initial nodes.
bootnodes := make([]*node, len(lookupTestnet.dists[256])) bootnodes := make([]*enode.Node, len(lookupTestnet.dists[256]))
for i := range lookupTestnet.dists[256] { for i := range lookupTestnet.dists[256] {
bootnodes[i] = wrapNode(lookupTestnet.node(256, i)) bootnodes[i] = lookupTestnet.node(256, i)
} }
fillTable(test.table, bootnodes, true) fillTable(test.table, bootnodes, true)
go serveTestnet(test, lookupTestnet) go serveTestnet(test, lookupTestnet)
@ -136,7 +136,7 @@ func TestUDPv4_LookupIteratorClose(t *testing.T) {
func serveTestnet(test *udpTest, testnet *preminedTestnet) { func serveTestnet(test *udpTest, testnet *preminedTestnet) {
for done := false; !done; { for done := false; !done; {
done = test.waitPacketOut(func(p v4wire.Packet, to *net.UDPAddr, hash []byte) { done = test.waitPacketOut(func(p v4wire.Packet, to netip.AddrPort, hash []byte) {
n, key := testnet.nodeByAddr(to) n, key := testnet.nodeByAddr(to)
switch p.(type) { switch p.(type) {
case *v4wire.Ping: case *v4wire.Ping:
@ -158,10 +158,10 @@ func checkLookupResults(t *testing.T, tn *preminedTestnet, results []*enode.Node
for _, e := range results { for _, e := range results {
t.Logf(" ld=%d, %x", enode.LogDist(tn.target.id(), e.ID()), e.ID().Bytes()) t.Logf(" ld=%d, %x", enode.LogDist(tn.target.id(), e.ID()), e.ID().Bytes())
} }
if hasDuplicates(wrapNodes(results)) { if hasDuplicates(results) {
t.Errorf("result set contains duplicate entries") t.Errorf("result set contains duplicate entries")
} }
if !sortedByDistanceTo(tn.target.id(), wrapNodes(results)) { if !sortedByDistanceTo(tn.target.id(), results) {
t.Errorf("result set not sorted by distance to target") t.Errorf("result set not sorted by distance to target")
} }
wantNodes := tn.closest(len(results)) wantNodes := tn.closest(len(results))
@ -264,9 +264,10 @@ func (tn *preminedTestnet) node(dist, index int) *enode.Node {
return n return n
} }
func (tn *preminedTestnet) nodeByAddr(addr *net.UDPAddr) (*enode.Node, *ecdsa.PrivateKey) { func (tn *preminedTestnet) nodeByAddr(addr netip.AddrPort) (*enode.Node, *ecdsa.PrivateKey) {
dist := int(addr.IP[1])<<8 + int(addr.IP[2]) ip := addr.Addr().As4()
index := int(addr.IP[3]) dist := int(ip[1])<<8 + int(ip[2])
index := int(ip[3])
key := tn.dists[dist][index] key := tn.dists[dist][index]
return tn.node(dist, index), key return tn.node(dist, index), key
} }
@ -274,7 +275,7 @@ func (tn *preminedTestnet) nodeByAddr(addr *net.UDPAddr) (*enode.Node, *ecdsa.Pr
func (tn *preminedTestnet) nodesAtDistance(dist int) []v4wire.Node { func (tn *preminedTestnet) nodesAtDistance(dist int) []v4wire.Node {
result := make([]v4wire.Node, len(tn.dists[dist])) result := make([]v4wire.Node, len(tn.dists[dist]))
for i := range result { for i := range result {
result[i] = nodeToRPC(wrapNode(tn.node(dist, i))) result[i] = nodeToRPC(tn.node(dist, i))
} }
return result return result
} }

@ -26,6 +26,7 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"net/netip"
"sync" "sync"
"time" "time"
@ -45,6 +46,7 @@ var (
errClockWarp = errors.New("reply deadline too far in the future") errClockWarp = errors.New("reply deadline too far in the future")
errClosed = errors.New("socket closed") errClosed = errors.New("socket closed")
errLowPort = errors.New("low port") errLowPort = errors.New("low port")
errNoUDPEndpoint = errors.New("node has no UDP endpoint")
) )
const ( const (
@ -93,7 +95,7 @@ type UDPv4 struct {
type replyMatcher struct { type replyMatcher struct {
// these fields must match in the reply. // these fields must match in the reply.
from enode.ID from enode.ID
ip net.IP ip netip.Addr
ptype byte ptype byte
// time when the request must complete // time when the request must complete
@ -119,7 +121,7 @@ type replyMatchFunc func(v4wire.Packet) (matched bool, requestDone bool)
// reply is a reply packet from a certain node. // reply is a reply packet from a certain node.
type reply struct { type reply struct {
from enode.ID from enode.ID
ip net.IP ip netip.Addr
data v4wire.Packet data v4wire.Packet
// loop indicates whether there was // loop indicates whether there was
// a matching request by sending on this channel. // a matching request by sending on this channel.
@ -201,9 +203,12 @@ func (t *UDPv4) Resolve(n *enode.Node) *enode.Node {
} }
func (t *UDPv4) ourEndpoint() v4wire.Endpoint { func (t *UDPv4) ourEndpoint() v4wire.Endpoint {
n := t.Self() node := t.Self()
a := &net.UDPAddr{IP: n.IP(), Port: n.UDP()} addr, ok := node.UDPEndpoint()
return v4wire.NewEndpoint(a, uint16(n.TCP())) if !ok {
return v4wire.Endpoint{}
}
return v4wire.NewEndpoint(addr, uint16(node.TCP()))
} }
// Ping sends a ping message to the given node. // Ping sends a ping message to the given node.
@ -214,7 +219,11 @@ func (t *UDPv4) Ping(n *enode.Node) error {
// ping sends a ping message to the given node and waits for a reply. // ping sends a ping message to the given node and waits for a reply.
func (t *UDPv4) ping(n *enode.Node) (seq uint64, err error) { func (t *UDPv4) ping(n *enode.Node) (seq uint64, err error) {
rm := t.sendPing(n.ID(), &net.UDPAddr{IP: n.IP(), Port: n.UDP()}, nil) addr, ok := n.UDPEndpoint()
if !ok {
return 0, errNoUDPEndpoint
}
rm := t.sendPing(n.ID(), addr, nil)
if err = <-rm.errc; err == nil { if err = <-rm.errc; err == nil {
seq = rm.reply.(*v4wire.Pong).ENRSeq seq = rm.reply.(*v4wire.Pong).ENRSeq
} }
@ -223,7 +232,7 @@ func (t *UDPv4) ping(n *enode.Node) (seq uint64, err error) {
// sendPing sends a ping message to the given node and invokes the callback // sendPing sends a ping message to the given node and invokes the callback
// when the reply arrives. // when the reply arrives.
func (t *UDPv4) sendPing(toid enode.ID, toaddr *net.UDPAddr, callback func()) *replyMatcher { func (t *UDPv4) sendPing(toid enode.ID, toaddr netip.AddrPort, callback func()) *replyMatcher {
req := t.makePing(toaddr) req := t.makePing(toaddr)
packet, hash, err := v4wire.Encode(t.priv, req) packet, hash, err := v4wire.Encode(t.priv, req)
if err != nil { if err != nil {
@ -233,7 +242,7 @@ func (t *UDPv4) sendPing(toid enode.ID, toaddr *net.UDPAddr, callback func()) *r
} }
// Add a matcher for the reply to the pending reply queue. Pongs are matched if they // Add a matcher for the reply to the pending reply queue. Pongs are matched if they
// reference the ping we're about to send. // reference the ping we're about to send.
rm := t.pending(toid, toaddr.IP, v4wire.PongPacket, func(p v4wire.Packet) (matched bool, requestDone bool) { rm := t.pending(toid, toaddr.Addr(), v4wire.PongPacket, func(p v4wire.Packet) (matched bool, requestDone bool) {
matched = bytes.Equal(p.(*v4wire.Pong).ReplyTok, hash) matched = bytes.Equal(p.(*v4wire.Pong).ReplyTok, hash)
if matched && callback != nil { if matched && callback != nil {
callback() callback()
@ -241,12 +250,13 @@ func (t *UDPv4) sendPing(toid enode.ID, toaddr *net.UDPAddr, callback func()) *r
return matched, matched return matched, matched
}) })
// Send the packet. // Send the packet.
t.localNode.UDPContact(toaddr) toUDPAddr := &net.UDPAddr{IP: toaddr.Addr().AsSlice()}
t.localNode.UDPContact(toUDPAddr)
t.write(toaddr, toid, req.Name(), packet) t.write(toaddr, toid, req.Name(), packet)
return rm return rm
} }
func (t *UDPv4) makePing(toaddr *net.UDPAddr) *v4wire.Ping { func (t *UDPv4) makePing(toaddr netip.AddrPort) *v4wire.Ping {
return &v4wire.Ping{ return &v4wire.Ping{
Version: 4, Version: 4,
From: t.ourEndpoint(), From: t.ourEndpoint(),
@ -290,35 +300,39 @@ func (t *UDPv4) newRandomLookup(ctx context.Context) *lookup {
func (t *UDPv4) newLookup(ctx context.Context, targetKey encPubkey) *lookup { func (t *UDPv4) newLookup(ctx context.Context, targetKey encPubkey) *lookup {
target := enode.ID(crypto.Keccak256Hash(targetKey[:])) target := enode.ID(crypto.Keccak256Hash(targetKey[:]))
ekey := v4wire.Pubkey(targetKey) ekey := v4wire.Pubkey(targetKey)
it := newLookup(ctx, t.tab, target, func(n *node) ([]*node, error) { it := newLookup(ctx, t.tab, target, func(n *enode.Node) ([]*enode.Node, error) {
return t.findnode(n.ID(), n.addr(), ekey) addr, ok := n.UDPEndpoint()
if !ok {
return nil, errNoUDPEndpoint
}
return t.findnode(n.ID(), addr, ekey)
}) })
return it return it
} }
// findnode sends a findnode request to the given node and waits until // findnode sends a findnode request to the given node and waits until
// the node has sent up to k neighbors. // the node has sent up to k neighbors.
func (t *UDPv4) findnode(toid enode.ID, toaddr *net.UDPAddr, target v4wire.Pubkey) ([]*node, error) { func (t *UDPv4) findnode(toid enode.ID, toAddrPort netip.AddrPort, target v4wire.Pubkey) ([]*enode.Node, error) {
t.ensureBond(toid, toaddr) t.ensureBond(toid, toAddrPort)
// Add a matcher for 'neighbours' replies to the pending reply queue. The matcher is // Add a matcher for 'neighbours' replies to the pending reply queue. The matcher is
// active until enough nodes have been received. // active until enough nodes have been received.
nodes := make([]*node, 0, bucketSize) nodes := make([]*enode.Node, 0, bucketSize)
nreceived := 0 nreceived := 0
rm := t.pending(toid, toaddr.IP, v4wire.NeighborsPacket, func(r v4wire.Packet) (matched bool, requestDone bool) { rm := t.pending(toid, toAddrPort.Addr(), v4wire.NeighborsPacket, func(r v4wire.Packet) (matched bool, requestDone bool) {
reply := r.(*v4wire.Neighbors) reply := r.(*v4wire.Neighbors)
for _, rn := range reply.Nodes { for _, rn := range reply.Nodes {
nreceived++ nreceived++
n, err := t.nodeFromRPC(toaddr, rn) n, err := t.nodeFromRPC(toAddrPort, rn)
if err != nil { if err != nil {
t.log.Trace("Invalid neighbor node received", "ip", rn.IP, "addr", toaddr, "err", err) t.log.Trace("Invalid neighbor node received", "ip", rn.IP, "addr", toAddrPort, "err", err)
continue continue
} }
nodes = append(nodes, n) nodes = append(nodes, n)
} }
return true, nreceived >= bucketSize return true, nreceived >= bucketSize
}) })
t.send(toaddr, toid, &v4wire.Findnode{ t.send(toAddrPort, toid, &v4wire.Findnode{
Target: target, Target: target,
Expiration: uint64(time.Now().Add(expiration).Unix()), Expiration: uint64(time.Now().Add(expiration).Unix()),
}) })
@ -336,7 +350,7 @@ func (t *UDPv4) findnode(toid enode.ID, toaddr *net.UDPAddr, target v4wire.Pubke
// RequestENR sends ENRRequest to the given node and waits for a response. // RequestENR sends ENRRequest to the given node and waits for a response.
func (t *UDPv4) RequestENR(n *enode.Node) (*enode.Node, error) { func (t *UDPv4) RequestENR(n *enode.Node) (*enode.Node, error) {
addr := &net.UDPAddr{IP: n.IP(), Port: n.UDP()} addr, _ := n.UDPEndpoint()
t.ensureBond(n.ID(), addr) t.ensureBond(n.ID(), addr)
req := &v4wire.ENRRequest{ req := &v4wire.ENRRequest{
@ -349,7 +363,7 @@ func (t *UDPv4) RequestENR(n *enode.Node) (*enode.Node, error) {
// Add a matcher for the reply to the pending reply queue. Responses are matched if // Add a matcher for the reply to the pending reply queue. Responses are matched if
// they reference the request we're about to send. // they reference the request we're about to send.
rm := t.pending(n.ID(), addr.IP, v4wire.ENRResponsePacket, func(r v4wire.Packet) (matched bool, requestDone bool) { rm := t.pending(n.ID(), addr.Addr(), v4wire.ENRResponsePacket, func(r v4wire.Packet) (matched bool, requestDone bool) {
matched = bytes.Equal(r.(*v4wire.ENRResponse).ReplyTok, hash) matched = bytes.Equal(r.(*v4wire.ENRResponse).ReplyTok, hash)
return matched, matched return matched, matched
}) })
@ -369,7 +383,7 @@ func (t *UDPv4) RequestENR(n *enode.Node) (*enode.Node, error) {
if respN.Seq() < n.Seq() { if respN.Seq() < n.Seq() {
return n, nil // response record is older return n, nil // response record is older
} }
if err := netutil.CheckRelayIP(addr.IP, respN.IP()); err != nil { if err := netutil.CheckRelayIP(addr.Addr().AsSlice(), respN.IP()); err != nil {
return nil, fmt.Errorf("invalid IP in response record: %v", err) return nil, fmt.Errorf("invalid IP in response record: %v", err)
} }
return respN, nil return respN, nil
@ -381,7 +395,7 @@ func (t *UDPv4) TableBuckets() [][]BucketNode {
// pending adds a reply matcher to the pending reply queue. // pending adds a reply matcher to the pending reply queue.
// see the documentation of type replyMatcher for a detailed explanation. // see the documentation of type replyMatcher for a detailed explanation.
func (t *UDPv4) pending(id enode.ID, ip net.IP, ptype byte, callback replyMatchFunc) *replyMatcher { func (t *UDPv4) pending(id enode.ID, ip netip.Addr, ptype byte, callback replyMatchFunc) *replyMatcher {
ch := make(chan error, 1) ch := make(chan error, 1)
p := &replyMatcher{from: id, ip: ip, ptype: ptype, callback: callback, errc: ch} p := &replyMatcher{from: id, ip: ip, ptype: ptype, callback: callback, errc: ch}
select { select {
@ -395,7 +409,7 @@ func (t *UDPv4) pending(id enode.ID, ip net.IP, ptype byte, callback replyMatchF
// handleReply dispatches a reply packet, invoking reply matchers. It returns // handleReply dispatches a reply packet, invoking reply matchers. It returns
// whether any matcher considered the packet acceptable. // whether any matcher considered the packet acceptable.
func (t *UDPv4) handleReply(from enode.ID, fromIP net.IP, req v4wire.Packet) bool { func (t *UDPv4) handleReply(from enode.ID, fromIP netip.Addr, req v4wire.Packet) bool {
matched := make(chan bool, 1) matched := make(chan bool, 1)
select { select {
case t.gotreply <- reply{from, fromIP, req, matched}: case t.gotreply <- reply{from, fromIP, req, matched}:
@ -461,7 +475,7 @@ func (t *UDPv4) loop() {
var matched bool // whether any replyMatcher considered the reply acceptable. var matched bool // whether any replyMatcher considered the reply acceptable.
for el := plist.Front(); el != nil; el = el.Next() { for el := plist.Front(); el != nil; el = el.Next() {
p := el.Value.(*replyMatcher) p := el.Value.(*replyMatcher)
if p.from == r.from && p.ptype == r.data.Kind() && p.ip.Equal(r.ip) { if p.from == r.from && p.ptype == r.data.Kind() && p.ip == r.ip {
ok, requestDone := p.callback(r.data) ok, requestDone := p.callback(r.data)
matched = matched || ok matched = matched || ok
p.reply = r.data p.reply = r.data
@ -500,7 +514,7 @@ func (t *UDPv4) loop() {
} }
} }
func (t *UDPv4) send(toaddr *net.UDPAddr, toid enode.ID, req v4wire.Packet) ([]byte, error) { func (t *UDPv4) send(toaddr netip.AddrPort, toid enode.ID, req v4wire.Packet) ([]byte, error) {
packet, hash, err := v4wire.Encode(t.priv, req) packet, hash, err := v4wire.Encode(t.priv, req)
if err != nil { if err != nil {
return hash, err return hash, err
@ -508,8 +522,8 @@ func (t *UDPv4) send(toaddr *net.UDPAddr, toid enode.ID, req v4wire.Packet) ([]b
return hash, t.write(toaddr, toid, req.Name(), packet) return hash, t.write(toaddr, toid, req.Name(), packet)
} }
func (t *UDPv4) write(toaddr *net.UDPAddr, toid enode.ID, what string, packet []byte) error { func (t *UDPv4) write(toaddr netip.AddrPort, toid enode.ID, what string, packet []byte) error {
_, err := t.conn.WriteToUDP(packet, toaddr) _, err := t.conn.WriteToUDPAddrPort(packet, toaddr)
t.log.Trace(">> "+what, "id", toid, "addr", toaddr, "err", err) t.log.Trace(">> "+what, "id", toid, "addr", toaddr, "err", err)
return err return err
} }
@ -523,7 +537,7 @@ func (t *UDPv4) readLoop(unhandled chan<- ReadPacket) {
buf := make([]byte, maxPacketSize) buf := make([]byte, maxPacketSize)
for { for {
nbytes, from, err := t.conn.ReadFromUDP(buf) nbytes, from, err := t.conn.ReadFromUDPAddrPort(buf)
if netutil.IsTemporaryError(err) { if netutil.IsTemporaryError(err) {
// Ignore temporary read errors. // Ignore temporary read errors.
t.log.Debug("Temporary UDP read error", "err", err) t.log.Debug("Temporary UDP read error", "err", err)
@ -544,7 +558,7 @@ func (t *UDPv4) readLoop(unhandled chan<- ReadPacket) {
} }
} }
func (t *UDPv4) handlePacket(from *net.UDPAddr, buf []byte) error { func (t *UDPv4) handlePacket(from netip.AddrPort, buf []byte) error {
rawpacket, fromKey, hash, err := v4wire.Decode(buf) rawpacket, fromKey, hash, err := v4wire.Decode(buf)
if err != nil { if err != nil {
t.log.Debug("Bad discv4 packet", "addr", from, "err", err) t.log.Debug("Bad discv4 packet", "addr", from, "err", err)
@ -563,15 +577,16 @@ func (t *UDPv4) handlePacket(from *net.UDPAddr, buf []byte) error {
} }
// checkBond checks if the given node has a recent enough endpoint proof. // checkBond checks if the given node has a recent enough endpoint proof.
func (t *UDPv4) checkBond(id enode.ID, ip net.IP) bool { func (t *UDPv4) checkBond(id enode.ID, ip netip.AddrPort) bool {
return time.Since(t.db.LastPongReceived(id, ip)) < bondExpiration return time.Since(t.db.LastPongReceived(id, ip.Addr().AsSlice())) < bondExpiration
} }
// ensureBond solicits a ping from a node if we haven't seen a ping from it for a while. // ensureBond solicits a ping from a node if we haven't seen a ping from it for a while.
// This ensures there is a valid endpoint proof on the remote end. // This ensures there is a valid endpoint proof on the remote end.
func (t *UDPv4) ensureBond(toid enode.ID, toaddr *net.UDPAddr) { func (t *UDPv4) ensureBond(toid enode.ID, toaddr netip.AddrPort) {
tooOld := time.Since(t.db.LastPingReceived(toid, toaddr.IP)) > bondExpiration ip := toaddr.Addr().AsSlice()
if tooOld || t.db.FindFails(toid, toaddr.IP) > maxFindnodeFailures { tooOld := time.Since(t.db.LastPingReceived(toid, ip)) > bondExpiration
if tooOld || t.db.FindFails(toid, ip) > maxFindnodeFailures {
rm := t.sendPing(toid, toaddr, nil) rm := t.sendPing(toid, toaddr, nil)
<-rm.errc <-rm.errc
// Wait for them to ping back and process our pong. // Wait for them to ping back and process our pong.
@ -579,11 +594,11 @@ func (t *UDPv4) ensureBond(toid enode.ID, toaddr *net.UDPAddr) {
} }
} }
func (t *UDPv4) nodeFromRPC(sender *net.UDPAddr, rn v4wire.Node) (*node, error) { func (t *UDPv4) nodeFromRPC(sender netip.AddrPort, rn v4wire.Node) (*enode.Node, error) {
if rn.UDP <= 1024 { if rn.UDP <= 1024 {
return nil, errLowPort return nil, errLowPort
} }
if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil { if err := netutil.CheckRelayIP(sender.Addr().AsSlice(), rn.IP); err != nil {
return nil, err return nil, err
} }
if t.netrestrict != nil && !t.netrestrict.Contains(rn.IP) { if t.netrestrict != nil && !t.netrestrict.Contains(rn.IP) {
@ -593,12 +608,12 @@ func (t *UDPv4) nodeFromRPC(sender *net.UDPAddr, rn v4wire.Node) (*node, error)
if err != nil { if err != nil {
return nil, err return nil, err
} }
n := wrapNode(enode.NewV4(key, rn.IP, int(rn.TCP), int(rn.UDP))) n := enode.NewV4(key, rn.IP, int(rn.TCP), int(rn.UDP))
err = n.ValidateComplete() err = n.ValidateComplete()
return n, err return n, err
} }
func nodeToRPC(n *node) v4wire.Node { func nodeToRPC(n *enode.Node) v4wire.Node {
var key ecdsa.PublicKey var key ecdsa.PublicKey
var ekey v4wire.Pubkey var ekey v4wire.Pubkey
if err := n.Load((*enode.Secp256k1)(&key)); err == nil { if err := n.Load((*enode.Secp256k1)(&key)); err == nil {
@ -637,14 +652,14 @@ type packetHandlerV4 struct {
senderKey *ecdsa.PublicKey // used for ping senderKey *ecdsa.PublicKey // used for ping
// preverify checks whether the packet is valid and should be handled at all. // preverify checks whether the packet is valid and should be handled at all.
preverify func(p *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, fromKey v4wire.Pubkey) error preverify func(p *packetHandlerV4, from netip.AddrPort, fromID enode.ID, fromKey v4wire.Pubkey) error
// handle handles the packet. // handle handles the packet.
handle func(req *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, mac []byte) handle func(req *packetHandlerV4, from netip.AddrPort, fromID enode.ID, mac []byte)
} }
// PING/v4 // PING/v4
func (t *UDPv4) verifyPing(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, fromKey v4wire.Pubkey) error { func (t *UDPv4) verifyPing(h *packetHandlerV4, from netip.AddrPort, fromID enode.ID, fromKey v4wire.Pubkey) error {
req := h.Packet.(*v4wire.Ping) req := h.Packet.(*v4wire.Ping)
if v4wire.Expired(req.Expiration) { if v4wire.Expired(req.Expiration) {
@ -658,7 +673,7 @@ func (t *UDPv4) verifyPing(h *packetHandlerV4, from *net.UDPAddr, fromID enode.I
return nil return nil
} }
func (t *UDPv4) handlePing(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, mac []byte) { func (t *UDPv4) handlePing(h *packetHandlerV4, from netip.AddrPort, fromID enode.ID, mac []byte) {
req := h.Packet.(*v4wire.Ping) req := h.Packet.(*v4wire.Ping)
// Reply. // Reply.
@ -670,8 +685,9 @@ func (t *UDPv4) handlePing(h *packetHandlerV4, from *net.UDPAddr, fromID enode.I
}) })
// Ping back if our last pong on file is too far in the past. // Ping back if our last pong on file is too far in the past.
n := wrapNode(enode.NewV4(h.senderKey, from.IP, int(req.From.TCP), from.Port)) fromIP := from.Addr().AsSlice()
if time.Since(t.db.LastPongReceived(n.ID(), from.IP)) > bondExpiration { n := enode.NewV4(h.senderKey, fromIP, int(req.From.TCP), int(from.Port()))
if time.Since(t.db.LastPongReceived(n.ID(), fromIP)) > bondExpiration {
t.sendPing(fromID, from, func() { t.sendPing(fromID, from, func() {
t.tab.addInboundNode(n) t.tab.addInboundNode(n)
}) })
@ -680,35 +696,40 @@ func (t *UDPv4) handlePing(h *packetHandlerV4, from *net.UDPAddr, fromID enode.I
} }
// Update node database and endpoint predictor. // Update node database and endpoint predictor.
t.db.UpdateLastPingReceived(n.ID(), from.IP, time.Now()) t.db.UpdateLastPingReceived(n.ID(), fromIP, time.Now())
t.localNode.UDPEndpointStatement(from, &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)}) fromUDPAddr := &net.UDPAddr{IP: fromIP, Port: int(from.Port())}
toUDPAddr := &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)}
t.localNode.UDPEndpointStatement(fromUDPAddr, toUDPAddr)
} }
// PONG/v4 // PONG/v4
func (t *UDPv4) verifyPong(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, fromKey v4wire.Pubkey) error { func (t *UDPv4) verifyPong(h *packetHandlerV4, from netip.AddrPort, fromID enode.ID, fromKey v4wire.Pubkey) error {
req := h.Packet.(*v4wire.Pong) req := h.Packet.(*v4wire.Pong)
if v4wire.Expired(req.Expiration) { if v4wire.Expired(req.Expiration) {
return errExpired return errExpired
} }
if !t.handleReply(fromID, from.IP, req) { if !t.handleReply(fromID, from.Addr(), req) {
return errUnsolicitedReply return errUnsolicitedReply
} }
t.localNode.UDPEndpointStatement(from, &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)}) fromIP := from.Addr().AsSlice()
t.db.UpdateLastPongReceived(fromID, from.IP, time.Now()) fromUDPAddr := &net.UDPAddr{IP: fromIP, Port: int(from.Port())}
toUDPAddr := &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)}
t.localNode.UDPEndpointStatement(fromUDPAddr, toUDPAddr)
t.db.UpdateLastPongReceived(fromID, fromIP, time.Now())
return nil return nil
} }
// FINDNODE/v4 // FINDNODE/v4
func (t *UDPv4) verifyFindnode(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, fromKey v4wire.Pubkey) error { func (t *UDPv4) verifyFindnode(h *packetHandlerV4, from netip.AddrPort, fromID enode.ID, fromKey v4wire.Pubkey) error {
req := h.Packet.(*v4wire.Findnode) req := h.Packet.(*v4wire.Findnode)
if v4wire.Expired(req.Expiration) { if v4wire.Expired(req.Expiration) {
return errExpired return errExpired
} }
if !t.checkBond(fromID, from.IP) { if !t.checkBond(fromID, from) {
// No endpoint proof pong exists, we don't process the packet. This prevents an // No endpoint proof pong exists, we don't process the packet. This prevents an
// attack vector where the discovery protocol could be used to amplify traffic in a // attack vector where the discovery protocol could be used to amplify traffic in a
// DDOS attack. A malicious actor would send a findnode request with the IP address // DDOS attack. A malicious actor would send a findnode request with the IP address
@ -720,7 +741,7 @@ func (t *UDPv4) verifyFindnode(h *packetHandlerV4, from *net.UDPAddr, fromID eno
return nil return nil
} }
func (t *UDPv4) handleFindnode(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, mac []byte) { func (t *UDPv4) handleFindnode(h *packetHandlerV4, from netip.AddrPort, fromID enode.ID, mac []byte) {
req := h.Packet.(*v4wire.Findnode) req := h.Packet.(*v4wire.Findnode)
// Determine closest nodes. // Determine closest nodes.
@ -732,7 +753,8 @@ func (t *UDPv4) handleFindnode(h *packetHandlerV4, from *net.UDPAddr, fromID eno
p := v4wire.Neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())} p := v4wire.Neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())}
var sent bool var sent bool
for _, n := range closest { for _, n := range closest {
if netutil.CheckRelayIP(from.IP, n.IP()) == nil { fromIP := from.Addr().AsSlice()
if netutil.CheckRelayIP(fromIP, n.IP()) == nil {
p.Nodes = append(p.Nodes, nodeToRPC(n)) p.Nodes = append(p.Nodes, nodeToRPC(n))
} }
if len(p.Nodes) == v4wire.MaxNeighbors { if len(p.Nodes) == v4wire.MaxNeighbors {
@ -748,13 +770,13 @@ func (t *UDPv4) handleFindnode(h *packetHandlerV4, from *net.UDPAddr, fromID eno
// NEIGHBORS/v4 // NEIGHBORS/v4
func (t *UDPv4) verifyNeighbors(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, fromKey v4wire.Pubkey) error { func (t *UDPv4) verifyNeighbors(h *packetHandlerV4, from netip.AddrPort, fromID enode.ID, fromKey v4wire.Pubkey) error {
req := h.Packet.(*v4wire.Neighbors) req := h.Packet.(*v4wire.Neighbors)
if v4wire.Expired(req.Expiration) { if v4wire.Expired(req.Expiration) {
return errExpired return errExpired
} }
if !t.handleReply(fromID, from.IP, h.Packet) { if !t.handleReply(fromID, from.Addr(), h.Packet) {
return errUnsolicitedReply return errUnsolicitedReply
} }
return nil return nil
@ -762,19 +784,19 @@ func (t *UDPv4) verifyNeighbors(h *packetHandlerV4, from *net.UDPAddr, fromID en
// ENRREQUEST/v4 // ENRREQUEST/v4
func (t *UDPv4) verifyENRRequest(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, fromKey v4wire.Pubkey) error { func (t *UDPv4) verifyENRRequest(h *packetHandlerV4, from netip.AddrPort, fromID enode.ID, fromKey v4wire.Pubkey) error {
req := h.Packet.(*v4wire.ENRRequest) req := h.Packet.(*v4wire.ENRRequest)
if v4wire.Expired(req.Expiration) { if v4wire.Expired(req.Expiration) {
return errExpired return errExpired
} }
if !t.checkBond(fromID, from.IP) { if !t.checkBond(fromID, from) {
return errUnknownNode return errUnknownNode
} }
return nil return nil
} }
func (t *UDPv4) handleENRRequest(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, mac []byte) { func (t *UDPv4) handleENRRequest(h *packetHandlerV4, from netip.AddrPort, fromID enode.ID, mac []byte) {
t.send(from, fromID, &v4wire.ENRResponse{ t.send(from, fromID, &v4wire.ENRResponse{
ReplyTok: mac, ReplyTok: mac,
Record: *t.localNode.Node().Record(), Record: *t.localNode.Node().Record(),
@ -783,8 +805,8 @@ func (t *UDPv4) handleENRRequest(h *packetHandlerV4, from *net.UDPAddr, fromID e
// ENRRESPONSE/v4 // ENRRESPONSE/v4
func (t *UDPv4) verifyENRResponse(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, fromKey v4wire.Pubkey) error { func (t *UDPv4) verifyENRResponse(h *packetHandlerV4, from netip.AddrPort, fromID enode.ID, fromKey v4wire.Pubkey) error {
if !t.handleReply(fromID, from.IP, h.Packet) { if !t.handleReply(fromID, from.Addr(), h.Packet) {
return errUnsolicitedReply return errUnsolicitedReply
} }
return nil return nil

@ -26,6 +26,7 @@ import (
"io" "io"
"math/rand" "math/rand"
"net" "net"
"net/netip"
"reflect" "reflect"
"sync" "sync"
"testing" "testing"
@ -55,7 +56,7 @@ type udpTest struct {
udp *UDPv4 udp *UDPv4
sent [][]byte sent [][]byte
localkey, remotekey *ecdsa.PrivateKey localkey, remotekey *ecdsa.PrivateKey
remoteaddr *net.UDPAddr remoteaddr netip.AddrPort
} }
func newUDPTest(t *testing.T) *udpTest { func newUDPTest(t *testing.T) *udpTest {
@ -64,7 +65,7 @@ func newUDPTest(t *testing.T) *udpTest {
pipe: newpipe(), pipe: newpipe(),
localkey: newkey(), localkey: newkey(),
remotekey: newkey(), remotekey: newkey(),
remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303}, remoteaddr: netip.MustParseAddrPort("10.0.1.99:30303"),
} }
test.db, _ = enode.OpenDB("") test.db, _ = enode.OpenDB("")
@ -92,7 +93,7 @@ func (test *udpTest) packetIn(wantError error, data v4wire.Packet) {
} }
// handles a packet as if it had been sent to the transport by the key/endpoint. // handles a packet as if it had been sent to the transport by the key/endpoint.
func (test *udpTest) packetInFrom(wantError error, key *ecdsa.PrivateKey, addr *net.UDPAddr, data v4wire.Packet) { func (test *udpTest) packetInFrom(wantError error, key *ecdsa.PrivateKey, addr netip.AddrPort, data v4wire.Packet) {
test.t.Helper() test.t.Helper()
enc, _, err := v4wire.Encode(key, data) enc, _, err := v4wire.Encode(key, data)
@ -106,7 +107,7 @@ func (test *udpTest) packetInFrom(wantError error, key *ecdsa.PrivateKey, addr *
} }
// waits for a packet to be sent by the transport. // waits for a packet to be sent by the transport.
// validate should have type func(X, *net.UDPAddr, []byte), where X is a packet type. // validate should have type func(X, netip.AddrPort, []byte), where X is a packet type.
func (test *udpTest) waitPacketOut(validate interface{}) (closed bool) { func (test *udpTest) waitPacketOut(validate interface{}) (closed bool) {
test.t.Helper() test.t.Helper()
@ -128,7 +129,7 @@ func (test *udpTest) waitPacketOut(validate interface{}) (closed bool) {
test.t.Errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype) test.t.Errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype)
return false return false
} }
fn.Call([]reflect.Value{reflect.ValueOf(p), reflect.ValueOf(&dgram.to), reflect.ValueOf(hash)}) fn.Call([]reflect.Value{reflect.ValueOf(p), reflect.ValueOf(dgram.to), reflect.ValueOf(hash)})
return false return false
} }
@ -236,7 +237,7 @@ func TestUDPv4_findnodeTimeout(t *testing.T) {
test := newUDPTest(t) test := newUDPTest(t)
defer test.close() defer test.close()
toaddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 2222} toaddr := netip.AddrPortFrom(netip.MustParseAddr("1.2.3.4"), 2222)
toid := enode.ID{1, 2, 3, 4} toid := enode.ID{1, 2, 3, 4}
target := v4wire.Pubkey{4, 5, 6, 7} target := v4wire.Pubkey{4, 5, 6, 7}
result, err := test.udp.findnode(toid, toaddr, target) result, err := test.udp.findnode(toid, toaddr, target)
@ -261,26 +262,25 @@ func TestUDPv4_findnode(t *testing.T) {
for i := 0; i < numCandidates; i++ { for i := 0; i < numCandidates; i++ {
key := newkey() key := newkey()
ip := net.IP{10, 13, 0, byte(i)} ip := net.IP{10, 13, 0, byte(i)}
n := wrapNode(enode.NewV4(&key.PublicKey, ip, 0, 2000)) n := enode.NewV4(&key.PublicKey, ip, 0, 2000)
// Ensure half of table content isn't verified live yet. // Ensure half of table content isn't verified live yet.
if i > numCandidates/2 { if i > numCandidates/2 {
n.isValidatedLive = true
live[n.ID()] = true live[n.ID()] = true
} }
test.table.addFoundNode(n, live[n.ID()])
nodes.push(n, numCandidates) nodes.push(n, numCandidates)
} }
fillTable(test.table, nodes.entries, false)
// ensure there's a bond with the test node, // ensure there's a bond with the test node,
// findnode won't be accepted otherwise. // findnode won't be accepted otherwise.
remoteID := v4wire.EncodePubkey(&test.remotekey.PublicKey).ID() remoteID := v4wire.EncodePubkey(&test.remotekey.PublicKey).ID()
test.table.db.UpdateLastPongReceived(remoteID, test.remoteaddr.IP, time.Now()) test.table.db.UpdateLastPongReceived(remoteID, test.remoteaddr.Addr().AsSlice(), time.Now())
// check that closest neighbors are returned. // check that closest neighbors are returned.
expected := test.table.findnodeByID(testTarget.ID(), bucketSize, true) expected := test.table.findnodeByID(testTarget.ID(), bucketSize, true)
test.packetIn(nil, &v4wire.Findnode{Target: testTarget, Expiration: futureExp}) test.packetIn(nil, &v4wire.Findnode{Target: testTarget, Expiration: futureExp})
waitNeighbors := func(want []*node) { waitNeighbors := func(want []*enode.Node) {
test.waitPacketOut(func(p *v4wire.Neighbors, to *net.UDPAddr, hash []byte) { test.waitPacketOut(func(p *v4wire.Neighbors, to netip.AddrPort, hash []byte) {
if len(p.Nodes) != len(want) { if len(p.Nodes) != len(want) {
t.Errorf("wrong number of results: got %d, want %d", len(p.Nodes), len(want)) t.Errorf("wrong number of results: got %d, want %d", len(p.Nodes), len(want))
return return
@ -309,10 +309,10 @@ func TestUDPv4_findnodeMultiReply(t *testing.T) {
defer test.close() defer test.close()
rid := enode.PubkeyToIDV4(&test.remotekey.PublicKey) rid := enode.PubkeyToIDV4(&test.remotekey.PublicKey)
test.table.db.UpdateLastPingReceived(rid, test.remoteaddr.IP, time.Now()) test.table.db.UpdateLastPingReceived(rid, test.remoteaddr.Addr().AsSlice(), time.Now())
// queue a pending findnode request // queue a pending findnode request
resultc, errc := make(chan []*node, 1), make(chan error, 1) resultc, errc := make(chan []*enode.Node, 1), make(chan error, 1)
go func() { go func() {
rid := encodePubkey(&test.remotekey.PublicKey).id() rid := encodePubkey(&test.remotekey.PublicKey).id()
ns, err := test.udp.findnode(rid, test.remoteaddr, testTarget) ns, err := test.udp.findnode(rid, test.remoteaddr, testTarget)
@ -325,18 +325,18 @@ func TestUDPv4_findnodeMultiReply(t *testing.T) {
// wait for the findnode to be sent. // wait for the findnode to be sent.
// after it is sent, the transport is waiting for a reply // after it is sent, the transport is waiting for a reply
test.waitPacketOut(func(p *v4wire.Findnode, to *net.UDPAddr, hash []byte) { test.waitPacketOut(func(p *v4wire.Findnode, to netip.AddrPort, hash []byte) {
if p.Target != testTarget { if p.Target != testTarget {
t.Errorf("wrong target: got %v, want %v", p.Target, testTarget) t.Errorf("wrong target: got %v, want %v", p.Target, testTarget)
} }
}) })
// send the reply as two packets. // send the reply as two packets.
list := []*node{ list := []*enode.Node{
wrapNode(enode.MustParse("enode://ba85011c70bcc5c04d8607d3a0ed29aa6179c092cbdda10d5d32684fb33ed01bd94f588ca8f91ac48318087dcb02eaf36773a7a453f0eedd6742af668097b29c@10.0.1.16:30303?discport=30304")), enode.MustParse("enode://ba85011c70bcc5c04d8607d3a0ed29aa6179c092cbdda10d5d32684fb33ed01bd94f588ca8f91ac48318087dcb02eaf36773a7a453f0eedd6742af668097b29c@10.0.1.16:30303?discport=30304"),
wrapNode(enode.MustParse("enode://81fa361d25f157cd421c60dcc28d8dac5ef6a89476633339c5df30287474520caca09627da18543d9079b5b288698b542d56167aa5c09111e55acdbbdf2ef799@10.0.1.16:30303")), enode.MustParse("enode://81fa361d25f157cd421c60dcc28d8dac5ef6a89476633339c5df30287474520caca09627da18543d9079b5b288698b542d56167aa5c09111e55acdbbdf2ef799@10.0.1.16:30303"),
wrapNode(enode.MustParse("enode://9bffefd833d53fac8e652415f4973bee289e8b1a5c6c4cbe70abf817ce8a64cee11b823b66a987f51aaa9fba0d6a91b3e6bf0d5a5d1042de8e9eeea057b217f8@10.0.1.36:30301?discport=17")), enode.MustParse("enode://9bffefd833d53fac8e652415f4973bee289e8b1a5c6c4cbe70abf817ce8a64cee11b823b66a987f51aaa9fba0d6a91b3e6bf0d5a5d1042de8e9eeea057b217f8@10.0.1.36:30301?discport=17"),
wrapNode(enode.MustParse("enode://1b5b4aa662d7cb44a7221bfba67302590b643028197a7d5214790f3bac7aaa4a3241be9e83c09cf1f6c69d007c634faae3dc1b1221793e8446c0b3a09de65960@10.0.1.16:30303")), enode.MustParse("enode://1b5b4aa662d7cb44a7221bfba67302590b643028197a7d5214790f3bac7aaa4a3241be9e83c09cf1f6c69d007c634faae3dc1b1221793e8446c0b3a09de65960@10.0.1.16:30303"),
} }
rpclist := make([]v4wire.Node, len(list)) rpclist := make([]v4wire.Node, len(list))
for i := range list { for i := range list {
@ -368,8 +368,8 @@ func TestUDPv4_pingMatch(t *testing.T) {
crand.Read(randToken) crand.Read(randToken)
test.packetIn(nil, &v4wire.Ping{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp}) test.packetIn(nil, &v4wire.Ping{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp})
test.waitPacketOut(func(*v4wire.Pong, *net.UDPAddr, []byte) {}) test.waitPacketOut(func(*v4wire.Pong, netip.AddrPort, []byte) {})
test.waitPacketOut(func(*v4wire.Ping, *net.UDPAddr, []byte) {}) test.waitPacketOut(func(*v4wire.Ping, netip.AddrPort, []byte) {})
test.packetIn(errUnsolicitedReply, &v4wire.Pong{ReplyTok: randToken, To: testLocalAnnounced, Expiration: futureExp}) test.packetIn(errUnsolicitedReply, &v4wire.Pong{ReplyTok: randToken, To: testLocalAnnounced, Expiration: futureExp})
} }
@ -379,10 +379,10 @@ func TestUDPv4_pingMatchIP(t *testing.T) {
defer test.close() defer test.close()
test.packetIn(nil, &v4wire.Ping{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp}) test.packetIn(nil, &v4wire.Ping{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp})
test.waitPacketOut(func(*v4wire.Pong, *net.UDPAddr, []byte) {}) test.waitPacketOut(func(*v4wire.Pong, netip.AddrPort, []byte) {})
test.waitPacketOut(func(p *v4wire.Ping, to *net.UDPAddr, hash []byte) { test.waitPacketOut(func(p *v4wire.Ping, to netip.AddrPort, hash []byte) {
wrongAddr := &net.UDPAddr{IP: net.IP{33, 44, 1, 2}, Port: 30000} wrongAddr := netip.MustParseAddrPort("33.44.1.2:30000")
test.packetInFrom(errUnsolicitedReply, test.remotekey, wrongAddr, &v4wire.Pong{ test.packetInFrom(errUnsolicitedReply, test.remotekey, wrongAddr, &v4wire.Pong{
ReplyTok: hash, ReplyTok: hash,
To: testLocalAnnounced, To: testLocalAnnounced,
@ -393,41 +393,36 @@ func TestUDPv4_pingMatchIP(t *testing.T) {
func TestUDPv4_successfulPing(t *testing.T) { func TestUDPv4_successfulPing(t *testing.T) {
test := newUDPTest(t) test := newUDPTest(t)
added := make(chan *node, 1) added := make(chan *tableNode, 1)
test.table.nodeAddedHook = func(b *bucket, n *node) { added <- n } test.table.nodeAddedHook = func(b *bucket, n *tableNode) { added <- n }
defer test.close() defer test.close()
// The remote side sends a ping packet to initiate the exchange. // The remote side sends a ping packet to initiate the exchange.
go test.packetIn(nil, &v4wire.Ping{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp}) go test.packetIn(nil, &v4wire.Ping{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp})
// The ping is replied to. // The ping is replied to.
test.waitPacketOut(func(p *v4wire.Pong, to *net.UDPAddr, hash []byte) { test.waitPacketOut(func(p *v4wire.Pong, to netip.AddrPort, hash []byte) {
pinghash := test.sent[0][:32] pinghash := test.sent[0][:32]
if !bytes.Equal(p.ReplyTok, pinghash) { if !bytes.Equal(p.ReplyTok, pinghash) {
t.Errorf("got pong.ReplyTok %x, want %x", p.ReplyTok, pinghash) t.Errorf("got pong.ReplyTok %x, want %x", p.ReplyTok, pinghash)
} }
wantTo := v4wire.Endpoint{ // The mirrored UDP address is the UDP packet sender.
// The mirrored UDP address is the UDP packet sender // The mirrored TCP port is the one from the ping packet.
IP: test.remoteaddr.IP, UDP: uint16(test.remoteaddr.Port), wantTo := v4wire.NewEndpoint(test.remoteaddr, testRemote.TCP)
// The mirrored TCP port is the one from the ping packet
TCP: testRemote.TCP,
}
if !reflect.DeepEqual(p.To, wantTo) { if !reflect.DeepEqual(p.To, wantTo) {
t.Errorf("got pong.To %v, want %v", p.To, wantTo) t.Errorf("got pong.To %v, want %v", p.To, wantTo)
} }
}) })
// Remote is unknown, the table pings back. // Remote is unknown, the table pings back.
test.waitPacketOut(func(p *v4wire.Ping, to *net.UDPAddr, hash []byte) { test.waitPacketOut(func(p *v4wire.Ping, to netip.AddrPort, hash []byte) {
if !reflect.DeepEqual(p.From, test.udp.ourEndpoint()) { wantFrom := test.udp.ourEndpoint()
wantFrom.IP = net.IP{}
if !reflect.DeepEqual(p.From, wantFrom) {
t.Errorf("got ping.From %#v, want %#v", p.From, test.udp.ourEndpoint()) t.Errorf("got ping.From %#v, want %#v", p.From, test.udp.ourEndpoint())
} }
wantTo := v4wire.Endpoint{ // The mirrored UDP address is the UDP packet sender.
// The mirrored UDP address is the UDP packet sender. wantTo := v4wire.NewEndpoint(test.remoteaddr, 0)
IP: test.remoteaddr.IP,
UDP: uint16(test.remoteaddr.Port),
TCP: 0,
}
if !reflect.DeepEqual(p.To, wantTo) { if !reflect.DeepEqual(p.To, wantTo) {
t.Errorf("got ping.To %v, want %v", p.To, wantTo) t.Errorf("got ping.To %v, want %v", p.To, wantTo)
} }
@ -442,11 +437,11 @@ func TestUDPv4_successfulPing(t *testing.T) {
if n.ID() != rid { if n.ID() != rid {
t.Errorf("node has wrong ID: got %v, want %v", n.ID(), rid) t.Errorf("node has wrong ID: got %v, want %v", n.ID(), rid)
} }
if !n.IP().Equal(test.remoteaddr.IP) { if !n.IP().Equal(test.remoteaddr.Addr().AsSlice()) {
t.Errorf("node has wrong IP: got %v, want: %v", n.IP(), test.remoteaddr.IP) t.Errorf("node has wrong IP: got %v, want: %v", n.IP(), test.remoteaddr.Addr())
} }
if n.UDP() != test.remoteaddr.Port { if n.UDP() != int(test.remoteaddr.Port()) {
t.Errorf("node has wrong UDP port: got %v, want: %v", n.UDP(), test.remoteaddr.Port) t.Errorf("node has wrong UDP port: got %v, want: %v", n.UDP(), test.remoteaddr.Port())
} }
if n.TCP() != int(testRemote.TCP) { if n.TCP() != int(testRemote.TCP) {
t.Errorf("node has wrong TCP port: got %v, want: %v", n.TCP(), testRemote.TCP) t.Errorf("node has wrong TCP port: got %v, want: %v", n.TCP(), testRemote.TCP)
@ -469,12 +464,12 @@ func TestUDPv4_EIP868(t *testing.T) {
// Perform endpoint proof and check for sequence number in packet tail. // Perform endpoint proof and check for sequence number in packet tail.
test.packetIn(nil, &v4wire.Ping{Expiration: futureExp}) test.packetIn(nil, &v4wire.Ping{Expiration: futureExp})
test.waitPacketOut(func(p *v4wire.Pong, addr *net.UDPAddr, hash []byte) { test.waitPacketOut(func(p *v4wire.Pong, addr netip.AddrPort, hash []byte) {
if p.ENRSeq != wantNode.Seq() { if p.ENRSeq != wantNode.Seq() {
t.Errorf("wrong sequence number in pong: %d, want %d", p.ENRSeq, wantNode.Seq()) t.Errorf("wrong sequence number in pong: %d, want %d", p.ENRSeq, wantNode.Seq())
} }
}) })
test.waitPacketOut(func(p *v4wire.Ping, addr *net.UDPAddr, hash []byte) { test.waitPacketOut(func(p *v4wire.Ping, addr netip.AddrPort, hash []byte) {
if p.ENRSeq != wantNode.Seq() { if p.ENRSeq != wantNode.Seq() {
t.Errorf("wrong sequence number in ping: %d, want %d", p.ENRSeq, wantNode.Seq()) t.Errorf("wrong sequence number in ping: %d, want %d", p.ENRSeq, wantNode.Seq())
} }
@ -483,7 +478,7 @@ func TestUDPv4_EIP868(t *testing.T) {
// Request should work now. // Request should work now.
test.packetIn(nil, &v4wire.ENRRequest{Expiration: futureExp}) test.packetIn(nil, &v4wire.ENRRequest{Expiration: futureExp})
test.waitPacketOut(func(p *v4wire.ENRResponse, addr *net.UDPAddr, hash []byte) { test.waitPacketOut(func(p *v4wire.ENRResponse, addr netip.AddrPort, hash []byte) {
n, err := enode.New(enode.ValidSchemes, &p.Record) n, err := enode.New(enode.ValidSchemes, &p.Record)
if err != nil { if err != nil {
t.Fatalf("invalid record: %v", err) t.Fatalf("invalid record: %v", err)
@ -584,7 +579,7 @@ type dgramPipe struct {
} }
type dgram struct { type dgram struct {
to net.UDPAddr to netip.AddrPort
data []byte data []byte
} }
@ -597,8 +592,8 @@ func newpipe() *dgramPipe {
} }
} }
// WriteToUDP queues a datagram. // WriteToUDPAddrPort queues a datagram.
func (c *dgramPipe) WriteToUDP(b []byte, to *net.UDPAddr) (n int, err error) { func (c *dgramPipe) WriteToUDPAddrPort(b []byte, to netip.AddrPort) (n int, err error) {
msg := make([]byte, len(b)) msg := make([]byte, len(b))
copy(msg, b) copy(msg, b)
c.mu.Lock() c.mu.Lock()
@ -606,15 +601,15 @@ func (c *dgramPipe) WriteToUDP(b []byte, to *net.UDPAddr) (n int, err error) {
if c.closed { if c.closed {
return 0, errors.New("closed") return 0, errors.New("closed")
} }
c.queue = append(c.queue, dgram{*to, b}) c.queue = append(c.queue, dgram{to, b})
c.cond.Signal() c.cond.Signal()
return len(b), nil return len(b), nil
} }
// ReadFromUDP just hangs until the pipe is closed. // ReadFromUDPAddrPort just hangs until the pipe is closed.
func (c *dgramPipe) ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) { func (c *dgramPipe) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) {
<-c.closing <-c.closing
return 0, nil, io.EOF return 0, netip.AddrPort{}, io.EOF
} }
func (c *dgramPipe) Close() error { func (c *dgramPipe) Close() error {

@ -25,6 +25,7 @@ import (
"fmt" "fmt"
"math/big" "math/big"
"net" "net"
"net/netip"
"time" "time"
"github.com/ethereum/go-ethereum/common/math" "github.com/ethereum/go-ethereum/common/math"
@ -150,14 +151,15 @@ type Endpoint struct {
} }
// NewEndpoint creates an endpoint. // NewEndpoint creates an endpoint.
func NewEndpoint(addr *net.UDPAddr, tcpPort uint16) Endpoint { func NewEndpoint(addr netip.AddrPort, tcpPort uint16) Endpoint {
ip := net.IP{} var ip net.IP
if ip4 := addr.IP.To4(); ip4 != nil { if addr.Addr().Is4() || addr.Addr().Is4In6() {
ip = ip4 ip4 := addr.Addr().As4()
} else if ip6 := addr.IP.To16(); ip6 != nil { ip = ip4[:]
ip = ip6 } else {
ip = addr.Addr().AsSlice()
} }
return Endpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort} return Endpoint{IP: ip, UDP: addr.Port(), TCP: tcpPort}
} }
type Packet interface { type Packet interface {

@ -18,6 +18,7 @@ package discover
import ( import (
"net" "net"
"net/netip"
"sync" "sync"
"time" "time"
@ -70,7 +71,7 @@ func (t *talkSystem) register(protocol string, handler TalkRequestHandler) {
} }
// handleRequest handles a talk request. // handleRequest handles a talk request.
func (t *talkSystem) handleRequest(id enode.ID, addr *net.UDPAddr, req *v5wire.TalkRequest) { func (t *talkSystem) handleRequest(id enode.ID, addr netip.AddrPort, req *v5wire.TalkRequest) {
t.mutex.Lock() t.mutex.Lock()
handler, ok := t.handlers[req.Protocol] handler, ok := t.handlers[req.Protocol]
t.mutex.Unlock() t.mutex.Unlock()
@ -88,7 +89,8 @@ func (t *talkSystem) handleRequest(id enode.ID, addr *net.UDPAddr, req *v5wire.T
case <-t.slots: case <-t.slots:
go func() { go func() {
defer func() { t.slots <- struct{}{} }() defer func() { t.slots <- struct{}{} }()
respMessage := handler(id, addr, req.Message) udpAddr := &net.UDPAddr{IP: addr.Addr().AsSlice(), Port: int(addr.Port())}
respMessage := handler(id, udpAddr, req.Message)
resp := &v5wire.TalkResponse{ReqID: req.ReqID, Message: respMessage} resp := &v5wire.TalkResponse{ReqID: req.ReqID, Message: respMessage}
t.transport.sendFromAnotherThread(id, addr, resp) t.transport.sendFromAnotherThread(id, addr, resp)
}() }()

@ -25,6 +25,7 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"net/netip"
"slices" "slices"
"sync" "sync"
"time" "time"
@ -101,14 +102,14 @@ type UDPv5 struct {
type sendRequest struct { type sendRequest struct {
destID enode.ID destID enode.ID
destAddr *net.UDPAddr destAddr netip.AddrPort
msg v5wire.Packet msg v5wire.Packet
} }
// callV5 represents a remote procedure call against another node. // callV5 represents a remote procedure call against another node.
type callV5 struct { type callV5 struct {
id enode.ID id enode.ID
addr *net.UDPAddr addr netip.AddrPort
node *enode.Node // This is required to perform handshakes. node *enode.Node // This is required to perform handshakes.
packet v5wire.Packet packet v5wire.Packet
@ -233,7 +234,7 @@ func (t *UDPv5) AllNodes() []*enode.Node {
for _, b := range &t.tab.buckets { for _, b := range &t.tab.buckets {
for _, n := range b.entries { for _, n := range b.entries {
nodes = append(nodes, unwrapNode(n)) nodes = append(nodes, n.Node)
} }
} }
return nodes return nodes
@ -266,7 +267,7 @@ func (t *UDPv5) TalkRequest(n *enode.Node, protocol string, request []byte) ([]b
} }
// TalkRequestToID sends a talk request to a node and waits for a response. // TalkRequestToID sends a talk request to a node and waits for a response.
func (t *UDPv5) TalkRequestToID(id enode.ID, addr *net.UDPAddr, protocol string, request []byte) ([]byte, error) { func (t *UDPv5) TalkRequestToID(id enode.ID, addr netip.AddrPort, protocol string, request []byte) ([]byte, error) {
req := &v5wire.TalkRequest{Protocol: protocol, Message: request} req := &v5wire.TalkRequest{Protocol: protocol, Message: request}
resp := t.callToID(id, addr, v5wire.TalkResponseMsg, req) resp := t.callToID(id, addr, v5wire.TalkResponseMsg, req)
defer t.callDone(resp) defer t.callDone(resp)
@ -314,26 +315,26 @@ func (t *UDPv5) newRandomLookup(ctx context.Context) *lookup {
} }
func (t *UDPv5) newLookup(ctx context.Context, target enode.ID) *lookup { func (t *UDPv5) newLookup(ctx context.Context, target enode.ID) *lookup {
return newLookup(ctx, t.tab, target, func(n *node) ([]*node, error) { return newLookup(ctx, t.tab, target, func(n *enode.Node) ([]*enode.Node, error) {
return t.lookupWorker(n, target) return t.lookupWorker(n, target)
}) })
} }
// lookupWorker performs FINDNODE calls against a single node during lookup. // lookupWorker performs FINDNODE calls against a single node during lookup.
func (t *UDPv5) lookupWorker(destNode *node, target enode.ID) ([]*node, error) { func (t *UDPv5) lookupWorker(destNode *enode.Node, target enode.ID) ([]*enode.Node, error) {
var ( var (
dists = lookupDistances(target, destNode.ID()) dists = lookupDistances(target, destNode.ID())
nodes = nodesByDistance{target: target} nodes = nodesByDistance{target: target}
err error err error
) )
var r []*enode.Node var r []*enode.Node
r, err = t.findnode(unwrapNode(destNode), dists) r, err = t.findnode(destNode, dists)
if errors.Is(err, errClosed) { if errors.Is(err, errClosed) {
return nil, err return nil, err
} }
for _, n := range r { for _, n := range r {
if n.ID() != t.Self().ID() { if n.ID() != t.Self().ID() {
nodes.push(wrapNode(n), findnodeResultLimit) nodes.push(n, findnodeResultLimit)
} }
} }
return nodes.entries, err return nodes.entries, err
@ -427,7 +428,7 @@ func (t *UDPv5) verifyResponseNode(c *callV5, r *enr.Record, distances []uint, s
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := netutil.CheckRelayIP(c.addr.IP, node.IP()); err != nil { if err := netutil.CheckRelayIP(c.addr.Addr().AsSlice(), node.IP()); err != nil {
return nil, err return nil, err
} }
if t.netrestrict != nil && !t.netrestrict.Contains(node.IP()) { if t.netrestrict != nil && !t.netrestrict.Contains(node.IP()) {
@ -452,14 +453,14 @@ func (t *UDPv5) verifyResponseNode(c *callV5, r *enr.Record, distances []uint, s
// callToNode sends the given call and sets up a handler for response packets (of message // callToNode sends the given call and sets up a handler for response packets (of message
// type responseType). Responses are dispatched to the call's response channel. // type responseType). Responses are dispatched to the call's response channel.
func (t *UDPv5) callToNode(n *enode.Node, responseType byte, req v5wire.Packet) *callV5 { func (t *UDPv5) callToNode(n *enode.Node, responseType byte, req v5wire.Packet) *callV5 {
addr := &net.UDPAddr{IP: n.IP(), Port: n.UDP()} addr, _ := n.UDPEndpoint()
c := &callV5{id: n.ID(), addr: addr, node: n} c := &callV5{id: n.ID(), addr: addr, node: n}
t.initCall(c, responseType, req) t.initCall(c, responseType, req)
return c return c
} }
// callToID is like callToNode, but for cases where the node record is not available. // callToID is like callToNode, but for cases where the node record is not available.
func (t *UDPv5) callToID(id enode.ID, addr *net.UDPAddr, responseType byte, req v5wire.Packet) *callV5 { func (t *UDPv5) callToID(id enode.ID, addr netip.AddrPort, responseType byte, req v5wire.Packet) *callV5 {
c := &callV5{id: id, addr: addr} c := &callV5{id: id, addr: addr}
t.initCall(c, responseType, req) t.initCall(c, responseType, req)
return c return c
@ -619,12 +620,12 @@ func (t *UDPv5) sendCall(c *callV5) {
// sendResponse sends a response packet to the given node. // sendResponse sends a response packet to the given node.
// This doesn't trigger a handshake even if no keys are available. // This doesn't trigger a handshake even if no keys are available.
func (t *UDPv5) sendResponse(toID enode.ID, toAddr *net.UDPAddr, packet v5wire.Packet) error { func (t *UDPv5) sendResponse(toID enode.ID, toAddr netip.AddrPort, packet v5wire.Packet) error {
_, err := t.send(toID, toAddr, packet, nil) _, err := t.send(toID, toAddr, packet, nil)
return err return err
} }
func (t *UDPv5) sendFromAnotherThread(toID enode.ID, toAddr *net.UDPAddr, packet v5wire.Packet) { func (t *UDPv5) sendFromAnotherThread(toID enode.ID, toAddr netip.AddrPort, packet v5wire.Packet) {
select { select {
case t.sendCh <- sendRequest{toID, toAddr, packet}: case t.sendCh <- sendRequest{toID, toAddr, packet}:
case <-t.closeCtx.Done(): case <-t.closeCtx.Done():
@ -632,7 +633,7 @@ func (t *UDPv5) sendFromAnotherThread(toID enode.ID, toAddr *net.UDPAddr, packet
} }
// send sends a packet to the given node. // send sends a packet to the given node.
func (t *UDPv5) send(toID enode.ID, toAddr *net.UDPAddr, packet v5wire.Packet, c *v5wire.Whoareyou) (v5wire.Nonce, error) { func (t *UDPv5) send(toID enode.ID, toAddr netip.AddrPort, packet v5wire.Packet, c *v5wire.Whoareyou) (v5wire.Nonce, error) {
addr := toAddr.String() addr := toAddr.String()
t.logcontext = append(t.logcontext[:0], "id", toID, "addr", addr) t.logcontext = append(t.logcontext[:0], "id", toID, "addr", addr)
t.logcontext = packet.AppendLogInfo(t.logcontext) t.logcontext = packet.AppendLogInfo(t.logcontext)
@ -644,7 +645,7 @@ func (t *UDPv5) send(toID enode.ID, toAddr *net.UDPAddr, packet v5wire.Packet, c
return nonce, err return nonce, err
} }
_, err = t.conn.WriteToUDP(enc, toAddr) _, err = t.conn.WriteToUDPAddrPort(enc, toAddr)
t.log.Trace(">> "+packet.Name(), t.logcontext...) t.log.Trace(">> "+packet.Name(), t.logcontext...)
return nonce, err return nonce, err
} }
@ -655,7 +656,7 @@ func (t *UDPv5) readLoop() {
buf := make([]byte, maxPacketSize) buf := make([]byte, maxPacketSize)
for range t.readNextCh { for range t.readNextCh {
nbytes, from, err := t.conn.ReadFromUDP(buf) nbytes, from, err := t.conn.ReadFromUDPAddrPort(buf)
if netutil.IsTemporaryError(err) { if netutil.IsTemporaryError(err) {
// Ignore temporary read errors. // Ignore temporary read errors.
t.log.Debug("Temporary UDP read error", "err", err) t.log.Debug("Temporary UDP read error", "err", err)
@ -672,7 +673,7 @@ func (t *UDPv5) readLoop() {
} }
// dispatchReadPacket sends a packet into the dispatch loop. // dispatchReadPacket sends a packet into the dispatch loop.
func (t *UDPv5) dispatchReadPacket(from *net.UDPAddr, content []byte) bool { func (t *UDPv5) dispatchReadPacket(from netip.AddrPort, content []byte) bool {
select { select {
case t.packetInCh <- ReadPacket{content, from}: case t.packetInCh <- ReadPacket{content, from}:
return true return true
@ -682,7 +683,7 @@ func (t *UDPv5) dispatchReadPacket(from *net.UDPAddr, content []byte) bool {
} }
// handlePacket decodes and processes an incoming packet from the network. // handlePacket decodes and processes an incoming packet from the network.
func (t *UDPv5) handlePacket(rawpacket []byte, fromAddr *net.UDPAddr) error { func (t *UDPv5) handlePacket(rawpacket []byte, fromAddr netip.AddrPort) error {
addr := fromAddr.String() addr := fromAddr.String()
fromID, fromNode, packet, err := t.codec.Decode(rawpacket, addr) fromID, fromNode, packet, err := t.codec.Decode(rawpacket, addr)
if err != nil { if err != nil {
@ -699,7 +700,7 @@ func (t *UDPv5) handlePacket(rawpacket []byte, fromAddr *net.UDPAddr) error {
} }
if fromNode != nil { if fromNode != nil {
// Handshake succeeded, add to table. // Handshake succeeded, add to table.
t.tab.addInboundNode(wrapNode(fromNode)) t.tab.addInboundNode(fromNode)
} }
if packet.Kind() != v5wire.WhoareyouPacket { if packet.Kind() != v5wire.WhoareyouPacket {
// WHOAREYOU logged separately to report errors. // WHOAREYOU logged separately to report errors.
@ -712,13 +713,13 @@ func (t *UDPv5) handlePacket(rawpacket []byte, fromAddr *net.UDPAddr) error {
} }
// handleCallResponse dispatches a response packet to the call waiting for it. // handleCallResponse dispatches a response packet to the call waiting for it.
func (t *UDPv5) handleCallResponse(fromID enode.ID, fromAddr *net.UDPAddr, p v5wire.Packet) bool { func (t *UDPv5) handleCallResponse(fromID enode.ID, fromAddr netip.AddrPort, p v5wire.Packet) bool {
ac := t.activeCallByNode[fromID] ac := t.activeCallByNode[fromID]
if ac == nil || !bytes.Equal(p.RequestID(), ac.reqid) { if ac == nil || !bytes.Equal(p.RequestID(), ac.reqid) {
t.log.Debug(fmt.Sprintf("Unsolicited/late %s response", p.Name()), "id", fromID, "addr", fromAddr) t.log.Debug(fmt.Sprintf("Unsolicited/late %s response", p.Name()), "id", fromID, "addr", fromAddr)
return false return false
} }
if !fromAddr.IP.Equal(ac.addr.IP) || fromAddr.Port != ac.addr.Port { if fromAddr != ac.addr {
t.log.Debug(fmt.Sprintf("%s from wrong endpoint", p.Name()), "id", fromID, "addr", fromAddr) t.log.Debug(fmt.Sprintf("%s from wrong endpoint", p.Name()), "id", fromID, "addr", fromAddr)
return false return false
} }
@ -743,7 +744,7 @@ func (t *UDPv5) getNode(id enode.ID) *enode.Node {
} }
// handle processes incoming packets according to their message type. // handle processes incoming packets according to their message type.
func (t *UDPv5) handle(p v5wire.Packet, fromID enode.ID, fromAddr *net.UDPAddr) { func (t *UDPv5) handle(p v5wire.Packet, fromID enode.ID, fromAddr netip.AddrPort) {
switch p := p.(type) { switch p := p.(type) {
case *v5wire.Unknown: case *v5wire.Unknown:
t.handleUnknown(p, fromID, fromAddr) t.handleUnknown(p, fromID, fromAddr)
@ -753,7 +754,9 @@ func (t *UDPv5) handle(p v5wire.Packet, fromID enode.ID, fromAddr *net.UDPAddr)
t.handlePing(p, fromID, fromAddr) t.handlePing(p, fromID, fromAddr)
case *v5wire.Pong: case *v5wire.Pong:
if t.handleCallResponse(fromID, fromAddr, p) { if t.handleCallResponse(fromID, fromAddr, p) {
t.localNode.UDPEndpointStatement(fromAddr, &net.UDPAddr{IP: p.ToIP, Port: int(p.ToPort)}) fromUDPAddr := &net.UDPAddr{IP: fromAddr.Addr().AsSlice(), Port: int(fromAddr.Port())}
toUDPAddr := &net.UDPAddr{IP: p.ToIP, Port: int(p.ToPort)}
t.localNode.UDPEndpointStatement(fromUDPAddr, toUDPAddr)
} }
case *v5wire.Findnode: case *v5wire.Findnode:
t.handleFindnode(p, fromID, fromAddr) t.handleFindnode(p, fromID, fromAddr)
@ -767,7 +770,7 @@ func (t *UDPv5) handle(p v5wire.Packet, fromID enode.ID, fromAddr *net.UDPAddr)
} }
// handleUnknown initiates a handshake by responding with WHOAREYOU. // handleUnknown initiates a handshake by responding with WHOAREYOU.
func (t *UDPv5) handleUnknown(p *v5wire.Unknown, fromID enode.ID, fromAddr *net.UDPAddr) { func (t *UDPv5) handleUnknown(p *v5wire.Unknown, fromID enode.ID, fromAddr netip.AddrPort) {
challenge := &v5wire.Whoareyou{Nonce: p.Nonce} challenge := &v5wire.Whoareyou{Nonce: p.Nonce}
crand.Read(challenge.IDNonce[:]) crand.Read(challenge.IDNonce[:])
if n := t.getNode(fromID); n != nil { if n := t.getNode(fromID); n != nil {
@ -783,7 +786,7 @@ var (
) )
// handleWhoareyou resends the active call as a handshake packet. // handleWhoareyou resends the active call as a handshake packet.
func (t *UDPv5) handleWhoareyou(p *v5wire.Whoareyou, fromID enode.ID, fromAddr *net.UDPAddr) { func (t *UDPv5) handleWhoareyou(p *v5wire.Whoareyou, fromID enode.ID, fromAddr netip.AddrPort) {
c, err := t.matchWithCall(fromID, p.Nonce) c, err := t.matchWithCall(fromID, p.Nonce)
if err != nil { if err != nil {
t.log.Debug("Invalid "+p.Name(), "addr", fromAddr, "err", err) t.log.Debug("Invalid "+p.Name(), "addr", fromAddr, "err", err)
@ -817,32 +820,35 @@ func (t *UDPv5) matchWithCall(fromID enode.ID, nonce v5wire.Nonce) (*callV5, err
} }
// handlePing sends a PONG response. // handlePing sends a PONG response.
func (t *UDPv5) handlePing(p *v5wire.Ping, fromID enode.ID, fromAddr *net.UDPAddr) { func (t *UDPv5) handlePing(p *v5wire.Ping, fromID enode.ID, fromAddr netip.AddrPort) {
remoteIP := fromAddr.IP var remoteIP net.IP
// Handle IPv4 mapped IPv6 addresses in the // Handle IPv4 mapped IPv6 addresses in the event the local node is binded
// event the local node is binded to an // to an ipv6 interface.
// ipv6 interface. if fromAddr.Addr().Is4() || fromAddr.Addr().Is4In6() {
if remoteIP.To4() != nil { ip4 := fromAddr.Addr().As4()
remoteIP = remoteIP.To4() remoteIP = ip4[:]
} else {
remoteIP = fromAddr.Addr().AsSlice()
} }
t.sendResponse(fromID, fromAddr, &v5wire.Pong{ t.sendResponse(fromID, fromAddr, &v5wire.Pong{
ReqID: p.ReqID, ReqID: p.ReqID,
ToIP: remoteIP, ToIP: remoteIP,
ToPort: uint16(fromAddr.Port), ToPort: fromAddr.Port(),
ENRSeq: t.localNode.Node().Seq(), ENRSeq: t.localNode.Node().Seq(),
}) })
} }
// handleFindnode returns nodes to the requester. // handleFindnode returns nodes to the requester.
func (t *UDPv5) handleFindnode(p *v5wire.Findnode, fromID enode.ID, fromAddr *net.UDPAddr) { func (t *UDPv5) handleFindnode(p *v5wire.Findnode, fromID enode.ID, fromAddr netip.AddrPort) {
nodes := t.collectTableNodes(fromAddr.IP, p.Distances, findnodeResultLimit) nodes := t.collectTableNodes(fromAddr.Addr(), p.Distances, findnodeResultLimit)
for _, resp := range packNodes(p.ReqID, nodes) { for _, resp := range packNodes(p.ReqID, nodes) {
t.sendResponse(fromID, fromAddr, resp) t.sendResponse(fromID, fromAddr, resp)
} }
} }
// collectTableNodes creates a FINDNODE result set for the given distances. // collectTableNodes creates a FINDNODE result set for the given distances.
func (t *UDPv5) collectTableNodes(rip net.IP, distances []uint, limit int) []*enode.Node { func (t *UDPv5) collectTableNodes(rip netip.Addr, distances []uint, limit int) []*enode.Node {
ripSlice := rip.AsSlice()
var bn []*enode.Node var bn []*enode.Node
var nodes []*enode.Node var nodes []*enode.Node
var processed = make(map[uint]struct{}) var processed = make(map[uint]struct{})
@ -857,7 +863,7 @@ func (t *UDPv5) collectTableNodes(rip net.IP, distances []uint, limit int) []*en
for _, n := range t.tab.appendLiveNodes(dist, bn[:0]) { for _, n := range t.tab.appendLiveNodes(dist, bn[:0]) {
// Apply some pre-checks to avoid sending invalid nodes. // Apply some pre-checks to avoid sending invalid nodes.
// Note liveness is checked by appendLiveNodes. // Note liveness is checked by appendLiveNodes.
if netutil.CheckRelayIP(rip, n.IP()) != nil { if netutil.CheckRelayIP(ripSlice, n.IP()) != nil {
continue continue
} }
nodes = append(nodes, n) nodes = append(nodes, n)

@ -23,6 +23,7 @@ import (
"fmt" "fmt"
"math/rand" "math/rand"
"net" "net"
"net/netip"
"reflect" "reflect"
"slices" "slices"
"testing" "testing"
@ -103,7 +104,7 @@ func TestUDPv5_pingHandling(t *testing.T) {
defer test.close() defer test.close()
test.packetIn(&v5wire.Ping{ReqID: []byte("foo")}) test.packetIn(&v5wire.Ping{ReqID: []byte("foo")})
test.waitPacketOut(func(p *v5wire.Pong, addr *net.UDPAddr, _ v5wire.Nonce) { test.waitPacketOut(func(p *v5wire.Pong, addr netip.AddrPort, _ v5wire.Nonce) {
if !bytes.Equal(p.ReqID, []byte("foo")) { if !bytes.Equal(p.ReqID, []byte("foo")) {
t.Error("wrong request ID in response:", p.ReqID) t.Error("wrong request ID in response:", p.ReqID)
} }
@ -135,16 +136,16 @@ func TestUDPv5_unknownPacket(t *testing.T) {
// Unknown packet from unknown node. // Unknown packet from unknown node.
test.packetIn(&v5wire.Unknown{Nonce: nonce}) test.packetIn(&v5wire.Unknown{Nonce: nonce})
test.waitPacketOut(func(p *v5wire.Whoareyou, addr *net.UDPAddr, _ v5wire.Nonce) { test.waitPacketOut(func(p *v5wire.Whoareyou, addr netip.AddrPort, _ v5wire.Nonce) {
check(p, 0) check(p, 0)
}) })
// Make node known. // Make node known.
n := test.getNode(test.remotekey, test.remoteaddr).Node() n := test.getNode(test.remotekey, test.remoteaddr).Node()
test.table.addFoundNode(wrapNode(n)) test.table.addFoundNode(n, false)
test.packetIn(&v5wire.Unknown{Nonce: nonce}) test.packetIn(&v5wire.Unknown{Nonce: nonce})
test.waitPacketOut(func(p *v5wire.Whoareyou, addr *net.UDPAddr, _ v5wire.Nonce) { test.waitPacketOut(func(p *v5wire.Whoareyou, addr netip.AddrPort, _ v5wire.Nonce) {
check(p, n.Seq()) check(p, n.Seq())
}) })
} }
@ -159,9 +160,9 @@ func TestUDPv5_findnodeHandling(t *testing.T) {
nodes253 := nodesAtDistance(test.table.self().ID(), 253, 16) nodes253 := nodesAtDistance(test.table.self().ID(), 253, 16)
nodes249 := nodesAtDistance(test.table.self().ID(), 249, 4) nodes249 := nodesAtDistance(test.table.self().ID(), 249, 4)
nodes248 := nodesAtDistance(test.table.self().ID(), 248, 10) nodes248 := nodesAtDistance(test.table.self().ID(), 248, 10)
fillTable(test.table, wrapNodes(nodes253), true) fillTable(test.table, nodes253, true)
fillTable(test.table, wrapNodes(nodes249), true) fillTable(test.table, nodes249, true)
fillTable(test.table, wrapNodes(nodes248), true) fillTable(test.table, nodes248, true)
// Requesting with distance zero should return the node's own record. // Requesting with distance zero should return the node's own record.
test.packetIn(&v5wire.Findnode{ReqID: []byte{0}, Distances: []uint{0}}) test.packetIn(&v5wire.Findnode{ReqID: []byte{0}, Distances: []uint{0}})
@ -199,7 +200,7 @@ func (test *udpV5Test) expectNodes(wantReqID []byte, wantTotal uint8, wantNodes
} }
for { for {
test.waitPacketOut(func(p *v5wire.Nodes, addr *net.UDPAddr, _ v5wire.Nonce) { test.waitPacketOut(func(p *v5wire.Nodes, addr netip.AddrPort, _ v5wire.Nonce) {
if !bytes.Equal(p.ReqID, wantReqID) { if !bytes.Equal(p.ReqID, wantReqID) {
test.t.Fatalf("wrong request ID %v in response, want %v", p.ReqID, wantReqID) test.t.Fatalf("wrong request ID %v in response, want %v", p.ReqID, wantReqID)
} }
@ -238,7 +239,7 @@ func TestUDPv5_pingCall(t *testing.T) {
_, err := test.udp.ping(remote) _, err := test.udp.ping(remote)
done <- err done <- err
}() }()
test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, _ v5wire.Nonce) {}) test.waitPacketOut(func(p *v5wire.Ping, addr netip.AddrPort, _ v5wire.Nonce) {})
if err := <-done; err != errTimeout { if err := <-done; err != errTimeout {
t.Fatalf("want errTimeout, got %q", err) t.Fatalf("want errTimeout, got %q", err)
} }
@ -248,7 +249,7 @@ func TestUDPv5_pingCall(t *testing.T) {
_, err := test.udp.ping(remote) _, err := test.udp.ping(remote)
done <- err done <- err
}() }()
test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, _ v5wire.Nonce) { test.waitPacketOut(func(p *v5wire.Ping, addr netip.AddrPort, _ v5wire.Nonce) {
test.packetInFrom(test.remotekey, test.remoteaddr, &v5wire.Pong{ReqID: p.ReqID}) test.packetInFrom(test.remotekey, test.remoteaddr, &v5wire.Pong{ReqID: p.ReqID})
}) })
if err := <-done; err != nil { if err := <-done; err != nil {
@ -260,8 +261,8 @@ func TestUDPv5_pingCall(t *testing.T) {
_, err := test.udp.ping(remote) _, err := test.udp.ping(remote)
done <- err done <- err
}() }()
test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, _ v5wire.Nonce) { test.waitPacketOut(func(p *v5wire.Ping, addr netip.AddrPort, _ v5wire.Nonce) {
wrongAddr := &net.UDPAddr{IP: net.IP{33, 44, 55, 22}, Port: 10101} wrongAddr := netip.MustParseAddrPort("33.44.55.22:10101")
test.packetInFrom(test.remotekey, wrongAddr, &v5wire.Pong{ReqID: p.ReqID}) test.packetInFrom(test.remotekey, wrongAddr, &v5wire.Pong{ReqID: p.ReqID})
}) })
if err := <-done; err != errTimeout { if err := <-done; err != errTimeout {
@ -291,7 +292,7 @@ func TestUDPv5_findnodeCall(t *testing.T) {
}() }()
// Serve the responses: // Serve the responses:
test.waitPacketOut(func(p *v5wire.Findnode, addr *net.UDPAddr, _ v5wire.Nonce) { test.waitPacketOut(func(p *v5wire.Findnode, addr netip.AddrPort, _ v5wire.Nonce) {
if !reflect.DeepEqual(p.Distances, distances) { if !reflect.DeepEqual(p.Distances, distances) {
t.Fatalf("wrong distances in request: %v", p.Distances) t.Fatalf("wrong distances in request: %v", p.Distances)
} }
@ -337,15 +338,15 @@ func TestUDPv5_callResend(t *testing.T) {
}() }()
// Ping answered by WHOAREYOU. // Ping answered by WHOAREYOU.
test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, nonce v5wire.Nonce) { test.waitPacketOut(func(p *v5wire.Ping, addr netip.AddrPort, nonce v5wire.Nonce) {
test.packetIn(&v5wire.Whoareyou{Nonce: nonce}) test.packetIn(&v5wire.Whoareyou{Nonce: nonce})
}) })
// Ping should be re-sent. // Ping should be re-sent.
test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, _ v5wire.Nonce) { test.waitPacketOut(func(p *v5wire.Ping, addr netip.AddrPort, _ v5wire.Nonce) {
test.packetIn(&v5wire.Pong{ReqID: p.ReqID}) test.packetIn(&v5wire.Pong{ReqID: p.ReqID})
}) })
// Answer the other ping. // Answer the other ping.
test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, _ v5wire.Nonce) { test.waitPacketOut(func(p *v5wire.Ping, addr netip.AddrPort, _ v5wire.Nonce) {
test.packetIn(&v5wire.Pong{ReqID: p.ReqID}) test.packetIn(&v5wire.Pong{ReqID: p.ReqID})
}) })
if err := <-done; err != nil { if err := <-done; err != nil {
@ -370,11 +371,11 @@ func TestUDPv5_multipleHandshakeRounds(t *testing.T) {
}() }()
// Ping answered by WHOAREYOU. // Ping answered by WHOAREYOU.
test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, nonce v5wire.Nonce) { test.waitPacketOut(func(p *v5wire.Ping, addr netip.AddrPort, nonce v5wire.Nonce) {
test.packetIn(&v5wire.Whoareyou{Nonce: nonce}) test.packetIn(&v5wire.Whoareyou{Nonce: nonce})
}) })
// Ping answered by WHOAREYOU again. // Ping answered by WHOAREYOU again.
test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, nonce v5wire.Nonce) { test.waitPacketOut(func(p *v5wire.Ping, addr netip.AddrPort, nonce v5wire.Nonce) {
test.packetIn(&v5wire.Whoareyou{Nonce: nonce}) test.packetIn(&v5wire.Whoareyou{Nonce: nonce})
}) })
if err := <-done; err != errTimeout { if err := <-done; err != errTimeout {
@ -401,7 +402,7 @@ func TestUDPv5_callTimeoutReset(t *testing.T) {
}() }()
// Serve two responses, slowly. // Serve two responses, slowly.
test.waitPacketOut(func(p *v5wire.Findnode, addr *net.UDPAddr, _ v5wire.Nonce) { test.waitPacketOut(func(p *v5wire.Findnode, addr netip.AddrPort, _ v5wire.Nonce) {
time.Sleep(respTimeout - 50*time.Millisecond) time.Sleep(respTimeout - 50*time.Millisecond)
test.packetIn(&v5wire.Nodes{ test.packetIn(&v5wire.Nodes{
ReqID: p.ReqID, ReqID: p.ReqID,
@ -439,7 +440,7 @@ func TestUDPv5_talkHandling(t *testing.T) {
Protocol: "test", Protocol: "test",
Message: []byte("test request"), Message: []byte("test request"),
}) })
test.waitPacketOut(func(p *v5wire.TalkResponse, addr *net.UDPAddr, _ v5wire.Nonce) { test.waitPacketOut(func(p *v5wire.TalkResponse, addr netip.AddrPort, _ v5wire.Nonce) {
if !bytes.Equal(p.ReqID, []byte("foo")) { if !bytes.Equal(p.ReqID, []byte("foo")) {
t.Error("wrong request ID in response:", p.ReqID) t.Error("wrong request ID in response:", p.ReqID)
} }
@ -458,7 +459,7 @@ func TestUDPv5_talkHandling(t *testing.T) {
Protocol: "wrong", Protocol: "wrong",
Message: []byte("test request"), Message: []byte("test request"),
}) })
test.waitPacketOut(func(p *v5wire.TalkResponse, addr *net.UDPAddr, _ v5wire.Nonce) { test.waitPacketOut(func(p *v5wire.TalkResponse, addr netip.AddrPort, _ v5wire.Nonce) {
if !bytes.Equal(p.ReqID, []byte("2")) { if !bytes.Equal(p.ReqID, []byte("2")) {
t.Error("wrong request ID in response:", p.ReqID) t.Error("wrong request ID in response:", p.ReqID)
} }
@ -485,7 +486,7 @@ func TestUDPv5_talkRequest(t *testing.T) {
_, err := test.udp.TalkRequest(remote, "test", []byte("test request")) _, err := test.udp.TalkRequest(remote, "test", []byte("test request"))
done <- err done <- err
}() }()
test.waitPacketOut(func(p *v5wire.TalkRequest, addr *net.UDPAddr, _ v5wire.Nonce) {}) test.waitPacketOut(func(p *v5wire.TalkRequest, addr netip.AddrPort, _ v5wire.Nonce) {})
if err := <-done; err != errTimeout { if err := <-done; err != errTimeout {
t.Fatalf("want errTimeout, got %q", err) t.Fatalf("want errTimeout, got %q", err)
} }
@ -495,7 +496,7 @@ func TestUDPv5_talkRequest(t *testing.T) {
_, err := test.udp.TalkRequest(remote, "test", []byte("test request")) _, err := test.udp.TalkRequest(remote, "test", []byte("test request"))
done <- err done <- err
}() }()
test.waitPacketOut(func(p *v5wire.TalkRequest, addr *net.UDPAddr, _ v5wire.Nonce) { test.waitPacketOut(func(p *v5wire.TalkRequest, addr netip.AddrPort, _ v5wire.Nonce) {
if p.Protocol != "test" { if p.Protocol != "test" {
t.Errorf("wrong protocol ID in talk request: %q", p.Protocol) t.Errorf("wrong protocol ID in talk request: %q", p.Protocol)
} }
@ -516,7 +517,7 @@ func TestUDPv5_talkRequest(t *testing.T) {
_, err := test.udp.TalkRequestToID(remote.ID(), test.remoteaddr, "test", []byte("test request 2")) _, err := test.udp.TalkRequestToID(remote.ID(), test.remoteaddr, "test", []byte("test request 2"))
done <- err done <- err
}() }()
test.waitPacketOut(func(p *v5wire.TalkRequest, addr *net.UDPAddr, _ v5wire.Nonce) { test.waitPacketOut(func(p *v5wire.TalkRequest, addr netip.AddrPort, _ v5wire.Nonce) {
if p.Protocol != "test" { if p.Protocol != "test" {
t.Errorf("wrong protocol ID in talk request: %q", p.Protocol) t.Errorf("wrong protocol ID in talk request: %q", p.Protocol)
} }
@ -583,13 +584,14 @@ func TestUDPv5_lookup(t *testing.T) {
for d, nn := range lookupTestnet.dists { for d, nn := range lookupTestnet.dists {
for i, key := range nn { for i, key := range nn {
n := lookupTestnet.node(d, i) n := lookupTestnet.node(d, i)
test.getNode(key, &net.UDPAddr{IP: n.IP(), Port: n.UDP()}) addr, _ := n.UDPEndpoint()
test.getNode(key, addr)
} }
} }
// Seed table with initial node. // Seed table with initial node.
initialNode := lookupTestnet.node(256, 0) initialNode := lookupTestnet.node(256, 0)
fillTable(test.table, []*node{wrapNode(initialNode)}, true) fillTable(test.table, []*enode.Node{initialNode}, true)
// Start the lookup. // Start the lookup.
resultC := make(chan []*enode.Node, 1) resultC := make(chan []*enode.Node, 1)
@ -601,7 +603,7 @@ func TestUDPv5_lookup(t *testing.T) {
// Answer lookup packets. // Answer lookup packets.
asked := make(map[enode.ID]bool) asked := make(map[enode.ID]bool)
for done := false; !done; { for done := false; !done; {
done = test.waitPacketOut(func(p v5wire.Packet, to *net.UDPAddr, _ v5wire.Nonce) { done = test.waitPacketOut(func(p v5wire.Packet, to netip.AddrPort, _ v5wire.Nonce) {
recipient, key := lookupTestnet.nodeByAddr(to) recipient, key := lookupTestnet.nodeByAddr(to)
switch p := p.(type) { switch p := p.(type) {
case *v5wire.Ping: case *v5wire.Ping:
@ -652,11 +654,8 @@ func TestUDPv5_PingWithIPV4MappedAddress(t *testing.T) {
test := newUDPV5Test(t) test := newUDPV5Test(t)
defer test.close() defer test.close()
rawIP := net.IPv4(0xFF, 0x12, 0x33, 0xE5) rawIP := netip.AddrFrom4([4]byte{0xFF, 0x12, 0x33, 0xE5})
test.remoteaddr = &net.UDPAddr{ test.remoteaddr = netip.AddrPortFrom(netip.AddrFrom16(rawIP.As16()), 0)
IP: rawIP.To16(),
Port: 0,
}
remote := test.getNode(test.remotekey, test.remoteaddr).Node() remote := test.getNode(test.remotekey, test.remoteaddr).Node()
done := make(chan struct{}, 1) done := make(chan struct{}, 1)
@ -665,14 +664,14 @@ func TestUDPv5_PingWithIPV4MappedAddress(t *testing.T) {
test.udp.handlePing(&v5wire.Ping{ENRSeq: 1}, remote.ID(), test.remoteaddr) test.udp.handlePing(&v5wire.Ping{ENRSeq: 1}, remote.ID(), test.remoteaddr)
done <- struct{}{} done <- struct{}{}
}() }()
test.waitPacketOut(func(p *v5wire.Pong, addr *net.UDPAddr, _ v5wire.Nonce) { test.waitPacketOut(func(p *v5wire.Pong, addr netip.AddrPort, _ v5wire.Nonce) {
if len(p.ToIP) == net.IPv6len { if len(p.ToIP) == net.IPv6len {
t.Error("Received untruncated ip address") t.Error("Received untruncated ip address")
} }
if len(p.ToIP) != net.IPv4len { if len(p.ToIP) != net.IPv4len {
t.Errorf("Received ip address with incorrect length: %d", len(p.ToIP)) t.Errorf("Received ip address with incorrect length: %d", len(p.ToIP))
} }
if !p.ToIP.Equal(rawIP) { if !p.ToIP.Equal(rawIP.AsSlice()) {
t.Errorf("Received incorrect ip address: wanted %s but received %s", rawIP.String(), p.ToIP.String()) t.Errorf("Received incorrect ip address: wanted %s but received %s", rawIP.String(), p.ToIP.String())
} }
}) })
@ -688,9 +687,9 @@ type udpV5Test struct {
db *enode.DB db *enode.DB
udp *UDPv5 udp *UDPv5
localkey, remotekey *ecdsa.PrivateKey localkey, remotekey *ecdsa.PrivateKey
remoteaddr *net.UDPAddr remoteaddr netip.AddrPort
nodesByID map[enode.ID]*enode.LocalNode nodesByID map[enode.ID]*enode.LocalNode
nodesByIP map[string]*enode.LocalNode nodesByIP map[netip.Addr]*enode.LocalNode
} }
// testCodec is the packet encoding used by protocol tests. This codec does not perform encryption. // testCodec is the packet encoding used by protocol tests. This codec does not perform encryption.
@ -750,9 +749,9 @@ func newUDPV5Test(t *testing.T) *udpV5Test {
pipe: newpipe(), pipe: newpipe(),
localkey: newkey(), localkey: newkey(),
remotekey: newkey(), remotekey: newkey(),
remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303}, remoteaddr: netip.MustParseAddrPort("10.0.1.99:30303"),
nodesByID: make(map[enode.ID]*enode.LocalNode), nodesByID: make(map[enode.ID]*enode.LocalNode),
nodesByIP: make(map[string]*enode.LocalNode), nodesByIP: make(map[netip.Addr]*enode.LocalNode),
} }
test.db, _ = enode.OpenDB("") test.db, _ = enode.OpenDB("")
ln := enode.NewLocalNode(test.db, test.localkey) ln := enode.NewLocalNode(test.db, test.localkey)
@ -777,8 +776,8 @@ func (test *udpV5Test) packetIn(packet v5wire.Packet) {
test.packetInFrom(test.remotekey, test.remoteaddr, packet) test.packetInFrom(test.remotekey, test.remoteaddr, packet)
} }
// handles a packet as if it had been sent to the transport by the key/endpoint. // packetInFrom handles a packet as if it had been sent to the transport by the key/endpoint.
func (test *udpV5Test) packetInFrom(key *ecdsa.PrivateKey, addr *net.UDPAddr, packet v5wire.Packet) { func (test *udpV5Test) packetInFrom(key *ecdsa.PrivateKey, addr netip.AddrPort, packet v5wire.Packet) {
test.t.Helper() test.t.Helper()
ln := test.getNode(key, addr) ln := test.getNode(key, addr)
@ -793,22 +792,22 @@ func (test *udpV5Test) packetInFrom(key *ecdsa.PrivateKey, addr *net.UDPAddr, pa
} }
// getNode ensures the test knows about a node at the given endpoint. // getNode ensures the test knows about a node at the given endpoint.
func (test *udpV5Test) getNode(key *ecdsa.PrivateKey, addr *net.UDPAddr) *enode.LocalNode { func (test *udpV5Test) getNode(key *ecdsa.PrivateKey, addr netip.AddrPort) *enode.LocalNode {
id := encodePubkey(&key.PublicKey).id() id := encodePubkey(&key.PublicKey).id()
ln := test.nodesByID[id] ln := test.nodesByID[id]
if ln == nil { if ln == nil {
db, _ := enode.OpenDB("") db, _ := enode.OpenDB("")
ln = enode.NewLocalNode(db, key) ln = enode.NewLocalNode(db, key)
ln.SetStaticIP(addr.IP) ln.SetStaticIP(addr.Addr().AsSlice())
ln.Set(enr.UDP(addr.Port)) ln.Set(enr.UDP(addr.Port()))
test.nodesByID[id] = ln test.nodesByID[id] = ln
} }
test.nodesByIP[string(addr.IP)] = ln test.nodesByIP[addr.Addr()] = ln
return ln return ln
} }
// waitPacketOut waits for the next output packet and handles it using the given 'validate' // waitPacketOut waits for the next output packet and handles it using the given 'validate'
// function. The function must be of type func (X, *net.UDPAddr, v5wire.Nonce) where X is // function. The function must be of type func (X, netip.AddrPort, v5wire.Nonce) where X is
// assignable to packetV5. // assignable to packetV5.
func (test *udpV5Test) waitPacketOut(validate interface{}) (closed bool) { func (test *udpV5Test) waitPacketOut(validate interface{}) (closed bool) {
test.t.Helper() test.t.Helper()
@ -824,7 +823,7 @@ func (test *udpV5Test) waitPacketOut(validate interface{}) (closed bool) {
test.t.Fatalf("timed out waiting for %v", exptype) test.t.Fatalf("timed out waiting for %v", exptype)
return false return false
} }
ln := test.nodesByIP[string(dgram.to.IP)] ln := test.nodesByIP[dgram.to.Addr()]
if ln == nil { if ln == nil {
test.t.Fatalf("attempt to send to non-existing node %v", &dgram.to) test.t.Fatalf("attempt to send to non-existing node %v", &dgram.to)
return false return false
@ -839,7 +838,7 @@ func (test *udpV5Test) waitPacketOut(validate interface{}) (closed bool) {
test.t.Errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype) test.t.Errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype)
return false return false
} }
fn.Call([]reflect.Value{reflect.ValueOf(p), reflect.ValueOf(&dgram.to), reflect.ValueOf(frame.AuthTag)}) fn.Call([]reflect.Value{reflect.ValueOf(p), reflect.ValueOf(dgram.to), reflect.ValueOf(frame.AuthTag)})
return false return false
} }

@ -24,6 +24,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"net/netip"
"slices" "slices"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -435,11 +436,11 @@ type sharedUDPConn struct {
unhandled chan discover.ReadPacket unhandled chan discover.ReadPacket
} }
// ReadFromUDP implements discover.UDPConn // ReadFromUDPAddrPort implements discover.UDPConn
func (s *sharedUDPConn) ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) { func (s *sharedUDPConn) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) {
packet, ok := <-s.unhandled packet, ok := <-s.unhandled
if !ok { if !ok {
return 0, nil, errors.New("connection was closed") return 0, netip.AddrPort{}, errors.New("connection was closed")
} }
l := len(packet.Data) l := len(packet.Data)
if l > len(b) { if l > len(b) {

Loading…
Cancel
Save