forked from Mxmilu666/frp
auth/oidc: cache OIDC access token and refresh before expiry (#5175)
* auth/oidc: cache OIDC access token and refresh before expiry - Use Config.TokenSource(ctx) once at init to create a persistent oauth2.TokenSource that caches the token and only refreshes on expiry - Wrap with oauth2.ReuseTokenSourceWithExpiry for configurable early refresh - Add tokenRefreshAdvanceDuration config option (default: 300s) - Add unit test verifying token caching with mock HTTP server * address review comments * auth/oidc: fallback to per-request token fetch when expires_in is missing When an OIDC provider omits the expires_in field, oauth2.ReuseTokenSource treats the cached token as valid forever and never refreshes it. This causes server-side OIDC verification to fail once the JWT's exp claim passes. Add a nonCachingTokenSource fallback: after fetching the initial token, if its Expiry is the zero value, swap the caching TokenSource for one that fetches a fresh token on every request, preserving the old behavior for providers that don't return expires_in. * auth/oidc: fix gosec lint and add test for zero-expiry fallback Suppress G101 false positive on test-only dummy token responses. Add test to verify per-request token fetch when expires_in is missing. Update caching test to account for eager initial token fetch. * fix lint
This commit is contained in:
@@ -75,11 +75,23 @@ func createOIDCHTTPClient(trustedCAFile string, insecureSkipVerify bool, proxyUR
|
||||
return &http.Client{Transport: transport}, nil
|
||||
}
|
||||
|
||||
// nonCachingTokenSource wraps a clientcredentials.Config to fetch a fresh
|
||||
// token on every call. This is used as a fallback when the OIDC provider
|
||||
// does not return expires_in, which would cause a caching TokenSource to
|
||||
// hold onto a stale token forever.
|
||||
type nonCachingTokenSource struct {
|
||||
cfg *clientcredentials.Config
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (s *nonCachingTokenSource) Token() (*oauth2.Token, error) {
|
||||
return s.cfg.Token(s.ctx)
|
||||
}
|
||||
|
||||
type OidcAuthProvider struct {
|
||||
additionalAuthScopes []v1.AuthScope
|
||||
|
||||
tokenGenerator *clientcredentials.Config
|
||||
httpClient *http.Client
|
||||
tokenSource oauth2.TokenSource
|
||||
}
|
||||
|
||||
func NewOidcAuthSetter(additionalAuthScopes []v1.AuthScope, cfg v1.AuthOIDCClientConfig) (*OidcAuthProvider, error) {
|
||||
@@ -100,30 +112,44 @@ func NewOidcAuthSetter(additionalAuthScopes []v1.AuthScope, cfg v1.AuthOIDCClien
|
||||
EndpointParams: eps,
|
||||
}
|
||||
|
||||
// Create custom HTTP client if needed
|
||||
var httpClient *http.Client
|
||||
// Build the context that TokenSource will use for all future HTTP requests.
|
||||
// context.Background() is appropriate here because the token source is
|
||||
// long-lived and outlives any single request.
|
||||
ctx := context.Background()
|
||||
if cfg.TrustedCaFile != "" || cfg.InsecureSkipVerify || cfg.ProxyURL != "" {
|
||||
var err error
|
||||
httpClient, err = createOIDCHTTPClient(cfg.TrustedCaFile, cfg.InsecureSkipVerify, cfg.ProxyURL)
|
||||
httpClient, err := createOIDCHTTPClient(cfg.TrustedCaFile, cfg.InsecureSkipVerify, cfg.ProxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create OIDC HTTP client: %w", err)
|
||||
}
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
|
||||
}
|
||||
|
||||
// Create a persistent TokenSource that caches the token and refreshes
|
||||
// it before expiry. This avoids making a new HTTP request to the OIDC
|
||||
// provider on every heartbeat/ping.
|
||||
tokenSource := tokenGenerator.TokenSource(ctx)
|
||||
|
||||
// Fetch the initial token to check if the provider returns an expiry.
|
||||
// If Expiry is the zero value (provider omitted expires_in), the cached
|
||||
// TokenSource would treat the token as valid forever and never refresh it,
|
||||
// even after the JWT's exp claim passes. In that case, fall back to
|
||||
// fetching a fresh token on every request.
|
||||
initialToken, err := tokenSource.Token()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to obtain initial OIDC token: %w", err)
|
||||
}
|
||||
if initialToken.Expiry.IsZero() {
|
||||
tokenSource = &nonCachingTokenSource{cfg: tokenGenerator, ctx: ctx}
|
||||
}
|
||||
|
||||
return &OidcAuthProvider{
|
||||
additionalAuthScopes: additionalAuthScopes,
|
||||
tokenGenerator: tokenGenerator,
|
||||
httpClient: httpClient,
|
||||
tokenSource: tokenSource,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (auth *OidcAuthProvider) generateAccessToken() (accessToken string, err error) {
|
||||
ctx := context.Background()
|
||||
if auth.httpClient != nil {
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, auth.httpClient)
|
||||
}
|
||||
|
||||
tokenObj, err := auth.tokenGenerator.Token(ctx)
|
||||
tokenObj, err := auth.tokenSource.Token()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("couldn't generate OIDC token for login: %v", err)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,10 @@ package auth_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -62,3 +66,90 @@ func TestPingAfterLoginWithDifferentSubjectFails(t *testing.T) {
|
||||
r.Error(err)
|
||||
r.Contains(err.Error(), "received different OIDC subject in login and ping")
|
||||
}
|
||||
|
||||
func TestOidcAuthProviderFallsBackWhenNoExpiry(t *testing.T) {
|
||||
r := require.New(t)
|
||||
|
||||
var requestCount atomic.Int32
|
||||
tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
requestCount.Add(1)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{ //nolint:gosec // test-only dummy token response
|
||||
"access_token": "fresh-test-token",
|
||||
"token_type": "Bearer",
|
||||
})
|
||||
}))
|
||||
defer tokenServer.Close()
|
||||
|
||||
provider, err := auth.NewOidcAuthSetter(
|
||||
[]v1.AuthScope{v1.AuthScopeHeartBeats},
|
||||
v1.AuthOIDCClientConfig{
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
TokenEndpointURL: tokenServer.URL,
|
||||
},
|
||||
)
|
||||
r.NoError(err)
|
||||
|
||||
// Constructor fetches the initial token (1 request).
|
||||
// Each subsequent call should also fetch a fresh token since there is no expiry.
|
||||
loginMsg := &msg.Login{}
|
||||
err = provider.SetLogin(loginMsg)
|
||||
r.NoError(err)
|
||||
r.Equal("fresh-test-token", loginMsg.PrivilegeKey)
|
||||
|
||||
for range 3 {
|
||||
pingMsg := &msg.Ping{}
|
||||
err = provider.SetPing(pingMsg)
|
||||
r.NoError(err)
|
||||
r.Equal("fresh-test-token", pingMsg.PrivilegeKey)
|
||||
}
|
||||
|
||||
// 1 initial (constructor) + 1 login + 3 pings = 5 requests
|
||||
r.Equal(int32(5), requestCount.Load(), "each call should fetch a fresh token when expires_in is missing")
|
||||
}
|
||||
|
||||
func TestOidcAuthProviderCachesToken(t *testing.T) {
|
||||
r := require.New(t)
|
||||
|
||||
var requestCount atomic.Int32
|
||||
tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
requestCount.Add(1)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{ //nolint:gosec // test-only dummy token response
|
||||
"access_token": "cached-test-token",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
})
|
||||
}))
|
||||
defer tokenServer.Close()
|
||||
|
||||
provider, err := auth.NewOidcAuthSetter(
|
||||
[]v1.AuthScope{v1.AuthScopeHeartBeats},
|
||||
v1.AuthOIDCClientConfig{
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
TokenEndpointURL: tokenServer.URL,
|
||||
},
|
||||
)
|
||||
r.NoError(err)
|
||||
|
||||
// Constructor eagerly fetches the initial token (1 request).
|
||||
r.Equal(int32(1), requestCount.Load())
|
||||
|
||||
// SetLogin should reuse the cached token
|
||||
loginMsg := &msg.Login{}
|
||||
err = provider.SetLogin(loginMsg)
|
||||
r.NoError(err)
|
||||
r.Equal("cached-test-token", loginMsg.PrivilegeKey)
|
||||
r.Equal(int32(1), requestCount.Load())
|
||||
|
||||
// Subsequent calls should also reuse the cached token
|
||||
for range 5 {
|
||||
pingMsg := &msg.Ping{}
|
||||
err = provider.SetPing(pingMsg)
|
||||
r.NoError(err)
|
||||
r.Equal("cached-test-token", pingMsg.PrivilegeKey)
|
||||
}
|
||||
r.Equal(int32(1), requestCount.Load(), "token endpoint should only be called once; cached token should be reused")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user