rpc: add limit for batch request items and response size (#26681)

This PR adds server-side limits for JSON-RPC batch requests. Before this change, batches
were limited only by processing time. The server would pick calls from the batch and
answer them until the response timeout occurred, then stop processing the remaining batch
items.

Here, we are adding two additional limits which can be configured:

- the 'item limit': batches can have at most N items
- the 'response size limit': batches can contain at most X response bytes

These limits are optional in package rpc. In Geth, we set a default limit of 1000 items
and 25MB response size.

When a batch goes over the limit, an error response is returned to the client. However,
doing this correctly isn't always possible. In JSON-RPC, only method calls with a valid
`id` can be responded to. Since batches may also contain non-call messages or
notifications, the best effort thing we can do to report an error with the batch itself is
reporting the limit violation as an error for the first method call in the batch. If a batch is
too large, but contains only notifications and responses, the error will be reported with
a null `id`.

The RPC client was also changed so it can deal with errors resulting from too large
batches. An older client connected to the server code in this PR could get stuck
until the request timeout occurred when the batch is too large. **Upgrading to a version
of the RPC client containing this change is strongly recommended to avoid timeout issues.**

For some weird reason, when writing the original client implementation, @fjl worked off of
the assumption that responses could be distributed across batches arbitrarily. So for a
batch request containing requests `[A B C]`, the server could respond with `[A B C]` but
also with `[A B] [C]` or even `[A] [B] [C]` and it wouldn't make a difference to the
client.

So in the implementation of BatchCallContext, the client waited for all requests in the
batch individually. If the server didn't respond to some of the requests in the batch, the
client would eventually just time out (if a context was used).

With the addition of batch limits into the server, we anticipate that people will hit this
kind of error way more often. To handle this properly, the client now waits for a single
response batch and expects it to contain all responses to the requests.

---------

Co-authored-by: Felix Lange <fjl@twurst.com>
Co-authored-by: Martin Holst Swende <martin@swende.se>
pull/27462/head
mmsqe 1 year ago committed by GitHub
parent 5ac4da3653
commit f3314bb6df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      cmd/clef/main.go
  2. 2
      cmd/geth/main.go
  3. 20
      cmd/utils/flags.go
  4. 8
      node/api.go
  5. 6
      node/config.go
  6. 24
      node/defaults.go
  7. 31
      node/node.go
  8. 18
      node/rpcstack.go
  9. 6
      node/rpcstack_test.go
  10. 131
      rpc/client.go
  11. 29
      rpc/client_opt.go
  12. 104
      rpc/client_test.go
  13. 5
      rpc/errors.go
  14. 283
      rpc/handler.go
  15. 19
      rpc/http.go
  16. 3
      rpc/inproc.go
  17. 3
      rpc/ipc.go
  18. 28
      rpc/server.go
  19. 39
      rpc/server_test.go
  20. 3
      rpc/stdio.go
  21. 13
      rpc/testdata/invalid-batch-toolarge.js
  22. 4
      rpc/websocket.go

@ -732,6 +732,7 @@ func signer(c *cli.Context) error {
cors := utils.SplitAndTrim(c.String(utils.HTTPCORSDomainFlag.Name)) cors := utils.SplitAndTrim(c.String(utils.HTTPCORSDomainFlag.Name))
srv := rpc.NewServer() srv := rpc.NewServer()
srv.SetBatchLimits(node.DefaultConfig.BatchRequestLimit, node.DefaultConfig.BatchResponseMaxSize)
err := node.RegisterApis(rpcAPI, []string{"account"}, srv) err := node.RegisterApis(rpcAPI, []string{"account"}, srv)
if err != nil { if err != nil {
utils.Fatalf("Could not register API: %w", err) utils.Fatalf("Could not register API: %w", err)

@ -168,6 +168,8 @@ var (
utils.RPCGlobalEVMTimeoutFlag, utils.RPCGlobalEVMTimeoutFlag,
utils.RPCGlobalTxFeeCapFlag, utils.RPCGlobalTxFeeCapFlag,
utils.AllowUnprotectedTxs, utils.AllowUnprotectedTxs,
utils.BatchRequestLimit,
utils.BatchResponseMaxSize,
} }
metricsFlags = []cli.Flag{ metricsFlags = []cli.Flag{

@ -713,6 +713,18 @@ var (
Usage: "Allow for unprotected (non EIP155 signed) transactions to be submitted via RPC", Usage: "Allow for unprotected (non EIP155 signed) transactions to be submitted via RPC",
Category: flags.APICategory, Category: flags.APICategory,
} }
BatchRequestLimit = &cli.IntFlag{
Name: "rpc.batch-request-limit",
Usage: "Maximum number of requests in a batch",
Value: node.DefaultConfig.BatchRequestLimit,
Category: flags.APICategory,
}
BatchResponseMaxSize = &cli.IntFlag{
Name: "rpc.batch-response-max-size",
Usage: "Maximum number of bytes returned from a batched call",
Value: node.DefaultConfig.BatchResponseMaxSize,
Category: flags.APICategory,
}
EnablePersonal = &cli.BoolFlag{ EnablePersonal = &cli.BoolFlag{
Name: "rpc.enabledeprecatedpersonal", Name: "rpc.enabledeprecatedpersonal",
Usage: "Enables the (deprecated) personal namespace", Usage: "Enables the (deprecated) personal namespace",
@ -1130,6 +1142,14 @@ func setHTTP(ctx *cli.Context, cfg *node.Config) {
if ctx.IsSet(AllowUnprotectedTxs.Name) { if ctx.IsSet(AllowUnprotectedTxs.Name) {
cfg.AllowUnprotectedTxs = ctx.Bool(AllowUnprotectedTxs.Name) cfg.AllowUnprotectedTxs = ctx.Bool(AllowUnprotectedTxs.Name)
} }
if ctx.IsSet(BatchRequestLimit.Name) {
cfg.BatchRequestLimit = ctx.Int(BatchRequestLimit.Name)
}
if ctx.IsSet(BatchResponseMaxSize.Name) {
cfg.BatchResponseMaxSize = ctx.Int(BatchResponseMaxSize.Name)
}
} }
// setGraphQL creates the GraphQL listener interface string from the set // setGraphQL creates the GraphQL listener interface string from the set

@ -176,6 +176,10 @@ func (api *adminAPI) StartHTTP(host *string, port *int, cors *string, apis *stri
CorsAllowedOrigins: api.node.config.HTTPCors, CorsAllowedOrigins: api.node.config.HTTPCors,
Vhosts: api.node.config.HTTPVirtualHosts, Vhosts: api.node.config.HTTPVirtualHosts,
Modules: api.node.config.HTTPModules, Modules: api.node.config.HTTPModules,
rpcEndpointConfig: rpcEndpointConfig{
batchItemLimit: api.node.config.BatchRequestLimit,
batchResponseSizeLimit: api.node.config.BatchResponseMaxSize,
},
} }
if cors != nil { if cors != nil {
config.CorsAllowedOrigins = nil config.CorsAllowedOrigins = nil
@ -250,6 +254,10 @@ func (api *adminAPI) StartWS(host *string, port *int, allowedOrigins *string, ap
Modules: api.node.config.WSModules, Modules: api.node.config.WSModules,
Origins: api.node.config.WSOrigins, Origins: api.node.config.WSOrigins,
// ExposeAll: api.node.config.WSExposeAll, // ExposeAll: api.node.config.WSExposeAll,
rpcEndpointConfig: rpcEndpointConfig{
batchItemLimit: api.node.config.BatchRequestLimit,
batchResponseSizeLimit: api.node.config.BatchResponseMaxSize,
},
} }
if apis != nil { if apis != nil {
config.Modules = nil config.Modules = nil

@ -197,6 +197,12 @@ type Config struct {
// AllowUnprotectedTxs allows non EIP-155 protected transactions to be send over RPC. // AllowUnprotectedTxs allows non EIP-155 protected transactions to be send over RPC.
AllowUnprotectedTxs bool `toml:",omitempty"` AllowUnprotectedTxs bool `toml:",omitempty"`
// BatchRequestLimit is the maximum number of requests in a batch.
BatchRequestLimit int `toml:",omitempty"`
// BatchResponseMaxSize is the maximum number of bytes returned from a batched rpc call.
BatchResponseMaxSize int `toml:",omitempty"`
// JWTSecret is the path to the hex-encoded jwt secret. // JWTSecret is the path to the hex-encoded jwt secret.
JWTSecret string `toml:",omitempty"` JWTSecret string `toml:",omitempty"`

@ -46,17 +46,19 @@ var (
// DefaultConfig contains reasonable default settings. // DefaultConfig contains reasonable default settings.
var DefaultConfig = Config{ var DefaultConfig = Config{
DataDir: DefaultDataDir(), DataDir: DefaultDataDir(),
HTTPPort: DefaultHTTPPort, HTTPPort: DefaultHTTPPort,
AuthAddr: DefaultAuthHost, AuthAddr: DefaultAuthHost,
AuthPort: DefaultAuthPort, AuthPort: DefaultAuthPort,
AuthVirtualHosts: DefaultAuthVhosts, AuthVirtualHosts: DefaultAuthVhosts,
HTTPModules: []string{"net", "web3"}, HTTPModules: []string{"net", "web3"},
HTTPVirtualHosts: []string{"localhost"}, HTTPVirtualHosts: []string{"localhost"},
HTTPTimeouts: rpc.DefaultHTTPTimeouts, HTTPTimeouts: rpc.DefaultHTTPTimeouts,
WSPort: DefaultWSPort, WSPort: DefaultWSPort,
WSModules: []string{"net", "web3"}, WSModules: []string{"net", "web3"},
GraphQLVirtualHosts: []string{"localhost"}, BatchRequestLimit: 1000,
BatchResponseMaxSize: 25 * 1000 * 1000,
GraphQLVirtualHosts: []string{"localhost"},
P2P: p2p.Config{ P2P: p2p.Config{
ListenAddr: ":30303", ListenAddr: ":30303",
MaxPeers: 50, MaxPeers: 50,

@ -101,10 +101,11 @@ func New(conf *Config) (*Node, error) {
if strings.HasSuffix(conf.Name, ".ipc") { if strings.HasSuffix(conf.Name, ".ipc") {
return nil, errors.New(`Config.Name cannot end in ".ipc"`) return nil, errors.New(`Config.Name cannot end in ".ipc"`)
} }
server := rpc.NewServer()
server.SetBatchLimits(conf.BatchRequestLimit, conf.BatchResponseMaxSize)
node := &Node{ node := &Node{
config: conf, config: conf,
inprocHandler: rpc.NewServer(), inprocHandler: server,
eventmux: new(event.TypeMux), eventmux: new(event.TypeMux),
log: conf.Logger, log: conf.Logger,
stop: make(chan struct{}), stop: make(chan struct{}),
@ -403,6 +404,11 @@ func (n *Node) startRPC() error {
openAPIs, allAPIs = n.getAPIs() openAPIs, allAPIs = n.getAPIs()
) )
rpcConfig := rpcEndpointConfig{
batchItemLimit: n.config.BatchRequestLimit,
batchResponseSizeLimit: n.config.BatchResponseMaxSize,
}
initHttp := func(server *httpServer, port int) error { initHttp := func(server *httpServer, port int) error {
if err := server.setListenAddr(n.config.HTTPHost, port); err != nil { if err := server.setListenAddr(n.config.HTTPHost, port); err != nil {
return err return err
@ -412,6 +418,7 @@ func (n *Node) startRPC() error {
Vhosts: n.config.HTTPVirtualHosts, Vhosts: n.config.HTTPVirtualHosts,
Modules: n.config.HTTPModules, Modules: n.config.HTTPModules,
prefix: n.config.HTTPPathPrefix, prefix: n.config.HTTPPathPrefix,
rpcEndpointConfig: rpcConfig,
}); err != nil { }); err != nil {
return err return err
} }
@ -425,9 +432,10 @@ func (n *Node) startRPC() error {
return err return err
} }
if err := server.enableWS(openAPIs, wsConfig{ if err := server.enableWS(openAPIs, wsConfig{
Modules: n.config.WSModules, Modules: n.config.WSModules,
Origins: n.config.WSOrigins, Origins: n.config.WSOrigins,
prefix: n.config.WSPathPrefix, prefix: n.config.WSPathPrefix,
rpcEndpointConfig: rpcConfig,
}); err != nil { }); err != nil {
return err return err
} }
@ -441,26 +449,29 @@ func (n *Node) startRPC() error {
if err := server.setListenAddr(n.config.AuthAddr, port); err != nil { if err := server.setListenAddr(n.config.AuthAddr, port); err != nil {
return err return err
} }
sharedConfig := rpcConfig
sharedConfig.jwtSecret = secret
if err := server.enableRPC(allAPIs, httpConfig{ if err := server.enableRPC(allAPIs, httpConfig{
CorsAllowedOrigins: DefaultAuthCors, CorsAllowedOrigins: DefaultAuthCors,
Vhosts: n.config.AuthVirtualHosts, Vhosts: n.config.AuthVirtualHosts,
Modules: DefaultAuthModules, Modules: DefaultAuthModules,
prefix: DefaultAuthPrefix, prefix: DefaultAuthPrefix,
jwtSecret: secret, rpcEndpointConfig: sharedConfig,
}); err != nil { }); err != nil {
return err return err
} }
servers = append(servers, server) servers = append(servers, server)
// Enable auth via WS // Enable auth via WS
server = n.wsServerForPort(port, true) server = n.wsServerForPort(port, true)
if err := server.setListenAddr(n.config.AuthAddr, port); err != nil { if err := server.setListenAddr(n.config.AuthAddr, port); err != nil {
return err return err
} }
if err := server.enableWS(allAPIs, wsConfig{ if err := server.enableWS(allAPIs, wsConfig{
Modules: DefaultAuthModules, Modules: DefaultAuthModules,
Origins: DefaultAuthOrigins, Origins: DefaultAuthOrigins,
prefix: DefaultAuthPrefix, prefix: DefaultAuthPrefix,
jwtSecret: secret, rpcEndpointConfig: sharedConfig,
}); err != nil { }); err != nil {
return err return err
} }

@ -41,15 +41,21 @@ type httpConfig struct {
CorsAllowedOrigins []string CorsAllowedOrigins []string
Vhosts []string Vhosts []string
prefix string // path prefix on which to mount http handler prefix string // path prefix on which to mount http handler
jwtSecret []byte // optional JWT secret rpcEndpointConfig
} }
// wsConfig is the JSON-RPC/Websocket configuration // wsConfig is the JSON-RPC/Websocket configuration
type wsConfig struct { type wsConfig struct {
Origins []string Origins []string
Modules []string Modules []string
prefix string // path prefix on which to mount ws handler prefix string // path prefix on which to mount ws handler
jwtSecret []byte // optional JWT secret rpcEndpointConfig
}
type rpcEndpointConfig struct {
jwtSecret []byte // optional JWT secret
batchItemLimit int
batchResponseSizeLimit int
} }
type rpcHandler struct { type rpcHandler struct {
@ -297,6 +303,7 @@ func (h *httpServer) enableRPC(apis []rpc.API, config httpConfig) error {
// Create RPC server and handler. // Create RPC server and handler.
srv := rpc.NewServer() srv := rpc.NewServer()
srv.SetBatchLimits(config.batchItemLimit, config.batchResponseSizeLimit)
if err := RegisterApis(apis, config.Modules, srv); err != nil { if err := RegisterApis(apis, config.Modules, srv); err != nil {
return err return err
} }
@ -328,6 +335,7 @@ func (h *httpServer) enableWS(apis []rpc.API, config wsConfig) error {
} }
// Create RPC server and handler. // Create RPC server and handler.
srv := rpc.NewServer() srv := rpc.NewServer()
srv.SetBatchLimits(config.batchItemLimit, config.batchResponseSizeLimit)
if err := RegisterApis(apis, config.Modules, srv); err != nil { if err := RegisterApis(apis, config.Modules, srv); err != nil {
return err return err
} }

@ -339,8 +339,10 @@ func TestJWT(t *testing.T) {
ss, _ := jwt.NewWithClaims(method, testClaim(input)).SignedString(secret) ss, _ := jwt.NewWithClaims(method, testClaim(input)).SignedString(secret)
return ss return ss
} }
srv := createAndStartServer(t, &httpConfig{jwtSecret: []byte("secret")}, cfg := rpcEndpointConfig{jwtSecret: []byte("secret")}
true, &wsConfig{Origins: []string{"*"}, jwtSecret: []byte("secret")}, nil) httpcfg := &httpConfig{rpcEndpointConfig: cfg}
wscfg := &wsConfig{Origins: []string{"*"}, rpcEndpointConfig: cfg}
srv := createAndStartServer(t, httpcfg, true, wscfg, nil)
wsUrl := fmt.Sprintf("ws://%v", srv.listenAddr()) wsUrl := fmt.Sprintf("ws://%v", srv.listenAddr())
htUrl := fmt.Sprintf("http://%v", srv.listenAddr()) htUrl := fmt.Sprintf("http://%v", srv.listenAddr())

@ -34,14 +34,15 @@ import (
var ( var (
ErrBadResult = errors.New("bad result in JSON-RPC response") ErrBadResult = errors.New("bad result in JSON-RPC response")
ErrClientQuit = errors.New("client is closed") ErrClientQuit = errors.New("client is closed")
ErrNoResult = errors.New("no result in JSON-RPC response") ErrNoResult = errors.New("JSON-RPC response has no result")
ErrMissingBatchResponse = errors.New("response batch did not contain a response to this call")
ErrSubscriptionQueueOverflow = errors.New("subscription queue overflow") ErrSubscriptionQueueOverflow = errors.New("subscription queue overflow")
errClientReconnected = errors.New("client reconnected") errClientReconnected = errors.New("client reconnected")
errDead = errors.New("connection lost") errDead = errors.New("connection lost")
) )
// Timeouts
const ( const (
// Timeouts
defaultDialTimeout = 10 * time.Second // used if context has no deadline defaultDialTimeout = 10 * time.Second // used if context has no deadline
subscribeTimeout = 10 * time.Second // overall timeout eth_subscribe, rpc_modules calls subscribeTimeout = 10 * time.Second // overall timeout eth_subscribe, rpc_modules calls
) )
@ -84,6 +85,10 @@ type Client struct {
// This function, if non-nil, is called when the connection is lost. // This function, if non-nil, is called when the connection is lost.
reconnectFunc reconnectFunc reconnectFunc reconnectFunc
// config fields
batchItemLimit int
batchResponseMaxSize int
// writeConn is used for writing to the connection on the caller's goroutine. It should // writeConn is used for writing to the connection on the caller's goroutine. It should
// only be accessed outside of dispatch, with the write lock held. The write lock is // only be accessed outside of dispatch, with the write lock held. The write lock is
// taken by sending on reqInit and released by sending on reqSent. // taken by sending on reqInit and released by sending on reqSent.
@ -114,7 +119,7 @@ func (c *Client) newClientConn(conn ServerCodec) *clientConn {
ctx := context.Background() ctx := context.Background()
ctx = context.WithValue(ctx, clientContextKey{}, c) ctx = context.WithValue(ctx, clientContextKey{}, c)
ctx = context.WithValue(ctx, peerInfoContextKey{}, conn.peerInfo()) ctx = context.WithValue(ctx, peerInfoContextKey{}, conn.peerInfo())
handler := newHandler(ctx, conn, c.idgen, c.services) handler := newHandler(ctx, conn, c.idgen, c.services, c.batchItemLimit, c.batchResponseMaxSize)
return &clientConn{conn, handler} return &clientConn{conn, handler}
} }
@ -128,14 +133,17 @@ type readOp struct {
batch bool batch bool
} }
// requestOp represents a pending request. This is used for both batch and non-batch
// requests.
type requestOp struct { type requestOp struct {
ids []json.RawMessage ids []json.RawMessage
err error err error
resp chan *jsonrpcMessage // receives up to len(ids) responses resp chan []*jsonrpcMessage // the response goes here
sub *ClientSubscription // only set for EthSubscribe requests sub *ClientSubscription // set for Subscribe requests.
hadResponse bool // true when the request was responded to
} }
func (op *requestOp) wait(ctx context.Context, c *Client) (*jsonrpcMessage, error) { func (op *requestOp) wait(ctx context.Context, c *Client) ([]*jsonrpcMessage, error) {
select { select {
case <-ctx.Done(): case <-ctx.Done():
// Send the timeout to dispatch so it can remove the request IDs. // Send the timeout to dispatch so it can remove the request IDs.
@ -211,7 +219,7 @@ func DialOptions(ctx context.Context, rawurl string, options ...ClientOption) (*
return nil, fmt.Errorf("no known transport for URL scheme %q", u.Scheme) return nil, fmt.Errorf("no known transport for URL scheme %q", u.Scheme)
} }
return newClient(ctx, reconnect) return newClient(ctx, cfg, reconnect)
} }
// ClientFromContext retrieves the client from the context, if any. This can be used to perform // ClientFromContext retrieves the client from the context, if any. This can be used to perform
@ -221,33 +229,42 @@ func ClientFromContext(ctx context.Context) (*Client, bool) {
return client, ok return client, ok
} }
func newClient(initctx context.Context, connect reconnectFunc) (*Client, error) { func newClient(initctx context.Context, cfg *clientConfig, connect reconnectFunc) (*Client, error) {
conn, err := connect(initctx) conn, err := connect(initctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
c := initClient(conn, randomIDGenerator(), new(serviceRegistry)) c := initClient(conn, new(serviceRegistry), cfg)
c.reconnectFunc = connect c.reconnectFunc = connect
return c, nil return c, nil
} }
func initClient(conn ServerCodec, idgen func() ID, services *serviceRegistry) *Client { func initClient(conn ServerCodec, services *serviceRegistry, cfg *clientConfig) *Client {
_, isHTTP := conn.(*httpConn) _, isHTTP := conn.(*httpConn)
c := &Client{ c := &Client{
isHTTP: isHTTP, isHTTP: isHTTP,
idgen: idgen, services: services,
services: services, idgen: cfg.idgen,
writeConn: conn, batchItemLimit: cfg.batchItemLimit,
close: make(chan struct{}), batchResponseMaxSize: cfg.batchResponseLimit,
closing: make(chan struct{}), writeConn: conn,
didClose: make(chan struct{}), close: make(chan struct{}),
reconnected: make(chan ServerCodec), closing: make(chan struct{}),
readOp: make(chan readOp), didClose: make(chan struct{}),
readErr: make(chan error), reconnected: make(chan ServerCodec),
reqInit: make(chan *requestOp), readOp: make(chan readOp),
reqSent: make(chan error, 1), readErr: make(chan error),
reqTimeout: make(chan *requestOp), reqInit: make(chan *requestOp),
} reqSent: make(chan error, 1),
reqTimeout: make(chan *requestOp),
}
// Set defaults.
if c.idgen == nil {
c.idgen = randomIDGenerator()
}
// Launch the main loop.
if !isHTTP { if !isHTTP {
go c.dispatch(conn) go c.dispatch(conn)
} }
@ -325,7 +342,10 @@ func (c *Client) CallContext(ctx context.Context, result interface{}, method str
if err != nil { if err != nil {
return err return err
} }
op := &requestOp{ids: []json.RawMessage{msg.ID}, resp: make(chan *jsonrpcMessage, 1)} op := &requestOp{
ids: []json.RawMessage{msg.ID},
resp: make(chan []*jsonrpcMessage, 1),
}
if c.isHTTP { if c.isHTTP {
err = c.sendHTTP(ctx, op, msg) err = c.sendHTTP(ctx, op, msg)
@ -337,9 +357,12 @@ func (c *Client) CallContext(ctx context.Context, result interface{}, method str
} }
// dispatch has accepted the request and will close the channel when it quits. // dispatch has accepted the request and will close the channel when it quits.
switch resp, err := op.wait(ctx, c); { batchresp, err := op.wait(ctx, c)
case err != nil: if err != nil {
return err return err
}
resp := batchresp[0]
switch {
case resp.Error != nil: case resp.Error != nil:
return resp.Error return resp.Error
case len(resp.Result) == 0: case len(resp.Result) == 0:
@ -380,7 +403,7 @@ func (c *Client) BatchCallContext(ctx context.Context, b []BatchElem) error {
) )
op := &requestOp{ op := &requestOp{
ids: make([]json.RawMessage, len(b)), ids: make([]json.RawMessage, len(b)),
resp: make(chan *jsonrpcMessage, len(b)), resp: make(chan []*jsonrpcMessage, 1),
} }
for i, elem := range b { for i, elem := range b {
msg, err := c.newMessage(elem.Method, elem.Args...) msg, err := c.newMessage(elem.Method, elem.Args...)
@ -398,28 +421,48 @@ func (c *Client) BatchCallContext(ctx context.Context, b []BatchElem) error {
} else { } else {
err = c.send(ctx, op, msgs) err = c.send(ctx, op, msgs)
} }
if err != nil {
return err
}
batchresp, err := op.wait(ctx, c)
if err != nil {
return err
}
// Wait for all responses to come back. // Wait for all responses to come back.
for n := 0; n < len(b) && err == nil; n++ { for n := 0; n < len(batchresp) && err == nil; n++ {
var resp *jsonrpcMessage resp := batchresp[n]
resp, err = op.wait(ctx, c) if resp == nil {
if err != nil { // Ignore null responses. These can happen for batches sent via HTTP.
break continue
} }
// Find the element corresponding to this response. // Find the element corresponding to this response.
// The element is guaranteed to be present because dispatch index, ok := byID[string(resp.ID)]
// only sends valid IDs to our channel. if !ok {
elem := &b[byID[string(resp.ID)]]
if resp.Error != nil {
elem.Error = resp.Error
continue continue
} }
if len(resp.Result) == 0 { delete(byID, string(resp.ID))
// Assign result and error.
elem := &b[index]
switch {
case resp.Error != nil:
elem.Error = resp.Error
case resp.Result == nil:
elem.Error = ErrNoResult elem.Error = ErrNoResult
continue default:
elem.Error = json.Unmarshal(resp.Result, elem.Result)
} }
elem.Error = json.Unmarshal(resp.Result, elem.Result)
} }
// Check that all expected responses have been received.
for _, index := range byID {
elem := &b[index]
elem.Error = ErrMissingBatchResponse
}
return err return err
} }
@ -480,7 +523,7 @@ func (c *Client) Subscribe(ctx context.Context, namespace string, channel interf
} }
op := &requestOp{ op := &requestOp{
ids: []json.RawMessage{msg.ID}, ids: []json.RawMessage{msg.ID},
resp: make(chan *jsonrpcMessage), resp: make(chan []*jsonrpcMessage, 1),
sub: newClientSubscription(c, namespace, chanVal), sub: newClientSubscription(c, namespace, chanVal),
} }

@ -28,11 +28,18 @@ type ClientOption interface {
} }
type clientConfig struct { type clientConfig struct {
// HTTP settings
httpClient *http.Client httpClient *http.Client
httpHeaders http.Header httpHeaders http.Header
httpAuth HTTPAuth httpAuth HTTPAuth
// WebSocket options
wsDialer *websocket.Dialer wsDialer *websocket.Dialer
// RPC handler options
idgen func() ID
batchItemLimit int
batchResponseLimit int
} }
func (cfg *clientConfig) initHeaders() { func (cfg *clientConfig) initHeaders() {
@ -104,3 +111,25 @@ func WithHTTPAuth(a HTTPAuth) ClientOption {
// Usually, HTTPAuth functions will call h.Set("authorization", "...") to add // Usually, HTTPAuth functions will call h.Set("authorization", "...") to add
// auth information to the request. // auth information to the request.
type HTTPAuth func(h http.Header) error type HTTPAuth func(h http.Header) error
// WithBatchItemLimit changes the maximum number of items allowed in batch requests.
//
// Note: this option applies when processing incoming batch requests. It does not affect
// batch requests sent by the client.
func WithBatchItemLimit(limit int) ClientOption {
return optionFunc(func(cfg *clientConfig) {
cfg.batchItemLimit = limit
})
}
// WithBatchResponseSizeLimit changes the maximum number of response bytes that can be
// generated for batch requests. When this limit is reached, further calls in the batch
// will not be processed.
//
// Note: this option applies when processing incoming batch requests. It does not affect
// batch requests sent by the client.
func WithBatchResponseSizeLimit(sizeLimit int) ClientOption {
return optionFunc(func(cfg *clientConfig) {
cfg.batchResponseLimit = sizeLimit
})
}

@ -169,10 +169,12 @@ func TestClientBatchRequest(t *testing.T) {
} }
} }
// This checks that, for HTTP connections, the length of batch responses is validated to
// match the request exactly.
func TestClientBatchRequest_len(t *testing.T) { func TestClientBatchRequest_len(t *testing.T) {
b, err := json.Marshal([]jsonrpcMessage{ b, err := json.Marshal([]jsonrpcMessage{
{Version: "2.0", ID: json.RawMessage("1"), Method: "foo", Result: json.RawMessage(`"0x1"`)}, {Version: "2.0", ID: json.RawMessage("1"), Result: json.RawMessage(`"0x1"`)},
{Version: "2.0", ID: json.RawMessage("2"), Method: "bar", Result: json.RawMessage(`"0x2"`)}, {Version: "2.0", ID: json.RawMessage("2"), Result: json.RawMessage(`"0x2"`)},
}) })
if err != nil { if err != nil {
t.Fatal("failed to encode jsonrpc message:", err) t.Fatal("failed to encode jsonrpc message:", err)
@ -185,37 +187,102 @@ func TestClientBatchRequest_len(t *testing.T) {
})) }))
t.Cleanup(s.Close) t.Cleanup(s.Close)
client, err := Dial(s.URL)
if err != nil {
t.Fatal("failed to dial test server:", err)
}
defer client.Close()
t.Run("too-few", func(t *testing.T) { t.Run("too-few", func(t *testing.T) {
client, err := Dial(s.URL)
if err != nil {
t.Fatal("failed to dial test server:", err)
}
defer client.Close()
batch := []BatchElem{ batch := []BatchElem{
{Method: "foo"}, {Method: "foo", Result: new(string)},
{Method: "bar"}, {Method: "bar", Result: new(string)},
{Method: "baz"}, {Method: "baz", Result: new(string)},
} }
ctx, cancelFn := context.WithTimeout(context.Background(), time.Second) ctx, cancelFn := context.WithTimeout(context.Background(), time.Second)
defer cancelFn() defer cancelFn()
if err := client.BatchCallContext(ctx, batch); !errors.Is(err, ErrBadResult) {
t.Errorf("expected %q but got: %v", ErrBadResult, err) if err := client.BatchCallContext(ctx, batch); err != nil {
t.Fatal("error:", err)
}
for i, elem := range batch[:2] {
if elem.Error != nil {
t.Errorf("expected no error for batch element %d, got %q", i, elem.Error)
}
}
for i, elem := range batch[2:] {
if elem.Error != ErrMissingBatchResponse {
t.Errorf("wrong error %q for batch element %d", elem.Error, i+2)
}
} }
}) })
t.Run("too-many", func(t *testing.T) { t.Run("too-many", func(t *testing.T) {
client, err := Dial(s.URL)
if err != nil {
t.Fatal("failed to dial test server:", err)
}
defer client.Close()
batch := []BatchElem{ batch := []BatchElem{
{Method: "foo"}, {Method: "foo", Result: new(string)},
} }
ctx, cancelFn := context.WithTimeout(context.Background(), time.Second) ctx, cancelFn := context.WithTimeout(context.Background(), time.Second)
defer cancelFn() defer cancelFn()
if err := client.BatchCallContext(ctx, batch); !errors.Is(err, ErrBadResult) {
t.Errorf("expected %q but got: %v", ErrBadResult, err) if err := client.BatchCallContext(ctx, batch); err != nil {
t.Fatal("error:", err)
}
for i, elem := range batch[:1] {
if elem.Error != nil {
t.Errorf("expected no error for batch element %d, got %q", i, elem.Error)
}
}
for i, elem := range batch[1:] {
if elem.Error != ErrMissingBatchResponse {
t.Errorf("wrong error %q for batch element %d", elem.Error, i+2)
}
} }
}) })
} }
// This checks that the client can handle the case where the server doesn't
// respond to all requests in a batch.
func TestClientBatchRequestLimit(t *testing.T) {
server := newTestServer()
defer server.Stop()
server.SetBatchLimits(2, 100000)
client := DialInProc(server)
batch := []BatchElem{
{Method: "foo"},
{Method: "bar"},
{Method: "baz"},
}
err := client.BatchCall(batch)
if err != nil {
t.Fatal("unexpected error:", err)
}
// Check that the first response indicates an error with batch size.
var err0 Error
if !errors.As(batch[0].Error, &err0) {
t.Log("error zero:", batch[0].Error)
t.Fatalf("batch elem 0 has wrong error type: %T", batch[0].Error)
} else {
if err0.ErrorCode() != -32600 || err0.Error() != errMsgBatchTooLarge {
t.Fatalf("wrong error on batch elem zero: %v", err0)
}
}
// Check that remaining response batch elements are reported as absent.
for i, elem := range batch[1:] {
if elem.Error != ErrMissingBatchResponse {
t.Fatalf("batch elem %d has unexpected error: %v", i+1, elem.Error)
}
}
}
func TestClientNotify(t *testing.T) { func TestClientNotify(t *testing.T) {
server := newTestServer() server := newTestServer()
defer server.Stop() defer server.Stop()
@ -310,7 +377,7 @@ func testClientCancel(transport string, t *testing.T) {
_, hasDeadline := ctx.Deadline() _, hasDeadline := ctx.Deadline()
t.Errorf("no error for call with %v wait time (deadline: %v)", timeout, hasDeadline) t.Errorf("no error for call with %v wait time (deadline: %v)", timeout, hasDeadline)
// default: // default:
// t.Logf("got expected error with %v wait time: %v", timeout, err) // t.Logf("got expected error with %v wait time: %v", timeout, err)
} }
cancel() cancel()
} }
@ -487,7 +554,8 @@ func TestClientSubscriptionUnsubscribeServer(t *testing.T) {
defer srv.Stop() defer srv.Stop()
// Create the client on the other end of the pipe. // Create the client on the other end of the pipe.
client, _ := newClient(context.Background(), func(context.Context) (ServerCodec, error) { cfg := new(clientConfig)
client, _ := newClient(context.Background(), cfg, func(context.Context) (ServerCodec, error) {
return NewCodec(p2), nil return NewCodec(p2), nil
}) })
defer client.Close() defer client.Close()

@ -61,12 +61,15 @@ const (
errcodeDefault = -32000 errcodeDefault = -32000
errcodeNotificationsUnsupported = -32001 errcodeNotificationsUnsupported = -32001
errcodeTimeout = -32002 errcodeTimeout = -32002
errcodeResponseTooLarge = -32003
errcodePanic = -32603 errcodePanic = -32603
errcodeMarshalError = -32603 errcodeMarshalError = -32603
) )
const ( const (
errMsgTimeout = "request timed out" errMsgTimeout = "request timed out"
errMsgResponseTooLarge = "response too large"
errMsgBatchTooLarge = "batch too large"
) )
type methodNotFoundError struct{ method string } type methodNotFoundError struct{ method string }

@ -49,17 +49,19 @@ import (
// h.removeRequestOp(op) // timeout, etc. // h.removeRequestOp(op) // timeout, etc.
// } // }
type handler struct { type handler struct {
reg *serviceRegistry reg *serviceRegistry
unsubscribeCb *callback unsubscribeCb *callback
idgen func() ID // subscription ID generator idgen func() ID // subscription ID generator
respWait map[string]*requestOp // active client requests respWait map[string]*requestOp // active client requests
clientSubs map[string]*ClientSubscription // active client subscriptions clientSubs map[string]*ClientSubscription // active client subscriptions
callWG sync.WaitGroup // pending call goroutines callWG sync.WaitGroup // pending call goroutines
rootCtx context.Context // canceled by close() rootCtx context.Context // canceled by close()
cancelRoot func() // cancel function for rootCtx cancelRoot func() // cancel function for rootCtx
conn jsonWriter // where responses will be sent conn jsonWriter // where responses will be sent
log log.Logger log log.Logger
allowSubscribe bool allowSubscribe bool
batchRequestLimit int
batchResponseMaxSize int
subLock sync.Mutex subLock sync.Mutex
serverSubs map[ID]*Subscription serverSubs map[ID]*Subscription
@ -70,19 +72,21 @@ type callProc struct {
notifiers []*Notifier notifiers []*Notifier
} }
func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *serviceRegistry) *handler { func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *serviceRegistry, batchRequestLimit, batchResponseMaxSize int) *handler {
rootCtx, cancelRoot := context.WithCancel(connCtx) rootCtx, cancelRoot := context.WithCancel(connCtx)
h := &handler{ h := &handler{
reg: reg, reg: reg,
idgen: idgen, idgen: idgen,
conn: conn, conn: conn,
respWait: make(map[string]*requestOp), respWait: make(map[string]*requestOp),
clientSubs: make(map[string]*ClientSubscription), clientSubs: make(map[string]*ClientSubscription),
rootCtx: rootCtx, rootCtx: rootCtx,
cancelRoot: cancelRoot, cancelRoot: cancelRoot,
allowSubscribe: true, allowSubscribe: true,
serverSubs: make(map[ID]*Subscription), serverSubs: make(map[ID]*Subscription),
log: log.Root(), log: log.Root(),
batchRequestLimit: batchRequestLimit,
batchResponseMaxSize: batchResponseMaxSize,
} }
if conn.remoteAddr() != "" { if conn.remoteAddr() != "" {
h.log = h.log.New("conn", conn.remoteAddr()) h.log = h.log.New("conn", conn.remoteAddr())
@ -134,16 +138,15 @@ func (b *batchCallBuffer) write(ctx context.Context, conn jsonWriter) {
b.doWrite(ctx, conn, false) b.doWrite(ctx, conn, false)
} }
// timeout sends the responses added so far. For the remaining unanswered call // respondWithError sends the responses added so far. For the remaining unanswered call
// messages, it sends a timeout error response. // messages, it responds with the given error.
func (b *batchCallBuffer) timeout(ctx context.Context, conn jsonWriter) { func (b *batchCallBuffer) respondWithError(ctx context.Context, conn jsonWriter, err error) {
b.mutex.Lock() b.mutex.Lock()
defer b.mutex.Unlock() defer b.mutex.Unlock()
for _, msg := range b.calls { for _, msg := range b.calls {
if !msg.isNotification() { if !msg.isNotification() {
resp := msg.errorResponse(&internalServerError{errcodeTimeout, errMsgTimeout}) b.resp = append(b.resp, msg.errorResponse(err))
b.resp = append(b.resp, resp)
} }
} }
b.doWrite(ctx, conn, true) b.doWrite(ctx, conn, true)
@ -171,17 +174,24 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) {
}) })
return return
} }
// Apply limit on total number of requests.
if h.batchRequestLimit != 0 && len(msgs) > h.batchRequestLimit {
h.startCallProc(func(cp *callProc) {
h.respondWithBatchTooLarge(cp, msgs)
})
return
}
// Handle non-call messages first: // Handle non-call messages first.
// Here we need to find the requestOp that sent the request batch.
calls := make([]*jsonrpcMessage, 0, len(msgs)) calls := make([]*jsonrpcMessage, 0, len(msgs))
for _, msg := range msgs { h.handleResponses(msgs, func(msg *jsonrpcMessage) {
if handled := h.handleImmediate(msg); !handled { calls = append(calls, msg)
calls = append(calls, msg) })
}
}
if len(calls) == 0 { if len(calls) == 0 {
return return
} }
// Process calls on a goroutine because they may block indefinitely: // Process calls on a goroutine because they may block indefinitely:
h.startCallProc(func(cp *callProc) { h.startCallProc(func(cp *callProc) {
var ( var (
@ -199,10 +209,12 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) {
if timeout, ok := ContextRequestTimeout(cp.ctx); ok { if timeout, ok := ContextRequestTimeout(cp.ctx); ok {
timer = time.AfterFunc(timeout, func() { timer = time.AfterFunc(timeout, func() {
cancel() cancel()
callBuffer.timeout(cp.ctx, h.conn) err := &internalServerError{errcodeTimeout, errMsgTimeout}
callBuffer.respondWithError(cp.ctx, h.conn, err)
}) })
} }
responseBytes := 0
for { for {
// No need to handle rest of calls if timed out. // No need to handle rest of calls if timed out.
if cp.ctx.Err() != nil { if cp.ctx.Err() != nil {
@ -214,59 +226,86 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) {
} }
resp := h.handleCallMsg(cp, msg) resp := h.handleCallMsg(cp, msg)
callBuffer.pushResponse(resp) callBuffer.pushResponse(resp)
if resp != nil && h.batchResponseMaxSize != 0 {
responseBytes += len(resp.Result)
if responseBytes > h.batchResponseMaxSize {
err := &internalServerError{errcodeResponseTooLarge, errMsgResponseTooLarge}
callBuffer.respondWithError(cp.ctx, h.conn, err)
break
}
}
} }
if timer != nil { if timer != nil {
timer.Stop() timer.Stop()
} }
callBuffer.write(cp.ctx, h.conn)
h.addSubscriptions(cp.notifiers) h.addSubscriptions(cp.notifiers)
callBuffer.write(cp.ctx, h.conn)
for _, n := range cp.notifiers { for _, n := range cp.notifiers {
n.activate() n.activate()
} }
}) })
} }
// handleMsg handles a single message. func (h *handler) respondWithBatchTooLarge(cp *callProc, batch []*jsonrpcMessage) {
func (h *handler) handleMsg(msg *jsonrpcMessage) { resp := errorMessage(&invalidRequestError{errMsgBatchTooLarge})
if ok := h.handleImmediate(msg); ok { // Find the first call and add its "id" field to the error.
return // This is the best we can do, given that the protocol doesn't have a way
// of reporting an error for the entire batch.
for _, msg := range batch {
if msg.isCall() {
resp.ID = msg.ID
break
}
} }
h.startCallProc(func(cp *callProc) { h.conn.writeJSON(cp.ctx, []*jsonrpcMessage{resp}, true)
var ( }
responded sync.Once
timer *time.Timer
cancel context.CancelFunc
)
cp.ctx, cancel = context.WithCancel(cp.ctx)
defer cancel()
// Cancel the request context after timeout and send an error response. Since the // handleMsg handles a single non-batch message.
// running method might not return immediately on timeout, we must wait for the func (h *handler) handleMsg(msg *jsonrpcMessage) {
// timeout concurrently with processing the request. msgs := []*jsonrpcMessage{msg}
if timeout, ok := ContextRequestTimeout(cp.ctx); ok { h.handleResponses(msgs, func(msg *jsonrpcMessage) {
timer = time.AfterFunc(timeout, func() { h.startCallProc(func(cp *callProc) {
cancel() h.handleNonBatchCall(cp, msg)
responded.Do(func() { })
resp := msg.errorResponse(&internalServerError{errcodeTimeout, errMsgTimeout}) })
h.conn.writeJSON(cp.ctx, resp, true) }
})
})
}
answer := h.handleCallMsg(cp, msg) func (h *handler) handleNonBatchCall(cp *callProc, msg *jsonrpcMessage) {
if timer != nil { var (
timer.Stop() responded sync.Once
} timer *time.Timer
h.addSubscriptions(cp.notifiers) cancel context.CancelFunc
if answer != nil { )
cp.ctx, cancel = context.WithCancel(cp.ctx)
defer cancel()
// Cancel the request context after timeout and send an error response. Since the
// running method might not return immediately on timeout, we must wait for the
// timeout concurrently with processing the request.
if timeout, ok := ContextRequestTimeout(cp.ctx); ok {
timer = time.AfterFunc(timeout, func() {
cancel()
responded.Do(func() { responded.Do(func() {
h.conn.writeJSON(cp.ctx, answer, false) resp := msg.errorResponse(&internalServerError{errcodeTimeout, errMsgTimeout})
h.conn.writeJSON(cp.ctx, resp, true)
}) })
} })
for _, n := range cp.notifiers { }
n.activate()
} answer := h.handleCallMsg(cp, msg)
}) if timer != nil {
timer.Stop()
}
h.addSubscriptions(cp.notifiers)
if answer != nil {
responded.Do(func() {
h.conn.writeJSON(cp.ctx, answer, false)
})
}
for _, n := range cp.notifiers {
n.activate()
}
} }
// close cancels all requests except for inflightReq and waits for // close cancels all requests except for inflightReq and waits for
@ -349,23 +388,60 @@ func (h *handler) startCallProc(fn func(*callProc)) {
}() }()
} }
// handleImmediate executes non-call messages. It returns false if the message is a // handleResponse processes method call responses.
// call or requires a reply. func (h *handler) handleResponses(batch []*jsonrpcMessage, handleCall func(*jsonrpcMessage)) {
func (h *handler) handleImmediate(msg *jsonrpcMessage) bool { var resolvedops []*requestOp
start := time.Now() handleResp := func(msg *jsonrpcMessage) {
switch { op := h.respWait[string(msg.ID)]
case msg.isNotification(): if op == nil {
if strings.HasSuffix(msg.Method, notificationMethodSuffix) { h.log.Debug("Unsolicited RPC response", "reqid", idForLog{msg.ID})
h.handleSubscriptionResult(msg) return
return true }
resolvedops = append(resolvedops, op)
delete(h.respWait, string(msg.ID))
// For subscription responses, start the subscription if the server
// indicates success. EthSubscribe gets unblocked in either case through
// the op.resp channel.
if op.sub != nil {
if msg.Error != nil {
op.err = msg.Error
} else {
op.err = json.Unmarshal(msg.Result, &op.sub.subid)
if op.err == nil {
go op.sub.run()
h.clientSubs[op.sub.subid] = op.sub
}
}
}
if !op.hadResponse {
op.hadResponse = true
op.resp <- batch
} }
return false }
case msg.isResponse():
h.handleResponse(msg) for _, msg := range batch {
h.log.Trace("Handled RPC response", "reqid", idForLog{msg.ID}, "duration", time.Since(start)) start := time.Now()
return true switch {
default: case msg.isResponse():
return false handleResp(msg)
h.log.Trace("Handled RPC response", "reqid", idForLog{msg.ID}, "duration", time.Since(start))
case msg.isNotification():
if strings.HasSuffix(msg.Method, notificationMethodSuffix) {
h.handleSubscriptionResult(msg)
continue
}
handleCall(msg)
default:
handleCall(msg)
}
}
for _, op := range resolvedops {
h.removeRequestOp(op)
} }
} }
@ -381,33 +457,6 @@ func (h *handler) handleSubscriptionResult(msg *jsonrpcMessage) {
} }
} }
// handleResponse processes method call responses.
func (h *handler) handleResponse(msg *jsonrpcMessage) {
op := h.respWait[string(msg.ID)]
if op == nil {
h.log.Debug("Unsolicited RPC response", "reqid", idForLog{msg.ID})
return
}
delete(h.respWait, string(msg.ID))
// For normal responses, just forward the reply to Call/BatchCall.
if op.sub == nil {
op.resp <- msg
return
}
// For subscription responses, start the subscription if the server
// indicates success. EthSubscribe gets unblocked in either case through
// the op.resp channel.
defer close(op.resp)
if msg.Error != nil {
op.err = msg.Error
return
}
if op.err = json.Unmarshal(msg.Result, &op.sub.subid); op.err == nil {
go op.sub.run()
h.clientSubs[op.sub.subid] = op.sub
}
}
// handleCallMsg executes a call message and returns the answer. // handleCallMsg executes a call message and returns the answer.
func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMessage { func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMessage {
start := time.Now() start := time.Now()
@ -416,6 +465,7 @@ func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMess
h.handleCall(ctx, msg) h.handleCall(ctx, msg)
h.log.Debug("Served "+msg.Method, "duration", time.Since(start)) h.log.Debug("Served "+msg.Method, "duration", time.Since(start))
return nil return nil
case msg.isCall(): case msg.isCall():
resp := h.handleCall(ctx, msg) resp := h.handleCall(ctx, msg)
var ctx []interface{} var ctx []interface{}
@ -430,8 +480,10 @@ func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMess
h.log.Debug("Served "+msg.Method, ctx...) h.log.Debug("Served "+msg.Method, ctx...)
} }
return resp return resp
case msg.hasValidID(): case msg.hasValidID():
return msg.errorResponse(&invalidRequestError{"invalid request"}) return msg.errorResponse(&invalidRequestError{"invalid request"})
default: default:
return errorMessage(&invalidRequestError{"invalid request"}) return errorMessage(&invalidRequestError{"invalid request"})
} }
@ -451,12 +503,14 @@ func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage
if callb == nil { if callb == nil {
return msg.errorResponse(&methodNotFoundError{method: msg.Method}) return msg.errorResponse(&methodNotFoundError{method: msg.Method})
} }
args, err := parsePositionalArguments(msg.Params, callb.argTypes) args, err := parsePositionalArguments(msg.Params, callb.argTypes)
if err != nil { if err != nil {
return msg.errorResponse(&invalidParamsError{err.Error()}) return msg.errorResponse(&invalidParamsError{err.Error()})
} }
start := time.Now() start := time.Now()
answer := h.runMethod(cp.ctx, msg, callb, args) answer := h.runMethod(cp.ctx, msg, callb, args)
// Collect the statistics for RPC calls if metrics is enabled. // Collect the statistics for RPC calls if metrics is enabled.
// We only care about pure rpc call. Filter out subscription. // We only care about pure rpc call. Filter out subscription.
if callb != h.unsubscribeCb { if callb != h.unsubscribeCb {
@ -469,6 +523,7 @@ func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage
rpcServingTimer.UpdateSince(start) rpcServingTimer.UpdateSince(start)
updateServeTimeHistogram(msg.Method, answer.Error == nil, time.Since(start)) updateServeTimeHistogram(msg.Method, answer.Error == nil, time.Since(start))
} }
return answer return answer
} }

