diff --git a/rlp/decode.go b/rlp/decode.go index 06786eae72..712d9fcf18 100644 --- a/rlp/decode.go +++ b/rlp/decode.go @@ -82,6 +82,20 @@ func (err decodeError) Error() string { return fmt.Sprintf("rlp: %s for %v", err.msg, err.typ) } +func wrapStreamError(err error, typ reflect.Type) error { + switch err { + case ErrExpectedList: + return decodeError{"expected input list", typ} + case ErrExpectedString: + return decodeError{"expected input string or byte", typ} + case errUintOverflow: + return decodeError{"input string too long", typ} + case errNotAtEOL: + return decodeError{"input list has too many elements", typ} + } + return err +} + var ( decoderInterface = reflect.TypeOf(new(Decoder)).Elem() bigInt = reflect.TypeOf(big.Int{}) @@ -118,10 +132,8 @@ func makeDecoder(typ reflect.Type) (dec decoder, err error) { func decodeUint(s *Stream, val reflect.Value) error { typ := val.Type() num, err := s.uint(typ.Bits()) - if err == errUintOverflow { - return decodeError{"input string too big", typ} - } else if err != nil { - return err + if err != nil { + return wrapStreamError(err, val.Type()) } val.SetUint(num) return nil @@ -130,7 +142,7 @@ func decodeUint(s *Stream, val reflect.Value) error { func decodeString(s *Stream, val reflect.Value) error { b, err := s.Bytes() if err != nil { - return err + return wrapStreamError(err, val.Type()) } val.SetString(string(b)) return nil @@ -143,7 +155,7 @@ func decodeBigIntNoPtr(s *Stream, val reflect.Value) error { func decodeBigInt(s *Stream, val reflect.Value) error { b, err := s.Bytes() if err != nil { - return err + return wrapStreamError(err, val.Type()) } i := val.Interface().(*big.Int) if i == nil { @@ -181,7 +193,7 @@ func makeListDecoder(typ reflect.Type) (decoder, error) { func decodeListSlice(s *Stream, val reflect.Value, elemdec decoder) error { size, err := s.List() if err != nil { - return err + return wrapStreamError(err, val.Type()) } if size == 0 { val.Set(reflect.MakeSlice(val.Type(), 0, 0)) @@ -242,10 +254,7 @@ func decodeListArray(s *Stream, val reflect.Value, elemdec decoder) error { if i < vlen { zero(val, i) } - if err = s.ListEnd(); err == errNotAtEOL { - return decodeError{"input list has too many elements", val.Type()} - } - return err + return wrapStreamError(s.ListEnd(), val.Type()) } func decodeByteSlice(s *Stream, val reflect.Value) error { @@ -271,14 +280,14 @@ func decodeByteArray(s *Stream, val reflect.Value) error { switch kind { case Byte: if val.Len() == 0 { - return decodeError{"input string too big", val.Type()} + return decodeError{"input string too long", val.Type()} } bv, _ := s.Uint() val.Index(0).SetUint(bv) zero(val, 1) case String: if uint64(val.Len()) < size { - return decodeError{"input string too big", val.Type()} + return decodeError{"input string too long", val.Type()} } slice := val.Slice(0, int(size)).Interface().([]byte) if err := s.readFull(slice); err != nil { @@ -317,7 +326,7 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) { } dec := func(s *Stream, val reflect.Value) (err error) { if _, err = s.List(); err != nil { - return err + return wrapStreamError(err, typ) } for _, f := range fields { err = f.info.decoder(s, val.Field(f.index)) @@ -328,10 +337,7 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) { return err } } - if err = s.ListEnd(); err == errNotAtEOL { - err = decodeError{"input list has too many elements", typ} - } - return err + return wrapStreamError(s.ListEnd(), typ) } return dec, nil } diff --git a/rlp/decode_test.go b/rlp/decode_test.go index 4c030e24d2..7a1743937d 100644 --- a/rlp/decode_test.go +++ b/rlp/decode_test.go @@ -213,8 +213,8 @@ var decodeTests = []decodeTest{ {input: "820505", ptr: new(uint32), value: uint32(0x0505)}, {input: "83050505", ptr: new(uint32), value: uint32(0x050505)}, {input: "8405050505", ptr: new(uint32), value: uint32(0x05050505)}, - {input: "850505050505", ptr: new(uint32), error: "rlp: input string too big for uint32"}, - {input: "C0", ptr: new(uint32), error: ErrExpectedString.Error()}, + {input: "850505050505", ptr: new(uint32), error: "rlp: input string too long for uint32"}, + {input: "C0", ptr: new(uint32), error: "rlp: expected input string or byte for uint32"}, // slices {input: "C0", ptr: new([]uint), value: []uint{}}, @@ -231,7 +231,7 @@ var decodeTests = []decodeTest{ {input: "8D6162636465666768696A6B6C6D", ptr: new([]byte), value: []byte("abcdefghijklm")}, {input: "C0", ptr: new([]byte), value: []byte{}}, {input: "C3010203", ptr: new([]byte), value: []byte{1, 2, 3}}, - {input: "C3820102", ptr: new([]byte), error: "rlp: input string too big for uint8"}, + {input: "C3820102", ptr: new([]byte), error: "rlp: input string too long for uint8"}, // byte arrays {input: "01", ptr: new([5]byte), value: [5]byte{1}}, @@ -239,8 +239,8 @@ var decodeTests = []decodeTest{ {input: "850102030405", ptr: new([5]byte), value: [5]byte{1, 2, 3, 4, 5}}, {input: "C0", ptr: new([5]byte), value: [5]byte{}}, {input: "C3010203", ptr: new([5]byte), value: [5]byte{1, 2, 3, 0, 0}}, - {input: "C3820102", ptr: new([5]byte), error: "rlp: input string too big for uint8"}, - {input: "86010203040506", ptr: new([5]byte), error: "rlp: input string too big for [5]uint8"}, + {input: "C3820102", ptr: new([5]byte), error: "rlp: input string too long for uint8"}, + {input: "86010203040506", ptr: new([5]byte), error: "rlp: input string too long for [5]uint8"}, {input: "850101", ptr: new([5]byte), error: io.ErrUnexpectedEOF.Error()}, // byte array reuse (should be zeroed) @@ -254,19 +254,19 @@ var decodeTests = []decodeTest{ // zero sized byte arrays {input: "80", ptr: new([0]byte), value: [0]byte{}}, {input: "C0", ptr: new([0]byte), value: [0]byte{}}, - {input: "01", ptr: new([0]byte), error: "rlp: input string too big for [0]uint8"}, - {input: "8101", ptr: new([0]byte), error: "rlp: input string too big for [0]uint8"}, + {input: "01", ptr: new([0]byte), error: "rlp: input string too long for [0]uint8"}, + {input: "8101", ptr: new([0]byte), error: "rlp: input string too long for [0]uint8"}, // strings {input: "00", ptr: new(string), value: "\000"}, {input: "8D6162636465666768696A6B6C6D", ptr: new(string), value: "abcdefghijklm"}, - {input: "C0", ptr: new(string), error: ErrExpectedString.Error()}, + {input: "C0", ptr: new(string), error: "rlp: expected input string or byte for string"}, // big ints {input: "01", ptr: new(*big.Int), value: big.NewInt(1)}, {input: "89FFFFFFFFFFFFFFFFFF", ptr: new(*big.Int), value: veryBigInt}, {input: "10", ptr: new(big.Int), value: *big.NewInt(16)}, // non-pointer also works - {input: "C0", ptr: new(*big.Int), error: ErrExpectedString.Error()}, + {input: "C0", ptr: new(*big.Int), error: "rlp: expected input string or byte for *big.Int"}, // structs {input: "C0", ptr: new(simplestruct), value: simplestruct{0, ""}},