New args types with stricter checking

pull/619/head
Taylor Gerring 10 years ago
parent 14c14fd61f
commit e402e1dc2e
  1. 12
      rpc/api.go
  2. 88
      rpc/args.go
  3. 12
      rpc/args_test.go

@ -108,15 +108,15 @@ func (api *EthereumApi) GetRequestReply(req *RpcRequest, reply *interface{}) err
count := api.xethAtStateNum(args.BlockNumber).TxCountAt(args.Address) count := api.xethAtStateNum(args.BlockNumber).TxCountAt(args.Address)
*reply = common.ToHex(big.NewInt(int64(count)).Bytes()) *reply = common.ToHex(big.NewInt(int64(count)).Bytes())
case "eth_getBlockTransactionCountByHash": case "eth_getBlockTransactionCountByHash":
args := new(GetBlockByHashArgs) args := new(HashArgs)
if err := json.Unmarshal(req.Params, &args); err != nil { if err := json.Unmarshal(req.Params, &args); err != nil {
return err return err
} }
block := NewBlockRes(api.xeth().EthBlockByHash(args.BlockHash), false) block := NewBlockRes(api.xeth().EthBlockByHash(args.Hash), false)
*reply = common.ToHex(big.NewInt(int64(len(block.Transactions))).Bytes()) *reply = common.ToHex(big.NewInt(int64(len(block.Transactions))).Bytes())
case "eth_getBlockTransactionCountByNumber": case "eth_getBlockTransactionCountByNumber":
args := new(GetBlockByNumberArgs) args := new(BlockNumArg)
if err := json.Unmarshal(req.Params, &args); err != nil { if err := json.Unmarshal(req.Params, &args); err != nil {
return err return err
} }
@ -124,16 +124,16 @@ func (api *EthereumApi) GetRequestReply(req *RpcRequest, reply *interface{}) err
block := NewBlockRes(api.xeth().EthBlockByNumber(args.BlockNumber), false) block := NewBlockRes(api.xeth().EthBlockByNumber(args.BlockNumber), false)
*reply = common.ToHex(big.NewInt(int64(len(block.Transactions))).Bytes()) *reply = common.ToHex(big.NewInt(int64(len(block.Transactions))).Bytes())
case "eth_getUncleCountByBlockHash": case "eth_getUncleCountByBlockHash":
args := new(GetBlockByHashArgs) args := new(HashArgs)
if err := json.Unmarshal(req.Params, &args); err != nil { if err := json.Unmarshal(req.Params, &args); err != nil {
return err return err
} }
block := api.xeth().EthBlockByHash(args.BlockHash) block := api.xeth().EthBlockByHash(args.Hash)
br := NewBlockRes(block, false) br := NewBlockRes(block, false)
*reply = common.ToHex(big.NewInt(int64(len(br.Uncles))).Bytes()) *reply = common.ToHex(big.NewInt(int64(len(br.Uncles))).Bytes())
case "eth_getUncleCountByBlockNumber": case "eth_getUncleCountByBlockNumber":
args := new(GetBlockByNumberArgs) args := new(BlockNumArg)
if err := json.Unmarshal(req.Params, &args); err != nil { if err := json.Unmarshal(req.Params, &args); err != nil {
return err return err
} }

