Compare commits

...

3 Commits

Author SHA1 Message Date
Shani Pathak
94a631fe9c auth/oidc: cache OIDC access token and refresh before expiry (#5175)
* auth/oidc: cache OIDC access token and refresh before expiry

- Use Config.TokenSource(ctx) once at init to create a persistent
  oauth2.TokenSource that caches the token and only refreshes on expiry
- Wrap with oauth2.ReuseTokenSourceWithExpiry for configurable early refresh
- Add tokenRefreshAdvanceDuration config option (default: 300s)
- Add unit test verifying token caching with mock HTTP server

* address review comments

* auth/oidc: fallback to per-request token fetch when expires_in is missing

When an OIDC provider omits the expires_in field, oauth2.ReuseTokenSource
treats the cached token as valid forever and never refreshes it. This causes
server-side OIDC verification to fail once the JWT's exp claim passes.

Add a nonCachingTokenSource fallback: after fetching the initial token, if
its Expiry is the zero value, swap the caching TokenSource for one that
fetches a fresh token on every request, preserving the old behavior for
providers that don't return expires_in.

* auth/oidc: fix gosec lint and add test for zero-expiry fallback

Suppress G101 false positive on test-only dummy token responses.
Add test to verify per-request token fetch when expires_in is missing.
Update caching test to account for eager initial token fetch.

* fix lint
2026-03-12 00:24:46 +08:00
fatedier
6b1be922e1 add AGENTS.md and CLAUDE.md, remove them from .gitignore (#5232) 2026-03-12 00:21:31 +08:00
fatedier
4f584f81d0 test/e2e: replace sleeps with event-driven waits in chaos/group/store tests (#5231)
* test/e2e: replace sleeps with event-driven waits in chaos/group/store tests

Replace 21 time.Sleep calls with deterministic waiting using
WaitForOutput, WaitForTCPReady, and a new WaitForTCPUnreachable helper.
Add CountOutput method for snapshot-based incremental log matching.

* test/e2e: validate interval and cap dial/sleep to remaining deadline in WaitForTCPUnreachable
2026-03-12 00:11:09 +08:00
10 changed files with 224 additions and 41 deletions

3
.gitignore vendored
View File

@@ -29,6 +29,5 @@ client.key
*.swp
# AI
CLAUDE.md
AGENTS.md
.claude/
.sisyphus/

34
AGENTS.md Normal file
View File

@@ -0,0 +1,34 @@
# AGENTS.md
## Development Commands
### Build
- `make build` - Build both frps and frpc binaries
- `make frps` - Build server binary only
- `make frpc` - Build client binary only
- `make all` - Build everything with formatting
### Testing
- `make test` - Run unit tests
- `make e2e` - Run end-to-end tests
- `make e2e-trace` - Run e2e tests with trace logging
- `make alltest` - Run all tests including vet, unit tests, and e2e
### Code Quality
- `make fmt` - Run go fmt
- `make fmt-more` - Run gofumpt for more strict formatting
- `make gci` - Run gci import organizer
- `make vet` - Run go vet
- `golangci-lint run` - Run comprehensive linting (configured in .golangci.yml)
### Assets
- `make web` - Build web dashboards (frps and frpc)
### Cleanup
- `make clean` - Remove built binaries and temporary files
## Testing
- E2E tests using Ginkgo/Gomega framework
- Mock servers in `/test/e2e/mock/`
- Run: `make e2e` or `make alltest`

1
CLAUDE.md Symbolic link
View File

@@ -0,0 +1 @@
AGENTS.md

View File

@@ -75,11 +75,23 @@ func createOIDCHTTPClient(trustedCAFile string, insecureSkipVerify bool, proxyUR
return &http.Client{Transport: transport}, nil
}
// nonCachingTokenSource wraps a clientcredentials.Config to fetch a fresh
// token on every call. This is used as a fallback when the OIDC provider
// does not return expires_in, which would cause a caching TokenSource to
// hold onto a stale token forever.
type nonCachingTokenSource struct {
cfg *clientcredentials.Config
ctx context.Context
}
func (s *nonCachingTokenSource) Token() (*oauth2.Token, error) {
return s.cfg.Token(s.ctx)
}
type OidcAuthProvider struct {
additionalAuthScopes []v1.AuthScope
tokenGenerator *clientcredentials.Config
httpClient *http.Client
tokenSource oauth2.TokenSource
}
func NewOidcAuthSetter(additionalAuthScopes []v1.AuthScope, cfg v1.AuthOIDCClientConfig) (*OidcAuthProvider, error) {
@@ -100,30 +112,44 @@ func NewOidcAuthSetter(additionalAuthScopes []v1.AuthScope, cfg v1.AuthOIDCClien
EndpointParams: eps,
}
// Create custom HTTP client if needed
var httpClient *http.Client
// Build the context that TokenSource will use for all future HTTP requests.
// context.Background() is appropriate here because the token source is
// long-lived and outlives any single request.
ctx := context.Background()
if cfg.TrustedCaFile != "" || cfg.InsecureSkipVerify || cfg.ProxyURL != "" {
var err error
httpClient, err = createOIDCHTTPClient(cfg.TrustedCaFile, cfg.InsecureSkipVerify, cfg.ProxyURL)
httpClient, err := createOIDCHTTPClient(cfg.TrustedCaFile, cfg.InsecureSkipVerify, cfg.ProxyURL)
if err != nil {
return nil, fmt.Errorf("failed to create OIDC HTTP client: %w", err)
}
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
}
// Create a persistent TokenSource that caches the token and refreshes
// it before expiry. This avoids making a new HTTP request to the OIDC
// provider on every heartbeat/ping.
tokenSource := tokenGenerator.TokenSource(ctx)
// Fetch the initial token to check if the provider returns an expiry.
// If Expiry is the zero value (provider omitted expires_in), the cached
// TokenSource would treat the token as valid forever and never refresh it,
// even after the JWT's exp claim passes. In that case, fall back to
// fetching a fresh token on every request.
initialToken, err := tokenSource.Token()
if err != nil {
return nil, fmt.Errorf("failed to obtain initial OIDC token: %w", err)
}
if initialToken.Expiry.IsZero() {
tokenSource = &nonCachingTokenSource{cfg: tokenGenerator, ctx: ctx}
}
return &OidcAuthProvider{
additionalAuthScopes: additionalAuthScopes,
tokenGenerator: tokenGenerator,
httpClient: httpClient,
tokenSource: tokenSource,
}, nil
}
func (auth *OidcAuthProvider) generateAccessToken() (accessToken string, err error) {
ctx := context.Background()
if auth.httpClient != nil {
ctx = context.WithValue(ctx, oauth2.HTTPClient, auth.httpClient)
}
tokenObj, err := auth.tokenGenerator.Token(ctx)
tokenObj, err := auth.tokenSource.Token()
if err != nil {
return "", fmt.Errorf("couldn't generate OIDC token for login: %v", err)
}

View File

@@ -2,6 +2,10 @@ package auth_test
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
@@ -62,3 +66,90 @@ func TestPingAfterLoginWithDifferentSubjectFails(t *testing.T) {
r.Error(err)
r.Contains(err.Error(), "received different OIDC subject in login and ping")
}
func TestOidcAuthProviderFallsBackWhenNoExpiry(t *testing.T) {
r := require.New(t)
var requestCount atomic.Int32
tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
requestCount.Add(1)
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{ //nolint:gosec // test-only dummy token response
"access_token": "fresh-test-token",
"token_type": "Bearer",
})
}))
defer tokenServer.Close()
provider, err := auth.NewOidcAuthSetter(
[]v1.AuthScope{v1.AuthScopeHeartBeats},
v1.AuthOIDCClientConfig{
ClientID: "test-client",
ClientSecret: "test-secret",
TokenEndpointURL: tokenServer.URL,
},
)
r.NoError(err)
// Constructor fetches the initial token (1 request).
// Each subsequent call should also fetch a fresh token since there is no expiry.
loginMsg := &msg.Login{}
err = provider.SetLogin(loginMsg)
r.NoError(err)
r.Equal("fresh-test-token", loginMsg.PrivilegeKey)
for range 3 {
pingMsg := &msg.Ping{}
err = provider.SetPing(pingMsg)
r.NoError(err)
r.Equal("fresh-test-token", pingMsg.PrivilegeKey)
}
// 1 initial (constructor) + 1 login + 3 pings = 5 requests
r.Equal(int32(5), requestCount.Load(), "each call should fetch a fresh token when expires_in is missing")
}
func TestOidcAuthProviderCachesToken(t *testing.T) {
r := require.New(t)
var requestCount atomic.Int32
tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
requestCount.Add(1)
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{ //nolint:gosec // test-only dummy token response
"access_token": "cached-test-token",
"token_type": "Bearer",
"expires_in": 3600,
})
}))
defer tokenServer.Close()
provider, err := auth.NewOidcAuthSetter(
[]v1.AuthScope{v1.AuthScopeHeartBeats},
v1.AuthOIDCClientConfig{
ClientID: "test-client",
ClientSecret: "test-secret",
TokenEndpointURL: tokenServer.URL,
},
)
r.NoError(err)
// Constructor eagerly fetches the initial token (1 request).
r.Equal(int32(1), requestCount.Load())
// SetLogin should reuse the cached token
loginMsg := &msg.Login{}
err = provider.SetLogin(loginMsg)
r.NoError(err)
r.Equal("cached-test-token", loginMsg.PrivilegeKey)
r.Equal(int32(1), requestCount.Load())
// Subsequent calls should also reuse the cached token
for range 5 {
pingMsg := &msg.Ping{}
err = provider.SetPing(pingMsg)
r.NoError(err)
r.Equal("cached-test-token", pingMsg.PrivilegeKey)
}
r.Equal(int32(1), requestCount.Load(), "token endpoint should only be called once; cached token should be reused")
}

