Merge pull request #243 from writeas/T713-oauth-account-management

OAuth account management
pull/288/head
Matt Baer 5 years ago committed by GitHub
commit 0acc630af5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 68
      account.go
  2. 54
      database.go
  3. 4
      database_test.go
  4. 2
      go.mod
  5. 4
      go.sum
  6. 1
      migrations/migrations.go
  7. 36
      migrations/v7.go
  8. 48
      oauth.go
  9. 14
      oauth_test.go
  10. 1
      routes.go
  11. BIN
      static/img/mark/gitlab.png
  12. BIN
      static/img/mark/slack.png
  13. BIN
      static/img/mark/writeas.png
  14. 1
      templates.go
  15. 64
      templates/user/settings.tmpl

@ -27,6 +27,7 @@ import (
"github.com/writeas/web-core/auth" "github.com/writeas/web-core/auth"
"github.com/writeas/web-core/data" "github.com/writeas/web-core/data"
"github.com/writeas/web-core/log" "github.com/writeas/web-core/log"
"github.com/writeas/writefreely/author" "github.com/writeas/writefreely/author"
"github.com/writeas/writefreely/config" "github.com/writeas/writefreely/config"
"github.com/writeas/writefreely/page" "github.com/writeas/writefreely/page"
@ -70,7 +71,7 @@ func canUserInvite(cfg *config.Config, isAdmin bool) bool {
} }
func (up *UserPage) SetMessaging(u *User) { func (up *UserPage) SetMessaging(u *User) {
//up.NeedsAuth = app.db.DoesUserNeedAuth(u.ID) // up.NeedsAuth = app.db.DoesUserNeedAuth(u.ID)
} }
const ( const (
@ -1042,18 +1043,52 @@ func viewSettings(app *App, u *User, w http.ResponseWriter, r *http.Request) err
flashes, _ := getSessionFlashes(app, w, r, nil) flashes, _ := getSessionFlashes(app, w, r, nil)
enableOauthSlack := app.Config().SlackOauth.ClientID != ""
enableOauthWriteAs := app.Config().WriteAsOauth.ClientID != ""
enableOauthGitLab := app.Config().GitlabOauth.ClientID != ""
oauthAccounts, err := app.db.GetOauthAccounts(r.Context(), u.ID)
if err != nil {
log.Error("Unable to get oauth accounts for settings: %s", err)
return impart.HTTPError{http.StatusInternalServerError, "Unable to retrieve user data. The humans have been alerted."}
}
for _, oauthAccount := range oauthAccounts {
switch oauthAccount.Provider {
case "slack":
enableOauthSlack = false
case "write.as":
enableOauthWriteAs = false
case "gitlab":
enableOauthGitLab = false
}
}
displayOauthSection := enableOauthSlack || enableOauthWriteAs || enableOauthGitLab || len(oauthAccounts) > 0
obj := struct { obj := struct {
*UserPage *UserPage
Email string Email string
HasPass bool HasPass bool
IsLogOut bool IsLogOut bool
Silenced bool Silenced bool
OauthSection bool
OauthAccounts []oauthAccountInfo
OauthSlack bool
OauthWriteAs bool
OauthGitLab bool
GitLabDisplayName string
}{ }{
UserPage: NewUserPage(app, r, u, "Account Settings", flashes), UserPage: NewUserPage(app, r, u, "Account Settings", flashes),
Email: fullUser.EmailClear(app.keys), Email: fullUser.EmailClear(app.keys),
HasPass: passIsSet, HasPass: passIsSet,
IsLogOut: r.FormValue("logout") == "1", IsLogOut: r.FormValue("logout") == "1",
Silenced: fullUser.IsSilenced(), Silenced: fullUser.IsSilenced(),
OauthSection: displayOauthSection,
OauthAccounts: oauthAccounts,
OauthSlack: enableOauthSlack,
OauthWriteAs: enableOauthWriteAs,
OauthGitLab: enableOauthGitLab,
GitLabDisplayName: config.OrDefaultString(app.Config().GitlabOauth.DisplayName, gitlabDisplayName),
} }
showUserPage(w, "settings", obj) showUserPage(w, "settings", obj)
@ -1098,6 +1133,19 @@ func getTempInfo(app *App, key string, r *http.Request, w http.ResponseWriter) s
return s return s
} }
func removeOauth(app *App, u *User, w http.ResponseWriter, r *http.Request) error {
provider := r.FormValue("provider")
clientID := r.FormValue("client_id")
remoteUserID := r.FormValue("remote_user_id")
err := app.db.RemoveOauth(r.Context(), u.ID, provider, clientID, remoteUserID)
if err != nil {
return impart.HTTPError{Status: http.StatusInternalServerError, Message: err.Error()}
}
return impart.HTTPError{Status: http.StatusFound, Message: "/me/settings"}
}
func prepareUserEmail(input string, emailKey []byte) zero.String { func prepareUserEmail(input string, emailKey []byte) zero.String {
email := zero.NewString("", input != "") email := zero.NewString("", input != "")
if len(input) > 0 { if len(input) > 0 {

@ -130,8 +130,10 @@ type writestore interface {
GetIDForRemoteUser(context.Context, string, string, string) (int64, error) GetIDForRemoteUser(context.Context, string, string, string) (int64, error)
RecordRemoteUserID(context.Context, int64, string, string, string, string) error RecordRemoteUserID(context.Context, int64, string, string, string, string) error
ValidateOAuthState(context.Context, string) (string, string, error) ValidateOAuthState(context.Context, string) (string, string, int64, error)
GenerateOAuthState(context.Context, string, string) (string, error) GenerateOAuthState(context.Context, string, string, int64) (string, error)
GetOauthAccounts(ctx context.Context, userID int64) ([]oauthAccountInfo, error)
RemoveOauth(ctx context.Context, userID int64, provider string, clientID string, remoteUserID string) error
DatabaseInitialized() bool DatabaseInitialized() bool
} }
@ -2510,20 +2512,24 @@ func (db *datastore) GetCollectionLastPostTime(id int64) (*time.Time, error) {
return &t, nil return &t, nil
} }
func (db *datastore) GenerateOAuthState(ctx context.Context, provider, clientID string) (string, error) { func (db *datastore) GenerateOAuthState(ctx context.Context, provider string, clientID string, attachUser int64) (string, error) {
state := store.Generate62RandomString(24) state := store.Generate62RandomString(24)
_, err := db.ExecContext(ctx, "INSERT INTO oauth_client_states (state, provider, client_id, used, created_at) VALUES (?, ?, ?, FALSE, "+db.now()+")", state, provider, clientID) attachUserVal := sql.NullInt64{Valid: attachUser > 0, Int64: attachUser}
_, err := db.ExecContext(ctx, "INSERT INTO oauth_client_states (state, provider, client_id, used, created_at, attach_user_id) VALUES (?, ?, ?, FALSE, "+db.now()+", ?)", state, provider, clientID, attachUserVal)
if err != nil { if err != nil {
return "", fmt.Errorf("unable to record oauth client state: %w", err) return "", fmt.Errorf("unable to record oauth client state: %w", err)
} }
return state, nil return state, nil
} }
func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (string, string, error) { func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (string, string, int64, error) {
var provider string var provider string
var clientID string var clientID string
var attachUserID sql.NullInt64
err := wf_db.RunTransactionWithOptions(ctx, db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { err := wf_db.RunTransactionWithOptions(ctx, db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error {
err := tx.QueryRow("SELECT provider, client_id FROM oauth_client_states WHERE state = ? AND used = FALSE", state).Scan(&provider, &clientID) err := tx.
QueryRowContext(ctx, "SELECT provider, client_id, attach_user_id FROM oauth_client_states WHERE state = ? AND used = FALSE", state).
Scan(&provider, &clientID, &attachUserID)
if err != nil { if err != nil {
return err return err
} }
@ -2542,9 +2548,9 @@ func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (stri
return nil return nil
}) })
if err != nil { if err != nil {
return "", "", nil return "", "", 0, nil
} }
return provider, clientID, nil return provider, clientID, attachUserID.Int64, nil
} }
func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID, provider, clientID, accessToken string) error { func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID, provider, clientID, accessToken string) error {
@ -2573,6 +2579,33 @@ func (db *datastore) GetIDForRemoteUser(ctx context.Context, remoteUserID, provi
return userID, nil return userID, nil
} }
type oauthAccountInfo struct {
Provider string
ClientID string
RemoteUserID string
}
func (db *datastore) GetOauthAccounts(ctx context.Context, userID int64) ([]oauthAccountInfo, error) {
rows, err := db.QueryContext(ctx, "SELECT provider, client_id, remote_user_id FROM oauth_users WHERE user_id = ? ", userID)
if err != nil {
log.Error("Failed selecting from oauth_users: %v", err)
return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve user oauth accounts."}
}
defer rows.Close()
var records []oauthAccountInfo
for rows.Next() {
info := oauthAccountInfo{}
err = rows.Scan(&info.Provider, &info.ClientID, &info.RemoteUserID)
if err != nil {
log.Error("Failed scanning GetAllUsers() row: %v", err)
break
}
records = append(records, info)
}
return records, nil
}
// DatabaseInitialized returns whether or not the current datastore has been // DatabaseInitialized returns whether or not the current datastore has been
// initialized with the correct schema. // initialized with the correct schema.
// Currently, it checks to see if the `users` table exists. // Currently, it checks to see if the `users` table exists.
@ -2595,6 +2628,11 @@ func (db *datastore) DatabaseInitialized() bool {
return true return true
} }
func (db *datastore) RemoveOauth(ctx context.Context, userID int64, provider string, clientID string, remoteUserID string) error {
_, err := db.ExecContext(ctx, `DELETE FROM oauth_users WHERE user_id = ? AND provider = ? AND client_id = ? AND remote_user_id = ?`, userID, provider, clientID, remoteUserID)
return err
}
func stringLogln(log *string, s string, v ...interface{}) { func stringLogln(log *string, s string, v ...interface{}) {
*log += fmt.Sprintf(s+"\n", v...) *log += fmt.Sprintf(s+"\n", v...)
} }

@ -18,13 +18,13 @@ func TestOAuthDatastore(t *testing.T) {
driverName: "", driverName: "",
} }
state, err := ds.GenerateOAuthState(ctx, "test", "development") state, err := ds.GenerateOAuthState(ctx, "test", "development", 0)
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, state, 24) assert.Len(t, state, 24)
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = false", state) countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = false", state)
_, _, err = ds.ValidateOAuthState(ctx, state) _, _, _, err = ds.ValidateOAuthState(ctx, state)
assert.NoError(t, err) assert.NoError(t, err)
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = true", state) countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = true", state)

