Skip to content

Commit 642f552

Browse files
fjlqu0b
authored andcommitted
p2p: fix race in dialScheduler (ethereum#29235)
Co-authored-by: Stefan <[email protected]>
1 parent 0e093b2 commit 642f552

File tree

1 file changed

+29
-22
lines changed

1 file changed

+29
-22
lines changed

p2p/dial.go

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
mrand "math/rand"
2626
"net"
2727
"sync"
28+
"sync/atomic"
2829
"time"
2930

3031
"github.com/ethereum/go-ethereum/common/mclock"
@@ -254,7 +255,7 @@ loop:
254255
}
255256

256257
case task := <-d.doneCh:
257-
id := task.dest.ID()
258+
id := task.dest().ID()
258259
delete(d.dialing, id)
259260
d.updateStaticPool(id)
260261
d.doneSinceLastLog++
@@ -431,7 +432,7 @@ func (d *dialScheduler) startStaticDials(n int) (started int) {
431432
// updateStaticPool attempts to move the given static dial back into staticPool.
432433
func (d *dialScheduler) updateStaticPool(id enode.ID) {
433434
task, ok := d.static[id]
434-
if ok && task.staticPoolIndex < 0 && d.checkDial(task.dest) == nil {
435+
if ok && task.staticPoolIndex < 0 && d.checkDial(task.dest()) == nil {
435436
d.addToStaticPool(task)
436437
}
437438
}
@@ -459,11 +460,11 @@ func (d *dialScheduler) removeFromStaticPool(idx int) {
459460

460461
// startDial runs the given dial task in a separate goroutine.
461462
func (d *dialScheduler) startDial(task *dialTask) {
462-
d.log.Trace("Starting p2p dial", "id", task.dest.ID(), "ip", task.dest.IP(), "flag", task.flags)
463-
hkey := string(task.dest.ID().Bytes())
463+
node := task.dest()
464+
d.log.Trace("Starting p2p dial", "id", node.ID(), "ip", node.IP(), "flag", task.flags)
465+
hkey := string(node.ID().Bytes())
464466
d.history.add(hkey, d.clock.Now().Add(dialHistoryExpiration))
465-
d.dialing[task.dest.ID()] = task
466-
467+
d.dialing[node.ID()] = task
467468
go func() {
468469
task.run(d)
469470
d.doneCh <- task
@@ -474,39 +475,46 @@ func (d *dialScheduler) startDial(task *dialTask) {
474475
type dialTask struct {
475476
staticPoolIndex int
476477
flags connFlag
478+
477479
// These fields are private to the task and should not be
478480
// accessed by dialScheduler while the task is running.
479-
dest *enode.Node
481+
destPtr atomic.Pointer[enode.Node]
480482
lastResolved mclock.AbsTime
481483
resolveDelay time.Duration
482484
}
483485

484486
func newDialTask(dest *enode.Node, flags connFlag) *dialTask {
485-
return &dialTask{dest: dest, flags: flags, staticPoolIndex: -1}
487+
t := &dialTask{flags: flags, staticPoolIndex: -1}
488+
t.destPtr.Store(dest)
489+
return t
486490
}
487491

488492
type dialError struct {
489493
error
490494
}
491495

496+
func (t *dialTask) dest() *enode.Node {
497+
return t.destPtr.Load()
498+
}
499+
492500
func (t *dialTask) run(d *dialScheduler) {
493501
if t.needResolve() && !t.resolve(d) {
494502
return
495503
}
496504

497-
err := t.dial(d, t.dest)
505+
err := t.dial(d, t.dest())
498506
if err != nil {
499507
// For static nodes, resolve one more time if dialing fails.
500508
if _, ok := err.(*dialError); ok && t.flags&staticDialedConn != 0 {
501509
if t.resolve(d) {
502-
t.dial(d, t.dest)
510+
t.dial(d, t.dest())
503511
}
504512
}
505513
}
506514
}
507515

508516
func (t *dialTask) needResolve() bool {
509-
return t.flags&staticDialedConn != 0 && t.dest.IP() == nil
517+
return t.flags&staticDialedConn != 0 && t.dest().IP() == nil
510518
}
511519

512520
// resolve attempts to find the current endpoint for the destination
@@ -528,42 +536,41 @@ func (t *dialTask) resolve(d *dialScheduler) bool {
528536
return false
529537
}
530538

531-
resolved := d.resolver.Resolve(t.dest)
539+
node := t.dest()
540+
resolved := d.resolver.Resolve(node)
532541
t.lastResolved = d.clock.Now()
533542

534543
if resolved == nil {
535544
t.resolveDelay *= 2
536545
if t.resolveDelay > maxResolveDelay {
537546
t.resolveDelay = maxResolveDelay
538547
}
539-
540-
d.log.Debug("Resolving node failed", "id", t.dest.ID(), "newdelay", t.resolveDelay)
541-
548+
d.log.Debug("Resolving node failed", "id", node.ID(), "newdelay", t.resolveDelay)
542549
return false
543550
}
544551
// The node was found.
545552
t.resolveDelay = initialResolveDelay
546-
t.dest = resolved
547-
d.log.Debug("Resolved node", "id", t.dest.ID(), "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()})
548-
553+
t.destPtr.Store(resolved)
554+
d.log.Debug("Resolved node", "id", resolved.ID(), "addr", &net.TCPAddr{IP: resolved.IP(), Port: resolved.TCP()})
549555
return true
550556
}
551557

552558
// dial performs the actual connection attempt.
553559
func (t *dialTask) dial(d *dialScheduler, dest *enode.Node) error {
554560
dialMeter.Mark(1)
555-
fd, err := d.dialer.Dial(d.ctx, t.dest)
561+
fd, err := d.dialer.Dial(d.ctx, dest)
556562
if err != nil {
557-
d.log.Trace("Dial error", "id", t.dest.ID(), "addr", nodeAddr(t.dest), "conn", t.flags, "err", cleanupDialErr(err))
563+
d.log.Trace("Dial error", "id", dest.ID(), "addr", nodeAddr(dest), "conn", t.flags, "err", cleanupDialErr(err))
558564
dialConnectionError.Mark(1)
559565
return &dialError{err}
560566
}
561567
return d.setupFunc(newMeteredConn(fd), t.flags, dest)
562568
}
563569

564570
func (t *dialTask) String() string {
565-
id := t.dest.ID()
566-
return fmt.Sprintf("%v %x %v:%d", t.flags, id[:8], t.dest.IP(), t.dest.TCP())
571+
node := t.dest()
572+
id := node.ID()
573+
return fmt.Sprintf("%v %x %v:%d", t.flags, id[:8], node.IP(), node.TCP())
567574
}
568575

569576
func cleanupDialErr(err error) error {

0 commit comments

Comments
 (0)