diff --git a/common/math/big.go b/common/math/big.go index 704ca40a93..0b67a1b503 100644 --- a/common/math/big.go +++ b/common/math/big.go @@ -18,6 +18,7 @@ package math import ( + "fmt" "math/big" ) @@ -35,6 +36,24 @@ const ( wordBytes = wordBits / 8 ) +// HexOrDecimal256 marshals big.Int as hex or decimal. +type HexOrDecimal256 big.Int + +// UnmarshalText implements encoding.TextUnmarshaler. +func (i *HexOrDecimal256) UnmarshalText(input []byte) error { + bigint, ok := ParseBig256(string(input)) + if !ok { + return fmt.Errorf("invalid hex or decimal integer %q", input) + } + *i = HexOrDecimal256(*bigint) + return nil +} + +// MarshalText implements encoding.TextMarshaler. +func (i *HexOrDecimal256) MarshalText() ([]byte, error) { + return []byte(fmt.Sprintf("%#x", (*big.Int)(i))), nil +} + // ParseBig256 parses s as a 256 bit integer in decimal or hexadecimal syntax. // Leading zeros are accepted. The empty string parses as zero. func ParseBig256(s string) (*big.Int, bool) { diff --git a/common/math/big_test.go b/common/math/big_test.go index 6eb13f4f12..deff254658 100644 --- a/common/math/big_test.go +++ b/common/math/big_test.go @@ -23,7 +23,7 @@ import ( "testing" ) -func TestParseBig256(t *testing.T) { +func TestHexOrDecimal256(t *testing.T) { tests := []struct { input string num *big.Int @@ -47,13 +47,14 @@ func TestParseBig256(t *testing.T) { {"115792089237316195423570985008687907853269984665640564039457584007913129639936", nil, false}, } for _, test := range tests { - num, ok := ParseBig256(test.input) - if ok != test.ok { - t.Errorf("ParseBig(%q) -> ok = %t, want %t", test.input, ok, test.ok) + var num HexOrDecimal256 + err := num.UnmarshalText([]byte(test.input)) + if (err == nil) != test.ok { + t.Errorf("ParseBig(%q) -> (err == nil) == %t, want %t", test.input, err == nil, test.ok) continue } - if num != nil && test.num != nil && num.Cmp(test.num) != 0 { - t.Errorf("ParseBig(%q) -> %d, want %d", test.input, num, test.num) + if test.num != nil && (*big.Int)(&num).Cmp(test.num) != 0 { + t.Errorf("ParseBig(%q) -> %d, want %d", test.input, (*big.Int)(&num), test.num) } } } diff --git a/common/math/integer.go b/common/math/integer.go index a3eeee27e7..7eff4d3b00 100644 --- a/common/math/integer.go +++ b/common/math/integer.go @@ -16,7 +16,10 @@ package math -import "strconv" +import ( + "fmt" + "strconv" +) const ( // Integer limit values. @@ -34,6 +37,24 @@ const ( MaxUint64 = 1<<64 - 1 ) +// HexOrDecimal64 marshals uint64 as hex or decimal. +type HexOrDecimal64 uint64 + +// UnmarshalText implements encoding.TextUnmarshaler. +func (i *HexOrDecimal64) UnmarshalText(input []byte) error { + int, ok := ParseUint64(string(input)) + if !ok { + return fmt.Errorf("invalid hex or decimal integer %q", input) + } + *i = HexOrDecimal64(int) + return nil +} + +// MarshalText implements encoding.TextMarshaler. +func (i HexOrDecimal64) MarshalText() ([]byte, error) { + return []byte(fmt.Sprintf("%#x", uint64(i))), nil +} + // ParseUint64 parses s as an integer in decimal or hexadecimal syntax. // Leading zeros are accepted. The empty string parses as zero. func ParseUint64(s string) (uint64, bool) { diff --git a/common/math/integer_test.go b/common/math/integer_test.go index 05bba221f9..b31c7c26c2 100644 --- a/common/math/integer_test.go +++ b/common/math/integer_test.go @@ -65,7 +65,7 @@ func TestOverflow(t *testing.T) { } } -func TestParseUint64(t *testing.T) { +func TestHexOrDecimal64(t *testing.T) { tests := []struct { input string num uint64 @@ -88,12 +88,13 @@ func TestParseUint64(t *testing.T) { {"18446744073709551617", 0, false}, } for _, test := range tests { - num, ok := ParseUint64(test.input) - if ok != test.ok { - t.Errorf("ParseUint64(%q) -> ok = %t, want %t", test.input, ok, test.ok) + var num HexOrDecimal64 + err := num.UnmarshalText([]byte(test.input)) + if (err == nil) != test.ok { + t.Errorf("ParseUint64(%q) -> (err == nil) = %t, want %t", test.input, err == nil, test.ok) continue } - if ok && num != test.num { + if err == nil && uint64(num) != test.num { t.Errorf("ParseUint64(%q) -> %d, want %d", test.input, num, test.num) } }