|
|
@ -7,6 +7,7 @@ import ( |
|
|
|
"github.com/gorilla/mux" |
|
|
|
"github.com/gorilla/mux" |
|
|
|
"github.com/gorilla/sessions" |
|
|
|
"github.com/gorilla/sessions" |
|
|
|
"github.com/guregu/null/zero" |
|
|
|
"github.com/guregu/null/zero" |
|
|
|
|
|
|
|
"github.com/writeas/impart" |
|
|
|
"github.com/writeas/nerds/store" |
|
|
|
"github.com/writeas/nerds/store" |
|
|
|
"github.com/writeas/web-core/auth" |
|
|
|
"github.com/writeas/web-core/auth" |
|
|
|
"github.com/writeas/web-core/log" |
|
|
|
"github.com/writeas/web-core/log" |
|
|
@ -85,21 +86,20 @@ type oauthHandler struct { |
|
|
|
oauthClient oauthClient |
|
|
|
oauthClient oauthClient |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func (h oauthHandler) viewOauthInit(w http.ResponseWriter, r *http.Request) { |
|
|
|
func (h oauthHandler) viewOauthInit(app *App, w http.ResponseWriter, r *http.Request) error { |
|
|
|
ctx := r.Context() |
|
|
|
ctx := r.Context() |
|
|
|
state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID()) |
|
|
|
state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID()) |
|
|
|
if err != nil { |
|
|
|
if err != nil { |
|
|
|
failOAuthRequest(w, http.StatusInternalServerError, "could not prepare oauth redirect url") |
|
|
|
return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"} |
|
|
|
} |
|
|
|
} |
|
|
|
location, err := h.oauthClient.buildLoginURL(state) |
|
|
|
location, err := h.oauthClient.buildLoginURL(state) |
|
|
|
if err != nil { |
|
|
|
if err != nil { |
|
|
|
failOAuthRequest(w, http.StatusInternalServerError, "could not prepare oauth redirect url") |
|
|
|
return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"} |
|
|
|
return |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
http.Redirect(w, r, location, http.StatusTemporaryRedirect) |
|
|
|
return impart.HTTPError{http.StatusTemporaryRedirect, location} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func configureSlackOauth(r *mux.Router, app *App) { |
|
|
|
func configureSlackOauth(parentHandler *Handler, r *mux.Router, app *App) { |
|
|
|
if app.Config().SlackOauth.ClientID != "" { |
|
|
|
if app.Config().SlackOauth.ClientID != "" { |
|
|
|
oauthClient := slackOauthClient{ |
|
|
|
oauthClient := slackOauthClient{ |
|
|
|
ClientID: app.Config().SlackOauth.ClientID, |
|
|
|
ClientID: app.Config().SlackOauth.ClientID, |
|
|
@ -108,11 +108,11 @@ func configureSlackOauth(r *mux.Router, app *App) { |
|
|
|
CallbackLocation: app.Config().App.Host + "/oauth/callback", |
|
|
|
CallbackLocation: app.Config().App.Host + "/oauth/callback", |
|
|
|
HttpClient: config.DefaultHTTPClient(), |
|
|
|
HttpClient: config.DefaultHTTPClient(), |
|
|
|
} |
|
|
|
} |
|
|
|
configureOauthRoutes(r, app, oauthClient) |
|
|
|
configureOauthRoutes(parentHandler, r, app, oauthClient) |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func configureWriteAsOauth(r *mux.Router, app *App) { |
|
|
|
func configureWriteAsOauth(parentHandler *Handler, r *mux.Router, app *App) { |
|
|
|
if app.Config().WriteAsOauth.ClientID != "" { |
|
|
|
if app.Config().WriteAsOauth.ClientID != "" { |
|
|
|
oauthClient := writeAsOauthClient{ |
|
|
|
oauthClient := writeAsOauthClient{ |
|
|
|
ClientID: app.Config().WriteAsOauth.ClientID, |
|
|
|
ClientID: app.Config().WriteAsOauth.ClientID, |
|
|
@ -126,22 +126,22 @@ func configureWriteAsOauth(r *mux.Router, app *App) { |
|
|
|
if oauthClient.ExchangeLocation == "" { |
|
|
|
if oauthClient.ExchangeLocation == "" { |
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
configureOauthRoutes(r, app, oauthClient) |
|
|
|
configureOauthRoutes(parentHandler, r, app, oauthClient) |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func configureOauthRoutes(r *mux.Router, app *App, oauthClient oauthClient) { |
|
|
|
func configureOauthRoutes(parentHandler *Handler, r *mux.Router, app *App, oauthClient oauthClient) { |
|
|
|
handler := &oauthHandler{ |
|
|
|
handler := &oauthHandler{ |
|
|
|
Config: app.Config(), |
|
|
|
Config: app.Config(), |
|
|
|
DB: app.DB(), |
|
|
|
DB: app.DB(), |
|
|
|
Store: app.SessionStore(), |
|
|
|
Store: app.SessionStore(), |
|
|
|
oauthClient: oauthClient, |
|
|
|
oauthClient: oauthClient, |
|
|
|
} |
|
|
|
} |
|
|
|
r.HandleFunc("/oauth/"+oauthClient.GetProvider(), handler.viewOauthInit).Methods("GET") |
|
|
|
r.HandleFunc("/oauth/"+oauthClient.GetProvider(), parentHandler.OAuth(handler.viewOauthInit)).Methods("GET") |
|
|
|
r.HandleFunc("/oauth/callback", handler.viewOauthCallback).Methods("GET") |
|
|
|
r.HandleFunc("/oauth/callback", parentHandler.OAuth(handler.viewOauthCallback)).Methods("GET") |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func (h oauthHandler) viewOauthCallback(w http.ResponseWriter, r *http.Request) { |
|
|
|
func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http.Request) error { |
|
|
|
ctx := r.Context() |
|
|
|
ctx := r.Context() |
|
|
|
|
|
|
|
|
|
|
|
code := r.FormValue("code") |
|
|
|
code := r.FormValue("code") |
|
|
@ -150,15 +150,13 @@ func (h oauthHandler) viewOauthCallback(w http.ResponseWriter, r *http.Request) |
|
|
|
provider, clientID, err := h.DB.ValidateOAuthState(ctx, state) |
|
|
|
provider, clientID, err := h.DB.ValidateOAuthState(ctx, state) |
|
|
|
if err != nil { |
|
|
|
if err != nil { |
|
|
|
log.Error("Unable to ValidateOAuthState: %s", err) |
|
|
|
log.Error("Unable to ValidateOAuthState: %s", err) |
|
|
|
failOAuthRequest(w, http.StatusInternalServerError, err.Error()) |
|
|
|
return impart.HTTPError{http.StatusInternalServerError, err.Error()} |
|
|
|
return |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
tokenResponse, err := h.oauthClient.exchangeOauthCode(ctx, code) |
|
|
|
tokenResponse, err := h.oauthClient.exchangeOauthCode(ctx, code) |
|
|
|
if err != nil { |
|
|
|
if err != nil { |
|
|
|
log.Error("Unable to exchangeOauthCode: %s", err) |
|
|
|
log.Error("Unable to exchangeOauthCode: %s", err) |
|
|
|
failOAuthRequest(w, http.StatusInternalServerError, err.Error()) |
|
|
|
return impart.HTTPError{http.StatusInternalServerError, err.Error()} |
|
|
|
return |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// Now that we have the access token, let's use it real quick to make sur
|
|
|
|
// Now that we have the access token, let's use it real quick to make sur
|
|
|
@ -166,15 +164,13 @@ func (h oauthHandler) viewOauthCallback(w http.ResponseWriter, r *http.Request) |
|
|
|
tokenInfo, err := h.oauthClient.inspectOauthAccessToken(ctx, tokenResponse.AccessToken) |
|
|
|
tokenInfo, err := h.oauthClient.inspectOauthAccessToken(ctx, tokenResponse.AccessToken) |
|
|
|
if err != nil { |
|
|
|
if err != nil { |
|
|
|
log.Error("Unable to inspectOauthAccessToken: %s", err) |
|
|
|
log.Error("Unable to inspectOauthAccessToken: %s", err) |
|
|
|
failOAuthRequest(w, http.StatusInternalServerError, err.Error()) |
|
|
|
return impart.HTTPError{http.StatusInternalServerError, err.Error()} |
|
|
|
return |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
localUserID, err := h.DB.GetIDForRemoteUser(ctx, tokenInfo.UserID, provider, clientID) |
|
|
|
localUserID, err := h.DB.GetIDForRemoteUser(ctx, tokenInfo.UserID, provider, clientID) |
|
|
|
if err != nil { |
|
|
|
if err != nil { |
|
|
|
log.Error("Unable to GetIDForRemoteUser: %s", err) |
|
|
|
log.Error("Unable to GetIDForRemoteUser: %s", err) |
|
|
|
failOAuthRequest(w, http.StatusInternalServerError, err.Error()) |
|
|
|
return impart.HTTPError{http.StatusInternalServerError, err.Error()} |
|
|
|
return |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if localUserID == -1 { |
|
|
|
if localUserID == -1 { |
|
|
@ -185,8 +181,7 @@ func (h oauthHandler) viewOauthCallback(w http.ResponseWriter, r *http.Request) |
|
|
|
randPass := store.Generate62RandomString(14) |
|
|
|
randPass := store.Generate62RandomString(14) |
|
|
|
hashedPass, err := auth.HashPass([]byte(randPass)) |
|
|
|
hashedPass, err := auth.HashPass([]byte(randPass)) |
|
|
|
if err != nil { |
|
|
|
if err != nil { |
|
|
|
failOAuthRequest(w, http.StatusInternalServerError, "unable to create password hash") |
|
|
|
return impart.HTTPError{http.StatusInternalServerError, "unable to create password hash"} |
|
|
|
return |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
newUser := &User{ |
|
|
|
newUser := &User{ |
|
|
|
Username: tokenInfo.Username, |
|
|
|
Username: tokenInfo.Username, |
|
|
@ -202,30 +197,28 @@ func (h oauthHandler) viewOauthCallback(w http.ResponseWriter, r *http.Request) |
|
|
|
|
|
|
|
|
|
|
|
err = h.DB.CreateUser(h.Config, newUser, displayName) |
|
|
|
err = h.DB.CreateUser(h.Config, newUser, displayName) |
|
|
|
if err != nil { |
|
|
|
if err != nil { |
|
|
|
failOAuthRequest(w, http.StatusInternalServerError, err.Error()) |
|
|
|
return impart.HTTPError{http.StatusInternalServerError, err.Error()} |
|
|
|
return |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
err = h.DB.RecordRemoteUserID(ctx, newUser.ID, tokenInfo.UserID, provider, clientID, tokenResponse.AccessToken) |
|
|
|
err = h.DB.RecordRemoteUserID(ctx, newUser.ID, tokenInfo.UserID, provider, clientID, tokenResponse.AccessToken) |
|
|
|
if err != nil { |
|
|
|
if err != nil { |
|
|
|
failOAuthRequest(w, http.StatusInternalServerError, err.Error()) |
|
|
|
return impart.HTTPError{http.StatusInternalServerError, err.Error()} |
|
|
|
return |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if err := loginOrFail(h.Store, w, r, newUser); err != nil { |
|
|
|
if err := loginOrFail(h.Store, w, r, newUser); err != nil { |
|
|
|
failOAuthRequest(w, http.StatusInternalServerError, err.Error()) |
|
|
|
return impart.HTTPError{http.StatusInternalServerError, err.Error()} |
|
|
|
} |
|
|
|
} |
|
|
|
return |
|
|
|
return nil |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
user, err := h.DB.GetUserForAuthByID(localUserID) |
|
|
|
user, err := h.DB.GetUserForAuthByID(localUserID) |
|
|
|
if err != nil { |
|
|
|
if err != nil { |
|
|
|
failOAuthRequest(w, http.StatusInternalServerError, err.Error()) |
|
|
|
return impart.HTTPError{http.StatusInternalServerError, err.Error()} |
|
|
|
return |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
if err = loginOrFail(h.Store, w, r, user); err != nil { |
|
|
|
if err = loginOrFail(h.Store, w, r, user); err != nil { |
|
|
|
failOAuthRequest(w, http.StatusInternalServerError, err.Error()) |
|
|
|
return impart.HTTPError{http.StatusInternalServerError, err.Error()} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return nil |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func limitedJsonUnmarshal(body io.ReadCloser, n int, thing interface{}) error { |
|
|
|
func limitedJsonUnmarshal(body io.ReadCloser, n int, thing interface{}) error { |
|
|
@ -251,16 +244,3 @@ func loginOrFail(store sessions.Store, w http.ResponseWriter, r *http.Request, u |
|
|
|
http.Redirect(w, r, "/", http.StatusTemporaryRedirect) |
|
|
|
http.Redirect(w, r, "/", http.StatusTemporaryRedirect) |
|
|
|
return nil |
|
|
|
return nil |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// failOAuthRequest is an HTTP handler helper that formats returned error
|
|
|
|
|
|
|
|
// messages.
|
|
|
|
|
|
|
|
func failOAuthRequest(w http.ResponseWriter, statusCode int, message string) { |
|
|
|
|
|
|
|
w.Header().Set("Content-Type", "application/json") |
|
|
|
|
|
|
|
w.WriteHeader(statusCode) |
|
|
|
|
|
|
|
err := json.NewEncoder(w).Encode(map[string]interface{}{ |
|
|
|
|
|
|
|
"error": message, |
|
|
|
|
|
|
|
}) |
|
|
|
|
|
|
|
if err != nil { |
|
|
|
|
|
|
|
panic(err) |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|