@ -108,8 +108,8 @@ func (args *GetBlockByHashArgs) UnmarshalJSON(b []byte) (err error) {
return NewDecodeParamError(err.Error()) return NewDecodeParamError(err.Error())
} }
if len(obj) < 1 { if len(obj) < 2 {
return NewInsufficientParamsError(len(obj), 1) return NewInsufficientParamsError(len(obj), 2)
} }
argstr, ok := obj[0].(string) argstr, ok := obj[0].(string)
@ -118,9 +118,7 @@ func (args *GetBlockByHashArgs) UnmarshalJSON(b []byte) (err error) {
} }
args.BlockHash = argstr args.BlockHash = argstr
if len(obj) > 1 { args.IncludeTxs = obj[1].(bool)
args.IncludeTxs = obj[1].(bool)
}
return nil return nil
} }
@ -136,8 +134,8 @@ func (args *GetBlockByNumberArgs) UnmarshalJSON(b []byte) (err error) {
return NewDecodeParamError(err.Error()) return NewDecodeParamError(err.Error())
} }
if len(obj) < 1 { if len(obj) < 2 {
return NewInsufficientParamsError(len(obj), 1) return NewInsufficientParamsError(len(obj), 2)
} }
if v, ok := obj[0].(float64); ok { if v, ok := obj[0].(float64); ok {
@ -148,9 +146,7 @@ func (args *GetBlockByNumberArgs) UnmarshalJSON(b []byte) (err error) {
return NewInvalidTypeError("blockNumber", "not a number or string") return NewInvalidTypeError("blockNumber", "not a number or string")
} }
if len(obj) > 1 { args.IncludeTxs = obj[1].(bool)
args.IncludeTxs = obj[1].(bool)
}
return nil return nil
} }
@ -496,6 +492,27 @@ func (args *GetDataArgs) UnmarshalJSON(b []byte) (err error) {
return nil return nil
} }
type BlockNumArg struct {
BlockNumber int64
}
func (args *BlockNumArg) UnmarshalJSON(b []byte) (err error) {
var obj []interface{}
if err := json.Unmarshal(b, &obj); err != nil {
return NewDecodeParamError(err.Error())
}
if len(obj) < 1 {
return NewInsufficientParamsError(len(obj), 1)
}
if err := blockHeight(obj[0], &args.BlockNumber); err != nil {
return err
}
return nil
}
type BlockNumIndexArgs struct { type BlockNumIndexArgs struct {
BlockNumber int64 BlockNumber int64
Index int64 Index int64
@ -507,21 +524,42 @@ func (args *BlockNumIndexArgs) UnmarshalJSON(b []byte) (err error) {
return NewDecodeParamError(err.Error()) return NewDecodeParamError(err.Error())
} }
if len(obj) < 1 { if len(obj) < 2 {
return NewInsufficientParamsError(len(obj), 1) return NewInsufficientParamsError(len(obj), 2)
} }
if err := blockHeight(obj[0], &args.BlockNumber); err != nil { if err := blockHeight(obj[0], &args.BlockNumber); err != nil {
return err return err
} }
if len(obj) > 1 { arg1, ok := obj[1].(string)
arg1, ok := obj[1].(string) if !ok {
if !ok { return NewInvalidTypeError("index", "not a string")
return NewInvalidTypeError("index", "not a string")
}
args.Index = common.Big(arg1).Int64()
} }
args.Index = common.Big(arg1).Int64()
return nil
}
type HashArgs struct {
Hash string
}
func (args *HashArgs) UnmarshalJSON(b []byte) (err error) {
var obj []interface{}
if err := json.Unmarshal(b, &obj); err != nil {
return NewDecodeParamError(err.Error())
}
if len(obj) < 1 {
return NewInsufficientParamsError(len(obj), 1)
}
arg0, ok := obj[0].(string)
if !ok {
return NewInvalidTypeError("hash", "not a string")
}
args.Hash = arg0
return nil return nil
} }
@ -537,8 +575,8 @@ func (args *HashIndexArgs) UnmarshalJSON(b []byte) (err error) {
return NewDecodeParamError(err.Error()) return NewDecodeParamError(err.Error())
} }
if len(obj) < 1 { if len(obj) < 2 {
return NewInsufficientParamsError(len(obj), 1) return NewInsufficientParamsError(len(obj), 2)
} }
arg0, ok := obj[0].(string) arg0, ok := obj[0].(string)
@ -547,13 +585,11 @@ func (args *HashIndexArgs) UnmarshalJSON(b []byte) (err error) {
} }
args.Hash = arg0 args.Hash = arg0
if len(obj) > 1 { arg1, ok := obj[1].(string)
arg1, ok := obj[1].(string) if !ok {
if !ok { return NewInvalidTypeError("index", "not a string")
return NewInvalidTypeError("index", "not a string")
}
args.Index = common.Big(arg1).Int64()
} }
args.Index = common.Big(arg1).Int64()
return nil return nil
} }

@ -225,7 +225,7 @@ func TestGetBlockByHashArgsHashInt(t *testing.T) {
input := `[8]` input := `[8]`
args := new(GetBlockByHashArgs) args := new(GetBlockByHashArgs)
str := ExpectInvalidTypeError(json.Unmarshal([]byte(input), &args)) str := ExpectInsufficientParamsError(json.Unmarshal([]byte(input), &args))
if len(str) > 0 { if len(str) > 0 {
t.Error(str) t.Error(str)
} }
@ -281,6 +281,16 @@ func TestGetBlockByNumberEmpty(t *testing.T) {
} }
} }
func TestGetBlockByNumberShort(t *testing.T) {
input := `["0xbbb"]`
args := new(GetBlockByNumberArgs)
str := ExpectInsufficientParamsError(json.Unmarshal([]byte(input), &args))
if len(str) > 0 {
t.Error(str)
}
}
func TestGetBlockByNumberBool(t *testing.T) { func TestGetBlockByNumberBool(t *testing.T) {
input := `[true, true]` input := `[true, true]`

Loading…
Cancel
Save