diff --git a/.travis.yml b/.travis.yml index 1e58d6b..fddc71c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,7 @@ language: go go: - - "1.11.x" + - "1.13.x" env: - GO111MODULE=on diff --git a/account.go b/account.go index 6fb8053..2dcfd27 100644 --- a/account.go +++ b/account.go @@ -306,12 +306,16 @@ func viewLogin(app *App, w http.ResponseWriter, r *http.Request) error { Message template.HTML Flashes []template.HTML LoginUsername string + OauthSlack bool + OauthWriteAs bool }{ pageForReq(app, r), r.FormValue("to"), template.HTML(""), []template.HTML{}, getTempInfo(app, "login-user", r, w), + app.Config().SlackOauth.ClientID != "", + app.Config().WriteAsOauth.ClientID != "", } if earlyError != "" { diff --git a/config/config.go b/config/config.go index 996c1df..2616e9e 100644 --- a/config/config.go +++ b/config/config.go @@ -42,6 +42,8 @@ type ( PagesParentDir string `ini:"pages_parent_dir"` KeysParentDir string `ini:"keys_parent_dir"` + HashSeed string `ini:"hash_seed"` + Dev bool `ini:"-"` } @@ -57,17 +59,21 @@ type ( } WriteAsOauthCfg struct { - ClientID string `ini:"client_id"` - ClientSecret string `ini:"client_secret"` - AuthLocation string `ini:"auth_location"` - TokenLocation string `ini:"token_location"` - InspectLocation string `ini:"inspect_location"` + ClientID string `ini:"client_id"` + ClientSecret string `ini:"client_secret"` + AuthLocation string `ini:"auth_location"` + TokenLocation string `ini:"token_location"` + InspectLocation string `ini:"inspect_location"` + CallbackProxy string `ini:"callback_proxy"` + CallbackProxyAPI string `ini:"callback_proxy_api"` } SlackOauthCfg struct { - ClientID string `ini:"client_id"` - ClientSecret string `ini:"client_secret"` - TeamID string `ini:"team_id"` + ClientID string `ini:"client_id"` + ClientSecret string `ini:"client_secret"` + TeamID string `ini:"team_id"` + CallbackProxy string `ini:"callback_proxy"` + CallbackProxyAPI string `ini:"callback_proxy_api"` } // AppCfg holds values that affect how the application functions diff --git a/less/core.less b/less/core.less index 8844c84..3669d76 100644 --- a/less/core.less +++ b/less/core.less @@ -684,18 +684,19 @@ select.inputform, textarea.inputform { border: 1px solid #999; } -input, button, select.inputform, textarea.inputform { +input, button, select.inputform, textarea.inputform, a.btn { padding: 0.5em; font-family: @serifFont; font-size: 100%; .rounded(.25em); - &[type=submit], &.submit { + &[type=submit], &.submit, &.cta { border: 1px solid @primary; background: @primary; color: white; .transition(0.2s); &:hover { background-color: lighten(@primary, 3%); + text-decoration: none; } &:disabled { cursor: default; diff --git a/oauth.go b/oauth.go index 4758e0f..eb47c91 100644 --- a/oauth.go +++ b/oauth.go @@ -7,13 +7,13 @@ import ( "github.com/gorilla/mux" "github.com/gorilla/sessions" "github.com/writeas/impart" - "github.com/writeas/nerds/store" - "github.com/writeas/web-core/auth" "github.com/writeas/web-core/log" "github.com/writeas/writefreely/config" "io" "io/ioutil" "net/http" + "net/url" + "strings" "time" ) @@ -73,17 +73,25 @@ type HttpClient interface { type oauthClient interface { GetProvider() string GetClientID() string + GetCallbackLocation() string buildLoginURL(state string) (string, error) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) } +type callbackProxyClient struct { + server string + callbackLocation string + httpClient HttpClient +} + type oauthHandler struct { - Config *config.Config - DB OAuthDatastore - Store sessions.Store - EmailKey []byte - oauthClient oauthClient + Config *config.Config + DB OAuthDatastore + Store sessions.Store + EmailKey []byte + oauthClient oauthClient + callbackProxy *callbackProxyClient } func (h oauthHandler) viewOauthInit(app *App, w http.ResponseWriter, r *http.Request) error { @@ -92,6 +100,13 @@ func (h oauthHandler) viewOauthInit(app *App, w http.ResponseWriter, r *http.Req if err != nil { return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"} } + + if h.callbackProxy != nil { + if err := h.callbackProxy.register(ctx, state); err != nil { + return impart.HTTPError{http.StatusInternalServerError, "could not register state server"} + } + } + location, err := h.oauthClient.buildLoginURL(state) if err != nil { return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"} @@ -101,19 +116,42 @@ func (h oauthHandler) viewOauthInit(app *App, w http.ResponseWriter, r *http.Req func configureSlackOauth(parentHandler *Handler, r *mux.Router, app *App) { if app.Config().SlackOauth.ClientID != "" { + callbackLocation := app.Config().App.Host + "/oauth/callback" + + var stateRegisterClient *callbackProxyClient = nil + if app.Config().SlackOauth.CallbackProxyAPI != "" { + stateRegisterClient = &callbackProxyClient{ + server: app.Config().SlackOauth.CallbackProxyAPI, + callbackLocation: app.Config().App.Host + "/oauth/callback", + httpClient: config.DefaultHTTPClient(), + } + callbackLocation = app.Config().SlackOauth.CallbackProxy + } oauthClient := slackOauthClient{ ClientID: app.Config().SlackOauth.ClientID, ClientSecret: app.Config().SlackOauth.ClientSecret, TeamID: app.Config().SlackOauth.TeamID, - CallbackLocation: app.Config().App.Host + "/oauth/callback", HttpClient: config.DefaultHTTPClient(), + CallbackLocation: callbackLocation, } - configureOauthRoutes(parentHandler, r, app, oauthClient) + configureOauthRoutes(parentHandler, r, app, oauthClient, stateRegisterClient) } } func configureWriteAsOauth(parentHandler *Handler, r *mux.Router, app *App) { if app.Config().WriteAsOauth.ClientID != "" { + callbackLocation := app.Config().App.Host + "/oauth/callback" + + var callbackProxy *callbackProxyClient = nil + if app.Config().WriteAsOauth.CallbackProxy != "" { + callbackProxy = &callbackProxyClient{ + server: app.Config().WriteAsOauth.CallbackProxyAPI, + callbackLocation: app.Config().App.Host + "/oauth/callback", + httpClient: config.DefaultHTTPClient(), + } + callbackLocation = app.Config().SlackOauth.CallbackProxy + } + oauthClient := writeAsOauthClient{ ClientID: app.Config().WriteAsOauth.ClientID, ClientSecret: app.Config().WriteAsOauth.ClientSecret, @@ -121,22 +159,24 @@ func configureWriteAsOauth(parentHandler *Handler, r *mux.Router, app *App) { InspectLocation: config.OrDefaultString(app.Config().WriteAsOauth.InspectLocation, writeAsIdentityLocation), AuthLocation: config.OrDefaultString(app.Config().WriteAsOauth.AuthLocation, writeAsAuthLocation), HttpClient: config.DefaultHTTPClient(), - CallbackLocation: app.Config().App.Host + "/oauth/callback", + CallbackLocation: callbackLocation, } - configureOauthRoutes(parentHandler, r, app, oauthClient) + configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy) } } -func configureOauthRoutes(parentHandler *Handler, r *mux.Router, app *App, oauthClient oauthClient) { +func configureOauthRoutes(parentHandler *Handler, r *mux.Router, app *App, oauthClient oauthClient, callbackProxy *callbackProxyClient) { handler := &oauthHandler{ - Config: app.Config(), - DB: app.DB(), - Store: app.SessionStore(), - oauthClient: oauthClient, - EmailKey: app.keys.EmailKey, + Config: app.Config(), + DB: app.DB(), + Store: app.SessionStore(), + oauthClient: oauthClient, + EmailKey: app.keys.EmailKey, + callbackProxy: callbackProxy, } r.HandleFunc("/oauth/"+oauthClient.GetProvider(), parentHandler.OAuth(handler.viewOauthInit)).Methods("GET") r.HandleFunc("/oauth/callback", parentHandler.OAuth(handler.viewOauthCallback)).Methods("GET") + r.HandleFunc("/oauth/signup", parentHandler.OAuth(handler.viewOauthSignup)).Methods("POST") } func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http.Request) error { @@ -171,51 +211,53 @@ func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http return impart.HTTPError{http.StatusInternalServerError, err.Error()} } - if localUserID == -1 { - // We don't have, nor do we want, the password from the origin, so we - //create a random string. If the user needs to set a password, they - //can do so through the settings page or through the password reset - //flow. - randPass := store.Generate62RandomString(14) - hashedPass, err := auth.HashPass([]byte(randPass)) - if err != nil { - return impart.HTTPError{http.StatusInternalServerError, "unable to create password hash"} - } - newUser := &User{ - Username: tokenInfo.Username, - HashedPass: hashedPass, - HasPass: true, - Email: prepareUserEmail(tokenInfo.Email, h.EmailKey), - Created: time.Now().Truncate(time.Second).UTC(), - } - displayName := tokenInfo.DisplayName - if len(displayName) == 0 { - displayName = tokenInfo.Username - } - - err = h.DB.CreateUser(h.Config, newUser, displayName) + if localUserID != -1 { + user, err := h.DB.GetUserByID(localUserID) if err != nil { + log.Error("Unable to GetUserByID %d: %s", localUserID, err) return impart.HTTPError{http.StatusInternalServerError, err.Error()} } - - err = h.DB.RecordRemoteUserID(ctx, newUser.ID, tokenInfo.UserID, provider, clientID, tokenResponse.AccessToken) - if err != nil { - return impart.HTTPError{http.StatusInternalServerError, err.Error()} - } - - if err := loginOrFail(h.Store, w, r, newUser); err != nil { + if err = loginOrFail(h.Store, w, r, user); err != nil { + log.Error("Unable to loginOrFail %d: %s", localUserID, err) return impart.HTTPError{http.StatusInternalServerError, err.Error()} } return nil } - user, err := h.DB.GetUserByID(localUserID) + tp := &oauthSignupPageParams{ + AccessToken: tokenResponse.AccessToken, + TokenUsername: tokenInfo.Username, + TokenAlias: tokenInfo.DisplayName, + TokenEmail: tokenInfo.Email, + TokenRemoteUser: tokenInfo.UserID, + Provider: provider, + ClientID: clientID, + } + tp.TokenHash = tp.HashTokenParams(h.Config.Server.HashSeed) + + return h.showOauthSignupPage(app, w, r, tp, nil) +} + +func (r *callbackProxyClient) register(ctx context.Context, state string) error { + form := url.Values{} + form.Add("state", state) + form.Add("location", r.callbackLocation) + req, err := http.NewRequestWithContext(ctx, "POST", r.server, strings.NewReader(form.Encode())) + if err != nil { + return err + } + req.Header.Set("User-Agent", "writefreely") + req.Header.Set("Accept", "application/json") + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := r.httpClient.Do(req) if err != nil { - return impart.HTTPError{http.StatusInternalServerError, err.Error()} + return err } - if err = loginOrFail(h.Store, w, r, user); err != nil { - return impart.HTTPError{http.StatusInternalServerError, err.Error()} + if resp.StatusCode != http.StatusCreated { + return fmt.Errorf("unable register state location: %d", resp.StatusCode) } + return nil } diff --git a/oauth_signup.go b/oauth_signup.go new file mode 100644 index 0000000..cf90af6 --- /dev/null +++ b/oauth_signup.go @@ -0,0 +1,206 @@ +package writefreely + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "github.com/writeas/impart" + "github.com/writeas/web-core/auth" + "github.com/writeas/web-core/log" + "github.com/writeas/writefreely/page" + "html/template" + "net/http" + "strings" + "time" +) + +type viewOauthSignupVars struct { + page.StaticPage + To string + Message template.HTML + Flashes []template.HTML + + AccessToken string + TokenUsername string + TokenAlias string + TokenEmail string + TokenRemoteUser string + Provider string + ClientID string + TokenHash string + + Username string + Alias string + Email string +} + +const ( + oauthParamAccessToken = "access_token" + oauthParamTokenUsername = "token_username" + oauthParamTokenAlias = "token_alias" + oauthParamTokenEmail = "token_email" + oauthParamTokenRemoteUserID = "token_remote_user" + oauthParamClientID = "client_id" + oauthParamProvider = "provider" + oauthParamHash = "signature" + oauthParamUsername = "username" + oauthParamAlias = "alias" + oauthParamEmail = "email" + oauthParamPassword = "password" +) + +type oauthSignupPageParams struct { + AccessToken string + TokenUsername string + TokenAlias string + TokenEmail string + TokenRemoteUser string + ClientID string + Provider string + TokenHash string +} + +func (p oauthSignupPageParams) HashTokenParams(key string) string { + hasher := sha256.New() + hasher.Write([]byte(key)) + hasher.Write([]byte(p.AccessToken)) + hasher.Write([]byte(p.TokenUsername)) + hasher.Write([]byte(p.TokenAlias)) + hasher.Write([]byte(p.TokenEmail)) + hasher.Write([]byte(p.TokenRemoteUser)) + hasher.Write([]byte(p.ClientID)) + hasher.Write([]byte(p.Provider)) + return hex.EncodeToString(hasher.Sum(nil)) +} + +func (h oauthHandler) viewOauthSignup(app *App, w http.ResponseWriter, r *http.Request) error { + tp := &oauthSignupPageParams{ + AccessToken: r.FormValue(oauthParamAccessToken), + TokenUsername: r.FormValue(oauthParamTokenUsername), + TokenAlias: r.FormValue(oauthParamTokenAlias), + TokenEmail: r.FormValue(oauthParamTokenEmail), + TokenRemoteUser: r.FormValue(oauthParamTokenRemoteUserID), + ClientID: r.FormValue(oauthParamClientID), + Provider: r.FormValue(oauthParamProvider), + } + if tp.HashTokenParams(h.Config.Server.HashSeed) != r.FormValue(oauthParamHash) { + return impart.HTTPError{Status: http.StatusBadRequest, Message: "Request has been tampered with."} + } + tp.TokenHash = tp.HashTokenParams(h.Config.Server.HashSeed) + if err := h.validateOauthSignup(r); err != nil { + return h.showOauthSignupPage(app, w, r, tp, err) + } + + hashedPass, err := auth.HashPass([]byte(r.FormValue(oauthParamPassword))) + if err != nil { + return h.showOauthSignupPage(app, w, r, tp, fmt.Errorf("unable to hash password")) + } + newUser := &User{ + Username: r.FormValue(oauthParamUsername), + HashedPass: hashedPass, + HasPass: true, + Email: prepareUserEmail(r.FormValue(oauthParamEmail), h.EmailKey), + Created: time.Now().Truncate(time.Second).UTC(), + } + displayName := r.FormValue(oauthParamAlias) + if len(displayName) == 0 { + displayName = r.FormValue(oauthParamUsername) + } + + err = h.DB.CreateUser(h.Config, newUser, displayName) + if err != nil { + return h.showOauthSignupPage(app, w, r, tp, err) + } + + err = h.DB.RecordRemoteUserID(r.Context(), newUser.ID, r.FormValue(oauthParamTokenRemoteUserID), r.FormValue(oauthParamProvider), r.FormValue(oauthParamClientID), r.FormValue(oauthParamAccessToken)) + if err != nil { + return h.showOauthSignupPage(app, w, r, tp, err) + } + + if err := loginOrFail(h.Store, w, r, newUser); err != nil { + return h.showOauthSignupPage(app, w, r, tp, err) + } + return nil +} + +func (h oauthHandler) validateOauthSignup(r *http.Request) error { + username := r.FormValue(oauthParamUsername) + if len(username) < h.Config.App.MinUsernameLen { + return impart.HTTPError{Status: http.StatusBadRequest, Message: "Username is too short."} + } + if len(username) > 100 { + return impart.HTTPError{Status: http.StatusBadRequest, Message: "Username is too long."} + } + alias := r.FormValue(oauthParamAlias) + if len(alias) == 0 { + return impart.HTTPError{Status: http.StatusBadRequest, Message: "Alias is too short."} + } + password := r.FormValue("password") + if len(password) == 0 { + return impart.HTTPError{Status: http.StatusBadRequest, Message: "Password is too short."} + } + email := r.FormValue(oauthParamEmail) + if len(email) > 0 { + parts := strings.Split(email, "@") + if len(parts) != 2 || (len(parts[0]) < 1 || len(parts[1]) < 1) { + return impart.HTTPError{Status: http.StatusBadRequest, Message: "Invalid email address"} + } + } + return nil +} + +func (h oauthHandler) showOauthSignupPage(app *App, w http.ResponseWriter, r *http.Request, tp *oauthSignupPageParams, errMsg error) error { + username := tp.TokenUsername + alias := tp.TokenAlias + email := tp.TokenEmail + + session, err := app.sessionStore.Get(r, cookieName) + if err != nil { + // Ignore this + log.Error("Unable to get session; ignoring: %v", err) + } + + if tmpValue := r.FormValue(oauthParamUsername); len(tmpValue) > 0 { + username = tmpValue + } + if tmpValue := r.FormValue(oauthParamAlias); len(tmpValue) > 0 { + alias = tmpValue + } + if tmpValue := r.FormValue(oauthParamEmail); len(tmpValue) > 0 { + email = tmpValue + } + + p := &viewOauthSignupVars{ + StaticPage: pageForReq(app, r), + To: r.FormValue("to"), + Flashes: []template.HTML{}, + + AccessToken: tp.AccessToken, + TokenUsername: tp.TokenUsername, + TokenAlias: tp.TokenAlias, + TokenEmail: tp.TokenEmail, + TokenRemoteUser: tp.TokenRemoteUser, + Provider: tp.Provider, + ClientID: tp.ClientID, + TokenHash: tp.TokenHash, + + Username: username, + Alias: alias, + Email: email, + } + + // Display any error messages + flashes, _ := getSessionFlashes(app, w, r, session) + for _, flash := range flashes { + p.Flashes = append(p.Flashes, template.HTML(flash)) + } + if errMsg != nil { + p.Flashes = append(p.Flashes, template.HTML(errMsg.Error())) + } + err = pages["signup-oauth.tmpl"].ExecuteTemplate(w, "base", p) + if err != nil { + log.Error("Unable to render signup-oauth: %v", err) + return err + } + return nil +} diff --git a/oauth_slack.go b/oauth_slack.go index 066aa18..8cf4992 100644 --- a/oauth_slack.go +++ b/oauth_slack.go @@ -3,6 +3,8 @@ package writefreely import ( "context" "errors" + "fmt" + "github.com/writeas/nerds/store" "github.com/writeas/slug" "net/http" "net/url" @@ -60,6 +62,10 @@ func (c slackOauthClient) GetClientID() string { return c.ClientID } +func (c slackOauthClient) GetCallbackLocation() string { + return c.CallbackLocation +} + func (c slackOauthClient) buildLoginURL(state string) (string, error) { u, err := url.Parse(slackAuthLocation) if err != nil { @@ -151,7 +157,7 @@ func (c slackOauthClient) inspectOauthAccessToken(ctx context.Context, accessTok func (resp slackUserIdentityResponse) InspectResponse() *InspectResponse { return &InspectResponse{ UserID: resp.User.ID, - Username: slug.Make(resp.User.Name), + Username: fmt.Sprintf("%s-%s", slug.Make(resp.User.Name), store.Generate62RandomString(5)), DisplayName: resp.User.Name, Email: resp.User.Email, } diff --git a/oauth_writeas.go b/oauth_writeas.go index eb12f64..6251a16 100644 --- a/oauth_writeas.go +++ b/oauth_writeas.go @@ -34,6 +34,10 @@ func (c writeAsOauthClient) GetClientID() string { return c.ClientID } +func (c writeAsOauthClient) GetCallbackLocation() string { + return c.CallbackLocation +} + func (c writeAsOauthClient) buildLoginURL(state string) (string, error) { u, err := url.Parse(c.AuthLocation) if err != nil { diff --git a/pages/login.tmpl b/pages/login.tmpl index 1c8e862..345b171 100644 --- a/pages/login.tmpl +++ b/pages/login.tmpl @@ -1,7 +1,38 @@ {{define "head"}}
or
+