From 58d0f6440ba82a0cbf327ca77b7cd795c2d687c6 Mon Sep 17 00:00:00 2001 From: Martin Holst Swende Date: Fri, 17 Mar 2023 06:51:55 -0400 Subject: [PATCH] rlp: support for uint256 (#26898) This adds built-in support in package rlp for encoding, decoding and generating code dealing with uint256.Int. --------- Co-authored-by: Felix Lange --- rlp/decode.go | 64 +++++++++++++++++++++++++++++ rlp/decode_test.go | 42 +++++++++++++++++++ rlp/encbuffer.go | 25 +++++++++++ rlp/encode.go | 21 ++++++++++ rlp/encode_test.go | 49 ++++++++++++++++++++++ rlp/rlpgen/gen.go | 51 ++++++++++++++++++++++- rlp/rlpgen/gen_test.go | 2 +- rlp/rlpgen/testdata/uint256.in.txt | 10 +++++ rlp/rlpgen/testdata/uint256.out.txt | 44 ++++++++++++++++++++ rlp/rlpgen/types.go | 10 +++++ 10 files changed, 316 insertions(+), 2 deletions(-) create mode 100644 rlp/rlpgen/testdata/uint256.in.txt create mode 100644 rlp/rlpgen/testdata/uint256.out.txt diff --git a/rlp/decode.go b/rlp/decode.go index c9b2652414..c9b50e8c18 100644 --- a/rlp/decode.go +++ b/rlp/decode.go @@ -29,6 +29,7 @@ import ( "sync" "github.com/ethereum/go-ethereum/rlp/internal/rlpstruct" + "github.com/holiman/uint256" ) //lint:ignore ST1012 EOL is not an error. @@ -52,6 +53,7 @@ var ( errUintOverflow = errors.New("rlp: uint overflow") errNoPointer = errors.New("rlp: interface given to Decode must be a pointer") errDecodeIntoNil = errors.New("rlp: pointer given to Decode must not be nil") + errUint256Large = errors.New("rlp: value too large for uint256") streamPool = sync.Pool{ New: func() interface{} { return new(Stream) }, @@ -148,6 +150,7 @@ func addErrorContext(err error, ctx string) error { var ( decoderInterface = reflect.TypeOf(new(Decoder)).Elem() bigInt = reflect.TypeOf(big.Int{}) + u256Int = reflect.TypeOf(uint256.Int{}) ) func makeDecoder(typ reflect.Type, tags rlpstruct.Tags) (dec decoder, err error) { @@ -159,6 +162,10 @@ func makeDecoder(typ reflect.Type, tags rlpstruct.Tags) (dec decoder, err error) return decodeBigInt, nil case typ.AssignableTo(bigInt): return decodeBigIntNoPtr, nil + case typ == reflect.PtrTo(u256Int): + return decodeU256, nil + case typ == u256Int: + return decodeU256NoPtr, nil case kind == reflect.Ptr: return makePtrDecoder(typ, tags) case reflect.PtrTo(typ).Implements(decoderInterface): @@ -235,6 +242,24 @@ func decodeBigInt(s *Stream, val reflect.Value) error { return nil } +func decodeU256NoPtr(s *Stream, val reflect.Value) error { + return decodeU256(s, val.Addr()) +} + +func decodeU256(s *Stream, val reflect.Value) error { + i := val.Interface().(*uint256.Int) + if i == nil { + i = new(uint256.Int) + val.Set(reflect.ValueOf(i)) + } + + err := s.ReadUint256(i) + if err != nil { + return wrapStreamError(err, val.Type()) + } + return nil +} + func makeListDecoder(typ reflect.Type, tag rlpstruct.Tags) (decoder, error) { etype := typ.Elem() if etype.Kind() == reflect.Uint8 && !reflect.PtrTo(etype).Implements(decoderInterface) { @@ -863,6 +888,45 @@ func (s *Stream) decodeBigInt(dst *big.Int) error { return nil } +// ReadUint256 decodes the next value as a uint256. +func (s *Stream) ReadUint256(dst *uint256.Int) error { + var buffer []byte + kind, size, err := s.Kind() + switch { + case err != nil: + return err + case kind == List: + return ErrExpectedString + case kind == Byte: + buffer = s.uintbuf[:1] + buffer[0] = s.byteval + s.kind = -1 // re-arm Kind + case size == 0: + // Avoid zero-length read. + s.kind = -1 + case size <= uint64(len(s.uintbuf)): + // All possible uint256 values fit into s.uintbuf. + buffer = s.uintbuf[:size] + if err := s.readFull(buffer); err != nil { + return err + } + // Reject inputs where single byte encoding should have been used. + if size == 1 && buffer[0] < 128 { + return ErrCanonSize + } + default: + return errUint256Large + } + + // Reject leading zero bytes. + if len(buffer) > 0 && buffer[0] == 0 { + return ErrCanonInt + } + // Set the integer bytes. + dst.SetBytes(buffer) + return nil +} + // Decode decodes a value and stores the result in the value pointed // to by val. Please see the documentation for the Decode function // to learn about the decoding rules. diff --git a/rlp/decode_test.go b/rlp/decode_test.go index dbcfcffed1..07d9c579a6 100644 --- a/rlp/decode_test.go +++ b/rlp/decode_test.go @@ -28,6 +28,7 @@ import ( "testing" "github.com/ethereum/go-ethereum/common/math" + "github.com/holiman/uint256" ) func TestStreamKind(t *testing.T) { @@ -468,6 +469,10 @@ var ( veryVeryBigInt = new(big.Int).Exp(veryBigInt, big.NewInt(8), nil) ) +var ( + veryBigInt256, _ = uint256.FromBig(veryBigInt) +) + var decodeTests = []decodeTest{ // booleans {input: "01", ptr: new(bool), value: true}, @@ -541,11 +546,27 @@ var decodeTests = []decodeTest{ {input: "89FFFFFFFFFFFFFFFFFF", ptr: new(*big.Int), value: veryBigInt}, {input: "B848FFFFFFFFFFFFFFFFF800000000000000001BFFFFFFFFFFFFFFFFC8000000000000000045FFFFFFFFFFFFFFFFC800000000000000001BFFFFFFFFFFFFFFFFF8000000000000000001", ptr: new(*big.Int), value: veryVeryBigInt}, {input: "10", ptr: new(big.Int), value: *big.NewInt(16)}, // non-pointer also works + + // big int errors {input: "C0", ptr: new(*big.Int), error: "rlp: expected input string or byte for *big.Int"}, {input: "00", ptr: new(*big.Int), error: "rlp: non-canonical integer (leading zero bytes) for *big.Int"}, {input: "820001", ptr: new(*big.Int), error: "rlp: non-canonical integer (leading zero bytes) for *big.Int"}, {input: "8105", ptr: new(*big.Int), error: "rlp: non-canonical size information for *big.Int"}, + // uint256 + {input: "80", ptr: new(*uint256.Int), value: uint256.NewInt(0)}, + {input: "01", ptr: new(*uint256.Int), value: uint256.NewInt(1)}, + {input: "88FFFFFFFFFFFFFFFF", ptr: new(*uint256.Int), value: uint256.NewInt(math.MaxUint64)}, + {input: "89FFFFFFFFFFFFFFFFFF", ptr: new(*uint256.Int), value: veryBigInt256}, + {input: "10", ptr: new(uint256.Int), value: *uint256.NewInt(16)}, // non-pointer also works + + // uint256 errors + {input: "C0", ptr: new(*uint256.Int), error: "rlp: expected input string or byte for *uint256.Int"}, + {input: "00", ptr: new(*uint256.Int), error: "rlp: non-canonical integer (leading zero bytes) for *uint256.Int"}, + {input: "820001", ptr: new(*uint256.Int), error: "rlp: non-canonical integer (leading zero bytes) for *uint256.Int"}, + {input: "8105", ptr: new(*uint256.Int), error: "rlp: non-canonical size information for *uint256.Int"}, + {input: "A1FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00", ptr: new(*uint256.Int), error: "rlp: value too large for uint256"}, + // structs { input: "C50583343434", @@ -1223,6 +1244,27 @@ func BenchmarkDecodeBigInts(b *testing.B) { } } +func BenchmarkDecodeU256Ints(b *testing.B) { + ints := make([]*uint256.Int, 200) + for i := range ints { + ints[i], _ = uint256.FromBig(math.BigPow(2, int64(i))) + } + enc, err := EncodeToBytes(ints) + if err != nil { + b.Fatal(err) + } + b.SetBytes(int64(len(enc))) + b.ReportAllocs() + b.ResetTimer() + + var out []*uint256.Int + for i := 0; i < b.N; i++ { + if err := DecodeBytes(enc, &out); err != nil { + b.Fatal(err) + } + } +} + func encodeTestSlice(n uint) []byte { s := make([]uint, n) for i := uint(0); i < n; i++ { diff --git a/rlp/encbuffer.go b/rlp/encbuffer.go index d2c6d93bca..9ac4fcb2ca 100644 --- a/rlp/encbuffer.go +++ b/rlp/encbuffer.go @@ -17,10 +17,13 @@ package rlp import ( + "encoding/binary" "io" "math/big" "reflect" "sync" + + "github.com/holiman/uint256" ) type encBuffer struct { @@ -169,6 +172,23 @@ func (w *encBuffer) writeBigInt(i *big.Int) { } } +// writeUint256 writes z as an integer. +func (w *encBuffer) writeUint256(z *uint256.Int) { + bitlen := z.BitLen() + if bitlen <= 64 { + w.writeUint64(z.Uint64()) + return + } + nBytes := byte((bitlen + 7) / 8) + var b [33]byte + binary.BigEndian.PutUint64(b[1:9], z[3]) + binary.BigEndian.PutUint64(b[9:17], z[2]) + binary.BigEndian.PutUint64(b[17:25], z[1]) + binary.BigEndian.PutUint64(b[25:33], z[0]) + b[32-nBytes] = 0x80 + nBytes + w.str = append(w.str, b[32-nBytes:]...) +} + // list adds a new list header to the header stack. It returns the index of the header. // Call listEnd with this index after encoding the content of the list. func (buf *encBuffer) list() int { @@ -376,6 +396,11 @@ func (w EncoderBuffer) WriteBigInt(i *big.Int) { w.buf.writeBigInt(i) } +// WriteUint256 encodes uint256.Int as an RLP string. +func (w EncoderBuffer) WriteUint256(i *uint256.Int) { + w.buf.writeUint256(i) +} + // WriteBytes encodes b as an RLP string. func (w EncoderBuffer) WriteBytes(b []byte) { w.buf.writeBytes(b) diff --git a/rlp/encode.go b/rlp/encode.go index a377a1ef4c..3fac0bd2d3 100644 --- a/rlp/encode.go +++ b/rlp/encode.go @@ -24,6 +24,7 @@ import ( "reflect" "github.com/ethereum/go-ethereum/rlp/internal/rlpstruct" + "github.com/holiman/uint256" ) var ( @@ -144,6 +145,10 @@ func makeWriter(typ reflect.Type, ts rlpstruct.Tags) (writer, error) { return writeBigIntPtr, nil case typ.AssignableTo(bigInt): return writeBigIntNoPtr, nil + case typ == reflect.PtrTo(u256Int): + return writeU256IntPtr, nil + case typ == u256Int: + return writeU256IntNoPtr, nil case kind == reflect.Ptr: return makePtrWriter(typ, ts) case reflect.PtrTo(typ).Implements(encoderInterface): @@ -206,6 +211,22 @@ func writeBigIntNoPtr(val reflect.Value, w *encBuffer) error { return nil } +func writeU256IntPtr(val reflect.Value, w *encBuffer) error { + ptr := val.Interface().(*uint256.Int) + if ptr == nil { + w.str = append(w.str, 0x80) + return nil + } + w.writeUint256(ptr) + return nil +} + +func writeU256IntNoPtr(val reflect.Value, w *encBuffer) error { + i := val.Interface().(uint256.Int) + w.writeUint256(&i) + return nil +} + func writeBytes(val reflect.Value, w *encBuffer) error { w.writeBytes(val.Bytes()) return nil diff --git a/rlp/encode_test.go b/rlp/encode_test.go index 82c490a802..02be47d0ef 100644 --- a/rlp/encode_test.go +++ b/rlp/encode_test.go @@ -27,6 +27,7 @@ import ( "testing" "github.com/ethereum/go-ethereum/common/math" + "github.com/holiman/uint256" ) type testEncoder struct { @@ -147,6 +148,30 @@ var encTests = []encTest{ {val: big.NewInt(-1), error: "rlp: cannot encode negative big.Int"}, {val: *big.NewInt(-1), error: "rlp: cannot encode negative big.Int"}, + // uint256 + {val: uint256.NewInt(0), output: "80"}, + {val: uint256.NewInt(1), output: "01"}, + {val: uint256.NewInt(127), output: "7F"}, + {val: uint256.NewInt(128), output: "8180"}, + {val: uint256.NewInt(256), output: "820100"}, + {val: uint256.NewInt(1024), output: "820400"}, + {val: uint256.NewInt(0xFFFFFF), output: "83FFFFFF"}, + {val: uint256.NewInt(0xFFFFFFFF), output: "84FFFFFFFF"}, + {val: uint256.NewInt(0xFFFFFFFFFF), output: "85FFFFFFFFFF"}, + {val: uint256.NewInt(0xFFFFFFFFFFFF), output: "86FFFFFFFFFFFF"}, + {val: uint256.NewInt(0xFFFFFFFFFFFFFF), output: "87FFFFFFFFFFFFFF"}, + { + val: new(uint256.Int).SetBytes(unhex("102030405060708090A0B0C0D0E0F2")), + output: "8F102030405060708090A0B0C0D0E0F2", + }, + { + val: new(uint256.Int).SetBytes(unhex("0100020003000400050006000700080009000A000B000C000D000E01")), + output: "9C0100020003000400050006000700080009000A000B000C000D000E01", + }, + // non-pointer uint256.Int + {val: *uint256.NewInt(0), output: "80"}, + {val: *uint256.NewInt(0xFFFFFF), output: "83FFFFFF"}, + // byte arrays {val: [0]byte{}, output: "80"}, {val: [1]byte{0}, output: "00"}, @@ -256,6 +281,12 @@ var encTests = []encTest{ output: "F90200CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376", }, + // Non-byte arrays are encoded as lists. + // Note that it is important to test [4]uint64 specifically, + // because that's the underlying type of uint256.Int. + {val: [4]uint32{1, 2, 3, 4}, output: "C401020304"}, + {val: [4]uint64{1, 2, 3, 4}, output: "C401020304"}, + // RawValue {val: RawValue(unhex("01")), output: "01"}, {val: RawValue(unhex("82FFFF")), output: "82FFFF"}, @@ -301,6 +332,7 @@ var encTests = []encTest{ {val: (*[]byte)(nil), output: "80"}, {val: (*[10]byte)(nil), output: "80"}, {val: (*big.Int)(nil), output: "80"}, + {val: (*uint256.Int)(nil), output: "80"}, {val: (*[]string)(nil), output: "C0"}, {val: (*[10]string)(nil), output: "C0"}, {val: (*[]interface{})(nil), output: "C0"}, @@ -509,6 +541,23 @@ func BenchmarkEncodeBigInts(b *testing.B) { } } +func BenchmarkEncodeU256Ints(b *testing.B) { + ints := make([]*uint256.Int, 200) + for i := range ints { + ints[i], _ = uint256.FromBig(math.BigPow(2, int64(i))) + } + out := bytes.NewBuffer(make([]byte, 0, 4096)) + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + out.Reset() + if err := Encode(out, ints); err != nil { + b.Fatal(err) + } + } +} + func BenchmarkEncodeConcurrentInterface(b *testing.B) { type struct1 struct { A string diff --git a/rlp/rlpgen/gen.go b/rlp/rlpgen/gen.go index 1deb5a93c2..0c65864826 100644 --- a/rlp/rlpgen/gen.go +++ b/rlp/rlpgen/gen.go @@ -283,7 +283,7 @@ func (op byteArrayOp) genDecode(ctx *genContext) (string, string) { return resultV, b.String() } -// bigIntNoPtrOp handles non-pointer big.Int. +// bigIntOp handles big.Int. // This exists because big.Int has it's own decoder operation on rlp.Stream, // but the decode method returns *big.Int, so it needs to be dereferenced. type bigIntOp struct { @@ -330,6 +330,49 @@ func (op bigIntOp) genDecode(ctx *genContext) (string, string) { return result, b.String() } +// uint256Op handles "github.com/holiman/uint256".Int +type uint256Op struct { + pointer bool +} + +func (op uint256Op) genWrite(ctx *genContext, v string) string { + var b bytes.Buffer + + dst := v + if !op.pointer { + dst = "&" + v + } + fmt.Fprintf(&b, "w.WriteUint256(%s)\n", dst) + + // Wrap with nil check. + if op.pointer { + code := b.String() + b.Reset() + fmt.Fprintf(&b, "if %s == nil {\n", v) + fmt.Fprintf(&b, " w.Write(rlp.EmptyString)") + fmt.Fprintf(&b, "} else {\n") + fmt.Fprint(&b, code) + fmt.Fprintf(&b, "}\n") + } + + return b.String() +} + +func (op uint256Op) genDecode(ctx *genContext) (string, string) { + ctx.addImport("github.com/holiman/uint256") + + var b bytes.Buffer + resultV := ctx.temp() + fmt.Fprintf(&b, "var %s uint256.Int\n", resultV) + fmt.Fprintf(&b, "if err := dec.ReadUint256(&%s); err != nil { return err }\n", resultV) + + result := resultV + if op.pointer { + result = "&" + resultV + } + return result, b.String() +} + // encoderDecoderOp handles rlp.Encoder and rlp.Decoder. // In order to be used with this, the type must implement both interfaces. // This restriction may be lifted in the future by creating separate ops for @@ -635,6 +678,9 @@ func (bctx *buildContext) makeOp(name *types.Named, typ types.Type, tags rlpstru if isBigInt(typ) { return bigIntOp{}, nil } + if isUint256(typ) { + return uint256Op{}, nil + } if typ == bctx.rawValueType { return bctx.makeRawValueOp(), nil } @@ -647,6 +693,9 @@ func (bctx *buildContext) makeOp(name *types.Named, typ types.Type, tags rlpstru if isBigInt(typ.Elem()) { return bigIntOp{pointer: true}, nil } + if isUint256(typ.Elem()) { + return uint256Op{pointer: true}, nil + } // Encoder/Decoder interfaces. if bctx.isEncoder(typ) { if bctx.isDecoder(typ) { diff --git a/rlp/rlpgen/gen_test.go b/rlp/rlpgen/gen_test.go index 241c34b6df..ff3b8385ae 100644 --- a/rlp/rlpgen/gen_test.go +++ b/rlp/rlpgen/gen_test.go @@ -47,7 +47,7 @@ func init() { } } -var tests = []string{"uints", "nil", "rawvalue", "optional", "bigint"} +var tests = []string{"uints", "nil", "rawvalue", "optional", "bigint", "uint256"} func TestOutput(t *testing.T) { for _, test := range tests { diff --git a/rlp/rlpgen/testdata/uint256.in.txt b/rlp/rlpgen/testdata/uint256.in.txt new file mode 100644 index 0000000000..ed16e0a788 --- /dev/null +++ b/rlp/rlpgen/testdata/uint256.in.txt @@ -0,0 +1,10 @@ +// -*- mode: go -*- + +package test + +import "github.com/holiman/uint256" + +type Test struct { + Int *uint256.Int + IntNoPtr uint256.Int +} diff --git a/rlp/rlpgen/testdata/uint256.out.txt b/rlp/rlpgen/testdata/uint256.out.txt new file mode 100644 index 0000000000..5e6d3ed992 --- /dev/null +++ b/rlp/rlpgen/testdata/uint256.out.txt @@ -0,0 +1,44 @@ +package test + +import "github.com/ethereum/go-ethereum/rlp" +import "github.com/holiman/uint256" +import "io" + +func (obj *Test) EncodeRLP(_w io.Writer) error { + w := rlp.NewEncoderBuffer(_w) + _tmp0 := w.List() + if obj.Int == nil { + w.Write(rlp.EmptyString) + } else { + w.WriteUint256(obj.Int) + } + w.WriteUint256(&obj.IntNoPtr) + w.ListEnd(_tmp0) + return w.Flush() +} + +func (obj *Test) DecodeRLP(dec *rlp.Stream) error { + var _tmp0 Test + { + if _, err := dec.List(); err != nil { + return err + } + // Int: + var _tmp1 uint256.Int + if err := dec.ReadUint256(&_tmp1); err != nil { + return err + } + _tmp0.Int = &_tmp1 + // IntNoPtr: + var _tmp2 uint256.Int + if err := dec.ReadUint256(&_tmp2); err != nil { + return err + } + _tmp0.IntNoPtr = _tmp2 + if err := dec.ListEnd(); err != nil { + return err + } + } + *obj = _tmp0 + return nil +} diff --git a/rlp/rlpgen/types.go b/rlp/rlpgen/types.go index 19694262e5..ea7dc96d88 100644 --- a/rlp/rlpgen/types.go +++ b/rlp/rlpgen/types.go @@ -97,6 +97,16 @@ func isBigInt(typ types.Type) bool { return name.Pkg().Path() == "math/big" && name.Name() == "Int" } +// isUint256 checks whether 'typ' is "github.com/holiman/uint256".Int. +func isUint256(typ types.Type) bool { + named, ok := typ.(*types.Named) + if !ok { + return false + } + name := named.Obj() + return name.Pkg().Path() == "github.com/holiman/uint256" && name.Name() == "Int" +} + // isByte checks whether the underlying type of 'typ' is uint8. func isByte(typ types.Type) bool { basic, ok := resolveUnderlying(typ).(*types.Basic)