Feature complete on MVP slack auth integration. T710

pull/232/head
Nick Gerakines 5 years ago
parent 13121cb266
commit 462f87919a
  1. 33
      config/config.go
  2. 28
      database.go
  3. 6
      database_test.go
  4. 40
      db/alter.go
  5. 56
      db/alter_test.go
  6. 29
      db/create.go
  7. 43
      db/dialect.go
  8. 1
      migrations/migrations.go
  9. 41
      migrations/v5.go
  10. 186
      oauth.go
  11. 148
      oauth_slack.go
  12. 72
      oauth_test.go
  13. 91
      oauth_writeas.go
  14. 12
      routes.go

@ -56,23 +56,18 @@ type (
Port int `ini:"port"`
}
OAuthCfg struct {
Enabled bool `ini:"enable"`
// write.as
WriteAsProviderAuthLocation string `ini:"wa_auth_location"`
WriteAsProviderTokenLocation string `ini:"wa_token_location"`
WriteAsProviderInspectLocation string `ini:"wa_inspect_location"`
WriteAsClientCallbackLocation string `ini:"wa_callback_location"`
WriteAsClientID string `ini:"wa_client_id"`
WriteAsClientSecret string `ini:"wa_client_secret"`
WriteAsAuthLocation string
// slack
SlackClientID string `ini:"slack_client_id"`
SlackClientSecret string `ini:"slack_client_secret"`
SlackTeamID string `init:"slack_team_id"`
SlackAuthLocation string
WriteAsOauthCfg struct {
ClientID string `ini:"client_id"`
ClientSecret string `ini:"client_secret"`
AuthLocation string `ini:"auth_location"`
TokenLocation string `ini:"token_location"`
InspectLocation string `ini:"inspect_location"`
}
SlackOauthCfg struct {
ClientID string `ini:"client_id"`
ClientSecret string `ini:"client_secret"`
TeamID string `ini:"team_id"`
}
// AppCfg holds values that affect how the application functions
@ -113,8 +108,6 @@ type (
// Defaults
DefaultVisibility string `ini:"default_visibility"`
OAuth OAuthCfg `ini:"oauth"`
}
// Config holds the complete configuration for running a writefreely instance
@ -122,6 +115,8 @@ type (
Server ServerCfg `ini:"server"`
Database DatabaseCfg `ini:"database"`
App AppCfg `ini:"app"`
SlackOauth SlackOauthCfg `ini:"oauth.slack"`
WriteAsOauth WriteAsOauthCfg `ini:"oauth.writeas"`
}
)

