@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/gorilla/mux"
"github.com/gorilla/sessions"
"github.com/guregu/null/zero"
"github.com/writeas/impart"
@ -14,8 +15,6 @@ import (
"io"
"io/ioutil"
"net/http"
"net/url"
"strings"
"time"
)
@ -31,11 +30,13 @@ type TokenResponse struct {
// InspectResponse contains data returned when an access token is inspected.
type InspectResponse struct {
ClientID string ` json:"client_id" `
UserID int64 ` json:"user_id" `
ExpiresAt time . Time ` json:"expires_at" `
Username string ` json:"username" `
Email string ` json:"email" `
ClientID string ` json:"client_id" `
UserID string ` json:"user_id" `
ExpiresAt time . Time ` json:"expires_at" `
Username string ` json:"username" `
DisplayName string ` json:"-" `
Email string ` json:"email" `
Error string ` json:"error" `
}
// tokenRequestMaxLen is the most bytes that we'll read from the /oauth/token
@ -57,11 +58,12 @@ type OAuthDatastoreProvider interface {
// OAuthDatastore provides a minimal interface of data store methods used in
// oauth functionality.
type OAuthDatastore interface {
GenerateOAuthState ( context . Context ) ( string , error )
ValidateOAuthState ( context . Context , string ) error
GetIDForRemoteUser ( context . Context , int64 ) ( int64 , error )
GetIDForRemoteUser ( context . Context , string , string , string ) ( int64 , error )
RecordRemoteUserID ( context . Context , int64 , string , string , string , string ) error
ValidateOAuthState ( context . Context , string ) ( string , string , error )
GenerateOAuthState ( context . Context , string , string ) ( string , error )
CreateUser ( * config . Config , * User , string ) error
RecordRemoteUserID ( context . Context , int64 , int64 ) error
GetUserForAuthByID ( int64 ) ( * User , error )
}
@ -69,41 +71,74 @@ type HttpClient interface {
Do ( req * http . Request ) ( * http . Response , error )
}
type oauthClient interface {
GetProvider ( ) string
GetClientID ( ) string
buildLoginURL ( state string ) ( string , error )
exchangeOauthCode ( ctx context . Context , code string ) ( * TokenResponse , error )
inspectOauthAccessToken ( ctx context . Context , accessToken string ) ( * InspectResponse , error )
}
type oauthHandler struct {
Config * config . Config
DB OAuthDatastore
Store sessions . Store
HttpClient HttpClient
Config * config . Config
DB OAuthDatastore
Store sessions . Store
oauthClient oauth Client
}
// buildAuthURL returns a URL used to initiate authentication.
func buildAuthURL ( db OAuthDatastore , ctx context . Context , clientID , authLocation , callbackURL string ) ( string , error ) {
state , err := db . GenerateOAuthState ( ctx )
func ( h oauthHandler ) viewOauthInit ( app * App , w http . ResponseWriter , r * http . Request ) error {
ctx := r . Context ( )
state , err := h . DB . GenerateOAuthState ( ctx , h . oauthClient . GetProvider ( ) , h . oauthClient . GetClientID ( ) )
if err != nil {
return impart . HTTPError { http . StatusInternalServerError , "could not prepare oauth redirect url" }
}
location , err := h . oauthClient . buildLoginURL ( state )
if err != nil {
return "" , err
return impart . HTTPError { http . StatusInternalServerError , "could not prepare oauth redirect url" }
}
return impart . HTTPError { http . StatusTemporaryRedirect , location }
}
u , err := url . Parse ( authLocation )
if err != nil {
return "" , err
func configureSlackOauth ( parentHandler * Handler , r * mux . Router , app * App ) {
if app . Config ( ) . SlackOauth . ClientID != "" {
oauthClient := slackOauthClient {
ClientID : app . Config ( ) . SlackOauth . ClientID ,
ClientSecret : app . Config ( ) . SlackOauth . ClientSecret ,
TeamID : app . Config ( ) . SlackOauth . TeamID ,
CallbackLocation : app . Config ( ) . App . Host + "/oauth/callback" ,
HttpClient : config . DefaultHTTPClient ( ) ,
}
configureOauthRoutes ( parentHandler , r , app , oauthClient )
}
q := u . Query ( )
q . Set ( "client_id" , clientID )
q . Set ( "redirect_uri" , callbackURL )
q . Set ( "response_type" , "code" )
q . Set ( "state" , state )
u . RawQuery = q . Encode ( )
}
func configureWriteAsOauth ( parentHandler * Handler , r * mux . Router , app * App ) {
if app . Config ( ) . WriteAsOauth . ClientID != "" {
oauthClient := writeAsOauthClient {
ClientID : app . Config ( ) . WriteAsOauth . ClientID ,
ClientSecret : app . Config ( ) . WriteAsOauth . ClientSecret ,
ExchangeLocation : config . OrDefaultString ( app . Config ( ) . WriteAsOauth . TokenLocation , writeAsExchangeLocation ) ,
InspectLocation : config . OrDefaultString ( app . Config ( ) . WriteAsOauth . InspectLocation , writeAsIdentityLocation ) ,
AuthLocation : config . OrDefaultString ( app . Config ( ) . WriteAsOauth . AuthLocation , writeAsAuthLocation ) ,
HttpClient : config . DefaultHTTPClient ( ) ,
CallbackLocation : app . Config ( ) . App . Host + "/oauth/callback" ,
}
if oauthClient . ExchangeLocation == "" {
return u . String ( ) , nil
}
configureOauthRoutes ( parentHandler , r , app , oauthClient )
}
}
// app *App, w http.ResponseWriter, r *http.Request
func ( h oauthHandler ) viewOauthInit ( app * App , w http . ResponseWriter , r * http . Request ) error {
location , err := buildAuthURL ( h . DB , r . Context ( ) , h . Config . App . OAuthClientID , h . Config . App . OAuthProviderAuthLocation , h . Config . App . OAuthClientCallbackLocation )
if err != nil {
return impart . HTTPError { http . StatusInternalServerError , "could not prepare oauth redirect url" }
func configureOauthRoutes ( parentHandler * Handler , r * mux . Router , app * App , oauthClient oauthClient ) {
handler := & oauthHandler {
Config : app . Config ( ) ,
DB : app . DB ( ) ,
Store : app . SessionStore ( ) ,
oauthClient : oauthClient ,
}
return impart . HTTPError { http . StatusTemporaryRedirect , location }
r . HandleFunc ( "/oauth/" + oauthClient . GetProvider ( ) , parentHandler . OAuth ( handler . viewOauthInit ) ) . Methods ( "GET" )
r . HandleFunc ( "/oauth/callback" , parentHandler . OAuth ( handler . viewOauthCallback ) ) . Methods ( "GET" )
}
func ( h oauthHandler ) viewOauthCallback ( app * App , w http . ResponseWriter , r * http . Request ) error {
@ -112,13 +147,13 @@ func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http
code := r . FormValue ( "code" )
state := r . FormValue ( "state" )
err := h . DB . ValidateOAuthState ( ctx , state )
provider , clientID , err := h . DB . ValidateOAuthState ( ctx , state )
if err != nil {
log . Error ( "Unable to ValidateOAuthState: %s" , err )
return impart . HTTPError { http . StatusInternalServerError , err . Error ( ) }
}
tokenResponse , err := h . exchangeOauthCode ( ctx , code )
tokenResponse , err := h . oauthClient . exchangeOauthCode ( ctx , code )
if err != nil {
log . Error ( "Unable to exchangeOauthCode: %s" , err )
return impart . HTTPError { http . StatusInternalServerError , err . Error ( ) }
@ -126,20 +161,18 @@ func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http
// Now that we have the access token, let's use it real quick to make sur
// it really really works.
tokenInfo , err := h . inspectOauthAccessToken ( ctx , tokenResponse . AccessToken )
tokenInfo , err := h . oauthClient . inspectOauthAccessToken ( ctx , tokenResponse . AccessToken )
if err != nil {
log . Error ( "Unable to inspectOauthAccessToken: %s" , err )
return impart . HTTPError { http . StatusInternalServerError , err . Error ( ) }
}
localUserID , err := h . DB . GetIDForRemoteUser ( ctx , tokenInfo . UserID )
localUserID , err := h . DB . GetIDForRemoteUser ( ctx , tokenInfo . UserID , provider , clientID )
if err != nil {
log . Error ( "Unable to GetIDForRemoteUser: %s" , err )
return impart . HTTPError { http . StatusInternalServerError , err . Error ( ) }
}
fmt . Println ( "local user id" , localUserID )
if localUserID == - 1 {
// We don't have, nor do we want, the password from the origin, so we
//create a random string. If the user needs to set a password, they
@ -148,23 +181,26 @@ func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http
randPass := store . Generate62RandomString ( 14 )
hashedPass , err := auth . HashPass ( [ ] byte ( randPass ) )
if err != nil {
log . ErrorLog . Println ( err )
return impart . HTTPError { http . StatusInternalServerError , "unable to create password hash" }
}
newUser := & User {
Username : tokenInfo . Username ,
HashedPass : hashedPass ,
HasPass : true ,
Email : zero . NewString ( "" , tokenInfo . Email != "" ) ,
Email : zero . NewString ( tokenInfo . Email , tokenInfo . Email != "" ) ,
Created : time . Now ( ) . Truncate ( time . Second ) . UTC ( ) ,
}
displayName := tokenInfo . DisplayName
if len ( displayName ) == 0 {
displayName = tokenInfo . Username
}
err = h . DB . CreateUser ( h . Config , newUser , newUser . Username )
err = h . DB . CreateUser ( h . Config , newUser , displayN ame)
if err != nil {
return impart . HTTPError { http . StatusInternalServerError , err . Error ( ) }
}
err = h . DB . RecordRemoteUserID ( ctx , newUser . ID , tokenInfo . UserID )
err = h . DB . RecordRemoteUserID ( ctx , newUser . ID , tokenInfo . UserID , provider , clientID , tokenResponse . AccessToken )
if err != nil {
return impart . HTTPError { http . StatusInternalServerError , err . Error ( ) }
}
@ -185,76 +221,16 @@ func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http
return nil
}
func ( h oauthHandler ) exchangeOauthCode ( ctx context . Context , code string ) ( * TokenResponse , error ) {
form := url . Values { }
form . Add ( "grant_type" , "authorization_code" )
form . Add ( "redirect_uri" , h . Config . App . OAuthClientCallbackLocation )
form . Add ( "code" , code )
req , err := http . NewRequest ( "POST" , h . Config . App . OAuthProviderTokenLocation , strings . NewReader ( form . Encode ( ) ) )
func limitedJsonUnmarshal ( body io . ReadCloser , n int , thing interface { } ) error {
lr := io . LimitReader ( body , int64 ( n + 1 ) )
data , err := ioutil . ReadAll ( lr )
if err != nil {
return nil , err
}
req . WithContext ( ctx )
req . Header . Set ( "User-Agent" , "writefreely" )
req . Header . Set ( "Accept" , "application/json" )
req . Header . Set ( "Content-Type" , "application/x-www-form-urlencoded" )
req . SetBasicAuth ( h . Config . App . OAuthClientID , h . Config . App . OAuthClientSecret )
resp , err := h . HttpClient . Do ( req )
if err != nil {
return nil , err
}
// Nick: I like using limited readers to reduce the risk of an endpoint
// being broken or compromised.
lr := io . LimitReader ( resp . Body , tokenRequestMaxLen )
body , err := ioutil . ReadAll ( lr )
if err != nil {
return nil , err
}
var tokenResponse TokenResponse
err = json . Unmarshal ( body , & tokenResponse )
if err != nil {
return nil , err
}
// Check the response for an error message, and return it if there is one.
if tokenResponse . Error != "" {
return nil , fmt . Errorf ( tokenResponse . Error )
}
return & tokenResponse , nil
}
func ( h oauthHandler ) inspectOauthAccessToken ( ctx context . Context , accessToken string ) ( * InspectResponse , error ) {
req , err := http . NewRequest ( "GET" , h . Config . App . OAuthProviderInspectLocation , nil )
if err != nil {
return nil , err
}
req . WithContext ( ctx )
req . Header . Set ( "User-Agent" , "writefreely" )
req . Header . Set ( "Accept" , "application/json" )
req . Header . Set ( "Authorization" , "Bearer " + accessToken )
resp , err := h . HttpClient . Do ( req )
if err != nil {
return nil , err
}
// Nick: I like using limited readers to reduce the risk of an endpoint
// being broken or compromised.
lr := io . LimitReader ( resp . Body , infoRequestMaxLen )
body , err := ioutil . ReadAll ( lr )
if err != nil {
return nil , err
return err
}
var inspectResponse InspectResponse
err = json . Unmarshal ( body , & inspectResponse )
if err != nil {
return nil , err
if len ( data ) == n + 1 {
return fmt . Errorf ( "content larger than max read allowance: %d" , n )
}
return & inspectResponse , nil
return json . Unmarshal ( data , thing )
}
func loginOrFail ( store sessions . Store , w http . ResponseWriter , r * http . Request , user * User ) error {