p2p: fix race in dialScheduler (#29235)

Co-authored-by: Stefan <stefan@starflinger.eu>
pull/29078/head^2
Felix Lange 8 months ago committed by GitHub
parent 6c76b813df
commit 758fce71fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 48
      p2p/dial.go

@ -25,6 +25,7 @@ import (
mrand "math/rand" mrand "math/rand"
"net" "net"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/common/mclock"
@ -248,7 +249,7 @@ loop:
} }
case task := <-d.doneCh: case task := <-d.doneCh:
id := task.dest.ID() id := task.dest().ID()
delete(d.dialing, id) delete(d.dialing, id)
d.updateStaticPool(id) d.updateStaticPool(id)
d.doneSinceLastLog++ d.doneSinceLastLog++
@ -410,7 +411,7 @@ func (d *dialScheduler) startStaticDials(n int) (started int) {
// updateStaticPool attempts to move the given static dial back into staticPool. // updateStaticPool attempts to move the given static dial back into staticPool.
func (d *dialScheduler) updateStaticPool(id enode.ID) { func (d *dialScheduler) updateStaticPool(id enode.ID) {
task, ok := d.static[id] task, ok := d.static[id]
if ok && task.staticPoolIndex < 0 && d.checkDial(task.dest) == nil { if ok && task.staticPoolIndex < 0 && d.checkDial(task.dest()) == nil {
d.addToStaticPool(task) d.addToStaticPool(task)
} }
} }
@ -437,10 +438,11 @@ func (d *dialScheduler) removeFromStaticPool(idx int) {
// startDial runs the given dial task in a separate goroutine. // startDial runs the given dial task in a separate goroutine.
func (d *dialScheduler) startDial(task *dialTask) { func (d *dialScheduler) startDial(task *dialTask) {
d.log.Trace("Starting p2p dial", "id", task.dest.ID(), "ip", task.dest.IP(), "flag", task.flags) node := task.dest()
hkey := string(task.dest.ID().Bytes()) d.log.Trace("Starting p2p dial", "id", node.ID(), "ip", node.IP(), "flag", task.flags)
hkey := string(node.ID().Bytes())
d.history.add(hkey, d.clock.Now().Add(dialHistoryExpiration)) d.history.add(hkey, d.clock.Now().Add(dialHistoryExpiration))
d.dialing[task.dest.ID()] = task d.dialing[node.ID()] = task
go func() { go func() {
task.run(d) task.run(d)
d.doneCh <- task d.doneCh <- task
@ -451,39 +453,46 @@ func (d *dialScheduler) startDial(task *dialTask) {
type dialTask struct { type dialTask struct {
staticPoolIndex int staticPoolIndex int
flags connFlag flags connFlag
// These fields are private to the task and should not be // These fields are private to the task and should not be
// accessed by dialScheduler while the task is running. // accessed by dialScheduler while the task is running.
dest *enode.Node destPtr atomic.Pointer[enode.Node]
lastResolved mclock.AbsTime lastResolved mclock.AbsTime
resolveDelay time.Duration resolveDelay time.Duration
} }
func newDialTask(dest *enode.Node, flags connFlag) *dialTask { func newDialTask(dest *enode.Node, flags connFlag) *dialTask {
return &dialTask{dest: dest, flags: flags, staticPoolIndex: -1} t := &dialTask{flags: flags, staticPoolIndex: -1}
t.destPtr.Store(dest)
return t
} }
type dialError struct { type dialError struct {
error error
} }
func (t *dialTask) dest() *enode.Node {
return t.destPtr.Load()
}
func (t *dialTask) run(d *dialScheduler) { func (t *dialTask) run(d *dialScheduler) {
if t.needResolve() && !t.resolve(d) { if t.needResolve() && !t.resolve(d) {
return return
} }
err := t.dial(d, t.dest) err := t.dial(d, t.dest())
if err != nil { if err != nil {
// For static nodes, resolve one more time if dialing fails. // For static nodes, resolve one more time if dialing fails.
if _, ok := err.(*dialError); ok && t.flags&staticDialedConn != 0 { if _, ok := err.(*dialError); ok && t.flags&staticDialedConn != 0 {
if t.resolve(d) { if t.resolve(d) {
t.dial(d, t.dest) t.dial(d, t.dest())
} }
} }
} }
} }
func (t *dialTask) needResolve() bool { func (t *dialTask) needResolve() bool {
return t.flags&staticDialedConn != 0 && t.dest.IP() == nil return t.flags&staticDialedConn != 0 && t.dest().IP() == nil
} }
// resolve attempts to find the current endpoint for the destination // resolve attempts to find the current endpoint for the destination
@ -502,29 +511,31 @@ func (t *dialTask) resolve(d *dialScheduler) bool {
if t.lastResolved > 0 && time.Duration(d.clock.Now()-t.lastResolved) < t.resolveDelay { if t.lastResolved > 0 && time.Duration(d.clock.Now()-t.lastResolved) < t.resolveDelay {
return false return false
} }
resolved := d.resolver.Resolve(t.dest)
node := t.dest()
resolved := d.resolver.Resolve(node)
t.lastResolved = d.clock.Now() t.lastResolved = d.clock.Now()
if resolved == nil { if resolved == nil {
t.resolveDelay *= 2 t.resolveDelay *= 2
if t.resolveDelay > maxResolveDelay { if t.resolveDelay > maxResolveDelay {
t.resolveDelay = maxResolveDelay t.resolveDelay = maxResolveDelay
} }
d.log.Debug("Resolving node failed", "id", t.dest.ID(), "newdelay", t.resolveDelay) d.log.Debug("Resolving node failed", "id", node.ID(), "newdelay", t.resolveDelay)
return false return false
} }
// The node was found. // The node was found.
t.resolveDelay = initialResolveDelay t.resolveDelay = initialResolveDelay
t.dest = resolved t.destPtr.Store(resolved)
d.log.Debug("Resolved node", "id", t.dest.ID(), "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}) d.log.Debug("Resolved node", "id", resolved.ID(), "addr", &net.TCPAddr{IP: resolved.IP(), Port: resolved.TCP()})
return true return true
} }
// dial performs the actual connection attempt. // dial performs the actual connection attempt.
func (t *dialTask) dial(d *dialScheduler, dest *enode.Node) error { func (t *dialTask) dial(d *dialScheduler, dest *enode.Node) error {
dialMeter.Mark(1) dialMeter.Mark(1)
fd, err := d.dialer.Dial(d.ctx, t.dest) fd, err := d.dialer.Dial(d.ctx, dest)
if err != nil { if err != nil {
d.log.Trace("Dial error", "id", t.dest.ID(), "addr", nodeAddr(t.dest), "conn", t.flags, "err", cleanupDialErr(err)) d.log.Trace("Dial error", "id", dest.ID(), "addr", nodeAddr(dest), "conn", t.flags, "err", cleanupDialErr(err))
dialConnectionError.Mark(1) dialConnectionError.Mark(1)
return &dialError{err} return &dialError{err}
} }
@ -532,8 +543,9 @@ func (t *dialTask) dial(d *dialScheduler, dest *enode.Node) error {
} }
func (t *dialTask) String() string { func (t *dialTask) String() string {
id := t.dest.ID() node := t.dest()
return fmt.Sprintf("%v %x %v:%d", t.flags, id[:8], t.dest.IP(), t.dest.TCP()) id := node.ID()
return fmt.Sprintf("%v %x %v:%d", t.flags, id[:8], node.IP(), node.TCP())
} }
func cleanupDialErr(err error) error { func cleanupDialErr(err error) error {

Loading…
Cancel
Save