@ -21,6 +21,7 @@ require (
github.com/guregu/null v3.4.0+incompatible github.com/guregu/null v3.4.0+incompatible
github.com/hashicorp/go-multierror v1.0.0 github.com/hashicorp/go-multierror v1.0.0
github.com/ikeikeikeike/go-sitemap-generator/v2 v2.0.2 github.com/ikeikeikeike/go-sitemap-generator/v2 v2.0.2
github.com/jteeuwen/go-bindata v3.0.7+incompatible // indirect
github.com/jtolds/gls v4.2.1+incompatible // indirect github.com/jtolds/gls v4.2.1+incompatible // indirect
github.com/kylemcc/twitter-text-go v0.0.0-20180726194232-7f582f6736ec github.com/kylemcc/twitter-text-go v0.0.0-20180726194232-7f582f6736ec
github.com/lunixbochs/vtclean v1.0.0 // indirect github.com/lunixbochs/vtclean v1.0.0 // indirect
@ -51,6 +52,7 @@ require (
github.com/writeas/slug v1.2.0 github.com/writeas/slug v1.2.0
github.com/writeas/web-core v1.2.0 github.com/writeas/web-core v1.2.0
github.com/writefreely/go-nodeinfo v1.2.0 github.com/writefreely/go-nodeinfo v1.2.0
golang.org/dl v0.0.0-20200319204010-bf12898a6070 // indirect
golang.org/x/crypto v0.0.0-20200109152110-61a87790db17 golang.org/x/crypto v0.0.0-20200109152110-61a87790db17
golang.org/x/lint v0.0.0-20181217174547-8f45f776aaf1 // indirect golang.org/x/lint v0.0.0-20181217174547-8f45f776aaf1 // indirect
golang.org/x/tools v0.0.0-20190208222737-3744606dbb67 // indirect golang.org/x/tools v0.0.0-20190208222737-3744606dbb67 // indirect

@ -73,6 +73,8 @@ github.com/hashicorp/go-multierror v1.0.0 h1:iVjPR7a6H0tWELX5NxNe7bYopibicUzc7uP
github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk= github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk=
github.com/ikeikeikeike/go-sitemap-generator/v2 v2.0.2 h1:wIdDEle9HEy7vBPjC6oKz6ejs3Ut+jmsYvuOoAW2pSM= github.com/ikeikeikeike/go-sitemap-generator/v2 v2.0.2 h1:wIdDEle9HEy7vBPjC6oKz6ejs3Ut+jmsYvuOoAW2pSM=
github.com/ikeikeikeike/go-sitemap-generator/v2 v2.0.2/go.mod h1:WtaVKD9TeruTED9ydiaOJU08qGoEPP/LyzTKiD3jEsw= github.com/ikeikeikeike/go-sitemap-generator/v2 v2.0.2/go.mod h1:WtaVKD9TeruTED9ydiaOJU08qGoEPP/LyzTKiD3jEsw=
github.com/jteeuwen/go-bindata v3.0.7+incompatible h1:91Uy4d9SYVr1kyTJ15wJsog+esAZZl7JmEfTkwmhJts=
github.com/jteeuwen/go-bindata v3.0.7+incompatible/go.mod h1:JVvhzYOiGBnFSYRyV00iY8q7/0PThjIYav1p9h5dmKs=
github.com/jtolds/gls v4.2.1+incompatible h1:fSuqC+Gmlu6l/ZYAoZzx2pyucC8Xza35fpRVWLVmUEE= github.com/jtolds/gls v4.2.1+incompatible h1:fSuqC+Gmlu6l/ZYAoZzx2pyucC8Xza35fpRVWLVmUEE=
github.com/jtolds/gls v4.2.1+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/jtolds/gls v4.2.1+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=
github.com/juju/ansiterm v0.0.0-20180109212912-720a0952cc2a h1:FaWFmfWdAUKbSCtOU2QjDaorUexogfaMgbipgYATUMU= github.com/juju/ansiterm v0.0.0-20180109212912-720a0952cc2a h1:FaWFmfWdAUKbSCtOU2QjDaorUexogfaMgbipgYATUMU=
@ -165,6 +167,8 @@ github.com/writeas/web-core v1.2.0 h1:CYqvBd+byi1cK4mCr1NZ6CjILuMOFmiFecv+OACcmG
github.com/writeas/web-core v1.2.0/go.mod h1:vTYajviuNBAxjctPp2NUYdgjofywVkxUGpeaERF3SfI= github.com/writeas/web-core v1.2.0/go.mod h1:vTYajviuNBAxjctPp2NUYdgjofywVkxUGpeaERF3SfI=
github.com/writefreely/go-nodeinfo v1.2.0 h1:La+YbTCvmpTwFhBSlebWDDL81N88Qf/SCAvRLR7F8ss= github.com/writefreely/go-nodeinfo v1.2.0 h1:La+YbTCvmpTwFhBSlebWDDL81N88Qf/SCAvRLR7F8ss=
github.com/writefreely/go-nodeinfo v1.2.0/go.mod h1:UTvE78KpcjYOlRHupZIiSEFcXHioTXuacCbHU+CAcPg= github.com/writefreely/go-nodeinfo v1.2.0/go.mod h1:UTvE78KpcjYOlRHupZIiSEFcXHioTXuacCbHU+CAcPg=
golang.org/dl v0.0.0-20200319204010-bf12898a6070 h1:m3RoSUFYtel4F/gCw0tosY5Exe7hm2NbeNv/737FbSo=
golang.org/dl v0.0.0-20200319204010-bf12898a6070/go.mod h1:IUMfjQLJQd4UTqG1Z90tenwKoCX93Gn3MAQJMOSBsDQ=
golang.org/x/crypto v0.0.0-20180527072434-ab813273cd59 h1:hk3yo72LXLapY9EXVttc3Z1rLOxT9IuAPPX3GpY2+jo= golang.org/x/crypto v0.0.0-20180527072434-ab813273cd59 h1:hk3yo72LXLapY9EXVttc3Z1rLOxT9IuAPPX3GpY2+jo=
golang.org/x/crypto v0.0.0-20180527072434-ab813273cd59/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20180527072434-ab813273cd59/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190131182504-b8fe1690c613/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190131182504-b8fe1690c613/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=

@ -62,6 +62,7 @@ var migrations = []Migration{
New("support oauth", oauth), // V3 -> V4 New("support oauth", oauth), // V3 -> V4
New("support slack oauth", oauthSlack), // V4 -> v5 New("support slack oauth", oauthSlack), // V4 -> v5
New("support ActivityPub mentions", supportActivityPubMentions), // V5 -> V6 (v0.12.0) New("support ActivityPub mentions", supportActivityPubMentions), // V5 -> V6 (v0.12.0)
New("support oauth attach", oauthAttach), // V6 -> V7
} }
// CurrentVer returns the current migration version the application is on // CurrentVer returns the current migration version the application is on

@ -0,0 +1,36 @@
package migrations
import (
"context"
"database/sql"
wf_db "github.com/writeas/writefreely/db"
)
func oauthAttach(db *datastore) error {
dialect := wf_db.DialectMySQL
if db.driverName == driverSQLite {
dialect = wf_db.DialectSQLite
}
return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error {
builders := []wf_db.SQLBuilder{
dialect.
AlterTable("oauth_client_states").
AddColumn(dialect.
Column(
"attach_user_id",
wf_db.ColumnTypeInteger,
wf_db.OptionalInt{Set: true, Value: 24}).SetNullable(true)),
}
for _, builder := range builders {
query, err := builder.ToSQL()
if err != nil {
return err
}
if _, err := tx.ExecContext(ctx, query); err != nil {
return err
}
}
return nil
})
}

