|
|
|
@ -2,7 +2,6 @@ package writefreely |
|
|
|
|
|
|
|
|
|
import ( |
|
|
|
|
"context" |
|
|
|
|
"encoding/hex" |
|
|
|
|
"encoding/json" |
|
|
|
|
"fmt" |
|
|
|
|
"github.com/gorilla/mux" |
|
|
|
@ -10,14 +9,10 @@ import ( |
|
|
|
|
"github.com/guregu/null/zero" |
|
|
|
|
"github.com/writeas/nerds/store" |
|
|
|
|
"github.com/writeas/web-core/auth" |
|
|
|
|
"github.com/writeas/web-core/log" |
|
|
|
|
"github.com/writeas/writefreely/config" |
|
|
|
|
"hash/fnv" |
|
|
|
|
"io" |
|
|
|
|
"io/ioutil" |
|
|
|
|
"net/http" |
|
|
|
|
"net/url" |
|
|
|
|
"strings" |
|
|
|
|
"time" |
|
|
|
|
) |
|
|
|
|
|
|
|
|
@ -33,7 +28,7 @@ type TokenResponse struct { |
|
|
|
|
// InspectResponse contains data returned when an access token is inspected.
|
|
|
|
|
type InspectResponse struct { |
|
|
|
|
ClientID string `json:"client_id"` |
|
|
|
|
UserID int64 `json:"user_id"` |
|
|
|
|
UserID string `json:"user_id"` |
|
|
|
|
ExpiresAt time.Time `json:"expires_at"` |
|
|
|
|
Username string `json:"username"` |
|
|
|
|
Email string `json:"email"` |
|
|
|
@ -58,9 +53,9 @@ type OAuthDatastoreProvider interface { |
|
|
|
|
// OAuthDatastore provides a minimal interface of data store methods used in
|
|
|
|
|
// oauth functionality.
|
|
|
|
|
type OAuthDatastore interface { |
|
|
|
|
GetIDForRemoteUser(context.Context, int64) (int64, error) |
|
|
|
|
RecordRemoteUserID(context.Context, int64, int64) error |
|
|
|
|
ValidateOAuthState(context.Context, string, string, string) error |
|
|
|
|
GetIDForRemoteUser(context.Context, string) (int64, error) |
|
|
|
|
RecordRemoteUserID(context.Context, int64, string) error |
|
|
|
|
ValidateOAuthState(context.Context, string) (string, string, error) |
|
|
|
|
GenerateOAuthState(context.Context, string, string) (string, error) |
|
|
|
|
|
|
|
|
|
CreateUser(*config.Config, *User, string) error |
|
|
|
@ -71,36 +66,28 @@ type HttpClient interface { |
|
|
|
|
Do(req *http.Request) (*http.Response, error) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
type oauthClient interface { |
|
|
|
|
GetProvider() string |
|
|
|
|
GetClientID() string |
|
|
|
|
buildLoginURL(state string) (string, error) |
|
|
|
|
exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) |
|
|
|
|
inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
type oauthHandler struct { |
|
|
|
|
Config *config.Config |
|
|
|
|
DB OAuthDatastore |
|
|
|
|
Store sessions.Store |
|
|
|
|
HttpClient HttpClient |
|
|
|
|
oauthClient oauthClient |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// buildAuthURL returns a URL used to initiate authentication.
|
|
|
|
|
func buildAuthURL(db OAuthDatastore, ctx context.Context, provider, clientID, authLocation, callbackURL string) (string, error) { |
|
|
|
|
state, err := db.GenerateOAuthState(ctx, provider, clientID) |
|
|
|
|
if err != nil { |
|
|
|
|
return "", err |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
u, err := url.Parse(authLocation) |
|
|
|
|
func (h oauthHandler) viewOauthInit(w http.ResponseWriter, r *http.Request) { |
|
|
|
|
ctx := r.Context() |
|
|
|
|
state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID()) |
|
|
|
|
if err != nil { |
|
|
|
|
return "", err |
|
|
|
|
failOAuthRequest(w, http.StatusInternalServerError, "could not prepare oauth redirect url") |
|
|
|
|
} |
|
|
|
|
q := u.Query() |
|
|
|
|
q.Set("client_id", clientID) |
|
|
|
|
q.Set("redirect_uri", callbackURL) |
|
|
|
|
q.Set("response_type", "code") |
|
|
|
|
q.Set("state", state) |
|
|
|
|
u.RawQuery = q.Encode() |
|
|
|
|
|
|
|
|
|
return u.String(), nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (h oauthHandler) viewOauthInitWriteAs(w http.ResponseWriter, r *http.Request) { |
|
|
|
|
location, err := buildAuthURL(h.DB, r.Context(), "write.as", h.Config.App.OAuth.WriteAsClientID, h.Config.App.OAuth.WriteAsProviderAuthLocation, h.Config.App.OAuth.WriteAsClientCallbackLocation) |
|
|
|
|
location, err := h.oauthClient.buildLoginURL(state) |
|
|
|
|
if err != nil { |
|
|
|
|
failOAuthRequest(w, http.StatusInternalServerError, "could not prepare oauth redirect url") |
|
|
|
|
return |
|
|
|
@ -108,52 +95,58 @@ func (h oauthHandler) viewOauthInitWriteAs(w http.ResponseWriter, r *http.Reques |
|
|
|
|
http.Redirect(w, r, location, http.StatusTemporaryRedirect) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (h oauthHandler) viewOauthInitSlack(w http.ResponseWriter, r *http.Request) { |
|
|
|
|
location, err := buildAuthURL(h.DB, r.Context(), "slack", h.Config.App.OAuth.WriteAsClientID, h.Config.App.OAuth.WriteAsProviderAuthLocation, h.Config.App.OAuth.WriteAsClientCallbackLocation) |
|
|
|
|
if err != nil { |
|
|
|
|
failOAuthRequest(w, http.StatusInternalServerError, "could not prepare oauth redirect url") |
|
|
|
|
return |
|
|
|
|
func configureSlackOauth(r *mux.Router, app *App) { |
|
|
|
|
if app.Config().SlackOauth.ClientID != "" { |
|
|
|
|
oauthClient := slackOauthClient{ |
|
|
|
|
ClientID: app.Config().SlackOauth.ClientID, |
|
|
|
|
ClientSecret: app.Config().SlackOauth.ClientSecret, |
|
|
|
|
TeamID: app.Config().SlackOauth.TeamID, |
|
|
|
|
CallbackLocation: app.Config().App.Host + "/oauth/callback", |
|
|
|
|
HttpClient: &http.Client{Timeout: 10 * time.Second}, |
|
|
|
|
} |
|
|
|
|
configureOauthRoutes(r, app, oauthClient) |
|
|
|
|
} |
|
|
|
|
http.Redirect(w, r, location, http.StatusTemporaryRedirect) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (h oauthHandler) configureRoutes(r *mux.Router) { |
|
|
|
|
if h.Config.App.OAuth.Enabled { |
|
|
|
|
if h.Config.App.OAuth.WriteAsClientID != "" { |
|
|
|
|
callbackHash := oauthProviderHash("write.as", h.Config.App.OAuth.WriteAsClientID) |
|
|
|
|
log.InfoLog.Println("write.as oauth callback URL", "/oauth/callback/"+callbackHash) |
|
|
|
|
r.HandleFunc("/oauth/write.as", h.viewOauthInitWriteAs).Methods("GET") |
|
|
|
|
r.HandleFunc("/oauth/callback/"+callbackHash, h.viewOauthCallback("write.as", h.Config.App.OAuth.WriteAsClientID)).Methods("GET") |
|
|
|
|
} |
|
|
|
|
if h.Config.App.OAuth.SlackClientID != "" { |
|
|
|
|
callbackHash := oauthProviderHash("slack", h.Config.App.OAuth.SlackClientID) |
|
|
|
|
log.InfoLog.Println("slack oauth callback URL", "/oauth/callback/"+callbackHash) |
|
|
|
|
r.HandleFunc("/oauth/slack", h.viewOauthInitSlack).Methods("GET") |
|
|
|
|
r.HandleFunc("/oauth/callback/"+callbackHash, h.viewOauthCallback("slack", h.Config.App.OAuth.SlackClientID)).Methods("GET") |
|
|
|
|
func configureWriteAsOauth(r *mux.Router, app *App) { |
|
|
|
|
if app.Config().WriteAsOauth.ClientID != "" { |
|
|
|
|
oauthClient := writeAsOauthClient{ |
|
|
|
|
ClientID: app.Config().WriteAsOauth.ClientID, |
|
|
|
|
ClientSecret: app.Config().WriteAsOauth.ClientSecret, |
|
|
|
|
ExchangeLocation: app.Config().WriteAsOauth.TokenLocation, |
|
|
|
|
InspectLocation: app.Config().WriteAsOauth.InspectLocation, |
|
|
|
|
AuthLocation: app.Config().WriteAsOauth.AuthLocation, |
|
|
|
|
HttpClient: &http.Client{Timeout: 10 * time.Second}, |
|
|
|
|
CallbackLocation: app.Config().App.Host + "/oauth/callback", |
|
|
|
|
} |
|
|
|
|
configureOauthRoutes(r, app, oauthClient) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func oauthProviderHash(provider, clientID string) string { |
|
|
|
|
hasher := fnv.New32() |
|
|
|
|
return hex.EncodeToString(hasher.Sum([]byte(provider + clientID))) |
|
|
|
|
func configureOauthRoutes(r *mux.Router, app *App, oauthClient oauthClient) { |
|
|
|
|
handler := &oauthHandler{ |
|
|
|
|
Config: app.Config(), |
|
|
|
|
DB: app.DB(), |
|
|
|
|
Store: app.SessionStore(), |
|
|
|
|
oauthClient: oauthClient, |
|
|
|
|
} |
|
|
|
|
r.HandleFunc("/oauth/"+oauthClient.GetProvider(), handler.viewOauthInit).Methods("GET") |
|
|
|
|
r.HandleFunc("/oauth/callback", handler.viewOauthCallback).Methods("GET") |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (h oauthHandler) viewOauthCallback(provider, clientID string) http.HandlerFunc { |
|
|
|
|
return func(w http.ResponseWriter, r *http.Request) { |
|
|
|
|
func (h oauthHandler) viewOauthCallback(w http.ResponseWriter, r *http.Request) { |
|
|
|
|
ctx := r.Context() |
|
|
|
|
|
|
|
|
|
code := r.FormValue("code") |
|
|
|
|
state := r.FormValue("state") |
|
|
|
|
|
|
|
|
|
err := h.DB.ValidateOAuthState(ctx, state, provider, clientID) |
|
|
|
|
_, _, err := h.DB.ValidateOAuthState(ctx, state) |
|
|
|
|
if err != nil { |
|
|
|
|
failOAuthRequest(w, http.StatusInternalServerError, err.Error()) |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
tokenResponse, err := h.exchangeOauthCode(ctx, code) |
|
|
|
|
tokenResponse, err := h.oauthClient.exchangeOauthCode(ctx, code) |
|
|
|
|
if err != nil { |
|
|
|
|
failOAuthRequest(w, http.StatusInternalServerError, err.Error()) |
|
|
|
|
return |
|
|
|
@ -161,7 +154,7 @@ func (h oauthHandler) viewOauthCallback(provider, clientID string) http.HandlerF |
|
|
|
|
|
|
|
|
|
// Now that we have the access token, let's use it real quick to make sur
|
|
|
|
|
// it really really works.
|
|
|
|
|
tokenInfo, err := h.inspectOauthAccessToken(ctx, tokenResponse.AccessToken) |
|
|
|
|
tokenInfo, err := h.oauthClient.inspectOauthAccessToken(ctx, tokenResponse.AccessToken) |
|
|
|
|
if err != nil { |
|
|
|
|
failOAuthRequest(w, http.StatusInternalServerError, err.Error()) |
|
|
|
|
return |
|
|
|
@ -173,8 +166,6 @@ func (h oauthHandler) viewOauthCallback(provider, clientID string) http.HandlerF |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
fmt.Println("local user id", localUserID) |
|
|
|
|
|
|
|
|
|
if localUserID == -1 { |
|
|
|
|
// We don't have, nor do we want, the password from the origin, so we
|
|
|
|
|
//create a random string. If the user needs to set a password, they
|
|
|
|
@ -183,7 +174,6 @@ func (h oauthHandler) viewOauthCallback(provider, clientID string) http.HandlerF |
|
|
|
|
randPass := store.Generate62RandomString(14) |
|
|
|
|
hashedPass, err := auth.HashPass([]byte(randPass)) |
|
|
|
|
if err != nil { |
|
|
|
|
log.ErrorLog.Println(err) |
|
|
|
|
failOAuthRequest(w, http.StatusInternalServerError, "unable to create password hash") |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
@ -191,7 +181,7 @@ func (h oauthHandler) viewOauthCallback(provider, clientID string) http.HandlerF |
|
|
|
|
Username: tokenInfo.Username, |
|
|
|
|
HashedPass: hashedPass, |
|
|
|
|
HasPass: true, |
|
|
|
|
Email: zero.NewString("", tokenInfo.Email != ""), |
|
|
|
|
Email: zero.NewString(tokenInfo.Email, tokenInfo.Email != ""), |
|
|
|
|
Created: time.Now().Truncate(time.Second).UTC(), |
|
|
|
|
} |
|
|
|
|
|
|
|
|
@ -221,74 +211,18 @@ func (h oauthHandler) viewOauthCallback(provider, clientID string) http.HandlerF |
|
|
|
|
if err = loginOrFail(h.Store, w, r, user); err != nil { |
|
|
|
|
failOAuthRequest(w, http.StatusInternalServerError, err.Error()) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (h oauthHandler) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) { |
|
|
|
|
form := url.Values{} |
|
|
|
|
form.Add("grant_type", "authorization_code") |
|
|
|
|
form.Add("redirect_uri", h.Config.App.OAuth.WriteAsClientCallbackLocation) |
|
|
|
|
form.Add("code", code) |
|
|
|
|
req, err := http.NewRequest("POST", h.Config.App.OAuth.WriteAsProviderTokenLocation, strings.NewReader(form.Encode())) |
|
|
|
|
if err != nil { |
|
|
|
|
return nil, err |
|
|
|
|
} |
|
|
|
|
req.WithContext(ctx) |
|
|
|
|
req.Header.Set("User-Agent", "writefreely") |
|
|
|
|
req.Header.Set("Accept", "application/json") |
|
|
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") |
|
|
|
|
req.SetBasicAuth(h.Config.App.OAuth.WriteAsClientID, h.Config.App.OAuth.WriteAsClientSecret) |
|
|
|
|
|
|
|
|
|
resp, err := h.HttpClient.Do(req) |
|
|
|
|
if err != nil { |
|
|
|
|
return nil, err |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Nick: I like using limited readers to reduce the risk of an endpoint
|
|
|
|
|
// being broken or compromised.
|
|
|
|
|
lr := io.LimitReader(resp.Body, tokenRequestMaxLen) |
|
|
|
|
body, err := ioutil.ReadAll(lr) |
|
|
|
|
if err != nil { |
|
|
|
|
return nil, err |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
var tokenResponse TokenResponse |
|
|
|
|
err = json.Unmarshal(body, &tokenResponse) |
|
|
|
|
if err != nil { |
|
|
|
|
return nil, err |
|
|
|
|
} |
|
|
|
|
return &tokenResponse, nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (h oauthHandler) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) { |
|
|
|
|
req, err := http.NewRequest("GET", h.Config.App.OAuth.WriteAsProviderInspectLocation, nil) |
|
|
|
|
if err != nil { |
|
|
|
|
return nil, err |
|
|
|
|
} |
|
|
|
|
req.WithContext(ctx) |
|
|
|
|
req.Header.Set("User-Agent", "writefreely") |
|
|
|
|
req.Header.Set("Accept", "application/json") |
|
|
|
|
req.Header.Set("Authorization", "Bearer "+accessToken) |
|
|
|
|
|
|
|
|
|
resp, err := h.HttpClient.Do(req) |
|
|
|
|
if err != nil { |
|
|
|
|
return nil, err |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Nick: I like using limited readers to reduce the risk of an endpoint
|
|
|
|
|
// being broken or compromised.
|
|
|
|
|
lr := io.LimitReader(resp.Body, infoRequestMaxLen) |
|
|
|
|
body, err := ioutil.ReadAll(lr) |
|
|
|
|
func limitedJsonUnmarshal(body io.ReadCloser, n int, thing interface{}) error { |
|
|
|
|
lr := io.LimitReader(body, int64(n+1)) |
|
|
|
|
data, err := ioutil.ReadAll(lr) |
|
|
|
|
if err != nil { |
|
|
|
|
return nil, err |
|
|
|
|
return err |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
var inspectResponse InspectResponse |
|
|
|
|
err = json.Unmarshal(body, &inspectResponse) |
|
|
|
|
if err != nil { |
|
|
|
|
return nil, err |
|
|
|
|
if len(data) == n+1 { |
|
|
|
|
return fmt.Errorf("content larger than max read allowance: %d", n) |
|
|
|
|
} |
|
|
|
|
return &inspectResponse, nil |
|
|
|
|
return json.Unmarshal(data, thing) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func loginOrFail(store sessions.Store, w http.ResponseWriter, r *http.Request, user *User) error { |
|
|
|
|