Skip to content

Commit e6b9ecf

Browse files
fjlqu0b
authored andcommitted
p2p: fix race in dialScheduler (ethereum#29235)
Co-authored-by: Stefan <[email protected]>
1 parent bca9449 commit e6b9ecf

File tree

1 file changed

+30
-18
lines changed

1 file changed

+30
-18
lines changed

p2p/dial.go

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
mrand "math/rand"
2626
"net"
2727
"sync"
28+
"sync/atomic"
2829
"time"
2930

3031
"github.com/ethereum/go-ethereum/common/mclock"
@@ -248,7 +249,7 @@ loop:
248249
}
249250

250251
case task := <-d.doneCh:
251-
id := task.dest.ID()
252+
id := task.dest().ID()
252253
delete(d.dialing, id)
253254
d.updateStaticPool(id)
254255
d.doneSinceLastLog++
@@ -410,7 +411,7 @@ func (d *dialScheduler) startStaticDials(n int) (started int) {
410411
// updateStaticPool attempts to move the given static dial back into staticPool.
411412
func (d *dialScheduler) updateStaticPool(id enode.ID) {
412413
task, ok := d.static[id]
413-
if ok && task.staticPoolIndex < 0 && d.checkDial(task.dest) == nil {
414+
if ok && task.staticPoolIndex < 0 && d.checkDial(task.dest()) == nil {
414415
d.addToStaticPool(task)
415416
}
416417
}
@@ -437,10 +438,11 @@ func (d *dialScheduler) removeFromStaticPool(idx int) {
437438

438439
// startDial runs the given dial task in a separate goroutine.
439440
func (d *dialScheduler) startDial(task *dialTask) {
440-
d.log.Trace("Starting p2p dial", "id", task.dest.ID(), "ip", task.dest.IP(), "flag", task.flags)
441-
hkey := string(task.dest.ID().Bytes())
441+
node := task.dest()
442+
d.log.Trace("Starting p2p dial", "id", node.ID(), "ip", node.IP(), "flag", task.flags)
443+
hkey := string(node.ID().Bytes())
442444
d.history.add(hkey, d.clock.Now().Add(dialHistoryExpiration))
443-
d.dialing[task.dest.ID()] = task
445+
d.dialing[node.ID()] = task
444446
go func() {
445447
task.run(d)
446448
d.doneCh <- task
@@ -451,39 +453,46 @@ func (d *dialScheduler) startDial(task *dialTask) {
451453
type dialTask struct {
452454
staticPoolIndex int
453455
flags connFlag
456+
454457
// These fields are private to the task and should not be
455458
// accessed by dialScheduler while the task is running.
456-
dest *enode.Node
459+
destPtr atomic.Pointer[enode.Node]
457460
lastResolved mclock.AbsTime
458461
resolveDelay time.Duration
459462
}
460463

461464
func newDialTask(dest *enode.Node, flags connFlag) *dialTask {
462-
return &dialTask{dest: dest, flags: flags, staticPoolIndex: -1}
465+
t := &dialTask{flags: flags, staticPoolIndex: -1}
466+
t.destPtr.Store(dest)
467+
return t
463468
}
464469

465470
type dialError struct {
466471
error
467472
}
468473

474+
func (t *dialTask) dest() *enode.Node {
475+
return t.destPtr.Load()
476+
}
477+
469478
func (t *dialTask) run(d *dialScheduler) {
470479
if t.needResolve() && !t.resolve(d) {
471480
return
472481
}
473482

474-
err := t.dial(d, t.dest)
483+
err := t.dial(d, t.dest())
475484
if err != nil {
476485
// For static nodes, resolve one more time if dialing fails.
477486
if _, ok := err.(*dialError); ok && t.flags&staticDialedConn != 0 {
478487
if t.resolve(d) {
479-
t.dial(d, t.dest)
488+
t.dial(d, t.dest())
480489
}
481490
}
482491
}
483492
}
484493

485494
func (t *dialTask) needResolve() bool {
486-
return t.flags&staticDialedConn != 0 && t.dest.IP() == nil
495+
return t.flags&staticDialedConn != 0 && t.dest().IP() == nil
487496
}
488497

489498
// resolve attempts to find the current endpoint for the destination
@@ -502,38 +511,41 @@ func (t *dialTask) resolve(d *dialScheduler) bool {
502511
if t.lastResolved > 0 && time.Duration(d.clock.Now()-t.lastResolved) < t.resolveDelay {
503512
return false
504513
}
505-
resolved := d.resolver.Resolve(t.dest)
514+
515+
node := t.dest()
516+
resolved := d.resolver.Resolve(node)
506517
t.lastResolved = d.clock.Now()
507518
if resolved == nil {
508519
t.resolveDelay *= 2
509520
if t.resolveDelay > maxResolveDelay {
510521
t.resolveDelay = maxResolveDelay
511522
}
512-
d.log.Debug("Resolving node failed", "id", t.dest.ID(), "newdelay", t.resolveDelay)
523+
d.log.Debug("Resolving node failed", "id", node.ID(), "newdelay", t.resolveDelay)
513524
return false
514525
}
515526
// The node was found.
516527
t.resolveDelay = initialResolveDelay
517-
t.dest = resolved
518-
d.log.Debug("Resolved node", "id", t.dest.ID(), "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()})
528+
t.destPtr.Store(resolved)
529+
d.log.Debug("Resolved node", "id", resolved.ID(), "addr", &net.TCPAddr{IP: resolved.IP(), Port: resolved.TCP()})
519530
return true
520531
}
521532

522533
// dial performs the actual connection attempt.
523534
func (t *dialTask) dial(d *dialScheduler, dest *enode.Node) error {
524535
dialMeter.Mark(1)
525-
fd, err := d.dialer.Dial(d.ctx, t.dest)
536+
fd, err := d.dialer.Dial(d.ctx, dest)
526537
if err != nil {
527-
d.log.Trace("Dial error", "id", t.dest.ID(), "addr", nodeAddr(t.dest), "conn", t.flags, "err", cleanupDialErr(err))
538+
d.log.Trace("Dial error", "id", dest.ID(), "addr", nodeAddr(dest), "conn", t.flags, "err", cleanupDialErr(err))
528539
dialConnectionError.Mark(1)
529540
return &dialError{err}
530541
}
531542
return d.setupFunc(newMeteredConn(fd), t.flags, dest)
532543
}
533544

534545
func (t *dialTask) String() string {
535-
id := t.dest.ID()
536-
return fmt.Sprintf("%v %x %v:%d", t.flags, id[:8], t.dest.IP(), t.dest.TCP())
546+
node := t.dest()
547+
id := node.ID()
548+
return fmt.Sprintf("%v %x %v:%d", t.flags, id[:8], node.IP(), node.TCP())
537549
}
538550

539551
func cleanupDialErr(err error) error {

0 commit comments

Comments
 (0)