mirror of https://github.com/writeas/writefreely
commit
2aea9560bc
@ -0,0 +1,52 @@ |
||||
package db |
||||
|
||||
import ( |
||||
"fmt" |
||||
"strings" |
||||
) |
||||
|
||||
type AlterTableSqlBuilder struct { |
||||
Dialect DialectType |
||||
Name string |
||||
Changes []string |
||||
} |
||||
|
||||
func (b *AlterTableSqlBuilder) AddColumn(col *Column) *AlterTableSqlBuilder { |
||||
if colVal, err := col.String(); err == nil { |
||||
b.Changes = append(b.Changes, fmt.Sprintf("ADD COLUMN %s", colVal)) |
||||
} |
||||
return b |
||||
} |
||||
|
||||
func (b *AlterTableSqlBuilder) ChangeColumn(name string, col *Column) *AlterTableSqlBuilder { |
||||
if colVal, err := col.String(); err == nil { |
||||
b.Changes = append(b.Changes, fmt.Sprintf("CHANGE COLUMN %s %s", name, colVal)) |
||||
} |
||||
return b |
||||
} |
||||
|
||||
func (b *AlterTableSqlBuilder) AddUniqueConstraint(name string, columns ...string) *AlterTableSqlBuilder { |
||||
b.Changes = append(b.Changes, fmt.Sprintf("ADD CONSTRAINT %s UNIQUE (%s)", name, strings.Join(columns, ", "))) |
||||
return b |
||||
} |
||||
|
||||
func (b *AlterTableSqlBuilder) ToSQL() (string, error) { |
||||
var str strings.Builder |
||||
|
||||
str.WriteString("ALTER TABLE ") |
||||
str.WriteString(b.Name) |
||||
str.WriteString(" ") |
||||
|
||||
if len(b.Changes) == 0 { |
||||
return "", fmt.Errorf("no changes provide for table: %s", b.Name) |
||||
} |
||||
changeCount := len(b.Changes) |
||||
for i, thing := range b.Changes { |
||||
str.WriteString(thing) |
||||
if i < changeCount-1 { |
||||
str.WriteString(", ") |
||||
} |
||||
} |
||||
|
||||
return str.String(), nil |
||||
} |
@ -0,0 +1,56 @@ |
||||
package db |
||||
|
||||
import "testing" |
||||
|
||||
func TestAlterTableSqlBuilder_ToSQL(t *testing.T) { |
||||
type fields struct { |
||||
Dialect DialectType |
||||
Name string |
||||
Changes []string |
||||
} |
||||
tests := []struct { |
||||
name string |
||||
builder *AlterTableSqlBuilder |
||||
want string |
||||
wantErr bool |
||||
}{ |
||||
{ |
||||
name: "MySQL add int", |
||||
builder: DialectMySQL. |
||||
AlterTable("the_table"). |
||||
AddColumn(DialectMySQL.Column("the_col", ColumnTypeInteger, UnsetSize)), |
||||
want: "ALTER TABLE the_table ADD COLUMN the_col INT NOT NULL", |
||||
wantErr: false, |
||||
}, |
||||
{ |
||||
name: "MySQL add string", |
||||
builder: DialectMySQL. |
||||
AlterTable("the_table"). |
||||
AddColumn(DialectMySQL.Column("the_col", ColumnTypeVarChar, OptionalInt{true, 128})), |
||||
want: "ALTER TABLE the_table ADD COLUMN the_col VARCHAR(128) NOT NULL", |
||||
wantErr: false, |
||||
}, |
||||
|
||||
{ |
||||
name: "MySQL add int and string", |
||||
builder: DialectMySQL. |
||||
AlterTable("the_table"). |
||||
AddColumn(DialectMySQL.Column("first_col", ColumnTypeInteger, UnsetSize)). |
||||
AddColumn(DialectMySQL.Column("second_col", ColumnTypeVarChar, OptionalInt{true, 128})), |
||||
want: "ALTER TABLE the_table ADD COLUMN first_col INT NOT NULL, ADD COLUMN second_col VARCHAR(128) NOT NULL", |
||||
wantErr: false, |
||||
}, |
||||
} |
||||
for _, tt := range tests { |
||||
t.Run(tt.name, func(t *testing.T) { |
||||
got, err := tt.builder.ToSQL() |
||||
if (err != nil) != tt.wantErr { |
||||
t.Errorf("ToSQL() error = %v, wantErr %v", err, tt.wantErr) |
||||
return |
||||
} |
||||
if got != tt.want { |
||||
t.Errorf("ToSQL() got = %v, want %v", got, tt.want) |
||||
} |
||||
}) |
||||
} |
||||
} |
@ -0,0 +1,76 @@ |
||||
package db |
||||
|
||||
import "fmt" |
||||
|
||||
type DialectType int |
||||
|
||||
const ( |
||||
DialectSQLite DialectType = iota |
||||
DialectMySQL DialectType = iota |
||||
) |
||||
|
||||
func (d DialectType) Column(name string, t ColumnType, size OptionalInt) *Column { |
||||
switch d { |
||||
case DialectSQLite: |
||||
return &Column{Dialect: DialectSQLite, Name: name, Type: t, Size: size} |
||||
case DialectMySQL: |
||||
return &Column{Dialect: DialectMySQL, Name: name, Type: t, Size: size} |
||||
default: |
||||
panic(fmt.Sprintf("unexpected dialect: %d", d)) |
||||
} |
||||
} |
||||
|
||||
func (d DialectType) Table(name string) *CreateTableSqlBuilder { |
||||
switch d { |
||||
case DialectSQLite: |
||||
return &CreateTableSqlBuilder{Dialect: DialectSQLite, Name: name} |
||||
case DialectMySQL: |
||||
return &CreateTableSqlBuilder{Dialect: DialectMySQL, Name: name} |
||||
default: |
||||
panic(fmt.Sprintf("unexpected dialect: %d", d)) |
||||
} |
||||
} |
||||
|
||||
func (d DialectType) AlterTable(name string) *AlterTableSqlBuilder { |
||||
switch d { |
||||
case DialectSQLite: |
||||
return &AlterTableSqlBuilder{Dialect: DialectSQLite, Name: name} |
||||
case DialectMySQL: |
||||
return &AlterTableSqlBuilder{Dialect: DialectMySQL, Name: name} |
||||
default: |
||||
panic(fmt.Sprintf("unexpected dialect: %d", d)) |
||||
} |
||||
} |
||||
|
||||
func (d DialectType) CreateUniqueIndex(name, table string, columns ...string) *CreateIndexSqlBuilder { |
||||
switch d { |
||||
case DialectSQLite: |
||||
return &CreateIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table, Unique: true, Columns: columns} |
||||
case DialectMySQL: |
||||
return &CreateIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table, Unique: true, Columns: columns} |
||||
default: |
||||
panic(fmt.Sprintf("unexpected dialect: %d", d)) |
||||
} |
||||
} |
||||
|
||||
func (d DialectType) CreateIndex(name, table string, columns ...string) *CreateIndexSqlBuilder { |
||||
switch d { |
||||
case DialectSQLite: |
||||
return &CreateIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table, Unique: false, Columns: columns} |
||||
case DialectMySQL: |
||||
return &CreateIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table, Unique: false, Columns: columns} |
||||
default: |
||||
panic(fmt.Sprintf("unexpected dialect: %d", d)) |
||||
} |
||||
} |
||||
|
||||
func (d DialectType) DropIndex(name, table string) *DropIndexSqlBuilder { |
||||
switch d { |
||||
case DialectSQLite: |
||||
return &DropIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table} |
||||
case DialectMySQL: |
||||
return &DropIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table} |
||||
default: |
||||
panic(fmt.Sprintf("unexpected dialect: %d", d)) |
||||
} |
||||
} |
@ -0,0 +1,53 @@ |
||||
package db |
||||
|
||||
import ( |
||||
"fmt" |
||||
"strings" |
||||
) |
||||
|
||||
type CreateIndexSqlBuilder struct { |
||||
Dialect DialectType |
||||
Name string |
||||
Table string |
||||
Unique bool |
||||
Columns []string |
||||
} |
||||
|
||||
type DropIndexSqlBuilder struct { |
||||
Dialect DialectType |
||||
Name string |
||||
Table string |
||||
} |
||||
|
||||
func (b *CreateIndexSqlBuilder) ToSQL() (string, error) { |
||||
var str strings.Builder |
||||
|
||||
str.WriteString("CREATE ") |
||||
if b.Unique { |
||||
str.WriteString("UNIQUE ") |
||||
} |
||||
str.WriteString("INDEX ") |
||||
str.WriteString(b.Name) |
||||
str.WriteString(" on ") |
||||
str.WriteString(b.Table) |
||||
|
||||
if len(b.Columns) == 0 { |
||||
return "", fmt.Errorf("columns provided for this index: %s", b.Name) |
||||
} |
||||
|
||||
str.WriteString(" (") |
||||
columnCount := len(b.Columns) |
||||
for i, thing := range b.Columns { |
||||
str.WriteString(thing) |
||||
if i < columnCount-1 { |
||||
str.WriteString(", ") |
||||
} |
||||
} |
||||
str.WriteString(")") |
||||
|
||||
return str.String(), nil |
||||
} |
||||
|
||||
func (b *DropIndexSqlBuilder) ToSQL() (string, error) { |
||||
return fmt.Sprintf("DROP INDEX %s on %s", b.Name, b.Table), nil |
||||
} |
@ -0,0 +1,9 @@ |
||||
package db |
||||
|
||||
type RawSqlBuilder struct { |
||||
Query string |
||||
} |
||||
|
||||
func (b *RawSqlBuilder) ToSQL() (string, error) { |
||||
return b.Query, nil |
||||
} |
@ -0,0 +1,67 @@ |
||||
package migrations |
||||
|
||||
import ( |
||||
"context" |
||||
"database/sql" |
||||
|
||||
wf_db "github.com/writeas/writefreely/db" |
||||
) |
||||
|
||||
func oauthSlack(db *datastore) error { |
||||
dialect := wf_db.DialectMySQL |
||||
if db.driverName == driverSQLite { |
||||
dialect = wf_db.DialectSQLite |
||||
} |
||||
return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { |
||||
builders := []wf_db.SQLBuilder{ |
||||
dialect. |
||||
AlterTable("oauth_client_states"). |
||||
AddColumn(dialect. |
||||
Column( |
||||
"provider", |
||||
wf_db.ColumnTypeVarChar, |
||||
wf_db.OptionalInt{Set: true, Value: 24,})). |
||||
AddColumn(dialect. |
||||
Column( |
||||
"client_id", |
||||
wf_db.ColumnTypeVarChar, |
||||
wf_db.OptionalInt{Set: true, Value: 128,})), |
||||
dialect. |
||||
AlterTable("oauth_users"). |
||||
ChangeColumn("remote_user_id", |
||||
dialect. |
||||
Column( |
||||
"remote_user_id", |
||||
wf_db.ColumnTypeVarChar, |
||||
wf_db.OptionalInt{Set: true, Value: 128,})). |
||||
AddColumn(dialect. |
||||
Column( |
||||
"provider", |
||||
wf_db.ColumnTypeVarChar, |
||||
wf_db.OptionalInt{Set: true, Value: 24,})). |
||||
AddColumn(dialect. |
||||
Column( |
||||
"client_id", |
||||
wf_db.ColumnTypeVarChar, |
||||
wf_db.OptionalInt{Set: true, Value: 128,})). |
||||
AddColumn(dialect. |
||||
Column( |
||||
"access_token", |
||||
wf_db.ColumnTypeVarChar, |
||||
wf_db.OptionalInt{Set: true, Value: 512,})), |
||||
dialect.DropIndex("remote_user_id", "oauth_users"), |
||||
dialect.DropIndex("user_id", "oauth_users"), |
||||
dialect.CreateUniqueIndex("oauth_users", "oauth_users", "user_id", "provider", "client_id"), |
||||
} |
||||
for _, builder := range builders { |
||||
query, err := builder.ToSQL() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if _, err := tx.ExecContext(ctx, query); err != nil { |
||||
return err |
||||
} |
||||
} |
||||
return nil |
||||
}) |
||||
} |
@ -0,0 +1,164 @@ |
||||
package writefreely |
||||
|
||||
import ( |
||||
"context" |
||||
"errors" |
||||
"github.com/writeas/slug" |
||||
"net/http" |
||||
"net/url" |
||||
"strings" |
||||
) |
||||
|
||||
type slackOauthClient struct { |
||||
ClientID string |
||||
ClientSecret string |
||||
TeamID string |
||||
CallbackLocation string |
||||
HttpClient HttpClient |
||||
} |
||||
|
||||
type slackExchangeResponse struct { |
||||
OK bool `json:"ok"` |
||||
AccessToken string `json:"access_token"` |
||||
Scope string `json:"scope"` |
||||
TeamName string `json:"team_name"` |
||||
TeamID string `json:"team_id"` |
||||
Error string `json:"error"` |
||||
} |
||||
|
||||
type slackIdentity struct { |
||||
Name string `json:"name"` |
||||
ID string `json:"id"` |
||||
Email string `json:"email"` |
||||
} |
||||
|
||||
type slackTeam struct { |
||||
Name string `json:"name"` |
||||
ID string `json:"id"` |
||||
} |
||||
|
||||
type slackUserIdentityResponse struct { |
||||
OK bool `json:"ok"` |
||||
User slackIdentity `json:"user"` |
||||
Team slackTeam `json:"team"` |
||||
Error string `json:"error"` |
||||
} |
||||
|
||||
const ( |
||||
slackAuthLocation = "https://slack.com/oauth/authorize" |
||||
slackExchangeLocation = "https://slack.com/api/oauth.access" |
||||
slackIdentityLocation = "https://slack.com/api/users.identity" |
||||
) |
||||
|
||||
var _ oauthClient = slackOauthClient{} |
||||
|
||||
func (c slackOauthClient) GetProvider() string { |
||||
return "slack" |
||||
} |
||||
|
||||
func (c slackOauthClient) GetClientID() string { |
||||
return c.ClientID |
||||
} |
||||
|
||||
func (c slackOauthClient) buildLoginURL(state string) (string, error) { |
||||
u, err := url.Parse(slackAuthLocation) |
||||
if err != nil { |
||||
return "", err |
||||
} |
||||
q := u.Query() |
||||
q.Set("client_id", c.ClientID) |
||||
q.Set("scope", "identity.basic identity.email identity.team") |
||||
q.Set("redirect_uri", c.CallbackLocation) |
||||
q.Set("state", state) |
||||
|
||||
// If this param is not set, the user can select which team they
|
||||
// authenticate through and then we'd have to match the configured team
|
||||
// against the profile get. That is extra work in the post-auth phase
|
||||
// that we don't want to do.
|
||||
q.Set("team", c.TeamID) |
||||
|
||||
// The Slack OAuth docs don't explicitly list this one, but it is part of
|
||||
// the spec, so we include it anyway.
|
||||
q.Set("response_type", "code") |
||||
u.RawQuery = q.Encode() |
||||
return u.String(), nil |
||||
} |
||||
|
||||
func (c slackOauthClient) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) { |
||||
form := url.Values{} |
||||
// The oauth.access documentation doesn't explicitly mention this
|
||||
// parameter, but it is part of the spec, so we include it anyway.
|
||||
// https://api.slack.com/methods/oauth.access
|
||||
form.Add("grant_type", "authorization_code") |
||||
form.Add("redirect_uri", c.CallbackLocation) |
||||
form.Add("code", code) |
||||
req, err := http.NewRequest("POST", slackExchangeLocation, strings.NewReader(form.Encode())) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
req.WithContext(ctx) |
||||
req.Header.Set("User-Agent", "writefreely") |
||||
req.Header.Set("Accept", "application/json") |
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") |
||||
req.SetBasicAuth(c.ClientID, c.ClientSecret) |
||||
|
||||
resp, err := c.HttpClient.Do(req) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
if resp.StatusCode != http.StatusOK { |
||||
return nil, errors.New("unable to exchange code for access token") |
||||
} |
||||
|
||||
var tokenResponse slackExchangeResponse |
||||
if err := limitedJsonUnmarshal(resp.Body, tokenRequestMaxLen, &tokenResponse); err != nil { |
||||
return nil, err |
||||
} |
||||
if !tokenResponse.OK { |
||||
return nil, errors.New(tokenResponse.Error) |
||||
} |
||||
return tokenResponse.TokenResponse(), nil |
||||
} |
||||
|
||||
func (c slackOauthClient) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) { |
||||
req, err := http.NewRequest("GET", slackIdentityLocation, nil) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
req.WithContext(ctx) |
||||
req.Header.Set("User-Agent", "writefreely") |
||||
req.Header.Set("Accept", "application/json") |
||||
req.Header.Set("Authorization", "Bearer "+accessToken) |
||||
|
||||
resp, err := c.HttpClient.Do(req) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
if resp.StatusCode != http.StatusOK { |
||||
return nil, errors.New("unable to inspect access token") |
||||
} |
||||
|
||||
var inspectResponse slackUserIdentityResponse |
||||
if err := limitedJsonUnmarshal(resp.Body, infoRequestMaxLen, &inspectResponse); err != nil { |
||||
return nil, err |
||||
} |
||||
if !inspectResponse.OK { |
||||
return nil, errors.New(inspectResponse.Error) |
||||
} |
||||
return inspectResponse.InspectResponse(), nil |
||||
} |
||||
|
||||
func (resp slackUserIdentityResponse) InspectResponse() *InspectResponse { |
||||
return &InspectResponse{ |
||||
UserID: resp.User.ID, |
||||
Username: slug.Make(resp.User.Name), |
||||
DisplayName: resp.User.Name, |
||||
Email: resp.User.Email, |
||||
} |
||||
} |
||||
|
||||
func (resp slackExchangeResponse) TokenResponse() *TokenResponse { |
||||
return &TokenResponse{ |
||||
AccessToken: resp.AccessToken, |
||||
} |
||||
} |
@ -0,0 +1,110 @@ |
||||
package writefreely |
||||
|
||||
import ( |
||||
"context" |
||||
"errors" |
||||
"net/http" |
||||
"net/url" |
||||
"strings" |
||||
) |
||||
|
||||
type writeAsOauthClient struct { |
||||
ClientID string |
||||
ClientSecret string |
||||
AuthLocation string |
||||
ExchangeLocation string |
||||
InspectLocation string |
||||
CallbackLocation string |
||||
HttpClient HttpClient |
||||
} |
||||
|
||||
var _ oauthClient = writeAsOauthClient{} |
||||
|
||||
const ( |
||||
writeAsAuthLocation = "https://write.as/oauth/login" |
||||
writeAsExchangeLocation = "https://write.as/oauth/token" |
||||
writeAsIdentityLocation = "https://write.as/oauth/inspect" |
||||
) |
||||
|
||||
func (c writeAsOauthClient) GetProvider() string { |
||||
return "write.as" |
||||
} |
||||
|
||||
func (c writeAsOauthClient) GetClientID() string { |
||||
return c.ClientID |
||||
} |
||||
|
||||
func (c writeAsOauthClient) buildLoginURL(state string) (string, error) { |
||||
u, err := url.Parse(c.AuthLocation) |
||||
if err != nil { |
||||
return "", err |
||||
} |
||||
q := u.Query() |
||||
q.Set("client_id", c.ClientID) |
||||
q.Set("redirect_uri", c.CallbackLocation) |
||||
q.Set("response_type", "code") |
||||
q.Set("state", state) |
||||
u.RawQuery = q.Encode() |
||||
return u.String(), nil |
||||
} |
||||
|
||||
func (c writeAsOauthClient) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) { |
||||
form := url.Values{} |
||||
form.Add("grant_type", "authorization_code") |
||||
form.Add("redirect_uri", c.CallbackLocation) |
||||
form.Add("code", code) |
||||
req, err := http.NewRequest("POST", c.ExchangeLocation, strings.NewReader(form.Encode())) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
req.WithContext(ctx) |
||||
req.Header.Set("User-Agent", "writefreely") |
||||
req.Header.Set("Accept", "application/json") |
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") |
||||
req.SetBasicAuth(c.ClientID, c.ClientSecret) |
||||
|
||||
resp, err := c.HttpClient.Do(req) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
if resp.StatusCode != http.StatusOK { |
||||
return nil, errors.New("unable to exchange code for access token") |
||||
} |
||||
|
||||
var tokenResponse TokenResponse |
||||
if err := limitedJsonUnmarshal(resp.Body, tokenRequestMaxLen, &tokenResponse); err != nil { |
||||
return nil, err |
||||
} |
||||
if tokenResponse.Error != "" { |
||||
return nil, errors.New(tokenResponse.Error) |
||||
} |
||||
return &tokenResponse, nil |
||||
} |
||||
|
||||
func (c writeAsOauthClient) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) { |
||||
req, err := http.NewRequest("GET", c.InspectLocation, nil) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
req.WithContext(ctx) |
||||
req.Header.Set("User-Agent", "writefreely") |
||||
req.Header.Set("Accept", "application/json") |
||||
req.Header.Set("Authorization", "Bearer "+accessToken) |
||||
|
||||
resp, err := c.HttpClient.Do(req) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
if resp.StatusCode != http.StatusOK { |
||||
return nil, errors.New("unable to inspect access token") |
||||
} |
||||
|
||||
var inspectResponse InspectResponse |
||||
if err := limitedJsonUnmarshal(resp.Body, infoRequestMaxLen, &inspectResponse); err != nil { |
||||
return nil, err |
||||
} |
||||
if inspectResponse.Error != "" { |
||||
return nil, errors.New(inspectResponse.Error) |
||||
} |
||||
return &inspectResponse, nil |
||||
} |
Loading…
Reference in new issue