From 94a631fe9c22491672b016413bb4d68067adeafb Mon Sep 17 00:00:00 2001 From: Shani Pathak Date: Wed, 11 Mar 2026 21:54:46 +0530 Subject: [PATCH] auth/oidc: cache OIDC access token and refresh before expiry (#5175) * auth/oidc: cache OIDC access token and refresh before expiry - Use Config.TokenSource(ctx) once at init to create a persistent oauth2.TokenSource that caches the token and only refreshes on expiry - Wrap with oauth2.ReuseTokenSourceWithExpiry for configurable early refresh - Add tokenRefreshAdvanceDuration config option (default: 300s) - Add unit test verifying token caching with mock HTTP server * address review comments * auth/oidc: fallback to per-request token fetch when expires_in is missing When an OIDC provider omits the expires_in field, oauth2.ReuseTokenSource treats the cached token as valid forever and never refreshes it. This causes server-side OIDC verification to fail once the JWT's exp claim passes. Add a nonCachingTokenSource fallback: after fetching the initial token, if its Expiry is the zero value, swap the caching TokenSource for one that fetches a fresh token on every request, preserving the old behavior for providers that don't return expires_in. * auth/oidc: fix gosec lint and add test for zero-expiry fallback Suppress G101 false positive on test-only dummy token responses. Add test to verify per-request token fetch when expires_in is missing. Update caching test to account for eager initial token fetch. * fix lint --- pkg/auth/oidc.go | 54 ++++++++++++++++++------- pkg/auth/oidc_test.go | 91 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 131 insertions(+), 14 deletions(-) diff --git a/pkg/auth/oidc.go b/pkg/auth/oidc.go index 9aeaf3c5..e63322f7 100644 --- a/pkg/auth/oidc.go +++ b/pkg/auth/oidc.go @@ -75,11 +75,23 @@ func createOIDCHTTPClient(trustedCAFile string, insecureSkipVerify bool, proxyUR return &http.Client{Transport: transport}, nil } +// nonCachingTokenSource wraps a clientcredentials.Config to fetch a fresh +// token on every call. This is used as a fallback when the OIDC provider +// does not return expires_in, which would cause a caching TokenSource to +// hold onto a stale token forever. +type nonCachingTokenSource struct { + cfg *clientcredentials.Config + ctx context.Context +} + +func (s *nonCachingTokenSource) Token() (*oauth2.Token, error) { + return s.cfg.Token(s.ctx) +} + type OidcAuthProvider struct { additionalAuthScopes []v1.AuthScope - tokenGenerator *clientcredentials.Config - httpClient *http.Client + tokenSource oauth2.TokenSource } func NewOidcAuthSetter(additionalAuthScopes []v1.AuthScope, cfg v1.AuthOIDCClientConfig) (*OidcAuthProvider, error) { @@ -100,30 +112,44 @@ func NewOidcAuthSetter(additionalAuthScopes []v1.AuthScope, cfg v1.AuthOIDCClien EndpointParams: eps, } - // Create custom HTTP client if needed - var httpClient *http.Client + // Build the context that TokenSource will use for all future HTTP requests. + // context.Background() is appropriate here because the token source is + // long-lived and outlives any single request. + ctx := context.Background() if cfg.TrustedCaFile != "" || cfg.InsecureSkipVerify || cfg.ProxyURL != "" { - var err error - httpClient, err = createOIDCHTTPClient(cfg.TrustedCaFile, cfg.InsecureSkipVerify, cfg.ProxyURL) + httpClient, err := createOIDCHTTPClient(cfg.TrustedCaFile, cfg.InsecureSkipVerify, cfg.ProxyURL) if err != nil { return nil, fmt.Errorf("failed to create OIDC HTTP client: %w", err) } + ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) + } + + // Create a persistent TokenSource that caches the token and refreshes + // it before expiry. This avoids making a new HTTP request to the OIDC + // provider on every heartbeat/ping. + tokenSource := tokenGenerator.TokenSource(ctx) + + // Fetch the initial token to check if the provider returns an expiry. + // If Expiry is the zero value (provider omitted expires_in), the cached + // TokenSource would treat the token as valid forever and never refresh it, + // even after the JWT's exp claim passes. In that case, fall back to + // fetching a fresh token on every request. + initialToken, err := tokenSource.Token() + if err != nil { + return nil, fmt.Errorf("failed to obtain initial OIDC token: %w", err) + } + if initialToken.Expiry.IsZero() { + tokenSource = &nonCachingTokenSource{cfg: tokenGenerator, ctx: ctx} } return &OidcAuthProvider{ additionalAuthScopes: additionalAuthScopes, - tokenGenerator: tokenGenerator, - httpClient: httpClient, + tokenSource: tokenSource, }, nil } func (auth *OidcAuthProvider) generateAccessToken() (accessToken string, err error) { - ctx := context.Background() - if auth.httpClient != nil { - ctx = context.WithValue(ctx, oauth2.HTTPClient, auth.httpClient) - } - - tokenObj, err := auth.tokenGenerator.Token(ctx) + tokenObj, err := auth.tokenSource.Token() if err != nil { return "", fmt.Errorf("couldn't generate OIDC token for login: %v", err) } diff --git a/pkg/auth/oidc_test.go b/pkg/auth/oidc_test.go index 58054186..ad4f0246 100644 --- a/pkg/auth/oidc_test.go +++ b/pkg/auth/oidc_test.go @@ -2,6 +2,10 @@ package auth_test import ( "context" + "encoding/json" + "net/http" + "net/http/httptest" + "sync/atomic" "testing" "time" @@ -62,3 +66,90 @@ func TestPingAfterLoginWithDifferentSubjectFails(t *testing.T) { r.Error(err) r.Contains(err.Error(), "received different OIDC subject in login and ping") } + +func TestOidcAuthProviderFallsBackWhenNoExpiry(t *testing.T) { + r := require.New(t) + + var requestCount atomic.Int32 + tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + requestCount.Add(1) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ //nolint:gosec // test-only dummy token response + "access_token": "fresh-test-token", + "token_type": "Bearer", + }) + })) + defer tokenServer.Close() + + provider, err := auth.NewOidcAuthSetter( + []v1.AuthScope{v1.AuthScopeHeartBeats}, + v1.AuthOIDCClientConfig{ + ClientID: "test-client", + ClientSecret: "test-secret", + TokenEndpointURL: tokenServer.URL, + }, + ) + r.NoError(err) + + // Constructor fetches the initial token (1 request). + // Each subsequent call should also fetch a fresh token since there is no expiry. + loginMsg := &msg.Login{} + err = provider.SetLogin(loginMsg) + r.NoError(err) + r.Equal("fresh-test-token", loginMsg.PrivilegeKey) + + for range 3 { + pingMsg := &msg.Ping{} + err = provider.SetPing(pingMsg) + r.NoError(err) + r.Equal("fresh-test-token", pingMsg.PrivilegeKey) + } + + // 1 initial (constructor) + 1 login + 3 pings = 5 requests + r.Equal(int32(5), requestCount.Load(), "each call should fetch a fresh token when expires_in is missing") +} + +func TestOidcAuthProviderCachesToken(t *testing.T) { + r := require.New(t) + + var requestCount atomic.Int32 + tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + requestCount.Add(1) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ //nolint:gosec // test-only dummy token response + "access_token": "cached-test-token", + "token_type": "Bearer", + "expires_in": 3600, + }) + })) + defer tokenServer.Close() + + provider, err := auth.NewOidcAuthSetter( + []v1.AuthScope{v1.AuthScopeHeartBeats}, + v1.AuthOIDCClientConfig{ + ClientID: "test-client", + ClientSecret: "test-secret", + TokenEndpointURL: tokenServer.URL, + }, + ) + r.NoError(err) + + // Constructor eagerly fetches the initial token (1 request). + r.Equal(int32(1), requestCount.Load()) + + // SetLogin should reuse the cached token + loginMsg := &msg.Login{} + err = provider.SetLogin(loginMsg) + r.NoError(err) + r.Equal("cached-test-token", loginMsg.PrivilegeKey) + r.Equal(int32(1), requestCount.Load()) + + // Subsequent calls should also reuse the cached token + for range 5 { + pingMsg := &msg.Ping{} + err = provider.SetPing(pingMsg) + r.NoError(err) + r.Equal("cached-test-token", pingMsg.PrivilegeKey) + } + r.Equal(int32(1), requestCount.Load(), "token endpoint should only be called once; cached token should be reused") +}