diff --git a/p2p/message.go b/p2p/message.go index 6521d09c24..dfc33f3497 100644 --- a/p2p/message.go +++ b/p2p/message.go @@ -174,10 +174,10 @@ func (rw *frameRW) ReadMsg() (msg Msg, err error) { // read magic and payload size start := make([]byte, 8) if _, err = io.ReadFull(rw.bufconn, start); err != nil { - return msg, newPeerError(errRead, "%v", err) + return msg, err } if !bytes.HasPrefix(start, magicToken) { - return msg, newPeerError(errMagicTokenMismatch, "got %x, want %x", start[:4], magicToken) + return msg, fmt.Errorf("bad magic token %x", start[:4], magicToken) } size := binary.BigEndian.Uint32(start[4:]) diff --git a/p2p/peer.go b/p2p/peer.go index 1fa8264a35..b61cf96daf 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -1,6 +1,7 @@ package p2p import ( + "errors" "fmt" "io" "io/ioutil" @@ -71,7 +72,8 @@ type Peer struct { runlock sync.RWMutex // protects running running map[string]*proto - protocolHandshakeEnabled bool + // disables protocol handshake, for testing + noHandshake bool protoWG sync.WaitGroup protoErr chan error @@ -134,11 +136,11 @@ func (p *Peer) Disconnect(reason DiscReason) { // String implements fmt.Stringer. func (p *Peer) String() string { - return fmt.Sprintf("Peer %.8x %v", p.remoteID, p.RemoteAddr()) + return fmt.Sprintf("Peer %.8x %v", p.remoteID[:], p.RemoteAddr()) } func newPeer(conn net.Conn, protocols []Protocol, ourName string, ourID, remoteID *discover.NodeID) *Peer { - logtag := fmt.Sprintf("Peer %.8x %v", remoteID, conn.RemoteAddr()) + logtag := fmt.Sprintf("Peer %.8x %v", remoteID[:], conn.RemoteAddr()) return &Peer{ Logger: logger.NewLogger(logtag), rw: newFrameRW(conn, msgWriteTimeout), @@ -164,33 +166,35 @@ func (p *Peer) run() DiscReason { var readErr = make(chan error, 1) defer p.closeProtocols() defer close(p.closed) - defer p.rw.Close() - // start the read loop go func() { readErr <- p.readLoop() }() - if p.protocolHandshakeEnabled { + if !p.noHandshake { if err := writeProtocolHandshake(p.rw, p.ourName, *p.ourID, p.protocols); err != nil { p.DebugDetailf("Protocol handshake error: %v\n", err) + p.rw.Close() return DiscProtocolError } } - // wait for an error or disconnect + // Wait for an error or disconnect. var reason DiscReason select { case err := <-readErr: // We rely on protocols to abort if there is a write error. It // might be more robust to handle them here as well. p.DebugDetailf("Read error: %v\n", err) - reason = DiscNetworkError + p.rw.Close() + return DiscNetworkError + case err := <-p.protoErr: reason = discReasonForError(err) case reason = <-p.disc: } - if reason != DiscNetworkError { - p.politeDisconnect(reason) - } + p.politeDisconnect(reason) + + // Wait for readLoop. It will end because conn is now closed. + <-readErr p.Debugf("Disconnected: %v\n", reason) return reason } @@ -198,9 +202,9 @@ func (p *Peer) run() DiscReason { func (p *Peer) politeDisconnect(reason DiscReason) { done := make(chan struct{}) go func() { - // send reason EncodeMsg(p.rw, discMsg, uint(reason)) - // discard any data that might arrive + // Wait for the other side to close the connection. + // Discard any data that they send until then. io.Copy(ioutil.Discard, p.rw) close(done) }() @@ -208,10 +212,11 @@ func (p *Peer) politeDisconnect(reason DiscReason) { case <-done: case <-time.After(disconnectGracePeriod): } + p.rw.Close() } func (p *Peer) readLoop() error { - if p.protocolHandshakeEnabled { + if !p.noHandshake { if err := readProtocolHandshake(p, p.rw); err != nil { return err } @@ -264,7 +269,7 @@ func readProtocolHandshake(p *Peer, rw MsgReadWriter) error { return newPeerError(errProtocolBreach, "expected handshake, got %x", msg.Code) } if msg.Size > baseProtocolMaxMsgSize { - return newPeerError(errMisc, "message too big") + return newPeerError(errInvalidMsg, "message too big") } var hs handshake if err := msg.Decode(&hs); err != nil { @@ -326,7 +331,7 @@ func (p *Peer) startProto(offset uint64, impl Protocol) *proto { err := impl.Run(p, rw) if err == nil { p.DebugDetailf("Protocol %s/%d returned\n", impl.Name, impl.Version) - err = newPeerError(errMisc, "protocol returned") + err = errors.New("protocol returned") } else { p.DebugDetailf("Protocol %s/%d error: %v\n", impl.Name, impl.Version, err) } diff --git a/p2p/peer_error.go b/p2p/peer_error.go index 9133768f90..0ff4f4b43e 100644 --- a/p2p/peer_error.go +++ b/p2p/peer_error.go @@ -123,7 +123,7 @@ func discReasonForError(err error) DiscReason { return DiscProtocolError case errPingTimeout: return DiscReadTimeout - case errRead, errWrite, errMisc: + case errRead, errWrite: return DiscNetworkError default: return DiscSubprotocolError diff --git a/p2p/peer_test.go b/p2p/peer_test.go index 76d856d3e0..68c9910a26 100644 --- a/p2p/peer_test.go +++ b/p2p/peer_test.go @@ -30,10 +30,10 @@ var discard = Protocol{ }, } -func testPeer(handshake bool, protos []Protocol) (*frameRW, *Peer, <-chan DiscReason) { +func testPeer(noHandshake bool, protos []Protocol) (*frameRW, *Peer, <-chan DiscReason) { conn1, conn2 := net.Pipe() peer := newPeer(conn1, protos, "name", &discover.NodeID{}, &discover.NodeID{}) - peer.protocolHandshakeEnabled = handshake + peer.noHandshake = noHandshake errc := make(chan DiscReason, 1) go func() { errc <- peer.run() }() return newFrameRW(conn2, msgWriteTimeout), peer, errc @@ -61,7 +61,7 @@ func TestPeerProtoReadMsg(t *testing.T) { }, } - rw, peer, errc := testPeer(false, []Protocol{proto}) + rw, peer, errc := testPeer(true, []Protocol{proto}) defer rw.Close() peer.startSubprotocols([]Cap{proto.cap()}) @@ -100,7 +100,7 @@ func TestPeerProtoReadLargeMsg(t *testing.T) { }, } - rw, peer, errc := testPeer(false, []Protocol{proto}) + rw, peer, errc := testPeer(true, []Protocol{proto}) defer rw.Close() peer.startSubprotocols([]Cap{proto.cap()}) @@ -130,7 +130,7 @@ func TestPeerProtoEncodeMsg(t *testing.T) { return nil }, } - rw, peer, _ := testPeer(false, []Protocol{proto}) + rw, peer, _ := testPeer(true, []Protocol{proto}) defer rw.Close() peer.startSubprotocols([]Cap{proto.cap()}) @@ -142,7 +142,7 @@ func TestPeerProtoEncodeMsg(t *testing.T) { func TestPeerWriteForBroadcast(t *testing.T) { defer testlog(t).detach() - rw, peer, peerErr := testPeer(false, []Protocol{discard}) + rw, peer, peerErr := testPeer(true, []Protocol{discard}) defer rw.Close() peer.startSubprotocols([]Cap{discard.cap()}) @@ -179,7 +179,7 @@ func TestPeerWriteForBroadcast(t *testing.T) { func TestPeerPing(t *testing.T) { defer testlog(t).detach() - rw, _, _ := testPeer(false, nil) + rw, _, _ := testPeer(true, nil) defer rw.Close() if err := EncodeMsg(rw, pingMsg); err != nil { t.Fatal(err) @@ -192,7 +192,7 @@ func TestPeerPing(t *testing.T) { func TestPeerDisconnect(t *testing.T) { defer testlog(t).detach() - rw, _, disc := testPeer(false, nil) + rw, _, disc := testPeer(true, nil) defer rw.Close() if err := EncodeMsg(rw, discMsg, DiscQuitting); err != nil { t.Fatal(err) @@ -233,7 +233,7 @@ func TestPeerHandshake(t *testing.T) { {Name: "c", Version: 3, Length: 1, Run: run}, {Name: "d", Version: 4, Length: 1, Run: run}, } - rw, p, disc := testPeer(true, protocols) + rw, p, disc := testPeer(false, protocols) p.remoteID = remote.ourID defer rw.Close() @@ -269,6 +269,7 @@ func TestPeerHandshake(t *testing.T) { } close(stop) + expectMsg(rw, discMsg, nil) t.Logf("disc reason: %v", <-disc) } diff --git a/p2p/server.go b/p2p/server.go index 87be97a2f3..c6d7fc2e82 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -408,7 +408,9 @@ func (srv *Server) startPeer(conn net.Conn, dest *discover.Node) { return } - srv.newPeerHook(p) + if srv.newPeerHook != nil { + srv.newPeerHook(p) + } p.run() srv.removePeer(p) } diff --git a/p2p/server_test.go b/p2p/server_test.go index 89300cf1cc..d1e1640fb1 100644 --- a/p2p/server_test.go +++ b/p2p/server_test.go @@ -118,6 +118,7 @@ func TestServerBroadcast(t *testing.T) { srv := startTestServer(t, func(p *Peer) { p.protocols = []Protocol{discard} p.startSubprotocols([]Cap{discard.cap()}) + p.noHandshake = true connected.Done() }) defer srv.Stop()