Merging T705-oauth into T710-oauth-slack. T705,T710

pull/232/head
Nick Gerakines 5 years ago
parent 4266154749
commit 13121cb266
  1. 30
      config/config.go
  2. 18
      database.go
  3. 2
      database_test.go
  4. 1
      go.mod
  5. 1
      go.sum
  6. 65
      oauth.go
  7. 4
      routes.go

@ -56,6 +56,25 @@ type (
Port int `ini:"port"`
}
OAuthCfg struct {
Enabled bool `ini:"enable"`
// write.as
WriteAsProviderAuthLocation string `ini:"wa_auth_location"`
WriteAsProviderTokenLocation string `ini:"wa_token_location"`
WriteAsProviderInspectLocation string `ini:"wa_inspect_location"`
WriteAsClientCallbackLocation string `ini:"wa_callback_location"`
WriteAsClientID string `ini:"wa_client_id"`
WriteAsClientSecret string `ini:"wa_client_secret"`
WriteAsAuthLocation string
// slack
SlackClientID string `ini:"slack_client_id"`
SlackClientSecret string `ini:"slack_client_secret"`
SlackTeamID string `init:"slack_team_id"`
SlackAuthLocation string
}
// AppCfg holds values that affect how the application functions
AppCfg struct {
SiteName string `ini:"site_name"`
@ -92,17 +111,10 @@ type (
LocalTimeline bool `ini:"local_timeline"`
UserInvites string `ini:"user_invites"`
// OAuth
EnableOAuth bool `ini:"enable_oauth"`
OAuthProviderAuthLocation string `ini:"oauth_auth_location"`
OAuthProviderTokenLocation string `ini:"oauth_token_location"`
OAuthProviderInspectLocation string `ini:"oauth_inspect_location"`
OAuthClientCallbackLocation string `ini:"oauth_callback_location"`
OAuthClientID string `ini:"oauth_client_id"`
OAuthClientSecret string `ini:"oauth_client_secret"`
// Defaults
DefaultVisibility string `ini:"default_visibility"`
OAuth OAuthCfg `ini:"oauth"`
}
// Config holds the complete configuration for running a writefreely instance

@ -125,10 +125,10 @@ type writestore interface {
GetUserLastPostTime(id int64) (*time.Time, error)
GetCollectionLastPostTime(id int64) (*time.Time, error)
GetIDForRemoteUser(ctx context.Context, remoteUserID int64) (int64, error)
RecordRemoteUserID(ctx context.Context, localUserID, remoteUserID int64) error
ValidateOAuthState(ctx context.Context, state string) error
GenerateOAuthState(ctx context.Context) (string, error)
GetIDForRemoteUser(context.Context, int64) (int64, error)
RecordRemoteUserID(context.Context, int64, int64) error
ValidateOAuthState(context.Context, string, string, string) error
GenerateOAuthState(context.Context, string, string) (string, error)
DatabaseInitialized() bool
}
@ -138,6 +138,8 @@ type datastore struct {
driverName string
}
var _ writestore = &datastore{}
func (db *datastore) now() string {
if db.driverName == driverSQLite {
return "strftime('%Y-%m-%d %H:%M:%S','now')"
@ -2459,17 +2461,17 @@ func (db *datastore) GetCollectionLastPostTime(id int64) (*time.Time, error) {
return &t, nil
}
func (db *datastore) GenerateOAuthState(ctx context.Context) (string, error) {
func (db *datastore) GenerateOAuthState(ctx context.Context, provider, clientID string) (string, error) {
state := store.Generate62RandomString(24)
_, err := db.ExecContext(ctx, "INSERT INTO oauth_client_state (state, used, created_at) VALUES (?, FALSE, NOW())", state)
_, err := db.ExecContext(ctx, "INSERT INTO oauth_client_state (state, provider, client_id, used, created_at) VALUES (?, ?, ?, FALSE, NOW())", state, provider, clientID)
if err != nil {
return "", fmt.Errorf("unable to record oauth client state: %w", err)
}
return state, nil
}
func (db *datastore) ValidateOAuthState(ctx context.Context, state string) error {
res, err := db.ExecContext(ctx, "UPDATE oauth_client_state SET used = TRUE WHERE state = ?", state)
func (db *datastore) ValidateOAuthState(ctx context.Context, state, provider, clientID string) error {
res, err := db.ExecContext(ctx, "UPDATE oauth_client_state SET used = TRUE WHERE state = ? AND provider = ? AND client_id = ?", state, provider, clientID)
if err != nil {
return err
}

@ -18,7 +18,7 @@ func TestOAuthDatastore(t *testing.T) {
driverName: "",
}
state, err := ds.GenerateOAuthState(ctx)
state, err := ds.GenerateOAuthState(ctx, "", "")
assert.NoError(t, err)
assert.Len(t, state, 24)

@ -19,6 +19,7 @@ require (
github.com/guregu/null v3.4.0+incompatible
github.com/ikeikeikeike/go-sitemap-generator/v2 v2.0.2
github.com/jtolds/gls v4.2.1+incompatible // indirect
github.com/kr/pretty v0.1.0
github.com/kylemcc/twitter-text-go v0.0.0-20180726194232-7f582f6736ec
github.com/lunixbochs/vtclean v1.0.0 // indirect
github.com/manifoldco/promptui v0.3.2

@ -64,6 +64,7 @@ github.com/jtolds/gls v4.2.1+incompatible h1:fSuqC+Gmlu6l/ZYAoZzx2pyucC8Xza35fpR
github.com/jtolds/gls v4.2.1+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=
github.com/juju/ansiterm v0.0.0-20180109212912-720a0952cc2a h1:FaWFmfWdAUKbSCtOU2QjDaorUexogfaMgbipgYATUMU=
github.com/juju/ansiterm v0.0.0-20180109212912-720a0952cc2a/go.mod h1:UJSiEoRfvx3hP73CvoARgeLjaIOjybY9vj8PUPPFGeU=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=

@ -2,14 +2,17 @@ package writefreely
import (
"context"
"encoding/hex"
"encoding/json"
"fmt"
"github.com/gorilla/mux"
"github.com/gorilla/sessions"
"github.com/guregu/null/zero"
"github.com/writeas/nerds/store"
"github.com/writeas/web-core/auth"
"github.com/writeas/web-core/log"
"github.com/writeas/writefreely/config"
"hash/fnv"
"io"
"io/ioutil"
"net/http"
@ -55,11 +58,12 @@ type OAuthDatastoreProvider interface {
// OAuthDatastore provides a minimal interface of data store methods used in
// oauth functionality.
type OAuthDatastore interface {
GenerateOAuthState(context.Context) (string, error)
ValidateOAuthState(context.Context, string) error
GetIDForRemoteUser(context.Context, int64) (int64, error)
CreateUser(*config.Config, *User, string) error
RecordRemoteUserID(context.Context, int64, int64) error
ValidateOAuthState(context.Context, string, string, string) error
GenerateOAuthState(context.Context, string, string) (string, error)
CreateUser(*config.Config, *User, string) error
GetUserForAuthByID(int64) (*User, error)
}
@ -75,8 +79,8 @@ type oauthHandler struct {
}
// buildAuthURL returns a URL used to initiate authentication.
func buildAuthURL(db OAuthDatastore, ctx context.Context, clientID, authLocation, callbackURL string) (string, error) {
state, err := db.GenerateOAuthState(ctx)
func buildAuthURL(db OAuthDatastore, ctx context.Context, provider, clientID, authLocation, callbackURL string) (string, error) {
state, err := db.GenerateOAuthState(ctx, provider, clientID)
if err != nil {
return "", err
}
@ -95,9 +99,17 @@ func buildAuthURL(db OAuthDatastore, ctx context.Context, clientID, authLocation
return u.String(), nil
}
// app *App, w http.ResponseWriter, r *http.Request
func (h oauthHandler) viewOauthInit(w http.ResponseWriter, r *http.Request) {
location, err := buildAuthURL(h.DB, r.Context(), h.Config.App.OAuthClientID, h.Config.App.OAuthProviderAuthLocation, h.Config.App.OAuthClientCallbackLocation)
func (h oauthHandler) viewOauthInitWriteAs(w http.ResponseWriter, r *http.Request) {
location, err := buildAuthURL(h.DB, r.Context(), "write.as", h.Config.App.OAuth.WriteAsClientID, h.Config.App.OAuth.WriteAsProviderAuthLocation, h.Config.App.OAuth.WriteAsClientCallbackLocation)
if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, "could not prepare oauth redirect url")
return
}
http.Redirect(w, r, location, http.StatusTemporaryRedirect)
}
func (h oauthHandler) viewOauthInitSlack(w http.ResponseWriter, r *http.Request) {
location, err := buildAuthURL(h.DB, r.Context(), "slack", h.Config.App.OAuth.WriteAsClientID, h.Config.App.OAuth.WriteAsProviderAuthLocation, h.Config.App.OAuth.WriteAsClientCallbackLocation)
if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, "could not prepare oauth redirect url")
return
@ -105,13 +117,37 @@ func (h oauthHandler) viewOauthInit(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, location, http.StatusTemporaryRedirect)
}
func (h oauthHandler) viewOauthCallback(w http.ResponseWriter, r *http.Request) {
func (h oauthHandler) configureRoutes(r *mux.Router) {
if h.Config.App.OAuth.Enabled {
if h.Config.App.OAuth.WriteAsClientID != "" {
callbackHash := oauthProviderHash("write.as", h.Config.App.OAuth.WriteAsClientID)
log.InfoLog.Println("write.as oauth callback URL", "/oauth/callback/"+callbackHash)
r.HandleFunc("/oauth/write.as", h.viewOauthInitWriteAs).Methods("GET")
r.HandleFunc("/oauth/callback/"+callbackHash, h.viewOauthCallback("write.as", h.Config.App.OAuth.WriteAsClientID)).Methods("GET")
}
if h.Config.App.OAuth.SlackClientID != "" {
callbackHash := oauthProviderHash("slack", h.Config.App.OAuth.SlackClientID)
log.InfoLog.Println("slack oauth callback URL", "/oauth/callback/"+callbackHash)
r.HandleFunc("/oauth/slack", h.viewOauthInitSlack).Methods("GET")
r.HandleFunc("/oauth/callback/"+callbackHash, h.viewOauthCallback("slack", h.Config.App.OAuth.SlackClientID)).Methods("GET")
}
}
}
func oauthProviderHash(provider, clientID string) string {
hasher := fnv.New32()
return hex.EncodeToString(hasher.Sum([]byte(provider + clientID)))
}
func (h oauthHandler) viewOauthCallback(provider, clientID string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
code := r.FormValue("code")
state := r.FormValue("state")
err := h.DB.ValidateOAuthState(ctx, state)
err := h.DB.ValidateOAuthState(ctx, state, provider, clientID)
if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
return
@ -186,13 +222,14 @@ func (h oauthHandler) viewOauthCallback(w http.ResponseWriter, r *http.Request)
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
}
}
}
func (h oauthHandler) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) {
form := url.Values{}
form.Add("grant_type", "authorization_code")
form.Add("redirect_uri", h.Config.App.OAuthClientCallbackLocation)
form.Add("redirect_uri", h.Config.App.OAuth.WriteAsClientCallbackLocation)
form.Add("code", code)
req, err := http.NewRequest("POST", h.Config.App.OAuthProviderTokenLocation, strings.NewReader(form.Encode()))
req, err := http.NewRequest("POST", h.Config.App.OAuth.WriteAsProviderTokenLocation, strings.NewReader(form.Encode()))
if err != nil {
return nil, err
}
@ -200,7 +237,7 @@ func (h oauthHandler) exchangeOauthCode(ctx context.Context, code string) (*Toke
req.Header.Set("User-Agent", "writefreely")
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth(h.Config.App.OAuthClientID, h.Config.App.OAuthClientSecret)
req.SetBasicAuth(h.Config.App.OAuth.WriteAsClientID, h.Config.App.OAuth.WriteAsClientSecret)
resp, err := h.HttpClient.Do(req)
if err != nil {
@ -224,7 +261,7 @@ func (h oauthHandler) exchangeOauthCode(ctx context.Context, code string) (*Toke
}
func (h oauthHandler) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) {
req, err := http.NewRequest("GET", h.Config.App.OAuthProviderInspectLocation, nil)
req, err := http.NewRequest("GET", h.Config.App.OAuth.WriteAsProviderInspectLocation, nil)
if err != nil {
return nil, err
}

@ -86,9 +86,7 @@ func InitRoutes(apper Apper, r *mux.Router) *mux.Router {
DB: apper.App().DB(),
Store: apper.App().SessionStore(),
}
write.HandleFunc("/oauth/write.as", oauthHandler.viewOauthInit).Methods("GET")
write.HandleFunc("/oauth/callback", oauthHandler.viewOauthCallback).Methods("GET")
oauthHandler.configureRoutes(write)
// Handle logged in user sections
me := write.PathPrefix("/me").Subrouter()

Loading…
Cancel
Save