@ -14,6 +14,7 @@ import (
"context"
"database/sql"
"fmt"
wf_db "github.com/writeas/writefreely/db"
"net/http"
"strings"
"time"
@ -125,9 +126,9 @@ type writestore interface {
GetUserLastPostTime(id int64) (*time.Time, error)
GetCollectionLastPostTime(id int64) (*time.Time, error)
GetIDForRemoteUser(context.Context, int64) (int64, error)
RecordRemoteUserID(context.Context, int64, int64) error
ValidateOAuthState(context.Context, string, string, string) error
GetIDForRemoteUser(context.Context, string) (int64, error)
RecordRemoteUserID(context.Context, int64, string) error
ValidateOAuthState(context.Context, string) (string, string, error)
GenerateOAuthState(context.Context, string, string) (string, error)
DatabaseInitialized() bool
@ -2470,8 +2471,16 @@ func (db *datastore) GenerateOAuthState(ctx context.Context, provider, clientID
return state, nil
}
func (db *datastore) ValidateOAuthState(ctx context.Context, state, provider, clientID string) error {
res, err := db.ExecContext(ctx, "UPDATE oauth_client_state SET used = TRUE WHERE state = ? AND provider = ? AND client_id = ?", state, provider, clientID)
func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (string, string, error) {
var provider string
var clientID string
err := wf_db.RunTransactionWithOptions(ctx, db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error {
err := tx.QueryRow("SELECT provider, client_id FROM oauth_client_state WHERE state = ? AND used = FALSE", state).Scan(&provider, &clientID)
if err != nil {
return err
}
res, err := tx.ExecContext(ctx, "UPDATE oauth_client_state SET used = TRUE WHERE state = ?", state)
if err != nil {
return err
}
@ -2483,9 +2492,14 @@ func (db *datastore) ValidateOAuthState(ctx context.Context, state, provider, cl
return fmt.Errorf("state not found")
}
return nil
})
if err != nil {
return "", "", nil
}
return provider, clientID, nil
}
func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID, remoteUserID int64) error {
func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID string) error {
var err error
if db.driverName == driverSQLite {
_, err = db.ExecContext(ctx, "INSERT OR REPLACE INTO users_oauth (user_id, remote_user_id) VALUES (?, ?)", localUserID, remoteUserID)
@ -2499,7 +2513,7 @@ func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID, remote
}
// GetIDForRemoteUser returns a user ID associated with a remote user ID.
func (db *datastore) GetIDForRemoteUser(ctx context.Context, remoteUserID int64) (int64, error) {
func (db *datastore) GetIDForRemoteUser(ctx context.Context, remoteUserID string) (int64, error) {
var userID int64 = -1
err := db.
QueryRowContext(ctx, "SELECT user_id FROM users_oauth WHERE remote_user_id = ?", remoteUserID).

@ -18,19 +18,19 @@ func TestOAuthDatastore(t *testing.T) {
driverName: "",
}
state, err := ds.GenerateOAuthState(ctx, "", "")
state, err := ds.GenerateOAuthState(ctx, "test", "development")
assert.NoError(t, err)
assert.Len(t, state, 24)
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_state` WHERE `state` = ? AND `used` = false", state)
err = ds.ValidateOAuthState(ctx, state)
_, _, err = ds.ValidateOAuthState(ctx, state)
assert.NoError(t, err)
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_state` WHERE `state` = ? AND `used` = true", state)
var localUserID int64 = 99
var remoteUserID int64 = 100
var remoteUserID = "100"
err = ds.RecordRemoteUserID(ctx, localUserID, remoteUserID)
assert.NoError(t, err)

@ -0,0 +1,40 @@
package db
import (
"fmt"
"strings"
)
type AlterTableSqlBuilder struct {
Dialect DialectType
Name string
Changes []string
}
func (b *AlterTableSqlBuilder) AddColumn(col *Column) *AlterTableSqlBuilder {
if colVal, err := col.String(); err == nil {
b.Changes = append(b.Changes, fmt.Sprintf("ADD COLUMN %s", colVal))
}
return b
}
func (b *AlterTableSqlBuilder) ToSQL() (string, error) {
var str strings.Builder
str.WriteString("ALTER TABLE ")
str.WriteString(b.Name)
str.WriteString(" ")
if len(b.Changes) == 0 {
return "", fmt.Errorf("no changes provide for table: %s", b.Name)
}
changeCount := len(b.Changes)
for i, thing := range b.Changes {
str.WriteString(thing)
if i < changeCount-1 {
str.WriteString(", ")
}
}
return str.String(), nil
}

@ -0,0 +1,56 @@
package db
import "testing"
func TestAlterTableSqlBuilder_ToSQL(t *testing.T) {
type fields struct {
Dialect DialectType
Name string
Changes []string
}
tests := []struct {
name string
builder *AlterTableSqlBuilder
want string
wantErr bool
}{
{
name: "MySQL add int",
builder: DialectMySQL.
AlterTable("the_table").
AddColumn(DialectMySQL.Column("the_col", ColumnTypeInteger, UnsetSize)),
want: "ALTER TABLE the_table ADD COLUMN the_col INT NOT NULL",
wantErr: false,
},
{
name: "MySQL add string",
builder: DialectMySQL.
AlterTable("the_table").
AddColumn(DialectMySQL.Column("the_col", ColumnTypeVarChar, OptionalInt{true, 128})),
want: "ALTER TABLE the_table ADD COLUMN the_col VARCHAR(128) NOT NULL",
wantErr: false,
},
{
name: "MySQL add int and string",
builder: DialectMySQL.
AlterTable("the_table").
AddColumn(DialectMySQL.Column("first_col", ColumnTypeInteger, UnsetSize)).
AddColumn(DialectMySQL.Column("second_col", ColumnTypeVarChar, OptionalInt{true, 128})),
want: "ALTER TABLE the_table ADD COLUMN first_col INT NOT NULL, ADD COLUMN second_col VARCHAR(128) NOT NULL",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.builder.ToSQL()
if (err != nil) != tt.wantErr {
t.Errorf("ToSQL() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("ToSQL() got = %v, want %v", got, tt.want)
}
})
}
}

@ -5,7 +5,6 @@ import (
"strings"
)
type DialectType int
type ColumnType int
type OptionalInt struct {
@ -41,11 +40,6 @@ type CreateTableSqlBuilder struct {
Constraints []string
}
const (
DialectSQLite DialectType = iota
DialectMySQL DialectType = iota
)
const (
ColumnTypeBool ColumnType = iota
ColumnTypeSmallInt ColumnType = iota
@ -61,28 +55,6 @@ var _ SQLBuilder = &CreateTableSqlBuilder{}
var UnsetSize OptionalInt = OptionalInt{Set: false, Value: 0}
var UnsetDefault OptionalString = OptionalString{Set: false, Value: ""}
func (d DialectType) Column(name string, t ColumnType, size OptionalInt) *Column {
switch d {
case DialectSQLite:
return &Column{Dialect: DialectSQLite, Name: name, Type: t, Size: size}
case DialectMySQL:
return &Column{Dialect: DialectMySQL, Name: name, Type: t, Size: size}
default:
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
}
func (d DialectType) Table(name string) *CreateTableSqlBuilder {
switch d {
case DialectSQLite:
return &CreateTableSqlBuilder{Dialect: DialectSQLite, Name: name}
case DialectMySQL:
return &CreateTableSqlBuilder{Dialect: DialectMySQL, Name: name}
default:
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
}
func (d ColumnType) Format(dialect DialectType, size OptionalInt) (string, error) {
if dialect != DialectMySQL && dialect != DialectSQLite {
return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size)
@ -269,3 +241,4 @@ func (b *CreateTableSqlBuilder) ToSQL() (string, error) {
return str.String(), nil
}

@ -0,0 +1,43 @@
package db
import "fmt"
type DialectType int
const (
DialectSQLite DialectType = iota
DialectMySQL DialectType = iota
)
func (d DialectType) Column(name string, t ColumnType, size OptionalInt) *Column {
switch d {
case DialectSQLite:
return &Column{Dialect: DialectSQLite, Name: name, Type: t, Size: size}
case DialectMySQL:
return &Column{Dialect: DialectMySQL, Name: name, Type: t, Size: size}
default:
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
}
func (d DialectType) Table(name string) *CreateTableSqlBuilder {
switch d {
case DialectSQLite:
return &CreateTableSqlBuilder{Dialect: DialectSQLite, Name: name}
case DialectMySQL:
return &CreateTableSqlBuilder{Dialect: DialectMySQL, Name: name}
default:
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
}
func (d DialectType) AlterTable(name string) *AlterTableSqlBuilder {
switch d {
case DialectSQLite:
return &AlterTableSqlBuilder{Dialect: DialectSQLite, Name: name}
case DialectMySQL:
return &AlterTableSqlBuilder{Dialect: DialectMySQL, Name: name}
default:
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
}

@ -60,6 +60,7 @@ var migrations = []Migration{
New("support dynamic instance pages", supportInstancePages), // V1 -> V2 (v0.9.0)
New("support users suspension", supportUserStatus), // V2 -> V3 (v0.11.0)
New("support oauth", oauth), // V3 -> V4
New("support slack oauth", oauth_slack), // V4 -> v5
}
// CurrentVer returns the current migration version the application is on

@ -0,0 +1,41 @@
package migrations
import (
"context"
"database/sql"
wf_db "github.com/writeas/writefreely/db"
)
func oauth_slack(db *datastore) error {
dialect := wf_db.DialectMySQL
if db.driverName == driverSQLite {
dialect = wf_db.DialectSQLite
}
return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error {
builders := []wf_db.SQLBuilder{
dialect.
AlterTable("oauth_client_state").
AddColumn(dialect.
Column(
"provider",
wf_db.ColumnTypeVarChar,
wf_db.OptionalInt{Set: true, Value: 24,})).
AddColumn(dialect.
Column(
"client_id",
wf_db.ColumnTypeVarChar,
wf_db.OptionalInt{Set: true, Value: 128,})),
}
for _, builder := range builders {
query, err := builder.ToSQL()
if err != nil {
return err
}
if _, err := tx.ExecContext(ctx, query); err != nil {
return err
}
}
return nil
})
}

@ -2,7 +2,6 @@ package writefreely
import (
"context"
"encoding/hex"
"encoding/json"
"fmt"
"github.com/gorilla/mux"
@ -10,14 +9,10 @@ import (
"github.com/guregu/null/zero"
"github.com/writeas/nerds/store"
"github.com/writeas/web-core/auth"
"github.com/writeas/web-core/log"
"github.com/writeas/writefreely/config"
"hash/fnv"
"io"
"io/ioutil"
"net/http"
"net/url"
"strings"
"time"
)
@ -33,7 +28,7 @@ type TokenResponse struct {
// InspectResponse contains data returned when an access token is inspected.
type InspectResponse struct {
ClientID string `json:"client_id"`
UserID int64 `json:"user_id"`
UserID string `json:"user_id"`
ExpiresAt time.Time `json:"expires_at"`
Username string `json:"username"`
Email string `json:"email"`
@ -58,9 +53,9 @@ type OAuthDatastoreProvider interface {
// OAuthDatastore provides a minimal interface of data store methods used in
// oauth functionality.
type OAuthDatastore interface {
GetIDForRemoteUser(context.Context, int64) (int64, error)
RecordRemoteUserID(context.Context, int64, int64) error
ValidateOAuthState(context.Context, string, string, string) error
GetIDForRemoteUser(context.Context, string) (int64, error)
RecordRemoteUserID(context.Context, int64, string) error
ValidateOAuthState(context.Context, string) (string, string, error)
GenerateOAuthState(context.Context, string, string) (string, error)
CreateUser(*config.Config, *User, string) error
@ -71,36 +66,28 @@ type HttpClient interface {
Do(req *http.Request) (*http.Response, error)
}
type oauthClient interface {
GetProvider() string
GetClientID() string
buildLoginURL(state string) (string, error)
exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error)
inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error)
}
type oauthHandler struct {
Config *config.Config
DB OAuthDatastore
Store sessions.Store
HttpClient HttpClient
oauthClient oauthClient
}
// buildAuthURL returns a URL used to initiate authentication.
func buildAuthURL(db OAuthDatastore, ctx context.Context, provider, clientID, authLocation, callbackURL string) (string, error) {
state, err := db.GenerateOAuthState(ctx, provider, clientID)
if err != nil {
return "", err
}
u, err := url.Parse(authLocation)
func (h oauthHandler) viewOauthInit(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID())
if err != nil {
return "", err
failOAuthRequest(w, http.StatusInternalServerError, "could not prepare oauth redirect url")
}
q := u.Query()
q.Set("client_id", clientID)
q.Set("redirect_uri", callbackURL)
q.Set("response_type", "code")
q.Set("state", state)
u.RawQuery = q.Encode()
return u.String(), nil
}
func (h oauthHandler) viewOauthInitWriteAs(w http.ResponseWriter, r *http.Request) {
location, err := buildAuthURL(h.DB, r.Context(), "write.as", h.Config.App.OAuth.WriteAsClientID, h.Config.App.OAuth.WriteAsProviderAuthLocation, h.Config.App.OAuth.WriteAsClientCallbackLocation)
location, err := h.oauthClient.buildLoginURL(state)
if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, "could not prepare oauth redirect url")
return
@ -108,52 +95,58 @@ func (h oauthHandler) viewOauthInitWriteAs(w http.ResponseWriter, r *http.Reques
http.Redirect(w, r, location, http.StatusTemporaryRedirect)
}
func (h oauthHandler) viewOauthInitSlack(w http.ResponseWriter, r *http.Request) {
location, err := buildAuthURL(h.DB, r.Context(), "slack", h.Config.App.OAuth.WriteAsClientID, h.Config.App.OAuth.WriteAsProviderAuthLocation, h.Config.App.OAuth.WriteAsClientCallbackLocation)
if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, "could not prepare oauth redirect url")
return
func configureSlackOauth(r *mux.Router, app *App) {
if app.Config().SlackOauth.ClientID != "" {
oauthClient := slackOauthClient{
ClientID: app.Config().SlackOauth.ClientID,
ClientSecret: app.Config().SlackOauth.ClientSecret,
TeamID: app.Config().SlackOauth.TeamID,
CallbackLocation: app.Config().App.Host + "/oauth/callback",
HttpClient: &http.Client{Timeout: 10 * time.Second},
}
configureOauthRoutes(r, app, oauthClient)
}
http.Redirect(w, r, location, http.StatusTemporaryRedirect)
}
func (h oauthHandler) configureRoutes(r *mux.Router) {
if h.Config.App.OAuth.Enabled {
if h.Config.App.OAuth.WriteAsClientID != "" {
callbackHash := oauthProviderHash("write.as", h.Config.App.OAuth.WriteAsClientID)
log.InfoLog.Println("write.as oauth callback URL", "/oauth/callback/"+callbackHash)
r.HandleFunc("/oauth/write.as", h.viewOauthInitWriteAs).Methods("GET")
r.HandleFunc("/oauth/callback/"+callbackHash, h.viewOauthCallback("write.as", h.Config.App.OAuth.WriteAsClientID)).Methods("GET")
}
if h.Config.App.OAuth.SlackClientID != "" {
callbackHash := oauthProviderHash("slack", h.Config.App.OAuth.SlackClientID)
log.InfoLog.Println("slack oauth callback URL", "/oauth/callback/"+callbackHash)
r.HandleFunc("/oauth/slack", h.viewOauthInitSlack).Methods("GET")
r.HandleFunc("/oauth/callback/"+callbackHash, h.viewOauthCallback("slack", h.Config.App.OAuth.SlackClientID)).Methods("GET")
func configureWriteAsOauth(r *mux.Router, app *App) {
if app.Config().WriteAsOauth.ClientID != "" {
oauthClient := writeAsOauthClient{
ClientID: app.Config().WriteAsOauth.ClientID,
ClientSecret: app.Config().WriteAsOauth.ClientSecret,
ExchangeLocation: app.Config().WriteAsOauth.TokenLocation,
InspectLocation: app.Config().WriteAsOauth.InspectLocation,
AuthLocation: app.Config().WriteAsOauth.AuthLocation,
HttpClient: &http.Client{Timeout: 10 * time.Second},
CallbackLocation: app.Config().App.Host + "/oauth/callback",
}
configureOauthRoutes(r, app, oauthClient)
}
}
func oauthProviderHash(provider, clientID string) string {
hasher := fnv.New32()
return hex.EncodeToString(hasher.Sum([]byte(provider + clientID)))
func configureOauthRoutes(r *mux.Router, app *App, oauthClient oauthClient) {
handler := &oauthHandler{
Config: app.Config(),
DB: app.DB(),
Store: app.SessionStore(),
oauthClient: oauthClient,
}
r.HandleFunc("/oauth/"+oauthClient.GetProvider(), handler.viewOauthInit).Methods("GET")
r.HandleFunc("/oauth/callback", handler.viewOauthCallback).Methods("GET")
}
func (h oauthHandler) viewOauthCallback(provider, clientID string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
func (h oauthHandler) viewOauthCallback(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
code := r.FormValue("code")
state := r.FormValue("state")
err := h.DB.ValidateOAuthState(ctx, state, provider, clientID)
_, _, err := h.DB.ValidateOAuthState(ctx, state)
if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
return
}
tokenResponse, err := h.exchangeOauthCode(ctx, code)
tokenResponse, err := h.oauthClient.exchangeOauthCode(ctx, code)
if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
return
@ -161,7 +154,7 @@ func (h oauthHandler) viewOauthCallback(provider, clientID string) http.HandlerF
// Now that we have the access token, let's use it real quick to make sur
// it really really works.
tokenInfo, err := h.inspectOauthAccessToken(ctx, tokenResponse.AccessToken)
tokenInfo, err := h.oauthClient.inspectOauthAccessToken(ctx, tokenResponse.AccessToken)
if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
return
@ -173,8 +166,6 @@ func (h oauthHandler) viewOauthCallback(provider, clientID string) http.HandlerF
return
}
fmt.Println("local user id", localUserID)
if localUserID == -1 {
// We don't have, nor do we want, the password from the origin, so we
//create a random string. If the user needs to set a password, they
@ -183,7 +174,6 @@ func (h oauthHandler) viewOauthCallback(provider, clientID string) http.HandlerF
randPass := store.Generate62RandomString(14)
hashedPass, err := auth.HashPass([]byte(randPass))
if err != nil {
log.ErrorLog.Println(err)
failOAuthRequest(w, http.StatusInternalServerError, "unable to create password hash")
return
}
@ -191,7 +181,7 @@ func (h oauthHandler) viewOauthCallback(provider, clientID string) http.HandlerF
Username: tokenInfo.Username,
HashedPass: hashedPass,
HasPass: true,
Email: zero.NewString("", tokenInfo.Email != ""),
Email: zero.NewString(tokenInfo.Email, tokenInfo.Email != ""),
Created: time.Now().Truncate(time.Second).UTC(),
}
@ -221,74 +211,18 @@ func (h oauthHandler) viewOauthCallback(provider, clientID string) http.HandlerF
if err = loginOrFail(h.Store, w, r, user); err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
}
}
}
func (h oauthHandler) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) {
form := url.Values{}
form.Add("grant_type", "authorization_code")
form.Add("redirect_uri", h.Config.App.OAuth.WriteAsClientCallbackLocation)
form.Add("code", code)
req, err := http.NewRequest("POST", h.Config.App.OAuth.WriteAsProviderTokenLocation, strings.NewReader(form.Encode()))
if err != nil {
return nil, err
}
req.WithContext(ctx)
req.Header.Set("User-Agent", "writefreely")
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth(h.Config.App.OAuth.WriteAsClientID, h.Config.App.OAuth.WriteAsClientSecret)
resp, err := h.HttpClient.Do(req)
if err != nil {
return nil, err
}
// Nick: I like using limited readers to reduce the risk of an endpoint
// being broken or compromised.
lr := io.LimitReader(resp.Body, tokenRequestMaxLen)
body, err := ioutil.ReadAll(lr)
if err != nil {
return nil, err
}
var tokenResponse TokenResponse
err = json.Unmarshal(body, &tokenResponse)
if err != nil {
return nil, err
}
return &tokenResponse, nil
}
func (h oauthHandler) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) {
req, err := http.NewRequest("GET", h.Config.App.OAuth.WriteAsProviderInspectLocation, nil)
if err != nil {
return nil, err
}
req.WithContext(ctx)
req.Header.Set("User-Agent", "writefreely")
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+accessToken)
resp, err := h.HttpClient.Do(req)
if err != nil {
return nil, err
}
// Nick: I like using limited readers to reduce the risk of an endpoint
// being broken or compromised.
lr := io.LimitReader(resp.Body, infoRequestMaxLen)
body, err := ioutil.ReadAll(lr)
func limitedJsonUnmarshal(body io.ReadCloser, n int, thing interface{}) error {
lr := io.LimitReader(body, int64(n+1))
data, err := ioutil.ReadAll(lr)
if err != nil {
return nil, err
return err
}
var inspectResponse InspectResponse
err = json.Unmarshal(body, &inspectResponse)
if err != nil {
return nil, err
if len(data) == n+1 {
return fmt.Errorf("content larger than max read allowance: %d", n)
}
return &inspectResponse, nil
return json.Unmarshal(data, thing)
}
func loginOrFail(store sessions.Store, w http.ResponseWriter, r *http.Request, user *User) error {

@ -0,0 +1,148 @@
package writefreely
import (
"context"
"github.com/writeas/slug"
"net/http"
"net/url"
"strings"
)
type slackOauthClient struct {
ClientID string
ClientSecret string
TeamID string
CallbackLocation string
HttpClient HttpClient
}
type slackExchangeResponse struct {
AccessToken string `json:"access_token"`
Scope string `json:"scope"`
TeamName string `json:"team_name"`
TeamID string `json:"team_id"`
}
type slackIdentity struct {
Name string `json:"name"`
ID string `json:"id"`
Email string `json:"email"`
}
type slackTeam struct {
Name string `json:"name"`
ID string `json:"id"`
}
type slackUserIdentityResponse struct {
OK bool `json:"ok"`
User slackIdentity `json:"user"`
Team slackTeam `json:"team"`
Error string `json:"error"`
}
const (
slackAuthLocation = "https://slack.com/oauth/authorize"
slackExchangeLocation = "https://slack.com/api/oauth.access"
slackIdentityLocation = "https://slack.com/api/users.identity"
)
var _ oauthClient = slackOauthClient{}
func (c slackOauthClient) GetProvider() string {
return "slack"
}
func (c slackOauthClient) GetClientID() string {
return c.ClientID
}
func (c slackOauthClient) buildLoginURL(state string) (string, error) {
u, err := url.Parse(slackAuthLocation)
if err != nil {
return "", err
}
q := u.Query()
q.Set("client_id", c.ClientID)
q.Set("scope", "identity.basic identity.email identity.team")
q.Set("redirect_uri", c.CallbackLocation)
q.Set("state", state)
// If this param is not set, the user can select which team they
// authenticate through and then we'd have to match the configured team
// against the profile get. That is extra work in the post-auth phase
// that we don't want to do.
q.Set("team", c.TeamID)
// The Slack OAuth docs don't explicitly list this one, but it is part of
// the spec, so we include it anyway.
q.Set("response_type", "code")
u.RawQuery = q.Encode()
return u.String(), nil
}
func (c slackOauthClient) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) {
form := url.Values{}
// The oauth.access documentation doesn't explicitly mention this
// parameter, but it is part of the spec, so we include it anyway.
// https://api.slack.com/methods/oauth.access
form.Add("grant_type", "authorization_code")
form.Add("redirect_uri", c.CallbackLocation)
form.Add("code", code)
req, err := http.NewRequest("POST", slackExchangeLocation, strings.NewReader(form.Encode()))
if err != nil {
return nil, err
}
req.WithContext(ctx)
req.Header.Set("User-Agent", "writefreely")
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth(c.ClientID, c.ClientSecret)
resp, err := c.HttpClient.Do(req)
if err != nil {
return nil, err
}
var tokenResponse slackExchangeResponse
if err := limitedJsonUnmarshal(resp.Body, tokenRequestMaxLen, &tokenResponse); err != nil {
return nil, err
}
return tokenResponse.TokenResponse(), nil
}
func (c slackOauthClient) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) {
req, err := http.NewRequest("GET", slackIdentityLocation, nil)
if err != nil {
return nil, err
}
req.WithContext(ctx)
req.Header.Set("User-Agent", "writefreely")
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+accessToken)
resp, err := c.HttpClient.Do(req)
if err != nil {
return nil, err
}
var inspectResponse slackUserIdentityResponse
if err := limitedJsonUnmarshal(resp.Body, infoRequestMaxLen, &inspectResponse); err != nil {
return nil, err
}
return inspectResponse.InspectResponse(), nil
}
func (resp slackUserIdentityResponse) InspectResponse() *InspectResponse {
return &InspectResponse{
UserID: resp.User.ID,
Username: slug.Make(resp.User.Name),
Email: resp.User.Email,
}
}
func (resp slackExchangeResponse) TokenResponse() *TokenResponse {
return &TokenResponse{
AccessToken: resp.AccessToken,
}
}

@ -21,14 +21,16 @@ type MockOAuthDatastoreProvider struct {
}
type MockOAuthDatastore struct {
DoGenerateOAuthState func(ctx context.Context) (string, error)
DoValidateOAuthState func(context.Context, string) error
DoGetIDForRemoteUser func(context.Context, int64) (int64, error)
DoGenerateOAuthState func(context.Context, string, string) (string, error)
DoValidateOAuthState func(context.Context, string) (string, string, error)
DoGetIDForRemoteUser func(context.Context, string) (int64, error)
DoCreateUser func(*config.Config, *User, string) error
DoRecordRemoteUserID func(context.Context, int64, int64) error
DoRecordRemoteUserID func(context.Context, int64, string) error
DoGetUserForAuthByID func(int64) (*User, error)
}
var _ OAuthDatastore = &MockOAuthDatastore{}
type StringReadCloser struct {
*strings.Reader
}
@ -68,24 +70,29 @@ func (m *MockOAuthDatastoreProvider) Config() *config.Config {
}
cfg := config.New()
cfg.UseSQLite(true)
cfg.App.EnableOAuth = true
cfg.App.OAuthProviderAuthLocation = "https://write.as/oauth/login"
cfg.App.OAuthProviderTokenLocation = "https://write.as/oauth/token"
cfg.App.OAuthProviderInspectLocation = "https://write.as/oauth/inspect"
cfg.App.OAuthClientCallbackLocation = "http://localhost/oauth/callback"
cfg.App.OAuthClientID = "development"
cfg.App.OAuthClientSecret = "development"
cfg.WriteAsOauth = config.WriteAsOauthCfg{
ClientID: "development",
ClientSecret: "development",
AuthLocation: "https://write.as/oauth/login",
TokenLocation: "https://write.as/oauth/token",
InspectLocation: "https://write.as/oauth/inspect",
}
cfg.SlackOauth = config.SlackOauthCfg{
ClientID: "development",
ClientSecret: "development",
TeamID: "development",
}
return cfg
}
func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) error {
func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) (string, string, error) {
if m.DoValidateOAuthState != nil {
return m.DoValidateOAuthState(ctx, state)
}
return nil
return "", "", nil
}
func (m *MockOAuthDatastore) GetIDForRemoteUser(ctx context.Context, remoteUserID int64) (int64, error) {
func (m *MockOAuthDatastore) GetIDForRemoteUser(ctx context.Context, remoteUserID string) (int64, error) {
if m.DoGetIDForRemoteUser != nil {
return m.DoGetIDForRemoteUser(ctx, remoteUserID)
}
@ -100,7 +107,7 @@ func (m *MockOAuthDatastore) CreateUser(cfg *config.Config, u *User, username st
return nil
}
func (m *MockOAuthDatastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID int64) error {
func (m *MockOAuthDatastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID string) error {
if m.DoRecordRemoteUserID != nil {
return m.DoRecordRemoteUserID(ctx, localUserID, remoteUserID)
}
@ -117,9 +124,9 @@ func (m *MockOAuthDatastore) GetUserForAuthByID(userID int64) (*User, error) {
return user, nil
}
func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context) (string, error) {
func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context, provider string, clientID string) (string, error) {
if m.DoGenerateOAuthState != nil {
return m.DoGenerateOAuthState(ctx)
return m.DoGenerateOAuthState(ctx, provider, clientID)
}
return store.Generate62RandomString(14), nil
}
@ -132,6 +139,15 @@ func TestViewOauthInit(t *testing.T) {
Config: app.Config(),
DB: app.DB(),
Store: app.SessionStore(),
oauthClient:writeAsOauthClient{
ClientID: app.Config().WriteAsOauth.ClientID,
ClientSecret: app.Config().WriteAsOauth.ClientSecret,
ExchangeLocation: app.Config().WriteAsOauth.TokenLocation,
InspectLocation: app.Config().WriteAsOauth.InspectLocation,
AuthLocation: app.Config().WriteAsOauth.AuthLocation,
CallbackLocation: "http://localhost/oauth/callback",
HttpClient: nil,
},
}
req, err := http.NewRequest("GET", "/oauth/client", nil)
assert.NoError(t, err)
@ -151,7 +167,7 @@ func TestViewOauthInit(t *testing.T) {
app := &MockOAuthDatastoreProvider{
DoDB: func() OAuthDatastore {
return &MockOAuthDatastore{
DoGenerateOAuthState: func(ctx context.Context) (string, error) {
DoGenerateOAuthState: func(ctx context.Context, provider, clientID string) (string, error) {
return "", fmt.Errorf("pretend unable to write state error")
},
}
@ -161,6 +177,15 @@ func TestViewOauthInit(t *testing.T) {
Config: app.Config(),
DB: app.DB(),
Store: app.SessionStore(),
oauthClient:writeAsOauthClient{
ClientID: app.Config().WriteAsOauth.ClientID,
ClientSecret: app.Config().WriteAsOauth.ClientSecret,
ExchangeLocation: app.Config().WriteAsOauth.TokenLocation,
InspectLocation: app.Config().WriteAsOauth.InspectLocation,
AuthLocation: app.Config().WriteAsOauth.AuthLocation,
CallbackLocation: "http://localhost/oauth/callback",
HttpClient: nil,
},
}
req, err := http.NewRequest("GET", "/oauth/client", nil)
assert.NoError(t, err)
@ -179,6 +204,13 @@ func TestViewOauthCallback(t *testing.T) {
Config: app.Config(),
DB: app.DB(),
Store: app.SessionStore(),
oauthClient: writeAsOauthClient{
ClientID: app.Config().WriteAsOauth.ClientID,
ClientSecret: app.Config().WriteAsOauth.ClientSecret,
ExchangeLocation: app.Config().WriteAsOauth.TokenLocation,
InspectLocation: app.Config().WriteAsOauth.InspectLocation,
AuthLocation: app.Config().WriteAsOauth.AuthLocation,
CallbackLocation: "http://localhost/oauth/callback",
HttpClient: &MockHTTPClient{
DoDo: func(req *http.Request) (*http.Response, error) {
switch req.URL.String() {
@ -190,7 +222,7 @@ func TestViewOauthCallback(t *testing.T) {
case "https://write.as/oauth/inspect":
return &http.Response{
StatusCode: 200,
Body: &StringReadCloser{strings.NewReader(`{"client_id": "development", "user_id": 1, "expires_at": "2019-12-19T11:42:01Z", "username": "nick", "email": "nick@testing.write.as"}`)},
Body: &StringReadCloser{strings.NewReader(`{"client_id": "development", "user_id": "1", "expires_at": "2019-12-19T11:42:01Z", "username": "nick", "email": "nick@testing.write.as"}`)},
}, nil
}
@ -199,6 +231,7 @@ func TestViewOauthCallback(t *testing.T) {
}, nil
},
},
},
}
req, err := http.NewRequest("GET", "/oauth/callback", nil)
assert.NoError(t, err)
@ -206,6 +239,5 @@ func TestViewOauthCallback(t *testing.T) {
h.viewOauthCallback(rr, req)
assert.NoError(t, err)
assert.Equal(t, http.StatusTemporaryRedirect, rr.Code)
})
}

@ -0,0 +1,91 @@
package writefreely
import (
"context"
"net/http"
"net/url"
"strings"
)
type writeAsOauthClient struct {
ClientID string
ClientSecret string
AuthLocation string
ExchangeLocation string
InspectLocation string
CallbackLocation string
HttpClient HttpClient
}
var _ oauthClient = writeAsOauthClient{}
func (c writeAsOauthClient) GetProvider() string {
return "write.as"
}
func (c writeAsOauthClient) GetClientID() string {
return c.ClientID
}
func (c writeAsOauthClient) buildLoginURL(state string) (string, error) {
u, err := url.Parse(c.AuthLocation)
if err != nil {
return "", err
}
q := u.Query()
q.Set("client_id", c.ClientID)
q.Set("redirect_uri", c.CallbackLocation)
q.Set("response_type", "code")
q.Set("state", state)
u.RawQuery = q.Encode()
return u.String(), nil
}
func (c writeAsOauthClient) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) {
form := url.Values{}
form.Add("grant_type", "authorization_code")
form.Add("redirect_uri", c.CallbackLocation)
form.Add("code", code)
req, err := http.NewRequest("POST", c.ExchangeLocation, strings.NewReader(form.Encode()))
if err != nil {
return nil, err
}
req.WithContext(ctx)
req.Header.Set("User-Agent", "writefreely")
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth(c.ClientID, c.ClientSecret)
resp, err := c.HttpClient.Do(req)
if err != nil {
return nil, err
}
var tokenResponse TokenResponse
if err := limitedJsonUnmarshal(resp.Body, tokenRequestMaxLen, &tokenResponse); err != nil {
return nil, err
}
return &tokenResponse, nil
}
func (c writeAsOauthClient) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) {
req, err := http.NewRequest("GET", c.InspectLocation, nil)
if err != nil {
return nil, err
}
req.WithContext(ctx)
req.Header.Set("User-Agent", "writefreely")
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+accessToken)
resp, err := c.HttpClient.Do(req)
if err != nil {
return nil, err
}
var inspectResponse InspectResponse
if err := limitedJsonUnmarshal(resp.Body, infoRequestMaxLen, &inspectResponse); err != nil {
return nil, err
}
return &inspectResponse, nil
}

@ -70,6 +70,9 @@ func InitRoutes(apper Apper, r *mux.Router) *mux.Router {
write.HandleFunc(nodeinfo.NodeInfoPath, handler.LogHandlerFunc(http.HandlerFunc(ni.NodeInfoDiscover)))
write.HandleFunc(niCfg.InfoURL, handler.LogHandlerFunc(http.HandlerFunc(ni.NodeInfo)))
configureSlackOauth(write, apper.App())
configureWriteAsOauth(write, apper.App())
// Set up dyamic page handlers
// Handle auth
auth := write.PathPrefix("/api/auth/").Subrouter()
@ -80,14 +83,6 @@ func InitRoutes(apper Apper, r *mux.Router) *mux.Router {
auth.HandleFunc("/read", handler.WebErrors(handleWebCollectionUnlock, UserLevelNone)).Methods("POST")
auth.HandleFunc("/me", handler.All(handleAPILogout)).Methods("DELETE")
oauthHandler := oauthHandler{
HttpClient: &http.Client{},
Config: apper.App().Config(),
DB: apper.App().DB(),
Store: apper.App().SessionStore(),
}
oauthHandler.configureRoutes(write)
// Handle logged in user sections
me := write.PathPrefix("/me").Subrouter()
me.HandleFunc("/", handler.Redirect("/me", UserLevelUser))
@ -189,6 +184,7 @@ func InitRoutes(apper Apper, r *mux.Router) *mux.Router {
}
write.HandleFunc(draftEditPrefix+"/{post}", handler.Web(handleViewPost, UserLevelOptional))
write.HandleFunc("/", handler.Web(handleViewHome, UserLevelOptional))
return r
}

Loading…
Cancel
Save