@ -19,7 +19,9 @@ package node
import (
import (
"bytes"
"bytes"
"fmt"
"fmt"
"io"
"net/http"
"net/http"
"net/http/httptest"
"net/url"
"net/url"
"strconv"
"strconv"
"strings"
"strings"
@ -34,29 +36,31 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/assert"
)
)
const testMethod = "rpc_modules"
// TestCorsHandler makes sure CORS are properly handled on the http server.
// TestCorsHandler makes sure CORS are properly handled on the http server.
func TestCorsHandler ( t * testing . T ) {
func TestCorsHandler ( t * testing . T ) {
srv := createAndStartServer ( t , & httpConfig { CorsAllowedOrigins : [ ] string { "test" , "test.com" } } , false , & wsConfig { } )
srv := createAndStartServer ( t , & httpConfig { CorsAllowedOrigins : [ ] string { "test" , "test.com" } } , false , & wsConfig { } , nil )
defer srv . stop ( )
defer srv . stop ( )
url := "http://" + srv . listenAddr ( )
url := "http://" + srv . listenAddr ( )
resp := rpcRequest ( t , url , "origin" , "test.com" )
resp := rpcRequest ( t , url , testMethod , "origin" , "test.com" )
assert . Equal ( t , "test.com" , resp . Header . Get ( "Access-Control-Allow-Origin" ) )
assert . Equal ( t , "test.com" , resp . Header . Get ( "Access-Control-Allow-Origin" ) )
resp2 := rpcRequest ( t , url , "origin" , "bad" )
resp2 := rpcRequest ( t , url , testMethod , "origin" , "bad" )
assert . Equal ( t , "" , resp2 . Header . Get ( "Access-Control-Allow-Origin" ) )
assert . Equal ( t , "" , resp2 . Header . Get ( "Access-Control-Allow-Origin" ) )
}
}
// TestVhosts makes sure vhosts are properly handled on the http server.
// TestVhosts makes sure vhosts are properly handled on the http server.
func TestVhosts ( t * testing . T ) {
func TestVhosts ( t * testing . T ) {
srv := createAndStartServer ( t , & httpConfig { Vhosts : [ ] string { "test" } } , false , & wsConfig { } )
srv := createAndStartServer ( t , & httpConfig { Vhosts : [ ] string { "test" } } , false , & wsConfig { } , nil )
defer srv . stop ( )
defer srv . stop ( )
url := "http://" + srv . listenAddr ( )
url := "http://" + srv . listenAddr ( )
resp := rpcRequest ( t , url , "host" , "test" )
resp := rpcRequest ( t , url , testMethod , "host" , "test" )
assert . Equal ( t , resp . StatusCode , http . StatusOK )
assert . Equal ( t , resp . StatusCode , http . StatusOK )
resp2 := rpcRequest ( t , url , "host" , "bad" )
resp2 := rpcRequest ( t , url , testMethod , "host" , "bad" )
assert . Equal ( t , resp2 . StatusCode , http . StatusForbidden )
assert . Equal ( t , resp2 . StatusCode , http . StatusForbidden )
}
}
@ -145,7 +149,7 @@ func TestWebsocketOrigins(t *testing.T) {
} ,
} ,
}
}
for _ , tc := range tests {
for _ , tc := range tests {
srv := createAndStartServer ( t , & httpConfig { } , true , & wsConfig { Origins : splitAndTrim ( tc . spec ) } )
srv := createAndStartServer ( t , & httpConfig { } , true , & wsConfig { Origins : splitAndTrim ( tc . spec ) } , nil )
url := fmt . Sprintf ( "ws://%v" , srv . listenAddr ( ) )
url := fmt . Sprintf ( "ws://%v" , srv . listenAddr ( ) )
for _ , origin := range tc . expOk {
for _ , origin := range tc . expOk {
if err := wsRequest ( t , url , "Origin" , origin ) ; err != nil {
if err := wsRequest ( t , url , "Origin" , origin ) ; err != nil {
@ -231,11 +235,14 @@ func Test_checkPath(t *testing.T) {
}
}
}
}
func createAndStartServer ( t * testing . T , conf * httpConfig , ws bool , wsConf * wsConfig ) * httpServer {
func createAndStartServer ( t * testing . T , conf * httpConfig , ws bool , wsConf * wsConfig , timeouts * rpc . HTTPTimeouts ) * httpServer {
t . Helper ( )
t . Helper ( )
srv := newHTTPServer ( testlog . Logger ( t , log . LvlDebug ) , rpc . DefaultHTTPTimeouts )
if timeouts == nil {
assert . NoError ( t , srv . enableRPC ( nil , * conf ) )
timeouts = & rpc . DefaultHTTPTimeouts
}
srv := newHTTPServer ( testlog . Logger ( t , log . LvlDebug ) , * timeouts )
assert . NoError ( t , srv . enableRPC ( apis ( ) , * conf ) )
if ws {
if ws {
assert . NoError ( t , srv . enableWS ( nil , * wsConf ) )
assert . NoError ( t , srv . enableWS ( nil , * wsConf ) )
}
}
@ -266,16 +273,33 @@ func wsRequest(t *testing.T, url string, extraHeaders ...string) error {
}
}
// rpcRequest performs a JSON-RPC request to the given URL.
// rpcRequest performs a JSON-RPC request to the given URL.
func rpcRequest ( t * testing . T , url string , extraHeaders ... string ) * http . Response {
func rpcRequest ( t * testing . T , url , method string , extraHeaders ... string ) * http . Response {
t . Helper ( )
body := fmt . Sprintf ( ` { "jsonrpc":"2.0","id":1,"method":"%s","params":[]} ` , method )
return baseRpcRequest ( t , url , body , extraHeaders ... )
}
func batchRpcRequest ( t * testing . T , url string , methods [ ] string , extraHeaders ... string ) * http . Response {
reqs := make ( [ ] string , len ( methods ) )
for i , m := range methods {
reqs [ i ] = fmt . Sprintf ( ` { "jsonrpc":"2.0","id":1,"method":"%s","params":[]} ` , m )
}
body := fmt . Sprintf ( ` [%s] ` , strings . Join ( reqs , "," ) )
return baseRpcRequest ( t , url , body , extraHeaders ... )
}
func baseRpcRequest ( t * testing . T , url , bodyStr string , extraHeaders ... string ) * http . Response {
t . Helper ( )
t . Helper ( )
// Create the request.
// Create the request.
body := bytes . NewReader ( [ ] byte ( ` { "jsonrpc":"2.0","id":1,"method":"rpc_modules","params":[]} ` ) )
body := bytes . NewReader ( [ ] byte ( bodyStr ) )
req , err := http . NewRequest ( "POST" , url , body )
req , err := http . NewRequest ( "POST" , url , body )
if err != nil {
if err != nil {
t . Fatal ( "could not create http request:" , err )
t . Fatal ( "could not create http request:" , err )
}
}
req . Header . Set ( "content-type" , "application/json" )
req . Header . Set ( "content-type" , "application/json" )
req . Header . Set ( "accept-encoding" , "identity" )
// Apply extra headers.
// Apply extra headers.
if len ( extraHeaders ) % 2 != 0 {
if len ( extraHeaders ) % 2 != 0 {
@ -315,7 +339,7 @@ func TestJWT(t *testing.T) {
return ss
return ss
}
}
srv := createAndStartServer ( t , & httpConfig { jwtSecret : [ ] byte ( "secret" ) } ,
srv := createAndStartServer ( t , & httpConfig { jwtSecret : [ ] byte ( "secret" ) } ,
true , & wsConfig { Origins : [ ] string { "*" } , jwtSecret : [ ] byte ( "secret" ) } )
true , & wsConfig { Origins : [ ] string { "*" } , jwtSecret : [ ] byte ( "secret" ) } , nil )
wsUrl := fmt . Sprintf ( "ws://%v" , srv . listenAddr ( ) )
wsUrl := fmt . Sprintf ( "ws://%v" , srv . listenAddr ( ) )
htUrl := fmt . Sprintf ( "http://%v" , srv . listenAddr ( ) )
htUrl := fmt . Sprintf ( "http://%v" , srv . listenAddr ( ) )
@ -348,7 +372,7 @@ func TestJWT(t *testing.T) {
t . Errorf ( "test %d-ws, token '%v': expected ok, got %v" , i , token , err )
t . Errorf ( "test %d-ws, token '%v': expected ok, got %v" , i , token , err )
}
}
token = tokenFn ( )
token = tokenFn ( )
if resp := rpcRequest ( t , htUrl , "Authorization" , token ) ; resp . StatusCode != 200 {
if resp := rpcRequest ( t , htUrl , testMethod , "Authorization" , token ) ; resp . StatusCode != 200 {
t . Errorf ( "test %d-http, token '%v': expected ok, got %v" , i , token , resp . StatusCode )
t . Errorf ( "test %d-http, token '%v': expected ok, got %v" , i , token , resp . StatusCode )
}
}
}
}
@ -414,10 +438,176 @@ func TestJWT(t *testing.T) {
}
}
token = tokenFn ( )
token = tokenFn ( )
resp := rpcRequest ( t , htUrl , "Authorization" , token )
resp := rpcRequest ( t , htUrl , testMethod , "Authorization" , token )
if resp . StatusCode != http . StatusUnauthorized {
if resp . StatusCode != http . StatusUnauthorized {
t . Errorf ( "tc %d-http, token '%v': expected not to allow, got %v" , i , token , resp . StatusCode )
t . Errorf ( "tc %d-http, token '%v': expected not to allow, got %v" , i , token , resp . StatusCode )
}
}
}
}
srv . stop ( )
srv . stop ( )
}
}
func TestGzipHandler ( t * testing . T ) {
type gzipTest struct {
name string
handler http . HandlerFunc
status int
isGzip bool
header map [ string ] string
}
tests := [ ] gzipTest {
{
name : "Write" ,
handler : func ( w http . ResponseWriter , r * http . Request ) {
w . Write ( [ ] byte ( "response" ) )
} ,
isGzip : true ,
status : 200 ,
} ,
{
name : "WriteHeader" ,
handler : func ( w http . ResponseWriter , r * http . Request ) {
w . Header ( ) . Set ( "x-foo" , "bar" )
w . WriteHeader ( 205 )
w . Write ( [ ] byte ( "response" ) )
} ,
isGzip : true ,
status : 205 ,
header : map [ string ] string { "x-foo" : "bar" } ,
} ,
{
name : "WriteContentLength" ,
handler : func ( w http . ResponseWriter , r * http . Request ) {
w . Header ( ) . Set ( "content-length" , "8" )
w . Write ( [ ] byte ( "response" ) )
} ,
isGzip : true ,
status : 200 ,
} ,
{
name : "Flush" ,
handler : func ( w http . ResponseWriter , r * http . Request ) {
w . Write ( [ ] byte ( "res" ) )
w . ( http . Flusher ) . Flush ( )
w . Write ( [ ] byte ( "ponse" ) )
} ,
isGzip : true ,
status : 200 ,
} ,
{
name : "disable" ,
handler : func ( w http . ResponseWriter , r * http . Request ) {
w . Header ( ) . Set ( "transfer-encoding" , "identity" )
w . Header ( ) . Set ( "x-foo" , "bar" )
w . Write ( [ ] byte ( "response" ) )
} ,
isGzip : false ,
status : 200 ,
header : map [ string ] string { "x-foo" : "bar" } ,
} ,
{
name : "disable-WriteHeader" ,
handler : func ( w http . ResponseWriter , r * http . Request ) {
w . Header ( ) . Set ( "transfer-encoding" , "identity" )
w . Header ( ) . Set ( "x-foo" , "bar" )
w . WriteHeader ( 205 )
w . Write ( [ ] byte ( "response" ) )
} ,
isGzip : false ,
status : 205 ,
header : map [ string ] string { "x-foo" : "bar" } ,
} ,
}
for _ , test := range tests {
test := test
t . Run ( test . name , func ( t * testing . T ) {
srv := httptest . NewServer ( newGzipHandler ( test . handler ) )
defer srv . Close ( )
resp , err := http . Get ( srv . URL )
if err != nil {
t . Fatal ( err )
}
defer resp . Body . Close ( )
content , err := io . ReadAll ( resp . Body )
if err != nil {
t . Fatal ( err )
}
wasGzip := resp . Uncompressed
if string ( content ) != "response" {
t . Fatalf ( "wrong response content %q" , content )
}
if wasGzip != test . isGzip {
t . Fatalf ( "response gzipped == %t, want %t" , wasGzip , test . isGzip )
}
if resp . StatusCode != test . status {
t . Fatalf ( "response status == %d, want %d" , resp . StatusCode , test . status )
}
for name , expectedValue := range test . header {
if v := resp . Header . Get ( name ) ; v != expectedValue {
t . Fatalf ( "response header %s == %s, want %s" , name , v , expectedValue )
}
}
} )
}
}
func TestHTTPWriteTimeout ( t * testing . T ) {
const (
timeoutRes = ` { "jsonrpc":"2.0","id":1,"error": { "code":-32002,"message":"request timed out"}} `
greetRes = ` { "jsonrpc":"2.0","id":1,"result":"Hello"} `
)
// Set-up server
timeouts := rpc . DefaultHTTPTimeouts
timeouts . WriteTimeout = time . Second
srv := createAndStartServer ( t , & httpConfig { Modules : [ ] string { "test" } } , false , & wsConfig { } , & timeouts )
url := fmt . Sprintf ( "http://%v" , srv . listenAddr ( ) )
// Send normal request
t . Run ( "message" , func ( t * testing . T ) {
resp := rpcRequest ( t , url , "test_sleep" )
defer resp . Body . Close ( )
body , err := io . ReadAll ( resp . Body )
if err != nil {
t . Fatal ( err )
}
if string ( body ) != timeoutRes {
t . Errorf ( "wrong response. have %s, want %s" , string ( body ) , timeoutRes )
}
} )
// Batch request
t . Run ( "batch" , func ( t * testing . T ) {
want := fmt . Sprintf ( "[%s,%s,%s]" , greetRes , timeoutRes , timeoutRes )
resp := batchRpcRequest ( t , url , [ ] string { "test_greet" , "test_sleep" , "test_greet" } )
defer resp . Body . Close ( )
body , err := io . ReadAll ( resp . Body )
if err != nil {
t . Fatal ( err )
}
if string ( body ) != want {
t . Errorf ( "wrong response. have %s, want %s" , string ( body ) , want )
}
} )
}
func apis ( ) [ ] rpc . API {
return [ ] rpc . API {
{
Namespace : "test" ,
Service : & testService { } ,
} ,
}
}
type testService struct { }
func ( s * testService ) Greet ( ) string {
return "Hello"
}
func ( s * testService ) Sleep ( ) {
time . Sleep ( 1500 * time . Millisecond )
}