|
|
|
@ -12,11 +12,8 @@ package writefreely |
|
|
|
|
|
|
|
|
|
import ( |
|
|
|
|
"context" |
|
|
|
|
"crypto/rand" |
|
|
|
|
"database/sql" |
|
|
|
|
"fmt" |
|
|
|
|
"github.com/pkg/errors" |
|
|
|
|
"math/big" |
|
|
|
|
"net/http" |
|
|
|
|
"strings" |
|
|
|
|
"time" |
|
|
|
@ -2463,11 +2460,8 @@ func (db *datastore) GetCollectionLastPostTime(id int64) (*time.Time, error) { |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (db *datastore) GenerateOAuthState(ctx context.Context) (string, error) { |
|
|
|
|
state, err := randString(24) |
|
|
|
|
if err != nil { |
|
|
|
|
return "", err |
|
|
|
|
} |
|
|
|
|
_, err = db.ExecContext(ctx, "INSERT INTO oauth_client_state (state, used, created_at) VALUES (?, FALSE, NOW())", state) |
|
|
|
|
state := store.Generate62RandomString(24) |
|
|
|
|
_, err := db.ExecContext(ctx, "INSERT INTO oauth_client_state (state, used, created_at) VALUES (?, FALSE, NOW())", state) |
|
|
|
|
if err != nil { |
|
|
|
|
return "", fmt.Errorf("unable to record oauth client state: %w", err) |
|
|
|
|
} |
|
|
|
@ -2494,7 +2488,7 @@ func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID, remote |
|
|
|
|
if db.driverName == driverSQLite { |
|
|
|
|
_, err = db.ExecContext(ctx, "INSERT OR REPLACE INTO users_oauth (user_id, remote_user_id) VALUES (?, ?)", localUserID, remoteUserID) |
|
|
|
|
} else { |
|
|
|
|
_, err = db.ExecContext(ctx, "INSERT INTO users_oauth (user_id, remote_user_id) VALUES (?, ?) "+db.upsert("user_id") + " user_id = ?", localUserID, remoteUserID, localUserID) |
|
|
|
|
_, err = db.ExecContext(ctx, "INSERT INTO users_oauth (user_id, remote_user_id) VALUES (?, ?) "+db.upsert("user_id")+" user_id = ?", localUserID, remoteUserID, localUserID) |
|
|
|
|
} |
|
|
|
|
if err != nil { |
|
|
|
|
log.Error("Unable to INSERT users_oauth for '%d': %v", localUserID, err) |
|
|
|
@ -2545,26 +2539,3 @@ func handleFailedPostInsert(err error) error { |
|
|
|
|
log.Error("Couldn't insert into posts: %v", err) |
|
|
|
|
return err |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func randString(length int) (string, error) { |
|
|
|
|
// every printable character on a US keyboard
|
|
|
|
|
charset := []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789") |
|
|
|
|
out := make([]rune, length) |
|
|
|
|
|
|
|
|
|
setLen := big.NewInt(int64(len(charset))) |
|
|
|
|
for idx := 0; idx < length; idx++ { |
|
|
|
|
offset, err := rand.Int(rand.Reader, setLen) |
|
|
|
|
if err != nil { |
|
|
|
|
return "", err |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if !offset.IsUint64() { |
|
|
|
|
// this should (in theory) never happen
|
|
|
|
|
return "", errors.Errorf("Non-Uint64 offset returned from rand.Int") |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
out[idx] = charset[offset.Uint64()] |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
return string(out), nil |
|
|
|
|
} |
|
|
|
|