auth/oidc: fix eager token fetch at startup, add validation and e2e tests (#5234)

This commit is contained in:
fatedier
2026-03-15 22:29:45 +08:00
committed by GitHub
parent 94a631fe9c
commit ff4ad2f907
12 changed files with 885 additions and 35 deletions

View File

@@ -19,6 +19,7 @@ import (
"errors"
"fmt"
"net"
"net/http"
"os"
"runtime"
"sync"
@@ -162,15 +163,6 @@ func NewService(options ServiceOptions) (*Service, error) {
return nil, err
}
var webServer *httppkg.Server
if options.Common.WebServer.Port > 0 {
ws, err := httppkg.NewServer(options.Common.WebServer)
if err != nil {
return nil, err
}
webServer = ws
}
authRuntime, err := auth.BuildClientAuth(&options.Common.Auth)
if err != nil {
return nil, err
@@ -191,6 +183,17 @@ func NewService(options ServiceOptions) (*Service, error) {
proxyCfgs = config.CompleteProxyConfigurers(proxyCfgs)
visitorCfgs = config.CompleteVisitorConfigurers(visitorCfgs)
// Create the web server after all fallible steps so its listener is not
// leaked when an earlier error causes NewService to return.
var webServer *httppkg.Server
if options.Common.WebServer.Port > 0 {
ws, err := httppkg.NewServer(options.Common.WebServer)
if err != nil {
return nil, err
}
webServer = ws
}
s := &Service{
ctx: context.Background(),
auth: authRuntime,
@@ -229,22 +232,25 @@ func (svr *Service) Run(ctx context.Context) error {
}
if svr.vnetController != nil {
vnetController := svr.vnetController
if err := svr.vnetController.Init(); err != nil {
log.Errorf("init virtual network controller error: %v", err)
svr.stop()
return err
}
go func() {
log.Infof("virtual network controller start...")
if err := svr.vnetController.Run(); err != nil {
if err := vnetController.Run(); err != nil && !errors.Is(err, net.ErrClosed) {
log.Warnf("virtual network controller exit with error: %v", err)
}
}()
}
if svr.webServer != nil {
webServer := svr.webServer
go func() {
log.Infof("admin server listen on %s", svr.webServer.Address())
if err := svr.webServer.Run(); err != nil {
log.Infof("admin server listen on %s", webServer.Address())
if err := webServer.Run(); err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Warnf("admin server exit with error: %v", err)
}
}()
@@ -255,6 +261,7 @@ func (svr *Service) Run(ctx context.Context) error {
if svr.ctl == nil {
cancelCause := cancelErr{}
_ = errors.As(context.Cause(svr.ctx), &cancelCause)
svr.stop()
return fmt.Errorf("login to the server failed: %v. With loginFailExit enabled, no additional retries will be attempted", cancelCause.Err)
}
@@ -497,6 +504,10 @@ func (svr *Service) stop() {
svr.webServer.Close()
svr.webServer = nil
}
if svr.vnetController != nil {
_ = svr.vnetController.Stop()
svr.vnetController = nil
}
}
func (svr *Service) getProxyStatus(name string) (*proxy.WorkingStatus, bool) {

View File

@@ -1,14 +1,120 @@
package client
import (
"context"
"errors"
"net"
"path/filepath"
"strconv"
"strings"
"testing"
"github.com/samber/lo"
"github.com/fatedier/frp/pkg/config/source"
v1 "github.com/fatedier/frp/pkg/config/v1"
)
type failingConnector struct {
err error
}
func (c *failingConnector) Open() error {
return c.err
}
func (c *failingConnector) Connect() (net.Conn, error) {
return nil, c.err
}
func (c *failingConnector) Close() error {
return nil
}
func getFreeTCPPort(t *testing.T) int {
t.Helper()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen on ephemeral port: %v", err)
}
defer ln.Close()
return ln.Addr().(*net.TCPAddr).Port
}
func TestRunStopsStartedComponentsOnInitialLoginFailure(t *testing.T) {
port := getFreeTCPPort(t)
agg := source.NewAggregator(source.NewConfigSource())
svr, err := NewService(ServiceOptions{
Common: &v1.ClientCommonConfig{
LoginFailExit: lo.ToPtr(true),
WebServer: v1.WebServerConfig{
Addr: "127.0.0.1",
Port: port,
},
},
ConfigSourceAggregator: agg,
ConnectorCreator: func(context.Context, *v1.ClientCommonConfig) Connector {
return &failingConnector{err: errors.New("login boom")}
},
})
if err != nil {
t.Fatalf("new service: %v", err)
}
err = svr.Run(context.Background())
if err == nil {
t.Fatal("expected run error, got nil")
}
if !strings.Contains(err.Error(), "login boom") {
t.Fatalf("unexpected error: %v", err)
}
if svr.webServer != nil {
t.Fatal("expected web server to be cleaned up after initial login failure")
}
ln, err := net.Listen("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(port)))
if err != nil {
t.Fatalf("expected admin port to be released: %v", err)
}
_ = ln.Close()
}
func TestNewServiceDoesNotLeakAdminListenerOnAuthBuildFailure(t *testing.T) {
port := getFreeTCPPort(t)
agg := source.NewAggregator(source.NewConfigSource())
_, err := NewService(ServiceOptions{
Common: &v1.ClientCommonConfig{
Auth: v1.AuthClientConfig{
Method: v1.AuthMethodOIDC,
OIDC: v1.AuthOIDCClientConfig{
TokenEndpointURL: "://bad",
},
},
WebServer: v1.WebServerConfig{
Addr: "127.0.0.1",
Port: port,
},
},
ConfigSourceAggregator: agg,
})
if err == nil {
t.Fatal("expected new service error, got nil")
}
if !strings.Contains(err.Error(), "auth.oidc.tokenEndpointURL") {
t.Fatalf("unexpected error: %v", err)
}
ln, err := net.Listen("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(port)))
if err != nil {
t.Fatalf("expected admin port to remain free: %v", err)
}
_ = ln.Close()
}
func TestUpdateConfigSourceRollsBackReloadCommonOnReplaceAllFailure(t *testing.T) {
prevCommon := &v1.ClientCommonConfig{User: "old-user"}
newCommon := &v1.ClientCommonConfig{User: "new-user"}