diff --git a/rlp/decode.go b/rlp/decode.go index 0e99d9caa0..0fde0a9473 100644 --- a/rlp/decode.go +++ b/rlp/decode.go @@ -540,6 +540,31 @@ func (s *Stream) Bytes() ([]byte, error) { } } +// Raw reads a raw encoded value including RLP type information. +func (s *Stream) Raw() ([]byte, error) { + kind, size, err := s.Kind() + if err != nil { + return nil, err + } + if kind == Byte { + s.kind = -1 // rearm Kind + return []byte{s.byteval}, nil + } + // the original header has already been read and is no longer + // available. read content and put a new header in front of it. + start := headsize(size) + buf := make([]byte, uint64(start)+size) + if err := s.readFull(buf[start:]); err != nil { + return nil, err + } + if kind == String { + puthead(buf, 0x80, 0xB8, size) + } else { + puthead(buf, 0xC0, 0xF7, size) + } + return buf, nil +} + var errUintOverflow = errors.New("rlp: uint overflow") // Uint reads an RLP string of up to 8 bytes and returns its contents diff --git a/rlp/decode_test.go b/rlp/decode_test.go index 0f034d5d80..a18ff1d080 100644 --- a/rlp/decode_test.go +++ b/rlp/decode_test.go @@ -165,6 +165,20 @@ func TestStreamList(t *testing.T) { } } +func TestStreamRaw(t *testing.T) { + s := NewStream(bytes.NewReader(unhex("C58401010101"))) + s.List() + + want := unhex("8401010101") + raw, err := s.Raw() + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(want, raw) { + t.Errorf("raw mismatch: got %x, want %x", raw, want) + } +} + func TestDecodeErrors(t *testing.T) { r := bytes.NewReader(nil) @@ -331,7 +345,7 @@ var decodeTests = []decodeTest{ {input: "C109", ptr: new(*[]uint), value: &[]uint{9}}, {input: "C58403030303", ptr: new(*[][]byte), value: &[][]byte{{3, 3, 3, 3}}}, - // check that input position is advanced also empty values. + // check that input position is advanced also for empty values. {input: "C3808005", ptr: new([]*uint), value: []*uint{nil, nil, uintp(5)}}, // pointer should be reset to nil diff --git a/rlp/encode.go b/rlp/encode.go index 7ac74d8fb9..289bc4eaaa 100644 --- a/rlp/encode.go +++ b/rlp/encode.go @@ -70,7 +70,7 @@ func (e flatenc) EncodeRLP(out io.Writer) error { newhead := eb.lheads[prevnheads] copy(eb.lheads[prevnheads:], eb.lheads[prevnheads+1:]) eb.lheads = eb.lheads[:len(eb.lheads)-1] - eb.lhsize -= newhead.tagsize() + eb.lhsize -= headsize(uint64(newhead.size)) return nil } @@ -155,21 +155,29 @@ type listhead struct { // encode writes head to the given buffer, which must be at least // 9 bytes long. It returns the encoded bytes. func (head *listhead) encode(buf []byte) []byte { - if head.size < 56 { - buf[0] = 0xC0 + byte(head.size) - return buf[:1] - } else { - sizesize := putint(buf[1:], uint64(head.size)) - buf[0] = 0xF7 + byte(sizesize) - return buf[:sizesize+1] + return buf[:puthead(buf, 0xC0, 0xF7, uint64(head.size))] +} + +// headsize returns the size of a list or string header +// for a value of the given size. +func headsize(size uint64) int { + if size < 56 { + return 1 } + return 1 + intsize(size) } -func (head *listhead) tagsize() int { - if head.size < 56 { +// puthead writes a list or string header to buf. +// buf must be at least 9 bytes long. +func puthead(buf []byte, smalltag, largetag byte, size uint64) int { + if size < 56 { + buf[0] = smalltag + byte(size) return 1 + } else { + sizesize := putint(buf[1:], size) + buf[0] = largetag + byte(sizesize) + return sizesize + 1 } - return 1 + intsize(uint64(head.size)) } func newencbuf() *encbuf {