p2p/discover: refactor node and endpoint representation (#29844)

Here we clean up internal uses of type discover.node, converting most code to use
enode.Node instead. The discover.node type used to be the canonical representation of
network hosts before ENR was introduced. Most code worked with *node to avoid conversions
when interacting with Table methods. Since *node also contains internal state of Table and
is a mutable type, using *node outside of Table code is prone to data races. It's also
cleaner not having to wrap/unwrap *enode.Node all the time.

discover.node has been renamed to tableNode to clarify its purpose.

While here, we also change most uses of net.UDPAddr into netip.AddrPort. While this is
technically a separate refactoring from the *node -> *enode.Node change, it is more
convenient because *enode.Node handles IP addresses as netip.Addr. The switch to package
netip in discovery would've happened very soon anyway.

The change to netip.AddrPort stops at certain interface points. For example, since package
p2p/netutil has not been converted to use netip.Addr yet, we still have to convert to
net.IP/net.UDPAddr in a few places.
pull/29879/head
Felix Lange 4 months ago committed by GitHub
parent e26fa9e40e
commit 94a8b296e4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 2
      cmd/devp2p/internal/v4test/framework.go
  2. 7
      p2p/discover/common.go
  3. 18
      p2p/discover/lookup.go
  4. 14
      p2p/discover/metrics.go
  5. 69
      p2p/discover/node.go
  6. 125
      p2p/discover/table.go
  7. 22
      p2p/discover/table_reval.go
  8. 4
      p2p/discover/table_reval_test.go
  9. 80
      p2p/discover/table_test.go
  10. 34
      p2p/discover/table_util_test.go
  11. 27
      p2p/discover/v4_lookup_test.go
  12. 146
      p2p/discover/v4_udp.go
  13. 107
      p2p/discover/v4_udp_test.go
  14. 16
      p2p/discover/v4wire/v4wire.go
  15. 6
      p2p/discover/v5_talk.go
  16. 80
      p2p/discover/v5_udp.go
  17. 93
      p2p/discover/v5_udp_test.go
  18. 7
      p2p/server.go

@ -110,7 +110,7 @@ func (te *testenv) localEndpoint(c net.PacketConn) v4wire.Endpoint {
}
func (te *testenv) remoteEndpoint() v4wire.Endpoint {
return v4wire.NewEndpoint(te.remoteAddr, 0)
return v4wire.NewEndpoint(te.remoteAddr.AddrPort(), 0)
}
func contains(ns []v4wire.Node, key v4wire.Pubkey) bool {

@ -22,6 +22,7 @@ import (
"encoding/binary"
"math/rand"
"net"
"net/netip"
"sync"
"time"
@ -34,8 +35,8 @@ import (
// UDPConn is a network connection on which discovery can operate.
type UDPConn interface {
ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error)
WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error)
ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error)
WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (n int, err error)
Close() error
LocalAddr() net.Addr
}
@ -94,7 +95,7 @@ func ListenUDP(c UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv4, error) {
// channel if configured.
type ReadPacket struct {
Data []byte
Addr *net.UDPAddr
Addr netip.AddrPort
}
type randomSource interface {

@ -29,16 +29,16 @@ import (
// not need to be an actual node identifier.
type lookup struct {
tab *Table
queryfunc func(*node) ([]*node, error)
replyCh chan []*node
queryfunc queryFunc
replyCh chan []*enode.Node
cancelCh <-chan struct{}
asked, seen map[enode.ID]bool
result nodesByDistance
replyBuffer []*node
replyBuffer []*enode.Node
queries int
}
type queryFunc func(*node) ([]*node, error)
type queryFunc func(*enode.Node) ([]*enode.Node, error)
func newLookup(ctx context.Context, tab *Table, target enode.ID, q queryFunc) *lookup {
it := &lookup{
@ -47,7 +47,7 @@ func newLookup(ctx context.Context, tab *Table, target enode.ID, q queryFunc) *l
asked: make(map[enode.ID]bool),
seen: make(map[enode.ID]bool),
result: nodesByDistance{target: target},
replyCh: make(chan []*node, alpha),
replyCh: make(chan []*enode.Node, alpha),
cancelCh: ctx.Done(),
queries: -1,
}
@ -61,7 +61,7 @@ func newLookup(ctx context.Context, tab *Table, target enode.ID, q queryFunc) *l
func (it *lookup) run() []*enode.Node {
for it.advance() {
}
return unwrapNodes(it.result.entries)
return it.result.entries
}
// advance advances the lookup until any new nodes have been found.
@ -139,7 +139,7 @@ func (it *lookup) slowdown() {
}
}
func (it *lookup) query(n *node, reply chan<- []*node) {
func (it *lookup) query(n *enode.Node, reply chan<- []*enode.Node) {
r, err := it.queryfunc(n)
if !errors.Is(err, errClosed) { // avoid recording failures on shutdown.
success := len(r) > 0
@ -154,7 +154,7 @@ func (it *lookup) query(n *node, reply chan<- []*node) {
// lookupIterator performs lookup operations and iterates over all seen nodes.
// When a lookup finishes, a new one is created through nextLookup.
type lookupIterator struct {
buffer []*node
buffer []*enode.Node
nextLookup lookupFunc
ctx context.Context
cancel func()
@ -173,7 +173,7 @@ func (it *lookupIterator) Node() *enode.Node {
if len(it.buffer) == 0 {
return nil
}
return unwrapNode(it.buffer[0])
return it.buffer[0]
}
// Next moves to the next node.

@ -18,7 +18,7 @@ package discover
import (
"fmt"
"net"
"net/netip"
"github.com/ethereum/go-ethereum/metrics"
)
@ -58,16 +58,16 @@ func newMeteredConn(conn UDPConn) UDPConn {
return &meteredUdpConn{UDPConn: conn}
}
// ReadFromUDP delegates a network read to the underlying connection, bumping the udp ingress traffic meter along the way.
func (c *meteredUdpConn) ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) {
n, addr, err = c.UDPConn.ReadFromUDP(b)
// ReadFromUDPAddrPort delegates a network read to the underlying connection, bumping the udp ingress traffic meter along the way.
func (c *meteredUdpConn) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) {
n, addr, err = c.UDPConn.ReadFromUDPAddrPort(b)
ingressTrafficMeter.Mark(int64(n))
return n, addr, err
}
// Write delegates a network write to the underlying connection, bumping the udp egress traffic meter along the way.
func (c *meteredUdpConn) WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error) {
n, err = c.UDPConn.WriteToUDP(b, addr)
// WriteToUDP delegates a network write to the underlying connection, bumping the udp egress traffic meter along the way.
func (c *meteredUdpConn) WriteToUDP(b []byte, addr netip.AddrPort) (n int, err error) {
n, err = c.UDPConn.WriteToUDPAddrPort(b, addr)
egressTrafficMeter.Mark(int64(n))
return n, err
}

@ -21,7 +21,8 @@ import (
"crypto/elliptic"
"errors"
"math/big"
"net"
"slices"
"sort"
"time"
"github.com/ethereum/go-ethereum/common/math"
@ -37,9 +38,8 @@ type BucketNode struct {
Live bool `json:"live"`
}
// node represents a host on the network.
// The fields of Node may not be modified.
type node struct {
// tableNode is an entry in Table.
type tableNode struct {
*enode.Node
revalList *revalidationList
addedToTable time.Time // first time node was added to bucket or replacement list
@ -75,34 +75,59 @@ func (e encPubkey) id() enode.ID {
return enode.ID(crypto.Keccak256Hash(e[:]))
}
func wrapNode(n *enode.Node) *node {
return &node{Node: n}
}
func wrapNodes(ns []*enode.Node) []*node {
result := make([]*node, len(ns))
func unwrapNodes(ns []*tableNode) []*enode.Node {
result := make([]*enode.Node, len(ns))
for i, n := range ns {
result[i] = wrapNode(n)
result[i] = n.Node
}
return result
}
func unwrapNode(n *node) *enode.Node {
return n.Node
func (n *tableNode) String() string {
return n.Node.String()
}
// nodesByDistance is a list of nodes, ordered by distance to target.
type nodesByDistance struct {
entries []*enode.Node
target enode.ID
}
func unwrapNodes(ns []*node) []*enode.Node {
result := make([]*enode.Node, len(ns))
for i, n := range ns {
result[i] = unwrapNode(n)
// push adds the given node to the list, keeping the total size below maxElems.
func (h *nodesByDistance) push(n *enode.Node, maxElems int) {
ix := sort.Search(len(h.entries), func(i int) bool {
return enode.DistCmp(h.target, h.entries[i].ID(), n.ID()) > 0
})
end := len(h.entries)
if len(h.entries) < maxElems {
h.entries = append(h.entries, n)
}
if ix < end {
// Slide existing entries down to make room.
// This will overwrite the entry we just appended.
copy(h.entries[ix+1:], h.entries[ix:])
h.entries[ix] = n
}
return result
}
func (n *node) addr() *net.UDPAddr {
return &net.UDPAddr{IP: n.IP(), Port: n.UDP()}
type nodeType interface {
ID() enode.ID
}
func (n *node) String() string {
return n.Node.String()
// containsID reports whether ns contains a node with the given ID.
func containsID[N nodeType](ns []N, id enode.ID) bool {
for _, n := range ns {
if n.ID() == id {
return true
}
}
return false
}
// deleteNode removes a node from the list.
func deleteNode[N nodeType](list []N, id enode.ID) []N {
return slices.DeleteFunc(list, func(n N) bool {
return n.ID() == id
})
}

@ -27,7 +27,6 @@ import (
"fmt"
"net"
"slices"
"sort"
"sync"
"time"
@ -65,7 +64,7 @@ const (
type Table struct {
mutex sync.Mutex // protects buckets, bucket content, nursery, rand
buckets [nBuckets]*bucket // index of known nodes by distance
nursery []*node // bootstrap nodes
nursery []*enode.Node // bootstrap nodes
rand reseedingRandom // source of randomness, periodically reseeded
ips netutil.DistinctNetSet
revalidation tableRevalidation
@ -85,8 +84,8 @@ type Table struct {
closeReq chan struct{}
closed chan struct{}
nodeAddedHook func(*bucket, *node)
nodeRemovedHook func(*bucket, *node)
nodeAddedHook func(*bucket, *tableNode)
nodeRemovedHook func(*bucket, *tableNode)
}
// transport is implemented by the UDP transports.
@ -101,20 +100,21 @@ type transport interface {
// bucket contains nodes, ordered by their last activity. the entry
// that was most recently active is the first element in entries.
type bucket struct {
entries []*node // live entries, sorted by time of last contact
replacements []*node // recently seen nodes to be used if revalidation fails
entries []*tableNode // live entries, sorted by time of last contact
replacements []*tableNode // recently seen nodes to be used if revalidation fails
ips netutil.DistinctNetSet
index int
}
type addNodeOp struct {
node *node
isInbound bool
node *enode.Node
isInbound bool
forceSetLive bool // for tests
}
type trackRequestOp struct {
node *node
foundNodes []*node
node *enode.Node
foundNodes []*enode.Node
success bool
}
@ -186,7 +186,7 @@ func (tab *Table) getNode(id enode.ID) *enode.Node {
b := tab.bucket(id)
for _, e := range b.entries {
if e.ID() == id {
return unwrapNode(e)
return e.Node
}
}
return nil
@ -202,7 +202,7 @@ func (tab *Table) close() {
// are used to connect to the network if the table is empty and there
// are no known nodes in the database.
func (tab *Table) setFallbackNodes(nodes []*enode.Node) error {
nursery := make([]*node, 0, len(nodes))
nursery := make([]*enode.Node, 0, len(nodes))
for _, n := range nodes {
if err := n.ValidateComplete(); err != nil {
return fmt.Errorf("bad bootstrap node %q: %v", n, err)
@ -211,7 +211,7 @@ func (tab *Table) setFallbackNodes(nodes []*enode.Node) error {
tab.log.Error("Bootstrap node filtered by netrestrict", "id", n.ID(), "ip", n.IP())
continue
}
nursery = append(nursery, wrapNode(n))
nursery = append(nursery, n)
}
tab.nursery = nursery
return nil
@ -255,9 +255,9 @@ func (tab *Table) findnodeByID(target enode.ID, nresults int, preferLive bool) *
liveNodes := &nodesByDistance{target: target}
for _, b := range &tab.buckets {
for _, n := range b.entries {
nodes.push(n, nresults)
nodes.push(n.Node, nresults)
if preferLive && n.isValidatedLive {
liveNodes.push(n, nresults)
liveNodes.push(n.Node, nresults)
}
}
}
@ -309,8 +309,8 @@ func (tab *Table) len() (n int) {
// list.
//
// The caller must not hold tab.mutex.
func (tab *Table) addFoundNode(n *node) bool {
op := addNodeOp{node: n, isInbound: false}
func (tab *Table) addFoundNode(n *enode.Node, forceSetLive bool) bool {
op := addNodeOp{node: n, isInbound: false, forceSetLive: forceSetLive}
select {
case tab.addNodeCh <- op:
return <-tab.addNodeHandled
@ -327,7 +327,7 @@ func (tab *Table) addFoundNode(n *node) bool {
// repeatedly.
//
// The caller must not hold tab.mutex.
func (tab *Table) addInboundNode(n *node) bool {
func (tab *Table) addInboundNode(n *enode.Node) bool {
op := addNodeOp{node: n, isInbound: true}
select {
case tab.addNodeCh <- op:
@ -337,7 +337,7 @@ func (tab *Table) addInboundNode(n *node) bool {
}
}
func (tab *Table) trackRequest(n *node, success bool, foundNodes []*node) {
func (tab *Table) trackRequest(n *enode.Node, success bool, foundNodes []*enode.Node) {
op := trackRequestOp{n, foundNodes, success}
select {
case tab.trackRequestCh <- op:
@ -443,13 +443,14 @@ func (tab *Table) doRefresh(done chan struct{}) {
}
func (tab *Table) loadSeedNodes() {
seeds := wrapNodes(tab.db.QuerySeeds(seedCount, seedMaxAge))
seeds := tab.db.QuerySeeds(seedCount, seedMaxAge)
seeds = append(seeds, tab.nursery...)
for i := range seeds {
seed := seeds[i]
if tab.log.Enabled(context.Background(), log.LevelTrace) {
age := time.Since(tab.db.LastPongReceived(seed.ID(), seed.IP()))
tab.log.Trace("Found seed node in database", "id", seed.ID(), "addr", seed.addr(), "age", age)
addr, _ := seed.UDPEndpoint()
tab.log.Trace("Found seed node in database", "id", seed.ID(), "addr", addr, "age", age)
}
tab.handleAddNode(addNodeOp{node: seed, isInbound: false})
}
@ -513,7 +514,7 @@ func (tab *Table) handleAddNode(req addNodeOp) bool {
}
b := tab.bucket(req.node.ID())
n, _ := tab.bumpInBucket(b, req.node.Node, req.isInbound)
n, _ := tab.bumpInBucket(b, req.node, req.isInbound)
if n != nil {
// Already in bucket.
return false
@ -529,15 +530,20 @@ func (tab *Table) handleAddNode(req addNodeOp) bool {
}
// Add to bucket.
b.entries = append(b.entries, req.node)
b.replacements = deleteNode(b.replacements, req.node)
tab.nodeAdded(b, req.node)
wn := &tableNode{Node: req.node}
if req.forceSetLive {
wn.livenessChecks = 1
wn.isValidatedLive = true
}
b.entries = append(b.entries, wn)
b.replacements = deleteNode(b.replacements, wn.ID())
tab.nodeAdded(b, wn)
return true
}
// addReplacement adds n to the replacement cache of bucket b.
func (tab *Table) addReplacement(b *bucket, n *node) {
if contains(b.replacements, n.ID()) {
func (tab *Table) addReplacement(b *bucket, n *enode.Node) {
if containsID(b.replacements, n.ID()) {
// TODO: update ENR
return
}
@ -545,15 +551,15 @@ func (tab *Table) addReplacement(b *bucket, n *node) {
return
}
n.addedToTable = time.Now()
var removed *node
b.replacements, removed = pushNode(b.replacements, n, maxReplacements)
wn := &tableNode{Node: n, addedToTable: time.Now()}
var removed *tableNode
b.replacements, removed = pushNode(b.replacements, wn, maxReplacements)
if removed != nil {
tab.removeIP(b, removed.IP())
}
}
func (tab *Table) nodeAdded(b *bucket, n *node) {
func (tab *Table) nodeAdded(b *bucket, n *tableNode) {
if n.addedToTable == (time.Time{}) {
n.addedToTable = time.Now()
}
@ -567,7 +573,7 @@ func (tab *Table) nodeAdded(b *bucket, n *node) {
}
}
func (tab *Table) nodeRemoved(b *bucket, n *node) {
func (tab *Table) nodeRemoved(b *bucket, n *tableNode) {
tab.revalidation.nodeRemoved(n)
if tab.nodeRemovedHook != nil {
tab.nodeRemovedHook(b, n)
@ -579,8 +585,8 @@ func (tab *Table) nodeRemoved(b *bucket, n *node) {
// deleteInBucket removes node n from the table.
// If there are replacement nodes in the bucket, the node is replaced.
func (tab *Table) deleteInBucket(b *bucket, id enode.ID) *node {
index := slices.IndexFunc(b.entries, func(e *node) bool { return e.ID() == id })
func (tab *Table) deleteInBucket(b *bucket, id enode.ID) *tableNode {
index := slices.IndexFunc(b.entries, func(e *tableNode) bool { return e.ID() == id })
if index == -1 {
// Entry has been removed already.
return nil
@ -608,8 +614,8 @@ func (tab *Table) deleteInBucket(b *bucket, id enode.ID) *node {
// bumpInBucket updates a node record if it exists in the bucket.
// The second return value reports whether the node's endpoint (IP/port) was updated.
func (tab *Table) bumpInBucket(b *bucket, newRecord *enode.Node, isInbound bool) (n *node, endpointChanged bool) {
i := slices.IndexFunc(b.entries, func(elem *node) bool {
func (tab *Table) bumpInBucket(b *bucket, newRecord *enode.Node, isInbound bool) (n *tableNode, endpointChanged bool) {
i := slices.IndexFunc(b.entries, func(elem *tableNode) bool {
return elem.ID() == newRecord.ID()
})
if i == -1 {
@ -672,21 +678,12 @@ func (tab *Table) handleTrackRequest(op trackRequestOp) {
// Add found nodes.
for _, n := range op.foundNodes {
tab.handleAddNode(addNodeOp{n, false})
tab.handleAddNode(addNodeOp{n, false, false})
}
}
func contains(ns []*node, id enode.ID) bool {
for _, n := range ns {
if n.ID() == id {
return true
}
}
return false
}
// pushNode adds n to the front of list, keeping at most max items.
func pushNode(list []*node, n *node, max int) ([]*node, *node) {
func pushNode(list []*tableNode, n *tableNode, max int) ([]*tableNode, *tableNode) {
if len(list) < max {
list = append(list, nil)
}
@ -695,37 +692,3 @@ func pushNode(list []*node, n *node, max int) ([]*node, *node) {
list[0] = n
return list, removed
}
// deleteNode removes n from list.
func deleteNode(list []*node, n *node) []*node {
for i := range list {
if list[i].ID() == n.ID() {
return append(list[:i], list[i+1:]...)
}
}
return list
}
// nodesByDistance is a list of nodes, ordered by distance to target.
type nodesByDistance struct {
entries []*node
target enode.ID
}
// push adds the given node to the list, keeping the total size below maxElems.
func (h *nodesByDistance) push(n *node, maxElems int) {
ix := sort.Search(len(h.entries), func(i int) bool {
return enode.DistCmp(h.target, h.entries[i].ID(), n.ID()) > 0
})
end := len(h.entries)
if len(h.entries) < maxElems {
h.entries = append(h.entries, n)
}
if ix < end {
// Slide existing entries down to make room.
// This will overwrite the entry we just appended.
copy(h.entries[ix+1:], h.entries[ix:])
h.entries[ix] = n
}
}

@ -39,7 +39,7 @@ type tableRevalidation struct {
}
type revalidationResponse struct {
n *node
n *tableNode
newRecord *enode.Node
didRespond bool
}
@ -55,12 +55,12 @@ func (tr *tableRevalidation) init(cfg *Config) {
}
// nodeAdded is called when the table receives a new node.
func (tr *tableRevalidation) nodeAdded(tab *Table, n *node) {
func (tr *tableRevalidation) nodeAdded(tab *Table, n *tableNode) {
tr.fast.push(n, tab.cfg.Clock.Now(), &tab.rand)
}
// nodeRemoved is called when a node was removed from the table.
func (tr *tableRevalidation) nodeRemoved(n *node) {
func (tr *tableRevalidation) nodeRemoved(n *tableNode) {
if n.revalList == nil {
panic(fmt.Errorf("removed node %v has nil revalList", n.ID()))
}
@ -68,7 +68,7 @@ func (tr *tableRevalidation) nodeRemoved(n *node) {
}
// nodeEndpointChanged is called when a change in IP or port is detected.
func (tr *tableRevalidation) nodeEndpointChanged(tab *Table, n *node) {
func (tr *tableRevalidation) nodeEndpointChanged(tab *Table, n *tableNode) {
n.isValidatedLive = false
tr.moveToList(&tr.fast, n, tab.cfg.Clock.Now(), &tab.rand)
}
@ -90,7 +90,7 @@ func (tr *tableRevalidation) run(tab *Table, now mclock.AbsTime) (nextTime mcloc
}
// startRequest spawns a revalidation request for node n.
func (tr *tableRevalidation) startRequest(tab *Table, n *node) {
func (tr *tableRevalidation) startRequest(tab *Table, n *tableNode) {
if _, ok := tr.activeReq[n.ID()]; ok {
panic(fmt.Errorf("duplicate startRequest (node %v)", n.ID()))
}
@ -180,7 +180,7 @@ func (tr *tableRevalidation) handleResponse(tab *Table, resp revalidationRespons
}
// moveToList ensures n is in the 'dest' list.
func (tr *tableRevalidation) moveToList(dest *revalidationList, n *node, now mclock.AbsTime, rand randomSource) {
func (tr *tableRevalidation) moveToList(dest *revalidationList, n *tableNode, now mclock.AbsTime, rand randomSource) {
if n.revalList == dest {
return
}
@ -192,14 +192,14 @@ func (tr *tableRevalidation) moveToList(dest *revalidationList, n *node, now mcl
// revalidationList holds a list nodes and the next revalidation time.
type revalidationList struct {
nodes []*node
nodes []*tableNode
nextTime mclock.AbsTime
interval time.Duration
name string
}
// get returns a random node from the queue. Nodes in the 'exclude' map are not returned.
func (list *revalidationList) get(now mclock.AbsTime, rand randomSource, exclude map[enode.ID]struct{}) *node {
func (list *revalidationList) get(now mclock.AbsTime, rand randomSource, exclude map[enode.ID]struct{}) *tableNode {
if now < list.nextTime || len(list.nodes) == 0 {
return nil
}
@ -217,7 +217,7 @@ func (list *revalidationList) schedule(now mclock.AbsTime, rand randomSource) {
list.nextTime = now.Add(time.Duration(rand.Int63n(int64(list.interval))))
}
func (list *revalidationList) push(n *node, now mclock.AbsTime, rand randomSource) {
func (list *revalidationList) push(n *tableNode, now mclock.AbsTime, rand randomSource) {
list.nodes = append(list.nodes, n)
if list.nextTime == never {
list.schedule(now, rand)
@ -225,7 +225,7 @@ func (list *revalidationList) push(n *node, now mclock.AbsTime, rand randomSourc
n.revalList = list
}
func (list *revalidationList) remove(n *node) {
func (list *revalidationList) remove(n *tableNode) {
i := slices.Index(list.nodes, n)
if i == -1 {
panic(fmt.Errorf("node %v not found in list", n.ID()))
@ -238,7 +238,7 @@ func (list *revalidationList) remove(n *node) {
}
func (list *revalidationList) contains(id enode.ID) bool {
return slices.ContainsFunc(list.nodes, func(n *node) bool {
return slices.ContainsFunc(list.nodes, func(n *tableNode) bool {
return n.ID() == id
})
}

@ -110,10 +110,10 @@ func TestRevalidation_endpointUpdate(t *testing.T) {
}
tr.handleResponse(tab, resp)
if !tr.fast.contains(node.ID()) {
if tr.fast.nodes[0].ID() != node.ID() {
t.Fatal("node not contained in fast revalidation list")
}
if node.isValidatedLive {
if tr.fast.nodes[0].isValidatedLive {
t.Fatal("node is marked live after endpoint change")
}
}

@ -22,6 +22,7 @@ import (
"math/rand"
"net"
"reflect"
"slices"
"testing"
"testing/quick"
"time"
@ -64,7 +65,7 @@ func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding
// Fill up the sender's bucket.
replacementNodeKey, _ := crypto.HexToECDSA("45a915e4d060149eb4365960e6a7a45f334393093061116b197e3240065ff2d8")
replacementNode := wrapNode(enode.NewV4(&replacementNodeKey.PublicKey, net.IP{127, 0, 0, 1}, 99, 99))
replacementNode := enode.NewV4(&replacementNodeKey.PublicKey, net.IP{127, 0, 0, 1}, 99, 99)
last := fillBucket(tab, replacementNode.ID())
tab.mutex.Lock()
nodeEvents := newNodeEventRecorder(128)
@ -78,7 +79,7 @@ func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding
transport.dead[replacementNode.ID()] = !newNodeIsResponding
// Add replacement node to table.
tab.addFoundNode(replacementNode)
tab.addFoundNode(replacementNode, false)
t.Log("last:", last.ID())
t.Log("replacement:", replacementNode.ID())
@ -115,11 +116,11 @@ func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding
if l := len(bucket.entries); l != wantSize {
t.Errorf("wrong bucket size after revalidation: got %d, want %d", l, wantSize)
}
if ok := contains(bucket.entries, last.ID()); ok != lastInBucketIsResponding {
if ok := containsID(bucket.entries, last.ID()); ok != lastInBucketIsResponding {
t.Errorf("revalidated node found: %t, want: %t", ok, lastInBucketIsResponding)
}
wantNewEntry := newNodeIsResponding && !lastInBucketIsResponding
if ok := contains(bucket.entries, replacementNode.ID()); ok != wantNewEntry {
if ok := containsID(bucket.entries, replacementNode.ID()); ok != wantNewEntry {
t.Errorf("replacement node found: %t, want: %t", ok, wantNewEntry)
}
}
@ -153,7 +154,7 @@ func TestTable_IPLimit(t *testing.T) {
for i := 0; i < tableIPLimit+1; i++ {
n := nodeAtDistance(tab.self().ID(), i, net.IP{172, 0, 1, byte(i)})
tab.addFoundNode(n)
tab.addFoundNode(n, false)
}
if tab.len() > tableIPLimit {
t.Errorf("too many nodes in table")
@ -171,7 +172,7 @@ func TestTable_BucketIPLimit(t *testing.T) {
d := 3
for i := 0; i < bucketIPLimit+1; i++ {
n := nodeAtDistance(tab.self().ID(), d, net.IP{172, 0, 1, byte(i)})
tab.addFoundNode(n)
tab.addFoundNode(n, false)
}
if tab.len() > bucketIPLimit {
t.Errorf("too many nodes in table")
@ -232,7 +233,7 @@ func TestTable_findnodeByID(t *testing.T) {
// check that the result nodes have minimum distance to target.
for _, b := range tab.buckets {
for _, n := range b.entries {
if contains(result, n.ID()) {
if containsID(result, n.ID()) {
continue // don't run the check below for nodes in result
}
farthestResult := result[len(result)-1].ID()
@ -255,7 +256,7 @@ func TestTable_findnodeByID(t *testing.T) {
type closeTest struct {
Self enode.ID
Target enode.ID
All []*node
All []*enode.Node
N int
}
@ -268,8 +269,7 @@ func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value {
for _, id := range gen([]enode.ID{}, rand).([]enode.ID) {
r := new(enr.Record)
r.Set(enr.IP(genIP(rand)))
n := wrapNode(enode.SignNull(r, id))
n.livenessChecks = 1
n := enode.SignNull(r, id)
t.All = append(t.All, n)
}
return reflect.ValueOf(t)
@ -284,16 +284,16 @@ func TestTable_addInboundNode(t *testing.T) {
// Insert two nodes.
n1 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 1})
n2 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 2})
tab.addFoundNode(n1)
tab.addFoundNode(n2)
checkBucketContent(t, tab, []*enode.Node{n1.Node, n2.Node})
tab.addFoundNode(n1, false)
tab.addFoundNode(n2, false)
checkBucketContent(t, tab, []*enode.Node{n1, n2})
// Add a changed version of n2. The bucket should be updated.
newrec := n2.Record()
newrec.Set(enr.IP{99, 99, 99, 99})
n2v2 := enode.SignNull(newrec, n2.ID())
tab.addInboundNode(wrapNode(n2v2))
checkBucketContent(t, tab, []*enode.Node{n1.Node, n2v2})
tab.addInboundNode(n2v2)
checkBucketContent(t, tab, []*enode.Node{n1, n2v2})
// Try updating n2 without sequence number change. The update is accepted
// because it's inbound.
@ -301,8 +301,8 @@ func TestTable_addInboundNode(t *testing.T) {
newrec.Set(enr.IP{100, 100, 100, 100})
newrec.SetSeq(n2.Seq())
n2v3 := enode.SignNull(newrec, n2.ID())
tab.addInboundNode(wrapNode(n2v3))
checkBucketContent(t, tab, []*enode.Node{n1.Node, n2v3})
tab.addInboundNode(n2v3)
checkBucketContent(t, tab, []*enode.Node{n1, n2v3})
}
func TestTable_addFoundNode(t *testing.T) {
@ -314,16 +314,16 @@ func TestTable_addFoundNode(t *testing.T) {
// Insert two nodes.
n1 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 1})
n2 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 2})
tab.addFoundNode(n1)
tab.addFoundNode(n2)
checkBucketContent(t, tab, []*enode.Node{n1.Node, n2.Node})
tab.addFoundNode(n1, false)
tab.addFoundNode(n2, false)
checkBucketContent(t, tab, []*enode.Node{n1, n2})
// Add a changed version of n2. The bucket should be updated.
newrec := n2.Record()
newrec.Set(enr.IP{99, 99, 99, 99})
n2v2 := enode.SignNull(newrec, n2.ID())
tab.addFoundNode(wrapNode(n2v2))
checkBucketContent(t, tab, []*enode.Node{n1.Node, n2v2})
tab.addFoundNode(n2v2, false)
checkBucketContent(t, tab, []*enode.Node{n1, n2v2})
// Try updating n2 without a sequence number change.
// The update should not be accepted.
@ -331,8 +331,8 @@ func TestTable_addFoundNode(t *testing.T) {
newrec.Set(enr.IP{100, 100, 100, 100})
newrec.SetSeq(n2.Seq())
n2v3 := enode.SignNull(newrec, n2.ID())
tab.addFoundNode(wrapNode(n2v3))
checkBucketContent(t, tab, []*enode.Node{n1.Node, n2v2})
tab.addFoundNode(n2v3, false)
checkBucketContent(t, tab, []*enode.Node{n1, n2v2})
}
// This test checks that discv4 nodes can update their own endpoint via PING.
@ -345,13 +345,13 @@ func TestTable_addInboundNodeUpdateV4Accept(t *testing.T) {
// Add a v4 node.
key, _ := crypto.HexToECDSA("dd3757a8075e88d0f2b1431e7d3c5b1562e1c0aab9643707e8cbfcc8dae5cfe3")
n1 := enode.NewV4(&key.PublicKey, net.IP{88, 77, 66, 1}, 9000, 9000)
tab.addInboundNode(wrapNode(n1))
tab.addInboundNode(n1)
checkBucketContent(t, tab, []*enode.Node{n1})
// Add an updated version with changed IP.
// The update will be accepted because it is inbound.
n1v2 := enode.NewV4(&key.PublicKey, net.IP{99, 99, 99, 99}, 9000, 9000)
tab.addInboundNode(wrapNode(n1v2))
tab.addInboundNode(n1v2)
checkBucketContent(t, tab, []*enode.Node{n1v2})
}
@ -366,13 +366,13 @@ func TestTable_addFoundNodeV4UpdateReject(t *testing.T) {
// Add a v4 node.
key, _ := crypto.HexToECDSA("dd3757a8075e88d0f2b1431e7d3c5b1562e1c0aab9643707e8cbfcc8dae5cfe3")
n1 := enode.NewV4(&key.PublicKey, net.IP{88, 77, 66, 1}, 9000, 9000)
tab.addFoundNode(wrapNode(n1))
tab.addFoundNode(n1, false)
checkBucketContent(t, tab, []*enode.Node{n1})
// Add an updated version with changed IP.
// The update won't be accepted because it isn't inbound.
n1v2 := enode.NewV4(&key.PublicKey, net.IP{99, 99, 99, 99}, 9000, 9000)
tab.addFoundNode(wrapNode(n1v2))
tab.addFoundNode(n1v2, false)
checkBucketContent(t, tab, []*enode.Node{n1})
}
@ -413,8 +413,8 @@ func TestTable_revalidateSyncRecord(t *testing.T) {
var r enr.Record
r.Set(enr.IP(net.IP{127, 0, 0, 1}))
id := enode.ID{1}
n1 := wrapNode(enode.SignNull(&r, id))
tab.addFoundNode(n1)
n1 := enode.SignNull(&r, id)
tab.addFoundNode(n1, false)
// Update the node record.
r.Set(enr.WithEntry("foo", "bar"))
@ -437,7 +437,7 @@ func TestNodesPush(t *testing.T) {
n1 := nodeAtDistance(target, 255, intIP(1))
n2 := nodeAtDistance(target, 254, intIP(2))
n3 := nodeAtDistance(target, 253, intIP(3))
perm := [][]*node{
perm := [][]*enode.Node{
{n3, n2, n1},
{n3, n1, n2},
{n2, n3, n1},
@ -452,7 +452,7 @@ func TestNodesPush(t *testing.T) {
for _, n := range nodes {
list.push(n, 3)
}
if !slicesEqual(list.entries, perm[0], nodeIDEqual) {
if !slices.EqualFunc(list.entries, perm[0], nodeIDEqual) {
t.Fatal("not equal")
}
}
@ -463,28 +463,16 @@ func TestNodesPush(t *testing.T) {
for _, n := range nodes {
list.push(n, 2)
}
if !slicesEqual(list.entries, perm[0][:2], nodeIDEqual) {
if !slices.EqualFunc(list.entries, perm[0][:2], nodeIDEqual) {
t.Fatal("not equal")
}
}
}
func nodeIDEqual(n1, n2 *node) bool {
func nodeIDEqual[N nodeType](n1, n2 N) bool {
return n1.ID() == n2.ID()
}
func slicesEqual[T any](s1, s2 []T, check func(e1, e2 T) bool) bool {
if len(s1) != len(s2) {
return false
}
for i := range s1 {
if !check(s1[i], s2[i]) {
return false
}
}
return true
}
// gen wraps quick.Value so it's easier to use.
// it generates a random value of the given value's type.
func gen(typ interface{}, rand *rand.Rand) interface{} {

@ -56,18 +56,18 @@ func newInactiveTestTable(t transport, cfg Config) (*Table, *enode.DB) {
}
// nodeAtDistance creates a node for which enode.LogDist(base, n.id) == ld.
func nodeAtDistance(base enode.ID, ld int, ip net.IP) *node {
func nodeAtDistance(base enode.ID, ld int, ip net.IP) *enode.Node {
var r enr.Record
r.Set(enr.IP(ip))
r.Set(enr.UDP(30303))
return wrapNode(enode.SignNull(&r, idAtDistance(base, ld)))
return enode.SignNull(&r, idAtDistance(base, ld))
}
// nodesAtDistance creates n nodes for which enode.LogDist(base, node.ID()) == ld.
func nodesAtDistance(base enode.ID, ld int, n int) []*enode.Node {
results := make([]*enode.Node, n)
for i := range results {
results[i] = unwrapNode(nodeAtDistance(base, ld, intIP(i)))
results[i] = nodeAtDistance(base, ld, intIP(i))
}
return results
}
@ -105,12 +105,12 @@ func intIP(i int) net.IP {
}
// fillBucket inserts nodes into the given bucket until it is full.
func fillBucket(tab *Table, id enode.ID) (last *node) {
func fillBucket(tab *Table, id enode.ID) (last *tableNode) {
ld := enode.LogDist(tab.self().ID(), id)
b := tab.bucket(id)
for len(b.entries) < bucketSize {
node := nodeAtDistance(tab.self().ID(), ld, intIP(ld))
if !tab.addFoundNode(node) {
if !tab.addFoundNode(node, false) {
panic("node not added")
}
}
@ -119,13 +119,9 @@ func fillBucket(tab *Table, id enode.ID) (last *node) {
// fillTable adds nodes the table to the end of their corresponding bucket
// if the bucket is not full. The caller must not hold tab.mutex.
func fillTable(tab *Table, nodes []*node, setLive bool) {
func fillTable(tab *Table, nodes []*enode.Node, setLive bool) {
for _, n := range nodes {
if setLive {
n.livenessChecks = 1
n.isValidatedLive = true
}
tab.addFoundNode(n)
tab.addFoundNode(n, setLive)
}
}
@ -219,7 +215,7 @@ func (t *pingRecorder) RequestENR(n *enode.Node) (*enode.Node, error) {
return t.records[n.ID()], nil
}
func hasDuplicates(slice []*node) bool {
func hasDuplicates(slice []*enode.Node) bool {
seen := make(map[enode.ID]bool, len(slice))
for i, e := range slice {
if e == nil {
@ -261,14 +257,14 @@ func nodeEqual(n1 *enode.Node, n2 *enode.Node) bool {
return n1.ID() == n2.ID() && n1.IP().Equal(n2.IP())
}
func sortByID(nodes []*enode.Node) {
slices.SortFunc(nodes, func(a, b *enode.Node) int {
func sortByID[N nodeType](nodes []N) {
slices.SortFunc(nodes, func(a, b N) int {
return bytes.Compare(a.ID().Bytes(), b.ID().Bytes())
})
}
func sortedByDistanceTo(distbase enode.ID, slice []*node) bool {
return slices.IsSortedFunc(slice, func(a, b *node) int {
func sortedByDistanceTo(distbase enode.ID, slice []*enode.Node) bool {
return slices.IsSortedFunc(slice, func(a, b *enode.Node) int {
return enode.DistCmp(distbase, a.ID(), b.ID())
})
}
@ -304,7 +300,7 @@ type nodeEventRecorder struct {
}
type recordedNodeEvent struct {
node *node
node *tableNode
added bool
}
@ -314,7 +310,7 @@ func newNodeEventRecorder(buffer int) *nodeEventRecorder {
}
}
func (set *nodeEventRecorder) nodeAdded(b *bucket, n *node) {
func (set *nodeEventRecorder) nodeAdded(b *bucket, n *tableNode) {
select {
case set.evc <- recordedNodeEvent{n, true}:
default:
@ -322,7 +318,7 @@ func (set *nodeEventRecorder) nodeAdded(b *bucket, n *node) {
}
}
func (set *nodeEventRecorder) nodeRemoved(b *bucket, n *node) {
func (set *nodeEventRecorder) nodeRemoved(b *bucket, n *tableNode) {
select {
case set.evc <- recordedNodeEvent{n, false}:
default:

@ -19,7 +19,7 @@ package discover
import (
"crypto/ecdsa"
"fmt"
"net"
"net/netip"
"slices"
"testing"
@ -40,7 +40,7 @@ func TestUDPv4_Lookup(t *testing.T) {
}
// Seed table with initial node.
fillTable(test.table, []*node{wrapNode(lookupTestnet.node(256, 0))}, true)
fillTable(test.table, []*enode.Node{lookupTestnet.node(256, 0)}, true)
// Start the lookup.
resultC := make(chan []*enode.Node, 1)
@ -70,9 +70,9 @@ func TestUDPv4_LookupIterator(t *testing.T) {
defer test.close()
// Seed table with initial nodes.
bootnodes := make([]*node, len(lookupTestnet.dists[256]))
bootnodes := make([]*enode.Node, len(lookupTestnet.dists[256]))
for i := range lookupTestnet.dists[256] {
bootnodes[i] = wrapNode(lookupTestnet.node(256, i))
bootnodes[i] = lookupTestnet.node(256, i)
}
fillTable(test.table, bootnodes, true)
go serveTestnet(test, lookupTestnet)
@ -105,9 +105,9 @@ func TestUDPv4_LookupIteratorClose(t *testing.T) {
defer test.close()
// Seed table with initial nodes.
bootnodes := make([]*node, len(lookupTestnet.dists[256]))
bootnodes := make([]*enode.Node, len(lookupTestnet.dists[256]))
for i := range lookupTestnet.dists[256] {
bootnodes[i] = wrapNode(lookupTestnet.node(256, i))
bootnodes[i] = lookupTestnet.node(256, i)
}
fillTable(test.table, bootnodes, true)
go serveTestnet(test, lookupTestnet)
@ -136,7 +136,7 @@ func TestUDPv4_LookupIteratorClose(t *testing.T) {
func serveTestnet(test *udpTest, testnet *preminedTestnet) {
for done := false; !done; {
done = test.waitPacketOut(func(p v4wire.Packet, to *net.UDPAddr, hash []byte) {
done = test.waitPacketOut(func(p v4wire.Packet, to netip.AddrPort, hash []byte) {
n, key := testnet.nodeByAddr(to)
switch p.(type) {
case *v4wire.Ping:
@ -158,10 +158,10 @@ func checkLookupResults(t *testing.T, tn *preminedTestnet, results []*enode.Node
for _, e := range results {
t.Logf(" ld=%d, %x", enode.LogDist(tn.target.id(), e.ID()), e.ID().Bytes())
}
if hasDuplicates(wrapNodes(results)) {
if hasDuplicates(results) {
t.Errorf("result set contains duplicate entries")
}
if !sortedByDistanceTo(tn.target.id(), wrapNodes(results)) {
if !sortedByDistanceTo(tn.target.id(), results) {
t.Errorf("result set not sorted by distance to target")
}
wantNodes := tn.closest(len(results))
@ -264,9 +264,10 @@ func (tn *preminedTestnet) node(dist, index int) *enode.Node {
return n
}
func (tn *preminedTestnet) nodeByAddr(addr *net.UDPAddr) (*enode.Node, *ecdsa.PrivateKey) {
dist := int(addr.IP[1])<<8 + int(addr.IP[2])
index := int(addr.IP[3])
func (tn *preminedTestnet) nodeByAddr(addr netip.AddrPort) (*enode.Node, *ecdsa.PrivateKey) {
ip := addr.Addr().As4()
dist := int(ip[1])<<8 + int(ip[2])
index := int(ip[3])
key := tn.dists[dist][index]
return tn.node(dist, index), key
}
@ -274,7 +275,7 @@ func (tn *preminedTestnet) nodeByAddr(addr *net.UDPAddr) (*enode.Node, *ecdsa.Pr
func (tn *preminedTestnet) nodesAtDistance(dist int) []v4wire.Node {
result := make([]v4wire.Node, len(tn.dists[dist]))
for i := range result {
result[i] = nodeToRPC(wrapNode(tn.node(dist, i)))
result[i] = nodeToRPC(tn.node(dist, i))
}
return result
}

@ -26,6 +26,7 @@ import (
"fmt"
"io"
"net"
"net/netip"
"sync"
"time"
@ -45,6 +46,7 @@ var (
errClockWarp = errors.New("reply deadline too far in the future")
errClosed = errors.New("socket closed")
errLowPort = errors.New("low port")
errNoUDPEndpoint = errors.New("node has no UDP endpoint")
)
const (
@ -93,7 +95,7 @@ type UDPv4 struct {
type replyMatcher struct {
// these fields must match in the reply.
from enode.ID
ip net.IP
ip netip.Addr
ptype byte
// time when the request must complete
@ -119,7 +121,7 @@ type replyMatchFunc func(v4wire.Packet) (matched bool, requestDone bool)
// reply is a reply packet from a certain node.
type reply struct {
from enode.ID
ip net.IP
ip netip.Addr
data v4wire.Packet
// loop indicates whether there was
// a matching request by sending on this channel.
@ -201,9 +203,12 @@ func (t *UDPv4) Resolve(n *enode.Node) *enode.Node {
}
func (t *UDPv4) ourEndpoint() v4wire.Endpoint {
n := t.Self()
a := &net.UDPAddr{IP: n.IP(), Port: n.UDP()}
return v4wire.NewEndpoint(a, uint16(n.TCP()))
node := t.Self()
addr, ok := node.UDPEndpoint()
if !ok {
return v4wire.Endpoint{}
}
return v4wire.NewEndpoint(addr, uint16(node.TCP()))
}
// Ping sends a ping message to the given node.
@ -214,7 +219,11 @@ func (t *UDPv4) Ping(n *enode.Node) error {
// ping sends a ping message to the given node and waits for a reply.
func (t *UDPv4) ping(n *enode.Node) (seq uint64, err error) {
rm := t.sendPing(n.ID(), &net.UDPAddr{IP: n.IP(), Port: n.UDP()}, nil)
addr, ok := n.UDPEndpoint()
if !ok {
return 0, errNoUDPEndpoint
}
rm := t.sendPing(n.ID(), addr, nil)
if err = <-rm.errc; err == nil {
seq = rm.reply.(*v4wire.Pong).ENRSeq
}
@ -223,7 +232,7 @@ func (t *UDPv4) ping(n *enode.Node) (seq uint64, err error) {
// sendPing sends a ping message to the given node and invokes the callback
// when the reply arrives.
func (t *UDPv4) sendPing(toid enode.ID, toaddr *net.UDPAddr, callback func()) *replyMatcher {
func (t *UDPv4) sendPing(toid enode.ID, toaddr netip.AddrPort, callback func()) *replyMatcher {
req := t.makePing(toaddr)
packet, hash, err := v4wire.Encode(t.priv, req)
if err != nil {
@ -233,7 +242,7 @@ func (t *UDPv4) sendPing(toid enode.ID, toaddr *net.UDPAddr, callback func()) *r
}
// Add a matcher for the reply to the pending reply queue. Pongs are matched if they
// reference the ping we're about to send.
rm := t.pending(toid, toaddr.IP, v4wire.PongPacket, func(p v4wire.Packet) (matched bool, requestDone bool) {
rm := t.pending(toid, toaddr.Addr(), v4wire.PongPacket, func(p v4wire.Packet) (matched bool, requestDone bool) {
matched = bytes.Equal(p.(*v4wire.Pong).ReplyTok, hash)
if matched && callback != nil {
callback()
@ -241,12 +250,13 @@ func (t *UDPv4) sendPing(toid enode.ID, toaddr *net.UDPAddr, callback func()) *r
return matched, matched
})
// Send the packet.
t.localNode.UDPContact(toaddr)
toUDPAddr := &net.UDPAddr{IP: toaddr.Addr().AsSlice()}
t.localNode.UDPContact(toUDPAddr)
t.write(toaddr, toid, req.Name(), packet)
return rm
}
func (t *UDPv4) makePing(toaddr *net.UDPAddr) *v4wire.Ping {
func (t *UDPv4) makePing(toaddr netip.AddrPort) *v4wire.Ping {
return &v4wire.Ping{
Version: 4,
From: t.ourEndpoint(),
@ -290,35 +300,39 @@ func (t *UDPv4) newRandomLookup(ctx context.Context) *lookup {
func (t *UDPv4) newLookup(ctx context.Context, targetKey encPubkey) *lookup {
target := enode.ID(crypto.Keccak256Hash(targetKey[:]))
ekey := v4wire.Pubkey(targetKey)
it := newLookup(ctx, t.tab, target, func(n *node) ([]*node, error) {
return t.findnode(n.ID(), n.addr(), ekey)
it := newLookup(ctx, t.tab, target, func(n *enode.Node) ([]*enode.Node, error) {
addr, ok := n.UDPEndpoint()
if !ok {
return nil, errNoUDPEndpoint
}
return t.findnode(n.ID(), addr, ekey)
})
return it
}
// findnode sends a findnode request to the given node and waits until
// the node has sent up to k neighbors.
func (t *UDPv4) findnode(toid enode.ID, toaddr *net.UDPAddr, target v4wire.Pubkey) ([]*node, error) {
t.ensureBond(toid, toaddr)
func (t *UDPv4) findnode(toid enode.ID, toAddrPort netip.AddrPort, target v4wire.Pubkey) ([]*enode.Node, error) {
t.ensureBond(toid, toAddrPort)
// Add a matcher for 'neighbours' replies to the pending reply queue. The matcher is
// active until enough nodes have been received.
nodes := make([]*node, 0, bucketSize)
nodes := make([]*enode.Node, 0, bucketSize)
nreceived := 0
rm := t.pending(toid, toaddr.IP, v4wire.NeighborsPacket, func(r v4wire.Packet) (matched bool, requestDone bool) {
rm := t.pending(toid, toAddrPort.Addr(), v4wire.NeighborsPacket, func(r v4wire.Packet) (matched bool, requestDone bool) {
reply := r.(*v4wire.Neighbors)
for _, rn := range reply.Nodes {
nreceived++
n, err := t.nodeFromRPC(toaddr, rn)
n, err := t.nodeFromRPC(toAddrPort, rn)
if err != nil {
t.log.Trace("Invalid neighbor node received", "ip", rn.IP, "addr", toaddr, "err", err)
t.log.Trace("Invalid neighbor node received", "ip", rn.IP, "addr", toAddrPort, "err", err)
continue
}
nodes = append(nodes, n)
}
return true, nreceived >= bucketSize
})
t.send(toaddr, toid, &v4wire.Findnode{
t.send(toAddrPort, toid, &v4wire.Findnode{
Target: target,
Expiration: uint64(time.Now().Add(expiration).Unix()),
})
@ -336,7 +350,7 @@ func (t *UDPv4) findnode(toid enode.ID, toaddr *net.UDPAddr, target v4wire.Pubke
// RequestENR sends ENRRequest to the given node and waits for a response.
func (t *UDPv4) RequestENR(n *enode.Node) (*enode.Node, error) {
addr := &net.UDPAddr{IP: n.IP(), Port: n.UDP()}
addr, _ := n.UDPEndpoint()
t.ensureBond(n.ID(), addr)
req := &v4wire.ENRRequest{
@ -349,7 +363,7 @@ func (t *UDPv4) RequestENR(n *enode.Node) (*enode.Node, error) {
// Add a matcher for the reply to the pending reply queue. Responses are matched if
// they reference the request we're about to send.
rm := t.pending(n.ID(), addr.IP, v4wire.ENRResponsePacket, func(r v4wire.Packet) (matched bool, requestDone bool) {
rm := t.pending(n.ID(), addr.Addr(), v4wire.ENRResponsePacket, func(r v4wire.Packet) (matched bool, requestDone bool) {
matched = bytes.Equal(r.(*v4wire.ENRResponse).ReplyTok, hash)
return matched, matched
})
@ -369,7 +383,7 @@ func (t *UDPv4) RequestENR(n *enode.Node) (*enode.Node, error) {
if respN.Seq() < n.Seq() {
return n, nil // response record is older
}
if err := netutil.CheckRelayIP(addr.IP, respN.IP()); err != nil {
if err := netutil.CheckRelayIP(addr.Addr().AsSlice(), respN.IP()); err != nil {
return nil, fmt.Errorf("invalid IP in response record: %v", err)
}
return respN, nil
@ -381,7 +395,7 @@ func (t *UDPv4) TableBuckets() [][]BucketNode {
// pending adds a reply matcher to the pending reply queue.
// see the documentation of type replyMatcher for a detailed explanation.
func (t *UDPv4) pending(id enode.ID, ip net.IP, ptype byte, callback replyMatchFunc) *replyMatcher {
func (t *UDPv4) pending(id enode.ID, ip netip.Addr, ptype byte, callback replyMatchFunc) *replyMatcher {
ch := make(chan error, 1)
p := &replyMatcher{from: id, ip: ip, ptype: ptype, callback: callback, errc: ch}
select {
@ -395,7 +409,7 @@ func (t *UDPv4) pending(id enode.ID, ip net.IP, ptype byte, callback replyMatchF
// handleReply dispatches a reply packet, invoking reply matchers. It returns
// whether any matcher considered the packet acceptable.
func (t *UDPv4) handleReply(from enode.ID, fromIP net.IP, req v4wire.Packet) bool {
func (t *UDPv4) handleReply(from enode.ID, fromIP netip.Addr, req v4wire.Packet) bool {
matched := make(chan bool, 1)
select {
case t.gotreply <- reply{from, fromIP, req, matched}:
@ -461,7 +475,7 @@ func (t *UDPv4) loop() {
var matched bool // whether any replyMatcher considered the reply acceptable.
for el := plist.Front(); el != nil; el = el.Next() {
p := el.Value.(*replyMatcher)
if p.from == r.from && p.ptype == r.data.Kind() && p.ip.Equal(r.ip) {
if p.from == r.from && p.ptype == r.data.Kind() && p.ip == r.ip {
ok, requestDone := p.callback(r.data)
matched = matched || ok
p.reply = r.data
@ -500,7 +514,7 @@ func (t *UDPv4) loop() {
}
}
func (t *UDPv4) send(toaddr *net.UDPAddr, toid enode.ID, req v4wire.Packet) ([]byte, error) {
func (t *UDPv4) send(toaddr netip.AddrPort, toid enode.ID, req v4wire.Packet) ([]byte, error) {
packet, hash, err := v4wire.Encode(t.priv, req)
if err != nil {
return hash, err
@ -508,8 +522,8 @@ func (t *UDPv4) send(toaddr *net.UDPAddr, toid enode.ID, req v4wire.Packet) ([]b
return hash, t.write(toaddr, toid, req.Name(), packet)
}
func (t *UDPv4) write(toaddr *net.UDPAddr, toid enode.ID, what string, packet []byte) error {
_, err := t.conn.WriteToUDP(packet, toaddr)
func (t *UDPv4) write(toaddr netip.AddrPort, toid enode.ID, what string, packet []byte) error {
_, err := t.conn.WriteToUDPAddrPort(packet, toaddr)
t.log.Trace(">> "+what, "id", toid, "addr", toaddr, "err", err)
return err
}
@ -523,7 +537,7 @@ func (t *UDPv4) readLoop(unhandled chan<- ReadPacket) {
buf := make([]byte, maxPacketSize)
for {
nbytes, from, err := t.conn.ReadFromUDP(buf)
nbytes, from, err := t.conn.ReadFromUDPAddrPort(buf)
if netutil.IsTemporaryError(err) {
// Ignore temporary read errors.
t.log.Debug("Temporary UDP read error", "err", err)
@ -544,7 +558,7 @@ func (t *UDPv4) readLoop(unhandled chan<- ReadPacket) {
}
}
func (t *UDPv4) handlePacket(from *net.UDPAddr, buf []byte) error {
func (t *UDPv4) handlePacket(from netip.AddrPort, buf []byte) error {
rawpacket, fromKey, hash, err := v4wire.Decode(buf)
if err != nil {
t.log.Debug("Bad discv4 packet", "addr", from, "err", err)
@ -563,15 +577,16 @@ func (t *UDPv4) handlePacket(from *net.UDPAddr, buf []byte) error {
}
// checkBond checks if the given node has a recent enough endpoint proof.
func (t *UDPv4) checkBond(id enode.ID, ip net.IP) bool {
return time.Since(t.db.LastPongReceived(id, ip)) < bondExpiration
func (t *UDPv4) checkBond(id enode.ID, ip netip.AddrPort) bool {
return time.Since(t.db.LastPongReceived(id, ip.Addr().AsSlice())) < bondExpiration
}
// ensureBond solicits a ping from a node if we haven't seen a ping from it for a while.
// This ensures there is a valid endpoint proof on the remote end.
func (t *UDPv4) ensureBond(toid enode.ID, toaddr *net.UDPAddr) {
tooOld := time.Since(t.db.LastPingReceived(toid, toaddr.IP)) > bondExpiration
if tooOld || t.db.FindFails(toid, toaddr.IP) > maxFindnodeFailures {
func (t *UDPv4) ensureBond(toid enode.ID, toaddr netip.AddrPort) {
ip := toaddr.Addr().AsSlice()
tooOld := time.Since(t.db.LastPingReceived(toid, ip)) > bondExpiration
if tooOld || t.db.FindFails(toid, ip) > maxFindnodeFailures {
rm := t.sendPing(toid, toaddr, nil)
<-rm.errc
// Wait for them to ping back and process our pong.
@ -579,11 +594,11 @@ func (t *UDPv4) ensureBond(toid enode.ID, toaddr *net.UDPAddr) {
}
}
func (t *UDPv4) nodeFromRPC(sender *net.UDPAddr, rn v4wire.Node) (*node, error) {
func (t *UDPv4) nodeFromRPC(sender netip.AddrPort, rn v4wire.Node) (*enode.Node, error) {
if rn.UDP <= 1024 {
return nil, errLowPort
}
if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil {
if err := netutil.CheckRelayIP(sender.Addr().AsSlice(), rn.IP); err != nil {
return nil, err
}
if t.netrestrict != nil && !t.netrestrict.Contains(rn.IP) {
@ -593,12 +608,12 @@ func (t *UDPv4) nodeFromRPC(sender *net.UDPAddr, rn v4wire.Node) (*node, error)
if err != nil {
return nil, err
}
n := wrapNode(enode.NewV4(key, rn.IP, int(rn.TCP), int(rn.UDP)))
n := enode.NewV4(key, rn.IP, int(rn.TCP), int(rn.UDP))
err = n.ValidateComplete()
return n, err
}
func nodeToRPC(n *node) v4wire.Node {
func nodeToRPC(n *enode.Node) v4wire.Node {
var key ecdsa.PublicKey
var ekey v4wire.Pubkey
if err := n.Load((*enode.Secp256k1)(&key)); err == nil {
@ -637,14 +652,14 @@ type packetHandlerV4 struct {
senderKey *ecdsa.PublicKey // used for ping
// preverify checks whether the packet is valid and should be handled at all.
preverify func(p *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, fromKey v4wire.Pubkey) error
preverify func(p *packetHandlerV4, from netip.AddrPort, fromID enode.ID, fromKey v4wire.Pubkey) error
// handle handles the packet.
handle func(req *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, mac []byte)
handle func(req *packetHandlerV4, from netip.AddrPort, fromID enode.ID, mac []byte)
}
// PING/v4
func (t *UDPv4) verifyPing(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, fromKey v4wire.Pubkey) error {
func (t *UDPv4) verifyPing(h *packetHandlerV4, from netip.AddrPort, fromID enode.ID, fromKey v4wire.Pubkey) error {
req := h.Packet.(*v4wire.Ping)
if v4wire.Expired(req.Expiration) {
@ -658,7 +673,7 @@ func (t *UDPv4) verifyPing(h *packetHandlerV4, from *net.UDPAddr, fromID enode.I
return nil
}
func (t *UDPv4) handlePing(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, mac []byte) {
func (t *UDPv4) handlePing(h *packetHandlerV4, from netip.AddrPort, fromID enode.ID, mac []byte) {
req := h.Packet.(*v4wire.Ping)
// Reply.
@ -670,8 +685,9 @@ func (t *UDPv4) handlePing(h *packetHandlerV4, from *net.UDPAddr, fromID enode.I
})
// Ping back if our last pong on file is too far in the past.
n := wrapNode(enode.NewV4(h.senderKey, from.IP, int(req.From.TCP), from.Port))
if time.Since(t.db.LastPongReceived(n.ID(), from.IP)) > bondExpiration {
fromIP := from.Addr().AsSlice()
n := enode.NewV4(h.senderKey, fromIP, int(req.From.TCP), int(from.Port()))
if time.Since(t.db.LastPongReceived(n.ID(), fromIP)) > bondExpiration {
t.sendPing(fromID, from, func() {
t.tab.addInboundNode(n)
})
@ -680,35 +696,40 @@ func (t *UDPv4) handlePing(h *packetHandlerV4, from *net.UDPAddr, fromID enode.I
}
// Update node database and endpoint predictor.
t.db.UpdateLastPingReceived(n.ID(), from.IP, time.Now())
t.localNode.UDPEndpointStatement(from, &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)})
t.db.UpdateLastPingReceived(n.ID(), fromIP, time.Now())
fromUDPAddr := &net.UDPAddr{IP: fromIP, Port: int(from.Port())}
toUDPAddr := &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)}
t.localNode.UDPEndpointStatement(fromUDPAddr, toUDPAddr)
}
// PONG/v4
func (t *UDPv4) verifyPong(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, fromKey v4wire.Pubkey) error {
func (t *UDPv4) verifyPong(h *packetHandlerV4, from netip.AddrPort, fromID enode.ID, fromKey v4wire.Pubkey) error {
req := h.Packet.(*v4wire.Pong)
if v4wire.Expired(req.Expiration) {
return errExpired
}
if !t.handleReply(fromID, from.IP, req) {
if !t.handleReply(fromID, from.Addr(), req) {
return errUnsolicitedReply
}
t.localNode.UDPEndpointStatement(from, &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)})
t.db.UpdateLastPongReceived(fromID, from.IP, time.Now())
fromIP := from.Addr().AsSlice()
fromUDPAddr := &net.UDPAddr{IP: fromIP, Port: int(from.Port())}
toUDPAddr := &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)}
t.localNode.UDPEndpointStatement(fromUDPAddr, toUDPAddr)
t.db.UpdateLastPongReceived(fromID, fromIP, time.Now())
return nil
}
// FINDNODE/v4
func (t *UDPv4) verifyFindnode(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, fromKey v4wire.Pubkey) error {
func (t *UDPv4) verifyFindnode(h *packetHandlerV4, from netip.AddrPort, fromID enode.ID, fromKey v4wire.Pubkey) error {
req := h.Packet.(*v4wire.Findnode)
if v4wire.Expired(req.Expiration) {
return errExpired
}
if !t.checkBond(fromID, from.IP) {
if !t.checkBond(fromID, from) {
// No endpoint proof pong exists, we don't process the packet. This prevents an
// attack vector where the discovery protocol could be used to amplify traffic in a
// DDOS attack. A malicious actor would send a findnode request with the IP address
@ -720,7 +741,7 @@ func (t *UDPv4) verifyFindnode(h *packetHandlerV4, from *net.UDPAddr, fromID eno
return nil
}
func (t *UDPv4) handleFindnode(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, mac []byte) {
func (t *UDPv4) handleFindnode(h *packetHandlerV4, from netip.AddrPort, fromID enode.ID, mac []byte) {
req := h.Packet.(*v4wire.Findnode)
// Determine closest nodes.
@ -732,7 +753,8 @@ func (t *UDPv4) handleFindnode(h *packetHandlerV4, from *net.UDPAddr, fromID eno
p := v4wire.Neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())}
var sent bool
for _, n := range closest {
if netutil.CheckRelayIP(from.IP, n.IP()) == nil {
fromIP := from.Addr().AsSlice()
if netutil.CheckRelayIP(fromIP, n.IP()) == nil {
p.Nodes = append(p.Nodes, nodeToRPC(n))
}
if len(p.Nodes) == v4wire.MaxNeighbors {
@ -748,13 +770,13 @@ func (t *UDPv4) handleFindnode(h *packetHandlerV4, from *net.UDPAddr, fromID eno
// NEIGHBORS/v4
func (t *UDPv4) verifyNeighbors(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, fromKey v4wire.Pubkey) error {
func (t *UDPv4) verifyNeighbors(h *packetHandlerV4, from netip.AddrPort, fromID enode.ID, fromKey v4wire.Pubkey) error {
req := h.Packet.(*v4wire.Neighbors)
if v4wire.Expired(req.Expiration) {
return errExpired
}
if !t.handleReply(fromID, from.IP, h.Packet) {
if !t.handleReply(fromID, from.Addr(), h.Packet) {
return errUnsolicitedReply
}
return nil
@ -762,19 +784,19 @@ func (t *UDPv4) verifyNeighbors(h *packetHandlerV4, from *net.UDPAddr, fromID en
// ENRREQUEST/v4
func (t *UDPv4) verifyENRRequest(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, fromKey v4wire.Pubkey) error {
func (t *UDPv4) verifyENRRequest(h *packetHandlerV4, from netip.AddrPort, fromID enode.ID, fromKey v4wire.Pubkey) error {
req := h.Packet.(*v4wire.ENRRequest)
if v4wire.Expired(req.Expiration) {
return errExpired
}
if !t.checkBond(fromID, from.IP) {
if !t.checkBond(fromID, from) {
return errUnknownNode
}
return nil
}
func (t *UDPv4) handleENRRequest(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, mac []byte) {
func (t *UDPv4) handleENRRequest(h *packetHandlerV4, from netip.AddrPort, fromID enode.ID, mac []byte) {
t.send(from, fromID, &v4wire.ENRResponse{
ReplyTok: mac,
Record: *t.localNode.Node().Record(),
@ -783,8 +805,8 @@ func (t *UDPv4) handleENRRequest(h *packetHandlerV4, from *net.UDPAddr, fromID e
// ENRRESPONSE/v4
func (t *UDPv4) verifyENRResponse(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, fromKey v4wire.Pubkey) error {
if !t.handleReply(fromID, from.IP, h.Packet) {
func (t *UDPv4) verifyENRResponse(h *packetHandlerV4, from netip.AddrPort, fromID enode.ID, fromKey v4wire.Pubkey) error {
if !t.handleReply(fromID, from.Addr(), h.Packet) {
return errUnsolicitedReply
}
return nil

@ -26,6 +26,7 @@ import (
"io"
"math/rand"
"net"
"net/netip"
"reflect"
"sync"
"testing"
@ -55,7 +56,7 @@ type udpTest struct {
udp *UDPv4
sent [][]byte
localkey, remotekey *ecdsa.PrivateKey
remoteaddr *net.UDPAddr
remoteaddr netip.AddrPort
}
func newUDPTest(t *testing.T) *udpTest {
@ -64,7 +65,7 @@ func newUDPTest(t *testing.T) *udpTest {
pipe: newpipe(),
localkey: newkey(),
remotekey: newkey(),
remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303},
remoteaddr: netip.MustParseAddrPort("10.0.1.99:30303"),
}
test.db, _ = enode.OpenDB("")
@ -92,7 +93,7 @@ func (test *udpTest) packetIn(wantError error, data v4wire.Packet) {
}
// handles a packet as if it had been sent to the transport by the key/endpoint.
func (test *udpTest) packetInFrom(wantError error, key *ecdsa.PrivateKey, addr *net.UDPAddr, data v4wire.Packet) {
func (test *udpTest) packetInFrom(wantError error, key *ecdsa.PrivateKey, addr netip.AddrPort, data v4wire.Packet) {
test.t.Helper()
enc, _, err := v4wire.Encode(key, data)
@ -106,7 +107,7 @@ func (test *udpTest) packetInFrom(wantError error, key *ecdsa.PrivateKey, addr *
}
// waits for a packet to be sent by the transport.
// validate should have type func(X, *net.UDPAddr, []byte), where X is a packet type.
// validate should have type func(X, netip.AddrPort, []byte), where X is a packet type.
func (test *udpTest) waitPacketOut(validate interface{}) (closed bool) {
test.t.Helper()
@ -128,7 +129,7 @@ func (test *udpTest) waitPacketOut(validate interface{}) (closed bool) {
test.t.Errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype)
return false
}
fn.Call([]reflect.Value{reflect.ValueOf(p), reflect.ValueOf(&dgram.to), reflect.ValueOf(hash)})
fn.Call([]reflect.Value{reflect.ValueOf(p), reflect.ValueOf(dgram.to), reflect.ValueOf(hash)})
return false
}
@ -236,7 +237,7 @@ func TestUDPv4_findnodeTimeout(t *testing.T) {
test := newUDPTest(t)
defer test.close()
toaddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 2222}
toaddr := netip.AddrPortFrom(netip.MustParseAddr("1.2.3.4"), 2222)
toid := enode.ID{1, 2, 3, 4}
target := v4wire.Pubkey{4, 5, 6, 7}
result, err := test.udp.findnode(toid, toaddr, target)
@ -261,26 +262,25 @@ func TestUDPv4_findnode(t *testing.T) {
for i := 0; i < numCandidates; i++ {
key := newkey()
ip := net.IP{10, 13, 0, byte(i)}
n := wrapNode(enode.NewV4(&key.PublicKey, ip, 0, 2000))
n := enode.NewV4(&key.PublicKey, ip, 0, 2000)
// Ensure half of table content isn't verified live yet.
if i > numCandidates/2 {
n.isValidatedLive = true
live[n.ID()] = true
}
test.table.addFoundNode(n, live[n.ID()])
nodes.push(n, numCandidates)
}
fillTable(test.table, nodes.entries, false)
// ensure there's a bond with the test node,
// findnode won't be accepted otherwise.
remoteID := v4wire.EncodePubkey(&test.remotekey.PublicKey).ID()
test.table.db.UpdateLastPongReceived(remoteID, test.remoteaddr.IP, time.Now())
test.table.db.UpdateLastPongReceived(remoteID, test.remoteaddr.Addr().AsSlice(), time.Now())
// check that closest neighbors are returned.
expected := test.table.findnodeByID(testTarget.ID(), bucketSize, true)
test.packetIn(nil, &v4wire.Findnode{Target: testTarget, Expiration: futureExp})
waitNeighbors := func(want []*node) {
test.waitPacketOut(func(p *v4wire.Neighbors, to *net.UDPAddr, hash []byte) {
waitNeighbors := func(want []*enode.Node) {
test.waitPacketOut(func(p *v4wire.Neighbors, to netip.AddrPort, hash []byte) {
if len(p.Nodes) != len(want) {
t.Errorf("wrong number of results: got %d, want %d", len(p.Nodes), len(want))
return
@ -309,10 +309,10 @@ func TestUDPv4_findnodeMultiReply(t *testing.T) {
defer test.close()
rid := enode.PubkeyToIDV4(&test.remotekey.PublicKey)
test.table.db.UpdateLastPingReceived(rid, test.remoteaddr.IP, time.Now())
test.table.db.UpdateLastPingReceived(rid, test.remoteaddr.Addr().AsSlice(), time.Now())
// queue a pending findnode request
resultc, errc := make(chan []*node, 1), make(chan error, 1)
resultc, errc := make(chan []*enode.Node, 1), make(chan error, 1)
go func() {
rid := encodePubkey(&test.remotekey.PublicKey).id()
ns, err := test.udp.findnode(rid, test.remoteaddr, testTarget)
@ -325,18 +325,18 @@ func TestUDPv4_findnodeMultiReply(t *testing.T) {
// wait for the findnode to be sent.
// after it is sent, the transport is waiting for a reply
test.waitPacketOut(func(p *v4wire.Findnode, to *net.UDPAddr, hash []byte) {
test.waitPacketOut(func(p *v4wire.Findnode, to netip.AddrPort, hash []byte) {
if p.Target != testTarget {
t.Errorf("wrong target: got %v, want %v", p.Target, testTarget)
}
})
// send the reply as two packets.
list := []*node{
wrapNode(enode.MustParse("enode://ba85011c70bcc5c04d8607d3a0ed29aa6179c092cbdda10d5d32684fb33ed01bd94f588ca8f91ac48318087dcb02eaf36773a7a453f0eedd6742af668097b29c@10.0.1.16:30303?discport=30304")),
wrapNode(enode.MustParse("enode://81fa361d25f157cd421c60dcc28d8dac5ef6a89476633339c5df30287474520caca09627da18543d9079b5b288698b542d56167aa5c09111e55acdbbdf2ef799@10.0.1.16:30303")),
wrapNode(enode.MustParse("enode://9bffefd833d53fac8e652415f4973bee289e8b1a5c6c4cbe70abf817ce8a64cee11b823b66a987f51aaa9fba0d6a91b3e6bf0d5a5d1042de8e9eeea057b217f8@10.0.1.36:30301?discport=17")),
wrapNode(enode.MustParse("enode://1b5b4aa662d7cb44a7221bfba67302590b643028197a7d5214790f3bac7aaa4a3241be9e83c09cf1f6c69d007c634faae3dc1b1221793e8446c0b3a09de65960@10.0.1.16:30303")),
list := []*enode.Node{
enode.MustParse("enode://ba85011c70bcc5c04d8607d3a0ed29aa6179c092cbdda10d5d32684fb33ed01bd94f588ca8f91ac48318087dcb02eaf36773a7a453f0eedd6742af668097b29c@10.0.1.16:30303?discport=30304"),
enode.MustParse("enode://81fa361d25f157cd421c60dcc28d8dac5ef6a89476633339c5df30287474520caca09627da18543d9079b5b288698b542d56167aa5c09111e55acdbbdf2ef799@10.0.1.16:30303"),
enode.MustParse("enode://9bffefd833d53fac8e652415f4973bee289e8b1a5c6c4cbe70abf817ce8a64cee11b823b66a987f51aaa9fba0d6a91b3e6bf0d5a5d1042de8e9eeea057b217f8@10.0.1.36:30301?discport=17"),
enode.MustParse("enode://1b5b4aa662d7cb44a7221bfba67302590b643028197a7d5214790f3bac7aaa4a3241be9e83c09cf1f6c69d007c634faae3dc1b1221793e8446c0b3a09de65960@10.0.1.16:30303"),
}
rpclist := make([]v4wire.Node, len(list))
for i := range list {
@ -368,8 +368,8 @@ func TestUDPv4_pingMatch(t *testing.T) {
crand.Read(randToken)
test.packetIn(nil, &v4wire.Ping{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp})
test.waitPacketOut(func(*v4wire.Pong, *net.UDPAddr, []byte) {})
test.waitPacketOut(func(*v4wire.Ping, *net.UDPAddr, []byte) {})
test.waitPacketOut(func(*v4wire.Pong, netip.AddrPort, []byte) {})
test.waitPacketOut(func(*v4wire.Ping, netip.AddrPort, []byte) {})
test.packetIn(errUnsolicitedReply, &v4wire.Pong{ReplyTok: randToken, To: testLocalAnnounced, Expiration: futureExp})
}
@ -379,10 +379,10 @@ func TestUDPv4_pingMatchIP(t *testing.T) {
defer test.close()
test.packetIn(nil, &v4wire.Ping{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp})
test.waitPacketOut(func(*v4wire.Pong, *net.UDPAddr, []byte) {})
test.waitPacketOut(func(*v4wire.Pong, netip.AddrPort, []byte) {})
test.waitPacketOut(func(p *v4wire.Ping, to *net.UDPAddr, hash []byte) {
wrongAddr := &net.UDPAddr{IP: net.IP{33, 44, 1, 2}, Port: 30000}
test.waitPacketOut(func(p *v4wire.Ping, to netip.AddrPort, hash []byte) {
wrongAddr := netip.MustParseAddrPort("33.44.1.2:30000")
test.packetInFrom(errUnsolicitedReply, test.remotekey, wrongAddr, &v4wire.Pong{
ReplyTok: hash,
To: testLocalAnnounced,
@ -393,41 +393,36 @@ func TestUDPv4_pingMatchIP(t *testing.T) {
func TestUDPv4_successfulPing(t *testing.T) {
test := newUDPTest(t)
added := make(chan *node, 1)
test.table.nodeAddedHook = func(b *bucket, n *node) { added <- n }
added := make(chan *tableNode, 1)
test.table.nodeAddedHook = func(b *bucket, n *tableNode) { added <- n }
defer test.close()
// The remote side sends a ping packet to initiate the exchange.
go test.packetIn(nil, &v4wire.Ping{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp})
// The ping is replied to.
test.waitPacketOut(func(p *v4wire.Pong, to *net.UDPAddr, hash []byte) {
test.waitPacketOut(func(p *v4wire.Pong, to netip.AddrPort, hash []byte) {
pinghash := test.sent[0][:32]
if !bytes.Equal(p.ReplyTok, pinghash) {
t.Errorf("got pong.ReplyTok %x, want %x", p.ReplyTok, pinghash)
}
wantTo := v4wire.Endpoint{
// The mirrored UDP address is the UDP packet sender
IP: test.remoteaddr.IP, UDP: uint16(test.remoteaddr.Port),
// The mirrored TCP port is the one from the ping packet
TCP: testRemote.TCP,
}
// The mirrored UDP address is the UDP packet sender.
// The mirrored TCP port is the one from the ping packet.
wantTo := v4wire.NewEndpoint(test.remoteaddr, testRemote.TCP)
if !reflect.DeepEqual(p.To, wantTo) {
t.Errorf("got pong.To %v, want %v", p.To, wantTo)
}
})
// Remote is unknown, the table pings back.
test.waitPacketOut(func(p *v4wire.Ping, to *net.UDPAddr, hash []byte) {
if !reflect.DeepEqual(p.From, test.udp.ourEndpoint()) {
test.waitPacketOut(func(p *v4wire.Ping, to netip.AddrPort, hash []byte) {
wantFrom := test.udp.ourEndpoint()
wantFrom.IP = net.IP{}
if !reflect.DeepEqual(p.From, wantFrom) {
t.Errorf("got ping.From %#v, want %#v", p.From, test.udp.ourEndpoint())
}
wantTo := v4wire.Endpoint{
// The mirrored UDP address is the UDP packet sender.
IP: test.remoteaddr.IP,
UDP: uint16(test.remoteaddr.Port),
TCP: 0,
}
// The mirrored UDP address is the UDP packet sender.
wantTo := v4wire.NewEndpoint(test.remoteaddr, 0)
if !reflect.DeepEqual(p.To, wantTo) {
t.Errorf("got ping.To %v, want %v", p.To, wantTo)
}
@ -442,11 +437,11 @@ func TestUDPv4_successfulPing(t *testing.T) {
if n.ID() != rid {
t.Errorf("node has wrong ID: got %v, want %v", n.ID(), rid)
}
if !n.IP().Equal(test.remoteaddr.IP) {
t.Errorf("node has wrong IP: got %v, want: %v", n.IP(), test.remoteaddr.IP)
if !n.IP().Equal(test.remoteaddr.Addr().AsSlice()) {
t.Errorf("node has wrong IP: got %v, want: %v", n.IP(), test.remoteaddr.Addr())
}
if n.UDP() != test.remoteaddr.Port {
t.Errorf("node has wrong UDP port: got %v, want: %v", n.UDP(), test.remoteaddr.Port)
if n.UDP() != int(test.remoteaddr.Port()) {
t.Errorf("node has wrong UDP port: got %v, want: %v", n.UDP(), test.remoteaddr.Port())
}
if n.TCP() != int(testRemote.TCP) {
t.Errorf("node has wrong TCP port: got %v, want: %v", n.TCP(), testRemote.TCP)
@ -469,12 +464,12 @@ func TestUDPv4_EIP868(t *testing.T) {
// Perform endpoint proof and check for sequence number in packet tail.
test.packetIn(nil, &v4wire.Ping{Expiration: futureExp})
test.waitPacketOut(func(p *v4wire.Pong, addr *net.UDPAddr, hash []byte) {
test.waitPacketOut(func(p *v4wire.Pong, addr netip.AddrPort, hash []byte) {
if p.ENRSeq != wantNode.Seq() {
t.Errorf("wrong sequence number in pong: %d, want %d", p.ENRSeq, wantNode.Seq())
}
})
test.waitPacketOut(func(p *v4wire.Ping, addr *net.UDPAddr, hash []byte) {
test.waitPacketOut(func(p *v4wire.Ping, addr netip.AddrPort, hash []byte) {
if p.ENRSeq != wantNode.Seq() {
t.Errorf("wrong sequence number in ping: %d, want %d", p.ENRSeq, wantNode.Seq())
}
@ -483,7 +478,7 @@ func TestUDPv4_EIP868(t *testing.T) {
// Request should work now.
test.packetIn(nil, &v4wire.ENRRequest{Expiration: futureExp})
test.waitPacketOut(func(p *v4wire.ENRResponse, addr *net.UDPAddr, hash []byte) {
test.waitPacketOut(func(p *v4wire.ENRResponse, addr netip.AddrPort, hash []byte) {
n, err := enode.New(enode.ValidSchemes, &p.Record)
if err != nil {
t.Fatalf("invalid record: %v", err)
@ -584,7 +579,7 @@ type dgramPipe struct {
}
type dgram struct {
to net.UDPAddr
to netip.AddrPort
data []byte
}
@ -597,8 +592,8 @@ func newpipe() *dgramPipe {
}
}
// WriteToUDP queues a datagram.
func (c *dgramPipe) WriteToUDP(b []byte, to *net.UDPAddr) (n int, err error) {
// WriteToUDPAddrPort queues a datagram.
func (c *dgramPipe) WriteToUDPAddrPort(b []byte, to netip.AddrPort) (n int, err error) {
msg := make([]byte, len(b))
copy(msg, b)
c.mu.Lock()
@ -606,15 +601,15 @@ func (c *dgramPipe) WriteToUDP(b []byte, to *net.UDPAddr) (n int, err error) {
if c.closed {
return 0, errors.New("closed")
}
c.queue = append(c.queue, dgram{*to, b})
c.queue = append(c.queue, dgram{to, b})
c.cond.Signal()
return len(b), nil
}
// ReadFromUDP just hangs until the pipe is closed.
func (c *dgramPipe) ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) {
// ReadFromUDPAddrPort just hangs until the pipe is closed.
func (c *dgramPipe) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) {
<-c.closing
return 0, nil, io.EOF
return 0, netip.AddrPort{}, io.EOF
}
func (c *dgramPipe) Close() error {

@ -25,6 +25,7 @@ import (
"fmt"
"math/big"
"net"
"net/netip"
"time"
"github.com/ethereum/go-ethereum/common/math"
@ -150,14 +151,15 @@ type Endpoint struct {
}
// NewEndpoint creates an endpoint.
func NewEndpoint(addr *net.UDPAddr, tcpPort uint16) Endpoint {
ip := net.IP{}
if ip4 := addr.IP.To4(); ip4 != nil {
ip = ip4
} else if ip6 := addr.IP.To16(); ip6 != nil {
ip = ip6
func NewEndpoint(addr netip.AddrPort, tcpPort uint16) Endpoint {
var ip net.IP
if addr.Addr().Is4() || addr.Addr().Is4In6() {
ip4 := addr.Addr().As4()
ip = ip4[:]
} else {
ip = addr.Addr().AsSlice()
}
return Endpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort}
return Endpoint{IP: ip, UDP: addr.Port(), TCP: tcpPort}
}
type Packet interface {

@ -18,6 +18,7 @@ package discover
import (
"net"
"net/netip"
"sync"
"time"
@ -70,7 +71,7 @@ func (t *talkSystem) register(protocol string, handler TalkRequestHandler) {
}
// handleRequest handles a talk request.
func (t *talkSystem) handleRequest(id enode.ID, addr *net.UDPAddr, req *v5wire.TalkRequest) {
func (t *talkSystem) handleRequest(id enode.ID, addr netip.AddrPort, req *v5wire.TalkRequest) {
t.mutex.Lock()
handler, ok := t.handlers[req.Protocol]
t.mutex.Unlock()
@ -88,7 +89,8 @@ func (t *talkSystem) handleRequest(id enode.ID, addr *net.UDPAddr, req *v5wire.T
case <-t.slots:
go func() {
defer func() { t.slots <- struct{}{} }()
respMessage := handler(id, addr, req.Message)
udpAddr := &net.UDPAddr{IP: addr.Addr().AsSlice(), Port: int(addr.Port())}
respMessage := handler(id, udpAddr, req.Message)
resp := &v5wire.TalkResponse{ReqID: req.ReqID, Message: respMessage}
t.transport.sendFromAnotherThread(id, addr, resp)
}()

@ -25,6 +25,7 @@ import (
"fmt"
"io"
"net"
"net/netip"
"slices"
"sync"
"time"
@ -101,14 +102,14 @@ type UDPv5 struct {
type sendRequest struct {
destID enode.ID
destAddr *net.UDPAddr
destAddr netip.AddrPort
msg v5wire.Packet
}
// callV5 represents a remote procedure call against another node.
type callV5 struct {
id enode.ID
addr *net.UDPAddr
addr netip.AddrPort
node *enode.Node // This is required to perform handshakes.
packet v5wire.Packet
@ -233,7 +234,7 @@ func (t *UDPv5) AllNodes() []*enode.Node {
for _, b := range &t.tab.buckets {
for _, n := range b.entries {
nodes = append(nodes, unwrapNode(n))
nodes = append(nodes, n.Node)
}
}
return nodes
@ -266,7 +267,7 @@ func (t *UDPv5) TalkRequest(n *enode.Node, protocol string, request []byte) ([]b
}
// TalkRequestToID sends a talk request to a node and waits for a response.
func (t *UDPv5) TalkRequestToID(id enode.ID, addr *net.UDPAddr, protocol string, request []byte) ([]byte, error) {
func (t *UDPv5) TalkRequestToID(id enode.ID, addr netip.AddrPort, protocol string, request []byte) ([]byte, error) {
req := &v5wire.TalkRequest{Protocol: protocol, Message: request}
resp := t.callToID(id, addr, v5wire.TalkResponseMsg, req)
defer t.callDone(resp)
@ -314,26 +315,26 @@ func (t *UDPv5) newRandomLookup(ctx context.Context) *lookup {
}
func (t *UDPv5) newLookup(ctx context.Context, target enode.ID) *lookup {
return newLookup(ctx, t.tab, target, func(n *node) ([]*node, error) {
return newLookup(ctx, t.tab, target, func(n *enode.Node) ([]*enode.Node, error) {
return t.lookupWorker(n, target)
})
}
// lookupWorker performs FINDNODE calls against a single node during lookup.
func (t *UDPv5) lookupWorker(destNode *node, target enode.ID) ([]*node, error) {
func (t *UDPv5) lookupWorker(destNode *enode.Node, target enode.ID) ([]*enode.Node, error) {
var (
dists = lookupDistances(target, destNode.ID())
nodes = nodesByDistance{target: target}
err error
)
var r []*enode.Node
r, err = t.findnode(unwrapNode(destNode), dists)
r, err = t.findnode(destNode, dists)
if errors.Is(err, errClosed) {
return nil, err
}
for _, n := range r {
if n.ID() != t.Self().ID() {
nodes.push(wrapNode(n), findnodeResultLimit)
nodes.push(n, findnodeResultLimit)
}
}
return nodes.entries, err
@ -427,7 +428,7 @@ func (t *UDPv5) verifyResponseNode(c *callV5, r *enr.Record, distances []uint, s
if err != nil {
return nil, err
}
if err := netutil.CheckRelayIP(c.addr.IP, node.IP()); err != nil {
if err := netutil.CheckRelayIP(c.addr.Addr().AsSlice(), node.IP()); err != nil {
return nil, err
}
if t.netrestrict != nil && !t.netrestrict.Contains(node.IP()) {
@ -452,14 +453,14 @@ func (t *UDPv5) verifyResponseNode(c *callV5, r *enr.Record, distances []uint, s
// callToNode sends the given call and sets up a handler for response packets (of message
// type responseType). Responses are dispatched to the call's response channel.
func (t *UDPv5) callToNode(n *enode.Node, responseType byte, req v5wire.Packet) *callV5 {
addr := &net.UDPAddr{IP: n.IP(), Port: n.UDP()}
addr, _ := n.UDPEndpoint()
c := &callV5{id: n.ID(), addr: addr, node: n}
t.initCall(c, responseType, req)
return c
}
// callToID is like callToNode, but for cases where the node record is not available.
func (t *UDPv5) callToID(id enode.ID, addr *net.UDPAddr, responseType byte, req v5wire.Packet) *callV5 {
func (t *UDPv5) callToID(id enode.ID, addr netip.AddrPort, responseType byte, req v5wire.Packet) *callV5 {
c := &callV5{id: id, addr: addr}
t.initCall(c, responseType, req)
return c
@ -619,12 +620,12 @@ func (t *UDPv5) sendCall(c *callV5) {
// sendResponse sends a response packet to the given node.
// This doesn't trigger a handshake even if no keys are available.
func (t *UDPv5) sendResponse(toID enode.ID, toAddr *net.UDPAddr, packet v5wire.Packet) error {
func (t *UDPv5) sendResponse(toID enode.ID, toAddr netip.AddrPort, packet v5wire.Packet) error {
_, err := t.send(toID, toAddr, packet, nil)
return err
}
func (t *UDPv5) sendFromAnotherThread(toID enode.ID, toAddr *net.UDPAddr, packet v5wire.Packet) {
func (t *UDPv5) sendFromAnotherThread(toID enode.ID, toAddr netip.AddrPort, packet v5wire.Packet) {
select {
case t.sendCh <- sendRequest{toID, toAddr, packet}:
case <-t.closeCtx.Done():
@ -632,7 +633,7 @@ func (t *UDPv5) sendFromAnotherThread(toID enode.ID, toAddr *net.UDPAddr, packet
}
// send sends a packet to the given node.
func (t *UDPv5) send(toID enode.ID, toAddr *net.UDPAddr, packet v5wire.Packet, c *v5wire.Whoareyou) (v5wire.Nonce, error) {
func (t *UDPv5) send(toID enode.ID, toAddr netip.AddrPort, packet v5wire.Packet, c *v5wire.Whoareyou) (v5wire.Nonce, error) {
addr := toAddr.String()
t.logcontext = append(t.logcontext[:0], "id", toID, "addr", addr)
t.logcontext = packet.AppendLogInfo(t.logcontext)
@ -644,7 +645,7 @@ func (t *UDPv5) send(toID enode.ID, toAddr *net.UDPAddr, packet v5wire.Packet, c
return nonce, err
}
_, err = t.conn.WriteToUDP(enc, toAddr)
_, err = t.conn.WriteToUDPAddrPort(enc, toAddr)
t.log.Trace(">> "+packet.Name(), t.logcontext...)
return nonce, err
}
@ -655,7 +656,7 @@ func (t *UDPv5) readLoop() {
buf := make([]byte, maxPacketSize)
for range t.readNextCh {
nbytes, from, err := t.conn.ReadFromUDP(buf)
nbytes, from, err := t.conn.ReadFromUDPAddrPort(buf)
if netutil.IsTemporaryError(err) {
// Ignore temporary read errors.
t.log.Debug("Temporary UDP read error", "err", err)
@ -672,7 +673,7 @@ func (t *UDPv5) readLoop() {
}
// dispatchReadPacket sends a packet into the dispatch loop.
func (t *UDPv5) dispatchReadPacket(from *net.UDPAddr, content []byte) bool {
func (t *UDPv5) dispatchReadPacket(from netip.AddrPort, content []byte) bool {
select {
case t.packetInCh <- ReadPacket{content, from}:
return true
@ -682,7 +683,7 @@ func (t *UDPv5) dispatchReadPacket(from *net.UDPAddr, content []byte) bool {
}
// handlePacket decodes and processes an incoming packet from the network.
func (t *UDPv5) handlePacket(rawpacket []byte, fromAddr *net.UDPAddr) error {
func (t *UDPv5) handlePacket(rawpacket []byte, fromAddr netip.AddrPort) error {
addr := fromAddr.String()
fromID, fromNode, packet, err := t.codec.Decode(rawpacket, addr)
if err != nil {
@ -699,7 +700,7 @@ func (t *UDPv5) handlePacket(rawpacket []byte, fromAddr *net.UDPAddr) error {
}
if fromNode != nil {
// Handshake succeeded, add to table.
t.tab.addInboundNode(wrapNode(fromNode))
t.tab.addInboundNode(fromNode)
}
if packet.Kind() != v5wire.WhoareyouPacket {
// WHOAREYOU logged separately to report errors.
@ -712,13 +713,13 @@ func (t *UDPv5) handlePacket(rawpacket []byte, fromAddr *net.UDPAddr) error {
}
// handleCallResponse dispatches a response packet to the call waiting for it.
func (t *UDPv5) handleCallResponse(fromID enode.ID, fromAddr *net.UDPAddr, p v5wire.Packet) bool {
func (t *UDPv5) handleCallResponse(fromID enode.ID, fromAddr netip.AddrPort, p v5wire.Packet) bool {
ac := t.activeCallByNode[fromID]
if ac == nil || !bytes.Equal(p.RequestID(), ac.reqid) {
t.log.Debug(fmt.Sprintf("Unsolicited/late %s response", p.Name()), "id", fromID, "addr", fromAddr)
return false
}
if !fromAddr.IP.Equal(ac.addr.IP) || fromAddr.Port != ac.addr.Port {
if fromAddr != ac.addr {
t.log.Debug(fmt.Sprintf("%s from wrong endpoint", p.Name()), "id", fromID, "addr", fromAddr)
return false
}
@ -743,7 +744,7 @@ func (t *UDPv5) getNode(id enode.ID) *enode.Node {
}
// handle processes incoming packets according to their message type.
func (t *UDPv5) handle(p v5wire.Packet, fromID enode.ID, fromAddr *net.UDPAddr) {
func (t *UDPv5) handle(p v5wire.Packet, fromID enode.ID, fromAddr netip.AddrPort) {
switch p := p.(type) {
case *v5wire.Unknown:
t.handleUnknown(p, fromID, fromAddr)
@ -753,7 +754,9 @@ func (t *UDPv5) handle(p v5wire.Packet, fromID enode.ID, fromAddr *net.UDPAddr)
t.handlePing(p, fromID, fromAddr)
case *v5wire.Pong:
if t.handleCallResponse(fromID, fromAddr, p) {
t.localNode.UDPEndpointStatement(fromAddr, &net.UDPAddr{IP: p.ToIP, Port: int(p.ToPort)})
fromUDPAddr := &net.UDPAddr{IP: fromAddr.Addr().AsSlice(), Port: int(fromAddr.Port())}
toUDPAddr := &net.UDPAddr{IP: p.ToIP, Port: int(p.ToPort)}
t.localNode.UDPEndpointStatement(fromUDPAddr, toUDPAddr)
}
case *v5wire.Findnode:
t.handleFindnode(p, fromID, fromAddr)
@ -767,7 +770,7 @@ func (t *UDPv5) handle(p v5wire.Packet, fromID enode.ID, fromAddr *net.UDPAddr)
}
// handleUnknown initiates a handshake by responding with WHOAREYOU.
func (t *UDPv5) handleUnknown(p *v5wire.Unknown, fromID enode.ID, fromAddr *net.UDPAddr) {
func (t *UDPv5) handleUnknown(p *v5wire.Unknown, fromID enode.ID, fromAddr netip.AddrPort) {
challenge := &v5wire.Whoareyou{Nonce: p.Nonce}
crand.Read(challenge.IDNonce[:])
if n := t.getNode(fromID); n != nil {
@ -783,7 +786,7 @@ var (
)
// handleWhoareyou resends the active call as a handshake packet.
func (t *UDPv5) handleWhoareyou(p *v5wire.Whoareyou, fromID enode.ID, fromAddr *net.UDPAddr) {
func (t *UDPv5) handleWhoareyou(p *v5wire.Whoareyou, fromID enode.ID, fromAddr netip.AddrPort) {
c, err := t.matchWithCall(fromID, p.Nonce)
if err != nil {
t.log.Debug("Invalid "+p.Name(), "addr", fromAddr, "err", err)
@ -817,32 +820,35 @@ func (t *UDPv5) matchWithCall(fromID enode.ID, nonce v5wire.Nonce) (*callV5, err
}
// handlePing sends a PONG response.
func (t *UDPv5) handlePing(p *v5wire.Ping, fromID enode.ID, fromAddr *net.UDPAddr) {
remoteIP := fromAddr.IP
// Handle IPv4 mapped IPv6 addresses in the
// event the local node is binded to an
// ipv6 interface.
if remoteIP.To4() != nil {
remoteIP = remoteIP.To4()
func (t *UDPv5) handlePing(p *v5wire.Ping, fromID enode.ID, fromAddr netip.AddrPort) {
var remoteIP net.IP
// Handle IPv4 mapped IPv6 addresses in the event the local node is binded
// to an ipv6 interface.
if fromAddr.Addr().Is4() || fromAddr.Addr().Is4In6() {
ip4 := fromAddr.Addr().As4()
remoteIP = ip4[:]
} else {
remoteIP = fromAddr.Addr().AsSlice()
}
t.sendResponse(fromID, fromAddr, &v5wire.Pong{
ReqID: p.ReqID,
ToIP: remoteIP,
ToPort: uint16(fromAddr.Port),
ToPort: fromAddr.Port(),
ENRSeq: t.localNode.Node().Seq(),
})
}
// handleFindnode returns nodes to the requester.
func (t *UDPv5) handleFindnode(p *v5wire.Findnode, fromID enode.ID, fromAddr *net.UDPAddr) {
nodes := t.collectTableNodes(fromAddr.IP, p.Distances, findnodeResultLimit)
func (t *UDPv5) handleFindnode(p *v5wire.Findnode, fromID enode.ID, fromAddr netip.AddrPort) {
nodes := t.collectTableNodes(fromAddr.Addr(), p.Distances, findnodeResultLimit)
for _, resp := range packNodes(p.ReqID, nodes) {
t.sendResponse(fromID, fromAddr, resp)
}
}
// collectTableNodes creates a FINDNODE result set for the given distances.
func (t *UDPv5) collectTableNodes(rip net.IP, distances []uint, limit int) []*enode.Node {
func (t *UDPv5) collectTableNodes(rip netip.Addr, distances []uint, limit int) []*enode.Node {
ripSlice := rip.AsSlice()
var bn []*enode.Node
var nodes []*enode.Node
var processed = make(map[uint]struct{})
@ -857,7 +863,7 @@ func (t *UDPv5) collectTableNodes(rip net.IP, distances []uint, limit int) []*en
for _, n := range t.tab.appendLiveNodes(dist, bn[:0]) {
// Apply some pre-checks to avoid sending invalid nodes.
// Note liveness is checked by appendLiveNodes.
if netutil.CheckRelayIP(rip, n.IP()) != nil {
if netutil.CheckRelayIP(ripSlice, n.IP()) != nil {
continue
}
nodes = append(nodes, n)

@ -23,6 +23,7 @@ import (
"fmt"
"math/rand"
"net"
"net/netip"
"reflect"
"slices"
"testing"
@ -103,7 +104,7 @@ func TestUDPv5_pingHandling(t *testing.T) {
defer test.close()
test.packetIn(&v5wire.Ping{ReqID: []byte("foo")})
test.waitPacketOut(func(p *v5wire.Pong, addr *net.UDPAddr, _ v5wire.Nonce) {
test.waitPacketOut(func(p *v5wire.Pong, addr netip.AddrPort, _ v5wire.Nonce) {
if !bytes.Equal(p.ReqID, []byte("foo")) {
t.Error("wrong request ID in response:", p.ReqID)
}
@ -135,16 +136,16 @@ func TestUDPv5_unknownPacket(t *testing.T) {
// Unknown packet from unknown node.
test.packetIn(&v5wire.Unknown{Nonce: nonce})
test.waitPacketOut(func(p *v5wire.Whoareyou, addr *net.UDPAddr, _ v5wire.Nonce) {
test.waitPacketOut(func(p *v5wire.Whoareyou, addr netip.AddrPort, _ v5wire.Nonce) {
check(p, 0)
})
// Make node known.
n := test.getNode(test.remotekey, test.remoteaddr).Node()
test.table.addFoundNode(wrapNode(n))
test.table.addFoundNode(n, false)
test.packetIn(&v5wire.Unknown{Nonce: nonce})
test.waitPacketOut(func(p *v5wire.Whoareyou, addr *net.UDPAddr, _ v5wire.Nonce) {
test.waitPacketOut(func(p *v5wire.Whoareyou, addr netip.AddrPort, _ v5wire.Nonce) {
check(p, n.Seq())
})
}
@ -159,9 +160,9 @@ func TestUDPv5_findnodeHandling(t *testing.T) {
nodes253 := nodesAtDistance(test.table.self().ID(), 253, 16)
nodes249 := nodesAtDistance(test.table.self().ID(), 249, 4)
nodes248 := nodesAtDistance(test.table.self().ID(), 248, 10)
fillTable(test.table, wrapNodes(nodes253), true)
fillTable(test.table, wrapNodes(nodes249), true)
fillTable(test.table, wrapNodes(nodes248), true)
fillTable(test.table, nodes253, true)
fillTable(test.table, nodes249, true)
fillTable(test.table, nodes248, true)
// Requesting with distance zero should return the node's own record.
test.packetIn(&v5wire.Findnode{ReqID: []byte{0}, Distances: []uint{0}})
@ -199,7 +200,7 @@ func (test *udpV5Test) expectNodes(wantReqID []byte, wantTotal uint8, wantNodes
}
for {
test.waitPacketOut(func(p *v5wire.Nodes, addr *net.UDPAddr, _ v5wire.Nonce) {
test.waitPacketOut(func(p *v5wire.Nodes, addr netip.AddrPort, _ v5wire.Nonce) {
if !bytes.Equal(p.ReqID, wantReqID) {
test.t.Fatalf("wrong request ID %v in response, want %v", p.ReqID, wantReqID)
}
@ -238,7 +239,7 @@ func TestUDPv5_pingCall(t *testing.T) {
_, err := test.udp.ping(remote)
done <- err
}()
test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, _ v5wire.Nonce) {})
test.waitPacketOut(func(p *v5wire.Ping, addr netip.AddrPort, _ v5wire.Nonce) {})
if err := <-done; err != errTimeout {
t.Fatalf("want errTimeout, got %q", err)
}
@ -248,7 +249,7 @@ func TestUDPv5_pingCall(t *testing.T) {
_, err := test.udp.ping(remote)
done <- err
}()
test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, _ v5wire.Nonce) {
test.waitPacketOut(func(p *v5wire.Ping, addr netip.AddrPort, _ v5wire.Nonce) {
test.packetInFrom(test.remotekey, test.remoteaddr, &v5wire.Pong{ReqID: p.ReqID})
})
if err := <-done; err != nil {
@ -260,8 +261,8 @@ func TestUDPv5_pingCall(t *testing.T) {
_, err := test.udp.ping(remote)
done <- err
}()
test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, _ v5wire.Nonce) {
wrongAddr := &net.UDPAddr{IP: net.IP{33, 44, 55, 22}, Port: 10101}
test.waitPacketOut(func(p *v5wire.Ping, addr netip.AddrPort, _ v5wire.Nonce) {
wrongAddr := netip.MustParseAddrPort("33.44.55.22:10101")
test.packetInFrom(test.remotekey, wrongAddr, &v5wire.Pong{ReqID: p.ReqID})
})
if err := <-done; err != errTimeout {
@ -291,7 +292,7 @@ func TestUDPv5_findnodeCall(t *testing.T) {
}()
// Serve the responses:
test.waitPacketOut(func(p *v5wire.Findnode, addr *net.UDPAddr, _ v5wire.Nonce) {
test.waitPacketOut(func(p *v5wire.Findnode, addr netip.AddrPort, _ v5wire.Nonce) {
if !reflect.DeepEqual(p.Distances, distances) {
t.Fatalf("wrong distances in request: %v", p.Distances)
}
@ -337,15 +338,15 @@ func TestUDPv5_callResend(t *testing.T) {
}()
// Ping answered by WHOAREYOU.
test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, nonce v5wire.Nonce) {
test.waitPacketOut(func(p *v5wire.Ping, addr netip.AddrPort, nonce v5wire.Nonce) {
test.packetIn(&v5wire.Whoareyou{Nonce: nonce})
})
// Ping should be re-sent.
test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, _ v5wire.Nonce) {
test.waitPacketOut(func(p *v5wire.Ping, addr netip.AddrPort, _ v5wire.Nonce) {
test.packetIn(&v5wire.Pong{ReqID: p.ReqID})
})
// Answer the other ping.
test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, _ v5wire.Nonce) {
test.waitPacketOut(func(p *v5wire.Ping, addr netip.AddrPort, _ v5wire.Nonce) {
test.packetIn(&v5wire.Pong{ReqID: p.ReqID})
})
if err := <-done; err != nil {
@ -370,11 +371,11 @@ func TestUDPv5_multipleHandshakeRounds(t *testing.T) {
}()
// Ping answered by WHOAREYOU.
test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, nonce v5wire.Nonce) {
test.waitPacketOut(func(p *v5wire.Ping, addr netip.AddrPort, nonce v5wire.Nonce) {
test.packetIn(&v5wire.Whoareyou{Nonce: nonce})
})
// Ping answered by WHOAREYOU again.
test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, nonce v5wire.Nonce) {
test.waitPacketOut(func(p *v5wire.Ping, addr netip.AddrPort, nonce v5wire.Nonce) {
test.packetIn(&v5wire.Whoareyou{Nonce: nonce})
})
if err := <-done; err != errTimeout {
@ -401,7 +402,7 @@ func TestUDPv5_callTimeoutReset(t *testing.T) {
}()
// Serve two responses, slowly.
test.waitPacketOut(func(p *v5wire.Findnode, addr *net.UDPAddr, _ v5wire.Nonce) {
test.waitPacketOut(func(p *v5wire.Findnode, addr netip.AddrPort, _ v5wire.Nonce) {
time.Sleep(respTimeout - 50*time.Millisecond)
test.packetIn(&v5wire.Nodes{
ReqID: p.ReqID,
@ -439,7 +440,7 @@ func TestUDPv5_talkHandling(t *testing.T) {
Protocol: "test",
Message: []byte("test request"),
})
test.waitPacketOut(func(p *v5wire.TalkResponse, addr *net.UDPAddr, _ v5wire.Nonce) {
test.waitPacketOut(func(p *v5wire.TalkResponse, addr netip.AddrPort, _ v5wire.Nonce) {
if !bytes.Equal(p.ReqID, []byte("foo")) {
t.Error("wrong request ID in response:", p.ReqID)
}
@ -458,7 +459,7 @@ func TestUDPv5_talkHandling(t *testing.T) {
Protocol: "wrong",
Message: []byte("test request"),
})
test.waitPacketOut(func(p *v5wire.TalkResponse, addr *net.UDPAddr, _ v5wire.Nonce) {
test.waitPacketOut(func(p *v5wire.TalkResponse, addr netip.AddrPort, _ v5wire.Nonce) {
if !bytes.Equal(p.ReqID, []byte("2")) {
t.Error("wrong request ID in response:", p.ReqID)
}
@ -485,7 +486,7 @@ func TestUDPv5_talkRequest(t *testing.T) {
_, err := test.udp.TalkRequest(remote, "test", []byte("test request"))
done <- err
}()
test.waitPacketOut(func(p *v5wire.TalkRequest, addr *net.UDPAddr, _ v5wire.Nonce) {})
test.waitPacketOut(func(p *v5wire.TalkRequest, addr netip.AddrPort, _ v5wire.Nonce) {})
if err := <-done; err != errTimeout {
t.Fatalf("want errTimeout, got %q", err)
}
@ -495,7 +496,7 @@ func TestUDPv5_talkRequest(t *testing.T) {
_, err := test.udp.TalkRequest(remote, "test", []byte("test request"))
done <- err
}()
test.waitPacketOut(func(p *v5wire.TalkRequest, addr *net.UDPAddr, _ v5wire.Nonce) {
test.waitPacketOut(func(p *v5wire.TalkRequest, addr netip.AddrPort, _ v5wire.Nonce) {
if p.Protocol != "test" {
t.Errorf("wrong protocol ID in talk request: %q", p.Protocol)
}
@ -516,7 +517,7 @@ func TestUDPv5_talkRequest(t *testing.T) {
_, err := test.udp.TalkRequestToID(remote.ID(), test.remoteaddr, "test", []byte("test request 2"))
done <- err
}()
test.waitPacketOut(func(p *v5wire.TalkRequest, addr *net.UDPAddr, _ v5wire.Nonce) {
test.waitPacketOut(func(p *v5wire.TalkRequest, addr netip.AddrPort, _ v5wire.Nonce) {
if p.Protocol != "test" {
t.Errorf("wrong protocol ID in talk request: %q", p.Protocol)
}
@ -583,13 +584,14 @@ func TestUDPv5_lookup(t *testing.T) {
for d, nn := range lookupTestnet.dists {
for i, key := range nn {
n := lookupTestnet.node(d, i)
test.getNode(key, &net.UDPAddr{IP: n.IP(), Port: n.UDP()})
addr, _ := n.UDPEndpoint()
test.getNode(key, addr)
}
}
// Seed table with initial node.
initialNode := lookupTestnet.node(256, 0)
fillTable(test.table, []*node{wrapNode(initialNode)}, true)
fillTable(test.table, []*enode.Node{initialNode}, true)
// Start the lookup.
resultC := make(chan []*enode.Node, 1)
@ -601,7 +603,7 @@ func TestUDPv5_lookup(t *testing.T) {
// Answer lookup packets.
asked := make(map[enode.ID]bool)
for done := false; !done; {
done = test.waitPacketOut(func(p v5wire.Packet, to *net.UDPAddr, _ v5wire.Nonce) {
done = test.waitPacketOut(func(p v5wire.Packet, to netip.AddrPort, _ v5wire.Nonce) {
recipient, key := lookupTestnet.nodeByAddr(to)
switch p := p.(type) {
case *v5wire.Ping:
@ -652,11 +654,8 @@ func TestUDPv5_PingWithIPV4MappedAddress(t *testing.T) {
test := newUDPV5Test(t)
defer test.close()
rawIP := net.IPv4(0xFF, 0x12, 0x33, 0xE5)
test.remoteaddr = &net.UDPAddr{
IP: rawIP.To16(),
Port: 0,
}
rawIP := netip.AddrFrom4([4]byte{0xFF, 0x12, 0x33, 0xE5})
test.remoteaddr = netip.AddrPortFrom(netip.AddrFrom16(rawIP.As16()), 0)
remote := test.getNode(test.remotekey, test.remoteaddr).Node()
done := make(chan struct{}, 1)
@ -665,14 +664,14 @@ func TestUDPv5_PingWithIPV4MappedAddress(t *testing.T) {
test.udp.handlePing(&v5wire.Ping{ENRSeq: 1}, remote.ID(), test.remoteaddr)
done <- struct{}{}
}()
test.waitPacketOut(func(p *v5wire.Pong, addr *net.UDPAddr, _ v5wire.Nonce) {
test.waitPacketOut(func(p *v5wire.Pong, addr netip.AddrPort, _ v5wire.Nonce) {
if len(p.ToIP) == net.IPv6len {
t.Error("Received untruncated ip address")
}
if len(p.ToIP) != net.IPv4len {
t.Errorf("Received ip address with incorrect length: %d", len(p.ToIP))
}
if !p.ToIP.Equal(rawIP) {
if !p.ToIP.Equal(rawIP.AsSlice()) {
t.Errorf("Received incorrect ip address: wanted %s but received %s", rawIP.String(), p.ToIP.String())
}
})
@ -688,9 +687,9 @@ type udpV5Test struct {
db *enode.DB
udp *UDPv5
localkey, remotekey *ecdsa.PrivateKey
remoteaddr *net.UDPAddr
remoteaddr netip.AddrPort
nodesByID map[enode.ID]*enode.LocalNode
nodesByIP map[string]*enode.LocalNode
nodesByIP map[netip.Addr]*enode.LocalNode
}
// testCodec is the packet encoding used by protocol tests. This codec does not perform encryption.
@ -750,9 +749,9 @@ func newUDPV5Test(t *testing.T) *udpV5Test {
pipe: newpipe(),
localkey: newkey(),
remotekey: newkey(),
remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303},
remoteaddr: netip.MustParseAddrPort("10.0.1.99:30303"),
nodesByID: make(map[enode.ID]*enode.LocalNode),
nodesByIP: make(map[string]*enode.LocalNode),
nodesByIP: make(map[netip.Addr]*enode.LocalNode),
}
test.db, _ = enode.OpenDB("")
ln := enode.NewLocalNode(test.db, test.localkey)
@ -777,8 +776,8 @@ func (test *udpV5Test) packetIn(packet v5wire.Packet) {
test.packetInFrom(test.remotekey, test.remoteaddr, packet)
}
// handles a packet as if it had been sent to the transport by the key/endpoint.
func (test *udpV5Test) packetInFrom(key *ecdsa.PrivateKey, addr *net.UDPAddr, packet v5wire.Packet) {
// packetInFrom handles a packet as if it had been sent to the transport by the key/endpoint.
func (test *udpV5Test) packetInFrom(key *ecdsa.PrivateKey, addr netip.AddrPort, packet v5wire.Packet) {
test.t.Helper()
ln := test.getNode(key, addr)
@ -793,22 +792,22 @@ func (test *udpV5Test) packetInFrom(key *ecdsa.PrivateKey, addr *net.UDPAddr, pa
}
// getNode ensures the test knows about a node at the given endpoint.
func (test *udpV5Test) getNode(key *ecdsa.PrivateKey, addr *net.UDPAddr) *enode.LocalNode {
func (test *udpV5Test) getNode(key *ecdsa.PrivateKey, addr netip.AddrPort) *enode.LocalNode {
id := encodePubkey(&key.PublicKey).id()
ln := test.nodesByID[id]
if ln == nil {
db, _ := enode.OpenDB("")
ln = enode.NewLocalNode(db, key)
ln.SetStaticIP(addr.IP)
ln.Set(enr.UDP(addr.Port))
ln.SetStaticIP(addr.Addr().AsSlice())
ln.Set(enr.UDP(addr.Port()))
test.nodesByID[id] = ln
}
test.nodesByIP[string(addr.IP)] = ln
test.nodesByIP[addr.Addr()] = ln
return ln
}
// waitPacketOut waits for the next output packet and handles it using the given 'validate'
// function. The function must be of type func (X, *net.UDPAddr, v5wire.Nonce) where X is
// function. The function must be of type func (X, netip.AddrPort, v5wire.Nonce) where X is
// assignable to packetV5.
func (test *udpV5Test) waitPacketOut(validate interface{}) (closed bool) {
test.t.Helper()
@ -824,7 +823,7 @@ func (test *udpV5Test) waitPacketOut(validate interface{}) (closed bool) {
test.t.Fatalf("timed out waiting for %v", exptype)
return false
}
ln := test.nodesByIP[string(dgram.to.IP)]
ln := test.nodesByIP[dgram.to.Addr()]
if ln == nil {
test.t.Fatalf("attempt to send to non-existing node %v", &dgram.to)
return false
@ -839,7 +838,7 @@ func (test *udpV5Test) waitPacketOut(validate interface{}) (closed bool) {
test.t.Errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype)
return false
}
fn.Call([]reflect.Value{reflect.ValueOf(p), reflect.ValueOf(&dgram.to), reflect.ValueOf(frame.AuthTag)})
fn.Call([]reflect.Value{reflect.ValueOf(p), reflect.ValueOf(dgram.to), reflect.ValueOf(frame.AuthTag)})
return false
}

@ -24,6 +24,7 @@ import (
"errors"
"fmt"
"net"
"net/netip"
"slices"
"sync"
"sync/atomic"
@ -435,11 +436,11 @@ type sharedUDPConn struct {
unhandled chan discover.ReadPacket
}
// ReadFromUDP implements discover.UDPConn
func (s *sharedUDPConn) ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) {
// ReadFromUDPAddrPort implements discover.UDPConn
func (s *sharedUDPConn) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) {
packet, ok := <-s.unhandled
if !ok {
return 0, nil, errors.New("connection was closed")
return 0, netip.AddrPort{}, errors.New("connection was closed")
}
l := len(packet.Data)
if l > len(b) {

Loading…
Cancel
Save