godash/handlers/auth.handlers.go

156 lines
4.1 KiB
Go
Raw Normal View History

2024-09-10 19:46:16 +02:00
package handlers
import (
"context"
2024-09-16 06:55:39 +02:00
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"io"
2024-09-10 19:46:16 +02:00
"log/slog"
"net/http"
"os"
"time"
2024-09-16 11:45:24 +02:00
"github.com/alexedwards/scs/v2"
2024-09-16 06:55:39 +02:00
"github.com/coreos/go-oidc/v3/oidc"
2024-09-16 11:45:24 +02:00
"github.com/thanhpk/randstr"
2024-09-16 06:55:39 +02:00
"golang.org/x/oauth2"
2024-09-10 19:46:16 +02:00
"gitlab.unjx.de/flohoss/godash/internal/env"
2024-09-16 11:45:24 +02:00
"gitlab.unjx.de/flohoss/godash/services"
2024-09-10 19:46:16 +02:00
)
2024-09-16 06:55:39 +02:00
func setCallbackCookie(w http.ResponseWriter, r *http.Request, name, value string) {
c := &http.Cookie{
Name: name,
Value: value,
MaxAge: int(time.Hour.Seconds()),
Secure: r.TLS != nil,
HttpOnly: true,
2024-09-10 19:46:16 +02:00
}
2024-09-16 06:55:39 +02:00
http.SetCookie(w, c)
}
func generateCodeVerifier() (string, error) {
verifierLength := 64
verifier := make([]byte, verifierLength)
2024-09-10 19:46:16 +02:00
2024-09-16 06:55:39 +02:00
_, err := rand.Read(verifier)
if err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(verifier), nil
}
func generateCodeChallenge(verifier string) string {
hash := sha256.New()
_, _ = io.WriteString(hash, verifier)
sha := hash.Sum(nil)
return base64.RawURLEncoding.EncodeToString(sha)
}
func NewAuthHandler(env *env.Config) *AuthHandler {
2024-09-10 19:46:16 +02:00
ctx := context.Background()
2024-09-16 06:55:39 +02:00
oidcProvider, err := oidc.NewProvider(ctx, env.OIDCIssuer)
2024-09-10 19:46:16 +02:00
if err != nil {
2024-09-16 06:55:39 +02:00
slog.Error("Failed to get oidc provider", "err", err.Error())
2024-09-10 19:46:16 +02:00
os.Exit(1)
}
2024-09-16 06:55:39 +02:00
oauth2Config := &oauth2.Config{
ClientID: env.OIDCClientID,
ClientSecret: env.OIDCClientSecret,
Endpoint: oidcProvider.Endpoint(),
RedirectURL: env.OIDCRedirectURI,
Scopes: env.OIDCScopes,
2024-09-10 19:46:16 +02:00
}
2024-09-16 06:55:39 +02:00
codeVerifier, err := generateCodeVerifier()
if err != nil {
slog.Error("Error generating code verifier", "err", err.Error())
os.Exit(1)
}
codeChallenge := generateCodeChallenge(codeVerifier)
authCodeOptions := []oauth2.AuthCodeOption{
2024-09-16 11:45:24 +02:00
oauth2.SetAuthURLParam("redirect_uri", env.OIDCRedirectURI),
2024-09-16 06:55:39 +02:00
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
2024-09-16 11:45:24 +02:00
oauth2.SetAuthURLParam("code_verifier", codeVerifier),
2024-09-10 19:46:16 +02:00
}
2024-09-16 11:45:24 +02:00
sessionManager := scs.New()
sessionManager.Lifetime = 24 * time.Hour
2024-09-10 19:46:16 +02:00
return &AuthHandler{
2024-09-16 06:55:39 +02:00
ctx: ctx,
oidcProvider: oidcProvider,
oauth2Config: oauth2Config,
authCodeOptions: authCodeOptions,
2024-09-16 11:45:24 +02:00
SessionManager: sessionManager,
2024-09-10 19:46:16 +02:00
}
}
type AuthHandler struct {
2024-09-16 06:55:39 +02:00
ctx context.Context
oidcProvider *oidc.Provider
oauth2Config *oauth2.Config
authCodeOptions []oauth2.AuthCodeOption
2024-09-16 11:45:24 +02:00
SessionManager *scs.SessionManager
2024-09-16 06:55:39 +02:00
}
func (ah *AuthHandler) handleCallback(w http.ResponseWriter, r *http.Request) {
state, err := r.Cookie("state")
if err != nil {
http.Error(w, "state not found", http.StatusBadRequest)
return
}
if r.URL.Query().Get("state") != state.Value {
http.Error(w, "state did not match", http.StatusBadRequest)
return
}
2024-09-16 11:45:24 +02:00
oauth2Token, err := ah.oauth2Config.Exchange(ah.ctx, r.URL.Query().Get("code"), ah.authCodeOptions...)
2024-09-16 06:55:39 +02:00
if err != nil {
http.Error(w, "failed to exchange token: "+err.Error(), http.StatusInternalServerError)
return
}
2024-09-16 11:45:24 +02:00
ah.SessionManager.Put(r.Context(), "access_token", oauth2Token.AccessToken)
2024-09-16 06:55:39 +02:00
2024-09-16 11:45:24 +02:00
http.Redirect(w, r, "/", http.StatusFound)
}
func (ah *AuthHandler) handleLogout(w http.ResponseWriter, r *http.Request) {
ah.SessionManager.Clear(r.Context())
http.Redirect(w, r, "/", http.StatusFound)
}
func (ah *AuthHandler) handleLogin(w http.ResponseWriter, r *http.Request) {
state := randstr.String(16)
setCallbackCookie(w, r, "state", state)
http.Redirect(w, r, ah.oauth2Config.AuthCodeURL(state, ah.authCodeOptions...), http.StatusFound)
2024-09-16 06:55:39 +02:00
}
func (ah *AuthHandler) AuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2024-09-16 11:45:24 +02:00
accessToken := ah.SessionManager.GetString(r.Context(), "access_token")
if accessToken == "" {
ah.handleLogin(w, r)
2024-09-16 06:55:39 +02:00
return
}
2024-09-16 11:45:24 +02:00
userInfo, err := ah.oidcProvider.UserInfo(ah.ctx, oauth2.StaticTokenSource(&oauth2.Token{AccessToken: accessToken}))
2024-09-16 06:55:39 +02:00
if err != nil {
2024-09-16 11:45:24 +02:00
ah.handleLogin(w, r)
2024-09-16 06:55:39 +02:00
return
}
2024-09-16 11:45:24 +02:00
var userClaims services.User
userInfo.Claims(&userClaims)
w.Header().Set("X-User-Name", userClaims.Name)
w.Header().Set("X-User-Email", userClaims.Email)
2024-09-16 06:55:39 +02:00
next.ServeHTTP(w, r)
})
2024-09-10 19:46:16 +02:00
}