p2p: fix DiscReason encoding/decoding (#30855)

This fixes an issue where the disconnect message was not wrapped in a list.
The specification requires it to be a list like any other message.

In order to remain compatible with legacy geth versions, we now accept both
encodings when parsing a disconnect message.

---------

Co-authored-by: Felix Lange <fjl@twurst.com>
pull/30981/head
lorenzo 2 months ago committed by GitHub
parent c7e740f40c
commit c1c2507148
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 25
      p2p/peer.go
  2. 5
      p2p/peer_error.go
  3. 24
      p2p/transport.go
  4. 10
      p2p/transport_test.go

@ -345,9 +345,7 @@ func (p *Peer) handle(msg Msg) error {
case msg.Code == discMsg: case msg.Code == discMsg:
// This is the last message. We don't need to discard or // This is the last message. We don't need to discard or
// check errors because, the connection will be closed after it. // check errors because, the connection will be closed after it.
var m struct{ R DiscReason } return decodeDisconnectMessage(msg.Payload)
rlp.Decode(msg.Payload, &m)
return m.R
case msg.Code < baseProtocolLength: case msg.Code < baseProtocolLength:
// ignore other base protocol messages // ignore other base protocol messages
return msg.Discard() return msg.Discard()
@ -372,6 +370,27 @@ func (p *Peer) handle(msg Msg) error {
return nil return nil
} }
// decodeDisconnectMessage decodes the payload of discMsg.
func decodeDisconnectMessage(r io.Reader) (reason DiscReason) {
s := rlp.NewStream(r, 100)
k, _, err := s.Kind()
if err != nil {
return DiscInvalid
}
if k == rlp.List {
s.List()
err = s.Decode(&reason)
} else {
// Legacy path: some implementations, including geth, used to send the disconnect
// reason as a byte array by accident.
err = s.Decode(&reason)
}
if err != nil {
reason = DiscInvalid
}
return reason
}
func countMatchingProtocols(protocols []Protocol, caps []Cap) int { func countMatchingProtocols(protocols []Protocol, caps []Cap) int {
n := 0 n := 0
for _, cap := range caps { for _, cap := range caps {

@ -70,6 +70,8 @@ const (
DiscSelf DiscSelf
DiscReadTimeout DiscReadTimeout
DiscSubprotocolError = DiscReason(0x10) DiscSubprotocolError = DiscReason(0x10)
DiscInvalid = 0xff
) )
var discReasonToString = [...]string{ var discReasonToString = [...]string{
@ -86,10 +88,11 @@ var discReasonToString = [...]string{
DiscSelf: "connected to self", DiscSelf: "connected to self",
DiscReadTimeout: "read timeout", DiscReadTimeout: "read timeout",
DiscSubprotocolError: "subprotocol error", DiscSubprotocolError: "subprotocol error",
DiscInvalid: "invalid disconnect reason",
} }
func (d DiscReason) String() string { func (d DiscReason) String() string {
if len(discReasonToString) <= int(d) { if len(discReasonToString) <= int(d) || discReasonToString[d] == "" {
return fmt.Sprintf("unknown disconnect reason %d", d) return fmt.Sprintf("unknown disconnect reason %d", d)
} }
return discReasonToString[d] return discReasonToString[d]

@ -113,15 +113,14 @@ func (t *rlpxTransport) close(err error) {
// Tell the remote end why we're disconnecting if possible. // Tell the remote end why we're disconnecting if possible.
// We only bother doing this if the underlying connection supports // We only bother doing this if the underlying connection supports
// setting a timeout tough. // setting a timeout tough.
if t.conn != nil { if reason, ok := err.(DiscReason); ok && reason != DiscNetworkError {
if r, ok := err.(DiscReason); ok && r != DiscNetworkError { // We do not use the WriteMsg func since we want a custom deadline
deadline := time.Now().Add(discWriteTimeout) deadline := time.Now().Add(discWriteTimeout)
if err := t.conn.SetWriteDeadline(deadline); err == nil { if err := t.conn.SetWriteDeadline(deadline); err == nil {
// Connection supports write deadline. // Connection supports write deadline.
t.wbuf.Reset() t.wbuf.Reset()
rlp.Encode(&t.wbuf, []DiscReason{r}) rlp.Encode(&t.wbuf, []any{reason})
t.conn.Write(discMsg, t.wbuf.Bytes()) t.conn.Write(discMsg, t.wbuf.Bytes())
}
} }
} }
t.conn.Close() t.conn.Close()
@ -163,11 +162,8 @@ func readProtocolHandshake(rw MsgReader) (*protoHandshake, error) {
if msg.Code == discMsg { if msg.Code == discMsg {
// Disconnect before protocol handshake is valid according to the // Disconnect before protocol handshake is valid according to the
// spec and we send it ourself if the post-handshake checks fail. // spec and we send it ourself if the post-handshake checks fail.
// We can't return the reason directly, though, because it is echoed r := decodeDisconnectMessage(msg.Payload)
// back otherwise. Wrap it in a string instead. return nil, r
var reason [1]DiscReason
rlp.Decode(msg.Payload, &reason)
return nil, reason[0]
} }
if msg.Code != handshakeMsg { if msg.Code != handshakeMsg {
return nil, fmt.Errorf("expected handshake, got %x", msg.Code) return nil, fmt.Errorf("expected handshake, got %x", msg.Code)

@ -97,7 +97,7 @@ func TestProtocolHandshake(t *testing.T) {
return return
} }
if err := ExpectMsg(rlpx, discMsg, []DiscReason{DiscQuitting}); err != nil { if err := ExpectMsg(rlpx, discMsg, []any{DiscQuitting}); err != nil {
t.Errorf("error receiving disconnect: %v", err) t.Errorf("error receiving disconnect: %v", err)
} }
}() }()
@ -112,7 +112,13 @@ func TestProtocolHandshakeErrors(t *testing.T) {
}{ }{
{ {
code: discMsg, code: discMsg,
msg: []DiscReason{DiscQuitting}, msg: []any{DiscQuitting},
err: DiscQuitting,
},
{
// legacy disconnect encoding as byte array
code: discMsg,
msg: []byte{byte(DiscQuitting)},
err: DiscQuitting, err: DiscQuitting,
}, },
{ {

Loading…
Cancel
Save