diff --git a/p2p/peer.go b/p2p/peer.go index ac691f2ce8..c7ec08887a 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -131,10 +131,11 @@ func (p *Peer) run() DiscReason { case err := <-p.protoErr: reason = discReasonForError(err) case reason = <-p.disc: + p.politeDisconnect(reason) + reason = DiscRequested } close(p.closed) - p.politeDisconnect(reason) p.wg.Wait() glog.V(logger.Debug).Infof("%v: Disconnected: %v\n", p, reason) return reason @@ -191,7 +192,7 @@ func (p *Peer) handle(msg Msg) error { // check errors because, the connection will be closed after it. rlp.Decode(msg.Payload, &reason) glog.V(logger.Debug).Infof("%v: Disconnect Requested: %v\n", p, reason[0]) - return DiscRequested + return reason[0] case msg.Code < baseProtocolLength: // ignore other base protocol messages return msg.Discard() diff --git a/p2p/peer_test.go b/p2p/peer_test.go index fb76818a06..7d17d447cd 100644 --- a/p2p/peer_test.go +++ b/p2p/peer_test.go @@ -172,12 +172,13 @@ func TestPeerDisconnect(t *testing.T) { if err := SendItems(rw, discMsg, DiscQuitting); err != nil { t.Fatal(err) } - if err := ExpectMsg(rw, discMsg, []interface{}{DiscRequested}); err != nil { - t.Error(err) - } - closer() - if reason := <-disc; reason != DiscRequested { - t.Errorf("run returned wrong reason: got %v, want %v", reason, DiscRequested) + select { + case reason := <-disc: + if reason != DiscQuitting { + t.Errorf("run returned wrong reason: got %v, want %v", reason, DiscRequested) + } + case <-time.After(500 * time.Millisecond): + t.Error("peer did not return") } }