Code cleanup from PR 255 feedback. T705

pull/225/head
Nick Gerakines 5 years ago
parent bf3b6a5ba0
commit 4266154749
  1. 35
      database.go
  2. 7
      oauth.go
  3. 3
      oauth_test.go

@ -12,11 +12,8 @@ package writefreely
import ( import (
"context" "context"
"crypto/rand"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/pkg/errors"
"math/big"
"net/http" "net/http"
"strings" "strings"
"time" "time"
@ -2463,11 +2460,8 @@ func (db *datastore) GetCollectionLastPostTime(id int64) (*time.Time, error) {
} }
func (db *datastore) GenerateOAuthState(ctx context.Context) (string, error) { func (db *datastore) GenerateOAuthState(ctx context.Context) (string, error) {
state, err := randString(24) state := store.Generate62RandomString(24)
if err != nil { _, err := db.ExecContext(ctx, "INSERT INTO oauth_client_state (state, used, created_at) VALUES (?, FALSE, NOW())", state)
return "", err
}
_, err = db.ExecContext(ctx, "INSERT INTO oauth_client_state (state, used, created_at) VALUES (?, FALSE, NOW())", state)
if err != nil { if err != nil {
return "", fmt.Errorf("unable to record oauth client state: %w", err) return "", fmt.Errorf("unable to record oauth client state: %w", err)
} }
@ -2494,7 +2488,7 @@ func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID, remote
if db.driverName == driverSQLite { if db.driverName == driverSQLite {
_, err = db.ExecContext(ctx, "INSERT OR REPLACE INTO users_oauth (user_id, remote_user_id) VALUES (?, ?)", localUserID, remoteUserID) _, err = db.ExecContext(ctx, "INSERT OR REPLACE INTO users_oauth (user_id, remote_user_id) VALUES (?, ?)", localUserID, remoteUserID)
} else { } else {
_, err = db.ExecContext(ctx, "INSERT INTO users_oauth (user_id, remote_user_id) VALUES (?, ?) "+db.upsert("user_id") + " user_id = ?", localUserID, remoteUserID, localUserID) _, err = db.ExecContext(ctx, "INSERT INTO users_oauth (user_id, remote_user_id) VALUES (?, ?) "+db.upsert("user_id")+" user_id = ?", localUserID, remoteUserID, localUserID)
} }
if err != nil { if err != nil {
log.Error("Unable to INSERT users_oauth for '%d': %v", localUserID, err) log.Error("Unable to INSERT users_oauth for '%d': %v", localUserID, err)
@ -2545,26 +2539,3 @@ func handleFailedPostInsert(err error) error {
log.Error("Couldn't insert into posts: %v", err) log.Error("Couldn't insert into posts: %v", err)
return err return err
} }
func randString(length int) (string, error) {
// every printable character on a US keyboard
charset := []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789")
out := make([]rune, length)
setLen := big.NewInt(int64(len(charset)))
for idx := 0; idx < length; idx++ {
offset, err := rand.Int(rand.Reader, setLen)
if err != nil {
return "", err
}
if !offset.IsUint64() {
// this should (in theory) never happen
return "", errors.Errorf("Non-Uint64 offset returned from rand.Int")
}
out[idx] = charset[offset.Uint64()]
}
return string(out), nil
}

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"github.com/gorilla/sessions" "github.com/gorilla/sessions"
"github.com/guregu/null/zero" "github.com/guregu/null/zero"
"github.com/writeas/nerds/store"
"github.com/writeas/web-core/auth" "github.com/writeas/web-core/auth"
"github.com/writeas/web-core/log" "github.com/writeas/web-core/log"
"github.com/writeas/writefreely/config" "github.com/writeas/writefreely/config"
@ -143,11 +144,7 @@ func (h oauthHandler) viewOauthCallback(w http.ResponseWriter, r *http.Request)
//create a random string. If the user needs to set a password, they //create a random string. If the user needs to set a password, they
//can do so through the settings page or through the password reset //can do so through the settings page or through the password reset
//flow. //flow.
randPass, err := randString(14) randPass := store.Generate62RandomString(14)
if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
return
}
hashedPass, err := auth.HashPass([]byte(randPass)) hashedPass, err := auth.HashPass([]byte(randPass))
if err != nil { if err != nil {
log.ErrorLog.Println(err) log.ErrorLog.Println(err)

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"github.com/gorilla/sessions" "github.com/gorilla/sessions"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/writeas/nerds/store"
"github.com/writeas/writefreely/config" "github.com/writeas/writefreely/config"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -120,7 +121,7 @@ func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context) (string, er
if m.DoGenerateOAuthState != nil { if m.DoGenerateOAuthState != nil {
return m.DoGenerateOAuthState(ctx) return m.DoGenerateOAuthState(ctx)
} }
return randString(14) return store.Generate62RandomString(14), nil
} }
func TestViewOauthInit(t *testing.T) { func TestViewOauthInit(t *testing.T) {

Loading…
Cancel
Save