View File

@@ -144,6 +144,30 @@ func waitForClientProxyReady(configPath string, p *process.Process, timeout time
return true
}
// WaitForTCPUnreachable polls a TCP address until a connection fails or timeout.
func WaitForTCPUnreachable(addr string, interval, timeout time.Duration) error {
if interval <= 0 {
return fmt.Errorf("invalid interval for TCP unreachable on %s: interval must be positive", addr)
}
if timeout <= 0 {
return fmt.Errorf("invalid timeout for TCP unreachable on %s: timeout must be positive", addr)
}
deadline := time.Now().Add(timeout)
for {
remaining := time.Until(deadline)
if remaining <= 0 {
return fmt.Errorf("timeout waiting for TCP unreachable on %s", addr)
}
dialTimeout := min(interval, remaining)
conn, err := net.DialTimeout("tcp", addr, dialTimeout)
if err != nil {
return nil
}
conn.Close()
time.Sleep(min(interval, time.Until(deadline)))
}
}
// WaitForTCPReady polls a TCP address until a connection succeeds or timeout.
func WaitForTCPReady(addr string, timeout time.Duration) error {
if timeout <= 0 {

View File

@@ -120,6 +120,11 @@ func (p *Process) Output() string {
return p.stdOutput.String() + p.errorOutput.String()
}
// CountOutput returns how many times pattern appears in the current accumulated output.
func (p *Process) CountOutput(pattern string) int {
return strings.Count(p.Output(), pattern)
}
func (p *Process) SetBeforeStopHandler(fn func()) {
p.beforeStopHandler = fn
}

View File

@@ -41,24 +41,24 @@ var _ = ginkgo.Describe("[Feature: Chaos]", func() {
// 2. stop frps, expect request failed
_ = ps.Stop()
time.Sleep(200 * time.Millisecond)
framework.NewRequestExpect(f).Port(remotePort).ExpectError(true).Ensure()
// 3. restart frps, expect request success
successCount := pc.CountOutput("[tcp] start proxy success")
_, _, err = f.RunFrps("-c", serverConfigPath)
framework.ExpectNoError(err)
time.Sleep(2 * time.Second)
framework.ExpectNoError(pc.WaitForOutput("[tcp] start proxy success", successCount+1, 5*time.Second))
framework.NewRequestExpect(f).Port(remotePort).Ensure()
// 4. stop frpc, expect request failed
_ = pc.Stop()
time.Sleep(200 * time.Millisecond)
framework.ExpectNoError(framework.WaitForTCPUnreachable(fmt.Sprintf("127.0.0.1:%d", remotePort), 100*time.Millisecond, 5*time.Second))
framework.NewRequestExpect(f).Port(remotePort).ExpectError(true).Ensure()
// 5. restart frpc, expect request success
_, _, err = f.RunFrpc("-c", clientConfigPath)
newPc, _, err := f.RunFrpc("-c", clientConfigPath)
framework.ExpectNoError(err)
time.Sleep(time.Second)
framework.ExpectNoError(newPc.WaitForOutput("[tcp] start proxy success", 1, 5*time.Second))
framework.NewRequestExpect(f).Port(remotePort).Ensure()
})
})

