diff --git a/signer/core/signed_data.go b/signer/core/signed_data.go index fec464417a..7fc66b4b74 100644 --- a/signer/core/signed_data.go +++ b/signer/core/signed_data.go @@ -481,6 +481,24 @@ func (typedData *TypedData) EncodeData(primaryType string, data map[string]inter return buffer.Bytes(), nil } +// Attempt to parse bytes in different formats: byte array, hex string, hexutil.Bytes. +func parseBytes(encType interface{}) ([]byte, bool) { + switch v := encType.(type) { + case []byte: + return v, true + case hexutil.Bytes: + return []byte(v), true + case string: + bytes, err := hexutil.Decode(v) + if err != nil { + return nil, false + } + return bytes, true + default: + return nil, false + } +} + func parseInteger(encType string, encValue interface{}) (*big.Int, error) { var ( length int @@ -560,7 +578,7 @@ func (typedData *TypedData) EncodePrimitiveValue(encType string, encValue interf } return crypto.Keccak256([]byte(strVal)), nil case "bytes": - bytesValue, ok := encValue.([]byte) + bytesValue, ok := parseBytes(encValue) if !ok { return nil, dataMismatchError(encType, encValue) } @@ -575,10 +593,13 @@ func (typedData *TypedData) EncodePrimitiveValue(encType string, encValue interf if length < 0 || length > 32 { return nil, fmt.Errorf("invalid size on bytes: %d", length) } - if byteValue, ok := encValue.(hexutil.Bytes); !ok { + if byteValue, ok := parseBytes(encValue); !ok || len(byteValue) != length { return nil, dataMismatchError(encType, encValue) } else { - return math.PaddedBigBytes(new(big.Int).SetBytes(byteValue), 32), nil + // Right-pad the bits + dst := make([]byte, 32) + copy(dst, byteValue) + return dst, nil } } if strings.HasPrefix(encType, "int") || strings.HasPrefix(encType, "uint") { diff --git a/signer/core/signed_data_internal_test.go b/signer/core/signed_data_internal_test.go index 0d59fcfca8..9768ee0b3f 100644 --- a/signer/core/signed_data_internal_test.go +++ b/signer/core/signed_data_internal_test.go @@ -17,10 +17,104 @@ package core import ( + "bytes" "math/big" "testing" + + "github.com/ethereum/go-ethereum/common/hexutil" ) +func TestBytesPadding(t *testing.T) { + tests := []struct { + Type string + Input []byte + Output []byte // nil => error + }{ + { + // Fail on wrong length + Type: "bytes20", + Input: []byte{}, + Output: nil, + }, + { + Type: "bytes1", + Input: []byte{1}, + Output: []byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + }, + { + Type: "bytes1", + Input: []byte{1, 2}, + Output: nil, + }, + { + Type: "bytes7", + Input: []byte{1, 2, 3, 4, 5, 6, 7}, + Output: []byte{1, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + }, + { + Type: "bytes32", + Input: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}, + Output: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}, + }, + { + Type: "bytes32", + Input: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33}, + Output: nil, + }, + } + + d := TypedData{} + for i, test := range tests { + val, err := d.EncodePrimitiveValue(test.Type, test.Input, 1) + if test.Output == nil { + if err == nil { + t.Errorf("test %d: expected error, got no error (result %x)", i, val) + } + } else { + if err != nil { + t.Errorf("test %d: expected no error, got %v", i, err) + } + if len(val) != 32 { + t.Errorf("test %d: expected len 32, got %d", i, len(val)) + } + if !bytes.Equal(val, test.Output) { + t.Errorf("test %d: expected %x, got %x", i, test.Output, val) + } + } + } +} + +func TestParseBytes(t *testing.T) { + for i, tt := range []struct { + v interface{} + exp []byte + }{ + {"0x", []byte{}}, + {"0x1234", []byte{0x12, 0x34}}, + {[]byte{12, 34}, []byte{12, 34}}, + {hexutil.Bytes([]byte{12, 34}), []byte{12, 34}}, + {"1234", nil}, // not a proper hex-string + {"0x01233", nil}, // nibbles should be rejected + {"not a hex string", nil}, + {15, nil}, + {nil, nil}, + } { + out, ok := parseBytes(tt.v) + if tt.exp == nil { + if ok || out != nil { + t.Errorf("test %d: expected !ok, got ok = %v with out = %x", i, ok, out) + } + continue + } + if !ok { + t.Errorf("test %d: expected ok got !ok", i) + } + if !bytes.Equal(out, tt.exp) { + t.Errorf("test %d: expected %x got %x", i, tt.exp, out) + } + } +} + func TestParseInteger(t *testing.T) { for i, tt := range []struct { t string