diff --git a/account.go b/account.go index 5dba924..ad58235 100644 --- a/account.go +++ b/account.go @@ -308,6 +308,7 @@ func viewLogin(app *App, w http.ResponseWriter, r *http.Request) error { LoginUsername string OauthSlack bool OauthWriteAs bool + OauthGitlab bool }{ pageForReq(app, r), r.FormValue("to"), @@ -316,6 +317,7 @@ func viewLogin(app *App, w http.ResponseWriter, r *http.Request) error { getTempInfo(app, "login-user", r, w), app.Config().SlackOauth.ClientID != "", app.Config().WriteAsOauth.ClientID != "", + app.Config().GitlabOauth.ClientID != "", } if earlyError != "" { diff --git a/config/config.go b/config/config.go index 78892bf..1d82a0e 100644 --- a/config/config.go +++ b/config/config.go @@ -69,6 +69,16 @@ type ( CallbackProxyAPI string `ini:"callback_proxy_api"` } + GitlabOauthCfg struct { + ClientID string `ini:"client_id"` + ClientSecret string `ini:"client_secret"` + AuthLocation string `ini:"auth_location"` + TokenLocation string `ini:"token_location"` + InspectLocation string `ini:"inspect_location"` + CallbackProxy string `ini:"callback_proxy"` + CallbackProxyAPI string `ini:"callback_proxy_api"` + } + SlackOauthCfg struct { ClientID string `ini:"client_id"` ClientSecret string `ini:"client_secret"` @@ -128,6 +138,7 @@ type ( App AppCfg `ini:"app"` SlackOauth SlackOauthCfg `ini:"oauth.slack"` WriteAsOauth WriteAsOauthCfg `ini:"oauth.writeas"` + GitlabOauth GitlabOauthCfg `ini:"oauth.gitlab"` } ) diff --git a/database.go b/database.go index cea7a97..f5e4564 100644 --- a/database.go +++ b/database.go @@ -2512,7 +2512,7 @@ func (db *datastore) GetCollectionLastPostTime(id int64) (*time.Time, error) { func (db *datastore) GenerateOAuthState(ctx context.Context, provider, clientID string) (string, error) { state := store.Generate62RandomString(24) - _, err := db.ExecContext(ctx, "INSERT INTO oauth_client_states (state, provider, client_id, used, created_at) VALUES (?, ?, ?, FALSE, NOW())", state, provider, clientID) + _, err := db.ExecContext(ctx, "INSERT INTO oauth_client_states (state, provider, client_id, used, created_at) VALUES (?, ?, ?, FALSE, " + db.now() + ")", state, provider, clientID) if err != nil { return "", fmt.Errorf("unable to record oauth client state: %w", err) } diff --git a/oauth.go b/oauth.go index caf8189..0893fcd 100644 --- a/oauth.go +++ b/oauth.go @@ -149,7 +149,7 @@ func configureWriteAsOauth(parentHandler *Handler, r *mux.Router, app *App) { callbackLocation: app.Config().App.Host + "/oauth/callback/write.as", httpClient: config.DefaultHTTPClient(), } - callbackLocation = app.Config().SlackOauth.CallbackProxy + callbackLocation = app.Config().WriteAsOauth.CallbackProxy } oauthClient := writeAsOauthClient{ @@ -165,6 +165,33 @@ func configureWriteAsOauth(parentHandler *Handler, r *mux.Router, app *App) { } } +func configureGitlabOauth(parentHandler *Handler, r *mux.Router, app *App) { + if app.Config().GitlabOauth.ClientID != "" { + callbackLocation := app.Config().App.Host + "/oauth/callback/gitlab" + + var callbackProxy *callbackProxyClient = nil + if app.Config().GitlabOauth.CallbackProxy != "" { + callbackProxy = &callbackProxyClient{ + server: app.Config().GitlabOauth.CallbackProxyAPI, + callbackLocation: app.Config().App.Host + "/oauth/callback/gitlab", + httpClient: config.DefaultHTTPClient(), + } + callbackLocation = app.Config().GitlabOauth.CallbackProxy + } + + oauthClient := gitlabOauthClient{ + ClientID: app.Config().GitlabOauth.ClientID, + ClientSecret: app.Config().GitlabOauth.ClientSecret, + ExchangeLocation: config.OrDefaultString(app.Config().GitlabOauth.TokenLocation, gitlabExchangeLocation), + InspectLocation: config.OrDefaultString(app.Config().GitlabOauth.InspectLocation, gitlabIdentityLocation), + AuthLocation: config.OrDefaultString(app.Config().GitlabOauth.AuthLocation, gitlabAuthLocation), + HttpClient: config.DefaultHTTPClient(), + CallbackLocation: callbackLocation, + } + configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy) + } +} + func configureOauthRoutes(parentHandler *Handler, r *mux.Router, app *App, oauthClient oauthClient, callbackProxy *callbackProxyClient) { handler := &oauthHandler{ Config: app.Config(), diff --git a/oauth_gitlab.go b/oauth_gitlab.go new file mode 100644 index 0000000..e5138d4 --- /dev/null +++ b/oauth_gitlab.go @@ -0,0 +1,116 @@ +package writefreely + +import ( + "context" + "errors" + "net/http" + "net/url" + "strings" +) + +type gitlabOauthClient struct { + ClientID string + ClientSecret string + AuthLocation string + ExchangeLocation string + InspectLocation string + CallbackLocation string + HttpClient HttpClient +} + +var _ oauthClient = gitlabOauthClient{} + +const ( + gitlabAuthLocation = "https://gitlab.com/oauth/authorize" + gitlabExchangeLocation = "https://gitlab.com/oauth/token" + gitlabIdentityLocation = "https://gitlab.com/api/v4/user" +) + +func (c gitlabOauthClient) GetProvider() string { + return "gitlab" +} + +func (c gitlabOauthClient) GetClientID() string { + return c.ClientID +} + +func (c gitlabOauthClient) GetCallbackLocation() string { + return c.CallbackLocation +} + +func (c gitlabOauthClient) buildLoginURL(state string) (string, error) { + u, err := url.Parse(c.AuthLocation) + if err != nil { + return "", err + } + q := u.Query() + q.Set("client_id", c.ClientID) + q.Set("redirect_uri", c.CallbackLocation) + q.Set("response_type", "code") + q.Set("state", state) + q.Set("scope", "read_user") + u.RawQuery = q.Encode() + return u.String(), nil +} + +func (c gitlabOauthClient) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) { + form := url.Values{} + form.Add("grant_type", "authorization_code") + form.Add("redirect_uri", c.CallbackLocation) + form.Add("scope", "read_user") + form.Add("code", code) + req, err := http.NewRequest("POST", c.ExchangeLocation, strings.NewReader(form.Encode())) + if err != nil { + return nil, err + } + req.WithContext(ctx) + req.Header.Set("User-Agent", "writefreely") + req.Header.Set("Accept", "application/json") + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.SetBasicAuth(c.ClientID, c.ClientSecret) + + resp, err := c.HttpClient.Do(req) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + return nil, errors.New("unable to exchange code for access token") + } + + var tokenResponse TokenResponse + if err := limitedJsonUnmarshal(resp.Body, tokenRequestMaxLen, &tokenResponse); err != nil { + return nil, err + } + if tokenResponse.Error != "" { + return nil, errors.New(tokenResponse.Error) + } + return &tokenResponse, nil +} + +func (c gitlabOauthClient) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) { + req, err := http.NewRequest("GET", c.InspectLocation, nil) + if err != nil { + return nil, err + } + req.WithContext(ctx) + req.Header.Set("User-Agent", "writefreely") + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+accessToken) + + resp, err := c.HttpClient.Do(req) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + return nil, errors.New("unable to inspect access token") + } + + var inspectResponse InspectResponse + if err := limitedJsonUnmarshal(resp.Body, infoRequestMaxLen, &inspectResponse); err != nil { + return nil, err + } + if inspectResponse.Error != "" { + return nil, errors.New(inspectResponse.Error) + } + return &inspectResponse, nil +} diff --git a/pages/login.tmpl b/pages/login.tmpl index 345b171..a988615 100644 --- a/pages/login.tmpl +++ b/pages/login.tmpl @@ -32,6 +32,10 @@ hr.short { box-sizing: border-box; font-size: 17px; } +#gitlab-login { + box-sizing: border-box; + font-size: 17px; +} {{end}} {{define "content"}} @@ -42,7 +46,7 @@ hr.short { {{range .Flashes}}