package writefreely
import (
"context"
"database/sql"
"github.com/stretchr/testify/assert"
"testing"
)
func TestOAuthDatastore ( t * testing . T ) {
if ! runMySQLTests ( ) {
t . Skip ( "skipping mysql tests" )
}
withTestDB ( t , func ( db * sql . DB ) {
ctx := context . Background ( )
ds := & datastore {
DB : db ,
driverName : "" ,
}
state , err := ds . GenerateOAuthState ( ctx , "test" , "development" )
assert . NoError ( t , err )
assert . Len ( t , state , 24 )
countRows ( t , ctx , db , 1 , "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = false" , state )
_ , _ , err = ds . ValidateOAuthState ( ctx , state )
assert . NoError ( t , err )
countRows ( t , ctx , db , 1 , "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = true" , state )
var localUserID int64 = 99
var remoteUserID = "100"
err = ds . RecordRemoteUserID ( ctx , localUserID , remoteUserID , "test" , "test" , "access_token_a" )
assert . NoError ( t , err )
countRows ( t , ctx , db , 1 , "SELECT COUNT(*) FROM `oauth_users` WHERE `user_id` = ? AND `remote_user_id` = ? AND access_token = 'access_token_a'" , localUserID , remoteUserID )
err = ds . RecordRemoteUserID ( ctx , localUserID , remoteUserID , "test" , "test" , "access_token_b" )
assert . NoError ( t , err )
countRows ( t , ctx , db , 1 , "SELECT COUNT(*) FROM `oauth_users` WHERE `user_id` = ? AND `remote_user_id` = ? AND access_token = 'access_token_b'" , localUserID , remoteUserID )
countRows ( t , ctx , db , 1 , "SELECT COUNT(*) FROM `oauth_users`" )
foundUserID , err := ds . GetIDForRemoteUser ( ctx , remoteUserID , "test" , "test" )
assert . NoError ( t , err )
assert . Equal ( t , localUserID , foundUserID )
} )
}