@ -130,8 +130,10 @@ type writestore interface {
GetIDForRemoteUser ( context . Context , string , string , string ) ( int64 , error )
GetIDForRemoteUser ( context . Context , string , string , string ) ( int64 , error )
RecordRemoteUserID ( context . Context , int64 , string , string , string , string ) error
RecordRemoteUserID ( context . Context , int64 , string , string , string , string ) error
ValidateOAuthState ( context . Context , string ) ( string , string , error )
ValidateOAuthState ( context . Context , string ) ( string , string , int64 , error )
GenerateOAuthState ( context . Context , string , string ) ( string , error )
GenerateOAuthState ( context . Context , string , string , int64 ) ( string , error )
GetOauthAccounts ( ctx context . Context , userID int64 ) ( [ ] oauthAccountInfo , error )
RemoveOauth ( ctx context . Context , userID int64 , provider string , clientID string , remoteUserID string ) error
DatabaseInitialized ( ) bool
DatabaseInitialized ( ) bool
}
}
@ -2510,20 +2512,24 @@ func (db *datastore) GetCollectionLastPostTime(id int64) (*time.Time, error) {
return & t , nil
return & t , nil
}
}
func ( db * datastore ) GenerateOAuthState ( ctx context . Context , provider , clientID string ) ( string , error ) {
func ( db * datastore ) GenerateOAuthState ( ctx context . Context , provider string , clientID string , attachUser int64 ) ( string , error ) {
state := store . Generate62RandomString ( 24 )
state := store . Generate62RandomString ( 24 )
_ , err := db . ExecContext ( ctx , "INSERT INTO oauth_client_states (state, provider, client_id, used, created_at) VALUES (?, ?, ?, FALSE, " + db . now ( ) + ")" , state , provider , clientID )
attachUserVal := sql . NullInt64 { Valid : attachUser > 0 , Int64 : attachUser }
_ , err := db . ExecContext ( ctx , "INSERT INTO oauth_client_states (state, provider, client_id, used, created_at, attach_user_id) VALUES (?, ?, ?, FALSE, " + db . now ( ) + ", ?)" , state , provider , clientID , attachUserVal )
if err != nil {
if err != nil {
return "" , fmt . Errorf ( "unable to record oauth client state: %w" , err )
return "" , fmt . Errorf ( "unable to record oauth client state: %w" , err )
}
}
return state , nil
return state , nil
}
}
func ( db * datastore ) ValidateOAuthState ( ctx context . Context , state string ) ( string , string , error ) {
func ( db * datastore ) ValidateOAuthState ( ctx context . Context , state string ) ( string , string , int64 , error ) {
var provider string
var provider string
var clientID string
var clientID string
var attachUserID sql . NullInt64
err := wf_db . RunTransactionWithOptions ( ctx , db . DB , & sql . TxOptions { } , func ( ctx context . Context , tx * sql . Tx ) error {
err := wf_db . RunTransactionWithOptions ( ctx , db . DB , & sql . TxOptions { } , func ( ctx context . Context , tx * sql . Tx ) error {
err := tx . QueryRow ( "SELECT provider, client_id FROM oauth_client_states WHERE state = ? AND used = FALSE" , state ) . Scan ( & provider , & clientID )
err := tx .
QueryRowContext ( ctx , "SELECT provider, client_id, attach_user_id FROM oauth_client_states WHERE state = ? AND used = FALSE" , state ) .
Scan ( & provider , & clientID , & attachUserID )
if err != nil {
if err != nil {
return err
return err
}
}
@ -2542,9 +2548,9 @@ func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (stri
return nil
return nil
} )
} )
if err != nil {
if err != nil {
return "" , "" , nil
return "" , "" , 0 , nil
}
}
return provider , clientID , nil
return provider , clientID , attachUserID . Int64 , nil
}
}
func ( db * datastore ) RecordRemoteUserID ( ctx context . Context , localUserID int64 , remoteUserID , provider , clientID , accessToken string ) error {
func ( db * datastore ) RecordRemoteUserID ( ctx context . Context , localUserID int64 , remoteUserID , provider , clientID , accessToken string ) error {
@ -2573,6 +2579,33 @@ func (db *datastore) GetIDForRemoteUser(ctx context.Context, remoteUserID, provi
return userID , nil
return userID , nil
}
}
type oauthAccountInfo struct {
Provider string
ClientID string
RemoteUserID string
}
func ( db * datastore ) GetOauthAccounts ( ctx context . Context , userID int64 ) ( [ ] oauthAccountInfo , error ) {
rows , err := db . QueryContext ( ctx , "SELECT provider, client_id, remote_user_id FROM oauth_users WHERE user_id = ? " , userID )
if err != nil {
log . Error ( "Failed selecting from oauth_users: %v" , err )
return nil , impart . HTTPError { http . StatusInternalServerError , "Couldn't retrieve user oauth accounts." }
}
defer rows . Close ( )
var records [ ] oauthAccountInfo
for rows . Next ( ) {
info := oauthAccountInfo { }
err = rows . Scan ( & info . Provider , & info . ClientID , & info . RemoteUserID )
if err != nil {
log . Error ( "Failed scanning GetAllUsers() row: %v" , err )
break
}
records = append ( records , info )
}
return records , nil
}
// DatabaseInitialized returns whether or not the current datastore has been
// DatabaseInitialized returns whether or not the current datastore has been
// initialized with the correct schema.
// initialized with the correct schema.
// Currently, it checks to see if the `users` table exists.
// Currently, it checks to see if the `users` table exists.
@ -2595,6 +2628,11 @@ func (db *datastore) DatabaseInitialized() bool {
return true
return true
}
}
func ( db * datastore ) RemoveOauth ( ctx context . Context , userID int64 , provider string , clientID string , remoteUserID string ) error {
_ , err := db . ExecContext ( ctx , ` DELETE FROM oauth_users WHERE user_id = ? AND provider = ? AND client_id = ? AND remote_user_id = ? ` , userID , provider , clientID , remoteUserID )
return err
}
func stringLogln ( log * string , s string , v ... interface { } ) {
func stringLogln ( log * string , s string , v ... interface { } ) {
* log += fmt . Sprintf ( s + "\n" , v ... )
* log += fmt . Sprintf ( s + "\n" , v ... )
}
}