|
|
@ -6,6 +6,7 @@ import ( |
|
|
|
"fmt" |
|
|
|
"fmt" |
|
|
|
"github.com/gorilla/sessions" |
|
|
|
"github.com/gorilla/sessions" |
|
|
|
"github.com/guregu/null/zero" |
|
|
|
"github.com/guregu/null/zero" |
|
|
|
|
|
|
|
"github.com/writeas/impart" |
|
|
|
"github.com/writeas/nerds/store" |
|
|
|
"github.com/writeas/nerds/store" |
|
|
|
"github.com/writeas/web-core/auth" |
|
|
|
"github.com/writeas/web-core/auth" |
|
|
|
"github.com/writeas/web-core/log" |
|
|
|
"github.com/writeas/web-core/log" |
|
|
@ -96,16 +97,15 @@ func buildAuthURL(db OAuthDatastore, ctx context.Context, clientID, authLocation |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// app *App, w http.ResponseWriter, r *http.Request
|
|
|
|
// app *App, w http.ResponseWriter, r *http.Request
|
|
|
|
func (h oauthHandler) viewOauthInit(w http.ResponseWriter, r *http.Request) { |
|
|
|
func (h oauthHandler) viewOauthInit(app *App, w http.ResponseWriter, r *http.Request) error { |
|
|
|
location, err := buildAuthURL(h.DB, r.Context(), h.Config.App.OAuthClientID, h.Config.App.OAuthProviderAuthLocation, h.Config.App.OAuthClientCallbackLocation) |
|
|
|
location, err := buildAuthURL(h.DB, r.Context(), h.Config.App.OAuthClientID, h.Config.App.OAuthProviderAuthLocation, h.Config.App.OAuthClientCallbackLocation) |
|
|
|
if err != nil { |
|
|
|
if err != nil { |
|
|
|
failOAuthRequest(w, http.StatusInternalServerError, "could not prepare oauth redirect url") |
|
|
|
return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"} |
|
|
|
return |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
http.Redirect(w, r, location, http.StatusTemporaryRedirect) |
|
|
|
return impart.HTTPError{http.StatusTemporaryRedirect, location} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func (h oauthHandler) viewOauthCallback(w http.ResponseWriter, r *http.Request) { |
|
|
|
func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http.Request) error { |
|
|
|
ctx := r.Context() |
|
|
|
ctx := r.Context() |
|
|
|
|
|
|
|
|
|
|
|
code := r.FormValue("code") |
|
|
|
code := r.FormValue("code") |
|
|
@ -113,28 +113,24 @@ func (h oauthHandler) viewOauthCallback(w http.ResponseWriter, r *http.Request) |
|
|
|
|
|
|
|
|
|
|
|
err := h.DB.ValidateOAuthState(ctx, state) |
|
|
|
err := h.DB.ValidateOAuthState(ctx, state) |
|
|
|
if err != nil { |
|
|
|
if err != nil { |
|
|
|
failOAuthRequest(w, http.StatusInternalServerError, err.Error()) |
|
|
|
return impart.HTTPError{http.StatusInternalServerError, err.Error()} |
|
|
|
return |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
tokenResponse, err := h.exchangeOauthCode(ctx, code) |
|
|
|
tokenResponse, err := h.exchangeOauthCode(ctx, code) |
|
|
|
if err != nil { |
|
|
|
if err != nil { |
|
|
|
failOAuthRequest(w, http.StatusInternalServerError, err.Error()) |
|
|
|
return impart.HTTPError{http.StatusInternalServerError, err.Error()} |
|
|
|
return |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// Now that we have the access token, let's use it real quick to make sur
|
|
|
|
// Now that we have the access token, let's use it real quick to make sur
|
|
|
|
// it really really works.
|
|
|
|
// it really really works.
|
|
|
|
tokenInfo, err := h.inspectOauthAccessToken(ctx, tokenResponse.AccessToken) |
|
|
|
tokenInfo, err := h.inspectOauthAccessToken(ctx, tokenResponse.AccessToken) |
|
|
|
if err != nil { |
|
|
|
if err != nil { |
|
|
|
failOAuthRequest(w, http.StatusInternalServerError, err.Error()) |
|
|
|
return impart.HTTPError{http.StatusInternalServerError, err.Error()} |
|
|
|
return |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
localUserID, err := h.DB.GetIDForRemoteUser(ctx, tokenInfo.UserID) |
|
|
|
localUserID, err := h.DB.GetIDForRemoteUser(ctx, tokenInfo.UserID) |
|
|
|
if err != nil { |
|
|
|
if err != nil { |
|
|
|
failOAuthRequest(w, http.StatusInternalServerError, err.Error()) |
|
|
|
return impart.HTTPError{http.StatusInternalServerError, err.Error()} |
|
|
|
return |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
fmt.Println("local user id", localUserID) |
|
|
|
fmt.Println("local user id", localUserID) |
|
|
@ -148,8 +144,7 @@ func (h oauthHandler) viewOauthCallback(w http.ResponseWriter, r *http.Request) |
|
|
|
hashedPass, err := auth.HashPass([]byte(randPass)) |
|
|
|
hashedPass, err := auth.HashPass([]byte(randPass)) |
|
|
|
if err != nil { |
|
|
|
if err != nil { |
|
|
|
log.ErrorLog.Println(err) |
|
|
|
log.ErrorLog.Println(err) |
|
|
|
failOAuthRequest(w, http.StatusInternalServerError, "unable to create password hash") |
|
|
|
return impart.HTTPError{http.StatusInternalServerError, "unable to create password hash"} |
|
|
|
return |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
newUser := &User{ |
|
|
|
newUser := &User{ |
|
|
|
Username: tokenInfo.Username, |
|
|
|
Username: tokenInfo.Username, |
|
|
@ -161,30 +156,28 @@ func (h oauthHandler) viewOauthCallback(w http.ResponseWriter, r *http.Request) |
|
|
|
|
|
|
|
|
|
|
|
err = h.DB.CreateUser(h.Config, newUser, newUser.Username) |
|
|
|
err = h.DB.CreateUser(h.Config, newUser, newUser.Username) |
|
|
|
if err != nil { |
|
|
|
if err != nil { |
|
|
|
failOAuthRequest(w, http.StatusInternalServerError, err.Error()) |
|
|
|
return impart.HTTPError{http.StatusInternalServerError, err.Error()} |
|
|
|
return |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
err = h.DB.RecordRemoteUserID(ctx, newUser.ID, tokenInfo.UserID) |
|
|
|
err = h.DB.RecordRemoteUserID(ctx, newUser.ID, tokenInfo.UserID) |
|
|
|
if err != nil { |
|
|
|
if err != nil { |
|
|
|
failOAuthRequest(w, http.StatusInternalServerError, err.Error()) |
|
|
|
return impart.HTTPError{http.StatusInternalServerError, err.Error()} |
|
|
|
return |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if err := loginOrFail(h.Store, w, r, newUser); err != nil { |
|
|
|
if err := loginOrFail(h.Store, w, r, newUser); err != nil { |
|
|
|
failOAuthRequest(w, http.StatusInternalServerError, err.Error()) |
|
|
|
return impart.HTTPError{http.StatusInternalServerError, err.Error()} |
|
|
|
} |
|
|
|
} |
|
|
|
return |
|
|
|
return nil |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
user, err := h.DB.GetUserForAuthByID(localUserID) |
|
|
|
user, err := h.DB.GetUserForAuthByID(localUserID) |
|
|
|
if err != nil { |
|
|
|
if err != nil { |
|
|
|
failOAuthRequest(w, http.StatusInternalServerError, err.Error()) |
|
|
|
return impart.HTTPError{http.StatusInternalServerError, err.Error()} |
|
|
|
return |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
if err = loginOrFail(h.Store, w, r, user); err != nil { |
|
|
|
if err = loginOrFail(h.Store, w, r, user); err != nil { |
|
|
|
failOAuthRequest(w, http.StatusInternalServerError, err.Error()) |
|
|
|
return impart.HTTPError{http.StatusInternalServerError, err.Error()} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return nil |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func (h oauthHandler) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) { |
|
|
|
func (h oauthHandler) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) { |
|
|
@ -265,16 +258,3 @@ func loginOrFail(store sessions.Store, w http.ResponseWriter, r *http.Request, u |
|
|
|
http.Redirect(w, r, "/", http.StatusTemporaryRedirect) |
|
|
|
http.Redirect(w, r, "/", http.StatusTemporaryRedirect) |
|
|
|
return nil |
|
|
|
return nil |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// failOAuthRequest is an HTTP handler helper that formats returned error
|
|
|
|
|
|
|
|
// messages.
|
|
|
|
|
|
|
|
func failOAuthRequest(w http.ResponseWriter, statusCode int, message string) { |
|
|
|
|
|
|
|
w.Header().Set("Content-Type", "application/json") |
|
|
|
|
|
|
|
w.WriteHeader(statusCode) |
|
|
|
|
|
|
|
err := json.NewEncoder(w).Encode(map[string]interface{}{ |
|
|
|
|
|
|
|
"error": message, |
|
|
|
|
|
|
|
}) |
|
|
|
|
|
|
|
if err != nil { |
|
|
|
|
|
|
|
panic(err) |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|