Feature complete on MVP slack auth integration. T710

pull/232/head
Nick Gerakines 5 years ago
parent 13121cb266
commit 462f87919a
  1. 39
      config/config.go
  2. 46
      database.go
  3. 6
      database_test.go
  4. 40
      db/alter.go
  5. 56
      db/alter_test.go
  6. 29
      db/create.go
  7. 43
      db/dialect.go
  8. 1
      migrations/migrations.go
  9. 41
      migrations/v5.go
  10. 288
      oauth.go
  11. 148
      oauth_slack.go
  12. 102
      oauth_test.go
  13. 91
      oauth_writeas.go
  14. 12
      routes.go

@ -56,23 +56,18 @@ type (
Port int `ini:"port"` Port int `ini:"port"`
} }
OAuthCfg struct { WriteAsOauthCfg struct {
Enabled bool `ini:"enable"` ClientID string `ini:"client_id"`
ClientSecret string `ini:"client_secret"`
// write.as AuthLocation string `ini:"auth_location"`
WriteAsProviderAuthLocation string `ini:"wa_auth_location"` TokenLocation string `ini:"token_location"`
WriteAsProviderTokenLocation string `ini:"wa_token_location"` InspectLocation string `ini:"inspect_location"`
WriteAsProviderInspectLocation string `ini:"wa_inspect_location"` }
WriteAsClientCallbackLocation string `ini:"wa_callback_location"`
WriteAsClientID string `ini:"wa_client_id"` SlackOauthCfg struct {
WriteAsClientSecret string `ini:"wa_client_secret"` ClientID string `ini:"client_id"`
WriteAsAuthLocation string ClientSecret string `ini:"client_secret"`
TeamID string `ini:"team_id"`
// slack
SlackClientID string `ini:"slack_client_id"`
SlackClientSecret string `ini:"slack_client_secret"`
SlackTeamID string `init:"slack_team_id"`
SlackAuthLocation string
} }
// AppCfg holds values that affect how the application functions // AppCfg holds values that affect how the application functions
@ -113,15 +108,15 @@ type (
// Defaults // Defaults
DefaultVisibility string `ini:"default_visibility"` DefaultVisibility string `ini:"default_visibility"`
OAuth OAuthCfg `ini:"oauth"`
} }
// Config holds the complete configuration for running a writefreely instance // Config holds the complete configuration for running a writefreely instance
Config struct { Config struct {
Server ServerCfg `ini:"server"` Server ServerCfg `ini:"server"`
Database DatabaseCfg `ini:"database"` Database DatabaseCfg `ini:"database"`
App AppCfg `ini:"app"` App AppCfg `ini:"app"`
SlackOauth SlackOauthCfg `ini:"oauth.slack"`
WriteAsOauth WriteAsOauthCfg `ini:"oauth.writeas"`
} }
) )

