Restrict API read access based on Private setting

This verifies that a user is authenticated before getting to the actual
handler on API endpoints where a user is reading content.

Ref T576
pull/123/head
Matt Baer 6 years ago
parent b3a36a3be7
commit a2088c1646
  1. 111
      handle.go
  2. 18
      routes.go

@ -237,22 +237,49 @@ func (h *Handler) AdminApper(f userApperHandlerFunc) http.HandlerFunc {
} }
} }
func apiAuth(app *App, r *http.Request) (*User, error) {
// Authorize user from Authorization header
t := r.Header.Get("Authorization")
if t == "" {
return nil, ErrNoAccessToken
}
u := &User{ID: app.db.GetUserID(t)}
if u.ID == -1 {
return nil, ErrBadAccessToken
}
return u, nil
}
// optionaAPIAuth is used for endpoints that accept authenticated requests via
// Authorization header or cookie, unlike apiAuth. It returns a different err
// in the case where no Authorization header is present.
func optionalAPIAuth(app *App, r *http.Request) (*User, error) {
// Authorize user from Authorization header
t := r.Header.Get("Authorization")
if t == "" {
return nil, ErrNotLoggedIn
}
u := &User{ID: app.db.GetUserID(t)}
if u.ID == -1 {
return nil, ErrBadAccessToken
}
return u, nil
}
func webAuth(app *App, r *http.Request) (*User, error) {
u := getUserSession(app, r)
if u == nil {
return nil, ErrNotLoggedIn
}
return u, nil
}
// UserAPI handles requests made in the API by the authenticated user. // UserAPI handles requests made in the API by the authenticated user.
// This provides user-friendly HTML pages and actions that work in the browser. // This provides user-friendly HTML pages and actions that work in the browser.
func (h *Handler) UserAPI(f userHandlerFunc) http.HandlerFunc { func (h *Handler) UserAPI(f userHandlerFunc) http.HandlerFunc {
return h.UserAll(false, f, func(app *App, r *http.Request) (*User, error) { return h.UserAll(false, f, apiAuth)
// Authorize user from Authorization header
t := r.Header.Get("Authorization")
if t == "" {
return nil, ErrNoAccessToken
}
u := &User{ID: app.db.GetUserID(t)}
if u.ID == -1 {
return nil, ErrBadAccessToken
}
return u, nil
})
} }
func (h *Handler) UserAll(web bool, f userHandlerFunc, a authFunc) http.HandlerFunc { func (h *Handler) UserAll(web bool, f userHandlerFunc, a authFunc) http.HandlerFunc {
@ -515,6 +542,64 @@ func (h *Handler) All(f handlerFunc) http.HandlerFunc {
} }
} }
func (h *Handler) AllReader(f handlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
h.handleError(w, r, func() error {
status := 200
start := time.Now()
defer func() {
if e := recover(); e != nil {
log.Error("%s:\n%s", e, debug.Stack())
impart.WriteError(w, impart.HTTPError{http.StatusInternalServerError, "Something didn't work quite right."})
status = 500
}
log.Info("\"%s %s\" %d %s \"%s\"", r.Method, r.RequestURI, status, time.Since(start), r.UserAgent())
}()
if h.app.App().cfg.App.Private {
// This instance is private, so ensure it's being accessed by a valid user
// Check if authenticated with an access token
_, apiErr := optionalAPIAuth(h.app.App(), r)
if apiErr != nil {
if err, ok := apiErr.(impart.HTTPError); ok {
status = err.Status
} else {
status = 500
}
if apiErr == ErrNotLoggedIn {
// Fall back to web auth since there was no access token given
_, err := webAuth(h.app.App(), r)
if err != nil {
if err, ok := apiErr.(impart.HTTPError); ok {
status = err.Status
} else {
status = 500
}
return err
}
} else {
return apiErr
}
}
}
err := f(h.app.App(), w, r)
if err != nil {
if err, ok := err.(impart.HTTPError); ok {
status = err.Status
} else {
status = 500
}
}
return err
}())
}
}
func (h *Handler) Download(f dataHandlerFunc, ul UserLevelFunc) http.HandlerFunc { func (h *Handler) Download(f dataHandlerFunc, ul UserLevelFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
h.handleHTTPError(w, r, func() error { h.handleHTTPError(w, r, func() error {

@ -113,28 +113,28 @@ func InitRoutes(apper Apper, r *mux.Router) *mux.Router {
// Handle collections // Handle collections
write.HandleFunc("/api/collections", handler.All(newCollection)).Methods("POST") write.HandleFunc("/api/collections", handler.All(newCollection)).Methods("POST")
apiColls := write.PathPrefix("/api/collections/").Subrouter() apiColls := write.PathPrefix("/api/collections/").Subrouter()
apiColls.HandleFunc("/{alias:[0-9a-zA-Z\\-]+}", handler.All(fetchCollection)).Methods("GET") apiColls.HandleFunc("/{alias:[0-9a-zA-Z\\-]+}", handler.AllReader(fetchCollection)).Methods("GET")
apiColls.HandleFunc("/{alias:[0-9a-zA-Z\\-]+}", handler.All(existingCollection)).Methods("POST", "DELETE") apiColls.HandleFunc("/{alias:[0-9a-zA-Z\\-]+}", handler.All(existingCollection)).Methods("POST", "DELETE")
apiColls.HandleFunc("/{alias}/posts", handler.All(fetchCollectionPosts)).Methods("GET") apiColls.HandleFunc("/{alias}/posts", handler.AllReader(fetchCollectionPosts)).Methods("GET")
apiColls.HandleFunc("/{alias}/posts", handler.All(newPost)).Methods("POST") apiColls.HandleFunc("/{alias}/posts", handler.All(newPost)).Methods("POST")
apiColls.HandleFunc("/{alias}/posts/{post}", handler.All(fetchPost)).Methods("GET") apiColls.HandleFunc("/{alias}/posts/{post}", handler.AllReader(fetchPost)).Methods("GET")
apiColls.HandleFunc("/{alias}/posts/{post:[a-zA-Z0-9]{10}}", handler.All(existingPost)).Methods("POST") apiColls.HandleFunc("/{alias}/posts/{post:[a-zA-Z0-9]{10}}", handler.All(existingPost)).Methods("POST")
apiColls.HandleFunc("/{alias}/posts/{post}/{property}", handler.All(fetchPostProperty)).Methods("GET") apiColls.HandleFunc("/{alias}/posts/{post}/{property}", handler.AllReader(fetchPostProperty)).Methods("GET")
apiColls.HandleFunc("/{alias}/collect", handler.All(addPost)).Methods("POST") apiColls.HandleFunc("/{alias}/collect", handler.All(addPost)).Methods("POST")
apiColls.HandleFunc("/{alias}/pin", handler.All(pinPost)).Methods("POST") apiColls.HandleFunc("/{alias}/pin", handler.All(pinPost)).Methods("POST")
apiColls.HandleFunc("/{alias}/unpin", handler.All(pinPost)).Methods("POST") apiColls.HandleFunc("/{alias}/unpin", handler.All(pinPost)).Methods("POST")
apiColls.HandleFunc("/{alias}/inbox", handler.All(handleFetchCollectionInbox)).Methods("POST") apiColls.HandleFunc("/{alias}/inbox", handler.All(handleFetchCollectionInbox)).Methods("POST")
apiColls.HandleFunc("/{alias}/outbox", handler.All(handleFetchCollectionOutbox)).Methods("GET") apiColls.HandleFunc("/{alias}/outbox", handler.AllReader(handleFetchCollectionOutbox)).Methods("GET")
apiColls.HandleFunc("/{alias}/following", handler.All(handleFetchCollectionFollowing)).Methods("GET") apiColls.HandleFunc("/{alias}/following", handler.AllReader(handleFetchCollectionFollowing)).Methods("GET")
apiColls.HandleFunc("/{alias}/followers", handler.All(handleFetchCollectionFollowers)).Methods("GET") apiColls.HandleFunc("/{alias}/followers", handler.AllReader(handleFetchCollectionFollowers)).Methods("GET")
// Handle posts // Handle posts
write.HandleFunc("/api/posts", handler.All(newPost)).Methods("POST") write.HandleFunc("/api/posts", handler.All(newPost)).Methods("POST")
posts := write.PathPrefix("/api/posts/").Subrouter() posts := write.PathPrefix("/api/posts/").Subrouter()
posts.HandleFunc("/{post:[a-zA-Z0-9]{10}}", handler.All(fetchPost)).Methods("GET") posts.HandleFunc("/{post:[a-zA-Z0-9]{10}}", handler.AllReader(fetchPost)).Methods("GET")
posts.HandleFunc("/{post:[a-zA-Z0-9]{10}}", handler.All(existingPost)).Methods("POST", "PUT") posts.HandleFunc("/{post:[a-zA-Z0-9]{10}}", handler.All(existingPost)).Methods("POST", "PUT")
posts.HandleFunc("/{post:[a-zA-Z0-9]{10}}", handler.All(deletePost)).Methods("DELETE") posts.HandleFunc("/{post:[a-zA-Z0-9]{10}}", handler.All(deletePost)).Methods("DELETE")
posts.HandleFunc("/{post:[a-zA-Z0-9]{10}}/{property}", handler.All(fetchPostProperty)).Methods("GET") posts.HandleFunc("/{post:[a-zA-Z0-9]{10}}/{property}", handler.AllReader(fetchPostProperty)).Methods("GET")
posts.HandleFunc("/claim", handler.All(addPost)).Methods("POST") posts.HandleFunc("/claim", handler.All(addPost)).Methods("POST")
posts.HandleFunc("/disperse", handler.All(dispersePost)).Methods("POST") posts.HandleFunc("/disperse", handler.All(dispersePost)).Methods("POST")

Loading…
Cancel
Save