diff --git a/Release.md b/Release.md index 87a2cc95..9bee09dc 100644 --- a/Release.md +++ b/Release.md @@ -7,3 +7,4 @@ * Kept proxy/visitor names as raw config names during completion; moved user-prefix handling to explicit wire-level naming logic. * Added `noweb` build tag to allow compiling without frontend assets. `make build` now auto-detects missing `web/*/dist` directories and skips embedding, so a fresh clone can build without running `make web` first. The dashboard gracefully returns 404 when assets are not embedded. * Improved config parsing errors: for `.toml` files, syntax errors now return immediately with parser position details (line/column when available) instead of falling through to YAML/JSON parsing, and TOML type mismatches report field-level errors without misleading line numbers. +* OIDC auth now caches the access token and refreshes it before expiry, avoiding a new token request on every heartbeat. Falls back to per-request fetch when the provider omits `expires_in`. diff --git a/client/service.go b/client/service.go index 26f1db18..f419b068 100644 --- a/client/service.go +++ b/client/service.go @@ -19,6 +19,7 @@ import ( "errors" "fmt" "net" + "net/http" "os" "runtime" "sync" @@ -162,15 +163,6 @@ func NewService(options ServiceOptions) (*Service, error) { return nil, err } - var webServer *httppkg.Server - if options.Common.WebServer.Port > 0 { - ws, err := httppkg.NewServer(options.Common.WebServer) - if err != nil { - return nil, err - } - webServer = ws - } - authRuntime, err := auth.BuildClientAuth(&options.Common.Auth) if err != nil { return nil, err @@ -191,6 +183,17 @@ func NewService(options ServiceOptions) (*Service, error) { proxyCfgs = config.CompleteProxyConfigurers(proxyCfgs) visitorCfgs = config.CompleteVisitorConfigurers(visitorCfgs) + // Create the web server after all fallible steps so its listener is not + // leaked when an earlier error causes NewService to return. + var webServer *httppkg.Server + if options.Common.WebServer.Port > 0 { + ws, err := httppkg.NewServer(options.Common.WebServer) + if err != nil { + return nil, err + } + webServer = ws + } + s := &Service{ ctx: context.Background(), auth: authRuntime, @@ -229,22 +232,25 @@ func (svr *Service) Run(ctx context.Context) error { } if svr.vnetController != nil { + vnetController := svr.vnetController if err := svr.vnetController.Init(); err != nil { log.Errorf("init virtual network controller error: %v", err) + svr.stop() return err } go func() { log.Infof("virtual network controller start...") - if err := svr.vnetController.Run(); err != nil { + if err := vnetController.Run(); err != nil && !errors.Is(err, net.ErrClosed) { log.Warnf("virtual network controller exit with error: %v", err) } }() } if svr.webServer != nil { + webServer := svr.webServer go func() { - log.Infof("admin server listen on %s", svr.webServer.Address()) - if err := svr.webServer.Run(); err != nil { + log.Infof("admin server listen on %s", webServer.Address()) + if err := webServer.Run(); err != nil && !errors.Is(err, http.ErrServerClosed) { log.Warnf("admin server exit with error: %v", err) } }() @@ -255,6 +261,7 @@ func (svr *Service) Run(ctx context.Context) error { if svr.ctl == nil { cancelCause := cancelErr{} _ = errors.As(context.Cause(svr.ctx), &cancelCause) + svr.stop() return fmt.Errorf("login to the server failed: %v. With loginFailExit enabled, no additional retries will be attempted", cancelCause.Err) } @@ -497,6 +504,10 @@ func (svr *Service) stop() { svr.webServer.Close() svr.webServer = nil } + if svr.vnetController != nil { + _ = svr.vnetController.Stop() + svr.vnetController = nil + } } func (svr *Service) getProxyStatus(name string) (*proxy.WorkingStatus, bool) { diff --git a/client/service_test.go b/client/service_test.go index e1c6b587..29f141a1 100644 --- a/client/service_test.go +++ b/client/service_test.go @@ -1,14 +1,120 @@ package client import ( + "context" + "errors" + "net" "path/filepath" + "strconv" "strings" "testing" + "github.com/samber/lo" + "github.com/fatedier/frp/pkg/config/source" v1 "github.com/fatedier/frp/pkg/config/v1" ) +type failingConnector struct { + err error +} + +func (c *failingConnector) Open() error { + return c.err +} + +func (c *failingConnector) Connect() (net.Conn, error) { + return nil, c.err +} + +func (c *failingConnector) Close() error { + return nil +} + +func getFreeTCPPort(t *testing.T) int { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen on ephemeral port: %v", err) + } + defer ln.Close() + + return ln.Addr().(*net.TCPAddr).Port +} + +func TestRunStopsStartedComponentsOnInitialLoginFailure(t *testing.T) { + port := getFreeTCPPort(t) + agg := source.NewAggregator(source.NewConfigSource()) + + svr, err := NewService(ServiceOptions{ + Common: &v1.ClientCommonConfig{ + LoginFailExit: lo.ToPtr(true), + WebServer: v1.WebServerConfig{ + Addr: "127.0.0.1", + Port: port, + }, + }, + ConfigSourceAggregator: agg, + ConnectorCreator: func(context.Context, *v1.ClientCommonConfig) Connector { + return &failingConnector{err: errors.New("login boom")} + }, + }) + if err != nil { + t.Fatalf("new service: %v", err) + } + + err = svr.Run(context.Background()) + if err == nil { + t.Fatal("expected run error, got nil") + } + if !strings.Contains(err.Error(), "login boom") { + t.Fatalf("unexpected error: %v", err) + } + if svr.webServer != nil { + t.Fatal("expected web server to be cleaned up after initial login failure") + } + + ln, err := net.Listen("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(port))) + if err != nil { + t.Fatalf("expected admin port to be released: %v", err) + } + _ = ln.Close() +} + +func TestNewServiceDoesNotLeakAdminListenerOnAuthBuildFailure(t *testing.T) { + port := getFreeTCPPort(t) + agg := source.NewAggregator(source.NewConfigSource()) + + _, err := NewService(ServiceOptions{ + Common: &v1.ClientCommonConfig{ + Auth: v1.AuthClientConfig{ + Method: v1.AuthMethodOIDC, + OIDC: v1.AuthOIDCClientConfig{ + TokenEndpointURL: "://bad", + }, + }, + WebServer: v1.WebServerConfig{ + Addr: "127.0.0.1", + Port: port, + }, + }, + ConfigSourceAggregator: agg, + }) + if err == nil { + t.Fatal("expected new service error, got nil") + } + if !strings.Contains(err.Error(), "auth.oidc.tokenEndpointURL") { + t.Fatalf("unexpected error: %v", err) + } + + ln, err := net.Listen("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(port))) + if err != nil { + t.Fatalf("expected admin port to remain free: %v", err) + } + _ = ln.Close() +} + func TestUpdateConfigSourceRollsBackReloadCommonOnReplaceAllFailure(t *testing.T) { prevCommon := &v1.ClientCommonConfig{User: "old-user"} newCommon := &v1.ClientCommonConfig{User: "new-user"} diff --git a/pkg/auth/oidc.go b/pkg/auth/oidc.go index e63322f7..826a6715 100644 --- a/pkg/auth/oidc.go +++ b/pkg/auth/oidc.go @@ -30,6 +30,7 @@ import ( "golang.org/x/oauth2/clientcredentials" v1 "github.com/fatedier/frp/pkg/config/v1" + "github.com/fatedier/frp/pkg/config/v1/validation" "github.com/fatedier/frp/pkg/msg" ) @@ -88,6 +89,40 @@ func (s *nonCachingTokenSource) Token() (*oauth2.Token, error) { return s.cfg.Token(s.ctx) } +// oidcTokenSource wraps a caching oauth2.TokenSource and, on the first +// successful Token() call, checks whether the provider returns an expiry. +// If not, it permanently switches to nonCachingTokenSource so that a fresh +// token is fetched every time. This avoids an eager network call at +// construction time, letting the login retry loop handle transient IdP +// outages. +type oidcTokenSource struct { + mu sync.Mutex + initialized bool + source oauth2.TokenSource + fallbackCfg *clientcredentials.Config + fallbackCtx context.Context +} + +func (s *oidcTokenSource) Token() (*oauth2.Token, error) { + s.mu.Lock() + if !s.initialized { + token, err := s.source.Token() + if err != nil { + s.mu.Unlock() + return nil, err + } + if token.Expiry.IsZero() { + s.source = &nonCachingTokenSource{cfg: s.fallbackCfg, ctx: s.fallbackCtx} + } + s.initialized = true + s.mu.Unlock() + return token, nil + } + source := s.source + s.mu.Unlock() + return source.Token() +} + type OidcAuthProvider struct { additionalAuthScopes []v1.AuthScope @@ -95,6 +130,10 @@ type OidcAuthProvider struct { } func NewOidcAuthSetter(additionalAuthScopes []v1.AuthScope, cfg v1.AuthOIDCClientConfig) (*OidcAuthProvider, error) { + if err := validation.ValidateOIDCClientCredentialsConfig(&cfg); err != nil { + return nil, err + } + eps := make(map[string][]string) for k, v := range cfg.AdditionalEndpointParams { eps[k] = []string{v} @@ -127,24 +166,22 @@ func NewOidcAuthSetter(additionalAuthScopes []v1.AuthScope, cfg v1.AuthOIDCClien // Create a persistent TokenSource that caches the token and refreshes // it before expiry. This avoids making a new HTTP request to the OIDC // provider on every heartbeat/ping. - tokenSource := tokenGenerator.TokenSource(ctx) - - // Fetch the initial token to check if the provider returns an expiry. - // If Expiry is the zero value (provider omitted expires_in), the cached - // TokenSource would treat the token as valid forever and never refresh it, - // even after the JWT's exp claim passes. In that case, fall back to - // fetching a fresh token on every request. - initialToken, err := tokenSource.Token() - if err != nil { - return nil, fmt.Errorf("failed to obtain initial OIDC token: %w", err) - } - if initialToken.Expiry.IsZero() { - tokenSource = &nonCachingTokenSource{cfg: tokenGenerator, ctx: ctx} - } + // + // We wrap it in an oidcTokenSource so that the first Token() call + // (deferred to SetLogin inside the login retry loop) probes whether the + // provider returns expires_in. If not, it switches to a non-caching + // source. This avoids an eager network call at construction time, which + // would prevent loopLoginUntilSuccess from retrying on transient IdP + // outages. + cachingSource := tokenGenerator.TokenSource(ctx) return &OidcAuthProvider{ additionalAuthScopes: additionalAuthScopes, - tokenSource: tokenSource, + tokenSource: &oidcTokenSource{ + source: cachingSource, + fallbackCfg: tokenGenerator, + fallbackCtx: ctx, + }, }, nil } diff --git a/pkg/auth/oidc_test.go b/pkg/auth/oidc_test.go index ad4f0246..70e59883 100644 --- a/pkg/auth/oidc_test.go +++ b/pkg/auth/oidc_test.go @@ -91,8 +91,10 @@ func TestOidcAuthProviderFallsBackWhenNoExpiry(t *testing.T) { ) r.NoError(err) - // Constructor fetches the initial token (1 request). - // Each subsequent call should also fetch a fresh token since there is no expiry. + // Constructor no longer fetches a token eagerly. + // The first SetLogin triggers the adaptive probe. + r.Equal(int32(0), requestCount.Load()) + loginMsg := &msg.Login{} err = provider.SetLogin(loginMsg) r.NoError(err) @@ -105,8 +107,8 @@ func TestOidcAuthProviderFallsBackWhenNoExpiry(t *testing.T) { r.Equal("fresh-test-token", pingMsg.PrivilegeKey) } - // 1 initial (constructor) + 1 login + 3 pings = 5 requests - r.Equal(int32(5), requestCount.Load(), "each call should fetch a fresh token when expires_in is missing") + // 1 probe (login) + 3 pings = 4 requests (probe doubles as the login token fetch) + r.Equal(int32(4), requestCount.Load(), "each call should fetch a fresh token when expires_in is missing") } func TestOidcAuthProviderCachesToken(t *testing.T) { @@ -134,10 +136,10 @@ func TestOidcAuthProviderCachesToken(t *testing.T) { ) r.NoError(err) - // Constructor eagerly fetches the initial token (1 request). - r.Equal(int32(1), requestCount.Load()) + // Constructor no longer fetches eagerly; first SetLogin triggers the probe. + r.Equal(int32(0), requestCount.Load()) - // SetLogin should reuse the cached token + // SetLogin triggers the adaptive probe and caches the token. loginMsg := &msg.Login{} err = provider.SetLogin(loginMsg) r.NoError(err) @@ -153,3 +155,99 @@ func TestOidcAuthProviderCachesToken(t *testing.T) { } r.Equal(int32(1), requestCount.Load(), "token endpoint should only be called once; cached token should be reused") } + +func TestOidcAuthProviderRetriesOnInitialFailure(t *testing.T) { + r := require.New(t) + + var requestCount atomic.Int32 + tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + n := requestCount.Add(1) + // The oauth2 library retries once internally, so we need two + // consecutive failures to surface an error to the caller. + if n <= 2 { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]any{ + "error": "temporarily_unavailable", + "error_description": "service is starting up", + }) + return + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ //nolint:gosec // test-only dummy token response + "access_token": "retry-test-token", + "token_type": "Bearer", + "expires_in": 3600, + }) + })) + defer tokenServer.Close() + + // Constructor succeeds even though the IdP is "down". + provider, err := auth.NewOidcAuthSetter( + []v1.AuthScope{v1.AuthScopeHeartBeats}, + v1.AuthOIDCClientConfig{ + ClientID: "test-client", + ClientSecret: "test-secret", + TokenEndpointURL: tokenServer.URL, + }, + ) + r.NoError(err) + r.Equal(int32(0), requestCount.Load()) + + // First SetLogin hits the IdP, which returns an error (after internal retry). + loginMsg := &msg.Login{} + err = provider.SetLogin(loginMsg) + r.Error(err) + r.Equal(int32(2), requestCount.Load()) + + // Second SetLogin retries and succeeds. + err = provider.SetLogin(loginMsg) + r.NoError(err) + r.Equal("retry-test-token", loginMsg.PrivilegeKey) + r.Equal(int32(3), requestCount.Load()) + + // Subsequent calls use cached token. + pingMsg := &msg.Ping{} + err = provider.SetPing(pingMsg) + r.NoError(err) + r.Equal("retry-test-token", pingMsg.PrivilegeKey) + r.Equal(int32(3), requestCount.Load()) +} + +func TestNewOidcAuthSetterRejectsInvalidStaticConfig(t *testing.T) { + r := require.New(t) + tokenServer := httptest.NewServer(http.NotFoundHandler()) + defer tokenServer.Close() + + _, err := auth.NewOidcAuthSetter(nil, v1.AuthOIDCClientConfig{ + ClientID: "test-client", + TokenEndpointURL: "://bad", + }) + r.Error(err) + r.Contains(err.Error(), "auth.oidc.tokenEndpointURL") + + _, err = auth.NewOidcAuthSetter(nil, v1.AuthOIDCClientConfig{ + TokenEndpointURL: tokenServer.URL, + }) + r.Error(err) + r.Contains(err.Error(), "auth.oidc.clientID is required") + + _, err = auth.NewOidcAuthSetter(nil, v1.AuthOIDCClientConfig{ + ClientID: "test-client", + TokenEndpointURL: tokenServer.URL, + AdditionalEndpointParams: map[string]string{ + "scope": "profile", + }, + }) + r.Error(err) + r.Contains(err.Error(), "auth.oidc.additionalEndpointParams.scope is not allowed; use auth.oidc.scope instead") + + _, err = auth.NewOidcAuthSetter(nil, v1.AuthOIDCClientConfig{ + ClientID: "test-client", + TokenEndpointURL: tokenServer.URL, + Audience: "api", + AdditionalEndpointParams: map[string]string{"audience": "override"}, + }) + r.Error(err) + r.Contains(err.Error(), "cannot specify both auth.oidc.audience and auth.oidc.additionalEndpointParams.audience") +} diff --git a/pkg/config/v1/validation/client.go b/pkg/config/v1/validation/client.go index eb4a0253..c90d525d 100644 --- a/pkg/config/v1/validation/client.go +++ b/pkg/config/v1/validation/client.go @@ -88,6 +88,11 @@ func (v *ConfigValidator) validateAuthConfig(c *v1.AuthClientConfig) (Warning, e if err := v.validateOIDCConfig(&c.OIDC); err != nil { errs = AppendError(errs, err) } + if c.Method == v1.AuthMethodOIDC && c.OIDC.TokenSource == nil { + if err := ValidateOIDCClientCredentialsConfig(&c.OIDC); err != nil { + errs = AppendError(errs, err) + } + } return nil, errs } diff --git a/pkg/config/v1/validation/oidc.go b/pkg/config/v1/validation/oidc.go new file mode 100644 index 00000000..c905e8e5 --- /dev/null +++ b/pkg/config/v1/validation/oidc.go @@ -0,0 +1,57 @@ +// Copyright 2026 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validation + +import ( + "errors" + "net/url" + "strings" + + v1 "github.com/fatedier/frp/pkg/config/v1" +) + +func ValidateOIDCClientCredentialsConfig(c *v1.AuthOIDCClientConfig) error { + var errs []string + + if c.ClientID == "" { + errs = append(errs, "auth.oidc.clientID is required") + } + + if c.TokenEndpointURL == "" { + errs = append(errs, "auth.oidc.tokenEndpointURL is required") + } else { + tokenURL, err := url.Parse(c.TokenEndpointURL) + if err != nil || !tokenURL.IsAbs() || tokenURL.Host == "" { + errs = append(errs, "auth.oidc.tokenEndpointURL must be an absolute http or https URL") + } else if tokenURL.Scheme != "http" && tokenURL.Scheme != "https" { + errs = append(errs, "auth.oidc.tokenEndpointURL must use http or https") + } + } + + if _, ok := c.AdditionalEndpointParams["scope"]; ok { + errs = append(errs, "auth.oidc.additionalEndpointParams.scope is not allowed; use auth.oidc.scope instead") + } + + if c.Audience != "" { + if _, ok := c.AdditionalEndpointParams["audience"]; ok { + errs = append(errs, "cannot specify both auth.oidc.audience and auth.oidc.additionalEndpointParams.audience") + } + } + + if len(errs) == 0 { + return nil + } + return errors.New(strings.Join(errs, "; ")) +} diff --git a/pkg/config/v1/validation/oidc_test.go b/pkg/config/v1/validation/oidc_test.go new file mode 100644 index 00000000..bc21da6e --- /dev/null +++ b/pkg/config/v1/validation/oidc_test.go @@ -0,0 +1,78 @@ +// Copyright 2026 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validation + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + + v1 "github.com/fatedier/frp/pkg/config/v1" +) + +func TestValidateOIDCClientCredentialsConfig(t *testing.T) { + tokenServer := httptest.NewServer(http.NotFoundHandler()) + defer tokenServer.Close() + + t.Run("valid", func(t *testing.T) { + require.NoError(t, ValidateOIDCClientCredentialsConfig(&v1.AuthOIDCClientConfig{ + ClientID: "test-client", + TokenEndpointURL: tokenServer.URL, + AdditionalEndpointParams: map[string]string{ + "resource": "api", + }, + })) + }) + + t.Run("invalid token endpoint url", func(t *testing.T) { + err := ValidateOIDCClientCredentialsConfig(&v1.AuthOIDCClientConfig{ + ClientID: "test-client", + TokenEndpointURL: "://bad", + }) + require.ErrorContains(t, err, "auth.oidc.tokenEndpointURL") + }) + + t.Run("missing client id", func(t *testing.T) { + err := ValidateOIDCClientCredentialsConfig(&v1.AuthOIDCClientConfig{ + TokenEndpointURL: tokenServer.URL, + }) + require.ErrorContains(t, err, "auth.oidc.clientID is required") + }) + + t.Run("scope endpoint param is not allowed", func(t *testing.T) { + err := ValidateOIDCClientCredentialsConfig(&v1.AuthOIDCClientConfig{ + ClientID: "test-client", + TokenEndpointURL: tokenServer.URL, + AdditionalEndpointParams: map[string]string{ + "scope": "email", + }, + }) + require.ErrorContains(t, err, "auth.oidc.additionalEndpointParams.scope is not allowed; use auth.oidc.scope instead") + }) + + t.Run("audience conflict", func(t *testing.T) { + err := ValidateOIDCClientCredentialsConfig(&v1.AuthOIDCClientConfig{ + ClientID: "test-client", + TokenEndpointURL: tokenServer.URL, + Audience: "api", + AdditionalEndpointParams: map[string]string{ + "audience": "override", + }, + }) + require.ErrorContains(t, err, "cannot specify both auth.oidc.audience and auth.oidc.additionalEndpointParams.audience") + }) +} diff --git a/pkg/util/http/server.go b/pkg/util/http/server.go index 99bed364..0bca8993 100644 --- a/pkg/util/http/server.go +++ b/pkg/util/http/server.go @@ -100,7 +100,11 @@ func (s *Server) Run() error { } func (s *Server) Close() error { - return s.hs.Close() + err := s.hs.Close() + if s.ln != nil { + _ = s.ln.Close() + } + return err } type RouterRegisterHelper struct { diff --git a/pkg/vnet/controller.go b/pkg/vnet/controller.go index ca71a8c3..d5c97c66 100644 --- a/pkg/vnet/controller.go +++ b/pkg/vnet/controller.go @@ -131,6 +131,9 @@ func (c *Controller) handlePacket(buf []byte) { } func (c *Controller) Stop() error { + if c.tun == nil { + return nil + } return c.tun.Close() } diff --git a/test/e2e/mock/server/oidcserver/oidcserver.go b/test/e2e/mock/server/oidcserver/oidcserver.go new file mode 100644 index 00000000..d7aa1329 --- /dev/null +++ b/test/e2e/mock/server/oidcserver/oidcserver.go @@ -0,0 +1,258 @@ +// Copyright 2026 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package oidcserver provides a minimal mock OIDC server for e2e testing. +// It implements three endpoints: +// - /.well-known/openid-configuration (discovery) +// - /jwks (JSON Web Key Set) +// - /token (client_credentials grant) +package oidcserver + +import ( + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "math/big" + "net" + "net/http" + "strconv" + "sync/atomic" + "time" +) + +type Server struct { + bindAddr string + bindPort int + l net.Listener + hs *http.Server + + privateKey *rsa.PrivateKey + kid string + + clientID string + clientSecret string + audience string + subject string + expiresIn int // seconds; 0 means omit expires_in from token response + + tokenRequestCount atomic.Int64 +} + +type Option func(*Server) + +func WithBindPort(port int) Option { + return func(s *Server) { s.bindPort = port } +} + +func WithClientCredentials(id, secret string) Option { + return func(s *Server) { + s.clientID = id + s.clientSecret = secret + } +} + +func WithAudience(aud string) Option { + return func(s *Server) { s.audience = aud } +} + +func WithSubject(sub string) Option { + return func(s *Server) { s.subject = sub } +} + +func WithExpiresIn(seconds int) Option { + return func(s *Server) { s.expiresIn = seconds } +} + +func New(options ...Option) *Server { + s := &Server{ + bindAddr: "127.0.0.1", + kid: "test-key-1", + clientID: "test-client", + clientSecret: "test-secret", + audience: "frps", + subject: "test-service", + expiresIn: 3600, + } + for _, opt := range options { + opt(s) + } + return s +} + +func (s *Server) Run() error { + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return fmt.Errorf("generate RSA key: %w", err) + } + s.privateKey = key + + s.l, err = net.Listen("tcp", net.JoinHostPort(s.bindAddr, strconv.Itoa(s.bindPort))) + if err != nil { + return err + } + s.bindPort = s.l.Addr().(*net.TCPAddr).Port + + mux := http.NewServeMux() + mux.HandleFunc("/.well-known/openid-configuration", s.handleDiscovery) + mux.HandleFunc("/jwks", s.handleJWKS) + mux.HandleFunc("/token", s.handleToken) + + s.hs = &http.Server{ + Handler: mux, + ReadHeaderTimeout: time.Minute, + } + go func() { _ = s.hs.Serve(s.l) }() + return nil +} + +func (s *Server) Close() error { + if s.hs != nil { + return s.hs.Close() + } + return nil +} + +func (s *Server) BindAddr() string { return s.bindAddr } +func (s *Server) BindPort() int { return s.bindPort } + +func (s *Server) Issuer() string { + return fmt.Sprintf("http://%s:%d", s.bindAddr, s.bindPort) +} + +func (s *Server) TokenEndpoint() string { + return s.Issuer() + "/token" +} + +// TokenRequestCount returns the number of successful token requests served. +func (s *Server) TokenRequestCount() int64 { + return s.tokenRequestCount.Load() +} + +func (s *Server) handleDiscovery(w http.ResponseWriter, _ *http.Request) { + issuer := s.Issuer() + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "issuer": issuer, + "token_endpoint": issuer + "/token", + "jwks_uri": issuer + "/jwks", + "response_types_supported": []string{"code"}, + "subject_types_supported": []string{"public"}, + "id_token_signing_alg_values_supported": []string{"RS256"}, + }) +} + +func (s *Server) handleJWKS(w http.ResponseWriter, _ *http.Request) { + pub := &s.privateKey.PublicKey + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "keys": []map[string]any{ + { + "kty": "RSA", + "alg": "RS256", + "use": "sig", + "kid": s.kid, + "n": base64.RawURLEncoding.EncodeToString(pub.N.Bytes()), + "e": base64.RawURLEncoding.EncodeToString(big.NewInt(int64(pub.E)).Bytes()), + }, + }, + }) +} + +func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + if err := r.ParseForm(); err != nil { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]any{ + "error": "invalid_request", + }) + return + } + + if r.FormValue("grant_type") != "client_credentials" { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]any{ + "error": "unsupported_grant_type", + }) + return + } + + // Accept credentials from Basic Auth or form body. + clientID, clientSecret, ok := r.BasicAuth() + if !ok { + clientID = r.FormValue("client_id") + clientSecret = r.FormValue("client_secret") + } + if clientID != s.clientID || clientSecret != s.clientSecret { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + _ = json.NewEncoder(w).Encode(map[string]any{ + "error": "invalid_client", + }) + return + } + + token, err := s.signJWT() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + resp := map[string]any{ + "access_token": token, + "token_type": "Bearer", + } + if s.expiresIn > 0 { + resp["expires_in"] = s.expiresIn + } + + s.tokenRequestCount.Add(1) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(resp) +} + +func (s *Server) signJWT() (string, error) { + now := time.Now() + header, _ := json.Marshal(map[string]string{ + "alg": "RS256", + "kid": s.kid, + "typ": "JWT", + }) + claims, _ := json.Marshal(map[string]any{ + "iss": s.Issuer(), + "sub": s.subject, + "aud": s.audience, + "iat": now.Unix(), + "exp": now.Add(1 * time.Hour).Unix(), + }) + + headerB64 := base64.RawURLEncoding.EncodeToString(header) + claimsB64 := base64.RawURLEncoding.EncodeToString(claims) + signingInput := headerB64 + "." + claimsB64 + + h := sha256.Sum256([]byte(signingInput)) + sig, err := rsa.SignPKCS1v15(rand.Reader, s.privateKey, crypto.SHA256, h[:]) + if err != nil { + return "", err + } + return signingInput + "." + base64.RawURLEncoding.EncodeToString(sig), nil +} diff --git a/test/e2e/v1/basic/oidc.go b/test/e2e/v1/basic/oidc.go new file mode 100644 index 00000000..19539c76 --- /dev/null +++ b/test/e2e/v1/basic/oidc.go @@ -0,0 +1,192 @@ +// Copyright 2026 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package basic + +import ( + "fmt" + "time" + + "github.com/onsi/ginkgo/v2" + + "github.com/fatedier/frp/test/e2e/framework" + "github.com/fatedier/frp/test/e2e/framework/consts" + "github.com/fatedier/frp/test/e2e/mock/server/oidcserver" + "github.com/fatedier/frp/test/e2e/pkg/port" +) + +var _ = ginkgo.Describe("[Feature: OIDC]", func() { + f := framework.NewDefaultFramework() + + ginkgo.It("should work with OIDC authentication", func() { + oidcSrv := oidcserver.New(oidcserver.WithBindPort(f.AllocPort())) + f.RunServer("", oidcSrv) + + portName := port.GenName("TCP") + + serverConf := consts.DefaultServerConfig + fmt.Sprintf(` +auth.method = "oidc" +auth.oidc.issuer = "%s" +auth.oidc.audience = "frps" +`, oidcSrv.Issuer()) + + clientConf := consts.DefaultClientConfig + fmt.Sprintf(` +auth.method = "oidc" +auth.oidc.clientID = "test-client" +auth.oidc.clientSecret = "test-secret" +auth.oidc.tokenEndpointURL = "%s" + +[[proxies]] +name = "tcp" +type = "tcp" +localPort = {{ .%s }} +remotePort = {{ .%s }} +`, oidcSrv.TokenEndpoint(), framework.TCPEchoServerPort, portName) + + f.RunProcesses(serverConf, []string{clientConf}) + framework.NewRequestExpect(f).PortName(portName).Ensure() + }) + + ginkgo.It("should authenticate heartbeats with OIDC", func() { + oidcSrv := oidcserver.New(oidcserver.WithBindPort(f.AllocPort())) + f.RunServer("", oidcSrv) + + serverPort := f.AllocPort() + remotePort := f.AllocPort() + + serverConf := fmt.Sprintf(` +bindAddr = "0.0.0.0" +bindPort = %d +log.level = "trace" +auth.method = "oidc" +auth.additionalScopes = ["HeartBeats"] +auth.oidc.issuer = "%s" +auth.oidc.audience = "frps" +`, serverPort, oidcSrv.Issuer()) + + clientConf := fmt.Sprintf(` +serverAddr = "127.0.0.1" +serverPort = %d +loginFailExit = false +log.level = "trace" +auth.method = "oidc" +auth.additionalScopes = ["HeartBeats"] +auth.oidc.clientID = "test-client" +auth.oidc.clientSecret = "test-secret" +auth.oidc.tokenEndpointURL = "%s" +transport.heartbeatInterval = 1 + +[[proxies]] +name = "tcp" +type = "tcp" +localPort = %d +remotePort = %d +`, serverPort, oidcSrv.TokenEndpoint(), f.PortByName(framework.TCPEchoServerPort), remotePort) + + serverConfigPath := f.GenerateConfigFile(serverConf) + clientConfigPath := f.GenerateConfigFile(clientConf) + + _, _, err := f.RunFrps("-c", serverConfigPath) + framework.ExpectNoError(err) + clientProcess, _, err := f.RunFrpc("-c", clientConfigPath) + framework.ExpectNoError(err) + + // Wait for several authenticated heartbeat cycles instead of a fixed sleep. + err = clientProcess.WaitForOutput("send heartbeat to server", 3, 10*time.Second) + framework.ExpectNoError(err) + + // Proxy should still work: heartbeat auth has not failed. + framework.NewRequestExpect(f).Port(remotePort).Ensure() + }) + + ginkgo.It("should work when token has no expires_in", func() { + oidcSrv := oidcserver.New( + oidcserver.WithBindPort(f.AllocPort()), + oidcserver.WithExpiresIn(0), + ) + f.RunServer("", oidcSrv) + + portName := port.GenName("TCP") + + serverConf := consts.DefaultServerConfig + fmt.Sprintf(` +auth.method = "oidc" +auth.oidc.issuer = "%s" +auth.oidc.audience = "frps" +`, oidcSrv.Issuer()) + + clientConf := consts.DefaultClientConfig + fmt.Sprintf(` +auth.method = "oidc" +auth.additionalScopes = ["HeartBeats"] +auth.oidc.clientID = "test-client" +auth.oidc.clientSecret = "test-secret" +auth.oidc.tokenEndpointURL = "%s" +transport.heartbeatInterval = 1 + +[[proxies]] +name = "tcp" +type = "tcp" +localPort = {{ .%s }} +remotePort = {{ .%s }} +`, oidcSrv.TokenEndpoint(), framework.TCPEchoServerPort, portName) + + _, clientProcesses := f.RunProcesses(serverConf, []string{clientConf}) + framework.NewRequestExpect(f).PortName(portName).Ensure() + + countAfterLogin := oidcSrv.TokenRequestCount() + + // Wait for several heartbeat cycles instead of a fixed sleep. + // Each heartbeat fetches a fresh token in non-caching mode. + err := clientProcesses[0].WaitForOutput("send heartbeat to server", 3, 10*time.Second) + framework.ExpectNoError(err) + + framework.NewRequestExpect(f).PortName(portName).Ensure() + + // Each heartbeat should have fetched a new token (non-caching mode). + countAfterHeartbeats := oidcSrv.TokenRequestCount() + framework.ExpectTrue( + countAfterHeartbeats > countAfterLogin, + "expected additional token requests for heartbeats, got %d before and %d after", + countAfterLogin, countAfterHeartbeats, + ) + }) + + ginkgo.It("should reject invalid OIDC credentials", func() { + oidcSrv := oidcserver.New(oidcserver.WithBindPort(f.AllocPort())) + f.RunServer("", oidcSrv) + + portName := port.GenName("TCP") + + serverConf := consts.DefaultServerConfig + fmt.Sprintf(` +auth.method = "oidc" +auth.oidc.issuer = "%s" +auth.oidc.audience = "frps" +`, oidcSrv.Issuer()) + + clientConf := consts.DefaultClientConfig + fmt.Sprintf(` +auth.method = "oidc" +auth.oidc.clientID = "test-client" +auth.oidc.clientSecret = "wrong-secret" +auth.oidc.tokenEndpointURL = "%s" + +[[proxies]] +name = "tcp" +type = "tcp" +localPort = {{ .%s }} +remotePort = {{ .%s }} +`, oidcSrv.TokenEndpoint(), framework.TCPEchoServerPort, portName) + + f.RunProcesses(serverConf, []string{clientConf}) + framework.NewRequestExpect(f).PortName(portName).ExpectError(true).Ensure() + }) +})