@ -139,7 +139,7 @@ func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) {
var cfg clientConfig var cfg clientConfig
cfg.httpClient = client cfg.httpClient = client
fn := newClientTransportHTTP(endpoint, &cfg) fn := newClientTransportHTTP(endpoint, &cfg)
return newClient(context.Background(), fn) return newClient(context.Background(), &cfg, fn)
} }
func newClientTransportHTTP(endpoint string, cfg *clientConfig) reconnectFunc { func newClientTransportHTTP(endpoint string, cfg *clientConfig) reconnectFunc {
@ -176,11 +176,12 @@ func (c *Client) sendHTTP(ctx context.Context, op *requestOp, msg interface{}) e
} }
defer respBody.Close() defer respBody.Close()
var respmsg jsonrpcMessage var resp jsonrpcMessage
if err := json.NewDecoder(respBody).Decode(&respmsg); err != nil { batch := [1]*jsonrpcMessage{&resp}
if err := json.NewDecoder(respBody).Decode(&resp); err != nil {
return err return err
} }
op.resp <- &respmsg op.resp <- batch[:]
return nil return nil
} }
@ -191,16 +192,12 @@ func (c *Client) sendBatchHTTP(ctx context.Context, op *requestOp, msgs []*jsonr
return err return err
} }
defer respBody.Close() defer respBody.Close()
var respmsgs []jsonrpcMessage
var respmsgs []*jsonrpcMessage
if err := json.NewDecoder(respBody).Decode(&respmsgs); err != nil { if err := json.NewDecoder(respBody).Decode(&respmsgs); err != nil {
return err return err
} }
if len(respmsgs) != len(msgs) { op.resp <- respmsgs
return fmt.Errorf("batch has %d requests but response has %d: %w", len(msgs), len(respmsgs), ErrBadResult)
}
for i := 0; i < len(respmsgs); i++ {
op.resp <- &respmsgs[i]
}
return nil return nil
} }

