// Copyright 2020 guylewin, guy@lewin.co.il // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package auth import ( "context" "crypto/tls" "crypto/x509" "fmt" "net/http" "net/url" "os" "slices" "sync" "github.com/coreos/go-oidc/v3/oidc" "golang.org/x/oauth2" "golang.org/x/oauth2/clientcredentials" v1 "github.com/fatedier/frp/pkg/config/v1" "github.com/fatedier/frp/pkg/msg" ) // createOIDCHTTPClient creates an HTTP client with custom TLS and proxy configuration for OIDC token requests func createOIDCHTTPClient(trustedCAFile string, insecureSkipVerify bool, proxyURL string) (*http.Client, error) { // Clone the default transport to get all reasonable defaults transport := http.DefaultTransport.(*http.Transport).Clone() // Configure TLS settings if trustedCAFile != "" || insecureSkipVerify { tlsConfig := &tls.Config{ InsecureSkipVerify: insecureSkipVerify, } if trustedCAFile != "" && !insecureSkipVerify { caCert, err := os.ReadFile(trustedCAFile) if err != nil { return nil, fmt.Errorf("failed to read OIDC CA certificate file %q: %w", trustedCAFile, err) } caCertPool := x509.NewCertPool() if !caCertPool.AppendCertsFromPEM(caCert) { return nil, fmt.Errorf("failed to parse OIDC CA certificate from file %q", trustedCAFile) } tlsConfig.RootCAs = caCertPool } transport.TLSClientConfig = tlsConfig } // Configure proxy settings if proxyURL != "" { parsedURL, err := url.Parse(proxyURL) if err != nil { return nil, fmt.Errorf("failed to parse OIDC proxy URL %q: %w", proxyURL, err) } transport.Proxy = http.ProxyURL(parsedURL) } else { // Explicitly disable proxy to override DefaultTransport's ProxyFromEnvironment transport.Proxy = nil } return &http.Client{Transport: transport}, nil } // nonCachingTokenSource wraps a clientcredentials.Config to fetch a fresh // token on every call. This is used as a fallback when the OIDC provider // does not return expires_in, which would cause a caching TokenSource to // hold onto a stale token forever. type nonCachingTokenSource struct { cfg *clientcredentials.Config ctx context.Context } func (s *nonCachingTokenSource) Token() (*oauth2.Token, error) { return s.cfg.Token(s.ctx) } type OidcAuthProvider struct { additionalAuthScopes []v1.AuthScope tokenSource oauth2.TokenSource } func NewOidcAuthSetter(additionalAuthScopes []v1.AuthScope, cfg v1.AuthOIDCClientConfig) (*OidcAuthProvider, error) { eps := make(map[string][]string) for k, v := range cfg.AdditionalEndpointParams { eps[k] = []string{v} } if cfg.Audience != "" { eps["audience"] = []string{cfg.Audience} } tokenGenerator := &clientcredentials.Config{ ClientID: cfg.ClientID, ClientSecret: cfg.ClientSecret, Scopes: []string{cfg.Scope}, TokenURL: cfg.TokenEndpointURL, EndpointParams: eps, } // Build the context that TokenSource will use for all future HTTP requests. // context.Background() is appropriate here because the token source is // long-lived and outlives any single request. ctx := context.Background() if cfg.TrustedCaFile != "" || cfg.InsecureSkipVerify || cfg.ProxyURL != "" { httpClient, err := createOIDCHTTPClient(cfg.TrustedCaFile, cfg.InsecureSkipVerify, cfg.ProxyURL) if err != nil { return nil, fmt.Errorf("failed to create OIDC HTTP client: %w", err) } ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) } // Create a persistent TokenSource that caches the token and refreshes // it before expiry. This avoids making a new HTTP request to the OIDC // provider on every heartbeat/ping. tokenSource := tokenGenerator.TokenSource(ctx) // Fetch the initial token to check if the provider returns an expiry. // If Expiry is the zero value (provider omitted expires_in), the cached // TokenSource would treat the token as valid forever and never refresh it, // even after the JWT's exp claim passes. In that case, fall back to // fetching a fresh token on every request. initialToken, err := tokenSource.Token() if err != nil { return nil, fmt.Errorf("failed to obtain initial OIDC token: %w", err) } if initialToken.Expiry.IsZero() { tokenSource = &nonCachingTokenSource{cfg: tokenGenerator, ctx: ctx} } return &OidcAuthProvider{ additionalAuthScopes: additionalAuthScopes, tokenSource: tokenSource, }, nil } func (auth *OidcAuthProvider) generateAccessToken() (accessToken string, err error) { tokenObj, err := auth.tokenSource.Token() if err != nil { return "", fmt.Errorf("couldn't generate OIDC token for login: %v", err) } return tokenObj.AccessToken, nil } func (auth *OidcAuthProvider) SetLogin(loginMsg *msg.Login) (err error) { loginMsg.PrivilegeKey, err = auth.generateAccessToken() return err } func (auth *OidcAuthProvider) SetPing(pingMsg *msg.Ping) (err error) { if !slices.Contains(auth.additionalAuthScopes, v1.AuthScopeHeartBeats) { return nil } pingMsg.PrivilegeKey, err = auth.generateAccessToken() return err } func (auth *OidcAuthProvider) SetNewWorkConn(newWorkConnMsg *msg.NewWorkConn) (err error) { if !slices.Contains(auth.additionalAuthScopes, v1.AuthScopeNewWorkConns) { return nil } newWorkConnMsg.PrivilegeKey, err = auth.generateAccessToken() return err } type OidcTokenSourceAuthProvider struct { additionalAuthScopes []v1.AuthScope valueSource *v1.ValueSource } func NewOidcTokenSourceAuthSetter(additionalAuthScopes []v1.AuthScope, valueSource *v1.ValueSource) *OidcTokenSourceAuthProvider { return &OidcTokenSourceAuthProvider{ additionalAuthScopes: additionalAuthScopes, valueSource: valueSource, } } func (auth *OidcTokenSourceAuthProvider) generateAccessToken() (accessToken string, err error) { ctx := context.Background() accessToken, err = auth.valueSource.Resolve(ctx) if err != nil { return "", fmt.Errorf("couldn't acquire OIDC token for login: %v", err) } return } func (auth *OidcTokenSourceAuthProvider) SetLogin(loginMsg *msg.Login) (err error) { loginMsg.PrivilegeKey, err = auth.generateAccessToken() return err } func (auth *OidcTokenSourceAuthProvider) SetPing(pingMsg *msg.Ping) (err error) { if !slices.Contains(auth.additionalAuthScopes, v1.AuthScopeHeartBeats) { return nil } pingMsg.PrivilegeKey, err = auth.generateAccessToken() return err } func (auth *OidcTokenSourceAuthProvider) SetNewWorkConn(newWorkConnMsg *msg.NewWorkConn) (err error) { if !slices.Contains(auth.additionalAuthScopes, v1.AuthScopeNewWorkConns) { return nil } newWorkConnMsg.PrivilegeKey, err = auth.generateAccessToken() return err } type TokenVerifier interface { Verify(context.Context, string) (*oidc.IDToken, error) } type OidcAuthConsumer struct { additionalAuthScopes []v1.AuthScope verifier TokenVerifier mu sync.RWMutex subjectsFromLogin map[string]struct{} } func NewTokenVerifier(cfg v1.AuthOIDCServerConfig) TokenVerifier { provider, err := oidc.NewProvider(context.Background(), cfg.Issuer) if err != nil { panic(err) } verifierConf := oidc.Config{ ClientID: cfg.Audience, SkipClientIDCheck: cfg.Audience == "", SkipExpiryCheck: cfg.SkipExpiryCheck, SkipIssuerCheck: cfg.SkipIssuerCheck, } return provider.Verifier(&verifierConf) } func NewOidcAuthVerifier(additionalAuthScopes []v1.AuthScope, verifier TokenVerifier) *OidcAuthConsumer { return &OidcAuthConsumer{ additionalAuthScopes: additionalAuthScopes, verifier: verifier, subjectsFromLogin: make(map[string]struct{}), } } func (auth *OidcAuthConsumer) VerifyLogin(loginMsg *msg.Login) (err error) { token, err := auth.verifier.Verify(context.Background(), loginMsg.PrivilegeKey) if err != nil { return fmt.Errorf("invalid OIDC token in login: %v", err) } auth.mu.Lock() auth.subjectsFromLogin[token.Subject] = struct{}{} auth.mu.Unlock() return nil } func (auth *OidcAuthConsumer) verifyPostLoginToken(privilegeKey string) (err error) { token, err := auth.verifier.Verify(context.Background(), privilegeKey) if err != nil { return fmt.Errorf("invalid OIDC token in ping: %v", err) } auth.mu.RLock() _, ok := auth.subjectsFromLogin[token.Subject] auth.mu.RUnlock() if !ok { return fmt.Errorf("received different OIDC subject in login and ping. "+ "new subject: %s", token.Subject) } return nil } func (auth *OidcAuthConsumer) VerifyPing(pingMsg *msg.Ping) (err error) { if !slices.Contains(auth.additionalAuthScopes, v1.AuthScopeHeartBeats) { return nil } return auth.verifyPostLoginToken(pingMsg.PrivilegeKey) } func (auth *OidcAuthConsumer) VerifyNewWorkConn(newWorkConnMsg *msg.NewWorkConn) (err error) { if !slices.Contains(auth.additionalAuthScopes, v1.AuthScopeNewWorkConns) { return nil } return auth.verifyPostLoginToken(newWorkConnMsg.PrivilegeKey) }