From c1c250714842768687e8a4abe0dd9f57ca9dcc53 Mon Sep 17 00:00:00 2001 From: lorenzo <31852651+lorenzo-dev1@users.noreply.github.com> Date: Thu, 12 Dec 2024 12:33:42 +0100 Subject: [PATCH] p2p: fix DiscReason encoding/decoding (#30855) This fixes an issue where the disconnect message was not wrapped in a list. The specification requires it to be a list like any other message. In order to remain compatible with legacy geth versions, we now accept both encodings when parsing a disconnect message. --------- Co-authored-by: Felix Lange --- p2p/peer.go | 25 ++++++++++++++++++++++--- p2p/peer_error.go | 5 ++++- p2p/transport.go | 24 ++++++++++-------------- p2p/transport_test.go | 10 ++++++++-- 4 files changed, 44 insertions(+), 20 deletions(-) diff --git a/p2p/peer.go b/p2p/peer.go index 30be151a2f..a01df63d0c 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -345,9 +345,7 @@ func (p *Peer) handle(msg Msg) error { case msg.Code == discMsg: // This is the last message. We don't need to discard or // check errors because, the connection will be closed after it. - var m struct{ R DiscReason } - rlp.Decode(msg.Payload, &m) - return m.R + return decodeDisconnectMessage(msg.Payload) case msg.Code < baseProtocolLength: // ignore other base protocol messages return msg.Discard() @@ -372,6 +370,27 @@ func (p *Peer) handle(msg Msg) error { return nil } +// decodeDisconnectMessage decodes the payload of discMsg. +func decodeDisconnectMessage(r io.Reader) (reason DiscReason) { + s := rlp.NewStream(r, 100) + k, _, err := s.Kind() + if err != nil { + return DiscInvalid + } + if k == rlp.List { + s.List() + err = s.Decode(&reason) + } else { + // Legacy path: some implementations, including geth, used to send the disconnect + // reason as a byte array by accident. + err = s.Decode(&reason) + } + if err != nil { + reason = DiscInvalid + } + return reason +} + func countMatchingProtocols(protocols []Protocol, caps []Cap) int { n := 0 for _, cap := range caps { diff --git a/p2p/peer_error.go b/p2p/peer_error.go index ebc59de251..dcdadf7fe3 100644 --- a/p2p/peer_error.go +++ b/p2p/peer_error.go @@ -70,6 +70,8 @@ const ( DiscSelf DiscReadTimeout DiscSubprotocolError = DiscReason(0x10) + + DiscInvalid = 0xff ) var discReasonToString = [...]string{ @@ -86,10 +88,11 @@ var discReasonToString = [...]string{ DiscSelf: "connected to self", DiscReadTimeout: "read timeout", DiscSubprotocolError: "subprotocol error", + DiscInvalid: "invalid disconnect reason", } func (d DiscReason) String() string { - if len(discReasonToString) <= int(d) { + if len(discReasonToString) <= int(d) || discReasonToString[d] == "" { return fmt.Sprintf("unknown disconnect reason %d", d) } return discReasonToString[d] diff --git a/p2p/transport.go b/p2p/transport.go index 360e73a0de..87d3013f11 100644 --- a/p2p/transport.go +++ b/p2p/transport.go @@ -113,15 +113,14 @@ func (t *rlpxTransport) close(err error) { // Tell the remote end why we're disconnecting if possible. // We only bother doing this if the underlying connection supports // setting a timeout tough. - if t.conn != nil { - if r, ok := err.(DiscReason); ok && r != DiscNetworkError { - deadline := time.Now().Add(discWriteTimeout) - if err := t.conn.SetWriteDeadline(deadline); err == nil { - // Connection supports write deadline. - t.wbuf.Reset() - rlp.Encode(&t.wbuf, []DiscReason{r}) - t.conn.Write(discMsg, t.wbuf.Bytes()) - } + if reason, ok := err.(DiscReason); ok && reason != DiscNetworkError { + // We do not use the WriteMsg func since we want a custom deadline + deadline := time.Now().Add(discWriteTimeout) + if err := t.conn.SetWriteDeadline(deadline); err == nil { + // Connection supports write deadline. + t.wbuf.Reset() + rlp.Encode(&t.wbuf, []any{reason}) + t.conn.Write(discMsg, t.wbuf.Bytes()) } } t.conn.Close() @@ -163,11 +162,8 @@ func readProtocolHandshake(rw MsgReader) (*protoHandshake, error) { if msg.Code == discMsg { // Disconnect before protocol handshake is valid according to the // spec and we send it ourself if the post-handshake checks fail. - // We can't return the reason directly, though, because it is echoed - // back otherwise. Wrap it in a string instead. - var reason [1]DiscReason - rlp.Decode(msg.Payload, &reason) - return nil, reason[0] + r := decodeDisconnectMessage(msg.Payload) + return nil, r } if msg.Code != handshakeMsg { return nil, fmt.Errorf("expected handshake, got %x", msg.Code) diff --git a/p2p/transport_test.go b/p2p/transport_test.go index 01695cd3af..777be1bd0d 100644 --- a/p2p/transport_test.go +++ b/p2p/transport_test.go @@ -97,7 +97,7 @@ func TestProtocolHandshake(t *testing.T) { return } - if err := ExpectMsg(rlpx, discMsg, []DiscReason{DiscQuitting}); err != nil { + if err := ExpectMsg(rlpx, discMsg, []any{DiscQuitting}); err != nil { t.Errorf("error receiving disconnect: %v", err) } }() @@ -112,7 +112,13 @@ func TestProtocolHandshakeErrors(t *testing.T) { }{ { code: discMsg, - msg: []DiscReason{DiscQuitting}, + msg: []any{DiscQuitting}, + err: DiscQuitting, + }, + { + // legacy disconnect encoding as byte array + code: discMsg, + msg: []byte{byte(DiscQuitting)}, err: DiscQuitting, }, {