@ -14,6 +14,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
wf_db "github.com/writeas/writefreely/db"
"net/http" "net/http"
"strings" "strings"
"time" "time"
@ -125,9 +126,9 @@ type writestore interface {
GetUserLastPostTime(id int64) (*time.Time, error) GetUserLastPostTime(id int64) (*time.Time, error)
GetCollectionLastPostTime(id int64) (*time.Time, error) GetCollectionLastPostTime(id int64) (*time.Time, error)
GetIDForRemoteUser(context.Context, int64) (int64, error) GetIDForRemoteUser(context.Context, string) (int64, error)
RecordRemoteUserID(context.Context, int64, int64) error RecordRemoteUserID(context.Context, int64, string) error
ValidateOAuthState(context.Context, string, string, string) error ValidateOAuthState(context.Context, string) (string, string, error)
GenerateOAuthState(context.Context, string, string) (string, error) GenerateOAuthState(context.Context, string, string) (string, error)
DatabaseInitialized() bool DatabaseInitialized() bool
@ -2470,22 +2471,35 @@ func (db *datastore) GenerateOAuthState(ctx context.Context, provider, clientID
return state, nil return state, nil
} }
func (db *datastore) ValidateOAuthState(ctx context.Context, state, provider, clientID string) error { func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (string, string, error) {
res, err := db.ExecContext(ctx, "UPDATE oauth_client_state SET used = TRUE WHERE state = ? AND provider = ? AND client_id = ?", state, provider, clientID) var provider string
if err != nil { var clientID string
return err err := wf_db.RunTransactionWithOptions(ctx, db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error {
} err := tx.QueryRow("SELECT provider, client_id FROM oauth_client_state WHERE state = ? AND used = FALSE", state).Scan(&provider, &clientID)
rowsAffected, err := res.RowsAffected() if err != nil {
return err
}
res, err := tx.ExecContext(ctx, "UPDATE oauth_client_state SET used = TRUE WHERE state = ?", state)
if err != nil {
return err
}
rowsAffected, err := res.RowsAffected()
if err != nil {
return err
}
if rowsAffected != 1 {
return fmt.Errorf("state not found")
}
return nil
})
if err != nil { if err != nil {
return err return "", "", nil
} }
if rowsAffected != 1 { return provider, clientID, nil
return fmt.Errorf("state not found")
}
return nil
} }
func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID, remoteUserID int64) error { func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID string) error {
var err error var err error
if db.driverName == driverSQLite { if db.driverName == driverSQLite {
_, err = db.ExecContext(ctx, "INSERT OR REPLACE INTO users_oauth (user_id, remote_user_id) VALUES (?, ?)", localUserID, remoteUserID) _, err = db.ExecContext(ctx, "INSERT OR REPLACE INTO users_oauth (user_id, remote_user_id) VALUES (?, ?)", localUserID, remoteUserID)
@ -2499,7 +2513,7 @@ func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID, remote
} }
// GetIDForRemoteUser returns a user ID associated with a remote user ID. // GetIDForRemoteUser returns a user ID associated with a remote user ID.
func (db *datastore) GetIDForRemoteUser(ctx context.Context, remoteUserID int64) (int64, error) { func (db *datastore) GetIDForRemoteUser(ctx context.Context, remoteUserID string) (int64, error) {
var userID int64 = -1 var userID int64 = -1
err := db. err := db.
QueryRowContext(ctx, "SELECT user_id FROM users_oauth WHERE remote_user_id = ?", remoteUserID). QueryRowContext(ctx, "SELECT user_id FROM users_oauth WHERE remote_user_id = ?", remoteUserID).

@ -18,19 +18,19 @@ func TestOAuthDatastore(t *testing.T) {
driverName: "", driverName: "",
} }
state, err := ds.GenerateOAuthState(ctx, "", "") state, err := ds.GenerateOAuthState(ctx, "test", "development")
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, state, 24) assert.Len(t, state, 24)
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_state` WHERE `state` = ? AND `used` = false", state) countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_state` WHERE `state` = ? AND `used` = false", state)
err = ds.ValidateOAuthState(ctx, state) _, _, err = ds.ValidateOAuthState(ctx, state)
assert.NoError(t, err) assert.NoError(t, err)
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_state` WHERE `state` = ? AND `used` = true", state) countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_state` WHERE `state` = ? AND `used` = true", state)
var localUserID int64 = 99 var localUserID int64 = 99
var remoteUserID int64 = 100 var remoteUserID = "100"
err = ds.RecordRemoteUserID(ctx, localUserID, remoteUserID) err = ds.RecordRemoteUserID(ctx, localUserID, remoteUserID)
assert.NoError(t, err) assert.NoError(t, err)

@ -0,0 +1,40 @@
package db
import (
"fmt"
"strings"
)
type AlterTableSqlBuilder struct {
Dialect DialectType
Name string
Changes []string
}
func (b *AlterTableSqlBuilder) AddColumn(col *Column) *AlterTableSqlBuilder {
if colVal, err := col.String(); err == nil {
b.Changes = append(b.Changes, fmt.Sprintf("ADD COLUMN %s", colVal))
}
return b
}
func (b *AlterTableSqlBuilder) ToSQL() (string, error) {
var str strings.Builder
str.WriteString("ALTER TABLE ")
str.WriteString(b.Name)
str.WriteString(" ")
if len(b.Changes) == 0 {
return "", fmt.Errorf("no changes provide for table: %s", b.Name)
}
changeCount := len(b.Changes)
for i, thing := range b.Changes {
str.WriteString(thing)
if i < changeCount-1 {
str.WriteString(", ")
}
}
return str.String(), nil
}

@ -0,0 +1,56 @@
package db
import "testing"
func TestAlterTableSqlBuilder_ToSQL(t *testing.T) {
type fields struct {
Dialect DialectType
Name string
Changes []string
}
tests := []struct {
name string
builder *AlterTableSqlBuilder
want string
wantErr bool
}{
{
name: "MySQL add int",
builder: DialectMySQL.
AlterTable("the_table").
AddColumn(DialectMySQL.Column("the_col", ColumnTypeInteger, UnsetSize)),
want: "ALTER TABLE the_table ADD COLUMN the_col INT NOT NULL",
wantErr: false,
},
{
name: "MySQL add string",
builder: DialectMySQL.
AlterTable("the_table").
AddColumn(DialectMySQL.Column("the_col", ColumnTypeVarChar, OptionalInt{true, 128})),
want: "ALTER TABLE the_table ADD COLUMN the_col VARCHAR(128) NOT NULL",
wantErr: false,
},
{
name: "MySQL add int and string",
builder: DialectMySQL.
AlterTable("the_table").
AddColumn(DialectMySQL.Column("first_col", ColumnTypeInteger, UnsetSize)).
AddColumn(DialectMySQL.Column("second_col", ColumnTypeVarChar, OptionalInt{true, 128})),
want: "ALTER TABLE the_table ADD COLUMN first_col INT NOT NULL, ADD COLUMN second_col VARCHAR(128) NOT NULL",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.builder.ToSQL()
if (err != nil) != tt.wantErr {
t.Errorf("ToSQL() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("ToSQL() got = %v, want %v", got, tt.want)
}
})
}
}

@ -5,7 +5,6 @@ import (
"strings" "strings"
) )
type DialectType int
type ColumnType int type ColumnType int
type OptionalInt struct { type OptionalInt struct {
@ -41,11 +40,6 @@ type CreateTableSqlBuilder struct {
Constraints []string Constraints []string
} }
const (
DialectSQLite DialectType = iota
DialectMySQL DialectType = iota
)
const ( const (
ColumnTypeBool ColumnType = iota ColumnTypeBool ColumnType = iota
ColumnTypeSmallInt ColumnType = iota ColumnTypeSmallInt ColumnType = iota
@ -61,28 +55,6 @@ var _ SQLBuilder = &CreateTableSqlBuilder{}
var UnsetSize OptionalInt = OptionalInt{Set: false, Value: 0} var UnsetSize OptionalInt = OptionalInt{Set: false, Value: 0}
var UnsetDefault OptionalString = OptionalString{Set: false, Value: ""} var UnsetDefault OptionalString = OptionalString{Set: false, Value: ""}
func (d DialectType) Column(name string, t ColumnType, size OptionalInt) *Column {
switch d {
case DialectSQLite:
return &Column{Dialect: DialectSQLite, Name: name, Type: t, Size: size}
case DialectMySQL:
return &Column{Dialect: DialectMySQL, Name: name, Type: t, Size: size}
default:
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
}
func (d DialectType) Table(name string) *CreateTableSqlBuilder {
switch d {
case DialectSQLite:
return &CreateTableSqlBuilder{Dialect: DialectSQLite, Name: name}
case DialectMySQL:
return &CreateTableSqlBuilder{Dialect: DialectMySQL, Name: name}
default:
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
}
func (d ColumnType) Format(dialect DialectType, size OptionalInt) (string, error) { func (d ColumnType) Format(dialect DialectType, size OptionalInt) (string, error) {
if dialect != DialectMySQL && dialect != DialectSQLite { if dialect != DialectMySQL && dialect != DialectSQLite {
return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size) return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size)
@ -269,3 +241,4 @@ func (b *CreateTableSqlBuilder) ToSQL() (string, error) {
return str.String(), nil return str.String(), nil
} }

@ -0,0 +1,43 @@
package db
import "fmt"
type DialectType int
const (
DialectSQLite DialectType = iota
DialectMySQL DialectType = iota
)
func (d DialectType) Column(name string, t ColumnType, size OptionalInt) *Column {
switch d {
case DialectSQLite:
return &Column{Dialect: DialectSQLite, Name: name, Type: t, Size: size}
case DialectMySQL:
return &Column{Dialect: DialectMySQL, Name: name, Type: t, Size: size}
default:
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
}
func (d DialectType) Table(name string) *CreateTableSqlBuilder {
switch d {
case DialectSQLite:
return &CreateTableSqlBuilder{Dialect: DialectSQLite, Name: name}
case DialectMySQL:
return &CreateTableSqlBuilder{Dialect: DialectMySQL, Name: name}
default:
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
}
func (d DialectType) AlterTable(name string) *AlterTableSqlBuilder {
switch d {
case DialectSQLite:
return &AlterTableSqlBuilder{Dialect: DialectSQLite, Name: name}
case DialectMySQL:
return &AlterTableSqlBuilder{Dialect: DialectMySQL, Name: name}
default:
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
}

@ -60,6 +60,7 @@ var migrations = []Migration{
New("support dynamic instance pages", supportInstancePages), // V1 -> V2 (v0.9.0) New("support dynamic instance pages", supportInstancePages), // V1 -> V2 (v0.9.0)
New("support users suspension", supportUserStatus), // V2 -> V3 (v0.11.0) New("support users suspension", supportUserStatus), // V2 -> V3 (v0.11.0)
New("support oauth", oauth), // V3 -> V4 New("support oauth", oauth), // V3 -> V4
New("support slack oauth", oauth_slack), // V4 -> v5
} }
// CurrentVer returns the current migration version the application is on // CurrentVer returns the current migration version the application is on

@ -0,0 +1,41 @@
package migrations
import (
"context"
"database/sql"
wf_db "github.com/writeas/writefreely/db"
)
func oauth_slack(db *datastore) error {
dialect := wf_db.DialectMySQL
if db.driverName == driverSQLite {
dialect = wf_db.DialectSQLite
}
return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error {
builders := []wf_db.SQLBuilder{
dialect.
AlterTable("oauth_client_state").
AddColumn(dialect.
Column(
"provider",
wf_db.ColumnTypeVarChar,
wf_db.OptionalInt{Set: true, Value: 24,})).
AddColumn(dialect.
Column(
"client_id",
wf_db.ColumnTypeVarChar,
wf_db.OptionalInt{Set: true, Value: 128,})),
}
for _, builder := range builders {
query, err := builder.ToSQL()
if err != nil {
return err
}
if _, err := tx.ExecContext(ctx, query); err != nil {
return err
}
}
return nil
})
}

@ -2,7 +2,6 @@ package writefreely
import ( import (
"context" "context"
"encoding/hex"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/gorilla/mux" "github.com/gorilla/mux"
@ -10,14 +9,10 @@ import (
"github.com/guregu/null/zero" "github.com/guregu/null/zero"
"github.com/writeas/nerds/store" "github.com/writeas/nerds/store"
"github.com/writeas/web-core/auth" "github.com/writeas/web-core/auth"
"github.com/writeas/web-core/log"
"github.com/writeas/writefreely/config" "github.com/writeas/writefreely/config"
"hash/fnv"
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/url"
"strings"
"time" "time"
) )
@ -33,7 +28,7 @@ type TokenResponse struct {
// InspectResponse contains data returned when an access token is inspected. // InspectResponse contains data returned when an access token is inspected.
type InspectResponse struct { type InspectResponse struct {
ClientID string `json:"client_id"` ClientID string `json:"client_id"`
UserID int64 `json:"user_id"` UserID string `json:"user_id"`
ExpiresAt time.Time `json:"expires_at"` ExpiresAt time.Time `json:"expires_at"`
Username string `json:"username"` Username string `json:"username"`
Email string `json:"email"` Email string `json:"email"`
@ -58,9 +53,9 @@ type OAuthDatastoreProvider interface {
// OAuthDatastore provides a minimal interface of data store methods used in // OAuthDatastore provides a minimal interface of data store methods used in
// oauth functionality. // oauth functionality.
type OAuthDatastore interface { type OAuthDatastore interface {
GetIDForRemoteUser(context.Context, int64) (int64, error) GetIDForRemoteUser(context.Context, string) (int64, error)
RecordRemoteUserID(context.Context, int64, int64) error RecordRemoteUserID(context.Context, int64, string) error
ValidateOAuthState(context.Context, string, string, string) error ValidateOAuthState(context.Context, string) (string, string, error)
GenerateOAuthState(context.Context, string, string) (string, error) GenerateOAuthState(context.Context, string, string) (string, error)
CreateUser(*config.Config, *User, string) error CreateUser(*config.Config, *User, string) error
@ -71,45 +66,28 @@ type HttpClient interface {
Do(req *http.Request) (*http.Response, error) Do(req *http.Request) (*http.Response, error)
} }
type oauthHandler struct { type oauthClient interface {
Config *config.Config GetProvider() string
DB OAuthDatastore GetClientID() string
Store sessions.Store buildLoginURL(state string) (string, error)
HttpClient HttpClient exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error)
inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error)
} }
// buildAuthURL returns a URL used to initiate authentication. type oauthHandler struct {
func buildAuthURL(db OAuthDatastore, ctx context.Context, provider, clientID, authLocation, callbackURL string) (string, error) { Config *config.Config
state, err := db.GenerateOAuthState(ctx, provider, clientID) DB OAuthDatastore
if err != nil { Store sessions.Store
return "", err oauthClient oauthClient
}
u, err := url.Parse(authLocation)
if err != nil {
return "", err
}
q := u.Query()
q.Set("client_id", clientID)
q.Set("redirect_uri", callbackURL)
q.Set("response_type", "code")
q.Set("state", state)
u.RawQuery = q.Encode()
return u.String(), nil
} }
func (h oauthHandler) viewOauthInitWriteAs(w http.ResponseWriter, r *http.Request) { func (h oauthHandler) viewOauthInit(w http.ResponseWriter, r *http.Request) {
location, err := buildAuthURL(h.DB, r.Context(), "write.as", h.Config.App.OAuth.WriteAsClientID, h.Config.App.OAuth.WriteAsProviderAuthLocation, h.Config.App.OAuth.WriteAsClientCallbackLocation) ctx := r.Context()
state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID())
if err != nil { if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, "could not prepare oauth redirect url") failOAuthRequest(w, http.StatusInternalServerError, "could not prepare oauth redirect url")
return
} }
http.Redirect(w, r, location, http.StatusTemporaryRedirect) location, err := h.oauthClient.buildLoginURL(state)
}
func (h oauthHandler) viewOauthInitSlack(w http.ResponseWriter, r *http.Request) {
location, err := buildAuthURL(h.DB, r.Context(), "slack", h.Config.App.OAuth.WriteAsClientID, h.Config.App.OAuth.WriteAsProviderAuthLocation, h.Config.App.OAuth.WriteAsClientCallbackLocation)
if err != nil { if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, "could not prepare oauth redirect url") failOAuthRequest(w, http.StatusInternalServerError, "could not prepare oauth redirect url")
return return
@ -117,178 +95,134 @@ func (h oauthHandler) viewOauthInitSlack(w http.ResponseWriter, r *http.Request)
http.Redirect(w, r, location, http.StatusTemporaryRedirect) http.Redirect(w, r, location, http.StatusTemporaryRedirect)
} }
func (h oauthHandler) configureRoutes(r *mux.Router) { func configureSlackOauth(r *mux.Router, app *App) {
if h.Config.App.OAuth.Enabled { if app.Config().SlackOauth.ClientID != "" {
if h.Config.App.OAuth.WriteAsClientID != "" { oauthClient := slackOauthClient{
callbackHash := oauthProviderHash("write.as", h.Config.App.OAuth.WriteAsClientID) ClientID: app.Config().SlackOauth.ClientID,
log.InfoLog.Println("write.as oauth callback URL", "/oauth/callback/"+callbackHash) ClientSecret: app.Config().SlackOauth.ClientSecret,
r.HandleFunc("/oauth/write.as", h.viewOauthInitWriteAs).Methods("GET") TeamID: app.Config().SlackOauth.TeamID,
r.HandleFunc("/oauth/callback/"+callbackHash, h.viewOauthCallback("write.as", h.Config.App.OAuth.WriteAsClientID)).Methods("GET") CallbackLocation: app.Config().App.Host + "/oauth/callback",
} HttpClient: &http.Client{Timeout: 10 * time.Second},
if h.Config.App.OAuth.SlackClientID != "" {
callbackHash := oauthProviderHash("slack", h.Config.App.OAuth.SlackClientID)
log.InfoLog.Println("slack oauth callback URL", "/oauth/callback/"+callbackHash)
r.HandleFunc("/oauth/slack", h.viewOauthInitSlack).Methods("GET")
r.HandleFunc("/oauth/callback/"+callbackHash, h.viewOauthCallback("slack", h.Config.App.OAuth.SlackClientID)).Methods("GET")
} }
configureOauthRoutes(r, app, oauthClient)
} }
}
func configureWriteAsOauth(r *mux.Router, app *App) {
if app.Config().WriteAsOauth.ClientID != "" {
oauthClient := writeAsOauthClient{
ClientID: app.Config().WriteAsOauth.ClientID,
ClientSecret: app.Config().WriteAsOauth.ClientSecret,
ExchangeLocation: app.Config().WriteAsOauth.TokenLocation,
InspectLocation: app.Config().WriteAsOauth.InspectLocation,
AuthLocation: app.Config().WriteAsOauth.AuthLocation,
HttpClient: &http.Client{Timeout: 10 * time.Second},
CallbackLocation: app.Config().App.Host + "/oauth/callback",
}
configureOauthRoutes(r, app, oauthClient)
}
} }
func oauthProviderHash(provider, clientID string) string { func configureOauthRoutes(r *mux.Router, app *App, oauthClient oauthClient) {
hasher := fnv.New32() handler := &oauthHandler{
return hex.EncodeToString(hasher.Sum([]byte(provider + clientID))) Config: app.Config(),
DB: app.DB(),
Store: app.SessionStore(),
oauthClient: oauthClient,
}
r.HandleFunc("/oauth/"+oauthClient.GetProvider(), handler.viewOauthInit).Methods("GET")
r.HandleFunc("/oauth/callback", handler.viewOauthCallback).Methods("GET")
} }
func (h oauthHandler) viewOauthCallback(provider, clientID string) http.HandlerFunc { func (h oauthHandler) viewOauthCallback(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context()
ctx := r.Context()
code := r.FormValue("code") code := r.FormValue("code")
state := r.FormValue("state") state := r.FormValue("state")
err := h.DB.ValidateOAuthState(ctx, state, provider, clientID) _, _, err := h.DB.ValidateOAuthState(ctx, state)
if err != nil { if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error()) failOAuthRequest(w, http.StatusInternalServerError, err.Error())
return return
} }
tokenResponse, err := h.exchangeOauthCode(ctx, code) tokenResponse, err := h.oauthClient.exchangeOauthCode(ctx, code)
if err != nil { if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error()) failOAuthRequest(w, http.StatusInternalServerError, err.Error())
return return
} }
// Now that we have the access token, let's use it real quick to make sur
// it really really works.
tokenInfo, err := h.oauthClient.inspectOauthAccessToken(ctx, tokenResponse.AccessToken)
if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
return
}
// Now that we have the access token, let's use it real quick to make sur localUserID, err := h.DB.GetIDForRemoteUser(ctx, tokenInfo.UserID)
// it really really works. if err != nil {
tokenInfo, err := h.inspectOauthAccessToken(ctx, tokenResponse.AccessToken) failOAuthRequest(w, http.StatusInternalServerError, err.Error())
return
}
if localUserID == -1 {
// We don't have, nor do we want, the password from the origin, so we
//create a random string. If the user needs to set a password, they
//can do so through the settings page or through the password reset
//flow.
randPass := store.Generate62RandomString(14)
hashedPass, err := auth.HashPass([]byte(randPass))
if err != nil { if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error()) failOAuthRequest(w, http.StatusInternalServerError, "unable to create password hash")
return return
} }
newUser := &User{
Username: tokenInfo.Username,
HashedPass: hashedPass,
HasPass: true,
Email: zero.NewString(tokenInfo.Email, tokenInfo.Email != ""),
Created: time.Now().Truncate(time.Second).UTC(),
}
localUserID, err := h.DB.GetIDForRemoteUser(ctx, tokenInfo.UserID) err = h.DB.CreateUser(h.Config, newUser, newUser.Username)
if err != nil { if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error()) failOAuthRequest(w, http.StatusInternalServerError, err.Error())
return return
} }
fmt.Println("local user id", localUserID) err = h.DB.RecordRemoteUserID(ctx, newUser.ID, tokenInfo.UserID)
if localUserID == -1 {
// We don't have, nor do we want, the password from the origin, so we
//create a random string. If the user needs to set a password, they
//can do so through the settings page or through the password reset
//flow.
randPass := store.Generate62RandomString(14)
hashedPass, err := auth.HashPass([]byte(randPass))
if err != nil {
log.ErrorLog.Println(err)
failOAuthRequest(w, http.StatusInternalServerError, "unable to create password hash")
return
}
newUser := &User{
Username: tokenInfo.Username,
HashedPass: hashedPass,
HasPass: true,
Email: zero.NewString("", tokenInfo.Email != ""),
Created: time.Now().Truncate(time.Second).UTC(),
}
err = h.DB.CreateUser(h.Config, newUser, newUser.Username)
if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
return
}
err = h.DB.RecordRemoteUserID(ctx, newUser.ID, tokenInfo.UserID)
if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
return
}
if err := loginOrFail(h.Store, w, r, newUser); err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
}
return
}
user, err := h.DB.GetUserForAuthByID(localUserID)
if err != nil { if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error()) failOAuthRequest(w, http.StatusInternalServerError, err.Error())
return return
} }
if err = loginOrFail(h.Store, w, r, user); err != nil {
if err := loginOrFail(h.Store, w, r, newUser); err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error()) failOAuthRequest(w, http.StatusInternalServerError, err.Error())
} }
} return
}
func (h oauthHandler) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) {
form := url.Values{}
form.Add("grant_type", "authorization_code")
form.Add("redirect_uri", h.Config.App.OAuth.WriteAsClientCallbackLocation)
form.Add("code", code)
req, err := http.NewRequest("POST", h.Config.App.OAuth.WriteAsProviderTokenLocation, strings.NewReader(form.Encode()))
if err != nil {
return nil, err
}
req.WithContext(ctx)
req.Header.Set("User-Agent", "writefreely")
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth(h.Config.App.OAuth.WriteAsClientID, h.Config.App.OAuth.WriteAsClientSecret)
resp, err := h.HttpClient.Do(req)
if err != nil {
return nil, err
} }
// Nick: I like using limited readers to reduce the risk of an endpoint user, err := h.DB.GetUserForAuthByID(localUserID)
// being broken or compromised.
lr := io.LimitReader(resp.Body, tokenRequestMaxLen)
body, err := ioutil.ReadAll(lr)
if err != nil { if err != nil {
return nil, err failOAuthRequest(w, http.StatusInternalServerError, err.Error())
return
} }
if err = loginOrFail(h.Store, w, r, user); err != nil {
var tokenResponse TokenResponse failOAuthRequest(w, http.StatusInternalServerError, err.Error())
err = json.Unmarshal(body, &tokenResponse)
if err != nil {
return nil, err
} }
return &tokenResponse, nil
} }
func (h oauthHandler) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) { func limitedJsonUnmarshal(body io.ReadCloser, n int, thing interface{}) error {
req, err := http.NewRequest("GET", h.Config.App.OAuth.WriteAsProviderInspectLocation, nil) lr := io.LimitReader(body, int64(n+1))
data, err := ioutil.ReadAll(lr)
if err != nil { if err != nil {
return nil, err return err
}
req.WithContext(ctx)
req.Header.Set("User-Agent", "writefreely")
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+accessToken)
resp, err := h.HttpClient.Do(req)
if err != nil {
return nil, err
}
// Nick: I like using limited readers to reduce the risk of an endpoint
// being broken or compromised.
lr := io.LimitReader(resp.Body, infoRequestMaxLen)
body, err := ioutil.ReadAll(lr)
if err != nil {
return nil, err
} }
if len(data) == n+1 {
var inspectResponse InspectResponse return fmt.Errorf("content larger than max read allowance: %d", n)
err = json.Unmarshal(body, &inspectResponse)
if err != nil {
return nil, err
} }
return &inspectResponse, nil return json.Unmarshal(data, thing)
} }
func loginOrFail(store sessions.Store, w http.ResponseWriter, r *http.Request, user *User) error { func loginOrFail(store sessions.Store, w http.ResponseWriter, r *http.Request, user *User) error {

@ -0,0 +1,148 @@
package writefreely
import (
"context"
"github.com/writeas/slug"
"net/http"
"net/url"
"strings"
)
type slackOauthClient struct {
ClientID string
ClientSecret string
TeamID string
CallbackLocation string
HttpClient HttpClient
}
type slackExchangeResponse struct {
AccessToken string `json:"access_token"`
Scope string `json:"scope"`
TeamName string `json:"team_name"`
TeamID string `json:"team_id"`
}
type slackIdentity struct {
Name string `json:"name"`
ID string `json:"id"`
Email string `json:"email"`
}
type slackTeam struct {
Name string `json:"name"`
ID string `json:"id"`
}
type slackUserIdentityResponse struct {
OK bool `json:"ok"`
User slackIdentity `json:"user"`
Team slackTeam `json:"team"`
Error string `json:"error"`
}
const (
slackAuthLocation = "https://slack.com/oauth/authorize"
slackExchangeLocation = "https://slack.com/api/oauth.access"
slackIdentityLocation = "https://slack.com/api/users.identity"
)
var _ oauthClient = slackOauthClient{}
func (c slackOauthClient) GetProvider() string {
return "slack"
}
func (c slackOauthClient) GetClientID() string {
return c.ClientID
}
func (c slackOauthClient) buildLoginURL(state string) (string, error) {
u, err := url.Parse(slackAuthLocation)
if err != nil {
return "", err
}
q := u.Query()
q.Set("client_id", c.ClientID)
q.Set("scope", "identity.basic identity.email identity.team")
q.Set("redirect_uri", c.CallbackLocation)
q.Set("state", state)
// If this param is not set, the user can select which team they
// authenticate through and then we'd have to match the configured team
// against the profile get. That is extra work in the post-auth phase
// that we don't want to do.
q.Set("team", c.TeamID)
// The Slack OAuth docs don't explicitly list this one, but it is part of
// the spec, so we include it anyway.
q.Set("response_type", "code")
u.RawQuery = q.Encode()
return u.String(), nil
}
func (c slackOauthClient) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) {
form := url.Values{}
// The oauth.access documentation doesn't explicitly mention this
// parameter, but it is part of the spec, so we include it anyway.
// https://api.slack.com/methods/oauth.access
form.Add("grant_type", "authorization_code")
form.Add("redirect_uri", c.CallbackLocation)
form.Add("code", code)
req, err := http.NewRequest("POST", slackExchangeLocation, strings.NewReader(form.Encode()))
if err != nil {
return nil, err
}
req.WithContext(ctx)
req.Header.Set("User-Agent", "writefreely")
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth(c.ClientID, c.ClientSecret)
resp, err := c.HttpClient.Do(req)
if err != nil {
return nil, err
}
var tokenResponse slackExchangeResponse
if err := limitedJsonUnmarshal(resp.Body, tokenRequestMaxLen, &tokenResponse); err != nil {
return nil, err
}
return tokenResponse.TokenResponse(), nil
}
func (c slackOauthClient) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) {
req, err := http.NewRequest("GET", slackIdentityLocation, nil)
if err != nil {
return nil, err
}
req.WithContext(ctx)
req.Header.Set("User-Agent", "writefreely")
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+accessToken)
resp, err := c.HttpClient.Do(req)
if err != nil {
return nil, err
}
var inspectResponse slackUserIdentityResponse
if err := limitedJsonUnmarshal(resp.Body, infoRequestMaxLen, &inspectResponse); err != nil {
return nil, err
}
return inspectResponse.InspectResponse(), nil
}
func (resp slackUserIdentityResponse) InspectResponse() *InspectResponse {
return &InspectResponse{
UserID: resp.User.ID,
Username: slug.Make(resp.User.Name),
Email: resp.User.Email,
}
}
func (resp slackExchangeResponse) TokenResponse() *TokenResponse {
return &TokenResponse{
AccessToken: resp.AccessToken,
}
}

