diff --git a/.travis.yml b/.travis.yml index 1e58d6b..fddc71c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,7 @@ language: go go: - - "1.11.x" + - "1.13.x" env: - GO111MODULE=on diff --git a/Makefile b/Makefile index 757bcfd..85f02d3 100644 --- a/Makefile +++ b/Makefile @@ -25,28 +25,40 @@ build-no-sqlite: assets-no-sqlite deps-no-sqlite build-linux: deps @hash xgo > /dev/null 2>&1; if [ $$? -ne 0 ]; then \ - $(GOGET) -u github.com/karalabe/xgo; \ + $(GOGET) -u src.techknowlogick.com/xgo; \ fi xgo --targets=linux/amd64, -dest build/ $(LDFLAGS) -tags='sqlite' -out writefreely ./cmd/writefreely build-windows: deps @hash xgo > /dev/null 2>&1; if [ $$? -ne 0 ]; then \ - $(GOGET) -u github.com/karalabe/xgo; \ + $(GOGET) -u src.techknowlogick.com/xgo; \ fi xgo --targets=windows/amd64, -dest build/ $(LDFLAGS) -tags='sqlite' -out writefreely ./cmd/writefreely build-darwin: deps @hash xgo > /dev/null 2>&1; if [ $$? -ne 0 ]; then \ - $(GOGET) -u github.com/karalabe/xgo; \ + $(GOGET) -u src.techknowlogick.com/xgo; \ fi xgo --targets=darwin/amd64, -dest build/ $(LDFLAGS) -tags='sqlite' -out writefreely ./cmd/writefreely +build-arm6: deps + @hash xgo > /dev/null 2>&1; if [ $$? -ne 0 ]; then \ + $(GOGET) -u src.techknowlogick.com/xgo; \ + fi + xgo --targets=linux/arm-6, -dest build/ $(LDFLAGS) -tags='sqlite' -out writefreely ./cmd/writefreely + build-arm7: deps @hash xgo > /dev/null 2>&1; if [ $$? -ne 0 ]; then \ - $(GOGET) -u github.com/karalabe/xgo; \ + $(GOGET) -u src.techknowlogick.com/xgo; \ fi xgo --targets=linux/arm-7, -dest build/ $(LDFLAGS) -tags='sqlite' -out writefreely ./cmd/writefreely +build-arm64: deps + @hash xgo > /dev/null 2>&1; if [ $$? -ne 0 ]; then \ + $(GOGET) -u src.techknowlogick.com/xgo; \ + fi + xgo --targets=linux/arm64, -dest build/ $(LDFLAGS) -tags='sqlite' -out writefreely ./cmd/writefreely + build-docker : $(DOCKERCMD) build -t $(IMAGE_NAME):latest -t $(IMAGE_NAME):$(GITREV) . @@ -79,10 +91,18 @@ release : clean ui assets mv build/$(BINARY_NAME)-linux-amd64 $(BUILDPATH)/$(BINARY_NAME) tar -cvzf $(BINARY_NAME)_$(GITREV)_linux_amd64.tar.gz -C build $(BINARY_NAME) rm $(BUILDPATH)/$(BINARY_NAME) + $(MAKE) build-arm6 + mv build/$(BINARY_NAME)-linux-arm-6 $(BUILDPATH)/$(BINARY_NAME) + tar -cvzf $(BINARY_NAME)_$(GITREV)_linux_arm6.tar.gz -C build $(BINARY_NAME) + rm $(BUILDPATH)/$(BINARY_NAME) $(MAKE) build-arm7 mv build/$(BINARY_NAME)-linux-arm-7 $(BUILDPATH)/$(BINARY_NAME) tar -cvzf $(BINARY_NAME)_$(GITREV)_linux_arm7.tar.gz -C build $(BINARY_NAME) rm $(BUILDPATH)/$(BINARY_NAME) + $(MAKE) build-arm64 + mv build/$(BINARY_NAME)-linux-arm64 $(BUILDPATH)/$(BINARY_NAME) + tar -cvzf $(BINARY_NAME)_$(GITREV)_linux_arm64.tar.gz -C build $(BINARY_NAME) + rm $(BUILDPATH)/$(BINARY_NAME) $(MAKE) build-darwin mv build/$(BINARY_NAME)-darwin-10.6-amd64 $(BUILDPATH)/$(BINARY_NAME) tar -cvzf $(BINARY_NAME)_$(GITREV)_macos_amd64.tar.gz -C build $(BINARY_NAME) @@ -135,7 +155,7 @@ $(TMPBIN)/go-bindata: deps $(TMPBIN) $(GOBUILD) -o $(TMPBIN)/go-bindata github.com/jteeuwen/go-bindata/go-bindata $(TMPBIN)/xgo: deps $(TMPBIN) - $(GOBUILD) -o $(TMPBIN)/xgo github.com/karalabe/xgo + $(GOBUILD) -o $(TMPBIN)/xgo src.techknowlogick.com/xgo ci-assets : $(TMPBIN)/go-bindata $(TMPBIN)/go-bindata -pkg writefreely -ignore=\\.gitignore -tags="!wflib" schema.sql sqlite.sql diff --git a/account.go b/account.go index 180e9b0..5dba924 100644 --- a/account.go +++ b/account.go @@ -156,17 +156,9 @@ func signupWithRegistration(app *App, signup userRegistration, w http.ResponseWr Username: signup.Alias, HashedPass: hashedPass, HasPass: createdWithPass, - Email: zero.NewString("", signup.Email != ""), + Email: prepareUserEmail(signup.Email, app.keys.EmailKey), Created: time.Now().Truncate(time.Second).UTC(), } - if signup.Email != "" { - encEmail, err := data.Encrypt(app.keys.EmailKey, signup.Email) - if err != nil { - log.Error("Unable to encrypt email: %s\n", err) - } else { - u.Email.String = string(encEmail) - } - } // Create actual user if err := app.db.CreateUser(app.cfg, u, desiredUsername); err != nil { @@ -314,12 +306,16 @@ func viewLogin(app *App, w http.ResponseWriter, r *http.Request) error { Message template.HTML Flashes []template.HTML LoginUsername string + OauthSlack bool + OauthWriteAs bool }{ pageForReq(app, r), r.FormValue("to"), template.HTML(""), []template.HTML{}, getTempInfo(app, "login-user", r, w), + app.Config().SlackOauth.ClientID != "", + app.Config().WriteAsOauth.ClientID != "", } if earlyError != "" { @@ -750,14 +746,20 @@ func viewArticles(app *App, u *User, w http.ResponseWriter, r *http.Request) err log.Error("unable to fetch collections: %v", err) } + silenced, err := app.db.IsUserSilenced(u.ID) + if err != nil { + log.Error("view articles: %v", err) + } d := struct { *UserPage AnonymousPosts *[]PublicPost Collections *[]Collection + Silenced bool }{ UserPage: NewUserPage(app, r, u, u.Username+"'s Posts", f), AnonymousPosts: p, Collections: c, + Silenced: silenced, } d.UserPage.SetMessaging(u) w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate") @@ -779,6 +781,11 @@ func viewCollections(app *App, u *User, w http.ResponseWriter, r *http.Request) uc, _ := app.db.GetUserCollectionCount(u.ID) // TODO: handle any errors + silenced, err := app.db.IsUserSilenced(u.ID) + if err != nil { + log.Error("view collections %v", err) + return fmt.Errorf("view collections: %v", err) + } d := struct { *UserPage Collections *[]Collection @@ -786,11 +793,13 @@ func viewCollections(app *App, u *User, w http.ResponseWriter, r *http.Request) UsedCollections, TotalCollections int NewBlogsDisabled bool + Silenced bool }{ UserPage: NewUserPage(app, r, u, u.Username+"'s Blogs", f), Collections: c, UsedCollections: int(uc), NewBlogsDisabled: !app.cfg.App.CanCreateBlogs(uc), + Silenced: silenced, } d.UserPage.SetMessaging(u) showUserPage(w, "collections", d) @@ -808,13 +817,20 @@ func viewEditCollection(app *App, u *User, w http.ResponseWriter, r *http.Reques return ErrCollectionNotFound } + silenced, err := app.db.IsUserSilenced(u.ID) + if err != nil { + log.Error("view edit collection %v", err) + return fmt.Errorf("view edit collection: %v", err) + } flashes, _ := getSessionFlashes(app, w, r, nil) obj := struct { *UserPage *Collection + Silenced bool }{ UserPage: NewUserPage(app, r, u, "Edit "+c.DisplayTitle(), flashes), Collection: c, + Silenced: silenced, } showUserPage(w, "collection", obj) @@ -976,17 +992,24 @@ func viewStats(app *App, u *User, w http.ResponseWriter, r *http.Request) error titleStats = c.DisplayTitle() + " " } + silenced, err := app.db.IsUserSilenced(u.ID) + if err != nil { + log.Error("view stats: %v", err) + return err + } obj := struct { *UserPage VisitsBlog string Collection *Collection TopPosts *[]PublicPost APFollowers int + Silenced bool }{ UserPage: NewUserPage(app, r, u, titleStats+"Stats", flashes), VisitsBlog: alias, Collection: c, TopPosts: topPosts, + Silenced: silenced, } if app.cfg.App.Federation { folls, err := app.db.GetAPFollowers(c) @@ -1020,11 +1043,13 @@ func viewSettings(app *App, u *User, w http.ResponseWriter, r *http.Request) err Email string HasPass bool IsLogOut bool + Silenced bool }{ UserPage: NewUserPage(app, r, u, "Account Settings", flashes), Email: fullUser.EmailClear(app.keys), HasPass: passIsSet, IsLogOut: r.FormValue("logout") == "1", + Silenced: fullUser.IsSilenced(), } showUserPage(w, "settings", obj) @@ -1068,3 +1093,16 @@ func getTempInfo(app *App, key string, r *http.Request, w http.ResponseWriter) s // Return value return s } + +func prepareUserEmail(input string, emailKey []byte) zero.String { + email := zero.NewString("", input != "") + if len(input) > 0 { + encEmail, err := data.Encrypt(emailKey, input) + if err != nil { + log.Error("Unable to encrypt email: %s\n", err) + } else { + email.String = string(encEmail) + } + } + return email +} diff --git a/account_import.go b/account_import.go new file mode 100644 index 0000000..b34f3a7 --- /dev/null +++ b/account_import.go @@ -0,0 +1,195 @@ +package writefreely + +import ( + "encoding/json" + "fmt" + "html/template" + "io" + "io/ioutil" + "net/http" + "os" + "path/filepath" + "strings" + "time" + + "github.com/hashicorp/go-multierror" + "github.com/writeas/impart" + wfimport "github.com/writeas/import" + "github.com/writeas/web-core/log" +) + +func viewImport(app *App, u *User, w http.ResponseWriter, r *http.Request) error { + // Fetch extra user data + p := NewUserPage(app, r, u, "Import Posts", nil) + + c, err := app.db.GetCollections(u, app.Config().App.Host) + if err != nil { + return impart.HTTPError{http.StatusInternalServerError, fmt.Sprintf("unable to fetch collections: %v", err)} + } + + d := struct { + *UserPage + Collections *[]Collection + Flashes []template.HTML + Message string + InfoMsg bool + }{ + UserPage: p, + Collections: c, + Flashes: []template.HTML{}, + } + + flashes, _ := getSessionFlashes(app, w, r, nil) + for _, flash := range flashes { + if strings.HasPrefix(flash, "SUCCESS: ") { + d.Message = strings.TrimPrefix(flash, "SUCCESS: ") + } else if strings.HasPrefix(flash, "INFO: ") { + d.Message = strings.TrimPrefix(flash, "INFO: ") + d.InfoMsg = true + } else { + d.Flashes = append(d.Flashes, template.HTML(flash)) + } + } + + showUserPage(w, "import", d) + return nil +} + +func handleImport(app *App, u *User, w http.ResponseWriter, r *http.Request) error { + // limit 10MB per submission + r.ParseMultipartForm(10 << 20) + + collAlias := r.PostFormValue("collection") + coll := &Collection{ + ID: 0, + } + var err error + if collAlias != "" { + coll, err = app.db.GetCollection(collAlias) + if err != nil { + log.Error("Unable to get collection for import: %s", err) + return err + } + // Only allow uploading to collection if current user is owner + if coll.OwnerID != u.ID { + err := ErrUnauthorizedGeneral + _ = addSessionFlash(app, w, r, err.Message, nil) + return err + } + coll.hostName = app.cfg.App.Host + } + + fileDates := make(map[string]int64) + err = json.Unmarshal([]byte(r.FormValue("fileDates")), &fileDates) + if err != nil { + log.Error("invalid form data for file dates: %v", err) + return impart.HTTPError{http.StatusBadRequest, "form data for file dates was invalid"} + } + files := r.MultipartForm.File["files"] + var fileErrs []error + filesSubmitted := len(files) + var filesImported int + for _, formFile := range files { + fname := "" + ok := func() bool { + file, err := formFile.Open() + if err != nil { + fileErrs = append(fileErrs, fmt.Errorf("Unable to read file %s", formFile.Filename)) + log.Error("import file: open from form: %v", err) + return false + } + defer file.Close() + + tempFile, err := ioutil.TempFile("", "post-upload-*.txt") + if err != nil { + fileErrs = append(fileErrs, fmt.Errorf("Internal error for %s", formFile.Filename)) + log.Error("import file: create temp file %s: %v", formFile.Filename, err) + return false + } + defer tempFile.Close() + + _, err = io.Copy(tempFile, file) + if err != nil { + fileErrs = append(fileErrs, fmt.Errorf("Internal error for %s", formFile.Filename)) + log.Error("import file: copy to temp location %s: %v", formFile.Filename, err) + return false + } + + info, err := tempFile.Stat() + if err != nil { + fileErrs = append(fileErrs, fmt.Errorf("Internal error for %s", formFile.Filename)) + log.Error("import file: stat temp file %s: %v", formFile.Filename, err) + return false + } + fname = info.Name() + return true + }() + if !ok { + continue + } + + post, err := wfimport.FromFile(filepath.Join(os.TempDir(), fname)) + if err == wfimport.ErrEmptyFile { + // not a real error so don't log + _ = addSessionFlash(app, w, r, fmt.Sprintf("%s was empty, import skipped", formFile.Filename), nil) + continue + } else if err == wfimport.ErrInvalidContentType { + // same as above + _ = addSessionFlash(app, w, r, fmt.Sprintf("%s is not a supported post file", formFile.Filename), nil) + continue + } else if err != nil { + fileErrs = append(fileErrs, fmt.Errorf("failed to read copy of %s", formFile.Filename)) + log.Error("import textfile: file to post: %v", err) + continue + } + + if collAlias != "" { + post.Collection = collAlias + } + dateTime := time.Unix(fileDates[formFile.Filename], 0) + post.Created = &dateTime + created := post.Created.Format("2006-01-02T15:04:05Z") + submittedPost := SubmittedPost{ + Title: &post.Title, + Content: &post.Content, + Font: "norm", + Created: &created, + } + rp, err := app.db.CreatePost(u.ID, coll.ID, &submittedPost) + if err != nil { + fileErrs = append(fileErrs, fmt.Errorf("failed to create post from %s", formFile.Filename)) + log.Error("import textfile: create db post: %v", err) + continue + } + + // Federate post, if necessary + if app.cfg.App.Federation && coll.ID > 0 { + go federatePost( + app, + &PublicPost{ + Post: rp, + Collection: &CollectionObj{ + Collection: *coll, + }, + }, + coll.ID, + false, + ) + } + filesImported++ + } + if len(fileErrs) != 0 { + _ = addSessionFlash(app, w, r, multierror.ListFormatFunc(fileErrs), nil) + } + + if filesImported == filesSubmitted { + verb := "posts" + if filesSubmitted == 1 { + verb = "post" + } + _ = addSessionFlash(app, w, r, fmt.Sprintf("SUCCESS: Import complete, %d %s imported.", filesImported, verb), nil) + } else if filesImported > 0 { + _ = addSessionFlash(app, w, r, fmt.Sprintf("INFO: %d of %d posts imported, see details below.", filesImported, filesSubmitted), nil) + } + return impart.HTTPError{http.StatusFound, "/me/import"} +} diff --git a/activitypub.go b/activitypub.go index 80d484e..f15773f 100644 --- a/activitypub.go +++ b/activitypub.go @@ -1,5 +1,5 @@ /* - * Copyright © 2018-2019 A Bunch Tell LLC. + * Copyright © 2018-2020 A Bunch Tell LLC. * * This file is part of WriteFreely. * @@ -37,6 +37,8 @@ import ( const ( // TODO: delete. don't use this! apCustomHandleDefault = "blog" + + apCacheTime = time.Minute ) type RemoteUser struct { @@ -44,6 +46,7 @@ type RemoteUser struct { ActorID string Inbox string SharedInbox string + Handle string } func (ru *RemoteUser) AsPerson() *activitystreams.Person { @@ -62,6 +65,12 @@ func (ru *RemoteUser) AsPerson() *activitystreams.Person { } } +func activityPubClient() *http.Client { + return &http.Client{ + Timeout: 15 * time.Second, + } +} + func handleFetchCollectionActivities(app *App, w http.ResponseWriter, r *http.Request) error { w.Header().Set("Server", serverSoftware) @@ -80,10 +89,19 @@ func handleFetchCollectionActivities(app *App, w http.ResponseWriter, r *http.Re if err != nil { return err } + silenced, err := app.db.IsUserSilenced(c.OwnerID) + if err != nil { + log.Error("fetch collection activities: %v", err) + return ErrInternalGeneral + } + if silenced { + return ErrCollectionNotFound + } c.hostName = app.cfg.App.Host p := c.PersonObject() + setCacheControl(w, apCacheTime) return impart.RenderActivityJSON(w, p, http.StatusOK) } @@ -105,6 +123,14 @@ func handleFetchCollectionOutbox(app *App, w http.ResponseWriter, r *http.Reques if err != nil { return err } + silenced, err := app.db.IsUserSilenced(c.OwnerID) + if err != nil { + log.Error("fetch collection outbox: %v", err) + return ErrInternalGeneral + } + if silenced { + return ErrCollectionNotFound + } c.hostName = app.cfg.App.Host if app.cfg.App.SingleUser { @@ -132,11 +158,12 @@ func handleFetchCollectionOutbox(app *App, w http.ResponseWriter, r *http.Reques posts, err := app.db.GetPosts(app.cfg, c, p, false, true, false) for _, pp := range *posts { pp.Collection = res - o := pp.ActivityObject(app.cfg) + o := pp.ActivityObject(app) a := activitystreams.NewCreateActivity(o) ocp.OrderedItems = append(ocp.OrderedItems, *a) } + setCacheControl(w, apCacheTime) return impart.RenderActivityJSON(w, ocp, http.StatusOK) } @@ -158,6 +185,14 @@ func handleFetchCollectionFollowers(app *App, w http.ResponseWriter, r *http.Req if err != nil { return err } + silenced, err := app.db.IsUserSilenced(c.OwnerID) + if err != nil { + log.Error("fetch collection followers: %v", err) + return ErrInternalGeneral + } + if silenced { + return ErrCollectionNotFound + } c.hostName = app.cfg.App.Host accountRoot := c.FederatedAccount() @@ -183,6 +218,7 @@ func handleFetchCollectionFollowers(app *App, w http.ResponseWriter, r *http.Req ocp.OrderedItems = append(ocp.OrderedItems, f.ActorID) } */ + setCacheControl(w, apCacheTime) return impart.RenderActivityJSON(w, ocp, http.StatusOK) } @@ -204,6 +240,14 @@ func handleFetchCollectionFollowing(app *App, w http.ResponseWriter, r *http.Req if err != nil { return err } + silenced, err := app.db.IsUserSilenced(c.OwnerID) + if err != nil { + log.Error("fetch collection following: %v", err) + return ErrInternalGeneral + } + if silenced { + return ErrCollectionNotFound + } c.hostName = app.cfg.App.Host accountRoot := c.FederatedAccount() @@ -219,6 +263,7 @@ func handleFetchCollectionFollowing(app *App, w http.ResponseWriter, r *http.Req // Return outbox page ocp := activitystreams.NewOrderedCollectionPage(accountRoot, "following", 0, p) ocp.OrderedItems = []interface{}{} + setCacheControl(w, apCacheTime) return impart.RenderActivityJSON(w, ocp, http.StatusOK) } @@ -238,6 +283,14 @@ func handleFetchCollectionInbox(app *App, w http.ResponseWriter, r *http.Request // TODO: return Reject? return err } + silenced, err := app.db.IsUserSilenced(c.OwnerID) + if err != nil { + log.Error("fetch collection inbox: %v", err) + return ErrInternalGeneral + } + if silenced { + return ErrCollectionNotFound + } c.hostName = app.cfg.App.Host if debugging { @@ -342,6 +395,11 @@ func handleFetchCollectionInbox(app *App, w http.ResponseWriter, r *http.Request } go func() { + if to == nil { + log.Error("No to! %v", err) + return + } + time.Sleep(2 * time.Second) am, err := a.Serialize() if err != nil { @@ -350,10 +408,6 @@ func handleFetchCollectionInbox(app *App, w http.ResponseWriter, r *http.Request } am["@context"] = []string{activitystreams.Namespace} - if to == nil { - log.Error("No to! %v", err) - return - } err = makeActivityPost(app.cfg.App.Host, p, fullActor.Inbox, am) if err != nil { log.Error("Unable to make activity POST: %v", err) @@ -462,7 +516,7 @@ func makeActivityPost(hostName string, p *activitystreams.Person, url string, m } } - resp, err := http.DefaultClient.Do(r) + resp, err := activityPubClient().Do(r) if err != nil { return err } @@ -498,7 +552,7 @@ func resolveIRI(hostName, url string) ([]byte, error) { } } - resp, err := http.DefaultClient.Do(r) + resp, err := activityPubClient().Do(r) if err != nil { return nil, err } @@ -524,7 +578,7 @@ func deleteFederatedPost(app *App, p *PublicPost, collID int64) error { } p.Collection.hostName = app.cfg.App.Host actor := p.Collection.PersonObject(collID) - na := p.ActivityObject(app.cfg) + na := p.ActivityObject(app) // Add followers p.Collection.ID = collID @@ -570,7 +624,7 @@ func federatePost(app *App, p *PublicPost, collID int64, isUpdate bool) error { } } actor := p.Collection.PersonObject(collID) - na := p.ActivityObject(app.cfg) + na := p.ActivityObject(app) // Add followers p.Collection.ID = collID @@ -588,18 +642,25 @@ func federatePost(app *App, p *PublicPost, collID int64, isUpdate bool) error { inbox = f.Inbox } if _, ok := inboxes[inbox]; ok { + // check if we're already sending to this shared inbox inboxes[inbox] = append(inboxes[inbox], f.ActorID) } else { + // add the new shared inbox to the list inboxes[inbox] = []string{f.ActorID} } } + var activity *activitystreams.Activity + // for each one of the shared inboxes for si, instFolls := range inboxes { + // add all followers from that instance + // to the CC field na.CC = []string{} for _, f := range instFolls { na.CC = append(na.CC, f) } - var activity *activitystreams.Activity + // create a new "Create" activity + // with our article as object if isUpdate { activity = activitystreams.NewUpdateActivity(na) } else { @@ -607,17 +668,42 @@ func federatePost(app *App, p *PublicPost, collID int64, isUpdate bool) error { activity.To = na.To activity.CC = na.CC } + // and post it to that sharedInbox err = makeActivityPost(app.cfg.App.Host, actor, si, activity) if err != nil { log.Error("Couldn't post! %v", err) } } + + // re-create the object so that the CC list gets reset and has + // the mentioned users. This might seem wasteful but the code is + // cleaner than adding the mentioned users to CC here instead of + // in p.ActivityObject() + na = p.ActivityObject(app) + for _, tag := range na.Tag { + if tag.Type == "Mention" { + activity = activitystreams.NewCreateActivity(na) + activity.To = na.To + activity.CC = na.CC + // This here might be redundant in some cases as we might have already + // sent this to the sharedInbox of this instance above, but we need too + // much logic to catch this at the expense of the odd extra request. + // I don't believe we'd ever have too many mentions in a single post that this + // could become a burden. + remoteUser, err := getRemoteUser(app, tag.HRef) + err = makeActivityPost(app.cfg.App.Host, actor, remoteUser.Inbox, activity) + if err != nil { + log.Error("Couldn't post! %v", err) + } + } + } + return nil } func getRemoteUser(app *App, actorID string) (*RemoteUser, error) { u := RemoteUser{ActorID: actorID} - err := app.db.QueryRow("SELECT id, inbox, shared_inbox FROM remoteusers WHERE actor_id = ?", actorID).Scan(&u.ID, &u.Inbox, &u.SharedInbox) + err := app.db.QueryRow("SELECT id, inbox, shared_inbox, handle FROM remoteusers WHERE actor_id = ?", actorID).Scan(&u.ID, &u.Inbox, &u.SharedInbox, &u.Handle) switch { case err == sql.ErrNoRows: return nil, impart.HTTPError{http.StatusNotFound, "No remote user with that ID."} @@ -629,6 +715,21 @@ func getRemoteUser(app *App, actorID string) (*RemoteUser, error) { return &u, nil } +// getRemoteUserFromHandle retrieves the profile page of a remote user +// from the @user@server.tld handle +func getRemoteUserFromHandle(app *App, handle string) (*RemoteUser, error) { + u := RemoteUser{Handle: handle} + err := app.db.QueryRow("SELECT id, actor_id, inbox, shared_inbox FROM remoteusers WHERE handle = ?", handle).Scan(&u.ID, &u.ActorID, &u.Inbox, &u.SharedInbox) + switch { + case err == sql.ErrNoRows: + return nil, ErrRemoteUserNotFound + case err != nil: + log.Error("Couldn't get remote user %s: %v", handle, err) + return nil, err + } + return &u, nil +} + func getActor(app *App, actorIRI string) (*activitystreams.Person, *RemoteUser, error) { log.Info("Fetching actor %s locally", actorIRI) actor := &activitystreams.Person{} @@ -703,3 +804,7 @@ func unmarshalActor(actorResp []byte, actor *activitystreams.Person) error { return nil } + +func setCacheControl(w http.ResponseWriter, ttl time.Duration) { + w.Header().Set("Cache-Control", fmt.Sprintf("public, max-age=%.0f", ttl.Seconds())) +} diff --git a/admin.go b/admin.go index a1c2dac..99124ae 100644 --- a/admin.go +++ b/admin.go @@ -16,12 +16,14 @@ import ( "net/http" "runtime" "strconv" + "strings" "time" "github.com/gorilla/mux" "github.com/writeas/impart" "github.com/writeas/web-core/auth" "github.com/writeas/web-core/log" + "github.com/writeas/web-core/passgen" "github.com/writeas/writefreely/appstats" "github.com/writeas/writefreely/config" ) @@ -169,11 +171,12 @@ func handleViewAdminUser(app *App, u *User, w http.ResponseWriter, r *http.Reque Config config.AppCfg Message string - User *User - Colls []inspectedCollection - LastPost string - - TotalPosts int64 + User *User + Colls []inspectedCollection + LastPost string + NewPassword string + TotalPosts int64 + ClearEmail string }{ Config: app.cfg.App, Message: r.FormValue("m"), @@ -183,7 +186,19 @@ func handleViewAdminUser(app *App, u *User, w http.ResponseWriter, r *http.Reque var err error p.User, err = app.db.GetUserForAuth(username) if err != nil { - return impart.HTTPError{http.StatusInternalServerError, fmt.Sprintf("Could not get user: %v", err)} + if err == ErrUserNotFound { + return err + } + log.Error("Could not get user: %v", err) + return impart.HTTPError{http.StatusInternalServerError, err.Error()} + } + + flashes, _ := getSessionFlashes(app, w, r, nil) + for _, flash := range flashes { + if strings.HasPrefix(flash, "SUCCESS: ") { + p.NewPassword = strings.TrimPrefix(flash, "SUCCESS: ") + p.ClearEmail = p.User.EmailClear(app.keys) + } } p.UserPage = NewUserPage(app, r, u, p.User.Username, nil) p.TotalPosts = app.db.GetUserPostsCount(p.User.ID) @@ -229,6 +244,62 @@ func handleViewAdminUser(app *App, u *User, w http.ResponseWriter, r *http.Reque return nil } +func handleAdminToggleUserStatus(app *App, u *User, w http.ResponseWriter, r *http.Request) error { + vars := mux.Vars(r) + username := vars["username"] + if username == "" { + return impart.HTTPError{http.StatusFound, "/admin/users"} + } + + user, err := app.db.GetUserForAuth(username) + if err != nil { + log.Error("failed to get user: %v", err) + return impart.HTTPError{http.StatusInternalServerError, fmt.Sprintf("Could not get user from username: %v", err)} + } + if user.IsSilenced() { + err = app.db.SetUserStatus(user.ID, UserActive) + } else { + err = app.db.SetUserStatus(user.ID, UserSilenced) + } + if err != nil { + log.Error("toggle user silenced: %v", err) + return impart.HTTPError{http.StatusInternalServerError, fmt.Sprintf("Could not toggle user status: %v", err)} + } + return impart.HTTPError{http.StatusFound, fmt.Sprintf("/admin/user/%s#status", username)} +} + +func handleAdminResetUserPass(app *App, u *User, w http.ResponseWriter, r *http.Request) error { + vars := mux.Vars(r) + username := vars["username"] + if username == "" { + return impart.HTTPError{http.StatusFound, "/admin/users"} + } + + // Generate new random password since none supplied + pass := passgen.NewWordish() + hashedPass, err := auth.HashPass([]byte(pass)) + if err != nil { + return impart.HTTPError{http.StatusInternalServerError, fmt.Sprintf("Could not create password hash: %v", err)} + } + + userIDVal := r.FormValue("user") + log.Info("ADMIN: Changing user %s password", userIDVal) + id, err := strconv.Atoi(userIDVal) + if err != nil { + return impart.HTTPError{http.StatusBadRequest, fmt.Sprintf("Invalid user ID: %v", err)} + } + + err = app.db.ChangePassphrase(int64(id), true, "", hashedPass) + if err != nil { + return impart.HTTPError{http.StatusInternalServerError, fmt.Sprintf("Could not update passphrase: %v", err)} + } + log.Info("ADMIN: Successfully changed.") + + addSessionFlash(app, w, r, fmt.Sprintf("SUCCESS: %s", pass), nil) + + return impart.HTTPError{http.StatusFound, fmt.Sprintf("/admin/user/%s", username)} +} + func handleViewAdminPages(app *App, u *User, w http.ResponseWriter, r *http.Request) error { p := struct { *UserPage diff --git a/app.go b/app.go index 5cdaac2..dd05c95 100644 --- a/app.go +++ b/app.go @@ -56,7 +56,7 @@ var ( debugging bool // Software version can be set from git env using -ldflags - softwareVer = "0.10.0" + softwareVer = "0.11.2" // DEPRECATED VARS isSingleUser bool @@ -70,7 +70,7 @@ type App struct { cfg *config.Config cfgFile string keys *key.Keychain - sessionStore *sessions.CookieStore + sessionStore sessions.Store formDecoder *schema.Decoder updates *updatesCache @@ -102,6 +102,14 @@ func (app *App) SetKeys(k *key.Keychain) { app.keys = k } +func (app *App) SessionStore() sessions.Store { + return app.sessionStore +} + +func (app *App) SetSessionStore(s sessions.Store) { + app.sessionStore = s +} + // Apper is the interface for getting data into and out of a WriteFreely // instance (or "App"). // @@ -684,6 +692,52 @@ func ResetPassword(apper Apper, username string) error { return nil } +// DoDeleteAccount runs the confirmation and account delete process. +func DoDeleteAccount(apper Apper, username string) error { + // Connect to the database + apper.LoadConfig() + connectToDatabase(apper.App()) + defer shutdown(apper.App()) + + // check user exists + u, err := apper.App().db.GetUserForAuth(username) + if err != nil { + log.Error("%s", err) + os.Exit(1) + } + userID := u.ID + + // do not delete the admin account + // TODO: check for other admins and skip? + if u.IsAdmin() { + log.Error("Can not delete admin account") + os.Exit(1) + } + + // confirm deletion, w/ w/out posts + prompt := promptui.Prompt{ + Templates: &promptui.PromptTemplates{ + Success: "{{ . | bold | faint }}: ", + }, + Label: fmt.Sprintf("Really delete user : %s", username), + IsConfirm: true, + } + _, err = prompt.Run() + if err != nil { + log.Info("Aborted...") + os.Exit(0) + } + + log.Info("Deleting...") + err = apper.App().db.DeleteAccount(userID) + if err != nil { + log.Error("%s", err) + os.Exit(1) + } + log.Info("Success.") + return nil +} + func connectToDatabase(app *App) { log.Info("Connecting to %s database...", app.cfg.Database.Type) diff --git a/author/author.go b/author/author.go index bf3bfe1..0114905 100644 --- a/author/author.go +++ b/author/author.go @@ -1,5 +1,5 @@ /* - * Copyright © 2018 A Bunch Tell LLC. + * Copyright © 2018-2020 A Bunch Tell LLC. * * This file is part of WriteFreely. * @@ -65,6 +65,7 @@ var reservedUsernames = map[string]bool{ "metadata": true, "new": true, "news": true, + "oauth": true, "post": true, "posts": true, "privacy": true, diff --git a/cmd/writefreely/main.go b/cmd/writefreely/main.go index 48993c7..7fc2342 100644 --- a/cmd/writefreely/main.go +++ b/cmd/writefreely/main.go @@ -13,11 +13,12 @@ package main import ( "flag" "fmt" + "os" + "strings" + "github.com/gorilla/mux" "github.com/writeas/web-core/log" "github.com/writeas/writefreely" - "os" - "strings" ) func main() { @@ -38,6 +39,7 @@ func main() { // Admin actions createAdmin := flag.String("create-admin", "", "Create an admin with the given username:password") createUser := flag.String("create-user", "", "Create a regular user with the given username:password") + deleteUsername := flag.String("delete-user", "", "Delete a user with the given username") resetPassUser := flag.String("reset-pass", "", "Reset the given user's password") outputVersion := flag.Bool("v", false, "Output the current version") flag.Parse() @@ -102,6 +104,13 @@ func main() { os.Exit(1) } os.Exit(0) + } else if *deleteUsername != "" { + err := writefreely.DoDeleteAccount(app, *deleteUsername) + if err != nil { + log.Error(err.Error()) + os.Exit(1) + } + os.Exit(0) } else if *migrate { err := writefreely.Migrate(app) if err != nil { diff --git a/collections.go b/collections.go index cdf3d5c..9688ad9 100644 --- a/collections.go +++ b/collections.go @@ -1,5 +1,5 @@ /* - * Copyright © 2018 A Bunch Tell LLC. + * Copyright © 2018-2020 A Bunch Tell LLC. * * This file is part of WriteFreely. * @@ -63,6 +63,7 @@ type ( TotalPosts int `json:"total_posts"` Owner *User `json:"owner,omitempty"` Posts *[]PublicPost `json:"posts,omitempty"` + Format *CollectionFormat } DisplayCollection struct { *CollectionObj @@ -70,7 +71,7 @@ type ( IsTopLevel bool CurrentPage int TotalPages int - Format *CollectionFormat + Silenced bool } SubmittedCollection struct { // Data used for updating a given collection @@ -379,6 +380,7 @@ func newCollection(app *App, w http.ResponseWriter, r *http.Request) error { } var userID int64 + var err error if reqJSON && !c.Web { accessToken = r.Header.Get("Authorization") if accessToken == "" { @@ -395,6 +397,14 @@ func newCollection(app *App, w http.ResponseWriter, r *http.Request) error { } userID = u.ID } + silenced, err := app.db.IsUserSilenced(userID) + if err != nil { + log.Error("new collection: %v", err) + return ErrInternalGeneral + } + if silenced { + return ErrUserSilenced + } if !author.IsValidUsername(app.cfg, c.Alias) { return impart.HTTPError{http.StatusPreconditionFailed, "Collection alias isn't valid."} @@ -477,6 +487,7 @@ func fetchCollection(app *App, w http.ResponseWriter, r *http.Request) error { res.Owner = u } } + // TODO: check status for silenced app.db.GetPostsCount(res, isCollOwner) // Strip non-public information res.Collection.ForPublic() @@ -545,6 +556,13 @@ type CollectionPage struct { CanInvite bool } +func NewCollectionObj(c *Collection) *CollectionObj { + return &CollectionObj{ + Collection: *c, + Format: c.NewFormat(), + } +} + func (c *CollectionObj) ScriptDisplay() template.JS { return template.JS(c.Script) } @@ -637,6 +655,16 @@ func processCollectionPermissions(app *App, cr *collectionReq, u *User, w http.R uname = u.Username } + // TODO: move this to all permission checks? + suspended, err := app.db.IsUserSilenced(c.OwnerID) + if err != nil { + log.Error("process protected collection permissions: %v", err) + return nil, err + } + if suspended { + return nil, ErrCollectionNotFound + } + // See if we've authorized this collection authd := isAuthorizedForCollection(app, c.Alias, r) @@ -684,11 +712,10 @@ func checkUserForCollection(app *App, cr *collectionReq, r *http.Request, isPost func newDisplayCollection(c *Collection, cr *collectionReq, page int) *DisplayCollection { coll := &DisplayCollection{ - CollectionObj: &CollectionObj{Collection: *c}, + CollectionObj: NewCollectionObj(c), CurrentPage: page, Prefix: cr.prefix, IsTopLevel: isSingleUser, - Format: c.NewFormat(), } c.db.GetPostsCount(coll.CollectionObj, cr.isCollOwner) return coll @@ -725,13 +752,19 @@ func handleViewCollection(app *App, w http.ResponseWriter, r *http.Request) erro if c == nil || err != nil { return err } - c.hostName = app.cfg.App.Host + silenced, err := app.db.IsUserSilenced(c.OwnerID) + if err != nil { + log.Error("view collection: %v", err) + return ErrInternalGeneral + } + // Serve ActivityStreams data now, if requested if strings.Contains(r.Header.Get("Accept"), "application/activity+json") { ac := c.PersonObject() ac.Context = []interface{}{activitystreams.Namespace} + setCacheControl(w, apCacheTime) return impart.RenderActivityJSON(w, ac, http.StatusOK) } @@ -784,6 +817,10 @@ func handleViewCollection(app *App, w http.ResponseWriter, r *http.Request) erro log.Error("Error getting user for collection: %v", err) } } + if !isOwner && silenced { + return ErrCollectionNotFound + } + displayPage.Silenced = isOwner && silenced displayPage.Owner = owner coll.Owner = displayPage.Owner @@ -820,6 +857,19 @@ func handleViewCollection(app *App, w http.ResponseWriter, r *http.Request) erro return err } +func handleViewMention(app *App, w http.ResponseWriter, r *http.Request) error { + vars := mux.Vars(r) + handle := vars["handle"] + + remoteUser, err := app.db.GetProfilePageFromHandle(app, handle) + if err != nil || remoteUser == "" { + log.Error("Couldn't find user %s: %v", handle, err) + return ErrRemoteUserNotFound + } + + return impart.HTTPError{Status: http.StatusFound, Message: remoteUser} +} + func handleViewCollectionTag(app *App, w http.ResponseWriter, r *http.Request) error { vars := mux.Vars(r) tag := vars["tag"] @@ -885,7 +935,11 @@ func handleViewCollectionTag(app *App, w http.ResponseWriter, r *http.Request) e // Log the error and just continue log.Error("Error getting user for collection: %v", err) } + if owner.IsSilenced() { + return ErrCollectionNotFound + } } + displayPage.Silenced = owner != nil && owner.IsSilenced() displayPage.Owner = owner coll.Owner = displayPage.Owner // Add more data @@ -924,11 +978,10 @@ func existingCollection(app *App, w http.ResponseWriter, r *http.Request) error collAlias := vars["alias"] isWeb := r.FormValue("web") == "1" - var u *User + u := &User{} if reqJSON && !isWeb { // Ensure an access token was given accessToken := r.Header.Get("Authorization") - u = &User{} u.ID = app.db.GetUserID(accessToken) if u.ID == -1 { return ErrBadAccessToken @@ -940,6 +993,16 @@ func existingCollection(app *App, w http.ResponseWriter, r *http.Request) error } } + silenced, err := app.db.IsUserSilenced(u.ID) + if err != nil { + log.Error("existing collection: %v", err) + return ErrInternalGeneral + } + + if silenced { + return ErrUserSilenced + } + if r.Method == "DELETE" { err := app.db.DeleteCollection(collAlias, u.ID) if err != nil { @@ -952,7 +1015,6 @@ func existingCollection(app *App, w http.ResponseWriter, r *http.Request) error } c := SubmittedCollection{OwnerID: uint64(u.ID)} - var err error if reqJSON { // Decode JSON request diff --git a/config/config.go b/config/config.go index 80e2565..31f62f0 100644 --- a/config/config.go +++ b/config/config.go @@ -43,6 +43,8 @@ type ( PagesParentDir string `ini:"pages_parent_dir"` KeysParentDir string `ini:"keys_parent_dir"` + HashSeed string `ini:"hash_seed"` + Dev bool `ini:"-"` } @@ -57,6 +59,24 @@ type ( Port int `ini:"port"` } + WriteAsOauthCfg struct { + ClientID string `ini:"client_id"` + ClientSecret string `ini:"client_secret"` + AuthLocation string `ini:"auth_location"` + TokenLocation string `ini:"token_location"` + InspectLocation string `ini:"inspect_location"` + CallbackProxy string `ini:"callback_proxy"` + CallbackProxyAPI string `ini:"callback_proxy_api"` + } + + SlackOauthCfg struct { + ClientID string `ini:"client_id"` + ClientSecret string `ini:"client_secret"` + TeamID string `ini:"team_id"` + CallbackProxy string `ini:"callback_proxy"` + CallbackProxyAPI string `ini:"callback_proxy_api"` + } + // AppCfg holds values that affect how the application functions AppCfg struct { SiteName string `ini:"site_name"` @@ -102,9 +122,11 @@ type ( // Config holds the complete configuration for running a writefreely instance Config struct { - Server ServerCfg `ini:"server"` - Database DatabaseCfg `ini:"database"` - App AppCfg `ini:"app"` + Server ServerCfg `ini:"server"` + Database DatabaseCfg `ini:"database"` + App AppCfg `ini:"app"` + SlackOauth SlackOauthCfg `ini:"oauth.slack"` + WriteAsOauth WriteAsOauthCfg `ini:"oauth.writeas"` } ) diff --git a/config/funcs.go b/config/funcs.go index a9c82ce..9678df0 100644 --- a/config/funcs.go +++ b/config/funcs.go @@ -11,7 +11,9 @@ package config import ( + "net/http" "strings" + "time" ) // FriendlyHost returns the app's Host sans any schema @@ -25,3 +27,16 @@ func (ac AppCfg) CanCreateBlogs(currentlyUsed uint64) bool { } return int(currentlyUsed) < ac.MaxBlogs } + +// OrDefaultString returns input or a default value if input is empty. +func OrDefaultString(input, defaultValue string) string { + if len(input) == 0 { + return defaultValue + } + return input +} + +// DefaultHTTPClient returns a sane default HTTP client. +func DefaultHTTPClient() *http.Client { + return &http.Client{Timeout: 10 * time.Second} +} diff --git a/database-lib.go b/database-lib.go index 58beb05..b6b4be2 100644 --- a/database-lib.go +++ b/database-lib.go @@ -1,7 +1,7 @@ // +build wflib /* - * Copyright © 2019 A Bunch Tell LLC. + * Copyright © 2019-2020 A Bunch Tell LLC. * * This file is part of WriteFreely. * @@ -18,3 +18,7 @@ package writefreely func (db *datastore) isDuplicateKeyErr(err error) bool { return false } + +func (db *datastore) isIgnorableError(err error) bool { + return false +} diff --git a/database-no-sqlite.go b/database-no-sqlite.go index a3d50fc..03d1a32 100644 --- a/database-no-sqlite.go +++ b/database-no-sqlite.go @@ -1,7 +1,7 @@ // +build !sqlite,!wflib /* - * Copyright © 2019 A Bunch Tell LLC. + * Copyright © 2019-2020 A Bunch Tell LLC. * * This file is part of WriteFreely. * @@ -28,3 +28,15 @@ func (db *datastore) isDuplicateKeyErr(err error) bool { return false } + +func (db *datastore) isIgnorableError(err error) bool { + if db.driverName == driverMySQL { + if mysqlErr, ok := err.(*mysql.MySQLError); ok { + return mysqlErr.Number == mySQLErrCollationMix + } + } else { + log.Error("isIgnorableError: failed check for unrecognized driver '%s'", db.driverName) + } + + return false +} diff --git a/database-sqlite.go b/database-sqlite.go index 3741169..bd77e6a 100644 --- a/database-sqlite.go +++ b/database-sqlite.go @@ -48,3 +48,15 @@ func (db *datastore) isDuplicateKeyErr(err error) bool { return false } + +func (db *datastore) isIgnorableError(err error) bool { + if db.driverName == driverMySQL { + if mysqlErr, ok := err.(*mysql.MySQLError); ok { + return mysqlErr.Number == mySQLErrCollationMix + } + } else { + log.Error("isIgnorableError: failed check for unrecognized driver '%s'", db.driverName) + } + + return false +} diff --git a/database.go b/database.go index a3235b6..cea7a97 100644 --- a/database.go +++ b/database.go @@ -1,5 +1,5 @@ /* - * Copyright © 2018 A Bunch Tell LLC. + * Copyright © 2018-2020 A Bunch Tell LLC. * * This file is part of WriteFreely. * @@ -11,8 +11,10 @@ package writefreely import ( + "context" "database/sql" "fmt" + wf_db "github.com/writeas/writefreely/db" "net/http" "strings" "time" @@ -20,6 +22,7 @@ import ( "github.com/guregu/null" "github.com/guregu/null/zero" uuid "github.com/nu7hatch/gouuid" + "github.com/writeas/activityserve" "github.com/writeas/impart" "github.com/writeas/nerds/store" "github.com/writeas/web-core/activitypub" @@ -35,6 +38,7 @@ import ( const ( mySQLErrDuplicateKey = 1062 + mySQLErrCollationMix = 1267 driverMySQL = "mysql" driverSQLite = "sqlite3" @@ -61,7 +65,7 @@ type writestore interface { GetAccessToken(userID int64) (string, error) GetTemporaryAccessToken(userID int64, validSecs int) (string, error) GetTemporaryOneTimeAccessToken(userID int64, validSecs int, oneTime bool) (string, error) - DeleteAccount(userID int64) (l *string, err error) + DeleteAccount(userID int64) error ChangeSettings(app *App, u *User, s *userSettings) error ChangePassphrase(userID int64, sudo bool, curPass string, hashedPass []byte) error @@ -124,6 +128,11 @@ type writestore interface { GetUserLastPostTime(id int64) (*time.Time, error) GetCollectionLastPostTime(id int64) (*time.Time, error) + GetIDForRemoteUser(context.Context, string, string, string) (int64, error) + RecordRemoteUserID(context.Context, int64, string, string, string, string) error + ValidateOAuthState(context.Context, string) (string, string, error) + GenerateOAuthState(context.Context, string, string) (string, error) + DatabaseInitialized() bool } @@ -132,6 +141,8 @@ type datastore struct { driverName string } +var _ writestore = &datastore{} + func (db *datastore) now() string { if db.driverName == driverSQLite { return "strftime('%Y-%m-%d %H:%M:%S','now')" @@ -296,7 +307,7 @@ func (db *datastore) CreateCollection(cfg *config.Config, alias, title string, u func (db *datastore) GetUserByID(id int64) (*User, error) { u := &User{ID: id} - err := db.QueryRow("SELECT username, password, email, created FROM users WHERE id = ?", id).Scan(&u.Username, &u.HashedPass, &u.Email, &u.Created) + err := db.QueryRow("SELECT username, password, email, created, status FROM users WHERE id = ?", id).Scan(&u.Username, &u.HashedPass, &u.Email, &u.Created, &u.Status) switch { case err == sql.ErrNoRows: return nil, ErrUserNotFound @@ -308,6 +319,23 @@ func (db *datastore) GetUserByID(id int64) (*User, error) { return u, nil } +// IsUserSilenced returns true if the user account associated with id is +// currently silenced. +func (db *datastore) IsUserSilenced(id int64) (bool, error) { + u := &User{ID: id} + + err := db.QueryRow("SELECT status FROM users WHERE id = ?", id).Scan(&u.Status) + switch { + case err == sql.ErrNoRows: + return false, fmt.Errorf("is user silenced: %v", ErrUserNotFound) + case err != nil: + log.Error("Couldn't SELECT user status: %v", err) + return false, fmt.Errorf("is user silenced: %v", err) + } + + return u.IsSilenced(), nil +} + // DoesUserNeedAuth returns true if the user hasn't provided any methods for // authenticating with the account, such a passphrase or email address. // Any errors are reported to admin and silently quashed, returning false as the @@ -347,7 +375,7 @@ func (db *datastore) IsUserPassSet(id int64) (bool, error) { func (db *datastore) GetUserForAuth(username string) (*User, error) { u := &User{Username: username} - err := db.QueryRow("SELECT id, password, email, created FROM users WHERE username = ?", username).Scan(&u.ID, &u.HashedPass, &u.Email, &u.Created) + err := db.QueryRow("SELECT id, password, email, created, status FROM users WHERE username = ?", username).Scan(&u.ID, &u.HashedPass, &u.Email, &u.Created, &u.Status) switch { case err == sql.ErrNoRows: // Check if they've entered the wrong, unnormalized username @@ -370,7 +398,7 @@ func (db *datastore) GetUserForAuth(username string) (*User, error) { func (db *datastore) GetUserForAuthByID(userID int64) (*User, error) { u := &User{ID: userID} - err := db.QueryRow("SELECT id, password, email, created FROM users WHERE id = ?", u.ID).Scan(&u.ID, &u.HashedPass, &u.Email, &u.Created) + err := db.QueryRow("SELECT id, password, email, created, status FROM users WHERE id = ?", u.ID).Scan(&u.ID, &u.HashedPass, &u.Email, &u.Created, &u.Status) switch { case err == sql.ErrNoRows: return nil, ErrUserNotFound @@ -1629,7 +1657,11 @@ func (db *datastore) GetMeStats(u *User) userMeStats { } func (db *datastore) GetTotalCollections() (collCount int64, err error) { - err = db.QueryRow(`SELECT COUNT(*) FROM collections`).Scan(&collCount) + err = db.QueryRow(` + SELECT COUNT(*) + FROM collections c + LEFT JOIN users u ON u.id = c.owner_id + WHERE u.status = 0`).Scan(&collCount) if err != nil { log.Error("Unable to fetch collections count: %v", err) } @@ -1637,7 +1669,11 @@ func (db *datastore) GetTotalCollections() (collCount int64, err error) { } func (db *datastore) GetTotalPosts() (postCount int64, err error) { - err = db.QueryRow(`SELECT COUNT(*) FROM posts`).Scan(&postCount) + err = db.QueryRow(` + SELECT COUNT(*) + FROM posts p + LEFT JOIN users u ON u.id = p.owner_id + WHERE u.status = 0`).Scan(&postCount) if err != nil { log.Error("Unable to fetch posts count: %v", err) } @@ -2079,22 +2115,13 @@ func (db *datastore) CollectionHasAttribute(id int64, attr string) bool { return true } -func (db *datastore) DeleteAccount(userID int64) (l *string, err error) { - debug := "" - l = &debug - - t, err := db.Begin() - if err != nil { - stringLogln(l, "Unable to begin: %v", err) - return - } - +// DeleteAccount will delete the entire account for userID +func (db *datastore) DeleteAccount(userID int64) error { // Get all collections rows, err := db.Query("SELECT id, alias FROM collections WHERE owner_id = ?", userID) if err != nil { - t.Rollback() - stringLogln(l, "Unable to get collections: %v", err) - return + log.Error("Unable to get collections: %v", err) + return err } defer rows.Close() colls := []Collection{} @@ -2102,103 +2129,158 @@ func (db *datastore) DeleteAccount(userID int64) (l *string, err error) { for rows.Next() { err = rows.Scan(&c.ID, &c.Alias) if err != nil { - t.Rollback() - stringLogln(l, "Unable to scan collection cols: %v", err) - return + log.Error("Unable to scan collection cols: %v", err) + return err } colls = append(colls, c) } + // Start transaction + t, err := db.Begin() + if err != nil { + log.Error("Unable to begin: %v", err) + return err + } + + // Clean up all collection related information var res sql.Result for _, c := range colls { - // TODO: user deleteCollection() func // Delete tokens res, err = t.Exec("DELETE FROM collectionattributes WHERE collection_id = ?", c.ID) if err != nil { t.Rollback() - stringLogln(l, "Unable to delete attributes on %s: %v", c.Alias, err) - return + log.Error("Unable to delete attributes on %s: %v", c.Alias, err) + return err } rs, _ := res.RowsAffected() - stringLogln(l, "Deleted %d for %s from collectionattributes", rs, c.Alias) + log.Info("Deleted %d for %s from collectionattributes", rs, c.Alias) // Remove any optional collection password res, err = t.Exec("DELETE FROM collectionpasswords WHERE collection_id = ?", c.ID) if err != nil { t.Rollback() - stringLogln(l, "Unable to delete passwords on %s: %v", c.Alias, err) - return + log.Error("Unable to delete passwords on %s: %v", c.Alias, err) + return err } rs, _ = res.RowsAffected() - stringLogln(l, "Deleted %d for %s from collectionpasswords", rs, c.Alias) + log.Info("Deleted %d for %s from collectionpasswords", rs, c.Alias) // Remove redirects to this collection res, err = t.Exec("DELETE FROM collectionredirects WHERE new_alias = ?", c.Alias) if err != nil { t.Rollback() - stringLogln(l, "Unable to delete redirects on %s: %v", c.Alias, err) - return + log.Error("Unable to delete redirects on %s: %v", c.Alias, err) + return err } rs, _ = res.RowsAffected() - stringLogln(l, "Deleted %d for %s from collectionredirects", rs, c.Alias) + log.Info("Deleted %d for %s from collectionredirects", rs, c.Alias) + + // Remove any collection keys + res, err = t.Exec("DELETE FROM collectionkeys WHERE collection_id = ?", c.ID) + if err != nil { + t.Rollback() + log.Error("Unable to delete keys on %s: %v", c.Alias, err) + return err + } + rs, _ = res.RowsAffected() + log.Info("Deleted %d for %s from collectionkeys", rs, c.Alias) + + // TODO: federate delete collection + + // Remove remote follows + res, err = t.Exec("DELETE FROM remotefollows WHERE collection_id = ?", c.ID) + if err != nil { + t.Rollback() + log.Error("Unable to delete remote follows on %s: %v", c.Alias, err) + return err + } + rs, _ = res.RowsAffected() + log.Info("Deleted %d for %s from remotefollows", rs, c.Alias) } // Delete collections res, err = t.Exec("DELETE FROM collections WHERE owner_id = ?", userID) if err != nil { t.Rollback() - stringLogln(l, "Unable to delete collections: %v", err) - return + log.Error("Unable to delete collections: %v", err) + return err } rs, _ := res.RowsAffected() - stringLogln(l, "Deleted %d from collections", rs) + log.Info("Deleted %d from collections", rs) // Delete tokens res, err = t.Exec("DELETE FROM accesstokens WHERE user_id = ?", userID) if err != nil { t.Rollback() - stringLogln(l, "Unable to delete access tokens: %v", err) - return + log.Error("Unable to delete access tokens: %v", err) + return err } rs, _ = res.RowsAffected() - stringLogln(l, "Deleted %d from accesstokens", rs) + log.Info("Deleted %d from accesstokens", rs) + + // Delete user attributes + res, err = t.Exec("DELETE FROM oauth_users WHERE user_id = ?", userID) + if err != nil { + t.Rollback() + log.Error("Unable to delete oauth_users: %v", err) + return err + } + rs, _ = res.RowsAffected() + log.Info("Deleted %d from oauth_users", rs) // Delete posts + // TODO: should maybe get each row so we can federate a delete + // if so needs to be outside of transaction like collections res, err = t.Exec("DELETE FROM posts WHERE owner_id = ?", userID) if err != nil { t.Rollback() - stringLogln(l, "Unable to delete posts: %v", err) - return + log.Error("Unable to delete posts: %v", err) + return err } rs, _ = res.RowsAffected() - stringLogln(l, "Deleted %d from posts", rs) + log.Info("Deleted %d from posts", rs) + // Delete user attributes res, err = t.Exec("DELETE FROM userattributes WHERE user_id = ?", userID) if err != nil { t.Rollback() - stringLogln(l, "Unable to delete attributes: %v", err) - return + log.Error("Unable to delete attributes: %v", err) + return err + } + rs, _ = res.RowsAffected() + log.Info("Deleted %d from userattributes", rs) + + // Delete user invites + res, err = t.Exec("DELETE FROM userinvites WHERE owner_id = ?", userID) + if err != nil { + t.Rollback() + log.Error("Unable to delete invites: %v", err) + return err } rs, _ = res.RowsAffected() - stringLogln(l, "Deleted %d from userattributes", rs) + log.Info("Deleted %d from userinvites", rs) + // Delete the user res, err = t.Exec("DELETE FROM users WHERE id = ?", userID) if err != nil { t.Rollback() - stringLogln(l, "Unable to delete user: %v", err) - return + log.Error("Unable to delete user: %v", err) + return err } rs, _ = res.RowsAffected() - stringLogln(l, "Deleted %d from users", rs) + log.Info("Deleted %d from users", rs) + // Commit all changes to the database err = t.Commit() if err != nil { t.Rollback() - stringLogln(l, "Unable to commit: %v", err) - return + log.Error("Unable to commit: %v", err) + return err } - return + // TODO: federate delete actor + + return nil } func (db *datastore) GetAPActorKeys(collectionID int64) ([]byte, []byte) { @@ -2247,7 +2329,7 @@ func (db *datastore) GetUserInvite(id string) (*Invite, error) { var i Invite err := db.QueryRow("SELECT id, max_uses, created, expires, inactive FROM userinvites WHERE id = ?", id).Scan(&i.ID, &i.MaxUses, &i.Created, &i.Expires, &i.Inactive) switch { - case err == sql.ErrNoRows: + case err == sql.ErrNoRows, db.isIgnorableError(err): return nil, impart.HTTPError{http.StatusNotFound, "Invite doesn't exist."} case err != nil: log.Error("Failed selecting invite: %v", err) @@ -2359,17 +2441,17 @@ func (db *datastore) GetAllUsers(page uint) (*[]User, error) { limitStr = fmt.Sprintf("%d, %d", (page-1)*adminUsersPerPage, adminUsersPerPage) } - rows, err := db.Query("SELECT id, username, created FROM users ORDER BY created DESC LIMIT " + limitStr) + rows, err := db.Query("SELECT id, username, created, status FROM users ORDER BY created DESC LIMIT " + limitStr) if err != nil { - log.Error("Failed selecting from posts: %v", err) - return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve user posts."} + log.Error("Failed selecting from users: %v", err) + return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve all users."} } defer rows.Close() users := []User{} for rows.Next() { u := User{} - err = rows.Scan(&u.ID, &u.Username, &u.Created) + err = rows.Scan(&u.ID, &u.Username, &u.Created, &u.Status) if err != nil { log.Error("Failed scanning GetAllUsers() row: %v", err) break @@ -2406,6 +2488,15 @@ func (db *datastore) GetUserLastPostTime(id int64) (*time.Time, error) { return &t, nil } +// SetUserStatus changes a user's status in the database. see Users.UserStatus +func (db *datastore) SetUserStatus(id int64, status UserStatus) error { + _, err := db.Exec("UPDATE users SET status = ? WHERE id = ?", status, id) + if err != nil { + return fmt.Errorf("failed to update user status: %v", err) + } + return nil +} + func (db *datastore) GetCollectionLastPostTime(id int64) (*time.Time, error) { var t time.Time err := db.QueryRow("SELECT created FROM posts WHERE collection_id = ? ORDER BY created DESC LIMIT 1", id).Scan(&t) @@ -2419,6 +2510,69 @@ func (db *datastore) GetCollectionLastPostTime(id int64) (*time.Time, error) { return &t, nil } +func (db *datastore) GenerateOAuthState(ctx context.Context, provider, clientID string) (string, error) { + state := store.Generate62RandomString(24) + _, err := db.ExecContext(ctx, "INSERT INTO oauth_client_states (state, provider, client_id, used, created_at) VALUES (?, ?, ?, FALSE, NOW())", state, provider, clientID) + if err != nil { + return "", fmt.Errorf("unable to record oauth client state: %w", err) + } + return state, nil +} + +func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (string, string, error) { + var provider string + var clientID string + err := wf_db.RunTransactionWithOptions(ctx, db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { + err := tx.QueryRow("SELECT provider, client_id FROM oauth_client_states WHERE state = ? AND used = FALSE", state).Scan(&provider, &clientID) + if err != nil { + return err + } + + res, err := tx.ExecContext(ctx, "UPDATE oauth_client_states SET used = TRUE WHERE state = ?", state) + if err != nil { + return err + } + rowsAffected, err := res.RowsAffected() + if err != nil { + return err + } + if rowsAffected != 1 { + return fmt.Errorf("state not found") + } + return nil + }) + if err != nil { + return "", "", nil + } + return provider, clientID, nil +} + +func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID, provider, clientID, accessToken string) error { + var err error + if db.driverName == driverSQLite { + _, err = db.ExecContext(ctx, "INSERT OR REPLACE INTO oauth_users (user_id, remote_user_id, provider, client_id, access_token) VALUES (?, ?, ?, ?, ?)", localUserID, remoteUserID, provider, clientID, accessToken) + } else { + _, err = db.ExecContext(ctx, "INSERT INTO oauth_users (user_id, remote_user_id, provider, client_id, access_token) VALUES (?, ?, ?, ?, ?) "+db.upsert("user")+" access_token = ?", localUserID, remoteUserID, provider, clientID, accessToken, accessToken) + } + if err != nil { + log.Error("Unable to INSERT oauth_users for '%d': %v", localUserID, err) + } + return err +} + +// GetIDForRemoteUser returns a user ID associated with a remote user ID. +func (db *datastore) GetIDForRemoteUser(ctx context.Context, remoteUserID, provider, clientID string) (int64, error) { + var userID int64 = -1 + err := db. + QueryRowContext(ctx, "SELECT user_id FROM oauth_users WHERE remote_user_id = ? AND provider = ? AND client_id = ?", remoteUserID, provider, clientID). + Scan(&userID) + // Not finding a record is OK. + if err != nil && err != sql.ErrNoRows { + return -1, err + } + return userID, nil +} + // DatabaseInitialized returns whether or not the current datastore has been // initialized with the correct schema. // Currently, it checks to see if the `users` table exists. @@ -2449,3 +2603,40 @@ func handleFailedPostInsert(err error) error { log.Error("Couldn't insert into posts: %v", err) return err } + +func (db *datastore) GetProfilePageFromHandle(app *App, handle string) (string, error) { + actorIRI := "" + remoteUser, err := getRemoteUserFromHandle(app, handle) + if err != nil { + // can't find using handle in the table but the table may already have this user without + // handle from a previous version + // TODO: Make this determination. We should know whether a user exists without a handle, or doesn't exist at all + actorIRI = RemoteLookup(handle) + _, errRemoteUser := getRemoteUser(app, actorIRI) + // if it exists then we need to update the handle + if errRemoteUser == nil { + _, err := app.db.Exec("UPDATE remoteusers SET handle = ? WHERE actor_id = ?", handle, actorIRI) + if err != nil { + log.Error("Can't update handle (" + handle + ") in database for user " + actorIRI) + } + } else { + // this probably means we don't have the user in the table so let's try to insert it + // here we need to ask the server for the inboxes + remoteActor, err := activityserve.NewRemoteActor(actorIRI) + if err != nil { + log.Error("Couldn't fetch remote actor", err) + } + if debugging { + log.Info("%s %s %s %s", actorIRI, remoteActor.GetInbox(), remoteActor.GetSharedInbox(), handle) + } + _, err = app.db.Exec("INSERT INTO remoteusers (actor_id, inbox, shared_inbox, handle) VALUES(?, ?, ?, ?)", actorIRI, remoteActor.GetInbox(), remoteActor.GetSharedInbox(), handle) + if err != nil { + log.Error("Can't insert remote user in database", err) + return "", err + } + } + } else { + actorIRI = remoteUser.ActorID + } + return actorIRI, nil +} diff --git a/database_test.go b/database_test.go new file mode 100644 index 0000000..c4c586a --- /dev/null +++ b/database_test.go @@ -0,0 +1,50 @@ +package writefreely + +import ( + "context" + "database/sql" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestOAuthDatastore(t *testing.T) { + if !runMySQLTests() { + t.Skip("skipping mysql tests") + } + withTestDB(t, func(db *sql.DB) { + ctx := context.Background() + ds := &datastore{ + DB: db, + driverName: "", + } + + state, err := ds.GenerateOAuthState(ctx, "test", "development") + assert.NoError(t, err) + assert.Len(t, state, 24) + + countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = false", state) + + _, _, err = ds.ValidateOAuthState(ctx, state) + assert.NoError(t, err) + + countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = true", state) + + var localUserID int64 = 99 + var remoteUserID = "100" + err = ds.RecordRemoteUserID(ctx, localUserID, remoteUserID, "test", "test", "access_token_a") + assert.NoError(t, err) + + countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_users` WHERE `user_id` = ? AND `remote_user_id` = ? AND access_token = 'access_token_a'", localUserID, remoteUserID) + + err = ds.RecordRemoteUserID(ctx, localUserID, remoteUserID, "test", "test", "access_token_b") + assert.NoError(t, err) + + countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_users` WHERE `user_id` = ? AND `remote_user_id` = ? AND access_token = 'access_token_b'", localUserID, remoteUserID) + + countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_users`") + + foundUserID, err := ds.GetIDForRemoteUser(ctx, remoteUserID, "test", "test") + assert.NoError(t, err) + assert.Equal(t, localUserID, foundUserID) + }) +} diff --git a/db/alter.go b/db/alter.go new file mode 100644 index 0000000..0a4ffdd --- /dev/null +++ b/db/alter.go @@ -0,0 +1,52 @@ +package db + +import ( + "fmt" + "strings" +) + +type AlterTableSqlBuilder struct { + Dialect DialectType + Name string + Changes []string +} + +func (b *AlterTableSqlBuilder) AddColumn(col *Column) *AlterTableSqlBuilder { + if colVal, err := col.String(); err == nil { + b.Changes = append(b.Changes, fmt.Sprintf("ADD COLUMN %s", colVal)) + } + return b +} + +func (b *AlterTableSqlBuilder) ChangeColumn(name string, col *Column) *AlterTableSqlBuilder { + if colVal, err := col.String(); err == nil { + b.Changes = append(b.Changes, fmt.Sprintf("CHANGE COLUMN %s %s", name, colVal)) + } + return b +} + +func (b *AlterTableSqlBuilder) AddUniqueConstraint(name string, columns ...string) *AlterTableSqlBuilder { + b.Changes = append(b.Changes, fmt.Sprintf("ADD CONSTRAINT %s UNIQUE (%s)", name, strings.Join(columns, ", "))) + return b +} + +func (b *AlterTableSqlBuilder) ToSQL() (string, error) { + var str strings.Builder + + str.WriteString("ALTER TABLE ") + str.WriteString(b.Name) + str.WriteString(" ") + + if len(b.Changes) == 0 { + return "", fmt.Errorf("no changes provide for table: %s", b.Name) + } + changeCount := len(b.Changes) + for i, thing := range b.Changes { + str.WriteString(thing) + if i < changeCount-1 { + str.WriteString(", ") + } + } + + return str.String(), nil +} diff --git a/db/alter_test.go b/db/alter_test.go new file mode 100644 index 0000000..4bd58ac --- /dev/null +++ b/db/alter_test.go @@ -0,0 +1,56 @@ +package db + +import "testing" + +func TestAlterTableSqlBuilder_ToSQL(t *testing.T) { + type fields struct { + Dialect DialectType + Name string + Changes []string + } + tests := []struct { + name string + builder *AlterTableSqlBuilder + want string + wantErr bool + }{ + { + name: "MySQL add int", + builder: DialectMySQL. + AlterTable("the_table"). + AddColumn(DialectMySQL.Column("the_col", ColumnTypeInteger, UnsetSize)), + want: "ALTER TABLE the_table ADD COLUMN the_col INT NOT NULL", + wantErr: false, + }, + { + name: "MySQL add string", + builder: DialectMySQL. + AlterTable("the_table"). + AddColumn(DialectMySQL.Column("the_col", ColumnTypeVarChar, OptionalInt{true, 128})), + want: "ALTER TABLE the_table ADD COLUMN the_col VARCHAR(128) NOT NULL", + wantErr: false, + }, + + { + name: "MySQL add int and string", + builder: DialectMySQL. + AlterTable("the_table"). + AddColumn(DialectMySQL.Column("first_col", ColumnTypeInteger, UnsetSize)). + AddColumn(DialectMySQL.Column("second_col", ColumnTypeVarChar, OptionalInt{true, 128})), + want: "ALTER TABLE the_table ADD COLUMN first_col INT NOT NULL, ADD COLUMN second_col VARCHAR(128) NOT NULL", + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.builder.ToSQL() + if (err != nil) != tt.wantErr { + t.Errorf("ToSQL() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("ToSQL() got = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/db/create.go b/db/create.go new file mode 100644 index 0000000..c384778 --- /dev/null +++ b/db/create.go @@ -0,0 +1,244 @@ +package db + +import ( + "fmt" + "strings" +) + +type ColumnType int + +type OptionalInt struct { + Set bool + Value int +} + +type OptionalString struct { + Set bool + Value string +} + +type SQLBuilder interface { + ToSQL() (string, error) +} + +type Column struct { + Dialect DialectType + Name string + Nullable bool + Default OptionalString + Type ColumnType + Size OptionalInt + PrimaryKey bool +} + +type CreateTableSqlBuilder struct { + Dialect DialectType + Name string + IfNotExists bool + ColumnOrder []string + Columns map[string]*Column + Constraints []string +} + +const ( + ColumnTypeBool ColumnType = iota + ColumnTypeSmallInt ColumnType = iota + ColumnTypeInteger ColumnType = iota + ColumnTypeChar ColumnType = iota + ColumnTypeVarChar ColumnType = iota + ColumnTypeText ColumnType = iota + ColumnTypeDateTime ColumnType = iota +) + +var _ SQLBuilder = &CreateTableSqlBuilder{} + +var UnsetSize OptionalInt = OptionalInt{Set: false, Value: 0} +var UnsetDefault OptionalString = OptionalString{Set: false, Value: ""} + +func (d ColumnType) Format(dialect DialectType, size OptionalInt) (string, error) { + if dialect != DialectMySQL && dialect != DialectSQLite { + return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size) + } + switch d { + case ColumnTypeSmallInt: + { + if dialect == DialectSQLite { + return "INTEGER", nil + } + mod := "" + if size.Set { + mod = fmt.Sprintf("(%d)", size.Value) + } + return "SMALLINT" + mod, nil + } + case ColumnTypeInteger: + { + if dialect == DialectSQLite { + return "INTEGER", nil + } + mod := "" + if size.Set { + mod = fmt.Sprintf("(%d)", size.Value) + } + return "INT" + mod, nil + } + case ColumnTypeChar: + { + if dialect == DialectSQLite { + return "TEXT", nil + } + mod := "" + if size.Set { + mod = fmt.Sprintf("(%d)", size.Value) + } + return "CHAR" + mod, nil + } + case ColumnTypeVarChar: + { + if dialect == DialectSQLite { + return "TEXT", nil + } + mod := "" + if size.Set { + mod = fmt.Sprintf("(%d)", size.Value) + } + return "VARCHAR" + mod, nil + } + case ColumnTypeBool: + { + if dialect == DialectSQLite { + return "INTEGER", nil + } + return "TINYINT(1)", nil + } + case ColumnTypeDateTime: + return "DATETIME", nil + case ColumnTypeText: + return "TEXT", nil + } + return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size) +} + +func (c *Column) SetName(name string) *Column { + c.Name = name + return c +} + +func (c *Column) SetNullable(nullable bool) *Column { + c.Nullable = nullable + return c +} + +func (c *Column) SetPrimaryKey(pk bool) *Column { + c.PrimaryKey = pk + return c +} + +func (c *Column) SetDefault(value string) *Column { + c.Default = OptionalString{Set: true, Value: value} + return c +} + +func (c *Column) SetType(t ColumnType) *Column { + c.Type = t + return c +} + +func (c *Column) SetSize(size int) *Column { + c.Size = OptionalInt{Set: true, Value: size} + return c +} + +func (c *Column) String() (string, error) { + var str strings.Builder + + str.WriteString(c.Name) + + str.WriteString(" ") + typeStr, err := c.Type.Format(c.Dialect, c.Size) + if err != nil { + return "", err + } + + str.WriteString(typeStr) + + if !c.Nullable { + str.WriteString(" NOT NULL") + } + + if c.Default.Set { + str.WriteString(" DEFAULT ") + str.WriteString(c.Default.Value) + } + + if c.PrimaryKey { + str.WriteString(" PRIMARY KEY") + } + + return str.String(), nil +} + +func (b *CreateTableSqlBuilder) Column(column *Column) *CreateTableSqlBuilder { + if b.Columns == nil { + b.Columns = make(map[string]*Column) + } + b.Columns[column.Name] = column + b.ColumnOrder = append(b.ColumnOrder, column.Name) + return b +} + +func (b *CreateTableSqlBuilder) UniqueConstraint(columns ...string) *CreateTableSqlBuilder { + for _, column := range columns { + if _, ok := b.Columns[column]; !ok { + // This fails silently. + return b + } + } + b.Constraints = append(b.Constraints, fmt.Sprintf("UNIQUE(%s)", strings.Join(columns, ","))) + return b +} + +func (b *CreateTableSqlBuilder) SetIfNotExists(ine bool) *CreateTableSqlBuilder { + b.IfNotExists = ine + return b +} + +func (b *CreateTableSqlBuilder) ToSQL() (string, error) { + var str strings.Builder + + str.WriteString("CREATE TABLE ") + if b.IfNotExists { + str.WriteString("IF NOT EXISTS ") + } + str.WriteString(b.Name) + + var things []string + for _, columnName := range b.ColumnOrder { + column, ok := b.Columns[columnName] + if !ok { + return "", fmt.Errorf("column not found: %s", columnName) + } + columnStr, err := column.String() + if err != nil { + return "", err + } + things = append(things, columnStr) + } + for _, constraint := range b.Constraints { + things = append(things, constraint) + } + + if thingLen := len(things); thingLen > 0 { + str.WriteString(" ( ") + for i, thing := range things { + str.WriteString(thing) + if i < thingLen-1 { + str.WriteString(", ") + } + } + str.WriteString(" )") + } + + return str.String(), nil +} + diff --git a/db/create_test.go b/db/create_test.go new file mode 100644 index 0000000..369d5c1 --- /dev/null +++ b/db/create_test.go @@ -0,0 +1,146 @@ +package db + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestDialect_Column(t *testing.T) { + c1 := DialectSQLite.Column("foo", ColumnTypeBool, UnsetSize) + assert.Equal(t, DialectSQLite, c1.Dialect) + c2 := DialectMySQL.Column("foo", ColumnTypeBool, UnsetSize) + assert.Equal(t, DialectMySQL, c2.Dialect) +} + +func TestColumnType_Format(t *testing.T) { + type args struct { + dialect DialectType + size OptionalInt + } + tests := []struct { + name string + d ColumnType + args args + want string + wantErr bool + }{ + {"Sqlite bool", ColumnTypeBool, args{dialect: DialectSQLite}, "INTEGER", false}, + {"Sqlite small int", ColumnTypeSmallInt, args{dialect: DialectSQLite}, "INTEGER", false}, + {"Sqlite int", ColumnTypeInteger, args{dialect: DialectSQLite}, "INTEGER", false}, + {"Sqlite char", ColumnTypeChar, args{dialect: DialectSQLite}, "TEXT", false}, + {"Sqlite varchar", ColumnTypeVarChar, args{dialect: DialectSQLite}, "TEXT", false}, + {"Sqlite text", ColumnTypeText, args{dialect: DialectSQLite}, "TEXT", false}, + {"Sqlite datetime", ColumnTypeDateTime, args{dialect: DialectSQLite}, "DATETIME", false}, + + {"MySQL bool", ColumnTypeBool, args{dialect: DialectMySQL}, "TINYINT(1)", false}, + {"MySQL small int", ColumnTypeSmallInt, args{dialect: DialectMySQL}, "SMALLINT", false}, + {"MySQL small int with param", ColumnTypeSmallInt, args{dialect: DialectMySQL, size: OptionalInt{true, 3}}, "SMALLINT(3)", false}, + {"MySQL int", ColumnTypeInteger, args{dialect: DialectMySQL}, "INT", false}, + {"MySQL int with param", ColumnTypeInteger, args{dialect: DialectMySQL, size: OptionalInt{true, 11}}, "INT(11)", false}, + {"MySQL char", ColumnTypeChar, args{dialect: DialectMySQL}, "CHAR", false}, + {"MySQL char with param", ColumnTypeChar, args{dialect: DialectMySQL, size: OptionalInt{true, 4}}, "CHAR(4)", false}, + {"MySQL varchar", ColumnTypeVarChar, args{dialect: DialectMySQL}, "VARCHAR", false}, + {"MySQL varchar with param", ColumnTypeVarChar, args{dialect: DialectMySQL, size: OptionalInt{true, 25}}, "VARCHAR(25)", false}, + {"MySQL text", ColumnTypeText, args{dialect: DialectMySQL}, "TEXT", false}, + {"MySQL datetime", ColumnTypeDateTime, args{dialect: DialectMySQL}, "DATETIME", false}, + + {"invalid column type", 10000, args{dialect: DialectMySQL}, "", true}, + {"invalid dialect", ColumnTypeBool, args{dialect: 10000}, "", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.d.Format(tt.args.dialect, tt.args.size) + if (err != nil) != tt.wantErr { + t.Errorf("Format() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("Format() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestColumn_Build(t *testing.T) { + type fields struct { + Dialect DialectType + Name string + Nullable bool + Default OptionalString + Type ColumnType + Size OptionalInt + PrimaryKey bool + } + tests := []struct { + name string + fields fields + want string + wantErr bool + }{ + {"Sqlite bool", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo INTEGER NOT NULL", false}, + {"Sqlite bool nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo INTEGER", false}, + {"Sqlite small int", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeSmallInt, UnsetSize, true}, "foo INTEGER NOT NULL PRIMARY KEY", false}, + {"Sqlite small int nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeSmallInt, UnsetSize, false}, "foo INTEGER", false}, + {"Sqlite int", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INTEGER NOT NULL", false}, + {"Sqlite int nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INTEGER", false}, + {"Sqlite char", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo TEXT NOT NULL", false}, + {"Sqlite char nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo TEXT", false}, + {"Sqlite varchar", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo TEXT NOT NULL", false}, + {"Sqlite varchar nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo TEXT", false}, + {"Sqlite text", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT NOT NULL", false}, + {"Sqlite text nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT", false}, + {"Sqlite datetime", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME NOT NULL", false}, + {"Sqlite datetime nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME", false}, + + {"MySQL bool", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo TINYINT(1) NOT NULL", false}, + {"MySQL bool nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo TINYINT(1)", false}, + {"MySQL small int", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeSmallInt, UnsetSize, true}, "foo SMALLINT NOT NULL PRIMARY KEY", false}, + {"MySQL small int nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeSmallInt, UnsetSize, false}, "foo SMALLINT", false}, + {"MySQL int", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INT NOT NULL", false}, + {"MySQL int nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INT", false}, + {"MySQL char", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo CHAR NOT NULL", false}, + {"MySQL char nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo CHAR", false}, + {"MySQL varchar", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo VARCHAR NOT NULL", false}, + {"MySQL varchar nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo VARCHAR", false}, + {"MySQL text", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT NOT NULL", false}, + {"MySQL text nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT", false}, + {"MySQL datetime", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME NOT NULL", false}, + {"MySQL datetime nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Column{ + Dialect: tt.fields.Dialect, + Name: tt.fields.Name, + Nullable: tt.fields.Nullable, + Default: tt.fields.Default, + Type: tt.fields.Type, + Size: tt.fields.Size, + PrimaryKey: tt.fields.PrimaryKey, + } + if got, err := c.String(); got != tt.want { + if (err != nil) != tt.wantErr { + t.Errorf("String() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("String() got = %v, want %v", got, tt.want) + } + } + }) + } +} + +func TestCreateTableSqlBuilder_ToSQL(t *testing.T) { + sql, err := DialectMySQL. + Table("foo"). + SetIfNotExists(true). + Column(DialectMySQL.Column("bar", ColumnTypeInteger, UnsetSize).SetPrimaryKey(true)). + Column(DialectMySQL.Column("baz", ColumnTypeText, UnsetSize)). + Column(DialectMySQL.Column("qux", ColumnTypeDateTime, UnsetSize).SetDefault("NOW()")). + UniqueConstraint("bar"). + UniqueConstraint("bar", "baz"). + ToSQL() + assert.NoError(t, err) + assert.Equal(t, "CREATE TABLE IF NOT EXISTS foo ( bar INT NOT NULL PRIMARY KEY, baz TEXT NOT NULL, qux DATETIME NOT NULL DEFAULT NOW(), UNIQUE(bar), UNIQUE(bar,baz) )", sql) +} diff --git a/db/dialect.go b/db/dialect.go new file mode 100644 index 0000000..4251465 --- /dev/null +++ b/db/dialect.go @@ -0,0 +1,76 @@ +package db + +import "fmt" + +type DialectType int + +const ( + DialectSQLite DialectType = iota + DialectMySQL DialectType = iota +) + +func (d DialectType) Column(name string, t ColumnType, size OptionalInt) *Column { + switch d { + case DialectSQLite: + return &Column{Dialect: DialectSQLite, Name: name, Type: t, Size: size} + case DialectMySQL: + return &Column{Dialect: DialectMySQL, Name: name, Type: t, Size: size} + default: + panic(fmt.Sprintf("unexpected dialect: %d", d)) + } +} + +func (d DialectType) Table(name string) *CreateTableSqlBuilder { + switch d { + case DialectSQLite: + return &CreateTableSqlBuilder{Dialect: DialectSQLite, Name: name} + case DialectMySQL: + return &CreateTableSqlBuilder{Dialect: DialectMySQL, Name: name} + default: + panic(fmt.Sprintf("unexpected dialect: %d", d)) + } +} + +func (d DialectType) AlterTable(name string) *AlterTableSqlBuilder { + switch d { + case DialectSQLite: + return &AlterTableSqlBuilder{Dialect: DialectSQLite, Name: name} + case DialectMySQL: + return &AlterTableSqlBuilder{Dialect: DialectMySQL, Name: name} + default: + panic(fmt.Sprintf("unexpected dialect: %d", d)) + } +} + +func (d DialectType) CreateUniqueIndex(name, table string, columns ...string) *CreateIndexSqlBuilder { + switch d { + case DialectSQLite: + return &CreateIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table, Unique: true, Columns: columns} + case DialectMySQL: + return &CreateIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table, Unique: true, Columns: columns} + default: + panic(fmt.Sprintf("unexpected dialect: %d", d)) + } +} + +func (d DialectType) CreateIndex(name, table string, columns ...string) *CreateIndexSqlBuilder { + switch d { + case DialectSQLite: + return &CreateIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table, Unique: false, Columns: columns} + case DialectMySQL: + return &CreateIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table, Unique: false, Columns: columns} + default: + panic(fmt.Sprintf("unexpected dialect: %d", d)) + } +} + +func (d DialectType) DropIndex(name, table string) *DropIndexSqlBuilder { + switch d { + case DialectSQLite: + return &DropIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table} + case DialectMySQL: + return &DropIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table} + default: + panic(fmt.Sprintf("unexpected dialect: %d", d)) + } +} diff --git a/db/index.go b/db/index.go new file mode 100644 index 0000000..8180224 --- /dev/null +++ b/db/index.go @@ -0,0 +1,53 @@ +package db + +import ( + "fmt" + "strings" +) + +type CreateIndexSqlBuilder struct { + Dialect DialectType + Name string + Table string + Unique bool + Columns []string +} + +type DropIndexSqlBuilder struct { + Dialect DialectType + Name string + Table string +} + +func (b *CreateIndexSqlBuilder) ToSQL() (string, error) { + var str strings.Builder + + str.WriteString("CREATE ") + if b.Unique { + str.WriteString("UNIQUE ") + } + str.WriteString("INDEX ") + str.WriteString(b.Name) + str.WriteString(" on ") + str.WriteString(b.Table) + + if len(b.Columns) == 0 { + return "", fmt.Errorf("columns provided for this index: %s", b.Name) + } + + str.WriteString(" (") + columnCount := len(b.Columns) + for i, thing := range b.Columns { + str.WriteString(thing) + if i < columnCount-1 { + str.WriteString(", ") + } + } + str.WriteString(")") + + return str.String(), nil +} + +func (b *DropIndexSqlBuilder) ToSQL() (string, error) { + return fmt.Sprintf("DROP INDEX %s on %s", b.Name, b.Table), nil +} diff --git a/db/raw.go b/db/raw.go new file mode 100644 index 0000000..d0301c8 --- /dev/null +++ b/db/raw.go @@ -0,0 +1,9 @@ +package db + +type RawSqlBuilder struct { + Query string +} + +func (b *RawSqlBuilder) ToSQL() (string, error) { + return b.Query, nil +} diff --git a/db/tx.go b/db/tx.go new file mode 100644 index 0000000..5c321af --- /dev/null +++ b/db/tx.go @@ -0,0 +1,26 @@ +package db + +import ( + "context" + "database/sql" +) + +// TransactionScopedWork describes code executed within a database transaction. +type TransactionScopedWork func(ctx context.Context, db *sql.Tx) error + +// RunTransactionWithOptions executes a block of code within a database transaction. +func RunTransactionWithOptions(ctx context.Context, db *sql.DB, txOpts *sql.TxOptions, txWork TransactionScopedWork) error { + tx, err := db.BeginTx(ctx, txOpts) + if err != nil { + return err + } + + if err = txWork(ctx, tx); err != nil { + if txErr := tx.Rollback(); txErr != nil { + return txErr + } + return err + } + return tx.Commit() +} + diff --git a/errors.go b/errors.go index 0092b7f..b62fc9e 100644 --- a/errors.go +++ b/errors.go @@ -1,5 +1,5 @@ /* - * Copyright © 2018 A Bunch Tell LLC. + * Copyright © 2018-2020 A Bunch Tell LLC. * * This file is part of WriteFreely. * @@ -11,8 +11,9 @@ package writefreely import ( - "github.com/writeas/impart" "net/http" + + "github.com/writeas/impart" ) // Commonly returned HTTP errors @@ -44,8 +45,11 @@ var ( ErrPostUnpublished = impart.HTTPError{Status: http.StatusGone, Message: "Post unpublished by author."} ErrPostFetchError = impart.HTTPError{Status: http.StatusInternalServerError, Message: "We encountered an error getting the post. The humans have been alerted."} - ErrUserNotFound = impart.HTTPError{http.StatusNotFound, "User doesn't exist."} - ErrUserNotFoundEmail = impart.HTTPError{http.StatusNotFound, "Please enter your username instead of your email address."} + ErrUserNotFound = impart.HTTPError{http.StatusNotFound, "User doesn't exist."} + ErrRemoteUserNotFound = impart.HTTPError{http.StatusNotFound, "Remote user not found."} + ErrUserNotFoundEmail = impart.HTTPError{http.StatusNotFound, "Please enter your username instead of your email address."} + + ErrUserSilenced = impart.HTTPError{http.StatusForbidden, "Account is silenced."} ) // Post operation errors diff --git a/feed.go b/feed.go index 32feb82..4e1f612 100644 --- a/feed.go +++ b/feed.go @@ -12,12 +12,13 @@ package writefreely import ( "fmt" + "net/http" + "time" + . "github.com/gorilla/feeds" "github.com/gorilla/mux" stripmd "github.com/writeas/go-strip-markdown" "github.com/writeas/web-core/log" - "net/http" - "time" ) func ViewFeed(app *App, w http.ResponseWriter, req *http.Request) error { @@ -34,6 +35,15 @@ func ViewFeed(app *App, w http.ResponseWriter, req *http.Request) error { if err != nil { return nil } + + silenced, err := app.db.IsUserSilenced(c.OwnerID) + if err != nil { + log.Error("view feed: get user: %v", err) + return ErrInternalGeneral + } + if silenced { + return ErrCollectionNotFound + } c.hostName = app.cfg.App.Host if c.IsPrivate() || c.IsProtected() { diff --git a/go.mod b/go.mod index 5e27956..5da3da4 100644 --- a/go.mod +++ b/go.mod @@ -6,17 +6,21 @@ require ( github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf // indirect github.com/captncraig/cors v0.0.0-20180620154129-376d45073b49 // indirect github.com/clbanning/mxj v1.8.4 // indirect + github.com/dchest/uniuri v0.0.0-20160212164326-8902c56451e9 // indirect github.com/dustin/go-humanize v1.0.0 github.com/fatih/color v1.7.0 + github.com/go-fed/httpsig v0.1.1-0.20190924171022-f4c36041199d // indirect github.com/go-sql-driver/mysql v1.4.1 github.com/go-test/deep v1.0.1 // indirect github.com/golang/lint v0.0.0-20181217174547-8f45f776aaf1 // indirect + github.com/gologme/log v0.0.0-20181207131047-4e5d8ccb38e8 // indirect github.com/gopherjs/gopherjs v0.0.0-20181103185306-d547d1d9531e // indirect github.com/gorilla/feeds v1.1.0 github.com/gorilla/mux v1.7.0 github.com/gorilla/schema v1.0.2 - github.com/gorilla/sessions v1.1.3 + github.com/gorilla/sessions v1.2.0 github.com/guregu/null v3.4.0+incompatible + github.com/hashicorp/go-multierror v1.0.0 github.com/ikeikeikeike/go-sitemap-generator/v2 v2.0.2 github.com/jtolds/gls v4.2.1+incompatible // indirect github.com/kylemcc/twitter-text-go v0.0.0-20180726194232-7f582f6736ec @@ -31,30 +35,30 @@ require ( github.com/pelletier/go-toml v1.2.0 // indirect github.com/pkg/errors v0.8.1 // indirect github.com/rainycape/unidecode v0.0.0-20150907023854-cb7f23ec59be // indirect - github.com/shurcooL/sanitized_anchor_name v1.0.0 // indirect github.com/smartystreets/assertions v0.0.0-20190116191733-b6c0e53d7304 // indirect github.com/smartystreets/goconvey v0.0.0-20181108003508-044398e4856c // indirect - github.com/stretchr/testify v1.3.0 // indirect + github.com/stretchr/testify v1.3.0 github.com/writeas/activity v0.1.2 + github.com/writeas/activityserve v0.0.0-20191115095800-dd6d19cc8b89 github.com/writeas/go-strip-markdown v2.0.1+incompatible github.com/writeas/go-webfinger v0.0.0-20190106002315-85cf805c86d2 github.com/writeas/httpsig v1.0.0 - github.com/writeas/impart v1.1.0 + github.com/writeas/impart v1.1.1-0.20191230230525-d3c45ced010d + github.com/writeas/import v0.2.0 github.com/writeas/monday v0.0.0-20181024183321-54a7dd579219 github.com/writeas/nerds v1.0.0 - github.com/writeas/openssl-go v1.0.0 // indirect github.com/writeas/saturday v1.7.1 github.com/writeas/slug v1.2.0 - github.com/writeas/web-core v1.0.0 + github.com/writeas/web-core v1.2.0 github.com/writefreely/go-nodeinfo v1.2.0 - golang.org/x/crypto v0.0.0-20190208162236-193df9c0f06f + golang.org/x/crypto v0.0.0-20200109152110-61a87790db17 golang.org/x/lint v0.0.0-20181217174547-8f45f776aaf1 // indirect - golang.org/x/net v0.0.0-20190206173232-65e2d4e15006 // indirect - golang.org/x/sys v0.0.0-20190209173611-3b5209105503 // indirect - golang.org/x/tools v0.0.0-20190208222737-3744606dbb67 + golang.org/x/tools v0.0.0-20190208222737-3744606dbb67 // indirect google.golang.org/appengine v1.4.0 // indirect gopkg.in/alecthomas/kingpin.v3-unstable v3.0.0-20180810215634-df19058c872c // indirect gopkg.in/ini.v1 v1.41.0 - gopkg.in/yaml.v1 v1.0.0-20140924161607-9f9df34309c0 // indirect gopkg.in/yaml.v2 v2.2.2 // indirect + src.techknowlogick.com/xgo v0.0.0-20200129005940-d0fae26e014b // indirect ) + +go 1.13 diff --git a/go.sum b/go.sum index ec1e19d..2d433ec 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +code.as/core/socks v1.0.0 h1:SPQXNp4SbEwjOAP9VzUahLHak8SDqy5n+9cm9tpjZOs= +code.as/core/socks v1.0.0/go.mod h1:BAXBy5O9s2gmw6UxLqNJcVbWY7C/UPs+801CcSsfWOY= github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/alecthomas/gometalinter v2.0.11+incompatible/go.mod h1:qfIpQGGz3d+NmgyPBqv+LSh50emm1pt72EtcX2vKYQk= @@ -23,13 +25,18 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dchest/uniuri v0.0.0-20160212164326-8902c56451e9 h1:74lLNRzvsdIlkTgfDSMuaPjBr4cf6k7pwQQANm/yLKU= +github.com/dchest/uniuri v0.0.0-20160212164326-8902c56451e9/go.mod h1:GgB8SF9nRG+GqaDtLcwJZsQFhcogVCJ79j4EdT0c2V4= github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/fatih/color v1.7.0 h1:DkWD4oS2D8LGGgTQ6IvwJJXSL5Vp2ffcQg58nFV38Ys= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fatih/structs v1.1.0 h1:Q7juDM0QtcnhCpeyLGQKyg4TOIghuNXrkL32pHAUMxo= github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= +github.com/go-fed/httpsig v0.1.0 h1:6F2OxRVnNTN4OPN+Mc2jxs2WEay9/qiHT/jphlvAwIY= github.com/go-fed/httpsig v0.1.0/go.mod h1:T56HUNYZUQ1AGUzhAYPugZfp36sKApVnGBgKlIY+aIE= +github.com/go-fed/httpsig v0.1.1-0.20190924171022-f4c36041199d h1:+uoOvOnNDgsYbWtAij4xP6Rgir3eJGjocFPxBJETU/U= +github.com/go-fed/httpsig v0.1.1-0.20190924171022-f4c36041199d/go.mod h1:T56HUNYZUQ1AGUzhAYPugZfp36sKApVnGBgKlIY+aIE= github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA= github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= github.com/go-test/deep v1.0.1 h1:UQhStjbkDClarlmv0am7OXXO4/GaPdCGiUiMTvi28sg= @@ -38,14 +45,14 @@ github.com/golang/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:tluoj9z5200j github.com/golang/lint v0.0.0-20181217174547-8f45f776aaf1 h1:6DVPu65tee05kY0/rciBQ47ue+AnuY8KTayV6VHikIo= github.com/golang/lint v0.0.0-20181217174547-8f45f776aaf1/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/gologme/log v0.0.0-20181207131047-4e5d8ccb38e8 h1:WD8iJ37bRNwvETMfVTusVSAi0WdXTpfNVGY2aHycNKY= +github.com/gologme/log v0.0.0-20181207131047-4e5d8ccb38e8/go.mod h1:gq31gQ8wEHkR+WekdWsqDuf8pXTUZA9BnnzTuPz1Y9U= github.com/google/shlex v0.0.0-20181106134648-c34317bd91bf h1:7+FW5aGwISbqUtkfmIpZJGRgNFg2ioYPvFaUxdqpDsg= github.com/google/shlex v0.0.0-20181106134648-c34317bd91bf/go.mod h1:RpwtwJQFrIEPstU94h88MWPXP2ektJZ8cZ0YntAmXiE= github.com/gopherjs/gopherjs v0.0.0-20181103185306-d547d1d9531e h1:JKmoR8x90Iww1ks85zJ1lfDGgIiMDuIptTOhJq+zKyg= github.com/gopherjs/gopherjs v0.0.0-20181103185306-d547d1d9531e/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gordonklaus/ineffassign v0.0.0-20180909121442-1003c8bd00dc h1:cJlkeAx1QYgO5N80aF5xRGstVsRQwgLR7uA2FnP1ZjY= github.com/gordonklaus/ineffassign v0.0.0-20180909121442-1003c8bd00dc/go.mod h1:cuNKsD1zp2v6XfE/orVX2QE1LC+i254ceGcVeDT3pTU= -github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8= -github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/feeds v1.1.0 h1:pcgLJhbdYgaUESnj3AmXPcB7cS3vy63+jC/TI14AGXk= github.com/gorilla/feeds v1.1.0/go.mod h1:Nk0jZrvPFZX1OBe5NPiddPw7CfwF6Q9eqzaBbaightA= github.com/gorilla/mux v1.7.0 h1:tOSd0UKHQd6urX6ApfOn4XdBMY6Sh1MfxV3kmaazO+U= @@ -54,10 +61,14 @@ github.com/gorilla/schema v1.0.2 h1:sAgNfOcNYvdDSrzGHVy9nzCQahG+qmsg+nE8dK85QRA= github.com/gorilla/schema v1.0.2/go.mod h1:kgLaKoK1FELgZqMAVxx/5cbj0kT+57qxUrAlIO2eleU= github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= -github.com/gorilla/sessions v1.1.3 h1:uXoZdcdA5XdXF3QzuSlheVRUvjl+1rKY7zBXL68L9RU= -github.com/gorilla/sessions v1.1.3/go.mod h1:8KCfur6+4Mqcc6S0FEfKuN15Vl5MgXW92AE8ovaJD0w= +github.com/gorilla/sessions v1.2.0 h1:S7P+1Hm5V/AT9cjEcUD5uDaQSX0OE577aCXgoaKpYbQ= +github.com/gorilla/sessions v1.2.0/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= github.com/guregu/null v3.4.0+incompatible h1:a4mw37gBO7ypcBlTJeZGuMpSxxFTV9qFfFKgWxQSGaM= github.com/guregu/null v3.4.0+incompatible/go.mod h1:ePGpQaN9cw0tj45IR5E5ehMvsFlLlQZAkkOXZurJ3NM= +github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-multierror v1.0.0 h1:iVjPR7a6H0tWELX5NxNe7bYopibicUzc7uPribsnS6o= +github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk= github.com/ikeikeikeike/go-sitemap-generator/v2 v2.0.2 h1:wIdDEle9HEy7vBPjC6oKz6ejs3Ut+jmsYvuOoAW2pSM= github.com/ikeikeikeike/go-sitemap-generator/v2 v2.0.2/go.mod h1:WtaVKD9TeruTED9ydiaOJU08qGoEPP/LyzTKiD3jEsw= github.com/jtolds/gls v4.2.1+incompatible h1:fSuqC+Gmlu6l/ZYAoZzx2pyucC8Xza35fpRVWLVmUEE= @@ -115,44 +126,63 @@ github.com/tsenart/deadcode v0.0.0-20160724212837-210d2dc333e9 h1:vY5WqiEon0ZSTG github.com/tsenart/deadcode v0.0.0-20160724212837-210d2dc333e9/go.mod h1:q+QjxYvZ+fpjMXqs+XEriussHjSYqeXVnAdSV1tkMYk= github.com/writeas/activity v0.1.2 h1:Y12B5lIrabfqKE7e7HFCWiXrlfXljr9tlkFm2mp7DgY= github.com/writeas/activity v0.1.2/go.mod h1:mYYgiewmEM+8tlifirK/vl6tmB2EbjYaxwb+ndUw5T0= +github.com/writeas/activityserve v0.0.0-20191008122325-5fc3b48e70c5 h1:nG84xWpxBM8YU/FJchezJqg7yZH8ImSRow6NoYtbSII= +github.com/writeas/activityserve v0.0.0-20191008122325-5fc3b48e70c5/go.mod h1:Kz62mzYsCnrFTSTSFLXFj3fGYBQOntmBWTDDq57b46A= +github.com/writeas/activityserve v0.0.0-20191011072627-3a81f7784d5b h1:rd2wX/bTqD55hxtBjAhwLcUgaQE36c70KX3NzpDAwVI= +github.com/writeas/activityserve v0.0.0-20191011072627-3a81f7784d5b/go.mod h1:Kz62mzYsCnrFTSTSFLXFj3fGYBQOntmBWTDDq57b46A= +github.com/writeas/activityserve v0.0.0-20191115095800-dd6d19cc8b89 h1:NJhzq9aTccL3SSSZMrcnYhkD6sObdY9otNZ1X6/ZKNE= +github.com/writeas/activityserve v0.0.0-20191115095800-dd6d19cc8b89/go.mod h1:Kz62mzYsCnrFTSTSFLXFj3fGYBQOntmBWTDDq57b46A= github.com/writeas/go-strip-markdown v2.0.1+incompatible h1:IIqxTM5Jr7RzhigcL6FkrCNfXkvbR+Nbu1ls48pXYcw= github.com/writeas/go-strip-markdown v2.0.1+incompatible/go.mod h1:Rsyu10ZhbEK9pXdk8V6MVnZmTzRG0alMNLMwa0J01fE= github.com/writeas/go-webfinger v0.0.0-20190106002315-85cf805c86d2 h1:DUsp4OhdfI+e6iUqcPQlwx8QYXuUDsToTz/x82D3Zuo= github.com/writeas/go-webfinger v0.0.0-20190106002315-85cf805c86d2/go.mod h1:w2VxyRO/J5vfNjJHYVubsjUGHd3RLDoVciz0DE3ApOc= +github.com/writeas/go-writeas v1.1.0 h1:WHGm6wriBkxYAOGbvriXH8DlMUGOi6jhSZLUZKQ+4mQ= +github.com/writeas/go-writeas v1.1.0/go.mod h1:oh9U1rWaiE0p3kzdKwwvOpNXgp0P0IELI7OLOwV4fkA= +github.com/writeas/go-writeas/v2 v2.0.2 h1:akvdMg89U5oBJiCkBwOXljVLTqP354uN6qnG2oOMrbk= +github.com/writeas/go-writeas/v2 v2.0.2/go.mod h1:9sjczQJKmru925fLzg0usrU1R1tE4vBmQtGnItUMR0M= github.com/writeas/httpsig v1.0.0 h1:peIAoIA3DmlP8IG8tMNZqI4YD1uEnWBmkcC9OFPjt3A= github.com/writeas/httpsig v1.0.0/go.mod h1:7ClMGSrSVXJbmiLa17bZ1LrG1oibGZmUMlh3402flPY= github.com/writeas/impart v1.1.0 h1:nPnoO211VscNkp/gnzir5UwCDEvdHThL5uELU60NFSE= github.com/writeas/impart v1.1.0/go.mod h1:g0MpxdnTOHHrl+Ca/2oMXUHJ0PcRAEWtkCzYCJUXC9Y= +github.com/writeas/impart v1.1.1-0.20191230230525-d3c45ced010d h1:PK7DOj3JE6MGf647esPrKzXEHFjGWX2hl22uX79ixaE= +github.com/writeas/impart v1.1.1-0.20191230230525-d3c45ced010d/go.mod h1:g0MpxdnTOHHrl+Ca/2oMXUHJ0PcRAEWtkCzYCJUXC9Y= +github.com/writeas/import v0.2.0 h1:Ov23JW9Rnjxk06rki1Spar45bNX647HhwhAZj3flJiY= +github.com/writeas/import v0.2.0/go.mod h1:gFe0Pl7ZWYiXbI0TJxeMMyylPGZmhVvCfQxhMEc8CxM= github.com/writeas/monday v0.0.0-20181024183321-54a7dd579219 h1:baEp0631C8sT2r/hqwypIw2snCFZa6h7U6TojoLHu/c= github.com/writeas/monday v0.0.0-20181024183321-54a7dd579219/go.mod h1:NyM35ayknT7lzO6O/1JpfgGyv+0W9Z9q7aE0J8bXxfQ= github.com/writeas/nerds v1.0.0 h1:ZzRcCN+Sr3MWID7o/x1cr1ZbLvdpej9Y1/Ho+JKlqxo= github.com/writeas/nerds v1.0.0/go.mod h1:Gn2bHy1EwRcpXeB7ZhVmuUwiweK0e+JllNf66gvNLdU= github.com/writeas/openssl-go v1.0.0 h1:YXM1tDXeYOlTyJjoMlYLQH1xOloUimSR1WMF8kjFc5o= github.com/writeas/openssl-go v1.0.0/go.mod h1:WsKeK5jYl0B5y8ggOmtVjbmb+3rEGqSD25TppjJnETA= +github.com/writeas/saturday v1.6.0/go.mod h1:ETE1EK6ogxptJpAgUbcJD0prAtX48bSloie80+tvnzQ= github.com/writeas/saturday v1.7.1 h1:lYo1EH6CYyrFObQoA9RNWHVlpZA5iYL5Opxo7PYAnZE= github.com/writeas/saturday v1.7.1/go.mod h1:ETE1EK6ogxptJpAgUbcJD0prAtX48bSloie80+tvnzQ= github.com/writeas/slug v1.2.0 h1:EMQ+cwLiOcA6EtFwUgyw3Ge18x9uflUnOnR6bp/J+/g= github.com/writeas/slug v1.2.0/go.mod h1:RE8shOqQP3YhsfsQe0L3RnuejfQ4Mk+JjY5YJQFubfQ= -github.com/writeas/web-core v1.0.0 h1:5VKkCakQgdKZcbfVKJXtRpc5VHrkflusCl/KRCPzpQ0= -github.com/writeas/web-core v1.0.0/go.mod h1:Si3chV7VWgY8CsV+3gRolMXSO2Vx1ZFAQ/mkrpvmyEE= +github.com/writeas/web-core v1.2.0 h1:CYqvBd+byi1cK4mCr1NZ6CjILuMOFmiFecv+OACcmG0= +github.com/writeas/web-core v1.2.0/go.mod h1:vTYajviuNBAxjctPp2NUYdgjofywVkxUGpeaERF3SfI= github.com/writefreely/go-nodeinfo v1.2.0 h1:La+YbTCvmpTwFhBSlebWDDL81N88Qf/SCAvRLR7F8ss= github.com/writefreely/go-nodeinfo v1.2.0/go.mod h1:UTvE78KpcjYOlRHupZIiSEFcXHioTXuacCbHU+CAcPg= golang.org/x/crypto v0.0.0-20180527072434-ab813273cd59 h1:hk3yo72LXLapY9EXVttc3Z1rLOxT9IuAPPX3GpY2+jo= golang.org/x/crypto v0.0.0-20180527072434-ab813273cd59/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/crypto v0.0.0-20190208162236-193df9c0f06f h1:ETU2VEl7TnT5bl7IvuKEzTDpplg5wzGYsOCAPhdoEIg= -golang.org/x/crypto v0.0.0-20190208162236-193df9c0f06f/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20190131182504-b8fe1690c613/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200109152110-61a87790db17 h1:nVJ3guKA9qdkEQ3TUdXI9QSINo2CUPM/cySEvw2w8I0= +golang.org/x/crypto v0.0.0-20200109152110-61a87790db17/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181217174547-8f45f776aaf1 h1:rJm0LuqUjoDhSk2zO9ISMSToQxGz7Os2jRiOL8AWu4c= golang.org/x/lint v0.0.0-20181217174547-8f45f776aaf1/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181220203305-927f97764cc3 h1:eH6Eip3UpmR+yM/qI9Ijluzb1bNv/cAU/n+6l8tRSis= golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190206173232-65e2d4e15006 h1:bfLnR+k0tq5Lqt6dflRLcZiz6UaXCMt3vhYJ1l4FQ80= -golang.org/x/net v0.0.0-20190206173232-65e2d4e15006/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 h1:0GoQqolDA55aaLxZyTzK/Y2ePZzZTUrRacwib7cNsYQ= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/sys v0.0.0-20180525142821-c11f84a56e43/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181122145206-62eef0e2fa9b/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190209173611-3b5209105503 h1:5SvYFrOM3W8Mexn9/oA44Ji7vhXAZQ9hiP+1Q/DMrWg= -golang.org/x/sys v0.0.0-20190209173611-3b5209105503/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/tools v0.0.0-20181122213734-04b5d21e00f1/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190208222737-3744606dbb67 h1:bPP/rGuN1LUM0eaEwo6vnP6OfIWJzJBulzGUiKLjjSY= @@ -170,3 +200,5 @@ gopkg.in/yaml.v1 v1.0.0-20140924161607-9f9df34309c0 h1:POO/ycCATvegFmVuPpQzZFJ+p gopkg.in/yaml.v1 v1.0.0-20140924161607-9f9df34309c0/go.mod h1:WDnlLJ4WF5VGsH/HVa3CI79GS0ol3YnhVnKP89i0kNg= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +src.techknowlogick.com/xgo v0.0.0-20200129005940-d0fae26e014b h1:rPAdjgXks4ToezTjygsnKZroxKVnA1L35DSpsJXPtfc= +src.techknowlogick.com/xgo v0.0.0-20200129005940-d0fae26e014b/go.mod h1:31CE1YKtDOrKTk9PSnjTpe6YbO6W/0LTYZ1VskL09oU= diff --git a/handle.go b/handle.go index 7e410f5..0fcc483 100644 --- a/handle.go +++ b/handle.go @@ -73,7 +73,7 @@ type ( type Handler struct { errors *ErrorPages - sessionStore *sessions.CookieStore + sessionStore sessions.Store app Apper } @@ -96,7 +96,7 @@ func NewHandler(apper Apper) *Handler { InternalServerError: template.Must(template.New("").Parse("{{define \"base\"}}
Internal server error.
{{end}}")), Blank: template.Must(template.New("").Parse("{{define \"base\"}}{{.Content}}
{{end}}")), }, - sessionStore: apper.App().sessionStore, + sessionStore: apper.App().SessionStore(), app: apper, } @@ -549,6 +549,37 @@ func (h *Handler) All(f handlerFunc) http.HandlerFunc { } } +func (h *Handler) OAuth(f handlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + h.handleOAuthError(w, r, func() error { + // TODO: return correct "success" status + status := 200 + start := time.Now() + + defer func() { + if e := recover(); e != nil { + log.Error("%s:\n%s", e, debug.Stack()) + impart.WriteError(w, impart.HTTPError{http.StatusInternalServerError, "Something didn't work quite right."}) + status = 500 + } + + log.Info(h.app.ReqLog(r, status, time.Since(start))) + }() + + err := f(h.app.App(), w, r) + if err != nil { + if err, ok := err.(impart.HTTPError); ok { + status = err.Status + } else { + status = 500 + } + } + + return err + }()) + } +} + func (h *Handler) AllReader(f handlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { h.handleError(w, r, func() error { @@ -779,6 +810,25 @@ func (h *Handler) handleError(w http.ResponseWriter, r *http.Request, err error) h.errors.InternalServerError.ExecuteTemplate(w, "base", pageForReq(h.app.App(), r)) } +func (h *Handler) handleOAuthError(w http.ResponseWriter, r *http.Request, err error) { + if err == nil { + return + } + + if err, ok := err.(impart.HTTPError); ok { + if err.Status >= 300 && err.Status < 400 { + sendRedirect(w, err.Status, err.Message) + return + } + + impart.WriteOAuthError(w, err) + return + } + + impart.WriteOAuthError(w, impart.HTTPError{http.StatusInternalServerError, "This is an unhelpful error message for a miscellaneous internal error."}) + return +} + func correctPageFromLoginAttempt(r *http.Request) string { to := r.FormValue("to") if to == "" { diff --git a/invites.go b/invites.go index 4e1f5fa..d5d024a 100644 --- a/invites.go +++ b/invites.go @@ -56,12 +56,19 @@ func handleViewUserInvites(app *App, u *User, w http.ResponseWriter, r *http.Req p := struct { *UserPage - Invites *[]Invite + Invites *[]Invite + Silenced bool }{ UserPage: NewUserPage(app, r, u, "Invite People", f), } var err error + + p.Silenced, err = app.db.IsUserSilenced(u.ID) + if err != nil { + log.Error("view invites: %v", err) + } + p.Invites, err = app.db.GetUserInvites(u.ID) if err != nil { return err @@ -78,6 +85,10 @@ func handleCreateUserInvite(app *App, u *User, w http.ResponseWriter, r *http.Re muVal := r.FormValue("uses") expVal := r.FormValue("expires") + if u.IsSilenced() { + return ErrUserSilenced + } + var err error var maxUses int if muVal != "0" { diff --git a/less/core.less b/less/core.less index f4332a9..fe8a28d 100644 --- a/less/core.less +++ b/less/core.less @@ -516,10 +516,17 @@ abbr { body#collection article p, body#subpage article p { .article-p; } -pre, body#post article, body#collection article, body#subpage article, body#subpage #wrapper h1 { +pre, body#post article, #post .alert, #subpage .alert, body#collection article, body#subpage article, body#subpage #wrapper h1 { max-width: 40rem; margin: 0 auto; } +#collection header .alert, #post .alert, #subpage .alert { + margin-bottom: 1em; + p { + text-align: left; + line-height: 1.4; + } +} textarea, pre, body#post article, body#collection article p { &.norm, &.sans, &.wrap { line-height: 1.4em; @@ -677,18 +684,19 @@ select.inputform, textarea.inputform { border: 1px solid #999; } -input, button, select.inputform, textarea.inputform { +input, button, select.inputform, textarea.inputform, a.btn { padding: 0.5em; font-family: @serifFont; font-size: 100%; .rounded(.25em); - &[type=submit], &.submit { + &[type=submit], &.submit, &.cta { border: 1px solid @primary; background: @primary; color: white; .transition(0.2s); &:hover { background-color: lighten(@primary, 3%); + text-decoration: none; } &:disabled { cursor: default; @@ -1310,6 +1318,24 @@ form { font-size: 0.86em; line-height: 2; } + + &.prominent { + margin: 1em 0; + + label { + font-weight: bold; + } + input, select { + width: 100%; + } + select { + font-size: 1em; + padding: 0.5rem; + display: block; + border-radius: 0.25rem; + margin: 0.5rem 0; + } + } } div.row { display: flex; diff --git a/less/post-temp.less b/less/post-temp.less index 3ec682d..8173864 100644 --- a/less/post-temp.less +++ b/less/post-temp.less @@ -17,6 +17,16 @@ body { font-size: 1.6em; } } + article { + h2#title.dated { + margin-bottom: 0.5em; + } + time.dt-published { + display: block; + color: #666; + margin-bottom: 1em; + } + } } } diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..d600d83 --- /dev/null +++ b/main_test.go @@ -0,0 +1,153 @@ +package writefreely + +import ( + "context" + "database/sql" + "encoding/gob" + "errors" + "fmt" + uuid "github.com/nu7hatch/gouuid" + "github.com/stretchr/testify/assert" + "math/rand" + "os" + "strings" + "testing" + "time" +) + +var testDB *sql.DB + +type ScopedTestBody func(*sql.DB) + +// TestMain provides testing infrastructure within this package. +func TestMain(m *testing.M) { + rand.Seed(time.Now().UTC().UnixNano()) + gob.Register(&User{}) + + if runMySQLTests() { + var err error + + testDB, err = initMySQL(os.Getenv("WF_USER"), os.Getenv("WF_PASSWORD"), os.Getenv("WF_DB"), os.Getenv("WF_HOST")) + if err != nil { + fmt.Println(err) + return + } + } + + code := m.Run() + if runMySQLTests() { + if closeErr := testDB.Close(); closeErr != nil { + fmt.Println(closeErr) + } + } + os.Exit(code) +} + +func runMySQLTests() bool { + return len(os.Getenv("TEST_MYSQL")) > 0 +} + +func initMySQL(dbUser, dbPassword, dbName, dbHost string) (*sql.DB, error) { + if dbUser == "" || dbPassword == "" { + return nil, errors.New("database user or password not set") + } + if dbHost == "" { + dbHost = "localhost" + } + if dbName == "" { + dbName = "writefreely" + } + + dsn := fmt.Sprintf("%s:%s@tcp(%s:3306)/%s?charset=utf8mb4&parseTime=true", dbUser, dbPassword, dbHost, dbName) + db, err := sql.Open("mysql", dsn) + if err != nil { + return nil, err + } + if err := ensureMySQL(db); err != nil { + return nil, err + } + return db, nil +} + +func ensureMySQL(db *sql.DB) error { + if err := db.Ping(); err != nil { + return err + } + db.SetMaxOpenConns(250) + return nil +} + +// withTestDB provides a scoped database connection. +func withTestDB(t *testing.T, testBody ScopedTestBody) { + db, cleanup, err := newTestDatabase(testDB, + os.Getenv("WF_USER"), + os.Getenv("WF_PASSWORD"), + os.Getenv("WF_DB"), + os.Getenv("WF_HOST"), + ) + assert.NoError(t, err) + defer func() { + assert.NoError(t, cleanup()) + }() + + testBody(db) +} + +// newTestDatabase creates a new temporary test database. When a test +// database connection is returned, it will have created a new database and +// initialized it with tables from a reference database. +func newTestDatabase(base *sql.DB, dbUser, dbPassword, dbName, dbHost string) (*sql.DB, func() error, error) { + var err error + var baseName = dbName + + if baseName == "" { + row := base.QueryRow("SELECT DATABASE()") + err := row.Scan(&baseName) + if err != nil { + return nil, nil, err + } + } + tUUID, _ := uuid.NewV4() + suffix := strings.Replace(tUUID.String(), "-", "_", -1) + newDBName := baseName + suffix + _, err = base.Exec("CREATE DATABASE " + newDBName) + if err != nil { + return nil, nil, err + } + newDB, err := initMySQL(dbUser, dbPassword, newDBName, dbHost) + if err != nil { + return nil, nil, err + } + + rows, err := base.Query("SHOW TABLES IN " + baseName) + if err != nil { + return nil, nil, err + } + for rows.Next() { + var tableName string + if err := rows.Scan(&tableName); err != nil { + return nil, nil, err + } + query := fmt.Sprintf("CREATE TABLE %s LIKE %s.%s", tableName, baseName, tableName) + if _, err := newDB.Exec(query); err != nil { + return nil, nil, err + } + } + + cleanup := func() error { + if closeErr := newDB.Close(); closeErr != nil { + fmt.Println(closeErr) + } + + _, err = base.Exec("DROP DATABASE " + newDBName) + return err + } + return newDB, cleanup, nil +} + +func countRows(t *testing.T, ctx context.Context, db *sql.DB, count int, query string, args ...interface{}) { + var returned int + err := db.QueryRowContext(ctx, query, args...).Scan(&returned) + assert.NoError(t, err, "error executing query %s and args %s", query, args) + assert.Equal(t, count, returned, "unexpected return count %d, expected %d from %s and args %s", returned, count, query, args) +} \ No newline at end of file diff --git a/migrations/migrations.go b/migrations/migrations.go index 70e4b7b..41f036f 100644 --- a/migrations/migrations.go +++ b/migrations/migrations.go @@ -13,6 +13,7 @@ package migrations import ( "database/sql" + "github.com/writeas/web-core/log" ) @@ -55,8 +56,12 @@ func (m *migration) Migrate(db *datastore) error { } var migrations = []Migration{ - New("support user invites", supportUserInvites), // -> V1 (v0.8.0) - New("support dynamic instance pages", supportInstancePages), // V1 -> V2 (v0.9.0) + New("support user invites", supportUserInvites), // -> V1 (v0.8.0) + New("support dynamic instance pages", supportInstancePages), // V1 -> V2 (v0.9.0) + New("support users suspension", supportUserStatus), // V2 -> V3 (v0.11.0) + New("support oauth", oauth), // V3 -> V4 + New("support slack oauth", oauthSlack), // V4 -> v5 + New("support ActivityPub mentions", supportActivityPubMentions), // V5 -> V6 (v0.12.0) } // CurrentVer returns the current migration version the application is on diff --git a/migrations/v3.go b/migrations/v3.go new file mode 100644 index 0000000..b5351da --- /dev/null +++ b/migrations/v3.go @@ -0,0 +1,29 @@ +/* + * Copyright © 2019 A Bunch Tell LLC. + * + * This file is part of WriteFreely. + * + * WriteFreely is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, included + * in the LICENSE file in this source code package. + */ + +package migrations + +func supportUserStatus(db *datastore) error { + t, err := db.Begin() + + _, err = t.Exec(`ALTER TABLE users ADD COLUMN status ` + db.typeInt() + ` DEFAULT '0' NOT NULL`) + if err != nil { + t.Rollback() + return err + } + + err = t.Commit() + if err != nil { + t.Rollback() + return err + } + + return nil +} diff --git a/migrations/v4.go b/migrations/v4.go new file mode 100644 index 0000000..c075dd8 --- /dev/null +++ b/migrations/v4.go @@ -0,0 +1,46 @@ +package migrations + +import ( + "context" + "database/sql" + + wf_db "github.com/writeas/writefreely/db" +) + +func oauth(db *datastore) error { + dialect := wf_db.DialectMySQL + if db.driverName == driverSQLite { + dialect = wf_db.DialectSQLite + } + return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { + createTableUsersOauth, err := dialect. + Table("oauth_users"). + SetIfNotExists(true). + Column(dialect.Column("user_id", wf_db.ColumnTypeInteger, wf_db.UnsetSize)). + Column(dialect.Column("remote_user_id", wf_db.ColumnTypeInteger, wf_db.UnsetSize)). + UniqueConstraint("user_id"). + UniqueConstraint("remote_user_id"). + ToSQL() + if err != nil { + return err + } + createTableOauthClientState, err := dialect. + Table("oauth_client_states"). + SetIfNotExists(true). + Column(dialect.Column("state", wf_db.ColumnTypeVarChar, wf_db.OptionalInt{Set: true, Value: 255})). + Column(dialect.Column("used", wf_db.ColumnTypeBool, wf_db.UnsetSize)). + Column(dialect.Column("created_at", wf_db.ColumnTypeDateTime, wf_db.UnsetSize).SetDefault("NOW()")). + UniqueConstraint("state"). + ToSQL() + if err != nil { + return err + } + + for _, table := range []string{createTableUsersOauth, createTableOauthClientState} { + if _, err := tx.ExecContext(ctx, table); err != nil { + return err + } + } + return nil + }) +} diff --git a/migrations/v5.go b/migrations/v5.go new file mode 100644 index 0000000..94e3944 --- /dev/null +++ b/migrations/v5.go @@ -0,0 +1,67 @@ +package migrations + +import ( + "context" + "database/sql" + + wf_db "github.com/writeas/writefreely/db" +) + +func oauthSlack(db *datastore) error { + dialect := wf_db.DialectMySQL + if db.driverName == driverSQLite { + dialect = wf_db.DialectSQLite + } + return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { + builders := []wf_db.SQLBuilder{ + dialect. + AlterTable("oauth_client_states"). + AddColumn(dialect. + Column( + "provider", + wf_db.ColumnTypeVarChar, + wf_db.OptionalInt{Set: true, Value: 24,})). + AddColumn(dialect. + Column( + "client_id", + wf_db.ColumnTypeVarChar, + wf_db.OptionalInt{Set: true, Value: 128,})), + dialect. + AlterTable("oauth_users"). + ChangeColumn("remote_user_id", + dialect. + Column( + "remote_user_id", + wf_db.ColumnTypeVarChar, + wf_db.OptionalInt{Set: true, Value: 128,})). + AddColumn(dialect. + Column( + "provider", + wf_db.ColumnTypeVarChar, + wf_db.OptionalInt{Set: true, Value: 24,})). + AddColumn(dialect. + Column( + "client_id", + wf_db.ColumnTypeVarChar, + wf_db.OptionalInt{Set: true, Value: 128,})). + AddColumn(dialect. + Column( + "access_token", + wf_db.ColumnTypeVarChar, + wf_db.OptionalInt{Set: true, Value: 512,})), + dialect.DropIndex("remote_user_id", "oauth_users"), + dialect.DropIndex("user_id", "oauth_users"), + dialect.CreateUniqueIndex("oauth_users", "oauth_users", "user_id", "provider", "client_id"), + } + for _, builder := range builders { + query, err := builder.ToSQL() + if err != nil { + return err + } + if _, err := tx.ExecContext(ctx, query); err != nil { + return err + } + } + return nil + }) +} diff --git a/migrations/v6.go b/migrations/v6.go new file mode 100644 index 0000000..c6f5012 --- /dev/null +++ b/migrations/v6.go @@ -0,0 +1,29 @@ +/* + * Copyright © 2019 A Bunch Tell LLC. + * + * This file is part of WriteFreely. + * + * WriteFreely is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, included + * in the LICENSE file in this source code package. + */ + +package migrations + +func supportActivityPubMentions(db *datastore) error { + t, err := db.Begin() + + _, err = t.Exec(`ALTER TABLE remoteusers ADD COLUMN handle ` + db.typeVarChar(255) + ` DEFAULT '' NOT NULL`) + if err != nil { + t.Rollback() + return err + } + + err = t.Commit() + if err != nil { + t.Rollback() + return err + } + + return nil +} diff --git a/oauth.go b/oauth.go new file mode 100644 index 0000000..caf8189 --- /dev/null +++ b/oauth.go @@ -0,0 +1,291 @@ +package writefreely + +import ( + "context" + "encoding/json" + "fmt" + "github.com/gorilla/mux" + "github.com/gorilla/sessions" + "github.com/writeas/impart" + "github.com/writeas/web-core/log" + "github.com/writeas/writefreely/config" + "io" + "io/ioutil" + "net/http" + "net/url" + "strings" + "time" +) + +// TokenResponse contains data returned when a token is created either +// through a code exchange or using a refresh token. +type TokenResponse struct { + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` + RefreshToken string `json:"refresh_token"` + TokenType string `json:"token_type"` + Error string `json:"error"` +} + +// InspectResponse contains data returned when an access token is inspected. +type InspectResponse struct { + ClientID string `json:"client_id"` + UserID string `json:"user_id"` + ExpiresAt time.Time `json:"expires_at"` + Username string `json:"username"` + DisplayName string `json:"-"` + Email string `json:"email"` + Error string `json:"error"` +} + +// tokenRequestMaxLen is the most bytes that we'll read from the /oauth/token +// endpoint. One megabyte is plenty. +const tokenRequestMaxLen = 1000000 + +// infoRequestMaxLen is the most bytes that we'll read from the +// /oauth/inspect endpoint. +const infoRequestMaxLen = 1000000 + +// OAuthDatastoreProvider provides a minimal interface of data store, config, +// and session store for use with the oauth handlers. +type OAuthDatastoreProvider interface { + DB() OAuthDatastore + Config() *config.Config + SessionStore() sessions.Store +} + +// OAuthDatastore provides a minimal interface of data store methods used in +// oauth functionality. +type OAuthDatastore interface { + GetIDForRemoteUser(context.Context, string, string, string) (int64, error) + RecordRemoteUserID(context.Context, int64, string, string, string, string) error + ValidateOAuthState(context.Context, string) (string, string, error) + GenerateOAuthState(context.Context, string, string) (string, error) + + CreateUser(*config.Config, *User, string) error + GetUserByID(int64) (*User, error) +} + +type HttpClient interface { + Do(req *http.Request) (*http.Response, error) +} + +type oauthClient interface { + GetProvider() string + GetClientID() string + GetCallbackLocation() string + buildLoginURL(state string) (string, error) + exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) + inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) +} + +type callbackProxyClient struct { + server string + callbackLocation string + httpClient HttpClient +} + +type oauthHandler struct { + Config *config.Config + DB OAuthDatastore + Store sessions.Store + EmailKey []byte + oauthClient oauthClient + callbackProxy *callbackProxyClient +} + +func (h oauthHandler) viewOauthInit(app *App, w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID()) + if err != nil { + return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"} + } + + if h.callbackProxy != nil { + if err := h.callbackProxy.register(ctx, state); err != nil { + return impart.HTTPError{http.StatusInternalServerError, "could not register state server"} + } + } + + location, err := h.oauthClient.buildLoginURL(state) + if err != nil { + return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"} + } + return impart.HTTPError{http.StatusTemporaryRedirect, location} +} + +func configureSlackOauth(parentHandler *Handler, r *mux.Router, app *App) { + if app.Config().SlackOauth.ClientID != "" { + callbackLocation := app.Config().App.Host + "/oauth/callback/slack" + + var stateRegisterClient *callbackProxyClient = nil + if app.Config().SlackOauth.CallbackProxyAPI != "" { + stateRegisterClient = &callbackProxyClient{ + server: app.Config().SlackOauth.CallbackProxyAPI, + callbackLocation: app.Config().App.Host + "/oauth/callback/slack", + httpClient: config.DefaultHTTPClient(), + } + callbackLocation = app.Config().SlackOauth.CallbackProxy + } + oauthClient := slackOauthClient{ + ClientID: app.Config().SlackOauth.ClientID, + ClientSecret: app.Config().SlackOauth.ClientSecret, + TeamID: app.Config().SlackOauth.TeamID, + HttpClient: config.DefaultHTTPClient(), + CallbackLocation: callbackLocation, + } + configureOauthRoutes(parentHandler, r, app, oauthClient, stateRegisterClient) + } +} + +func configureWriteAsOauth(parentHandler *Handler, r *mux.Router, app *App) { + if app.Config().WriteAsOauth.ClientID != "" { + callbackLocation := app.Config().App.Host + "/oauth/callback/write.as" + + var callbackProxy *callbackProxyClient = nil + if app.Config().WriteAsOauth.CallbackProxy != "" { + callbackProxy = &callbackProxyClient{ + server: app.Config().WriteAsOauth.CallbackProxyAPI, + callbackLocation: app.Config().App.Host + "/oauth/callback/write.as", + httpClient: config.DefaultHTTPClient(), + } + callbackLocation = app.Config().SlackOauth.CallbackProxy + } + + oauthClient := writeAsOauthClient{ + ClientID: app.Config().WriteAsOauth.ClientID, + ClientSecret: app.Config().WriteAsOauth.ClientSecret, + ExchangeLocation: config.OrDefaultString(app.Config().WriteAsOauth.TokenLocation, writeAsExchangeLocation), + InspectLocation: config.OrDefaultString(app.Config().WriteAsOauth.InspectLocation, writeAsIdentityLocation), + AuthLocation: config.OrDefaultString(app.Config().WriteAsOauth.AuthLocation, writeAsAuthLocation), + HttpClient: config.DefaultHTTPClient(), + CallbackLocation: callbackLocation, + } + configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy) + } +} + +func configureOauthRoutes(parentHandler *Handler, r *mux.Router, app *App, oauthClient oauthClient, callbackProxy *callbackProxyClient) { + handler := &oauthHandler{ + Config: app.Config(), + DB: app.DB(), + Store: app.SessionStore(), + oauthClient: oauthClient, + EmailKey: app.keys.EmailKey, + callbackProxy: callbackProxy, + } + r.HandleFunc("/oauth/"+oauthClient.GetProvider(), parentHandler.OAuth(handler.viewOauthInit)).Methods("GET") + r.HandleFunc("/oauth/callback/"+oauthClient.GetProvider(), parentHandler.OAuth(handler.viewOauthCallback)).Methods("GET") + r.HandleFunc("/oauth/signup", parentHandler.OAuth(handler.viewOauthSignup)).Methods("POST") +} + +func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + + code := r.FormValue("code") + state := r.FormValue("state") + + provider, clientID, err := h.DB.ValidateOAuthState(ctx, state) + if err != nil { + log.Error("Unable to ValidateOAuthState: %s", err) + return impart.HTTPError{http.StatusInternalServerError, err.Error()} + } + + tokenResponse, err := h.oauthClient.exchangeOauthCode(ctx, code) + if err != nil { + log.Error("Unable to exchangeOauthCode: %s", err) + return impart.HTTPError{http.StatusInternalServerError, err.Error()} + } + + // Now that we have the access token, let's use it real quick to make sur + // it really really works. + tokenInfo, err := h.oauthClient.inspectOauthAccessToken(ctx, tokenResponse.AccessToken) + if err != nil { + log.Error("Unable to inspectOauthAccessToken: %s", err) + return impart.HTTPError{http.StatusInternalServerError, err.Error()} + } + + localUserID, err := h.DB.GetIDForRemoteUser(ctx, tokenInfo.UserID, provider, clientID) + if err != nil { + log.Error("Unable to GetIDForRemoteUser: %s", err) + return impart.HTTPError{http.StatusInternalServerError, err.Error()} + } + + if localUserID != -1 { + user, err := h.DB.GetUserByID(localUserID) + if err != nil { + log.Error("Unable to GetUserByID %d: %s", localUserID, err) + return impart.HTTPError{http.StatusInternalServerError, err.Error()} + } + if err = loginOrFail(h.Store, w, r, user); err != nil { + log.Error("Unable to loginOrFail %d: %s", localUserID, err) + return impart.HTTPError{http.StatusInternalServerError, err.Error()} + } + return nil + } + + displayName := tokenInfo.DisplayName + if len(displayName) == 0 { + displayName = tokenInfo.Username + } + + tp := &oauthSignupPageParams{ + AccessToken: tokenResponse.AccessToken, + TokenUsername: tokenInfo.Username, + TokenAlias: tokenInfo.DisplayName, + TokenEmail: tokenInfo.Email, + TokenRemoteUser: tokenInfo.UserID, + Provider: provider, + ClientID: clientID, + } + tp.TokenHash = tp.HashTokenParams(h.Config.Server.HashSeed) + + return h.showOauthSignupPage(app, w, r, tp, nil) +} + +func (r *callbackProxyClient) register(ctx context.Context, state string) error { + form := url.Values{} + form.Add("state", state) + form.Add("location", r.callbackLocation) + req, err := http.NewRequestWithContext(ctx, "POST", r.server, strings.NewReader(form.Encode())) + if err != nil { + return err + } + req.Header.Set("User-Agent", "writefreely") + req.Header.Set("Accept", "application/json") + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := r.httpClient.Do(req) + if err != nil { + return err + } + if resp.StatusCode != http.StatusCreated { + return fmt.Errorf("unable register state location: %d", resp.StatusCode) + } + + return nil +} + +func limitedJsonUnmarshal(body io.ReadCloser, n int, thing interface{}) error { + lr := io.LimitReader(body, int64(n+1)) + data, err := ioutil.ReadAll(lr) + if err != nil { + return err + } + if len(data) == n+1 { + return fmt.Errorf("content larger than max read allowance: %d", n) + } + return json.Unmarshal(data, thing) +} + +func loginOrFail(store sessions.Store, w http.ResponseWriter, r *http.Request, user *User) error { + // An error may be returned, but a valid session should always be returned. + session, _ := store.Get(r, cookieName) + session.Values[cookieUserVal] = user.Cookie() + if err := session.Save(r, w); err != nil { + fmt.Println("error saving session", err) + return err + } + http.Redirect(w, r, "/", http.StatusTemporaryRedirect) + return nil +} diff --git a/oauth/state.go b/oauth/state.go new file mode 100644 index 0000000..e8dd154 --- /dev/null +++ b/oauth/state.go @@ -0,0 +1,10 @@ +package oauth + +import "context" + +// ClientStateStore provides state management used by the OAuth client. +type ClientStateStore interface { + Generate(ctx context.Context) (string, error) + Validate(ctx context.Context, state string) error +} + diff --git a/oauth_signup.go b/oauth_signup.go new file mode 100644 index 0000000..220afbd --- /dev/null +++ b/oauth_signup.go @@ -0,0 +1,218 @@ +/* + * Copyright © 2020 A Bunch Tell LLC. + * + * This file is part of WriteFreely. + * + * WriteFreely is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, included + * in the LICENSE file in this source code package. + */ + +package writefreely + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "github.com/writeas/impart" + "github.com/writeas/web-core/auth" + "github.com/writeas/web-core/log" + "github.com/writeas/writefreely/page" + "html/template" + "net/http" + "strings" + "time" +) + +type viewOauthSignupVars struct { + page.StaticPage + To string + Message template.HTML + Flashes []template.HTML + + AccessToken string + TokenUsername string + TokenAlias string // TODO: rename this to match the data it represents: the collection title + TokenEmail string + TokenRemoteUser string + Provider string + ClientID string + TokenHash string + + LoginUsername string + Alias string // TODO: rename this to match the data it represents: the collection title + Email string +} + +const ( + oauthParamAccessToken = "access_token" + oauthParamTokenUsername = "token_username" + oauthParamTokenAlias = "token_alias" + oauthParamTokenEmail = "token_email" + oauthParamTokenRemoteUserID = "token_remote_user" + oauthParamClientID = "client_id" + oauthParamProvider = "provider" + oauthParamHash = "signature" + oauthParamUsername = "username" + oauthParamAlias = "alias" + oauthParamEmail = "email" + oauthParamPassword = "password" +) + +type oauthSignupPageParams struct { + AccessToken string + TokenUsername string + TokenAlias string // TODO: rename this to match the data it represents: the collection title + TokenEmail string + TokenRemoteUser string + ClientID string + Provider string + TokenHash string +} + +func (p oauthSignupPageParams) HashTokenParams(key string) string { + hasher := sha256.New() + hasher.Write([]byte(key)) + hasher.Write([]byte(p.AccessToken)) + hasher.Write([]byte(p.TokenUsername)) + hasher.Write([]byte(p.TokenAlias)) + hasher.Write([]byte(p.TokenEmail)) + hasher.Write([]byte(p.TokenRemoteUser)) + hasher.Write([]byte(p.ClientID)) + hasher.Write([]byte(p.Provider)) + return hex.EncodeToString(hasher.Sum(nil)) +} + +func (h oauthHandler) viewOauthSignup(app *App, w http.ResponseWriter, r *http.Request) error { + tp := &oauthSignupPageParams{ + AccessToken: r.FormValue(oauthParamAccessToken), + TokenUsername: r.FormValue(oauthParamTokenUsername), + TokenAlias: r.FormValue(oauthParamTokenAlias), + TokenEmail: r.FormValue(oauthParamTokenEmail), + TokenRemoteUser: r.FormValue(oauthParamTokenRemoteUserID), + ClientID: r.FormValue(oauthParamClientID), + Provider: r.FormValue(oauthParamProvider), + } + if tp.HashTokenParams(h.Config.Server.HashSeed) != r.FormValue(oauthParamHash) { + return impart.HTTPError{Status: http.StatusBadRequest, Message: "Request has been tampered with."} + } + tp.TokenHash = tp.HashTokenParams(h.Config.Server.HashSeed) + if err := h.validateOauthSignup(r); err != nil { + return h.showOauthSignupPage(app, w, r, tp, err) + } + + var err error + hashedPass := []byte{} + clearPass := r.FormValue(oauthParamPassword) + hasPass := clearPass != "" + if hasPass { + hashedPass, err = auth.HashPass([]byte(clearPass)) + if err != nil { + return h.showOauthSignupPage(app, w, r, tp, fmt.Errorf("unable to hash password")) + } + } + newUser := &User{ + Username: r.FormValue(oauthParamUsername), + HashedPass: hashedPass, + HasPass: hasPass, + Email: prepareUserEmail(r.FormValue(oauthParamEmail), h.EmailKey), + Created: time.Now().Truncate(time.Second).UTC(), + } + displayName := r.FormValue(oauthParamAlias) + if len(displayName) == 0 { + displayName = r.FormValue(oauthParamUsername) + } + + err = h.DB.CreateUser(h.Config, newUser, displayName) + if err != nil { + return h.showOauthSignupPage(app, w, r, tp, err) + } + + err = h.DB.RecordRemoteUserID(r.Context(), newUser.ID, r.FormValue(oauthParamTokenRemoteUserID), r.FormValue(oauthParamProvider), r.FormValue(oauthParamClientID), r.FormValue(oauthParamAccessToken)) + if err != nil { + return h.showOauthSignupPage(app, w, r, tp, err) + } + + if err := loginOrFail(h.Store, w, r, newUser); err != nil { + return h.showOauthSignupPage(app, w, r, tp, err) + } + return nil +} + +func (h oauthHandler) validateOauthSignup(r *http.Request) error { + username := r.FormValue(oauthParamUsername) + if len(username) < h.Config.App.MinUsernameLen { + return impart.HTTPError{Status: http.StatusBadRequest, Message: "Username is too short."} + } + if len(username) > 100 { + return impart.HTTPError{Status: http.StatusBadRequest, Message: "Username is too long."} + } + collTitle := r.FormValue(oauthParamAlias) + if len(collTitle) == 0 { + collTitle = username + } + email := r.FormValue(oauthParamEmail) + if len(email) > 0 { + parts := strings.Split(email, "@") + if len(parts) != 2 || (len(parts[0]) < 1 || len(parts[1]) < 1) { + return impart.HTTPError{Status: http.StatusBadRequest, Message: "Invalid email address"} + } + } + return nil +} + +func (h oauthHandler) showOauthSignupPage(app *App, w http.ResponseWriter, r *http.Request, tp *oauthSignupPageParams, errMsg error) error { + username := tp.TokenUsername + collTitle := tp.TokenAlias + email := tp.TokenEmail + + session, err := app.sessionStore.Get(r, cookieName) + if err != nil { + // Ignore this + log.Error("Unable to get session; ignoring: %v", err) + } + + if tmpValue := r.FormValue(oauthParamUsername); len(tmpValue) > 0 { + username = tmpValue + } + if tmpValue := r.FormValue(oauthParamAlias); len(tmpValue) > 0 { + collTitle = tmpValue + } + if tmpValue := r.FormValue(oauthParamEmail); len(tmpValue) > 0 { + email = tmpValue + } + + p := &viewOauthSignupVars{ + StaticPage: pageForReq(app, r), + To: r.FormValue("to"), + Flashes: []template.HTML{}, + + AccessToken: tp.AccessToken, + TokenUsername: tp.TokenUsername, + TokenAlias: tp.TokenAlias, + TokenEmail: tp.TokenEmail, + TokenRemoteUser: tp.TokenRemoteUser, + Provider: tp.Provider, + ClientID: tp.ClientID, + TokenHash: tp.TokenHash, + + LoginUsername: username, + Alias: collTitle, + Email: email, + } + + // Display any error messages + flashes, _ := getSessionFlashes(app, w, r, session) + for _, flash := range flashes { + p.Flashes = append(p.Flashes, template.HTML(flash)) + } + if errMsg != nil { + p.Flashes = append(p.Flashes, template.HTML(errMsg.Error())) + } + err = pages["signup-oauth.tmpl"].ExecuteTemplate(w, "base", p) + if err != nil { + log.Error("Unable to render signup-oauth: %v", err) + return err + } + return nil +} diff --git a/oauth_slack.go b/oauth_slack.go new file mode 100644 index 0000000..35db156 --- /dev/null +++ b/oauth_slack.go @@ -0,0 +1,180 @@ +/* + * Copyright © 2019-2020 A Bunch Tell LLC. + * + * This file is part of WriteFreely. + * + * WriteFreely is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, included + * in the LICENSE file in this source code package. + */ + +package writefreely + +import ( + "context" + "errors" + "fmt" + "github.com/writeas/nerds/store" + "github.com/writeas/slug" + "net/http" + "net/url" + "strings" +) + +type slackOauthClient struct { + ClientID string + ClientSecret string + TeamID string + CallbackLocation string + HttpClient HttpClient +} + +type slackExchangeResponse struct { + OK bool `json:"ok"` + AccessToken string `json:"access_token"` + Scope string `json:"scope"` + TeamName string `json:"team_name"` + TeamID string `json:"team_id"` + Error string `json:"error"` +} + +type slackIdentity struct { + Name string `json:"name"` + ID string `json:"id"` + Email string `json:"email"` +} + +type slackTeam struct { + Name string `json:"name"` + ID string `json:"id"` +} + +type slackUserIdentityResponse struct { + OK bool `json:"ok"` + User slackIdentity `json:"user"` + Team slackTeam `json:"team"` + Error string `json:"error"` +} + +const ( + slackAuthLocation = "https://slack.com/oauth/authorize" + slackExchangeLocation = "https://slack.com/api/oauth.access" + slackIdentityLocation = "https://slack.com/api/users.identity" +) + +var _ oauthClient = slackOauthClient{} + +func (c slackOauthClient) GetProvider() string { + return "slack" +} + +func (c slackOauthClient) GetClientID() string { + return c.ClientID +} + +func (c slackOauthClient) GetCallbackLocation() string { + return c.CallbackLocation +} + +func (c slackOauthClient) buildLoginURL(state string) (string, error) { + u, err := url.Parse(slackAuthLocation) + if err != nil { + return "", err + } + q := u.Query() + q.Set("client_id", c.ClientID) + q.Set("scope", "identity.basic identity.email identity.team") + q.Set("redirect_uri", c.CallbackLocation) + q.Set("state", state) + + // If this param is not set, the user can select which team they + // authenticate through and then we'd have to match the configured team + // against the profile get. That is extra work in the post-auth phase + // that we don't want to do. + q.Set("team", c.TeamID) + + // The Slack OAuth docs don't explicitly list this one, but it is part of + // the spec, so we include it anyway. + q.Set("response_type", "code") + u.RawQuery = q.Encode() + return u.String(), nil +} + +func (c slackOauthClient) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) { + form := url.Values{} + // The oauth.access documentation doesn't explicitly mention this + // parameter, but it is part of the spec, so we include it anyway. + // https://api.slack.com/methods/oauth.access + form.Add("grant_type", "authorization_code") + form.Add("redirect_uri", c.CallbackLocation) + form.Add("code", code) + req, err := http.NewRequest("POST", slackExchangeLocation, strings.NewReader(form.Encode())) + if err != nil { + return nil, err + } + req.WithContext(ctx) + req.Header.Set("User-Agent", "writefreely") + req.Header.Set("Accept", "application/json") + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.SetBasicAuth(c.ClientID, c.ClientSecret) + + resp, err := c.HttpClient.Do(req) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + return nil, errors.New("unable to exchange code for access token") + } + + var tokenResponse slackExchangeResponse + if err := limitedJsonUnmarshal(resp.Body, tokenRequestMaxLen, &tokenResponse); err != nil { + return nil, err + } + if !tokenResponse.OK { + return nil, errors.New(tokenResponse.Error) + } + return tokenResponse.TokenResponse(), nil +} + +func (c slackOauthClient) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) { + req, err := http.NewRequest("GET", slackIdentityLocation, nil) + if err != nil { + return nil, err + } + req.WithContext(ctx) + req.Header.Set("User-Agent", "writefreely") + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+accessToken) + + resp, err := c.HttpClient.Do(req) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + return nil, errors.New("unable to inspect access token") + } + + var inspectResponse slackUserIdentityResponse + if err := limitedJsonUnmarshal(resp.Body, infoRequestMaxLen, &inspectResponse); err != nil { + return nil, err + } + if !inspectResponse.OK { + return nil, errors.New(inspectResponse.Error) + } + return inspectResponse.InspectResponse(), nil +} + +func (resp slackUserIdentityResponse) InspectResponse() *InspectResponse { + return &InspectResponse{ + UserID: resp.User.ID, + Username: fmt.Sprintf("%s-%s", slug.Make(resp.User.Name), store.GenerateRandomString("0123456789bcdfghjklmnpqrstvwxyz", 5)), + DisplayName: resp.User.Name, + Email: resp.User.Email, + } +} + +func (resp slackExchangeResponse) TokenResponse() *TokenResponse { + return &TokenResponse{ + AccessToken: resp.AccessToken, + } +} diff --git a/oauth_test.go b/oauth_test.go new file mode 100644 index 0000000..2e293e7 --- /dev/null +++ b/oauth_test.go @@ -0,0 +1,253 @@ +package writefreely + +import ( + "context" + "fmt" + "github.com/gorilla/sessions" + "github.com/stretchr/testify/assert" + "github.com/writeas/impart" + "github.com/writeas/nerds/store" + "github.com/writeas/writefreely/config" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" +) + +type MockOAuthDatastoreProvider struct { + DoDB func() OAuthDatastore + DoConfig func() *config.Config + DoSessionStore func() sessions.Store +} + +type MockOAuthDatastore struct { + DoGenerateOAuthState func(context.Context, string, string) (string, error) + DoValidateOAuthState func(context.Context, string) (string, string, error) + DoGetIDForRemoteUser func(context.Context, string, string, string) (int64, error) + DoCreateUser func(*config.Config, *User, string) error + DoRecordRemoteUserID func(context.Context, int64, string, string, string, string) error + DoGetUserByID func(int64) (*User, error) +} + +var _ OAuthDatastore = &MockOAuthDatastore{} + +type StringReadCloser struct { + *strings.Reader +} + +func (src *StringReadCloser) Close() error { + return nil +} + +type MockHTTPClient struct { + DoDo func(req *http.Request) (*http.Response, error) +} + +func (m *MockHTTPClient) Do(req *http.Request) (*http.Response, error) { + if m.DoDo != nil { + return m.DoDo(req) + } + return &http.Response{}, nil +} + +func (m *MockOAuthDatastoreProvider) SessionStore() sessions.Store { + if m.DoSessionStore != nil { + return m.DoSessionStore() + } + return sessions.NewCookieStore([]byte("secret-key")) +} + +func (m *MockOAuthDatastoreProvider) DB() OAuthDatastore { + if m.DoDB != nil { + return m.DoDB() + } + return &MockOAuthDatastore{} +} + +func (m *MockOAuthDatastoreProvider) Config() *config.Config { + if m.DoConfig != nil { + return m.DoConfig() + } + cfg := config.New() + cfg.UseSQLite(true) + cfg.WriteAsOauth = config.WriteAsOauthCfg{ + ClientID: "development", + ClientSecret: "development", + AuthLocation: "https://write.as/oauth/login", + TokenLocation: "https://write.as/oauth/token", + InspectLocation: "https://write.as/oauth/inspect", + } + cfg.SlackOauth = config.SlackOauthCfg{ + ClientID: "development", + ClientSecret: "development", + TeamID: "development", + } + return cfg +} + +func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) (string, string, error) { + if m.DoValidateOAuthState != nil { + return m.DoValidateOAuthState(ctx, state) + } + return "", "", nil +} + +func (m *MockOAuthDatastore) GetIDForRemoteUser(ctx context.Context, remoteUserID, provider, clientID string) (int64, error) { + if m.DoGetIDForRemoteUser != nil { + return m.DoGetIDForRemoteUser(ctx, remoteUserID, provider, clientID) + } + return -1, nil +} + +func (m *MockOAuthDatastore) CreateUser(cfg *config.Config, u *User, username string) error { + if m.DoCreateUser != nil { + return m.DoCreateUser(cfg, u, username) + } + u.ID = 1 + return nil +} + +func (m *MockOAuthDatastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID, provider, clientID, accessToken string) error { + if m.DoRecordRemoteUserID != nil { + return m.DoRecordRemoteUserID(ctx, localUserID, remoteUserID, provider, clientID, accessToken) + } + return nil +} + +func (m *MockOAuthDatastore) GetUserByID(userID int64) (*User, error) { + if m.DoGetUserByID != nil { + return m.DoGetUserByID(userID) + } + user := &User{ + + } + return user, nil +} + +func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context, provider string, clientID string) (string, error) { + if m.DoGenerateOAuthState != nil { + return m.DoGenerateOAuthState(ctx, provider, clientID) + } + return store.Generate62RandomString(14), nil +} + +func TestViewOauthInit(t *testing.T) { + + t.Run("success", func(t *testing.T) { + app := &MockOAuthDatastoreProvider{} + h := oauthHandler{ + Config: app.Config(), + DB: app.DB(), + Store: app.SessionStore(), + EmailKey: []byte{0xd, 0xe, 0xc, 0xa, 0xf, 0xf, 0xb, 0xa, 0xd}, + oauthClient: writeAsOauthClient{ + ClientID: app.Config().WriteAsOauth.ClientID, + ClientSecret: app.Config().WriteAsOauth.ClientSecret, + ExchangeLocation: app.Config().WriteAsOauth.TokenLocation, + InspectLocation: app.Config().WriteAsOauth.InspectLocation, + AuthLocation: app.Config().WriteAsOauth.AuthLocation, + CallbackLocation: "http://localhost/oauth/callback", + HttpClient: nil, + }, + } + req, err := http.NewRequest("GET", "/oauth/client", nil) + assert.NoError(t, err) + rr := httptest.NewRecorder() + err = h.viewOauthInit(nil, rr, req) + assert.NotNil(t, err) + httpErr, ok := err.(impart.HTTPError) + assert.True(t, ok) + assert.Equal(t, http.StatusTemporaryRedirect, httpErr.Status) + assert.NotEmpty(t, httpErr.Message) + locURI, err := url.Parse(httpErr.Message) + assert.NoError(t, err) + assert.Equal(t, "/oauth/login", locURI.Path) + assert.Equal(t, "development", locURI.Query().Get("client_id")) + assert.Equal(t, "http://localhost/oauth/callback", locURI.Query().Get("redirect_uri")) + assert.Equal(t, "code", locURI.Query().Get("response_type")) + assert.NotEmpty(t, locURI.Query().Get("state")) + }) + + t.Run("state failure", func(t *testing.T) { + app := &MockOAuthDatastoreProvider{ + DoDB: func() OAuthDatastore { + return &MockOAuthDatastore{ + DoGenerateOAuthState: func(ctx context.Context, provider, clientID string) (string, error) { + return "", fmt.Errorf("pretend unable to write state error") + }, + } + }, + } + h := oauthHandler{ + Config: app.Config(), + DB: app.DB(), + Store: app.SessionStore(), + EmailKey: []byte{0xd, 0xe, 0xc, 0xa, 0xf, 0xf, 0xb, 0xa, 0xd}, + oauthClient: writeAsOauthClient{ + ClientID: app.Config().WriteAsOauth.ClientID, + ClientSecret: app.Config().WriteAsOauth.ClientSecret, + ExchangeLocation: app.Config().WriteAsOauth.TokenLocation, + InspectLocation: app.Config().WriteAsOauth.InspectLocation, + AuthLocation: app.Config().WriteAsOauth.AuthLocation, + CallbackLocation: "http://localhost/oauth/callback", + HttpClient: nil, + }, + } + req, err := http.NewRequest("GET", "/oauth/client", nil) + assert.NoError(t, err) + rr := httptest.NewRecorder() + err = h.viewOauthInit(nil, rr, req) + httpErr, ok := err.(impart.HTTPError) + assert.True(t, ok) + assert.NotEmpty(t, httpErr.Message) + assert.Equal(t, http.StatusInternalServerError, httpErr.Status) + assert.Equal(t, "could not prepare oauth redirect url", httpErr.Message) + }) +} + +func TestViewOauthCallback(t *testing.T) { + t.Run("success", func(t *testing.T) { + app := &MockOAuthDatastoreProvider{} + h := oauthHandler{ + Config: app.Config(), + DB: app.DB(), + Store: app.SessionStore(), + EmailKey: []byte{0xd, 0xe, 0xc, 0xa, 0xf, 0xf, 0xb, 0xa, 0xd}, + oauthClient: writeAsOauthClient{ + ClientID: app.Config().WriteAsOauth.ClientID, + ClientSecret: app.Config().WriteAsOauth.ClientSecret, + ExchangeLocation: app.Config().WriteAsOauth.TokenLocation, + InspectLocation: app.Config().WriteAsOauth.InspectLocation, + AuthLocation: app.Config().WriteAsOauth.AuthLocation, + CallbackLocation: "http://localhost/oauth/callback", + HttpClient: &MockHTTPClient{ + DoDo: func(req *http.Request) (*http.Response, error) { + switch req.URL.String() { + case "https://write.as/oauth/token": + return &http.Response{ + StatusCode: 200, + Body: &StringReadCloser{strings.NewReader(`{"access_token": "access_token", "expires_in": 1000, "refresh_token": "refresh_token", "token_type": "access"}`)}, + }, nil + case "https://write.as/oauth/inspect": + return &http.Response{ + StatusCode: 200, + Body: &StringReadCloser{strings.NewReader(`{"client_id": "development", "user_id": "1", "expires_at": "2019-12-19T11:42:01Z", "username": "nick", "email": "nick@testing.write.as"}`)}, + }, nil + } + + return &http.Response{ + StatusCode: http.StatusNotFound, + }, nil + }, + }, + }, + } + req, err := http.NewRequest("GET", "/oauth/callback", nil) + assert.NoError(t, err) + rr := httptest.NewRecorder() + err = h.viewOauthCallback(nil, rr, req) + assert.NoError(t, err) + assert.Equal(t, http.StatusTemporaryRedirect, rr.Code) + }) +} diff --git a/oauth_writeas.go b/oauth_writeas.go new file mode 100644 index 0000000..6251a16 --- /dev/null +++ b/oauth_writeas.go @@ -0,0 +1,114 @@ +package writefreely + +import ( + "context" + "errors" + "net/http" + "net/url" + "strings" +) + +type writeAsOauthClient struct { + ClientID string + ClientSecret string + AuthLocation string + ExchangeLocation string + InspectLocation string + CallbackLocation string + HttpClient HttpClient +} + +var _ oauthClient = writeAsOauthClient{} + +const ( + writeAsAuthLocation = "https://write.as/oauth/login" + writeAsExchangeLocation = "https://write.as/oauth/token" + writeAsIdentityLocation = "https://write.as/oauth/inspect" +) + +func (c writeAsOauthClient) GetProvider() string { + return "write.as" +} + +func (c writeAsOauthClient) GetClientID() string { + return c.ClientID +} + +func (c writeAsOauthClient) GetCallbackLocation() string { + return c.CallbackLocation +} + +func (c writeAsOauthClient) buildLoginURL(state string) (string, error) { + u, err := url.Parse(c.AuthLocation) + if err != nil { + return "", err + } + q := u.Query() + q.Set("client_id", c.ClientID) + q.Set("redirect_uri", c.CallbackLocation) + q.Set("response_type", "code") + q.Set("state", state) + u.RawQuery = q.Encode() + return u.String(), nil +} + +func (c writeAsOauthClient) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) { + form := url.Values{} + form.Add("grant_type", "authorization_code") + form.Add("redirect_uri", c.CallbackLocation) + form.Add("code", code) + req, err := http.NewRequest("POST", c.ExchangeLocation, strings.NewReader(form.Encode())) + if err != nil { + return nil, err + } + req.WithContext(ctx) + req.Header.Set("User-Agent", "writefreely") + req.Header.Set("Accept", "application/json") + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.SetBasicAuth(c.ClientID, c.ClientSecret) + + resp, err := c.HttpClient.Do(req) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + return nil, errors.New("unable to exchange code for access token") + } + + var tokenResponse TokenResponse + if err := limitedJsonUnmarshal(resp.Body, tokenRequestMaxLen, &tokenResponse); err != nil { + return nil, err + } + if tokenResponse.Error != "" { + return nil, errors.New(tokenResponse.Error) + } + return &tokenResponse, nil +} + +func (c writeAsOauthClient) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) { + req, err := http.NewRequest("GET", c.InspectLocation, nil) + if err != nil { + return nil, err + } + req.WithContext(ctx) + req.Header.Set("User-Agent", "writefreely") + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+accessToken) + + resp, err := c.HttpClient.Do(req) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + return nil, errors.New("unable to inspect access token") + } + + var inspectResponse InspectResponse + if err := limitedJsonUnmarshal(resp.Body, infoRequestMaxLen, &inspectResponse); err != nil { + return nil, err + } + if inspectResponse.Error != "" { + return nil, errors.New(inspectResponse.Error) + } + return &inspectResponse, nil +} diff --git a/pad.go b/pad.go index 3cb7f37..0354cd3 100644 --- a/pad.go +++ b/pad.go @@ -35,9 +35,10 @@ func handleViewPad(app *App, w http.ResponseWriter, r *http.Request) error { } appData := &struct { page.StaticPage - Post *RawPost - User *User - Blogs *[]Collection + Post *RawPost + User *User + Blogs *[]Collection + Silenced bool Editing bool // True if we're modifying an existing post EditCollection *Collection // Collection of the post we're editing, if any @@ -52,11 +53,17 @@ func handleViewPad(app *App, w http.ResponseWriter, r *http.Request) error { if err != nil { log.Error("Unable to get user's blogs for Pad: %v", err) } + appData.Silenced, err = app.db.IsUserSilenced(appData.User.ID) + if err != nil { + log.Error("Unable to get user status for Pad: %v", err) + } } padTmpl := app.cfg.App.Editor if templates[padTmpl] == nil { - log.Info("No template '%s' found. Falling back to default 'pad' template.", padTmpl) + if padTmpl != "" { + log.Info("No template '%s' found. Falling back to default 'pad' template.", padTmpl) + } padTmpl = "pad" } @@ -85,6 +92,7 @@ func handleViewPad(app *App, w http.ResponseWriter, r *http.Request) error { if err != nil { return err } + appData.EditCollection.hostName = app.cfg.App.Host } else { // Editing a floating article appData.Post = getRawPost(app, action) @@ -119,12 +127,18 @@ func handleViewMeta(app *App, w http.ResponseWriter, r *http.Request) error { EditCollection *Collection // Collection of the post we're editing, if any Flashes []string NeedsToken bool + Silenced bool }{ StaticPage: pageForReq(app, r), Post: &RawPost{Font: "norm"}, User: getUserSession(app, r), } var err error + appData.Silenced, err = app.db.IsUserSilenced(appData.User.ID) + if err != nil { + log.Error("view meta: get user status: %v", err) + return ErrInternalGeneral + } if action == "" && slug == "" { return ErrPostNotFound @@ -148,6 +162,7 @@ func handleViewMeta(app *App, w http.ResponseWriter, r *http.Request) error { if err != nil { return err } + appData.EditCollection.hostName = app.cfg.App.Host } else { // Editing a floating article appData.Post = getRawPost(app, action) diff --git a/pages/login.tmpl b/pages/login.tmpl index 1c8e862..345b171 100644 --- a/pages/login.tmpl +++ b/pages/login.tmpl @@ -1,7 +1,38 @@ {{define "head"}}or
+