From 07d909ff3283fb50a1312a1919787768e54a1c81 Mon Sep 17 00:00:00 2001
From: rene <41963722+renaynay@users.noreply.github.com>
Date: Wed, 8 Apr 2020 13:33:12 +0200
Subject: [PATCH] node: allow websocket and HTTP on the same port (#20810)
This change makes it possible to run geth with JSON-RPC over HTTP and
WebSocket on the same TCP port. The default port for WebSocket
is still 8546.
geth --rpc --rpcport 8545 --ws --wsport 8545
This also removes a lot of deprecated API surface from package rpc.
The rpc package is now purely about serving JSON-RPC and no longer
provides a way to start an HTTP server.
---
cmd/clef/main.go | 9 ++-
cmd/geth/retesteth.go | 10 ++-
graphql/service.go | 14 +++-
node/api.go | 2 +-
node/endpoints.go | 99 ++++++++++++++++++++++++++
node/node.go | 69 +++++++++++++++---
node/node_test.go | 58 +++++++++++++++
node/rpcstack.go | 159 ++++++++++++++++++++++++++++++++++++++++++
node/rpcstack_test.go | 38 ++++++++++
rpc/endpoints.go | 83 ----------------------
rpc/gzip.go | 66 ------------------
rpc/http.go | 97 --------------------------
rpc/websocket.go | 7 --
13 files changed, 443 insertions(+), 268 deletions(-)
create mode 100644 node/endpoints.go
create mode 100644 node/rpcstack.go
create mode 100644 node/rpcstack_test.go
delete mode 100644 rpc/gzip.go
diff --git a/cmd/clef/main.go b/cmd/clef/main.go
index f4533a6e85..801c7e9efd 100644
--- a/cmd/clef/main.go
+++ b/cmd/clef/main.go
@@ -583,9 +583,16 @@ func signer(c *cli.Context) error {
vhosts := splitAndTrim(c.GlobalString(utils.RPCVirtualHostsFlag.Name))
cors := splitAndTrim(c.GlobalString(utils.RPCCORSDomainFlag.Name))
+ srv := rpc.NewServer()
+ err := node.RegisterApisFromWhitelist(rpcAPI, []string{"account"}, srv, false)
+ if err != nil {
+ utils.Fatalf("Could not register API: %w", err)
+ }
+ handler := node.NewHTTPHandlerStack(srv, cors, vhosts)
+
// start http server
httpEndpoint := fmt.Sprintf("%s:%d", c.GlobalString(utils.RPCListenAddrFlag.Name), c.Int(rpcPortFlag.Name))
- listener, _, err := rpc.StartHTTPEndpoint(httpEndpoint, rpcAPI, []string{"account"}, cors, vhosts, rpc.DefaultHTTPTimeouts)
+ listener, err := node.StartHTTPEndpoint(httpEndpoint, rpc.DefaultHTTPTimeouts, handler)
if err != nil {
utils.Fatalf("Could not start RPC api: %v", err)
}
diff --git a/cmd/geth/retesteth.go b/cmd/geth/retesteth.go
index 05331d12dd..eccc8cd670 100644
--- a/cmd/geth/retesteth.go
+++ b/cmd/geth/retesteth.go
@@ -890,6 +890,14 @@ func retesteth(ctx *cli.Context) error {
vhosts := splitAndTrim(ctx.GlobalString(utils.RPCVirtualHostsFlag.Name))
cors := splitAndTrim(ctx.GlobalString(utils.RPCCORSDomainFlag.Name))
+ // register apis and create handler stack
+ srv := rpc.NewServer()
+ err := node.RegisterApisFromWhitelist(rpcAPI, []string{"test", "eth", "debug", "web3"}, srv, false)
+ if err != nil {
+ utils.Fatalf("Could not register RPC apis: %w", err)
+ }
+ handler := node.NewHTTPHandlerStack(srv, cors, vhosts)
+
// start http server
var RetestethHTTPTimeouts = rpc.HTTPTimeouts{
ReadTimeout: 120 * time.Second,
@@ -897,7 +905,7 @@ func retesteth(ctx *cli.Context) error {
IdleTimeout: 120 * time.Second,
}
httpEndpoint := fmt.Sprintf("%s:%d", ctx.GlobalString(utils.RPCListenAddrFlag.Name), ctx.Int(rpcPortFlag.Name))
- listener, _, err := rpc.StartHTTPEndpoint(httpEndpoint, rpcAPI, []string{"test", "eth", "debug", "web3"}, cors, vhosts, RetestethHTTPTimeouts)
+ listener, err := node.StartHTTPEndpoint(httpEndpoint, RetestethHTTPTimeouts, handler)
if err != nil {
utils.Fatalf("Could not start RPC api: %v", err)
}
diff --git a/graphql/service.go b/graphql/service.go
index 21892402db..a206053024 100644
--- a/graphql/service.go
+++ b/graphql/service.go
@@ -23,6 +23,7 @@ import (
"github.com/ethereum/go-ethereum/internal/ethapi"
"github.com/ethereum/go-ethereum/log"
+ "github.com/ethereum/go-ethereum/node"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/rpc"
"github.com/graph-gophers/graphql-go"
@@ -68,7 +69,18 @@ func (s *Service) Start(server *p2p.Server) error {
if s.listener, err = net.Listen("tcp", s.endpoint); err != nil {
return err
}
- go rpc.NewHTTPServer(s.cors, s.vhosts, s.timeouts, s.handler).Serve(s.listener)
+ // create handler stack and wrap the graphql handler
+ handler := node.NewHTTPHandlerStack(s.handler, s.cors, s.vhosts)
+ // make sure timeout values are meaningful
+ node.CheckTimeouts(&s.timeouts)
+ // create http server
+ httpSrv := &http.Server{
+ Handler: handler,
+ ReadTimeout: s.timeouts.ReadTimeout,
+ WriteTimeout: s.timeouts.WriteTimeout,
+ IdleTimeout: s.timeouts.IdleTimeout,
+ }
+ go httpSrv.Serve(s.listener)
log.Info("GraphQL endpoint opened", "url", fmt.Sprintf("http://%s", s.endpoint))
return nil
}
diff --git a/node/api.go b/node/api.go
index 66cd1dde33..1a73d1321d 100644
--- a/node/api.go
+++ b/node/api.go
@@ -186,7 +186,7 @@ func (api *PrivateAdminAPI) StartRPC(host *string, port *int, cors *string, apis
}
}
- if err := api.node.startHTTP(fmt.Sprintf("%s:%d", *host, *port), api.node.rpcAPIs, modules, allowedOrigins, allowedVHosts, api.node.config.HTTPTimeouts); err != nil {
+ if err := api.node.startHTTP(fmt.Sprintf("%s:%d", *host, *port), api.node.rpcAPIs, modules, allowedOrigins, allowedVHosts, api.node.config.HTTPTimeouts, api.node.config.WSOrigins); err != nil {
return false, err
}
return true, nil
diff --git a/node/endpoints.go b/node/endpoints.go
new file mode 100644
index 0000000000..8cd6b4d1c8
--- /dev/null
+++ b/node/endpoints.go
@@ -0,0 +1,99 @@
+// Copyright 2018 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package node
+
+import (
+ "net"
+ "net/http"
+ "time"
+
+ "github.com/ethereum/go-ethereum/log"
+ "github.com/ethereum/go-ethereum/rpc"
+)
+
+// StartHTTPEndpoint starts the HTTP RPC endpoint.
+func StartHTTPEndpoint(endpoint string, timeouts rpc.HTTPTimeouts, handler http.Handler) (net.Listener, error) {
+ // start the HTTP listener
+ var (
+ listener net.Listener
+ err error
+ )
+ if listener, err = net.Listen("tcp", endpoint); err != nil {
+ return nil, err
+ }
+ // make sure timeout values are meaningful
+ CheckTimeouts(&timeouts)
+ // Bundle and start the HTTP server
+ httpSrv := &http.Server{
+ Handler: handler,
+ ReadTimeout: timeouts.ReadTimeout,
+ WriteTimeout: timeouts.WriteTimeout,
+ IdleTimeout: timeouts.IdleTimeout,
+ }
+ go httpSrv.Serve(listener)
+ return listener, err
+}
+
+// startWSEndpoint starts a websocket endpoint.
+func startWSEndpoint(endpoint string, handler http.Handler) (net.Listener, error) {
+ // start the HTTP listener
+ var (
+ listener net.Listener
+ err error
+ )
+ if listener, err = net.Listen("tcp", endpoint); err != nil {
+ return nil, err
+ }
+ wsSrv := &http.Server{Handler: handler}
+ go wsSrv.Serve(listener)
+ return listener, err
+}
+
+// checkModuleAvailability checks that all names given in modules are actually
+// available API services. It assumes that the MetadataApi module ("rpc") is always available;
+// the registration of this "rpc" module happens in NewServer() and is thus common to all endpoints.
+func checkModuleAvailability(modules []string, apis []rpc.API) (bad, available []string) {
+ availableSet := make(map[string]struct{})
+ for _, api := range apis {
+ if _, ok := availableSet[api.Namespace]; !ok {
+ availableSet[api.Namespace] = struct{}{}
+ available = append(available, api.Namespace)
+ }
+ }
+ for _, name := range modules {
+ if _, ok := availableSet[name]; !ok && name != rpc.MetadataApi {
+ bad = append(bad, name)
+ }
+ }
+ return bad, available
+}
+
+// CheckTimeouts ensures that timeout values are meaningful
+func CheckTimeouts(timeouts *rpc.HTTPTimeouts) {
+ if timeouts.ReadTimeout < time.Second {
+ log.Warn("Sanitizing invalid HTTP read timeout", "provided", timeouts.ReadTimeout, "updated", rpc.DefaultHTTPTimeouts.ReadTimeout)
+ timeouts.ReadTimeout = rpc.DefaultHTTPTimeouts.ReadTimeout
+ }
+ if timeouts.WriteTimeout < time.Second {
+ log.Warn("Sanitizing invalid HTTP write timeout", "provided", timeouts.WriteTimeout, "updated", rpc.DefaultHTTPTimeouts.WriteTimeout)
+ timeouts.WriteTimeout = rpc.DefaultHTTPTimeouts.WriteTimeout
+ }
+ if timeouts.IdleTimeout < time.Second {
+ log.Warn("Sanitizing invalid HTTP idle timeout", "provided", timeouts.IdleTimeout, "updated", rpc.DefaultHTTPTimeouts.IdleTimeout)
+ timeouts.IdleTimeout = rpc.DefaultHTTPTimeouts.IdleTimeout
+ }
+}
diff --git a/node/node.go b/node/node.go
index 39e15e0a74..1d14317fc1 100644
--- a/node/node.go
+++ b/node/node.go
@@ -291,17 +291,21 @@ func (n *Node) startRPC(services map[reflect.Type]Service) error {
n.stopInProc()
return err
}
- if err := n.startHTTP(n.httpEndpoint, apis, n.config.HTTPModules, n.config.HTTPCors, n.config.HTTPVirtualHosts, n.config.HTTPTimeouts); err != nil {
+ if err := n.startHTTP(n.httpEndpoint, apis, n.config.HTTPModules, n.config.HTTPCors, n.config.HTTPVirtualHosts, n.config.HTTPTimeouts, n.config.WSOrigins); err != nil {
n.stopIPC()
n.stopInProc()
return err
}
- if err := n.startWS(n.wsEndpoint, apis, n.config.WSModules, n.config.WSOrigins, n.config.WSExposeAll); err != nil {
- n.stopHTTP()
- n.stopIPC()
- n.stopInProc()
- return err
+ // if endpoints are not the same, start separate servers
+ if n.httpEndpoint != n.wsEndpoint {
+ if err := n.startWS(n.wsEndpoint, apis, n.config.WSModules, n.config.WSOrigins, n.config.WSExposeAll); err != nil {
+ n.stopHTTP()
+ n.stopIPC()
+ n.stopInProc()
+ return err
+ }
}
+
// All API endpoints started successfully
n.rpcAPIs = apis
return nil
@@ -359,22 +363,36 @@ func (n *Node) stopIPC() {
}
// startHTTP initializes and starts the HTTP RPC endpoint.
-func (n *Node) startHTTP(endpoint string, apis []rpc.API, modules []string, cors []string, vhosts []string, timeouts rpc.HTTPTimeouts) error {
+func (n *Node) startHTTP(endpoint string, apis []rpc.API, modules []string, cors []string, vhosts []string, timeouts rpc.HTTPTimeouts, wsOrigins []string) error {
// Short circuit if the HTTP endpoint isn't being exposed
if endpoint == "" {
return nil
}
- listener, handler, err := rpc.StartHTTPEndpoint(endpoint, apis, modules, cors, vhosts, timeouts)
+ // register apis and create handler stack
+ srv := rpc.NewServer()
+ err := RegisterApisFromWhitelist(apis, modules, srv, false)
+ if err != nil {
+ return err
+ }
+ handler := NewHTTPHandlerStack(srv, cors, vhosts)
+ // wrap handler in websocket handler only if websocket port is the same as http rpc
+ if n.httpEndpoint == n.wsEndpoint {
+ handler = NewWebsocketUpgradeHandler(handler, srv.WebsocketHandler(wsOrigins))
+ }
+ listener, err := StartHTTPEndpoint(endpoint, timeouts, handler)
if err != nil {
return err
}
n.log.Info("HTTP endpoint opened", "url", fmt.Sprintf("http://%v/", listener.Addr()),
"cors", strings.Join(cors, ","),
"vhosts", strings.Join(vhosts, ","))
+ if n.httpEndpoint == n.wsEndpoint {
+ n.log.Info("WebSocket endpoint opened", "url", fmt.Sprintf("ws://%v", listener.Addr()))
+ }
// All listeners booted successfully
n.httpEndpoint = endpoint
n.httpListener = listener
- n.httpHandler = handler
+ n.httpHandler = srv
return nil
}
@@ -399,7 +417,14 @@ func (n *Node) startWS(endpoint string, apis []rpc.API, modules []string, wsOrig
if endpoint == "" {
return nil
}
- listener, handler, err := rpc.StartWSEndpoint(endpoint, apis, modules, wsOrigins, exposeAll)
+
+ srv := rpc.NewServer()
+ handler := srv.WebsocketHandler(wsOrigins)
+ err := RegisterApisFromWhitelist(apis, modules, srv, exposeAll)
+ if err != nil {
+ return err
+ }
+ listener, err := startWSEndpoint(endpoint, handler)
if err != nil {
return err
}
@@ -407,7 +432,7 @@ func (n *Node) startWS(endpoint string, apis []rpc.API, modules []string, wsOrig
// All listeners booted successfully
n.wsEndpoint = endpoint
n.wsListener = listener
- n.wsHandler = handler
+ n.wsHandler = srv
return nil
}
@@ -664,3 +689,25 @@ func (n *Node) apis() []rpc.API {
},
}
}
+
+// RegisterApisFromWhitelist checks the given modules' availability, generates a whitelist based on the allowed modules,
+// and then registers all of the APIs exposed by the services.
+func RegisterApisFromWhitelist(apis []rpc.API, modules []string, srv *rpc.Server, exposeAll bool) error {
+ if bad, available := checkModuleAvailability(modules, apis); len(bad) > 0 {
+ log.Error("Unavailable modules in HTTP API list", "unavailable", bad, "available", available)
+ }
+ // Generate the whitelist based on the allowed modules
+ whitelist := make(map[string]bool)
+ for _, module := range modules {
+ whitelist[module] = true
+ }
+ // Register all the APIs exposed by the services
+ for _, api := range apis {
+ if exposeAll || whitelist[api.Namespace] || (len(whitelist) == 0 && api.Public) {
+ if err := srv.RegisterName(api.Namespace, api.Service); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
diff --git a/node/node_test.go b/node/node_test.go
index c464771cd8..e246731fef 100644
--- a/node/node_test.go
+++ b/node/node_test.go
@@ -19,6 +19,7 @@ package node
import (
"errors"
"io/ioutil"
+ "net/http"
"os"
"reflect"
"testing"
@@ -27,6 +28,8 @@ import (
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/rpc"
+
+ "github.com/stretchr/testify/assert"
)
var (
@@ -597,3 +600,58 @@ func TestAPIGather(t *testing.T) {
}
}
}
+
+func TestWebsocketHTTPOnSamePort_WebsocketRequest(t *testing.T) {
+ node := startHTTP(t)
+ defer node.stopHTTP()
+
+ wsReq, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:7453", nil)
+ if err != nil {
+ t.Error("could not issue new http request ", err)
+ }
+ wsReq.Header.Set("Connection", "upgrade")
+ wsReq.Header.Set("Upgrade", "websocket")
+ wsReq.Header.Set("Sec-WebSocket-Version", "13")
+ wsReq.Header.Set("Sec-Websocket-Key", "SGVsbG8sIHdvcmxkIQ==")
+
+ resp := doHTTPRequest(t, wsReq)
+ assert.Equal(t, "websocket", resp.Header.Get("Upgrade"))
+}
+
+func TestWebsocketHTTPOnSamePort_HTTPRequest(t *testing.T) {
+ node := startHTTP(t)
+ defer node.stopHTTP()
+
+ httpReq, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:7453", nil)
+ if err != nil {
+ t.Error("could not issue new http request ", err)
+ }
+ httpReq.Header.Set("Accept-Encoding", "gzip")
+
+ resp := doHTTPRequest(t, httpReq)
+ assert.Equal(t, "gzip", resp.Header.Get("Content-Encoding"))
+}
+
+func startHTTP(t *testing.T) *Node {
+ conf := &Config{HTTPPort: 7453, WSPort: 7453}
+ node, err := New(conf)
+ if err != nil {
+ t.Error("could not create a new node ", err)
+ }
+
+ err = node.startHTTP("127.0.0.1:7453", []rpc.API{}, []string{}, []string{}, []string{}, rpc.HTTPTimeouts{}, []string{})
+ if err != nil {
+ t.Error("could not start http service on node ", err)
+ }
+
+ return node
+}
+
+func doHTTPRequest(t *testing.T, req *http.Request) *http.Response {
+ client := &http.Client{}
+ resp, err := client.Do(req)
+ if err != nil {
+ t.Error("could not issue a GET request to the given endpoint", err)
+ }
+ return resp
+}
diff --git a/node/rpcstack.go b/node/rpcstack.go
new file mode 100644
index 0000000000..793061968d
--- /dev/null
+++ b/node/rpcstack.go
@@ -0,0 +1,159 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package node
+
+import (
+ "compress/gzip"
+ "io"
+ "io/ioutil"
+ "net"
+ "net/http"
+ "strings"
+ "sync"
+
+ "github.com/ethereum/go-ethereum/log"
+ "github.com/rs/cors"
+)
+
+// NewHTTPHandlerStack returns wrapped http-related handlers
+func NewHTTPHandlerStack(srv http.Handler, cors []string, vhosts []string) http.Handler {
+ // Wrap the CORS-handler within a host-handler
+ handler := newCorsHandler(srv, cors)
+ handler = newVHostHandler(vhosts, handler)
+ return newGzipHandler(handler)
+}
+
+func newCorsHandler(srv http.Handler, allowedOrigins []string) http.Handler {
+ // disable CORS support if user has not specified a custom CORS configuration
+ if len(allowedOrigins) == 0 {
+ return srv
+ }
+ c := cors.New(cors.Options{
+ AllowedOrigins: allowedOrigins,
+ AllowedMethods: []string{http.MethodPost, http.MethodGet},
+ MaxAge: 600,
+ AllowedHeaders: []string{"*"},
+ })
+ return c.Handler(srv)
+}
+
+// virtualHostHandler is a handler which validates the Host-header of incoming requests.
+// Using virtual hosts can help prevent DNS rebinding attacks, where a 'random' domain name points to
+// the service ip address (but without CORS headers). By verifying the targeted virtual host, we can
+// ensure that it's a destination that the node operator has defined.
+type virtualHostHandler struct {
+ vhosts map[string]struct{}
+ next http.Handler
+}
+
+func newVHostHandler(vhosts []string, next http.Handler) http.Handler {
+ vhostMap := make(map[string]struct{})
+ for _, allowedHost := range vhosts {
+ vhostMap[strings.ToLower(allowedHost)] = struct{}{}
+ }
+ return &virtualHostHandler{vhostMap, next}
+}
+
+// ServeHTTP serves JSON-RPC requests over HTTP, implements http.Handler
+func (h *virtualHostHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ // if r.Host is not set, we can continue serving since a browser would set the Host header
+ if r.Host == "" {
+ h.next.ServeHTTP(w, r)
+ return
+ }
+ host, _, err := net.SplitHostPort(r.Host)
+ if err != nil {
+ // Either invalid (too many colons) or no port specified
+ host = r.Host
+ }
+ if ipAddr := net.ParseIP(host); ipAddr != nil {
+ // It's an IP address, we can serve that
+ h.next.ServeHTTP(w, r)
+ return
+
+ }
+ // Not an IP address, but a hostname. Need to validate
+ if _, exist := h.vhosts["*"]; exist {
+ h.next.ServeHTTP(w, r)
+ return
+ }
+ if _, exist := h.vhosts[host]; exist {
+ h.next.ServeHTTP(w, r)
+ return
+ }
+ http.Error(w, "invalid host specified", http.StatusForbidden)
+}
+
+var gzPool = sync.Pool{
+ New: func() interface{} {
+ w := gzip.NewWriter(ioutil.Discard)
+ return w
+ },
+}
+
+type gzipResponseWriter struct {
+ io.Writer
+ http.ResponseWriter
+}
+
+func (w *gzipResponseWriter) WriteHeader(status int) {
+ w.Header().Del("Content-Length")
+ w.ResponseWriter.WriteHeader(status)
+}
+
+func (w *gzipResponseWriter) Write(b []byte) (int, error) {
+ return w.Writer.Write(b)
+}
+
+func newGzipHandler(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
+ next.ServeHTTP(w, r)
+ return
+ }
+
+ w.Header().Set("Content-Encoding", "gzip")
+
+ gz := gzPool.Get().(*gzip.Writer)
+ defer gzPool.Put(gz)
+
+ gz.Reset(w)
+ defer gz.Close()
+
+ next.ServeHTTP(&gzipResponseWriter{ResponseWriter: w, Writer: gz}, r)
+ })
+}
+
+// NewWebsocketUpgradeHandler returns a websocket handler that serves an incoming request only if it contains an upgrade
+// request to the websocket protocol. If not, serves the the request with the http handler.
+func NewWebsocketUpgradeHandler(h http.Handler, ws http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if isWebsocket(r) {
+ ws.ServeHTTP(w, r)
+ log.Debug("serving websocket request")
+ return
+ }
+
+ h.ServeHTTP(w, r)
+ })
+}
+
+// isWebsocket checks the header of an http request for a websocket upgrade request.
+func isWebsocket(r *http.Request) bool {
+ return strings.ToLower(r.Header.Get("Upgrade")) == "websocket" &&
+ strings.ToLower(r.Header.Get("Connection")) == "upgrade"
+}
diff --git a/node/rpcstack_test.go b/node/rpcstack_test.go
new file mode 100644
index 0000000000..9db03181c9
--- /dev/null
+++ b/node/rpcstack_test.go
@@ -0,0 +1,38 @@
+package node
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/ethereum/go-ethereum/rpc"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestNewWebsocketUpgradeHandler_websocket(t *testing.T) {
+ srv := rpc.NewServer()
+
+ handler := NewWebsocketUpgradeHandler(nil, srv.WebsocketHandler([]string{}))
+ ts := httptest.NewServer(handler)
+ defer ts.Close()
+
+ responses := make(chan *http.Response)
+ go func(responses chan *http.Response) {
+ client := &http.Client{}
+
+ req, _ := http.NewRequest(http.MethodGet, ts.URL, nil)
+ req.Header.Set("Connection", "upgrade")
+ req.Header.Set("Upgrade", "websocket")
+ req.Header.Set("Sec-WebSocket-Version", "13")
+ req.Header.Set("Sec-Websocket-Key", "SGVsbG8sIHdvcmxkIQ==")
+
+ resp, err := client.Do(req)
+ if err != nil {
+ t.Error("could not issue a GET request to the test http server", err)
+ }
+ responses <- resp
+ }(responses)
+
+ response := <-responses
+ assert.Equal(t, "websocket", response.Header.Get("Upgrade"))
+}
diff --git a/rpc/endpoints.go b/rpc/endpoints.go
index 07aaa44122..9fc0705172 100644
--- a/rpc/endpoints.go
+++ b/rpc/endpoints.go
@@ -22,89 +22,6 @@ import (
"github.com/ethereum/go-ethereum/log"
)
-// checkModuleAvailability checks that all names given in modules are actually
-// available API services. It assumes that the MetadataApi module ("rpc") is always available;
-// the registration of this "rpc" module happens in NewServer() and is thus common to all endpoints.
-func checkModuleAvailability(modules []string, apis []API) (bad, available []string) {
- availableSet := make(map[string]struct{})
- for _, api := range apis {
- if _, ok := availableSet[api.Namespace]; !ok {
- availableSet[api.Namespace] = struct{}{}
- available = append(available, api.Namespace)
- }
- }
- for _, name := range modules {
- if _, ok := availableSet[name]; !ok && name != MetadataApi {
- bad = append(bad, name)
- }
- }
- return bad, available
-}
-
-// StartHTTPEndpoint starts the HTTP RPC endpoint, configured with cors/vhosts/modules.
-func StartHTTPEndpoint(endpoint string, apis []API, modules []string, cors []string, vhosts []string, timeouts HTTPTimeouts) (net.Listener, *Server, error) {
- if bad, available := checkModuleAvailability(modules, apis); len(bad) > 0 {
- log.Error("Unavailable modules in HTTP API list", "unavailable", bad, "available", available)
- }
- // Generate the whitelist based on the allowed modules
- whitelist := make(map[string]bool)
- for _, module := range modules {
- whitelist[module] = true
- }
- // Register all the APIs exposed by the services
- handler := NewServer()
- for _, api := range apis {
- if whitelist[api.Namespace] || (len(whitelist) == 0 && api.Public) {
- if err := handler.RegisterName(api.Namespace, api.Service); err != nil {
- return nil, nil, err
- }
- log.Debug("HTTP registered", "namespace", api.Namespace)
- }
- }
- // All APIs registered, start the HTTP listener
- var (
- listener net.Listener
- err error
- )
- if listener, err = net.Listen("tcp", endpoint); err != nil {
- return nil, nil, err
- }
- go NewHTTPServer(cors, vhosts, timeouts, handler).Serve(listener)
- return listener, handler, err
-}
-
-// StartWSEndpoint starts a websocket endpoint.
-func StartWSEndpoint(endpoint string, apis []API, modules []string, wsOrigins []string, exposeAll bool) (net.Listener, *Server, error) {
- if bad, available := checkModuleAvailability(modules, apis); len(bad) > 0 {
- log.Error("Unavailable modules in WS API list", "unavailable", bad, "available", available)
- }
- // Generate the whitelist based on the allowed modules
- whitelist := make(map[string]bool)
- for _, module := range modules {
- whitelist[module] = true
- }
- // Register all the APIs exposed by the services
- handler := NewServer()
- for _, api := range apis {
- if exposeAll || whitelist[api.Namespace] || (len(whitelist) == 0 && api.Public) {
- if err := handler.RegisterName(api.Namespace, api.Service); err != nil {
- return nil, nil, err
- }
- log.Debug("WebSocket registered", "service", api.Service, "namespace", api.Namespace)
- }
- }
- // All APIs registered, start the HTTP listener
- var (
- listener net.Listener
- err error
- )
- if listener, err = net.Listen("tcp", endpoint); err != nil {
- return nil, nil, err
- }
- go NewWSServer(wsOrigins, handler).Serve(listener)
- return listener, handler, err
-}
-
// StartIPCEndpoint starts an IPC endpoint.
func StartIPCEndpoint(ipcEndpoint string, apis []API) (net.Listener, *Server, error) {
// Register all the APIs exposed by the services.
diff --git a/rpc/gzip.go b/rpc/gzip.go
deleted file mode 100644
index a14fd09d54..0000000000
--- a/rpc/gzip.go
+++ /dev/null
@@ -1,66 +0,0 @@
-// Copyright 2019 The go-ethereum Authors
-// This file is part of the go-ethereum library.
-//
-// The go-ethereum library is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Lesser General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// The go-ethereum library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public License
-// along with the go-ethereum library. If not, see .
-
-package rpc
-
-import (
- "compress/gzip"
- "io"
- "io/ioutil"
- "net/http"
- "strings"
- "sync"
-)
-
-var gzPool = sync.Pool{
- New: func() interface{} {
- w := gzip.NewWriter(ioutil.Discard)
- return w
- },
-}
-
-type gzipResponseWriter struct {
- io.Writer
- http.ResponseWriter
-}
-
-func (w *gzipResponseWriter) WriteHeader(status int) {
- w.Header().Del("Content-Length")
- w.ResponseWriter.WriteHeader(status)
-}
-
-func (w *gzipResponseWriter) Write(b []byte) (int, error) {
- return w.Writer.Write(b)
-}
-
-func newGzipHandler(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
- next.ServeHTTP(w, r)
- return
- }
-
- w.Header().Set("Content-Encoding", "gzip")
-
- gz := gzPool.Get().(*gzip.Writer)
- defer gzPool.Put(gz)
-
- gz.Reset(w)
- defer gz.Close()
-
- next.ServeHTTP(&gzipResponseWriter{ResponseWriter: w, Writer: gz}, r)
- })
-}
diff --git a/rpc/http.go b/rpc/http.go
index 1cf6d90393..b3ce0a5b5e 100644
--- a/rpc/http.go
+++ b/rpc/http.go
@@ -25,14 +25,9 @@ import (
"io"
"io/ioutil"
"mime"
- "net"
"net/http"
- "strings"
"sync"
"time"
-
- "github.com/ethereum/go-ethereum/log"
- "github.com/rs/cors"
)
const (
@@ -209,37 +204,6 @@ func (t *httpServerConn) RemoteAddr() string {
// SetWriteDeadline does nothing and always returns nil.
func (t *httpServerConn) SetWriteDeadline(time.Time) error { return nil }
-// NewHTTPServer creates a new HTTP RPC server around an API provider.
-//
-// Deprecated: Server implements http.Handler
-func NewHTTPServer(cors []string, vhosts []string, timeouts HTTPTimeouts, srv http.Handler) *http.Server {
- // Wrap the CORS-handler within a host-handler
- handler := newCorsHandler(srv, cors)
- handler = newVHostHandler(vhosts, handler)
- handler = newGzipHandler(handler)
-
- // Make sure timeout values are meaningful
- if timeouts.ReadTimeout < time.Second {
- log.Warn("Sanitizing invalid HTTP read timeout", "provided", timeouts.ReadTimeout, "updated", DefaultHTTPTimeouts.ReadTimeout)
- timeouts.ReadTimeout = DefaultHTTPTimeouts.ReadTimeout
- }
- if timeouts.WriteTimeout < time.Second {
- log.Warn("Sanitizing invalid HTTP write timeout", "provided", timeouts.WriteTimeout, "updated", DefaultHTTPTimeouts.WriteTimeout)
- timeouts.WriteTimeout = DefaultHTTPTimeouts.WriteTimeout
- }
- if timeouts.IdleTimeout < time.Second {
- log.Warn("Sanitizing invalid HTTP idle timeout", "provided", timeouts.IdleTimeout, "updated", DefaultHTTPTimeouts.IdleTimeout)
- timeouts.IdleTimeout = DefaultHTTPTimeouts.IdleTimeout
- }
- // Bundle and start the HTTP server
- return &http.Server{
- Handler: handler,
- ReadTimeout: timeouts.ReadTimeout,
- WriteTimeout: timeouts.WriteTimeout,
- IdleTimeout: timeouts.IdleTimeout,
- }
-}
-
// ServeHTTP serves JSON-RPC requests over HTTP.
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Permit dumb empty requests for remote health-checks (AWS)
@@ -296,64 +260,3 @@ func validateRequest(r *http.Request) (int, error) {
err := fmt.Errorf("invalid content type, only %s is supported", contentType)
return http.StatusUnsupportedMediaType, err
}
-
-func newCorsHandler(srv http.Handler, allowedOrigins []string) http.Handler {
- // disable CORS support if user has not specified a custom CORS configuration
- if len(allowedOrigins) == 0 {
- return srv
- }
- c := cors.New(cors.Options{
- AllowedOrigins: allowedOrigins,
- AllowedMethods: []string{http.MethodPost, http.MethodGet},
- MaxAge: 600,
- AllowedHeaders: []string{"*"},
- })
- return c.Handler(srv)
-}
-
-// virtualHostHandler is a handler which validates the Host-header of incoming requests.
-// The virtualHostHandler can prevent DNS rebinding attacks, which do not utilize CORS-headers,
-// since they do in-domain requests against the RPC api. Instead, we can see on the Host-header
-// which domain was used, and validate that against a whitelist.
-type virtualHostHandler struct {
- vhosts map[string]struct{}
- next http.Handler
-}
-
-// ServeHTTP serves JSON-RPC requests over HTTP, implements http.Handler
-func (h *virtualHostHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
- // if r.Host is not set, we can continue serving since a browser would set the Host header
- if r.Host == "" {
- h.next.ServeHTTP(w, r)
- return
- }
- host, _, err := net.SplitHostPort(r.Host)
- if err != nil {
- // Either invalid (too many colons) or no port specified
- host = r.Host
- }
- if ipAddr := net.ParseIP(host); ipAddr != nil {
- // It's an IP address, we can serve that
- h.next.ServeHTTP(w, r)
- return
-
- }
- // Not an IP address, but a hostname. Need to validate
- if _, exist := h.vhosts["*"]; exist {
- h.next.ServeHTTP(w, r)
- return
- }
- if _, exist := h.vhosts[host]; exist {
- h.next.ServeHTTP(w, r)
- return
- }
- http.Error(w, "invalid host specified", http.StatusForbidden)
-}
-
-func newVHostHandler(vhosts []string, next http.Handler) http.Handler {
- vhostMap := make(map[string]struct{})
- for _, allowedHost := range vhosts {
- vhostMap[strings.ToLower(allowedHost)] = struct{}{}
- }
- return &virtualHostHandler{vhostMap, next}
-}
diff --git a/rpc/websocket.go b/rpc/websocket.go
index b7ec56c6a6..6e37b8522d 100644
--- a/rpc/websocket.go
+++ b/rpc/websocket.go
@@ -38,13 +38,6 @@ const (
var wsBufferPool = new(sync.Pool)
-// NewWSServer creates a new websocket RPC server around an API provider.
-//
-// Deprecated: use Server.WebsocketHandler
-func NewWSServer(allowedOrigins []string, srv *Server) *http.Server {
- return &http.Server{Handler: srv.WebsocketHandler(allowedOrigins)}
-}
-
// WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections.
//
// allowedOrigins should be a comma-separated list of allowed origin URLs.