/ *
* Copyright © 2020 - 2021 A Bunch Tell LLC and respective authors .
*
* This file is part of WriteFreely .
*
* WriteFreely is free software : you can redistribute it and / or modify
* it under the terms of the GNU Affero General Public License , included
* in the LICENSE file in this source code package .
* /
package writefreely
import (
"context"
"errors"
"fmt"
"github.com/writeas/web-core/log"
"net/http"
"net/url"
"strings"
)
type genericOauthClient struct {
ClientID string
ClientSecret string
AuthLocation string
ExchangeLocation string
InspectLocation string
CallbackLocation string
Scope string
MapUserID string
MapUsername string
MapDisplayName string
MapEmail string
HttpClient HttpClient
}
var _ oauthClient = genericOauthClient { }
const (
genericOauthDisplayName = "OAuth"
)
func ( c genericOauthClient ) GetProvider ( ) string {
return "generic"
}
func ( c genericOauthClient ) GetClientID ( ) string {
return c . ClientID
}
func ( c genericOauthClient ) GetCallbackLocation ( ) string {
return c . CallbackLocation
}
func ( c genericOauthClient ) buildLoginURL ( state string ) ( string , error ) {
u , err := url . Parse ( c . AuthLocation )
if err != nil {
return "" , err
}
q := u . Query ( )
q . Set ( "client_id" , c . ClientID )
q . Set ( "redirect_uri" , c . CallbackLocation )
q . Set ( "response_type" , "code" )
q . Set ( "state" , state )
q . Set ( "scope" , c . Scope )
u . RawQuery = q . Encode ( )
return u . String ( ) , nil
}
func ( c genericOauthClient ) exchangeOauthCode ( ctx context . Context , code string ) ( * TokenResponse , error ) {
form := url . Values { }
form . Add ( "grant_type" , "authorization_code" )
form . Add ( "redirect_uri" , c . CallbackLocation )
form . Add ( "scope" , c . Scope )
form . Add ( "code" , code )
req , err := http . NewRequest ( "POST" , c . ExchangeLocation , strings . NewReader ( form . Encode ( ) ) )
if err != nil {
return nil , err
}
req . WithContext ( ctx )
req . Header . Set ( "User-Agent" , ServerUserAgent ( "" ) )
req . Header . Set ( "Accept" , "application/json" )
req . Header . Set ( "Content-Type" , "application/x-www-form-urlencoded" )
req . SetBasicAuth ( c . ClientID , c . ClientSecret )
resp , err := c . HttpClient . Do ( req )
if err != nil {
return nil , err
}
if resp . StatusCode != http . StatusOK {
return nil , errors . New ( "unable to exchange code for access token" )
}
var tokenResponse TokenResponse
if err := limitedJsonUnmarshal ( resp . Body , tokenRequestMaxLen , & tokenResponse ) ; err != nil {
return nil , err
}
if tokenResponse . Error != "" {
return nil , errors . New ( tokenResponse . Error )
}
return & tokenResponse , nil
}
func ( c genericOauthClient ) inspectOauthAccessToken ( ctx context . Context , accessToken string ) ( * InspectResponse , error ) {
req , err := http . NewRequest ( "GET" , c . InspectLocation , nil )
if err != nil {
return nil , err
}
req . WithContext ( ctx )
req . Header . Set ( "User-Agent" , ServerUserAgent ( "" ) )
req . Header . Set ( "Accept" , "application/json" )
req . Header . Set ( "Authorization" , "Bearer " + accessToken )
resp , err := c . HttpClient . Do ( req )
if err != nil {
return nil , err
}
if resp . StatusCode != http . StatusOK {
return nil , errors . New ( "unable to inspect access token" )
}
// since we don't know what the JSON from the server will look like, we create a
// generic interface and then map manually to values set in the config
var genericInterface map [ string ] interface { }
if err := limitedJsonUnmarshal ( resp . Body , infoRequestMaxLen , & genericInterface ) ; err != nil {
return nil , err
}
// map each relevant field in inspectResponse to the mapped field from the config
var inspectResponse InspectResponse
inspectResponse . UserID , _ = genericInterface [ c . MapUserID ] . ( string )
if inspectResponse . UserID == "" {
log . Error ( "[CONFIGURATION ERROR] Generic OAuth provider returned empty UserID value (`%s`).\n Do you need to configure a different `map_user_id` value for this provider?" , c . MapUserID )
return nil , fmt . Errorf ( "no UserID (`%s`) value returned" , c . MapUserID )
}
inspectResponse . Username , _ = genericInterface [ c . MapUsername ] . ( string )
inspectResponse . DisplayName , _ = genericInterface [ c . MapDisplayName ] . ( string )
inspectResponse . Email , _ = genericInterface [ c . MapEmail ] . ( string )
return & inspectResponse , nil
}