diff --git a/common/types.go b/common/types.go index b1666d7338..fec9861648 100644 --- a/common/types.go +++ b/common/types.go @@ -19,10 +19,12 @@ package common import ( "encoding/hex" "encoding/json" + "errors" "fmt" "math/big" "math/rand" "reflect" + "strings" ) const ( @@ -30,6 +32,8 @@ const ( AddressLength = 20 ) +var hashJsonLengthErr = errors.New("common: unmarshalJSON failed: hash must be exactly 32 bytes") + type ( Hash [HashLength]byte Address [AddressLength]byte @@ -58,6 +62,15 @@ func (h *Hash) UnmarshalJSON(input []byte) error { if length >= 2 && input[0] == '"' && input[length-1] == '"' { input = input[1 : length-1] } + // strip "0x" for length check + if len(input) > 1 && strings.ToLower(string(input[:2])) == "0x" { + input = input[2:] + } + + // validate the length of the input hash + if len(input) != HashLength*2 { + return hashJsonLengthErr + } h.SetBytes(FromHex(string(input))) return nil } diff --git a/common/types_test.go b/common/types_test.go index edf8d4d142..f2dfbf0c95 100644 --- a/common/types_test.go +++ b/common/types_test.go @@ -29,3 +29,25 @@ func TestBytesConversion(t *testing.T) { t.Errorf("expected %x got %x", exp, hash) } } + +func TestHashJsonValidation(t *testing.T) { + var h Hash + var tests = []struct { + Prefix string + Size int + Error error + }{ + {"", 2, hashJsonLengthErr}, + {"", 62, hashJsonLengthErr}, + {"", 66, hashJsonLengthErr}, + {"", 65, hashJsonLengthErr}, + {"0X", 64, nil}, + {"0x", 64, nil}, + {"0x", 62, hashJsonLengthErr}, + } + for i, test := range tests { + if err := h.UnmarshalJSON(append([]byte(test.Prefix), make([]byte, test.Size)...)); err != test.Error { + t.Error(i, "expected", test.Error, "got", err) + } + } +}