diff --git a/config/config.go b/config/config.go index 4b9586e..7a539f1 100644 --- a/config/config.go +++ b/config/config.go @@ -56,6 +56,25 @@ type ( Port int `ini:"port"` } + OAuthCfg struct { + Enabled bool `ini:"enable"` + + // write.as + WriteAsProviderAuthLocation string `ini:"wa_auth_location"` + WriteAsProviderTokenLocation string `ini:"wa_token_location"` + WriteAsProviderInspectLocation string `ini:"wa_inspect_location"` + WriteAsClientCallbackLocation string `ini:"wa_callback_location"` + WriteAsClientID string `ini:"wa_client_id"` + WriteAsClientSecret string `ini:"wa_client_secret"` + WriteAsAuthLocation string + + // slack + SlackClientID string `ini:"slack_client_id"` + SlackClientSecret string `ini:"slack_client_secret"` + SlackTeamID string `init:"slack_team_id"` + SlackAuthLocation string + } + // AppCfg holds values that affect how the application functions AppCfg struct { SiteName string `ini:"site_name"` @@ -92,17 +111,10 @@ type ( LocalTimeline bool `ini:"local_timeline"` UserInvites string `ini:"user_invites"` - // OAuth - EnableOAuth bool `ini:"enable_oauth"` - OAuthProviderAuthLocation string `ini:"oauth_auth_location"` - OAuthProviderTokenLocation string `ini:"oauth_token_location"` - OAuthProviderInspectLocation string `ini:"oauth_inspect_location"` - OAuthClientCallbackLocation string `ini:"oauth_callback_location"` - OAuthClientID string `ini:"oauth_client_id"` - OAuthClientSecret string `ini:"oauth_client_secret"` - // Defaults DefaultVisibility string `ini:"default_visibility"` + + OAuth OAuthCfg `ini:"oauth"` } // Config holds the complete configuration for running a writefreely instance diff --git a/database.go b/database.go index 56035dd..a4c79d4 100644 --- a/database.go +++ b/database.go @@ -125,10 +125,10 @@ type writestore interface { GetUserLastPostTime(id int64) (*time.Time, error) GetCollectionLastPostTime(id int64) (*time.Time, error) - GetIDForRemoteUser(ctx context.Context, remoteUserID int64) (int64, error) - RecordRemoteUserID(ctx context.Context, localUserID, remoteUserID int64) error - ValidateOAuthState(ctx context.Context, state string) error - GenerateOAuthState(ctx context.Context) (string, error) + GetIDForRemoteUser(context.Context, int64) (int64, error) + RecordRemoteUserID(context.Context, int64, int64) error + ValidateOAuthState(context.Context, string, string, string) error + GenerateOAuthState(context.Context, string, string) (string, error) DatabaseInitialized() bool } @@ -138,6 +138,8 @@ type datastore struct { driverName string } +var _ writestore = &datastore{} + func (db *datastore) now() string { if db.driverName == driverSQLite { return "strftime('%Y-%m-%d %H:%M:%S','now')" @@ -2459,17 +2461,17 @@ func (db *datastore) GetCollectionLastPostTime(id int64) (*time.Time, error) { return &t, nil } -func (db *datastore) GenerateOAuthState(ctx context.Context) (string, error) { +func (db *datastore) GenerateOAuthState(ctx context.Context, provider, clientID string) (string, error) { state := store.Generate62RandomString(24) - _, err := db.ExecContext(ctx, "INSERT INTO oauth_client_state (state, used, created_at) VALUES (?, FALSE, NOW())", state) + _, err := db.ExecContext(ctx, "INSERT INTO oauth_client_state (state, provider, client_id, used, created_at) VALUES (?, ?, ?, FALSE, NOW())", state, provider, clientID) if err != nil { return "", fmt.Errorf("unable to record oauth client state: %w", err) } return state, nil } -func (db *datastore) ValidateOAuthState(ctx context.Context, state string) error { - res, err := db.ExecContext(ctx, "UPDATE oauth_client_state SET used = TRUE WHERE state = ?", state) +func (db *datastore) ValidateOAuthState(ctx context.Context, state, provider, clientID string) error { + res, err := db.ExecContext(ctx, "UPDATE oauth_client_state SET used = TRUE WHERE state = ? AND provider = ? AND client_id = ?", state, provider, clientID) if err != nil { return err } diff --git a/database_test.go b/database_test.go index 4a45dd0..b19c861 100644 --- a/database_test.go +++ b/database_test.go @@ -18,7 +18,7 @@ func TestOAuthDatastore(t *testing.T) { driverName: "", } - state, err := ds.GenerateOAuthState(ctx) + state, err := ds.GenerateOAuthState(ctx, "", "") assert.NoError(t, err) assert.Len(t, state, 24) diff --git a/go.mod b/go.mod index 372b7ba..29c08db 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( github.com/guregu/null v3.4.0+incompatible github.com/ikeikeikeike/go-sitemap-generator/v2 v2.0.2 github.com/jtolds/gls v4.2.1+incompatible // indirect + github.com/kr/pretty v0.1.0 github.com/kylemcc/twitter-text-go v0.0.0-20180726194232-7f582f6736ec github.com/lunixbochs/vtclean v1.0.0 // indirect github.com/manifoldco/promptui v0.3.2 diff --git a/go.sum b/go.sum index ee3a418..035538e 100644 --- a/go.sum +++ b/go.sum @@ -64,6 +64,7 @@ github.com/jtolds/gls v4.2.1+incompatible h1:fSuqC+Gmlu6l/ZYAoZzx2pyucC8Xza35fpR github.com/jtolds/gls v4.2.1+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/juju/ansiterm v0.0.0-20180109212912-720a0952cc2a h1:FaWFmfWdAUKbSCtOU2QjDaorUexogfaMgbipgYATUMU= github.com/juju/ansiterm v0.0.0-20180109212912-720a0952cc2a/go.mod h1:UJSiEoRfvx3hP73CvoARgeLjaIOjybY9vj8PUPPFGeU= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= diff --git a/oauth.go b/oauth.go index d918f7f..70ae064 100644 --- a/oauth.go +++ b/oauth.go @@ -2,14 +2,17 @@ package writefreely import ( "context" + "encoding/hex" "encoding/json" "fmt" + "github.com/gorilla/mux" "github.com/gorilla/sessions" "github.com/guregu/null/zero" "github.com/writeas/nerds/store" "github.com/writeas/web-core/auth" "github.com/writeas/web-core/log" "github.com/writeas/writefreely/config" + "hash/fnv" "io" "io/ioutil" "net/http" @@ -55,11 +58,12 @@ type OAuthDatastoreProvider interface { // OAuthDatastore provides a minimal interface of data store methods used in // oauth functionality. type OAuthDatastore interface { - GenerateOAuthState(context.Context) (string, error) - ValidateOAuthState(context.Context, string) error GetIDForRemoteUser(context.Context, int64) (int64, error) - CreateUser(*config.Config, *User, string) error RecordRemoteUserID(context.Context, int64, int64) error + ValidateOAuthState(context.Context, string, string, string) error + GenerateOAuthState(context.Context, string, string) (string, error) + + CreateUser(*config.Config, *User, string) error GetUserForAuthByID(int64) (*User, error) } @@ -75,8 +79,8 @@ type oauthHandler struct { } // buildAuthURL returns a URL used to initiate authentication. -func buildAuthURL(db OAuthDatastore, ctx context.Context, clientID, authLocation, callbackURL string) (string, error) { - state, err := db.GenerateOAuthState(ctx) +func buildAuthURL(db OAuthDatastore, ctx context.Context, provider, clientID, authLocation, callbackURL string) (string, error) { + state, err := db.GenerateOAuthState(ctx, provider, clientID) if err != nil { return "", err } @@ -95,9 +99,8 @@ func buildAuthURL(db OAuthDatastore, ctx context.Context, clientID, authLocation return u.String(), nil } -// app *App, w http.ResponseWriter, r *http.Request -func (h oauthHandler) viewOauthInit(w http.ResponseWriter, r *http.Request) { - location, err := buildAuthURL(h.DB, r.Context(), h.Config.App.OAuthClientID, h.Config.App.OAuthProviderAuthLocation, h.Config.App.OAuthClientCallbackLocation) +func (h oauthHandler) viewOauthInitWriteAs(w http.ResponseWriter, r *http.Request) { + location, err := buildAuthURL(h.DB, r.Context(), "write.as", h.Config.App.OAuth.WriteAsClientID, h.Config.App.OAuth.WriteAsProviderAuthLocation, h.Config.App.OAuth.WriteAsClientCallbackLocation) if err != nil { failOAuthRequest(w, http.StatusInternalServerError, "could not prepare oauth redirect url") return @@ -105,94 +108,128 @@ func (h oauthHandler) viewOauthInit(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, location, http.StatusTemporaryRedirect) } -func (h oauthHandler) viewOauthCallback(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - - code := r.FormValue("code") - state := r.FormValue("state") - - err := h.DB.ValidateOAuthState(ctx, state) +func (h oauthHandler) viewOauthInitSlack(w http.ResponseWriter, r *http.Request) { + location, err := buildAuthURL(h.DB, r.Context(), "slack", h.Config.App.OAuth.WriteAsClientID, h.Config.App.OAuth.WriteAsProviderAuthLocation, h.Config.App.OAuth.WriteAsClientCallbackLocation) if err != nil { - failOAuthRequest(w, http.StatusInternalServerError, err.Error()) + failOAuthRequest(w, http.StatusInternalServerError, "could not prepare oauth redirect url") return } + http.Redirect(w, r, location, http.StatusTemporaryRedirect) +} - tokenResponse, err := h.exchangeOauthCode(ctx, code) - if err != nil { - failOAuthRequest(w, http.StatusInternalServerError, err.Error()) - return +func (h oauthHandler) configureRoutes(r *mux.Router) { + if h.Config.App.OAuth.Enabled { + if h.Config.App.OAuth.WriteAsClientID != "" { + callbackHash := oauthProviderHash("write.as", h.Config.App.OAuth.WriteAsClientID) + log.InfoLog.Println("write.as oauth callback URL", "/oauth/callback/"+callbackHash) + r.HandleFunc("/oauth/write.as", h.viewOauthInitWriteAs).Methods("GET") + r.HandleFunc("/oauth/callback/"+callbackHash, h.viewOauthCallback("write.as", h.Config.App.OAuth.WriteAsClientID)).Methods("GET") + } + if h.Config.App.OAuth.SlackClientID != "" { + callbackHash := oauthProviderHash("slack", h.Config.App.OAuth.SlackClientID) + log.InfoLog.Println("slack oauth callback URL", "/oauth/callback/"+callbackHash) + r.HandleFunc("/oauth/slack", h.viewOauthInitSlack).Methods("GET") + r.HandleFunc("/oauth/callback/"+callbackHash, h.viewOauthCallback("slack", h.Config.App.OAuth.SlackClientID)).Methods("GET") + } } - // Now that we have the access token, let's use it real quick to make sur - // it really really works. - tokenInfo, err := h.inspectOauthAccessToken(ctx, tokenResponse.AccessToken) - if err != nil { - failOAuthRequest(w, http.StatusInternalServerError, err.Error()) - return - } +} - localUserID, err := h.DB.GetIDForRemoteUser(ctx, tokenInfo.UserID) - if err != nil { - failOAuthRequest(w, http.StatusInternalServerError, err.Error()) - return - } +func oauthProviderHash(provider, clientID string) string { + hasher := fnv.New32() + return hex.EncodeToString(hasher.Sum([]byte(provider + clientID))) +} - fmt.Println("local user id", localUserID) +func (h oauthHandler) viewOauthCallback(provider, clientID string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() - if localUserID == -1 { - // We don't have, nor do we want, the password from the origin, so we - //create a random string. If the user needs to set a password, they - //can do so through the settings page or through the password reset - //flow. - randPass := store.Generate62RandomString(14) - hashedPass, err := auth.HashPass([]byte(randPass)) + code := r.FormValue("code") + state := r.FormValue("state") + + err := h.DB.ValidateOAuthState(ctx, state, provider, clientID) if err != nil { - log.ErrorLog.Println(err) - failOAuthRequest(w, http.StatusInternalServerError, "unable to create password hash") + failOAuthRequest(w, http.StatusInternalServerError, err.Error()) return } - newUser := &User{ - Username: tokenInfo.Username, - HashedPass: hashedPass, - HasPass: true, - Email: zero.NewString("", tokenInfo.Email != ""), - Created: time.Now().Truncate(time.Second).UTC(), - } - err = h.DB.CreateUser(h.Config, newUser, newUser.Username) + tokenResponse, err := h.exchangeOauthCode(ctx, code) if err != nil { failOAuthRequest(w, http.StatusInternalServerError, err.Error()) return } - err = h.DB.RecordRemoteUserID(ctx, newUser.ID, tokenInfo.UserID) + // Now that we have the access token, let's use it real quick to make sur + // it really really works. + tokenInfo, err := h.inspectOauthAccessToken(ctx, tokenResponse.AccessToken) if err != nil { failOAuthRequest(w, http.StatusInternalServerError, err.Error()) return } - if err := loginOrFail(h.Store, w, r, newUser); err != nil { + localUserID, err := h.DB.GetIDForRemoteUser(ctx, tokenInfo.UserID) + if err != nil { failOAuthRequest(w, http.StatusInternalServerError, err.Error()) + return } - return - } - user, err := h.DB.GetUserForAuthByID(localUserID) - if err != nil { - failOAuthRequest(w, http.StatusInternalServerError, err.Error()) - return - } - if err = loginOrFail(h.Store, w, r, user); err != nil { - failOAuthRequest(w, http.StatusInternalServerError, err.Error()) + fmt.Println("local user id", localUserID) + + if localUserID == -1 { + // We don't have, nor do we want, the password from the origin, so we + //create a random string. If the user needs to set a password, they + //can do so through the settings page or through the password reset + //flow. + randPass := store.Generate62RandomString(14) + hashedPass, err := auth.HashPass([]byte(randPass)) + if err != nil { + log.ErrorLog.Println(err) + failOAuthRequest(w, http.StatusInternalServerError, "unable to create password hash") + return + } + newUser := &User{ + Username: tokenInfo.Username, + HashedPass: hashedPass, + HasPass: true, + Email: zero.NewString("", tokenInfo.Email != ""), + Created: time.Now().Truncate(time.Second).UTC(), + } + + err = h.DB.CreateUser(h.Config, newUser, newUser.Username) + if err != nil { + failOAuthRequest(w, http.StatusInternalServerError, err.Error()) + return + } + + err = h.DB.RecordRemoteUserID(ctx, newUser.ID, tokenInfo.UserID) + if err != nil { + failOAuthRequest(w, http.StatusInternalServerError, err.Error()) + return + } + + if err := loginOrFail(h.Store, w, r, newUser); err != nil { + failOAuthRequest(w, http.StatusInternalServerError, err.Error()) + } + return + } + + user, err := h.DB.GetUserForAuthByID(localUserID) + if err != nil { + failOAuthRequest(w, http.StatusInternalServerError, err.Error()) + return + } + if err = loginOrFail(h.Store, w, r, user); err != nil { + failOAuthRequest(w, http.StatusInternalServerError, err.Error()) + } } } func (h oauthHandler) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) { form := url.Values{} form.Add("grant_type", "authorization_code") - form.Add("redirect_uri", h.Config.App.OAuthClientCallbackLocation) + form.Add("redirect_uri", h.Config.App.OAuth.WriteAsClientCallbackLocation) form.Add("code", code) - req, err := http.NewRequest("POST", h.Config.App.OAuthProviderTokenLocation, strings.NewReader(form.Encode())) + req, err := http.NewRequest("POST", h.Config.App.OAuth.WriteAsProviderTokenLocation, strings.NewReader(form.Encode())) if err != nil { return nil, err } @@ -200,7 +237,7 @@ func (h oauthHandler) exchangeOauthCode(ctx context.Context, code string) (*Toke req.Header.Set("User-Agent", "writefreely") req.Header.Set("Accept", "application/json") req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.SetBasicAuth(h.Config.App.OAuthClientID, h.Config.App.OAuthClientSecret) + req.SetBasicAuth(h.Config.App.OAuth.WriteAsClientID, h.Config.App.OAuth.WriteAsClientSecret) resp, err := h.HttpClient.Do(req) if err != nil { @@ -224,7 +261,7 @@ func (h oauthHandler) exchangeOauthCode(ctx context.Context, code string) (*Toke } func (h oauthHandler) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) { - req, err := http.NewRequest("GET", h.Config.App.OAuthProviderInspectLocation, nil) + req, err := http.NewRequest("GET", h.Config.App.OAuth.WriteAsProviderInspectLocation, nil) if err != nil { return nil, err } diff --git a/routes.go b/routes.go index e286c3e..2bca288 100644 --- a/routes.go +++ b/routes.go @@ -86,9 +86,7 @@ func InitRoutes(apper Apper, r *mux.Router) *mux.Router { DB: apper.App().DB(), Store: apper.App().SessionStore(), } - - write.HandleFunc("/oauth/write.as", oauthHandler.viewOauthInit).Methods("GET") - write.HandleFunc("/oauth/callback", oauthHandler.viewOauthCallback).Methods("GET") + oauthHandler.configureRoutes(write) // Handle logged in user sections me := write.PathPrefix("/me").Subrouter()