@ -21,14 +21,16 @@ type MockOAuthDatastoreProvider struct {
} }
type MockOAuthDatastore struct { type MockOAuthDatastore struct {
DoGenerateOAuthState func(ctx context.Context) (string, error) DoGenerateOAuthState func(context.Context, string, string) (string, error)
DoValidateOAuthState func(context.Context, string) error DoValidateOAuthState func(context.Context, string) (string, string, error)
DoGetIDForRemoteUser func(context.Context, int64) (int64, error) DoGetIDForRemoteUser func(context.Context, string) (int64, error)
DoCreateUser func(*config.Config, *User, string) error DoCreateUser func(*config.Config, *User, string) error
DoRecordRemoteUserID func(context.Context, int64, int64) error DoRecordRemoteUserID func(context.Context, int64, string) error
DoGetUserForAuthByID func(int64) (*User, error) DoGetUserForAuthByID func(int64) (*User, error)
} }
var _ OAuthDatastore = &MockOAuthDatastore{}
type StringReadCloser struct { type StringReadCloser struct {
*strings.Reader *strings.Reader
} }
@ -68,24 +70,29 @@ func (m *MockOAuthDatastoreProvider) Config() *config.Config {
} }
cfg := config.New() cfg := config.New()
cfg.UseSQLite(true) cfg.UseSQLite(true)
cfg.App.EnableOAuth = true cfg.WriteAsOauth = config.WriteAsOauthCfg{
cfg.App.OAuthProviderAuthLocation = "https://write.as/oauth/login" ClientID: "development",
cfg.App.OAuthProviderTokenLocation = "https://write.as/oauth/token" ClientSecret: "development",
cfg.App.OAuthProviderInspectLocation = "https://write.as/oauth/inspect" AuthLocation: "https://write.as/oauth/login",
cfg.App.OAuthClientCallbackLocation = "http://localhost/oauth/callback" TokenLocation: "https://write.as/oauth/token",
cfg.App.OAuthClientID = "development" InspectLocation: "https://write.as/oauth/inspect",
cfg.App.OAuthClientSecret = "development" }
cfg.SlackOauth = config.SlackOauthCfg{
ClientID: "development",
ClientSecret: "development",
TeamID: "development",
}
return cfg return cfg
} }
func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) error { func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) (string, string, error) {
if m.DoValidateOAuthState != nil { if m.DoValidateOAuthState != nil {
return m.DoValidateOAuthState(ctx, state) return m.DoValidateOAuthState(ctx, state)
} }
return nil return "", "", nil
} }
func (m *MockOAuthDatastore) GetIDForRemoteUser(ctx context.Context, remoteUserID int64) (int64, error) { func (m *MockOAuthDatastore) GetIDForRemoteUser(ctx context.Context, remoteUserID string) (int64, error) {
if m.DoGetIDForRemoteUser != nil { if m.DoGetIDForRemoteUser != nil {
return m.DoGetIDForRemoteUser(ctx, remoteUserID) return m.DoGetIDForRemoteUser(ctx, remoteUserID)
} }
@ -100,7 +107,7 @@ func (m *MockOAuthDatastore) CreateUser(cfg *config.Config, u *User, username st
return nil return nil
} }
func (m *MockOAuthDatastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID int64) error { func (m *MockOAuthDatastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID string) error {
if m.DoRecordRemoteUserID != nil { if m.DoRecordRemoteUserID != nil {
return m.DoRecordRemoteUserID(ctx, localUserID, remoteUserID) return m.DoRecordRemoteUserID(ctx, localUserID, remoteUserID)
} }
@ -117,9 +124,9 @@ func (m *MockOAuthDatastore) GetUserForAuthByID(userID int64) (*User, error) {
return user, nil return user, nil
} }
func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context) (string, error) { func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context, provider string, clientID string) (string, error) {
if m.DoGenerateOAuthState != nil { if m.DoGenerateOAuthState != nil {
return m.DoGenerateOAuthState(ctx) return m.DoGenerateOAuthState(ctx, provider, clientID)
} }
return store.Generate62RandomString(14), nil return store.Generate62RandomString(14), nil
} }
@ -132,6 +139,15 @@ func TestViewOauthInit(t *testing.T) {
Config: app.Config(), Config: app.Config(),
DB: app.DB(), DB: app.DB(),
Store: app.SessionStore(), Store: app.SessionStore(),
oauthClient:writeAsOauthClient{
ClientID: app.Config().WriteAsOauth.ClientID,
ClientSecret: app.Config().WriteAsOauth.ClientSecret,
ExchangeLocation: app.Config().WriteAsOauth.TokenLocation,
InspectLocation: app.Config().WriteAsOauth.InspectLocation,
AuthLocation: app.Config().WriteAsOauth.AuthLocation,
CallbackLocation: "http://localhost/oauth/callback",
HttpClient: nil,
},
} }
req, err := http.NewRequest("GET", "/oauth/client", nil) req, err := http.NewRequest("GET", "/oauth/client", nil)
assert.NoError(t, err) assert.NoError(t, err)
@ -151,7 +167,7 @@ func TestViewOauthInit(t *testing.T) {
app := &MockOAuthDatastoreProvider{ app := &MockOAuthDatastoreProvider{
DoDB: func() OAuthDatastore { DoDB: func() OAuthDatastore {
return &MockOAuthDatastore{ return &MockOAuthDatastore{
DoGenerateOAuthState: func(ctx context.Context) (string, error) { DoGenerateOAuthState: func(ctx context.Context, provider, clientID string) (string, error) {
return "", fmt.Errorf("pretend unable to write state error") return "", fmt.Errorf("pretend unable to write state error")
}, },
} }
@ -161,6 +177,15 @@ func TestViewOauthInit(t *testing.T) {
Config: app.Config(), Config: app.Config(),
DB: app.DB(), DB: app.DB(),
Store: app.SessionStore(), Store: app.SessionStore(),
oauthClient:writeAsOauthClient{
ClientID: app.Config().WriteAsOauth.ClientID,
ClientSecret: app.Config().WriteAsOauth.ClientSecret,
ExchangeLocation: app.Config().WriteAsOauth.TokenLocation,
InspectLocation: app.Config().WriteAsOauth.InspectLocation,
AuthLocation: app.Config().WriteAsOauth.AuthLocation,
CallbackLocation: "http://localhost/oauth/callback",
HttpClient: nil,
},
} }
req, err := http.NewRequest("GET", "/oauth/client", nil) req, err := http.NewRequest("GET", "/oauth/client", nil)
assert.NoError(t, err) assert.NoError(t, err)
@ -179,24 +204,32 @@ func TestViewOauthCallback(t *testing.T) {
Config: app.Config(), Config: app.Config(),
DB: app.DB(), DB: app.DB(),
Store: app.SessionStore(), Store: app.SessionStore(),
HttpClient: &MockHTTPClient{ oauthClient: writeAsOauthClient{
DoDo: func(req *http.Request) (*http.Response, error) { ClientID: app.Config().WriteAsOauth.ClientID,
switch req.URL.String() { ClientSecret: app.Config().WriteAsOauth.ClientSecret,
case "https://write.as/oauth/token": ExchangeLocation: app.Config().WriteAsOauth.TokenLocation,
return &http.Response{ InspectLocation: app.Config().WriteAsOauth.InspectLocation,
StatusCode: 200, AuthLocation: app.Config().WriteAsOauth.AuthLocation,
Body: &StringReadCloser{strings.NewReader(`{"access_token": "access_token", "expires_in": 1000, "refresh_token": "refresh_token", "token_type": "access"}`)}, CallbackLocation: "http://localhost/oauth/callback",
}, nil HttpClient: &MockHTTPClient{
case "https://write.as/oauth/inspect": DoDo: func(req *http.Request) (*http.Response, error) {
switch req.URL.String() {
case "https://write.as/oauth/token":
return &http.Response{
StatusCode: 200,
Body: &StringReadCloser{strings.NewReader(`{"access_token": "access_token", "expires_in": 1000, "refresh_token": "refresh_token", "token_type": "access"}`)},
}, nil
case "https://write.as/oauth/inspect":
return &http.Response{
StatusCode: 200,
Body: &StringReadCloser{strings.NewReader(`{"client_id": "development", "user_id": "1", "expires_at": "2019-12-19T11:42:01Z", "username": "nick", "email": "nick@testing.write.as"}`)},
}, nil
}
return &http.Response{ return &http.Response{
StatusCode: 200, StatusCode: http.StatusNotFound,
Body: &StringReadCloser{strings.NewReader(`{"client_id": "development", "user_id": 1, "expires_at": "2019-12-19T11:42:01Z", "username": "nick", "email": "nick@testing.write.as"}`)},
}, nil }, nil
} },
return &http.Response{
StatusCode: http.StatusNotFound,
}, nil
}, },
}, },
} }
@ -206,6 +239,5 @@ func TestViewOauthCallback(t *testing.T) {
h.viewOauthCallback(rr, req) h.viewOauthCallback(rr, req)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, http.StatusTemporaryRedirect, rr.Code) assert.Equal(t, http.StatusTemporaryRedirect, rr.Code)
}) })
} }

