@ -18,7 +18,10 @@ package node
import (
"bytes"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"testing"
@ -31,25 +34,27 @@ import (
// TestCorsHandler makes sure CORS are properly handled on the http server.
func TestCorsHandler ( t * testing . T ) {
srv := createAndStartServer ( t , httpConfig { CorsAllowedOrigins : [ ] string { "test" , "test.com" } } , false , wsConfig { } )
srv := createAndStartServer ( t , & httpConfig { CorsAllowedOrigins : [ ] string { "test" , "test.com" } } , false , & wsConfig { } )
defer srv . stop ( )
url := "http://" + srv . listenAddr ( )
resp := test Request( t , "origin" , "test.com" , "" , srv )
resp := rpc Request( t , url , "origin" , "test.com" )
assert . Equal ( t , "test.com" , resp . Header . Get ( "Access-Control-Allow-Origin" ) )
resp2 := test Request( t , "origin" , "bad" , "" , srv )
resp2 := rpc Request( t , url , "origin" , "bad" )
assert . Equal ( t , "" , resp2 . Header . Get ( "Access-Control-Allow-Origin" ) )
}
// TestVhosts makes sure vhosts are properly handled on the http server.
func TestVhosts ( t * testing . T ) {
srv := createAndStartServer ( t , httpConfig { Vhosts : [ ] string { "test" } } , false , wsConfig { } )
srv := createAndStartServer ( t , & httpConfig { Vhosts : [ ] string { "test" } } , false , & wsConfig { } )
defer srv . stop ( )
url := "http://" + srv . listenAddr ( )
resp := test Request( t , "" , "" , "test" , srv )
resp := rpc Request( t , url , "host " , "test" )
assert . Equal ( t , resp . StatusCode , http . StatusOK )
resp2 := test Request( t , "" , "" , "bad" , srv )
resp2 := rpc Request( t , url , "host " , "bad" )
assert . Equal ( t , resp2 . StatusCode , http . StatusForbidden )
}
@ -138,14 +143,15 @@ func TestWebsocketOrigins(t *testing.T) {
} ,
}
for _ , tc := range tests {
srv := createAndStartServer ( t , httpConfig { } , true , wsConfig { Origins : splitAndTrim ( tc . spec ) } )
srv := createAndStartServer ( t , & httpConfig { } , true , & wsConfig { Origins : splitAndTrim ( tc . spec ) } )
url := fmt . Sprintf ( "ws://%v" , srv . listenAddr ( ) )
for _ , origin := range tc . expOk {
if err := attemptWebsocketConnectionFromOrigin ( t , srv , origin ) ; err != nil {
if err := wsRequest ( t , url , origin ) ; err != nil {
t . Errorf ( "spec '%v', origin '%v': expected ok, got %v" , tc . spec , origin , err )
}
}
for _ , origin := range tc . expFail {
if err := attemptWebsocketConnectionFromOrigin ( t , srv , origin ) ; err == nil {
if err := wsRequest ( t , url , origin ) ; err == nil {
t . Errorf ( "spec '%v', origin '%v': expected not to allow, got ok" , tc . spec , origin )
}
}
@ -168,47 +174,118 @@ func TestIsWebsocket(t *testing.T) {
assert . True ( t , isWebsocket ( r ) )
}
func createAndStartServer ( t * testing . T , conf httpConfig , ws bool , wsConf wsConfig ) * httpServer {
func Test_checkPath ( t * testing . T ) {
tests := [ ] struct {
req * http . Request
prefix string
expected bool
} {
{
req : & http . Request { URL : & url . URL { Path : "/test" } } ,
prefix : "/test" ,
expected : true ,
} ,
{
req : & http . Request { URL : & url . URL { Path : "/testing" } } ,
prefix : "/test" ,
expected : true ,
} ,
{
req : & http . Request { URL : & url . URL { Path : "/" } } ,
prefix : "/test" ,
expected : false ,
} ,
{
req : & http . Request { URL : & url . URL { Path : "/fail" } } ,
prefix : "/test" ,
expected : false ,
} ,
{
req : & http . Request { URL : & url . URL { Path : "/" } } ,
prefix : "" ,
expected : true ,
} ,
{
req : & http . Request { URL : & url . URL { Path : "/fail" } } ,
prefix : "" ,
expected : false ,
} ,
{
req : & http . Request { URL : & url . URL { Path : "/" } } ,
prefix : "/" ,
expected : true ,
} ,
{
req : & http . Request { URL : & url . URL { Path : "/testing" } } ,
prefix : "/" ,
expected : true ,
} ,
}
for i , tt := range tests {
t . Run ( strconv . Itoa ( i ) , func ( t * testing . T ) {
assert . Equal ( t , tt . expected , checkPath ( tt . req , tt . prefix ) )
} )
}
}
func createAndStartServer ( t * testing . T , conf * httpConfig , ws bool , wsConf * wsConfig ) * httpServer {
t . Helper ( )
srv := newHTTPServer ( testlog . Logger ( t , log . LvlDebug ) , rpc . DefaultHTTPTimeouts )
assert . NoError ( t , srv . enableRPC ( nil , conf ) )
assert . NoError ( t , srv . enableRPC ( nil , * conf ) )
if ws {
assert . NoError ( t , srv . enableWS ( nil , wsConf ) )
assert . NoError ( t , srv . enableWS ( nil , * wsConf ) )
}
assert . NoError ( t , srv . setListenAddr ( "localhost" , 0 ) )
assert . NoError ( t , srv . start ( ) )
return srv
}
func attemptWebsocketConnectionFromOrigin ( t * testing . T , srv * httpServer , browserOrigin string ) error {
// wsRequest attempts to open a WebSocket connection to the given URL.
func wsRequest ( t * testing . T , url , browserOrigin string ) error {
t . Helper ( )
dialer := websocket . DefaultDialer
_ , _ , err := dialer . Dial ( "ws://" + srv . listenAddr ( ) , http . Header {
"Content-type" : [ ] string { "application/json" } ,
"Sec-WebSocket-Version" : [ ] string { "13" } ,
"Origin" : [ ] string { browserOrigin } ,
} )
t . Logf ( "checking WebSocket on %s (origin %q)" , url , browserOrigin )
headers := make ( http . Header )
if browserOrigin != "" {
headers . Set ( "Origin" , browserOrigin )
}
conn , _ , err := websocket . DefaultDialer . Dial ( url , headers )
if conn != nil {
conn . Close ( )
}
return err
}
func testRequest ( t * testing . T , key , value , host string , srv * httpServer ) * http . Response {
// rpcRequest performs a JSON-RPC request to the given URL.
func rpcRequest ( t * testing . T , url string , extraHeaders ... string ) * http . Response {
t . Helper ( )
body := bytes . NewReader ( [ ] byte ( ` { "jsonrpc":"2.0","id":1,method":"rpc_modules"} ` ) )
req , _ := http . NewRequest ( "POST" , "http://" + srv . listenAddr ( ) , body )
// Create the request.
body := bytes . NewReader ( [ ] byte ( ` { "jsonrpc":"2.0","id":1,"method":"rpc_modules","params":[]} ` ) )
req , err := http . NewRequest ( "POST" , url , body )
if err != nil {
t . Fatal ( "could not create http request:" , err )
}
req . Header . Set ( "content-type" , "application/json" )
if key != "" && value != "" {
// Apply extra headers.
if len ( extraHeaders ) % 2 != 0 {
panic ( "odd extraHeaders length" )
}
for i := 0 ; i < len ( extraHeaders ) ; i += 2 {
key , value := extraHeaders [ i ] , extraHeaders [ i + 1 ]
if strings . ToLower ( key ) == "host" {
req . Host = value
} else {
req . Header . Set ( key , value )
}
if host != "" {
req . Host = host
}
client := http . DefaultClient
resp , err := client . Do ( req )
// Perform the request.
t . Logf ( "checking RPC/HTTP on %s %v" , url , extraHeaders )
resp , err := http . DefaultClient . Do ( req )
if err != nil {
t . Fatal ( err )
}