diff --git a/handlers/auth.handlers.go b/handlers/auth.handlers.go index dbf6096..801a394 100644 --- a/handlers/auth.handlers.go +++ b/handlers/auth.handlers.go @@ -49,6 +49,23 @@ func generateCodeChallenge(verifier string) string { return base64.RawURLEncoding.EncodeToString(sha) } +func (ah *AuthHandler) saveTokenToSession(r *http.Request, oauth2Token *oauth2.Token) { + ah.SessionManager.Put(r.Context(), "access_token", oauth2Token.AccessToken) + ah.SessionManager.Put(r.Context(), "refresh_token", oauth2Token.RefreshToken) + ah.SessionManager.Put(r.Context(), "token_type", oauth2Token.TokenType) + ah.SessionManager.Put(r.Context(), "expiry", oauth2Token.Expiry.Unix()) +} + +func (ah *AuthHandler) loadTokenFromSession(r *http.Request) *oauth2.Token { + ex := ah.SessionManager.GetInt64(r.Context(), "expiry") + return &oauth2.Token{ + AccessToken: ah.SessionManager.GetString(r.Context(), "access_token"), + RefreshToken: ah.SessionManager.GetString(r.Context(), "refresh_token"), + TokenType: ah.SessionManager.GetString(r.Context(), "token_type"), + Expiry: time.Unix(ex, 0), + } +} + func NewAuthHandler(env *env.Config) *AuthHandler { ctx := context.Background() @@ -80,7 +97,7 @@ func NewAuthHandler(env *env.Config) *AuthHandler { } sessionManager := scs.New() - sessionManager.Lifetime = 24 * time.Hour + sessionManager.Lifetime = 24 * 31 * time.Hour return &AuthHandler{ ctx: ctx, @@ -116,8 +133,7 @@ func (ah *AuthHandler) handleCallback(w http.ResponseWriter, r *http.Request) { return } - ah.SessionManager.Put(r.Context(), "access_token", oauth2Token.AccessToken) - + ah.saveTokenToSession(r, oauth2Token) http.Redirect(w, r, "/", http.StatusFound) } @@ -134,19 +150,25 @@ func (ah *AuthHandler) handleLogin(w http.ResponseWriter, r *http.Request) { func (ah *AuthHandler) AuthMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - accessToken := ah.SessionManager.GetString(r.Context(), "access_token") - if accessToken == "" { + exists := ah.SessionManager.Exists(r.Context(), "access_token") + if !exists { ah.handleLogin(w, r) return } - userInfo, err := ah.oidcProvider.UserInfo(ah.ctx, oauth2.StaticTokenSource(&oauth2.Token{AccessToken: accessToken})) + token := ah.loadTokenFromSession(r) + ah.oauth2Config.Client(ah.ctx, token) + + tokenInfo, err := ah.oidcProvider.Verifier(&oidc.Config{ClientID: ah.oauth2Config.ClientID}).Verify(ah.ctx, token.AccessToken) if err != nil { ah.handleLogin(w, r) return } + + ah.saveTokenToSession(r, token) + var userClaims services.User - userInfo.Claims(&userClaims) + tokenInfo.Claims(&userClaims) w.Header().Set("X-User-Name", userClaims.Name) w.Header().Set("X-User-Email", userClaims.Email)