|
|
|
@ -2,14 +2,17 @@ package writefreely |
|
|
|
|
|
|
|
|
|
import ( |
|
|
|
|
"context" |
|
|
|
|
"encoding/hex" |
|
|
|
|
"encoding/json" |
|
|
|
|
"fmt" |
|
|
|
|
"github.com/gorilla/mux" |
|
|
|
|
"github.com/gorilla/sessions" |
|
|
|
|
"github.com/guregu/null/zero" |
|
|
|
|
"github.com/writeas/nerds/store" |
|
|
|
|
"github.com/writeas/web-core/auth" |
|
|
|
|
"github.com/writeas/web-core/log" |
|
|
|
|
"github.com/writeas/writefreely/config" |
|
|
|
|
"hash/fnv" |
|
|
|
|
"io" |
|
|
|
|
"io/ioutil" |
|
|
|
|
"net/http" |
|
|
|
@ -55,11 +58,12 @@ type OAuthDatastoreProvider interface { |
|
|
|
|
// OAuthDatastore provides a minimal interface of data store methods used in
|
|
|
|
|
// oauth functionality.
|
|
|
|
|
type OAuthDatastore interface { |
|
|
|
|
GenerateOAuthState(context.Context) (string, error) |
|
|
|
|
ValidateOAuthState(context.Context, string) error |
|
|
|
|
GetIDForRemoteUser(context.Context, int64) (int64, error) |
|
|
|
|
CreateUser(*config.Config, *User, string) error |
|
|
|
|
RecordRemoteUserID(context.Context, int64, int64) error |
|
|
|
|
ValidateOAuthState(context.Context, string, string, string) error |
|
|
|
|
GenerateOAuthState(context.Context, string, string) (string, error) |
|
|
|
|
|
|
|
|
|
CreateUser(*config.Config, *User, string) error |
|
|
|
|
GetUserForAuthByID(int64) (*User, error) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
@ -75,8 +79,8 @@ type oauthHandler struct { |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// buildAuthURL returns a URL used to initiate authentication.
|
|
|
|
|
func buildAuthURL(db OAuthDatastore, ctx context.Context, clientID, authLocation, callbackURL string) (string, error) { |
|
|
|
|
state, err := db.GenerateOAuthState(ctx) |
|
|
|
|
func buildAuthURL(db OAuthDatastore, ctx context.Context, provider, clientID, authLocation, callbackURL string) (string, error) { |
|
|
|
|
state, err := db.GenerateOAuthState(ctx, provider, clientID) |
|
|
|
|
if err != nil { |
|
|
|
|
return "", err |
|
|
|
|
} |
|
|
|
@ -95,9 +99,17 @@ func buildAuthURL(db OAuthDatastore, ctx context.Context, clientID, authLocation |
|
|
|
|
return u.String(), nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// app *App, w http.ResponseWriter, r *http.Request
|
|
|
|
|
func (h oauthHandler) viewOauthInit(w http.ResponseWriter, r *http.Request) { |
|
|
|
|
location, err := buildAuthURL(h.DB, r.Context(), h.Config.App.OAuthClientID, h.Config.App.OAuthProviderAuthLocation, h.Config.App.OAuthClientCallbackLocation) |
|
|
|
|
func (h oauthHandler) viewOauthInitWriteAs(w http.ResponseWriter, r *http.Request) { |
|
|
|
|
location, err := buildAuthURL(h.DB, r.Context(), "write.as", h.Config.App.OAuth.WriteAsClientID, h.Config.App.OAuth.WriteAsProviderAuthLocation, h.Config.App.OAuth.WriteAsClientCallbackLocation) |
|
|
|
|
if err != nil { |
|
|
|
|
failOAuthRequest(w, http.StatusInternalServerError, "could not prepare oauth redirect url") |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
http.Redirect(w, r, location, http.StatusTemporaryRedirect) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (h oauthHandler) viewOauthInitSlack(w http.ResponseWriter, r *http.Request) { |
|
|
|
|
location, err := buildAuthURL(h.DB, r.Context(), "slack", h.Config.App.OAuth.WriteAsClientID, h.Config.App.OAuth.WriteAsProviderAuthLocation, h.Config.App.OAuth.WriteAsClientCallbackLocation) |
|
|
|
|
if err != nil { |
|
|
|
|
failOAuthRequest(w, http.StatusInternalServerError, "could not prepare oauth redirect url") |
|
|
|
|
return |
|
|
|
@ -105,13 +117,37 @@ func (h oauthHandler) viewOauthInit(w http.ResponseWriter, r *http.Request) { |
|
|
|
|
http.Redirect(w, r, location, http.StatusTemporaryRedirect) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (h oauthHandler) viewOauthCallback(w http.ResponseWriter, r *http.Request) { |
|
|
|
|
func (h oauthHandler) configureRoutes(r *mux.Router) { |
|
|
|
|
if h.Config.App.OAuth.Enabled { |
|
|
|
|
if h.Config.App.OAuth.WriteAsClientID != "" { |
|
|
|
|
callbackHash := oauthProviderHash("write.as", h.Config.App.OAuth.WriteAsClientID) |
|
|
|
|
log.InfoLog.Println("write.as oauth callback URL", "/oauth/callback/"+callbackHash) |
|
|
|
|
r.HandleFunc("/oauth/write.as", h.viewOauthInitWriteAs).Methods("GET") |
|
|
|
|
r.HandleFunc("/oauth/callback/"+callbackHash, h.viewOauthCallback("write.as", h.Config.App.OAuth.WriteAsClientID)).Methods("GET") |
|
|
|
|
} |
|
|
|
|
if h.Config.App.OAuth.SlackClientID != "" { |
|
|
|
|
callbackHash := oauthProviderHash("slack", h.Config.App.OAuth.SlackClientID) |
|
|
|
|
log.InfoLog.Println("slack oauth callback URL", "/oauth/callback/"+callbackHash) |
|
|
|
|
r.HandleFunc("/oauth/slack", h.viewOauthInitSlack).Methods("GET") |
|
|
|
|
r.HandleFunc("/oauth/callback/"+callbackHash, h.viewOauthCallback("slack", h.Config.App.OAuth.SlackClientID)).Methods("GET") |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func oauthProviderHash(provider, clientID string) string { |
|
|
|
|
hasher := fnv.New32() |
|
|
|
|
return hex.EncodeToString(hasher.Sum([]byte(provider + clientID))) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (h oauthHandler) viewOauthCallback(provider, clientID string) http.HandlerFunc { |
|
|
|
|
return func(w http.ResponseWriter, r *http.Request) { |
|
|
|
|
ctx := r.Context() |
|
|
|
|
|
|
|
|
|
code := r.FormValue("code") |
|
|
|
|
state := r.FormValue("state") |
|
|
|
|
|
|
|
|
|
err := h.DB.ValidateOAuthState(ctx, state) |
|
|
|
|
err := h.DB.ValidateOAuthState(ctx, state, provider, clientID) |
|
|
|
|
if err != nil { |
|
|
|
|
failOAuthRequest(w, http.StatusInternalServerError, err.Error()) |
|
|
|
|
return |
|
|
|
@ -185,14 +221,15 @@ func (h oauthHandler) viewOauthCallback(w http.ResponseWriter, r *http.Request) |
|
|
|
|
if err = loginOrFail(h.Store, w, r, user); err != nil { |
|
|
|
|
failOAuthRequest(w, http.StatusInternalServerError, err.Error()) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (h oauthHandler) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) { |
|
|
|
|
form := url.Values{} |
|
|
|
|
form.Add("grant_type", "authorization_code") |
|
|
|
|
form.Add("redirect_uri", h.Config.App.OAuthClientCallbackLocation) |
|
|
|
|
form.Add("redirect_uri", h.Config.App.OAuth.WriteAsClientCallbackLocation) |
|
|
|
|
form.Add("code", code) |
|
|
|
|
req, err := http.NewRequest("POST", h.Config.App.OAuthProviderTokenLocation, strings.NewReader(form.Encode())) |
|
|
|
|
req, err := http.NewRequest("POST", h.Config.App.OAuth.WriteAsProviderTokenLocation, strings.NewReader(form.Encode())) |
|
|
|
|
if err != nil { |
|
|
|
|
return nil, err |
|
|
|
|
} |
|
|
|
@ -200,7 +237,7 @@ func (h oauthHandler) exchangeOauthCode(ctx context.Context, code string) (*Toke |
|
|
|
|
req.Header.Set("User-Agent", "writefreely") |
|
|
|
|
req.Header.Set("Accept", "application/json") |
|
|
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") |
|
|
|
|
req.SetBasicAuth(h.Config.App.OAuthClientID, h.Config.App.OAuthClientSecret) |
|
|
|
|
req.SetBasicAuth(h.Config.App.OAuth.WriteAsClientID, h.Config.App.OAuth.WriteAsClientSecret) |
|
|
|
|
|
|
|
|
|
resp, err := h.HttpClient.Do(req) |
|
|
|
|
if err != nil { |
|
|
|
@ -224,7 +261,7 @@ func (h oauthHandler) exchangeOauthCode(ctx context.Context, code string) (*Toke |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (h oauthHandler) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) { |
|
|
|
|
req, err := http.NewRequest("GET", h.Config.App.OAuthProviderInspectLocation, nil) |
|
|
|
|
req, err := http.NewRequest("GET", h.Config.App.OAuth.WriteAsProviderInspectLocation, nil) |
|
|
|
|
if err != nil { |
|
|
|
|
return nil, err |
|
|
|
|
} |
|
|
|
|