diff --git a/account.go b/account.go index d32f503..56a4f84 100644 --- a/account.go +++ b/account.go @@ -1,5 +1,5 @@ /* - * Copyright © 2018-2019 A Bunch Tell LLC. + * Copyright © 2018-2020 A Bunch Tell LLC. * * This file is part of WriteFreely. * @@ -168,11 +168,7 @@ func signupWithRegistration(app *App, signup userRegistration, w http.ResponseWr // Log invite if needed if signup.InviteCode != "" { - cu, err := app.db.GetUserForAuth(signup.Alias) - if err != nil { - return nil, err - } - err = app.db.CreateInvitedUser(signup.InviteCode, cu.ID) + err = app.db.CreateInvitedUser(signup.InviteCode, u.ID) if err != nil { return nil, err } @@ -493,6 +489,9 @@ func login(app *App, w http.ResponseWriter, r *http.Request) error { return impart.HTTPError{http.StatusPreconditionFailed, "This user never added a password or email address. Please contact us for help."} } } + if len(u.HashedPass) == 0 { + return impart.HTTPError{http.StatusUnauthorized, "This user never set a password. Perhaps try logging in via OAuth?"} + } if !auth.Authenticated(u.HashedPass, []byte(signin.Pass)) { return impart.HTTPError{http.StatusUnauthorized, "Incorrect password."} } diff --git a/database.go b/database.go index 128e436..0eee612 100644 --- a/database.go +++ b/database.go @@ -132,8 +132,8 @@ type writestore interface { GetIDForRemoteUser(context.Context, string, string, string) (int64, error) RecordRemoteUserID(context.Context, int64, string, string, string, string) error - ValidateOAuthState(context.Context, string) (string, string, int64, error) - GenerateOAuthState(context.Context, string, string, int64) (string, error) + ValidateOAuthState(context.Context, string) (string, string, int64, string, error) + GenerateOAuthState(context.Context, string, string, int64, string) (string, error) GetOauthAccounts(ctx context.Context, userID int64) ([]oauthAccountInfo, error) RemoveOauth(ctx context.Context, userID int64, provider string, clientID string, remoteUserID string) error @@ -178,6 +178,7 @@ func (db *datastore) dateSub(l int, unit string) string { return fmt.Sprintf("DATE_SUB(NOW(), INTERVAL %d %s)", l, unit) } +// CreateUser creates a new user in the database from the given User, UPDATING it in the process with the user's ID. func (db *datastore) CreateUser(cfg *config.Config, u *User, collectionTitle string) error { if db.PostIDExists(u.Username) { return impart.HTTPError{http.StatusConflict, "Invalid collection name."} @@ -2516,24 +2517,26 @@ func (db *datastore) GetCollectionLastPostTime(id int64) (*time.Time, error) { return &t, nil } -func (db *datastore) GenerateOAuthState(ctx context.Context, provider string, clientID string, attachUser int64) (string, error) { +func (db *datastore) GenerateOAuthState(ctx context.Context, provider string, clientID string, attachUser int64, inviteCode string) (string, error) { state := store.Generate62RandomString(24) attachUserVal := sql.NullInt64{Valid: attachUser > 0, Int64: attachUser} - _, err := db.ExecContext(ctx, "INSERT INTO oauth_client_states (state, provider, client_id, used, created_at, attach_user_id) VALUES (?, ?, ?, FALSE, "+db.now()+", ?)", state, provider, clientID, attachUserVal) + inviteCodeVal := sql.NullString{Valid: inviteCode != "", String: inviteCode} + _, err := db.ExecContext(ctx, "INSERT INTO oauth_client_states (state, provider, client_id, used, created_at, attach_user_id, invite_code) VALUES (?, ?, ?, FALSE, "+db.now()+", ?, ?)", state, provider, clientID, attachUserVal, inviteCodeVal) if err != nil { return "", fmt.Errorf("unable to record oauth client state: %w", err) } return state, nil } -func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (string, string, int64, error) { +func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (string, string, int64, string, error) { var provider string var clientID string var attachUserID sql.NullInt64 + var inviteCode sql.NullString err := wf_db.RunTransactionWithOptions(ctx, db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { err := tx. - QueryRowContext(ctx, "SELECT provider, client_id, attach_user_id FROM oauth_client_states WHERE state = ? AND used = FALSE", state). - Scan(&provider, &clientID, &attachUserID) + QueryRowContext(ctx, "SELECT provider, client_id, attach_user_id, invite_code FROM oauth_client_states WHERE state = ? AND used = FALSE", state). + Scan(&provider, &clientID, &attachUserID, &inviteCode) if err != nil { return err } @@ -2552,9 +2555,9 @@ func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (stri return nil }) if err != nil { - return "", "", 0, nil + return "", "", 0, "", nil } - return provider, clientID, attachUserID.Int64, nil + return provider, clientID, attachUserID.Int64, inviteCode.String, nil } func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID, provider, clientID, accessToken string) error { diff --git a/database_test.go b/database_test.go index 569d020..c114077 100644 --- a/database_test.go +++ b/database_test.go @@ -18,13 +18,13 @@ func TestOAuthDatastore(t *testing.T) { driverName: "", } - state, err := ds.GenerateOAuthState(ctx, "test", "development", 0) + state, err := ds.GenerateOAuthState(ctx, "test", "development", 0, "") assert.NoError(t, err) assert.Len(t, state, 24) countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = false", state) - _, _, _, err = ds.ValidateOAuthState(ctx, state) + _, _, _, _, err = ds.ValidateOAuthState(ctx, state) assert.NoError(t, err) countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = true", state) diff --git a/invites.go b/invites.go index c1c7d95..10416b2 100644 --- a/invites.go +++ b/invites.go @@ -1,5 +1,5 @@ /* - * Copyright © 2019 A Bunch Tell LLC. + * Copyright © 2019-2020 A Bunch Tell LLC. * * This file is part of WriteFreely. * @@ -42,6 +42,18 @@ func (i Invite) Expired() bool { return i.Expires != nil && i.Expires.Before(time.Now()) } +func (i Invite) Active(db *datastore) bool { + if i.Expired() { + return false + } + if i.MaxUses.Valid && i.MaxUses.Int64 > 0 { + if c := db.GetUsersInvitedCount(i.ID); c >= i.MaxUses.Int64 { + return false + } + } + return true +} + func (i Invite) ExpiresFriendly() string { return i.Expires.Format("January 2, 2006, 3:04 PM") } @@ -161,9 +173,11 @@ func handleViewInvite(app *App, w http.ResponseWriter, r *http.Request) error { Error string Flashes []template.HTML Invite string + OAuth *OAuthButtons }{ StaticPage: pageForReq(app, r), Invite: inviteCode, + OAuth: NewOAuthButtons(app.cfg), } if expired { diff --git a/less/app.less b/less/app.less index ec3472d..e1cf5ea 100644 --- a/less/app.less +++ b/less/app.less @@ -5,6 +5,7 @@ @import "post-temp"; @import "effects"; @import "admin"; +@import "login"; @import "pages/error"; @import "lib/elements"; @import "lib/material"; diff --git a/less/login.less b/less/login.less new file mode 100644 index 0000000..473d26f --- /dev/null +++ b/less/login.less @@ -0,0 +1,45 @@ +/* + * Copyright © 2020 A Bunch Tell LLC. + * + * This file is part of WriteFreely. + * + * WriteFreely is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, included + * in the LICENSE file in this source code package. + */ + +.row.signinbtns { + justify-content: space-evenly; + font-size: 1em; + margin-top: 2em; + margin-bottom: 1em; + + .loginbtn { + height: 40px; + } + + #writeas-login, #gitlab-login { + box-sizing: border-box; + font-size: 17px; + } +} + +.or { + text-align: center; + margin-bottom: 3.5em; + + p { + display: inline-block; + background-color: white; + padding: 0 1em; + } + + hr { + margin-top: -1.6em; + margin-bottom: 0; + } + + hr.short { + max-width: 30rem; + } +} \ No newline at end of file diff --git a/migrations/migrations.go b/migrations/migrations.go index 31ae43c..a0b3f25 100644 --- a/migrations/migrations.go +++ b/migrations/migrations.go @@ -62,7 +62,8 @@ var migrations = []Migration{ New("support oauth", oauth), // V3 -> V4 New("support slack oauth", oauthSlack), // V4 -> v5 New("support ActivityPub mentions", supportActivityPubMentions), // V5 -> V6 - New("support oauth attach", oauthAttach), // V6 -> V7 (v0.12.0) + New("support oauth attach", oauthAttach), // V6 -> V7 + New("support oauth via invite", oauthInvites), // V7 -> V8 (v0.12.0) } // CurrentVer returns the current migration version the application is on diff --git a/migrations/v8.go b/migrations/v8.go new file mode 100644 index 0000000..2318c4e --- /dev/null +++ b/migrations/v8.go @@ -0,0 +1,45 @@ +/* + * Copyright © 2020 A Bunch Tell LLC. + * + * This file is part of WriteFreely. + * + * WriteFreely is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, included + * in the LICENSE file in this source code package. + */ + +package migrations + +import ( + "context" + "database/sql" + + wf_db "github.com/writeas/writefreely/db" +) + +func oauthInvites(db *datastore) error { + dialect := wf_db.DialectMySQL + if db.driverName == driverSQLite { + dialect = wf_db.DialectSQLite + } + return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { + builders := []wf_db.SQLBuilder{ + dialect. + AlterTable("oauth_client_states"). + AddColumn(dialect.Column("invite_code", wf_db.ColumnTypeChar, wf_db.OptionalInt{ + Set: true, + Value: 6, + }).SetNullable(true)), + } + for _, builder := range builders { + query, err := builder.ToSQL() + if err != nil { + return err + } + if _, err := tx.ExecContext(ctx, query); err != nil { + return err + } + } + return nil + }) +} diff --git a/oauth.go b/oauth.go index 9073f75..b5c88aa 100644 --- a/oauth.go +++ b/oauth.go @@ -1,3 +1,13 @@ +/* + * Copyright © 2019-2020 A Bunch Tell LLC. + * + * This file is part of WriteFreely. + * + * WriteFreely is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, included + * in the LICENSE file in this source code package. + */ + package writefreely import ( @@ -15,10 +25,27 @@ import ( "github.com/gorilla/sessions" "github.com/writeas/impart" "github.com/writeas/web-core/log" - "github.com/writeas/writefreely/config" ) +// OAuthButtons holds display information for different OAuth providers we support. +type OAuthButtons struct { + SlackEnabled bool + WriteAsEnabled bool + GitLabEnabled bool + GitLabDisplayName string +} + +// NewOAuthButtons creates a new OAuthButtons struct based on our app configuration. +func NewOAuthButtons(cfg *config.Config) *OAuthButtons { + return &OAuthButtons{ + SlackEnabled: cfg.SlackOauth.ClientID != "", + WriteAsEnabled: cfg.WriteAsOauth.ClientID != "", + GitLabEnabled: cfg.GitlabOauth.ClientID != "", + GitLabDisplayName: config.OrDefaultString(cfg.GitlabOauth.DisplayName, gitlabDisplayName), + } +} + // TokenResponse contains data returned when a token is created either // through a code exchange or using a refresh token. type TokenResponse struct { @@ -61,8 +88,8 @@ type OAuthDatastoreProvider interface { type OAuthDatastore interface { GetIDForRemoteUser(context.Context, string, string, string) (int64, error) RecordRemoteUserID(context.Context, int64, string, string, string, string) error - ValidateOAuthState(context.Context, string) (string, string, int64, error) - GenerateOAuthState(context.Context, string, string, int64) (string, error) + ValidateOAuthState(context.Context, string) (string, string, int64, string, error) + GenerateOAuthState(context.Context, string, string, int64, string) (string, error) CreateUser(*config.Config, *User, string) error GetUserByID(int64) (*User, error) @@ -108,7 +135,7 @@ func (h oauthHandler) viewOauthInit(app *App, w http.ResponseWriter, r *http.Req attachUser = user.ID } - state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID(), attachUser) + state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID(), attachUser, r.FormValue("invite_code")) if err != nil { log.Error("viewOauthInit error: %s", err) return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"} @@ -228,7 +255,7 @@ func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http code := r.FormValue("code") state := r.FormValue("state") - provider, clientID, attachUserID, err := h.DB.ValidateOAuthState(ctx, state) + provider, clientID, attachUserID, inviteCode, err := h.DB.ValidateOAuthState(ctx, state) if err != nil { log.Error("Unable to ValidateOAuthState: %s", err) return impart.HTTPError{http.StatusInternalServerError, err.Error()} @@ -240,7 +267,7 @@ func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http return impart.HTTPError{http.StatusInternalServerError, err.Error()} } - // Now that we have the access token, let's use it real quick to make sur + // Now that we have the access token, let's use it real quick to make sure // it really really works. tokenInfo, err := h.oauthClient.inspectOauthAccessToken(ctx, tokenResponse.AccessToken) if err != nil { @@ -262,6 +289,7 @@ func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http } if localUserID != -1 { + // Existing user, so log in now user, err := h.DB.GetUserByID(localUserID) if err != nil { log.Error("Unable to GetUserByID %d: %s", localUserID, err) @@ -282,6 +310,22 @@ func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http return impart.HTTPError{http.StatusFound, "/me/settings"} } + // New user registration below. + // First, verify that user is allowed to register + if inviteCode != "" { + // Verify invite code is valid + i, err := app.db.GetUserInvite(inviteCode) + if err != nil { + return impart.HTTPError{http.StatusInternalServerError, err.Error()} + } + if !i.Active(app.db) { + return impart.HTTPError{http.StatusNotFound, "Invite link has expired."} + } + } else if !app.cfg.App.OpenRegistration { + addSessionFlash(app, w, r, ErrUserNotFound.Error(), nil) + return impart.HTTPError{http.StatusFound, "/login"} + } + displayName := tokenInfo.DisplayName if len(displayName) == 0 { displayName = tokenInfo.Username @@ -295,6 +339,7 @@ func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http TokenRemoteUser: tokenInfo.UserID, Provider: provider, ClientID: clientID, + InviteCode: inviteCode, } tp.TokenHash = tp.HashTokenParams(h.Config.Server.HashSeed) diff --git a/oauth_signup.go b/oauth_signup.go index 220afbd..cbe4f60 100644 --- a/oauth_signup.go +++ b/oauth_signup.go @@ -38,6 +38,7 @@ type viewOauthSignupVars struct { Provider string ClientID string TokenHash string + InviteCode string LoginUsername string Alias string // TODO: rename this to match the data it represents: the collection title @@ -57,6 +58,7 @@ const ( oauthParamAlias = "alias" oauthParamEmail = "email" oauthParamPassword = "password" + oauthParamInviteCode = "invite_code" ) type oauthSignupPageParams struct { @@ -68,6 +70,7 @@ type oauthSignupPageParams struct { ClientID string Provider string TokenHash string + InviteCode string } func (p oauthSignupPageParams) HashTokenParams(key string) string { @@ -92,6 +95,7 @@ func (h oauthHandler) viewOauthSignup(app *App, w http.ResponseWriter, r *http.R TokenRemoteUser: r.FormValue(oauthParamTokenRemoteUserID), ClientID: r.FormValue(oauthParamClientID), Provider: r.FormValue(oauthParamProvider), + InviteCode: r.FormValue(oauthParamInviteCode), } if tp.HashTokenParams(h.Config.Server.HashSeed) != r.FormValue(oauthParamHash) { return impart.HTTPError{Status: http.StatusBadRequest, Message: "Request has been tampered with."} @@ -128,6 +132,14 @@ func (h oauthHandler) viewOauthSignup(app *App, w http.ResponseWriter, r *http.R return h.showOauthSignupPage(app, w, r, tp, err) } + // Log invite if needed + if tp.InviteCode != "" { + err = app.db.CreateInvitedUser(tp.InviteCode, newUser.ID) + if err != nil { + return err + } + } + err = h.DB.RecordRemoteUserID(r.Context(), newUser.ID, r.FormValue(oauthParamTokenRemoteUserID), r.FormValue(oauthParamProvider), r.FormValue(oauthParamClientID), r.FormValue(oauthParamAccessToken)) if err != nil { return h.showOauthSignupPage(app, w, r, tp, err) @@ -195,6 +207,7 @@ func (h oauthHandler) showOauthSignupPage(app *App, w http.ResponseWriter, r *ht Provider: tp.Provider, ClientID: tp.ClientID, TokenHash: tp.TokenHash, + InviteCode: tp.InviteCode, LoginUsername: username, Alias: collTitle, diff --git a/oauth_slack.go b/oauth_slack.go index 35db156..c881ab6 100644 --- a/oauth_slack.go +++ b/oauth_slack.go @@ -13,8 +13,6 @@ package writefreely import ( "context" "errors" - "fmt" - "github.com/writeas/nerds/store" "github.com/writeas/slug" "net/http" "net/url" @@ -167,7 +165,7 @@ func (c slackOauthClient) inspectOauthAccessToken(ctx context.Context, accessTok func (resp slackUserIdentityResponse) InspectResponse() *InspectResponse { return &InspectResponse{ UserID: resp.User.ID, - Username: fmt.Sprintf("%s-%s", slug.Make(resp.User.Name), store.GenerateRandomString("0123456789bcdfghjklmnpqrstvwxyz", 5)), + Username: slug.Make(resp.User.Name), DisplayName: resp.User.Name, Email: resp.User.Email, } diff --git a/oauth_test.go b/oauth_test.go index c23eadd..96f65b2 100644 --- a/oauth_test.go +++ b/oauth_test.go @@ -22,8 +22,8 @@ type MockOAuthDatastoreProvider struct { } type MockOAuthDatastore struct { - DoGenerateOAuthState func(context.Context, string, string, int64) (string, error) - DoValidateOAuthState func(context.Context, string) (string, string, int64, error) + DoGenerateOAuthState func(context.Context, string, string, int64, string) (string, error) + DoValidateOAuthState func(context.Context, string) (string, string, int64, string, error) DoGetIDForRemoteUser func(context.Context, string, string, string) (int64, error) DoCreateUser func(*config.Config, *User, string) error DoRecordRemoteUserID func(context.Context, int64, string, string, string, string) error @@ -86,11 +86,11 @@ func (m *MockOAuthDatastoreProvider) Config() *config.Config { return cfg } -func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) (string, string, int64, error) { +func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) (string, string, int64, string, error) { if m.DoValidateOAuthState != nil { return m.DoValidateOAuthState(ctx, state) } - return "", "", 0, nil + return "", "", 0, "", nil } func (m *MockOAuthDatastore) GetIDForRemoteUser(ctx context.Context, remoteUserID, provider, clientID string) (int64, error) { @@ -119,15 +119,13 @@ func (m *MockOAuthDatastore) GetUserByID(userID int64) (*User, error) { if m.DoGetUserByID != nil { return m.DoGetUserByID(userID) } - user := &User{ - - } + user := &User{} return user, nil } -func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context, provider string, clientID string, attachUserID int64) (string, error) { +func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context, provider string, clientID string, attachUserID int64, inviteCode string) (string, error) { if m.DoGenerateOAuthState != nil { - return m.DoGenerateOAuthState(ctx, provider, clientID, attachUserID) + return m.DoGenerateOAuthState(ctx, provider, clientID, attachUserID, inviteCode) } return store.Generate62RandomString(14), nil } @@ -173,7 +171,7 @@ func TestViewOauthInit(t *testing.T) { app := &MockOAuthDatastoreProvider{ DoDB: func() OAuthDatastore { return &MockOAuthDatastore{ - DoGenerateOAuthState: func(ctx context.Context, provider, clientID string, attachUserID int64) (string, error) { + DoGenerateOAuthState: func(ctx context.Context, provider, clientID string, attachUserID int64, inviteCode string) (string, error) { return "", fmt.Errorf("pretend unable to write state error") }, } diff --git a/pages/login.tmpl b/pages/login.tmpl index 94087f3..5a338a4 100644 --- a/pages/login.tmpl +++ b/pages/login.tmpl @@ -3,39 +3,6 @@ {{end}} {{define "content"}} diff --git a/pages/signup-oauth.tmpl b/pages/signup-oauth.tmpl index ecf5db0..fcd70d2 100644 --- a/pages/signup-oauth.tmpl +++ b/pages/signup-oauth.tmpl @@ -1,6 +1,4 @@ -{{define "head"}}