@ -0,0 +1,91 @@
package writefreely
import (
"context"
"net/http"
"net/url"
"strings"
)
type writeAsOauthClient struct {
ClientID string
ClientSecret string
AuthLocation string
ExchangeLocation string
InspectLocation string
CallbackLocation string
HttpClient HttpClient
}
var _ oauthClient = writeAsOauthClient{}
func (c writeAsOauthClient) GetProvider() string {
return "write.as"
}
func (c writeAsOauthClient) GetClientID() string {
return c.ClientID
}
func (c writeAsOauthClient) buildLoginURL(state string) (string, error) {
u, err := url.Parse(c.AuthLocation)
if err != nil {
return "", err
}
q := u.Query()
q.Set("client_id", c.ClientID)
q.Set("redirect_uri", c.CallbackLocation)
q.Set("response_type", "code")
q.Set("state", state)
u.RawQuery = q.Encode()
return u.String(), nil
}
func (c writeAsOauthClient) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) {
form := url.Values{}
form.Add("grant_type", "authorization_code")
form.Add("redirect_uri", c.CallbackLocation)
form.Add("code", code)
req, err := http.NewRequest("POST", c.ExchangeLocation, strings.NewReader(form.Encode()))
if err != nil {
return nil, err
}
req.WithContext(ctx)
req.Header.Set("User-Agent", "writefreely")
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth(c.ClientID, c.ClientSecret)
resp, err := c.HttpClient.Do(req)
if err != nil {
return nil, err
}
var tokenResponse TokenResponse
if err := limitedJsonUnmarshal(resp.Body, tokenRequestMaxLen, &tokenResponse); err != nil {
return nil, err
}
return &tokenResponse, nil
}
func (c writeAsOauthClient) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) {
req, err := http.NewRequest("GET", c.InspectLocation, nil)
if err != nil {
return nil, err
}
req.WithContext(ctx)
req.Header.Set("User-Agent", "writefreely")
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+accessToken)
resp, err := c.HttpClient.Do(req)
if err != nil {
return nil, err
}
var inspectResponse InspectResponse
if err := limitedJsonUnmarshal(resp.Body, infoRequestMaxLen, &inspectResponse); err != nil {
return nil, err
}
return &inspectResponse, nil
}

