2024-09-10 19:46:16 +02:00
|
|
|
package handlers
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
2024-09-16 06:55:39 +02:00
|
|
|
"crypto/rand"
|
|
|
|
"crypto/sha256"
|
|
|
|
"encoding/base64"
|
|
|
|
"io"
|
2024-09-10 19:46:16 +02:00
|
|
|
"log/slog"
|
|
|
|
"net/http"
|
|
|
|
"os"
|
|
|
|
|
2024-09-16 06:55:39 +02:00
|
|
|
"github.com/coreos/go-oidc/v3/oidc"
|
2024-09-25 15:33:28 +02:00
|
|
|
"github.com/gorilla/sessions"
|
2024-09-16 11:45:24 +02:00
|
|
|
"github.com/thanhpk/randstr"
|
2024-09-16 06:55:39 +02:00
|
|
|
"golang.org/x/oauth2"
|
2024-09-10 19:46:16 +02:00
|
|
|
|
|
|
|
"gitlab.unjx.de/flohoss/godash/internal/env"
|
2024-09-16 11:45:24 +02:00
|
|
|
"gitlab.unjx.de/flohoss/godash/services"
|
2024-09-10 19:46:16 +02:00
|
|
|
)
|
|
|
|
|
2024-09-25 15:33:28 +02:00
|
|
|
type contextKey string
|
|
|
|
|
|
|
|
const (
|
|
|
|
NameKey contextKey = "name"
|
|
|
|
EmailKey contextKey = "email"
|
|
|
|
GravatarKey contextKey = "gravatar"
|
|
|
|
|
|
|
|
StoreSessionKey = "godash_session"
|
|
|
|
)
|
2024-09-16 06:55:39 +02:00
|
|
|
|
|
|
|
func generateCodeVerifier() (string, error) {
|
|
|
|
verifierLength := 64
|
|
|
|
verifier := make([]byte, verifierLength)
|
2024-09-10 19:46:16 +02:00
|
|
|
|
2024-09-16 06:55:39 +02:00
|
|
|
_, err := rand.Read(verifier)
|
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
return base64.RawURLEncoding.EncodeToString(verifier), nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func generateCodeChallenge(verifier string) string {
|
|
|
|
hash := sha256.New()
|
|
|
|
_, _ = io.WriteString(hash, verifier)
|
|
|
|
sha := hash.Sum(nil)
|
|
|
|
return base64.RawURLEncoding.EncodeToString(sha)
|
|
|
|
}
|
|
|
|
|
2024-09-25 15:33:28 +02:00
|
|
|
func NewAuthHandler(env *env.Config, store *sessions.CookieStore) *AuthHandler {
|
2024-09-10 19:46:16 +02:00
|
|
|
ctx := context.Background()
|
2024-09-25 15:33:28 +02:00
|
|
|
provider, err := oidc.NewProvider(ctx, env.AuthIssuer)
|
2024-09-10 19:46:16 +02:00
|
|
|
if err != nil {
|
2024-09-16 06:55:39 +02:00
|
|
|
slog.Error("Failed to get oidc provider", "err", err.Error())
|
2024-09-10 19:46:16 +02:00
|
|
|
os.Exit(1)
|
|
|
|
}
|
|
|
|
|
2024-09-25 15:33:28 +02:00
|
|
|
config := &oauth2.Config{
|
|
|
|
ClientID: env.AuthClientID,
|
|
|
|
ClientSecret: env.AuthClientSecret,
|
|
|
|
RedirectURL: env.PublicUrl + "/callback",
|
|
|
|
Scopes: env.AuthScopes,
|
|
|
|
Endpoint: provider.Endpoint(),
|
2024-09-10 19:46:16 +02:00
|
|
|
}
|
|
|
|
|
2024-09-16 06:55:39 +02:00
|
|
|
codeVerifier, err := generateCodeVerifier()
|
|
|
|
if err != nil {
|
|
|
|
slog.Error("Error generating code verifier", "err", err.Error())
|
|
|
|
os.Exit(1)
|
|
|
|
}
|
|
|
|
codeChallenge := generateCodeChallenge(codeVerifier)
|
|
|
|
authCodeOptions := []oauth2.AuthCodeOption{
|
|
|
|
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
|
|
|
|
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
2024-09-16 11:45:24 +02:00
|
|
|
oauth2.SetAuthURLParam("code_verifier", codeVerifier),
|
2024-09-10 19:46:16 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
return &AuthHandler{
|
2024-09-25 15:33:28 +02:00
|
|
|
provider: provider,
|
|
|
|
config: config,
|
2024-09-16 06:55:39 +02:00
|
|
|
authCodeOptions: authCodeOptions,
|
2024-09-25 15:33:28 +02:00
|
|
|
store: store,
|
2024-09-10 19:46:16 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
type AuthHandler struct {
|
2024-09-25 15:33:28 +02:00
|
|
|
provider *oidc.Provider
|
|
|
|
config *oauth2.Config
|
2024-09-16 06:55:39 +02:00
|
|
|
authCodeOptions []oauth2.AuthCodeOption
|
2024-09-25 15:33:28 +02:00
|
|
|
store *sessions.CookieStore
|
2024-09-16 06:55:39 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
func (ah *AuthHandler) handleCallback(w http.ResponseWriter, r *http.Request) {
|
2024-09-25 15:33:28 +02:00
|
|
|
session, _ := ah.store.Get(r, StoreSessionKey)
|
|
|
|
state, ok := session.Values["state"].(string)
|
|
|
|
if !ok || state == "" {
|
2024-09-16 06:55:39 +02:00
|
|
|
http.Error(w, "state not found", http.StatusBadRequest)
|
|
|
|
return
|
|
|
|
}
|
2024-09-25 15:33:28 +02:00
|
|
|
if r.URL.Query().Get("state") != state {
|
2024-09-16 06:55:39 +02:00
|
|
|
http.Error(w, "state did not match", http.StatusBadRequest)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
2024-09-25 15:33:28 +02:00
|
|
|
oauth2Token, err := ah.config.Exchange(r.Context(), r.URL.Query().Get("code"), ah.authCodeOptions...)
|
2024-09-16 06:55:39 +02:00
|
|
|
if err != nil {
|
|
|
|
http.Error(w, "failed to exchange token: "+err.Error(), http.StatusInternalServerError)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
2024-09-25 15:33:28 +02:00
|
|
|
userInfo, err := ah.provider.UserInfo(r.Context(), oauth2.StaticTokenSource(oauth2Token))
|
|
|
|
if err != nil {
|
|
|
|
http.Error(w, "failed to get userinfo: "+err.Error(), http.StatusInternalServerError)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
user := &services.User{}
|
|
|
|
userInfo.Claims(user)
|
|
|
|
|
|
|
|
session.Values[string(NameKey)] = user.Name
|
|
|
|
session.Values[string(EmailKey)] = user.Email
|
|
|
|
session.Values[string(GravatarKey)] = services.NewGravatarFromEmail(user.Email).GetURL()
|
|
|
|
err = session.Save(r, w)
|
|
|
|
if err != nil {
|
|
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
2024-09-16 11:45:24 +02:00
|
|
|
http.Redirect(w, r, "/", http.StatusFound)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (ah *AuthHandler) handleLogout(w http.ResponseWriter, r *http.Request) {
|
2024-09-25 15:33:28 +02:00
|
|
|
session, _ := ah.store.Get(r, StoreSessionKey)
|
|
|
|
session.Values = make(map[interface{}]interface{})
|
|
|
|
err := session.Save(r, w)
|
|
|
|
if err != nil {
|
|
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
|
|
return
|
|
|
|
}
|
2024-09-16 11:45:24 +02:00
|
|
|
http.Redirect(w, r, "/", http.StatusFound)
|
|
|
|
}
|
|
|
|
|
2024-09-16 06:55:39 +02:00
|
|
|
func (ah *AuthHandler) AuthMiddleware(next http.Handler) http.Handler {
|
|
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
2024-09-25 15:33:28 +02:00
|
|
|
session, _ := ah.store.Get(r, StoreSessionKey)
|
|
|
|
name, ok := session.Values[string(NameKey)].(string)
|
|
|
|
if !ok || name == "" {
|
|
|
|
state := randstr.String(16)
|
|
|
|
session.Values["state"] = state
|
|
|
|
err := session.Save(r, w)
|
|
|
|
if err != nil {
|
|
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
http.Redirect(w, r, ah.config.AuthCodeURL(state, ah.authCodeOptions...), http.StatusFound)
|
2024-09-16 06:55:39 +02:00
|
|
|
return
|
|
|
|
}
|
2024-09-16 17:37:57 +02:00
|
|
|
|
2024-09-16 06:55:39 +02:00
|
|
|
next.ServeHTTP(w, r)
|
|
|
|
})
|
2024-09-10 19:46:16 +02:00
|
|
|
}
|