@ -4,17 +4,19 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/gorilla/mux"
"github.com/gorilla/sessions"
"github.com/writeas/impart"
"github.com/writeas/web-core/log"
"github.com/writeas/writefreely/config"
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"time" "time"
"github.com/gorilla/mux"
"github.com/gorilla/sessions"
"github.com/writeas/impart"
"github.com/writeas/web-core/log"
"github.com/writeas/writefreely/config"
) )
// TokenResponse contains data returned when a token is created either // TokenResponse contains data returned when a token is created either
@ -59,8 +61,8 @@ type OAuthDatastoreProvider interface {
type OAuthDatastore interface { type OAuthDatastore interface {
GetIDForRemoteUser(context.Context, string, string, string) (int64, error) GetIDForRemoteUser(context.Context, string, string, string) (int64, error)
RecordRemoteUserID(context.Context, int64, string, string, string, string) error RecordRemoteUserID(context.Context, int64, string, string, string, string) error
ValidateOAuthState(context.Context, string) (string, string, error) ValidateOAuthState(context.Context, string) (string, string, int64, error)
GenerateOAuthState(context.Context, string, string) (string, error) GenerateOAuthState(context.Context, string, string, int64) (string, error)
CreateUser(*config.Config, *User, string) error CreateUser(*config.Config, *User, string) error
GetUserByID(int64) (*User, error) GetUserByID(int64) (*User, error)
@ -96,19 +98,32 @@ type oauthHandler struct {
func (h oauthHandler) viewOauthInit(app *App, w http.ResponseWriter, r *http.Request) error { func (h oauthHandler) viewOauthInit(app *App, w http.ResponseWriter, r *http.Request) error {
ctx := r.Context() ctx := r.Context()
state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID())
var attachUser int64
if attach := r.URL.Query().Get("attach"); attach == "t" {
user, _ := getUserAndSession(app, r)
if user == nil {
return impart.HTTPError{http.StatusInternalServerError, "cannot attach auth to user: user not found in session"}
}
attachUser = user.ID
}
state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID(), attachUser)
if err != nil { if err != nil {
log.Error("viewOauthInit error: %s", err)
return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"} return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"}
} }
if h.callbackProxy != nil { if h.callbackProxy != nil {
if err := h.callbackProxy.register(ctx, state); err != nil { if err := h.callbackProxy.register(ctx, state); err != nil {
log.Error("viewOauthInit error: %s", err)
return impart.HTTPError{http.StatusInternalServerError, "could not register state server"} return impart.HTTPError{http.StatusInternalServerError, "could not register state server"}
} }
} }
location, err := h.oauthClient.buildLoginURL(state) location, err := h.oauthClient.buildLoginURL(state)
if err != nil { if err != nil {
log.Error("viewOauthInit error: %s", err)
return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"} return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"}
} }
return impart.HTTPError{http.StatusTemporaryRedirect, location} return impart.HTTPError{http.StatusTemporaryRedirect, location}
@ -213,7 +228,7 @@ func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http
code := r.FormValue("code") code := r.FormValue("code")
state := r.FormValue("state") state := r.FormValue("state")
provider, clientID, err := h.DB.ValidateOAuthState(ctx, state) provider, clientID, attachUserID, err := h.DB.ValidateOAuthState(ctx, state)
if err != nil { if err != nil {
log.Error("Unable to ValidateOAuthState: %s", err) log.Error("Unable to ValidateOAuthState: %s", err)
return impart.HTTPError{http.StatusInternalServerError, err.Error()} return impart.HTTPError{http.StatusInternalServerError, err.Error()}
@ -239,6 +254,13 @@ func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http
return impart.HTTPError{http.StatusInternalServerError, err.Error()} return impart.HTTPError{http.StatusInternalServerError, err.Error()}
} }
if localUserID != -1 && attachUserID > 0 {
if err = addSessionFlash(app, w, r, "This Slack account is already attached to another user.", nil); err != nil {
return impart.HTTPError{Status: http.StatusInternalServerError, Message: err.Error()}
}
return impart.HTTPError{http.StatusFound, "/me/settings"}
}
if localUserID != -1 { if localUserID != -1 {
user, err := h.DB.GetUserByID(localUserID) user, err := h.DB.GetUserByID(localUserID)
if err != nil { if err != nil {
@ -251,6 +273,14 @@ func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http
} }
return nil return nil
} }
if attachUserID > 0 {
log.Info("attaching to user %d", attachUserID)
err = h.DB.RecordRemoteUserID(r.Context(), attachUserID, tokenInfo.UserID, provider, clientID, tokenResponse.AccessToken)
if err != nil {
return impart.HTTPError{http.StatusInternalServerError, err.Error()}
}
return impart.HTTPError{http.StatusFound, "/me/settings"}
}
displayName := tokenInfo.DisplayName displayName := tokenInfo.DisplayName
if len(displayName) == 0 { if len(displayName) == 0 {

@ -22,8 +22,8 @@ type MockOAuthDatastoreProvider struct {
} }
type MockOAuthDatastore struct { type MockOAuthDatastore struct {
DoGenerateOAuthState func(context.Context, string, string) (string, error) DoGenerateOAuthState func(context.Context, string, string, int64) (string, error)
DoValidateOAuthState func(context.Context, string) (string, string, error) DoValidateOAuthState func(context.Context, string) (string, string, int64, error)
DoGetIDForRemoteUser func(context.Context, string, string, string) (int64, error) DoGetIDForRemoteUser func(context.Context, string, string, string) (int64, error)
DoCreateUser func(*config.Config, *User, string) error DoCreateUser func(*config.Config, *User, string) error
DoRecordRemoteUserID func(context.Context, int64, string, string, string, string) error DoRecordRemoteUserID func(context.Context, int64, string, string, string, string) error
@ -86,11 +86,11 @@ func (m *MockOAuthDatastoreProvider) Config() *config.Config {
return cfg return cfg
} }
func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) (string, string, error) { func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) (string, string, int64, error) {
if m.DoValidateOAuthState != nil { if m.DoValidateOAuthState != nil {
return m.DoValidateOAuthState(ctx, state) return m.DoValidateOAuthState(ctx, state)
} }
return "", "", nil return "", "", 0, nil
} }
func (m *MockOAuthDatastore) GetIDForRemoteUser(ctx context.Context, remoteUserID, provider, clientID string) (int64, error) { func (m *MockOAuthDatastore) GetIDForRemoteUser(ctx context.Context, remoteUserID, provider, clientID string) (int64, error) {
@ -125,9 +125,9 @@ func (m *MockOAuthDatastore) GetUserByID(userID int64) (*User, error) {
return user, nil return user, nil
} }
func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context, provider string, clientID string) (string, error) { func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context, provider string, clientID string, attachUserID int64) (string, error) {
if m.DoGenerateOAuthState != nil { if m.DoGenerateOAuthState != nil {
return m.DoGenerateOAuthState(ctx, provider, clientID) return m.DoGenerateOAuthState(ctx, provider, clientID, attachUserID)
} }
return store.Generate62RandomString(14), nil return store.Generate62RandomString(14), nil
} }
@ -173,7 +173,7 @@ func TestViewOauthInit(t *testing.T) {
app := &MockOAuthDatastoreProvider{ app := &MockOAuthDatastoreProvider{
DoDB: func() OAuthDatastore { DoDB: func() OAuthDatastore {
return &MockOAuthDatastore{ return &MockOAuthDatastore{
DoGenerateOAuthState: func(ctx context.Context, provider, clientID string) (string, error) { DoGenerateOAuthState: func(ctx context.Context, provider, clientID string, attachUserID int64) (string, error) {
return "", fmt.Errorf("pretend unable to write state error") return "", fmt.Errorf("pretend unable to write state error")
}, },
} }

@ -115,6 +115,7 @@ func InitRoutes(apper Apper, r *mux.Router) *mux.Router {
apiMe.HandleFunc("/self", handler.All(updateSettings)).Methods("POST") apiMe.HandleFunc("/self", handler.All(updateSettings)).Methods("POST")
apiMe.HandleFunc("/invites", handler.User(handleCreateUserInvite)).Methods("POST") apiMe.HandleFunc("/invites", handler.User(handleCreateUserInvite)).Methods("POST")
apiMe.HandleFunc("/import", handler.User(handleImport)).Methods("POST") apiMe.HandleFunc("/import", handler.User(handleImport)).Methods("POST")
apiMe.HandleFunc("/oauth/remove", handler.User(removeOauth)).Methods("POST")
// Sign up validation // Sign up validation
write.HandleFunc("/api/alias", handler.All(handleUsernameCheck)).Methods("POST") write.HandleFunc("/api/alias", handler.All(handleUsernameCheck)).Methods("POST")

Binary file not shown.

After

Width:  |  Height:  |  Size: 1005 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.8 KiB

@ -37,6 +37,7 @@ var (
"localstr": localStr, "localstr": localStr,
"localhtml": localHTML, "localhtml": localHTML,
"tolower": strings.ToLower, "tolower": strings.ToLower,
"title": strings.Title,
} }
) )

@ -4,7 +4,13 @@
<style type="text/css"> <style type="text/css">
.option { margin: 2em 0em; } .option { margin: 2em 0em; }
h3 { font-weight: normal; } h3 { font-weight: normal; }
.section > *:not(input) { font-size: 0.86em; } .section p, .section label {
font-size: 0.86em;
}
.oauth-provider img {
max-height: 2.75em;
vertical-align: middle;
}
</style> </style>
<div class="content-container snug"> <div class="content-container snug">
{{if .Silenced}} {{if .Silenced}}
@ -62,10 +68,64 @@ h3 { font-weight: normal; }
</div> </div>
</div> </div>
<div class="option" style="text-align: center; margin-top: 4em;"> <div class="option" style="text-align: center;">
<input type="submit" value="Save changes" tabindex="4" /> <input type="submit" value="Save changes" tabindex="4" />
</div> </div>
</form> </form>
{{ if .OauthSection }}
<hr />
{{ if .OauthAccounts }}
<div class="option">
<h2>Linked Accounts</h2>
<p>These are your linked external accounts.</p>
{{ range $oauth_account := .OauthAccounts }}
<form method="post" action="/api/me/oauth/remove" autocomplete="false">
<input type="hidden" name="provider" value="{{ $oauth_account.Provider }}" />
<input type="hidden" name="client_id" value="{{ $oauth_account.ClientID }}" />
<input type="hidden" name="remote_user_id" value="{{ $oauth_account.RemoteUserID }}" />
<div class="section oauth-provider">
<img src="/img/mark/{{$oauth_account.Provider}}.png" alt="{{ $oauth_account.Provider | title }}" />
<input type="submit" value="Remove {{ $oauth_account.Provider | title }}" />
</div>
</form>
{{ end }}
</div>
{{ end }}
{{ if or .OauthSlack .OauthWriteAs .OauthGitLab }}
<div class="option">
<h2>Link External Accounts</h2>
<p>Connect additional accounts to enable logging in with those providers, instead of using your username and password.</p>
<div class="row">
{{ if .OauthWriteAs }}
<div class="section oauth-provider">
<img src="/img/mark/writeas.png" alt="Write.as" />
<a class="btn cta loginbtn" id="writeas-login" href="/oauth/write.as?attach=t">
Link <strong>Write.as</strong>
</a>
</div>
{{ end }}
{{ if .OauthSlack }}
<div class="section oauth-provider">
<img src="/img/mark/slack.png" alt="Slack" />
<a class="btn cta loginbtn" href="/oauth/slack?attach=t">
Link <strong>Slack</strong>
</a>
</div>
{{ end }}
{{ if .OauthGitLab }}
<div class="section oauth-provider">
<img src="/img/mark/gitlab.png" alt="GitLab" />
<a class="btn cta loginbtn" id="gitlab-login" href="/oauth/gitlab?attach=t">
Link <strong>{{.GitLabDisplayName}}</strong>
</a>
</div>
{{ end }}
</div>
</div>
{{ end }}
{{ end }}
</div> </div>
<script> <script>

Loading…
Cancel
Save