@ -70,6 +70,9 @@ func InitRoutes(apper Apper, r *mux.Router) *mux.Router {
write.HandleFunc(nodeinfo.NodeInfoPath, handler.LogHandlerFunc(http.HandlerFunc(ni.NodeInfoDiscover))) write.HandleFunc(nodeinfo.NodeInfoPath, handler.LogHandlerFunc(http.HandlerFunc(ni.NodeInfoDiscover)))
write.HandleFunc(niCfg.InfoURL, handler.LogHandlerFunc(http.HandlerFunc(ni.NodeInfo))) write.HandleFunc(niCfg.InfoURL, handler.LogHandlerFunc(http.HandlerFunc(ni.NodeInfo)))
configureSlackOauth(write, apper.App())
configureWriteAsOauth(write, apper.App())
// Set up dyamic page handlers // Set up dyamic page handlers
// Handle auth // Handle auth
auth := write.PathPrefix("/api/auth/").Subrouter() auth := write.PathPrefix("/api/auth/").Subrouter()
@ -80,14 +83,6 @@ func InitRoutes(apper Apper, r *mux.Router) *mux.Router {
auth.HandleFunc("/read", handler.WebErrors(handleWebCollectionUnlock, UserLevelNone)).Methods("POST") auth.HandleFunc("/read", handler.WebErrors(handleWebCollectionUnlock, UserLevelNone)).Methods("POST")
auth.HandleFunc("/me", handler.All(handleAPILogout)).Methods("DELETE") auth.HandleFunc("/me", handler.All(handleAPILogout)).Methods("DELETE")
oauthHandler := oauthHandler{
HttpClient: &http.Client{},
Config: apper.App().Config(),
DB: apper.App().DB(),
Store: apper.App().SessionStore(),
}
oauthHandler.configureRoutes(write)
// Handle logged in user sections // Handle logged in user sections
me := write.PathPrefix("/me").Subrouter() me := write.PathPrefix("/me").Subrouter()
me.HandleFunc("/", handler.Redirect("/me", UserLevelUser)) me.HandleFunc("/", handler.Redirect("/me", UserLevelUser))
@ -189,6 +184,7 @@ func InitRoutes(apper Apper, r *mux.Router) *mux.Router {
} }
write.HandleFunc(draftEditPrefix+"/{post}", handler.Web(handleViewPost, UserLevelOptional)) write.HandleFunc(draftEditPrefix+"/{post}", handler.Web(handleViewPost, UserLevelOptional))
write.HandleFunc("/", handler.Web(handleViewHome, UserLevelOptional)) write.HandleFunc("/", handler.Web(handleViewHome, UserLevelOptional))
return r return r
} }

Loading…
Cancel
Save