@ -24,7 +24,8 @@ import (
// DialInProc attaches an in-process connection to the given RPC server. // DialInProc attaches an in-process connection to the given RPC server.
func DialInProc(handler *Server) *Client { func DialInProc(handler *Server) *Client {
initctx := context.Background() initctx := context.Background()
c, _ := newClient(initctx, func(context.Context) (ServerCodec, error) { cfg := new(clientConfig)
c, _ := newClient(initctx, cfg, func(context.Context) (ServerCodec, error) {
p1, p2 := net.Pipe() p1, p2 := net.Pipe()
go handler.ServeCodec(NewCodec(p1), 0) go handler.ServeCodec(NewCodec(p1), 0)
return NewCodec(p2), nil return NewCodec(p2), nil

@ -46,7 +46,8 @@ func (s *Server) ServeListener(l net.Listener) error {
// The context is used for the initial connection establishment. It does not // The context is used for the initial connection establishment. It does not
// affect subsequent interactions with the client. // affect subsequent interactions with the client.
func DialIPC(ctx context.Context, endpoint string) (*Client, error) { func DialIPC(ctx context.Context, endpoint string) (*Client, error) {
return newClient(ctx, newClientTransportIPC(endpoint)) cfg := new(clientConfig)
return newClient(ctx, cfg, newClientTransportIPC(endpoint))
} }
func newClientTransportIPC(endpoint string) reconnectFunc { func newClientTransportIPC(endpoint string) reconnectFunc {

@ -46,9 +46,11 @@ type Server struct {
services serviceRegistry services serviceRegistry
idgen func() ID idgen func() ID
mutex sync.Mutex mutex sync.Mutex
codecs map[ServerCodec]struct{} codecs map[ServerCodec]struct{}
run atomic.Bool run atomic.Bool
batchItemLimit int
batchResponseLimit int
} }
// NewServer creates a new server instance with no registered handlers. // NewServer creates a new server instance with no registered handlers.
@ -65,6 +67,17 @@ func NewServer() *Server {
return server return server
} }
// SetBatchLimits sets limits applied to batch requests. There are two limits: 'itemLimit'
// is the maximum number of items in a batch. 'maxResponseSize' is the maximum number of
// response bytes across all requests in a batch.
//
// This method should be called before processing any requests via ServeCodec, ServeHTTP,
// ServeListener etc.
func (s *Server) SetBatchLimits(itemLimit, maxResponseSize int) {
s.batchItemLimit = itemLimit
s.batchResponseLimit = maxResponseSize
}
// RegisterName creates a service for the given receiver type under the given name. When no // RegisterName creates a service for the given receiver type under the given name. When no
// methods on the given receiver match the criteria to be either a RPC method or a // methods on the given receiver match the criteria to be either a RPC method or a
// subscription an error is returned. Otherwise a new service is created and added to the // subscription an error is returned. Otherwise a new service is created and added to the
@ -86,7 +99,12 @@ func (s *Server) ServeCodec(codec ServerCodec, options CodecOption) {
} }
defer s.untrackCodec(codec) defer s.untrackCodec(codec)
c := initClient(codec, s.idgen, &s.services) cfg := &clientConfig{
idgen: s.idgen,
batchItemLimit: s.batchItemLimit,
batchResponseLimit: s.batchResponseLimit,
}
c := initClient(codec, &s.services, cfg)
<-codec.closed() <-codec.closed()
c.Close() c.Close()
} }
@ -118,7 +136,7 @@ func (s *Server) serveSingleRequest(ctx context.Context, codec ServerCodec) {
return return
} }
h := newHandler(ctx, codec, s.idgen, &s.services) h := newHandler(ctx, codec, s.idgen, &s.services, s.batchItemLimit, s.batchResponseLimit)
h.allowSubscribe = false h.allowSubscribe = false
defer h.close(io.EOF, nil) defer h.close(io.EOF, nil)

@ -70,6 +70,7 @@ func TestServer(t *testing.T) {
func runTestScript(t *testing.T, file string) { func runTestScript(t *testing.T, file string) {
server := newTestServer() server := newTestServer()
server.SetBatchLimits(4, 100000)
content, err := os.ReadFile(file) content, err := os.ReadFile(file)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -152,3 +153,41 @@ func TestServerShortLivedConn(t *testing.T) {
} }
} }
} }
func TestServerBatchResponseSizeLimit(t *testing.T) {
server := newTestServer()
defer server.Stop()
server.SetBatchLimits(100, 60)
var (
batch []BatchElem
client = DialInProc(server)
)
for i := 0; i < 5; i++ {
batch = append(batch, BatchElem{
Method: "test_echo",
Args: []any{"x", 1},
Result: new(echoResult),
})
}
if err := client.BatchCall(batch); err != nil {
t.Fatal("error sending batch:", err)
}
for i := range batch {
// We expect the first two queries to be ok, but after that the size limit takes effect.
if i < 2 {
if batch[i].Error != nil {
t.Fatalf("batch elem %d has unexpected error: %v", i, batch[i].Error)
}
continue
}
// After two, we expect an error.
re, ok := batch[i].Error.(Error)
if !ok {
t.Fatalf("batch elem %d has wrong error: %v", i, batch[i].Error)
}
wantedCode := errcodeResponseTooLarge
if re.ErrorCode() != wantedCode {
t.Errorf("batch elem %d wrong error code, have %d want %d", i, re.ErrorCode(), wantedCode)
}
}
}

@ -32,7 +32,8 @@ func DialStdIO(ctx context.Context) (*Client, error) {
// DialIO creates a client which uses the given IO channels // DialIO creates a client which uses the given IO channels
func DialIO(ctx context.Context, in io.Reader, out io.Writer) (*Client, error) { func DialIO(ctx context.Context, in io.Reader, out io.Writer) (*Client, error) {
return newClient(ctx, newClientTransportIO(in, out)) cfg := new(clientConfig)
return newClient(ctx, cfg, newClientTransportIO(in, out))
} }
func newClientTransportIO(in io.Reader, out io.Writer) reconnectFunc { func newClientTransportIO(in io.Reader, out io.Writer) reconnectFunc {

@ -0,0 +1,13 @@
// This file checks the behavior of the batch item limit code.
// In tests, the batch item limit is set to 4. So to trigger the error,
// all batches in this file have 5 elements.
// For batches that do not contain any calls, a response message with "id" == null
// is returned.
--> [{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]}]
<-- [{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"batch too large"}}]
// For batches with at least one call, the call's "id" is used.
--> [{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","id":3,"method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]}]
<-- [{"jsonrpc":"2.0","id":3,"error":{"code":-32600,"message":"batch too large"}}]

@ -197,7 +197,7 @@ func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, diale
if err != nil { if err != nil {
return nil, err return nil, err
} }
return newClient(ctx, connect) return newClient(ctx, cfg, connect)
} }
// DialWebsocket creates a new RPC client that communicates with a JSON-RPC server // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server
@ -214,7 +214,7 @@ func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error
if err != nil { if err != nil {
return nil, err return nil, err
} }
return newClient(ctx, connect) return newClient(ctx, cfg, connect)
} }
func newClientTransportWS(endpoint string, cfg *clientConfig) (reconnectFunc, error) { func newClientTransportWS(endpoint string, cfg *clientConfig) (reconnectFunc, error) {

Loading…
Cancel
Save