diff --git a/database.go b/database.go index 5b3aaa8..56035dd 100644 --- a/database.go +++ b/database.go @@ -12,11 +12,8 @@ package writefreely import ( "context" - "crypto/rand" "database/sql" "fmt" - "github.com/pkg/errors" - "math/big" "net/http" "strings" "time" @@ -2463,11 +2460,8 @@ func (db *datastore) GetCollectionLastPostTime(id int64) (*time.Time, error) { } func (db *datastore) GenerateOAuthState(ctx context.Context) (string, error) { - state, err := randString(24) - if err != nil { - return "", err - } - _, err = db.ExecContext(ctx, "INSERT INTO oauth_client_state (state, used, created_at) VALUES (?, FALSE, NOW())", state) + state := store.Generate62RandomString(24) + _, err := db.ExecContext(ctx, "INSERT INTO oauth_client_state (state, used, created_at) VALUES (?, FALSE, NOW())", state) if err != nil { return "", fmt.Errorf("unable to record oauth client state: %w", err) } @@ -2494,7 +2488,7 @@ func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID, remote if db.driverName == driverSQLite { _, err = db.ExecContext(ctx, "INSERT OR REPLACE INTO users_oauth (user_id, remote_user_id) VALUES (?, ?)", localUserID, remoteUserID) } else { - _, err = db.ExecContext(ctx, "INSERT INTO users_oauth (user_id, remote_user_id) VALUES (?, ?) "+db.upsert("user_id") + " user_id = ?", localUserID, remoteUserID, localUserID) + _, err = db.ExecContext(ctx, "INSERT INTO users_oauth (user_id, remote_user_id) VALUES (?, ?) "+db.upsert("user_id")+" user_id = ?", localUserID, remoteUserID, localUserID) } if err != nil { log.Error("Unable to INSERT users_oauth for '%d': %v", localUserID, err) @@ -2545,26 +2539,3 @@ func handleFailedPostInsert(err error) error { log.Error("Couldn't insert into posts: %v", err) return err } - -func randString(length int) (string, error) { - // every printable character on a US keyboard - charset := []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789") - out := make([]rune, length) - - setLen := big.NewInt(int64(len(charset))) - for idx := 0; idx < length; idx++ { - offset, err := rand.Int(rand.Reader, setLen) - if err != nil { - return "", err - } - - if !offset.IsUint64() { - // this should (in theory) never happen - return "", errors.Errorf("Non-Uint64 offset returned from rand.Int") - } - - out[idx] = charset[offset.Uint64()] - } - - return string(out), nil -} diff --git a/oauth.go b/oauth.go index aafdb51..d918f7f 100644 --- a/oauth.go +++ b/oauth.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/gorilla/sessions" "github.com/guregu/null/zero" + "github.com/writeas/nerds/store" "github.com/writeas/web-core/auth" "github.com/writeas/web-core/log" "github.com/writeas/writefreely/config" @@ -143,11 +144,7 @@ func (h oauthHandler) viewOauthCallback(w http.ResponseWriter, r *http.Request) //create a random string. If the user needs to set a password, they //can do so through the settings page or through the password reset //flow. - randPass, err := randString(14) - if err != nil { - failOAuthRequest(w, http.StatusInternalServerError, err.Error()) - return - } + randPass := store.Generate62RandomString(14) hashedPass, err := auth.HashPass([]byte(randPass)) if err != nil { log.ErrorLog.Println(err) diff --git a/oauth_test.go b/oauth_test.go index 9418721..482a846 100644 --- a/oauth_test.go +++ b/oauth_test.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/gorilla/sessions" "github.com/stretchr/testify/assert" + "github.com/writeas/nerds/store" "github.com/writeas/writefreely/config" "net/http" "net/http/httptest" @@ -120,7 +121,7 @@ func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context) (string, er if m.DoGenerateOAuthState != nil { return m.DoGenerateOAuthState(ctx) } - return randString(14) + return store.Generate62RandomString(14), nil } func TestViewOauthInit(t *testing.T) {