Changed oauth table names per PR feedback. T705

pull/225/head
Nick Gerakines 5 years ago
parent 6bcc4cfa46
commit b5f716135b
  1. 12
      database.go
  2. 6
      database_test.go
  3. 4
      migrations/v4.go

@ -2461,7 +2461,7 @@ func (db *datastore) GetCollectionLastPostTime(id int64) (*time.Time, error) {
func (db *datastore) GenerateOAuthState(ctx context.Context) (string, error) { func (db *datastore) GenerateOAuthState(ctx context.Context) (string, error) {
state := store.Generate62RandomString(24) state := store.Generate62RandomString(24)
_, err := db.ExecContext(ctx, "INSERT INTO oauth_client_state (state, used, created_at) VALUES (?, FALSE, NOW())", state) _, err := db.ExecContext(ctx, "INSERT INTO oauth_client_states (state, used, created_at) VALUES (?, FALSE, NOW())", state)
if err != nil { if err != nil {
return "", fmt.Errorf("unable to record oauth client state: %w", err) return "", fmt.Errorf("unable to record oauth client state: %w", err)
} }
@ -2469,7 +2469,7 @@ func (db *datastore) GenerateOAuthState(ctx context.Context) (string, error) {
} }
func (db *datastore) ValidateOAuthState(ctx context.Context, state string) error { func (db *datastore) ValidateOAuthState(ctx context.Context, state string) error {
res, err := db.ExecContext(ctx, "UPDATE oauth_client_state SET used = TRUE WHERE state = ?", state) res, err := db.ExecContext(ctx, "UPDATE oauth_client_states SET used = TRUE WHERE state = ?", state)
if err != nil { if err != nil {
return err return err
} }
@ -2486,12 +2486,12 @@ func (db *datastore) ValidateOAuthState(ctx context.Context, state string) error
func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID, remoteUserID int64) error { func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID, remoteUserID int64) error {
var err error var err error
if db.driverName == driverSQLite { if db.driverName == driverSQLite {
_, err = db.ExecContext(ctx, "INSERT OR REPLACE INTO users_oauth (user_id, remote_user_id) VALUES (?, ?)", localUserID, remoteUserID) _, err = db.ExecContext(ctx, "INSERT OR REPLACE INTO oauth_users (user_id, remote_user_id) VALUES (?, ?)", localUserID, remoteUserID)
} else { } else {
_, err = db.ExecContext(ctx, "INSERT INTO users_oauth (user_id, remote_user_id) VALUES (?, ?) "+db.upsert("user_id")+" user_id = ?", localUserID, remoteUserID, localUserID) _, err = db.ExecContext(ctx, "INSERT INTO oauth_users (user_id, remote_user_id) VALUES (?, ?) "+db.upsert("user_id")+" user_id = ?", localUserID, remoteUserID, localUserID)
} }
if err != nil { if err != nil {
log.Error("Unable to INSERT users_oauth for '%d': %v", localUserID, err) log.Error("Unable to INSERT oauth_users for '%d': %v", localUserID, err)
} }
return err return err
} }
@ -2500,7 +2500,7 @@ func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID, remote
func (db *datastore) GetIDForRemoteUser(ctx context.Context, remoteUserID int64) (int64, error) { func (db *datastore) GetIDForRemoteUser(ctx context.Context, remoteUserID int64) (int64, error) {
var userID int64 = -1 var userID int64 = -1
err := db. err := db.
QueryRowContext(ctx, "SELECT user_id FROM users_oauth WHERE remote_user_id = ?", remoteUserID). QueryRowContext(ctx, "SELECT user_id FROM oauth_users WHERE remote_user_id = ?", remoteUserID).
Scan(&userID) Scan(&userID)
// Not finding a record is OK. // Not finding a record is OK.
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {

@ -22,19 +22,19 @@ func TestOAuthDatastore(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, state, 24) assert.Len(t, state, 24)
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_state` WHERE `state` = ? AND `used` = false", state) countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = false", state)
err = ds.ValidateOAuthState(ctx, state) err = ds.ValidateOAuthState(ctx, state)
assert.NoError(t, err) assert.NoError(t, err)
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_state` WHERE `state` = ? AND `used` = true", state) countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = true", state)
var localUserID int64 = 99 var localUserID int64 = 99
var remoteUserID int64 = 100 var remoteUserID int64 = 100
err = ds.RecordRemoteUserID(ctx, localUserID, remoteUserID) err = ds.RecordRemoteUserID(ctx, localUserID, remoteUserID)
assert.NoError(t, err) assert.NoError(t, err)
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `users_oauth` WHERE `user_id` = ? AND `remote_user_id` = ?", localUserID, remoteUserID) countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_users` WHERE `user_id` = ? AND `remote_user_id` = ?", localUserID, remoteUserID)
foundUserID, err := ds.GetIDForRemoteUser(ctx, remoteUserID) foundUserID, err := ds.GetIDForRemoteUser(ctx, remoteUserID)
assert.NoError(t, err) assert.NoError(t, err)

@ -14,7 +14,7 @@ func oauth(db *datastore) error {
} }
return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error {
createTableUsersOauth, err := dialect. createTableUsersOauth, err := dialect.
Table("users_oauth"). Table("oauth_users").
SetIfNotExists(true). SetIfNotExists(true).
Column(dialect.Column("user_id", wf_db.ColumnTypeInteger, wf_db.UnsetSize)). Column(dialect.Column("user_id", wf_db.ColumnTypeInteger, wf_db.UnsetSize)).
Column(dialect.Column("remote_user_id", wf_db.ColumnTypeInteger, wf_db.UnsetSize)). Column(dialect.Column("remote_user_id", wf_db.ColumnTypeInteger, wf_db.UnsetSize)).
@ -25,7 +25,7 @@ func oauth(db *datastore) error {
return err return err
} }
createTableOauthClientState, err := dialect. createTableOauthClientState, err := dialect.
Table("oauth_client_state"). Table("oauth_client_states").
SetIfNotExists(true). SetIfNotExists(true).
Column(dialect.Column("state", wf_db.ColumnTypeVarChar, wf_db.OptionalInt{Set: true, Value: 255})). Column(dialect.Column("state", wf_db.ColumnTypeVarChar, wf_db.OptionalInt{Set: true, Value: 255})).
Column(dialect.Column("used", wf_db.ColumnTypeBool, wf_db.UnsetSize)). Column(dialect.Column("used", wf_db.ColumnTypeBool, wf_db.UnsetSize)).

Loading…
Cancel
Save