diff --git a/rpc/client.go b/rpc/client.go index 984f56bc6..91e68e73e 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -85,7 +85,7 @@ type Client struct { // writeConn is used for writing to the connection on the caller's goroutine. It should // only be accessed outside of dispatch, with the write lock held. The write lock is - // taken by sending on requestOp and released by sending on sendDone. + // taken by sending on reqInit and released by sending on reqSent. writeConn jsonWriter // for dispatch @@ -260,6 +260,19 @@ func (c *Client) Close() { } } +// SetHeader adds a custom HTTP header to the client's requests. +// This method only works for clients using HTTP, it doesn't have +// any effect for clients using another transport. +func (c *Client) SetHeader(key, value string) { + if !c.isHTTP { + return + } + conn := c.writeConn.(*httpConn) + conn.mu.Lock() + conn.headers.Set(key, value) + conn.mu.Unlock() +} + // Call performs a JSON-RPC call with the given arguments and unmarshals into // result if no error occurred. // diff --git a/rpc/client_test.go b/rpc/client_test.go index 19c2facb5..5b1f96035 100644 --- a/rpc/client_test.go +++ b/rpc/client_test.go @@ -26,6 +26,7 @@ import ( "os" "reflect" "runtime" + "strings" "sync" "testing" "time" @@ -429,6 +430,42 @@ func TestClientNotificationStorm(t *testing.T) { doTest(23000, true) } +func TestClientSetHeader(t *testing.T) { + var gotHeader bool + srv := newTestServer() + httpsrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("test") == "ok" { + gotHeader = true + } + srv.ServeHTTP(w, r) + })) + defer httpsrv.Close() + defer srv.Stop() + + client, err := Dial(httpsrv.URL) + if err != nil { + t.Fatal(err) + } + defer client.Close() + + client.SetHeader("test", "ok") + if _, err := client.SupportedModules(); err != nil { + t.Fatal(err) + } + if !gotHeader { + t.Fatal("client did not set custom header") + } + + // Check that Content-Type can be replaced. + client.SetHeader("content-type", "application/x-garbage") + _, err = client.SupportedModules() + if err == nil { + t.Fatal("no error for invalid content-type header") + } else if !strings.Contains(err.Error(), "Unsupported Media Type") { + t.Fatalf("error is not related to content-type: %q", err) + } +} + func TestClientHTTP(t *testing.T) { server := newTestServer() defer server.Stop() diff --git a/rpc/http.go b/rpc/http.go index 78fbd9f8f..87a96e49e 100644 --- a/rpc/http.go +++ b/rpc/http.go @@ -26,6 +26,7 @@ import ( "io/ioutil" "mime" "net/http" + "net/url" "sync" "time" ) @@ -40,9 +41,11 @@ var acceptedContentTypes = []string{contentType, "application/json-rpc", "applic type httpConn struct { client *http.Client - req *http.Request + url string closeOnce sync.Once closeCh chan interface{} + mu sync.Mutex // protects headers + headers http.Header } // httpConn is treated specially by Client. @@ -51,7 +54,7 @@ func (hc *httpConn) writeJSON(context.Context, interface{}) error { } func (hc *httpConn) remoteAddr() string { - return hc.req.URL.String() + return hc.url } func (hc *httpConn) readBatch() ([]*jsonrpcMessage, bool, error) { @@ -102,16 +105,24 @@ var DefaultHTTPTimeouts = HTTPTimeouts{ // DialHTTPWithClient creates a new RPC client that connects to an RPC server over HTTP // using the provided HTTP Client. func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) { - req, err := http.NewRequest(http.MethodPost, endpoint, nil) + // Sanity check URL so we don't end up with a client that will fail every request. + _, err := url.Parse(endpoint) if err != nil { return nil, err } - req.Header.Set("Content-Type", contentType) - req.Header.Set("Accept", contentType) initctx := context.Background() + headers := make(http.Header, 2) + headers.Set("accept", contentType) + headers.Set("content-type", contentType) return newClient(initctx, func(context.Context) (ServerCodec, error) { - return &httpConn{client: client, req: req, closeCh: make(chan interface{})}, nil + hc := &httpConn{ + client: client, + headers: headers, + url: endpoint, + closeCh: make(chan interface{}), + } + return hc, nil }) } @@ -131,7 +142,7 @@ func (c *Client) sendHTTP(ctx context.Context, op *requestOp, msg interface{}) e if respBody != nil { buf := new(bytes.Buffer) if _, err2 := buf.ReadFrom(respBody); err2 == nil { - return fmt.Errorf("%v %v", err, buf.String()) + return fmt.Errorf("%v: %v", err, buf.String()) } } return err @@ -166,10 +177,18 @@ func (hc *httpConn) doRequest(ctx context.Context, msg interface{}) (io.ReadClos if err != nil { return nil, err } - req := hc.req.WithContext(ctx) - req.Body = ioutil.NopCloser(bytes.NewReader(body)) + req, err := http.NewRequestWithContext(ctx, "POST", hc.url, ioutil.NopCloser(bytes.NewReader(body))) + if err != nil { + return nil, err + } req.ContentLength = int64(len(body)) + // set headers + hc.mu.Lock() + req.Header = hc.headers.Clone() + hc.mu.Unlock() + + // do request resp, err := hc.client.Do(req) if err != nil { return nil, err