diff --git a/p2p/peer.go b/p2p/peer.go index ff86026028..c4c1fcd7c7 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -95,11 +95,10 @@ type PeerEvent struct { // Peer represents a connected remote node. type Peer struct { - rw *conn - isInbound bool // Cached from rw.flags to avoid a race condition - running map[string]*protoRW - log log.Logger - created mclock.AbsTime + rw *conn + running map[string]*protoRW + log log.Logger + created mclock.AbsTime wg sync.WaitGroup protoErr chan error @@ -161,20 +160,19 @@ func (p *Peer) String() string { // Inbound returns true if the peer is an inbound connection func (p *Peer) Inbound() bool { - return p.isInbound + return p.rw.is(inboundConn) } func newPeer(conn *conn, protocols []Protocol) *Peer { protomap := matchProtocols(protocols, conn.caps, conn) p := &Peer{ - rw: conn, - isInbound: conn.is(inboundConn), - running: protomap, - created: mclock.Now(), - disc: make(chan DiscReason), - protoErr: make(chan error, len(protomap)+1), // protocols + pingLoop - closed: make(chan struct{}), - log: log.New("id", conn.id, "conn", conn.flags), + rw: conn, + running: protomap, + created: mclock.Now(), + disc: make(chan DiscReason), + protoErr: make(chan error, len(protomap)+1), // protocols + pingLoop + closed: make(chan struct{}), + log: log.New("id", conn.id, "conn", conn.flags), } return p } diff --git a/p2p/server.go b/p2p/server.go index 39ff2f51e8..d2cb949255 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -23,6 +23,7 @@ import ( "fmt" "net" "sync" + "sync/atomic" "time" "github.com/ethereum/go-ethereum/common" @@ -187,7 +188,7 @@ type peerDrop struct { requested bool // true if signaled by the peer } -type connFlag int +type connFlag int32 const ( dynDialedConn connFlag = 1 << iota @@ -252,7 +253,18 @@ func (f connFlag) String() string { } func (c *conn) is(f connFlag) bool { - return c.flags&f != 0 + flags := connFlag(atomic.LoadInt32((*int32)(&c.flags))) + return flags&f != 0 +} + +func (c *conn) set(f connFlag, val bool) { + flags := connFlag(atomic.LoadInt32((*int32)(&c.flags))) + if val { + flags |= f + } else { + flags &= ^f + } + atomic.StoreInt32((*int32)(&c.flags), int32(flags)) } // Peers returns all connected peers. @@ -632,7 +644,7 @@ running: trusted[n.ID] = true // Mark any already-connected peer as trusted if p, ok := peers[n.ID]; ok { - p.rw.flags |= trustedConn + p.rw.set(trustedConn, true) } case n := <-srv.removetrusted: // This channel is used by RemoveTrustedPeer to remove an enode @@ -643,7 +655,7 @@ running: } // Unmark any already-connected peer as trusted if p, ok := peers[n.ID]; ok { - p.rw.flags &= ^trustedConn + p.rw.set(trustedConn, false) } case op := <-srv.peerOp: // This channel is used by Peers and PeerCount. diff --git a/p2p/server_test.go b/p2p/server_test.go index 65897e0185..3f24a79bae 100644 --- a/p2p/server_test.go +++ b/p2p/server_test.go @@ -189,12 +189,10 @@ func TestServerDial(t *testing.T) { } done <- true }() - // Trigger potential race conditions peer = srv.Peers()[0] _ = peer.Inbound() _ = peer.Info() - <-done case <-time.After(1 * time.Second): t.Error("server did not launch peer within one second")