diff --git a/integrations/repofiles_update_test.go b/integrations/repofiles_update_test.go index a7beec49553..c422483bf87 100644 --- a/integrations/repofiles_update_test.go +++ b/integrations/repofiles_update_test.go @@ -207,11 +207,14 @@ func TestCreateOrUpdateRepoFileForCreate(t *testing.T) { commitID, _ := gitRepo.GetBranchCommitID(opts.NewBranch) expectedFileResponse := getExpectedFileResponseForRepofilesCreate(commitID) - assert.EqualValues(t, expectedFileResponse.Content, fileResponse.Content) - assert.EqualValues(t, expectedFileResponse.Commit.SHA, fileResponse.Commit.SHA) - assert.EqualValues(t, expectedFileResponse.Commit.HTMLURL, fileResponse.Commit.HTMLURL) - assert.EqualValues(t, expectedFileResponse.Commit.Author.Email, fileResponse.Commit.Author.Email) - assert.EqualValues(t, expectedFileResponse.Commit.Author.Name, fileResponse.Commit.Author.Name) + assert.NotNil(t, expectedFileResponse) + if expectedFileResponse != nil { + assert.EqualValues(t, expectedFileResponse.Content, fileResponse.Content) + assert.EqualValues(t, expectedFileResponse.Commit.SHA, fileResponse.Commit.SHA) + assert.EqualValues(t, expectedFileResponse.Commit.HTMLURL, fileResponse.Commit.HTMLURL) + assert.EqualValues(t, expectedFileResponse.Commit.Author.Email, fileResponse.Commit.Author.Email) + assert.EqualValues(t, expectedFileResponse.Commit.Author.Name, fileResponse.Commit.Author.Name) + } }) } diff --git a/models/issue_watch.go b/models/issue_watch.go index c4732d784e1..9046e4d2f75 100644 --- a/models/issue_watch.go +++ b/models/issue_watch.go @@ -68,10 +68,14 @@ func getIssueWatch(e Engine, userID, issueID int64) (iw *IssueWatch, exists bool // but avoids joining with `user` for performance reasons // User permissions must be verified elsewhere if required func GetIssueWatchersIDs(issueID int64) ([]int64, error) { + return getIssueWatchersIDs(x, issueID, true) +} + +func getIssueWatchersIDs(e Engine, issueID int64, watching bool) ([]int64, error) { ids := make([]int64, 0, 64) - return ids, x.Table("issue_watch"). + return ids, e.Table("issue_watch"). Where("issue_id=?", issueID). - And("is_watching = ?", true). + And("is_watching = ?", watching). Select("user_id"). Find(&ids) } @@ -99,39 +103,9 @@ func getIssueWatchers(e Engine, issueID int64, listOptions ListOptions) (IssueWa } func removeIssueWatchersByRepoID(e Engine, userID int64, repoID int64) error { - iw := &IssueWatch{ - IsWatching: false, - } _, err := e. Join("INNER", "issue", "`issue`.id = `issue_watch`.issue_id AND `issue`.repo_id = ?", repoID). - Cols("is_watching", "updated_unix"). Where("`issue_watch`.user_id = ?", userID). - Update(iw) + Delete(new(IssueWatch)) return err } - -// LoadWatchUsers return watching users -func (iwl IssueWatchList) LoadWatchUsers() (users UserList, err error) { - return iwl.loadWatchUsers(x) -} - -func (iwl IssueWatchList) loadWatchUsers(e Engine) (users UserList, err error) { - if len(iwl) == 0 { - return []*User{}, nil - } - - var userIDs = make([]int64, 0, len(iwl)) - for _, iw := range iwl { - if iw.IsWatching { - userIDs = append(userIDs, iw.UserID) - } - } - - if len(userIDs) == 0 { - return []*User{}, nil - } - - err = e.In("id", userIDs).Find(&users) - - return -} diff --git a/models/notification.go b/models/notification.go index e7217a6e047..c52d6c557a5 100644 --- a/models/notification.go +++ b/models/notification.go @@ -133,55 +133,42 @@ func CreateOrUpdateIssueNotifications(issueID, commentID int64, notificationAuth } func createOrUpdateIssueNotifications(e Engine, issueID, commentID int64, notificationAuthorID int64) error { - issueWatches, err := getIssueWatchers(e, issueID, ListOptions{}) + // init + toNotify := make(map[int64]struct{}, 32) + notifications, err := getNotificationsByIssueID(e, issueID) if err != nil { return err } - issue, err := getIssueByID(e, issueID) if err != nil { return err } - watches, err := getWatchers(e, issue.RepoID) + issueWatches, err := getIssueWatchersIDs(e, issueID, true) if err != nil { return err } + for _, id := range issueWatches { + toNotify[id] = struct{}{} + } - notifications, err := getNotificationsByIssueID(e, issueID) + repoWatches, err := getRepoWatchersIDs(e, issue.RepoID) if err != nil { return err } - - alreadyNotified := make(map[int64]struct{}, len(issueWatches)+len(watches)) - - notifyUser := func(userID int64) error { - // do not send notification for the own issuer/commenter - if userID == notificationAuthorID { - return nil - } - - if _, ok := alreadyNotified[userID]; ok { - return nil - } - alreadyNotified[userID] = struct{}{} - - if notificationExists(notifications, issue.ID, userID) { - return updateIssueNotification(e, userID, issue.ID, commentID, notificationAuthorID) - } - return createIssueNotification(e, userID, issue, commentID, notificationAuthorID) + for _, id := range repoWatches { + toNotify[id] = struct{}{} } - for _, issueWatch := range issueWatches { - // ignore if user unwatched the issue - if !issueWatch.IsWatching { - alreadyNotified[issueWatch.UserID] = struct{}{} - continue - } - - if err := notifyUser(issueWatch.UserID); err != nil { - return err - } + // dont notify user who cause notification + delete(toNotify, notificationAuthorID) + // explicit unwatch on issue + issueUnWatches, err := getIssueWatchersIDs(e, issueID, false) + if err != nil { + return err + } + for _, id := range issueUnWatches { + delete(toNotify, id) } err = issue.loadRepo(e) @@ -189,16 +176,23 @@ func createOrUpdateIssueNotifications(e Engine, issueID, commentID int64, notifi return err } - for _, watch := range watches { + // notify + for userID := range toNotify { issue.Repo.Units = nil - if issue.IsPull && !issue.Repo.checkUnitUser(e, watch.UserID, false, UnitTypePullRequests) { + if issue.IsPull && !issue.Repo.checkUnitUser(e, userID, false, UnitTypePullRequests) { continue } - if !issue.IsPull && !issue.Repo.checkUnitUser(e, watch.UserID, false, UnitTypeIssues) { + if !issue.IsPull && !issue.Repo.checkUnitUser(e, userID, false, UnitTypeIssues) { continue } - if err := notifyUser(watch.UserID); err != nil { + if notificationExists(notifications, issue.ID, userID) { + if err = updateIssueNotification(e, userID, issue.ID, commentID, notificationAuthorID); err != nil { + return err + } + continue + } + if err = createIssueNotification(e, userID, issue, commentID, notificationAuthorID); err != nil { return err } } diff --git a/models/repo_watch.go b/models/repo_watch.go index a9d56eff03d..11cfa889184 100644 --- a/models/repo_watch.go +++ b/models/repo_watch.go @@ -144,8 +144,12 @@ func GetWatchers(repoID int64) ([]*Watch, error) { // but avoids joining with `user` for performance reasons // User permissions must be verified elsewhere if required func GetRepoWatchersIDs(repoID int64) ([]int64, error) { + return getRepoWatchersIDs(x, repoID) +} + +func getRepoWatchersIDs(e Engine, repoID int64) ([]int64, error) { ids := make([]int64, 0, 64) - return ids, x.Table("watch"). + return ids, e.Table("watch"). Where("watch.repo_id=?", repoID). And("watch.mode<>?", RepoWatchModeDont). Select("user_id"). diff --git a/models/user.go b/models/user.go index bf59c1240bc..8be15ba6df8 100644 --- a/models/user.go +++ b/models/user.go @@ -1409,7 +1409,7 @@ func GetUserNamesByIDs(ids []int64) ([]string, error) { } // GetUsersByIDs returns all resolved users from a list of Ids. -func GetUsersByIDs(ids []int64) ([]*User, error) { +func GetUsersByIDs(ids []int64) (UserList, error) { ous := make([]*User, 0, len(ids)) if len(ids) == 0 { return ous, nil diff --git a/modules/git/repo_branch.go b/modules/git/repo_branch.go index e79bab76a6f..3d0e6497edc 100644 --- a/modules/git/repo_branch.go +++ b/modules/git/repo_branch.go @@ -48,6 +48,9 @@ type Branch struct { // GetHEADBranch returns corresponding branch of HEAD. func (repo *Repository) GetHEADBranch() (*Branch, error) { + if repo == nil { + return nil, fmt.Errorf("nil repo") + } stdout, err := NewCommand("symbolic-ref", "HEAD").RunInDir(repo.Path) if err != nil { return nil, err diff --git a/modules/test/context_tests.go b/modules/test/context_tests.go index cf9c5fbc548..f9f0ec5d42c 100644 --- a/modules/test/context_tests.go +++ b/modules/test/context_tests.go @@ -58,8 +58,11 @@ func LoadRepoCommit(t *testing.T, ctx *context.Context) { defer gitRepo.Close() branch, err := gitRepo.GetHEADBranch() assert.NoError(t, err) - ctx.Repo.Commit, err = gitRepo.GetBranchCommit(branch.Name) - assert.NoError(t, err) + assert.NotNil(t, branch) + if branch != nil { + ctx.Repo.Commit, err = gitRepo.GetBranchCommit(branch.Name) + assert.NoError(t, err) + } } // LoadUser load a user into a test context. diff --git a/routers/api/v1/repo/issue_subscription.go b/routers/api/v1/repo/issue_subscription.go index 274da966fda..0406edd2078 100644 --- a/routers/api/v1/repo/issue_subscription.go +++ b/routers/api/v1/repo/issue_subscription.go @@ -190,9 +190,14 @@ func GetIssueSubscribers(ctx *context.APIContext) { return } - users, err := iwl.LoadWatchUsers() + var userIDs = make([]int64, 0, len(iwl)) + for _, iw := range iwl { + userIDs = append(userIDs, iw.UserID) + } + + users, err := models.GetUsersByIDs(userIDs) if err != nil { - ctx.Error(http.StatusInternalServerError, "LoadWatchUsers", err) + ctx.Error(http.StatusInternalServerError, "GetUsersByIDs", err) return }