diff --git a/rlp/decode.go b/rlp/decode.go index 6952ecaea..0c660426f 100644 --- a/rlp/decode.go +++ b/rlp/decode.go @@ -820,6 +820,16 @@ func (s *Stream) Kind() (kind Kind, size uint64, err error) { func (s *Stream) readKind() (kind Kind, size uint64, err error) { b, err := s.readByte() if err != nil { + if len(s.stack) == 0 { + // At toplevel, Adjust the error to actual EOF. io.EOF is + // used by callers to determine when to stop decoding. + switch err { + case io.ErrUnexpectedEOF: + err = io.EOF + case ErrValueTooLarge: + err = io.EOF + } + } return 0, 0, err } s.byteval = 0 @@ -876,9 +886,6 @@ func (s *Stream) readUint(size byte) (uint64, error) { return 0, nil case 1: b, err := s.readByte() - if err == io.EOF { - err = io.ErrUnexpectedEOF - } return uint64(b), err default: start := int(8 - size) @@ -899,10 +906,9 @@ func (s *Stream) readUint(size byte) (uint64, error) { } func (s *Stream) readFull(buf []byte) (err error) { - if s.limited && s.remaining < uint64(len(buf)) { - return ErrValueTooLarge + if err := s.willRead(uint64(len(buf))); err != nil { + return err } - s.willRead(uint64(len(buf))) var nn, n int for n < len(buf) && err == nil { nn, err = s.r.Read(buf[n:]) @@ -915,23 +921,32 @@ func (s *Stream) readFull(buf []byte) (err error) { } func (s *Stream) readByte() (byte, error) { - if s.limited && s.remaining == 0 { - return 0, io.EOF + if err := s.willRead(1); err != nil { + return 0, err } - s.willRead(1) b, err := s.r.ReadByte() - if len(s.stack) > 0 && err == io.EOF { + if err == io.EOF { err = io.ErrUnexpectedEOF } return b, err } -func (s *Stream) willRead(n uint64) { +func (s *Stream) willRead(n uint64) error { s.kind = -1 // rearm Kind - if s.limited { - s.remaining -= n - } + if len(s.stack) > 0 { + // check list overflow + tos := s.stack[len(s.stack)-1] + if n > tos.size-tos.pos { + return ErrElemTooLarge + } s.stack[len(s.stack)-1].pos += n } + if s.limited { + if n > s.remaining { + return ErrValueTooLarge + } + s.remaining -= n + } + return nil } diff --git a/rlp/decode_test.go b/rlp/decode_test.go index d07520bd0..ae65346a9 100644 --- a/rlp/decode_test.go +++ b/rlp/decode_test.go @@ -119,6 +119,10 @@ func TestStreamErrors(t *testing.T) { {"8158", calls{"Uint", "Uint"}, nil, io.EOF}, {"C0", calls{"List", "ListEnd", "List"}, nil, io.EOF}, + {"", calls{"List"}, withoutInputLimit, io.EOF}, + {"8158", calls{"Uint", "Uint"}, withoutInputLimit, io.EOF}, + {"C0", calls{"List", "ListEnd", "List"}, withoutInputLimit, io.EOF}, + // Input limit errors. {"81", calls{"Bytes"}, nil, ErrValueTooLarge}, {"81", calls{"Uint"}, nil, ErrValueTooLarge}, @@ -426,6 +430,13 @@ var decodeTests = []decodeTest{ ptr: new([]io.Reader), error: "rlp: type io.Reader is not RLP-serializable", }, + + // fuzzer crashes + { + input: "c330f9c030f93030ce3030303030303030bd303030303030", + ptr: new(interface{}), + error: "rlp: element is larger than containing list", + }, } func uintp(i uint) *uint { return &i }