View File

@@ -286,7 +286,7 @@ var _ = ginkgo.Describe("[Feature: Group]", func() {
healthCheck.intervalSeconds = 1
`, fooPort, remotePort, barPort, remotePort)
f.RunProcesses(serverConf, []string{clientConf})
_, clientProcesses := f.RunProcesses(serverConf, []string{clientConf})
// check foo and bar is ok
results := []string{}
@@ -299,15 +299,17 @@ var _ = ginkgo.Describe("[Feature: Group]", func() {
framework.ExpectContainElements(results, []string{"foo", "bar"})
// close bar server, check foo is ok
failedCount := clientProcesses[0].CountOutput("[bar] health check failed")
barServer.Close()
time.Sleep(2 * time.Second)
framework.ExpectNoError(clientProcesses[0].WaitForOutput("[bar] health check failed", failedCount+1, 5*time.Second))
for range 10 {
framework.NewRequestExpect(f).Port(remotePort).ExpectResp([]byte("foo")).Ensure()
}
// resume bar server, check foo and bar is ok
successCount := clientProcesses[0].CountOutput("[bar] health check success")
f.RunServer("", barServer)
time.Sleep(2 * time.Second)
framework.ExpectNoError(clientProcesses[0].WaitForOutput("[bar] health check success", successCount+1, 5*time.Second))
results = []string{}
for range 10 {
framework.NewRequestExpect(f).Port(remotePort).Ensure(validateFooBarResponse, func(resp *request.Response) bool {
@@ -357,7 +359,7 @@ var _ = ginkgo.Describe("[Feature: Group]", func() {
healthCheck.path = "/healthz"
`, fooPort, barPort)
f.RunProcesses(serverConf, []string{clientConf})
_, clientProcesses := f.RunProcesses(serverConf, []string{clientConf})
// send first HTTP request
var contents []string
@@ -387,15 +389,17 @@ var _ = ginkgo.Describe("[Feature: Group]", func() {
framework.ExpectContainElements(results, []string{"foo", "bar"})
// close bar server, check foo is ok
failedCount := clientProcesses[0].CountOutput("[bar] health check failed")
barServer.Close()
time.Sleep(2 * time.Second)
framework.ExpectNoError(clientProcesses[0].WaitForOutput("[bar] health check failed", failedCount+1, 5*time.Second))
results = doFooBarHTTPRequest(vhostPort, "example.com")
framework.ExpectContainElements(results, []string{"foo"})
framework.ExpectNotContainElements(results, []string{"bar"})
// resume bar server, check foo and bar is ok
successCount := clientProcesses[0].CountOutput("[bar] health check success")
f.RunServer("", barServer)
time.Sleep(2 * time.Second)
framework.ExpectNoError(clientProcesses[0].WaitForOutput("[bar] health check success", successCount+1, 5*time.Second))
results = doFooBarHTTPRequest(vhostPort, "example.com")
framework.ExpectContainElements(results, []string{"foo", "bar"})
})

View File

@@ -31,7 +31,7 @@ var _ = ginkgo.Describe("[Feature: Store]", func() {
`, adminPort, f.TempDirectory)
f.RunProcesses(serverConf, []string{clientConf})
time.Sleep(500 * time.Millisecond)
framework.ExpectNoError(framework.WaitForTCPReady(fmt.Sprintf("127.0.0.1:%d", adminPort), 5*time.Second))
proxyConfig := map[string]any{
"name": "test-tcp",
@@ -52,7 +52,7 @@ var _ = ginkgo.Describe("[Feature: Store]", func() {
return resp.Code == 200
})
time.Sleep(time.Second)
framework.ExpectNoError(framework.WaitForTCPReady(fmt.Sprintf("127.0.0.1:%d", remotePort), 5*time.Second))
framework.NewRequestExpect(f).Port(remotePort).Ensure()
})
@@ -72,7 +72,7 @@ var _ = ginkgo.Describe("[Feature: Store]", func() {
`, adminPort, f.TempDirectory)
f.RunProcesses(serverConf, []string{clientConf})
time.Sleep(500 * time.Millisecond)
framework.ExpectNoError(framework.WaitForTCPReady(fmt.Sprintf("127.0.0.1:%d", adminPort), 5*time.Second))
proxyConfig := map[string]any{
"name": "test-tcp",
@@ -93,7 +93,7 @@ var _ = ginkgo.Describe("[Feature: Store]", func() {
return resp.Code == 200
})
time.Sleep(time.Second)
framework.ExpectNoError(framework.WaitForTCPReady(fmt.Sprintf("127.0.0.1:%d", remotePort1), 5*time.Second))
framework.NewRequestExpect(f).Port(remotePort1).Ensure()
proxyConfig["tcp"].(map[string]any)["remotePort"] = remotePort2
@@ -107,7 +107,8 @@ var _ = ginkgo.Describe("[Feature: Store]", func() {
return resp.Code == 200
})
time.Sleep(time.Second)
framework.ExpectNoError(framework.WaitForTCPReady(fmt.Sprintf("127.0.0.1:%d", remotePort2), 5*time.Second))
framework.ExpectNoError(framework.WaitForTCPUnreachable(fmt.Sprintf("127.0.0.1:%d", remotePort1), 100*time.Millisecond, 5*time.Second))
framework.NewRequestExpect(f).Port(remotePort2).Ensure()
framework.NewRequestExpect(f).Port(remotePort1).ExpectError(true).Ensure()
})
@@ -126,7 +127,7 @@ var _ = ginkgo.Describe("[Feature: Store]", func() {
`, adminPort, f.TempDirectory)
f.RunProcesses(serverConf, []string{clientConf})
time.Sleep(500 * time.Millisecond)
framework.ExpectNoError(framework.WaitForTCPReady(fmt.Sprintf("127.0.0.1:%d", adminPort), 5*time.Second))
proxyConfig := map[string]any{
"name": "test-tcp",
@@ -147,7 +148,7 @@ var _ = ginkgo.Describe("[Feature: Store]", func() {
return resp.Code == 200
})
time.Sleep(time.Second)
framework.ExpectNoError(framework.WaitForTCPReady(fmt.Sprintf("127.0.0.1:%d", remotePort), 5*time.Second))
framework.NewRequestExpect(f).Port(remotePort).Ensure()
framework.NewRequestExpect(f).RequestModify(func(r *request.Request) {
@@ -156,7 +157,7 @@ var _ = ginkgo.Describe("[Feature: Store]", func() {
return resp.Code == 200
})
time.Sleep(time.Second)
framework.ExpectNoError(framework.WaitForTCPUnreachable(fmt.Sprintf("127.0.0.1:%d", remotePort), 100*time.Millisecond, 5*time.Second))
framework.NewRequestExpect(f).Port(remotePort).ExpectError(true).Ensure()
})
@@ -174,7 +175,7 @@ var _ = ginkgo.Describe("[Feature: Store]", func() {
`, adminPort, f.TempDirectory)
f.RunProcesses(serverConf, []string{clientConf})
time.Sleep(500 * time.Millisecond)
framework.ExpectNoError(framework.WaitForTCPReady(fmt.Sprintf("127.0.0.1:%d", adminPort), 5*time.Second))
proxyConfig := map[string]any{
"name": "test-tcp",
@@ -195,8 +196,6 @@ var _ = ginkgo.Describe("[Feature: Store]", func() {
return resp.Code == 200
})
time.Sleep(500 * time.Millisecond)
framework.NewRequestExpect(f).RequestModify(func(r *request.Request) {
r.HTTP().Port(adminPort).HTTPPath("/api/store/proxies")
}).Ensure(func(resp *request.Response) bool {
@@ -226,7 +225,7 @@ var _ = ginkgo.Describe("[Feature: Store]", func() {
`, adminPort)
f.RunProcesses(serverConf, []string{clientConf})
time.Sleep(500 * time.Millisecond)
framework.ExpectNoError(framework.WaitForTCPReady(fmt.Sprintf("127.0.0.1:%d", adminPort), 5*time.Second))
framework.NewRequestExpect(f).RequestModify(func(r *request.Request) {
r.HTTP().Port(adminPort).HTTPPath("/api/store/proxies")
@@ -248,7 +247,7 @@ var _ = ginkgo.Describe("[Feature: Store]", func() {
`, adminPort, f.TempDirectory)
f.RunProcesses(serverConf, []string{clientConf})
time.Sleep(500 * time.Millisecond)
framework.ExpectNoError(framework.WaitForTCPReady(fmt.Sprintf("127.0.0.1:%d", adminPort), 5*time.Second))
invalidBody, _ := json.Marshal(map[string]any{
"name": "bad-proxy",
@@ -281,7 +280,7 @@ var _ = ginkgo.Describe("[Feature: Store]", func() {
`, adminPort, f.TempDirectory)
f.RunProcesses(serverConf, []string{clientConf})
time.Sleep(500 * time.Millisecond)
framework.ExpectNoError(framework.WaitForTCPReady(fmt.Sprintf("127.0.0.1:%d", adminPort), 5*time.Second))
createBody, _ := json.Marshal(map[string]any{
"name": "proxy-a",