server: introduce SessionContext to encapsulate NewControl parameters (#5217)

Replace 10 positional parameters in NewControl() with a single
SessionContext struct, matching the client-side pattern. This also
eliminates the post-construction mutation of clientRegistry and
removes two TODO comments.
This commit is contained in:
fatedier
2026-03-07 20:17:00 +08:00
committed by GitHub
parent bd200b1a3b
commit 017d71717f
2 changed files with 103 additions and 111 deletions

View File

@@ -95,20 +95,33 @@ func (cm *ControlManager) Close() error {
return nil
}
type Control struct {
// SessionContext encapsulates the input parameters for creating a new Control.
type SessionContext struct {
// all resource managers and controllers
rc *controller.ResourceController
RC *controller.ResourceController
// proxy manager
pxyManager *proxy.Manager
PxyManager *proxy.Manager
// plugin manager
pluginManager *plugin.Manager
PluginManager *plugin.Manager
// verifies authentication based on selected method
authVerifier auth.Verifier
AuthVerifier auth.Verifier
// key used for connection encryption
encryptionKey []byte
EncryptionKey []byte
// control connection
Conn net.Conn
// indicates whether the connection is encrypted
ConnEncrypted bool
// login message
LoginMsg *msg.Login
// server configuration
ServerCfg *v1.ServerConfig
// client registry
ClientRegistry *registry.ClientRegistry
}
type Control struct {
// session context
sessionCtx *SessionContext
// other components can use this to communicate with client
msgTransporter transport.MessageTransporter
@@ -117,12 +130,6 @@ type Control struct {
// It provides a channel for sending messages, and you can register handlers to process messages based on their respective types.
msgDispatcher *msg.Dispatcher
// login message
loginMsg *msg.Login
// control connection
conn net.Conn
// work connections
workConnCh chan net.Conn
@@ -145,61 +152,37 @@ type Control struct {
mu sync.RWMutex
// Server configuration information
serverCfg *v1.ServerConfig
clientRegistry *registry.ClientRegistry
xl *xlog.Logger
ctx context.Context
doneCh chan struct{}
}
// TODO(fatedier): Referencing the implementation of frpc, encapsulate the input parameters as SessionContext.
func NewControl(
ctx context.Context,
rc *controller.ResourceController,
pxyManager *proxy.Manager,
pluginManager *plugin.Manager,
authVerifier auth.Verifier,
encryptionKey []byte,
ctlConn net.Conn,
ctlConnEncrypted bool,
loginMsg *msg.Login,
serverCfg *v1.ServerConfig,
) (*Control, error) {
poolCount := loginMsg.PoolCount
if poolCount > int(serverCfg.Transport.MaxPoolCount) {
poolCount = int(serverCfg.Transport.MaxPoolCount)
func NewControl(ctx context.Context, sessionCtx *SessionContext) (*Control, error) {
poolCount := sessionCtx.LoginMsg.PoolCount
if poolCount > int(sessionCtx.ServerCfg.Transport.MaxPoolCount) {
poolCount = int(sessionCtx.ServerCfg.Transport.MaxPoolCount)
}
ctl := &Control{
rc: rc,
pxyManager: pxyManager,
pluginManager: pluginManager,
authVerifier: authVerifier,
encryptionKey: encryptionKey,
conn: ctlConn,
loginMsg: loginMsg,
workConnCh: make(chan net.Conn, poolCount+10),
proxies: make(map[string]proxy.Proxy),
poolCount: poolCount,
portsUsedNum: 0,
runID: loginMsg.RunID,
serverCfg: serverCfg,
xl: xlog.FromContextSafe(ctx),
ctx: ctx,
doneCh: make(chan struct{}),
sessionCtx: sessionCtx,
workConnCh: make(chan net.Conn, poolCount+10),
proxies: make(map[string]proxy.Proxy),
poolCount: poolCount,
portsUsedNum: 0,
runID: sessionCtx.LoginMsg.RunID,
xl: xlog.FromContextSafe(ctx),
ctx: ctx,
doneCh: make(chan struct{}),
}
ctl.lastPing.Store(time.Now())
if ctlConnEncrypted {
cryptoRW, err := netpkg.NewCryptoReadWriter(ctl.conn, ctl.encryptionKey)
if sessionCtx.ConnEncrypted {
cryptoRW, err := netpkg.NewCryptoReadWriter(sessionCtx.Conn, sessionCtx.EncryptionKey)
if err != nil {
return nil, err
}
ctl.msgDispatcher = msg.NewDispatcher(cryptoRW)
} else {
ctl.msgDispatcher = msg.NewDispatcher(ctl.conn)
ctl.msgDispatcher = msg.NewDispatcher(sessionCtx.Conn)
}
ctl.registerMsgHandlers()
ctl.msgTransporter = transport.NewMessageTransporter(ctl.msgDispatcher)
@@ -213,7 +196,7 @@ func (ctl *Control) Start() {
RunID: ctl.runID,
Error: "",
}
_ = msg.WriteMsg(ctl.conn, loginRespMsg)
_ = msg.WriteMsg(ctl.sessionCtx.Conn, loginRespMsg)
go func() {
for i := 0; i < ctl.poolCount; i++ {
@@ -225,7 +208,7 @@ func (ctl *Control) Start() {
}
func (ctl *Control) Close() error {
ctl.conn.Close()
ctl.sessionCtx.Conn.Close()
return nil
}
@@ -233,7 +216,7 @@ func (ctl *Control) Replaced(newCtl *Control) {
xl := ctl.xl
xl.Infof("replaced by client [%s]", newCtl.runID)
ctl.runID = ""
ctl.conn.Close()
ctl.sessionCtx.Conn.Close()
}
func (ctl *Control) RegisterWorkConn(conn net.Conn) error {
@@ -291,7 +274,7 @@ func (ctl *Control) GetWorkConn() (workConn net.Conn, err error) {
return
}
case <-time.After(time.Duration(ctl.serverCfg.UserConnTimeout) * time.Second):
case <-time.After(time.Duration(ctl.sessionCtx.ServerCfg.UserConnTimeout) * time.Second):
err = fmt.Errorf("timeout trying to get work connection")
xl.Warnf("%v", err)
return
@@ -304,15 +287,15 @@ func (ctl *Control) GetWorkConn() (workConn net.Conn, err error) {
}
func (ctl *Control) heartbeatWorker() {
if ctl.serverCfg.Transport.HeartbeatTimeout <= 0 {
if ctl.sessionCtx.ServerCfg.Transport.HeartbeatTimeout <= 0 {
return
}
xl := ctl.xl
go wait.Until(func() {
if time.Since(ctl.lastPing.Load().(time.Time)) > time.Duration(ctl.serverCfg.Transport.HeartbeatTimeout)*time.Second {
if time.Since(ctl.lastPing.Load().(time.Time)) > time.Duration(ctl.sessionCtx.ServerCfg.Transport.HeartbeatTimeout)*time.Second {
xl.Warnf("heartbeat timeout")
ctl.conn.Close()
ctl.sessionCtx.Conn.Close()
return
}
}, time.Second, ctl.doneCh)
@@ -330,7 +313,7 @@ func (ctl *Control) worker() {
go ctl.msgDispatcher.Run()
<-ctl.msgDispatcher.Done()
ctl.conn.Close()
ctl.sessionCtx.Conn.Close()
ctl.mu.Lock()
defer ctl.mu.Unlock()
@@ -342,26 +325,26 @@ func (ctl *Control) worker() {
for _, pxy := range ctl.proxies {
pxy.Close()
ctl.pxyManager.Del(pxy.GetName())
ctl.sessionCtx.PxyManager.Del(pxy.GetName())
metrics.Server.CloseProxy(pxy.GetName(), pxy.GetConfigurer().GetBaseConfig().Type)
notifyContent := &plugin.CloseProxyContent{
User: plugin.UserInfo{
User: ctl.loginMsg.User,
Metas: ctl.loginMsg.Metas,
RunID: ctl.loginMsg.RunID,
User: ctl.sessionCtx.LoginMsg.User,
Metas: ctl.sessionCtx.LoginMsg.Metas,
RunID: ctl.sessionCtx.LoginMsg.RunID,
},
CloseProxy: msg.CloseProxy{
ProxyName: pxy.GetName(),
},
}
go func() {
_ = ctl.pluginManager.CloseProxy(notifyContent)
_ = ctl.sessionCtx.PluginManager.CloseProxy(notifyContent)
}()
}
metrics.Server.CloseClient()
ctl.clientRegistry.MarkOfflineByRunID(ctl.runID)
ctl.sessionCtx.ClientRegistry.MarkOfflineByRunID(ctl.runID)
xl.Infof("client exit success")
close(ctl.doneCh)
}
@@ -381,14 +364,14 @@ func (ctl *Control) handleNewProxy(m msg.Message) {
content := &plugin.NewProxyContent{
User: plugin.UserInfo{
User: ctl.loginMsg.User,
Metas: ctl.loginMsg.Metas,
RunID: ctl.loginMsg.RunID,
User: ctl.sessionCtx.LoginMsg.User,
Metas: ctl.sessionCtx.LoginMsg.Metas,
RunID: ctl.sessionCtx.LoginMsg.RunID,
},
NewProxy: *inMsg,
}
var remoteAddr string
retContent, err := ctl.pluginManager.NewProxy(content)
retContent, err := ctl.sessionCtx.PluginManager.NewProxy(content)
if err == nil {
inMsg = &retContent.NewProxy
remoteAddr, err = ctl.RegisterProxy(inMsg)
@@ -401,15 +384,15 @@ func (ctl *Control) handleNewProxy(m msg.Message) {
if err != nil {
xl.Warnf("new proxy [%s] type [%s] error: %v", inMsg.ProxyName, inMsg.ProxyType, err)
resp.Error = util.GenerateResponseErrorString(fmt.Sprintf("new proxy [%s] error", inMsg.ProxyName),
err, lo.FromPtr(ctl.serverCfg.DetailedErrorsToClient))
err, lo.FromPtr(ctl.sessionCtx.ServerCfg.DetailedErrorsToClient))
} else {
resp.RemoteAddr = remoteAddr
xl.Infof("new proxy [%s] type [%s] success", inMsg.ProxyName, inMsg.ProxyType)
clientID := ctl.loginMsg.ClientID
clientID := ctl.sessionCtx.LoginMsg.ClientID
if clientID == "" {
clientID = ctl.loginMsg.RunID
clientID = ctl.sessionCtx.LoginMsg.RunID
}
metrics.Server.NewProxy(inMsg.ProxyName, inMsg.ProxyType, ctl.loginMsg.User, clientID)
metrics.Server.NewProxy(inMsg.ProxyName, inMsg.ProxyType, ctl.sessionCtx.LoginMsg.User, clientID)
}
_ = ctl.msgDispatcher.Send(resp)
}
@@ -420,21 +403,21 @@ func (ctl *Control) handlePing(m msg.Message) {
content := &plugin.PingContent{
User: plugin.UserInfo{
User: ctl.loginMsg.User,
Metas: ctl.loginMsg.Metas,
RunID: ctl.loginMsg.RunID,
User: ctl.sessionCtx.LoginMsg.User,
Metas: ctl.sessionCtx.LoginMsg.Metas,
RunID: ctl.sessionCtx.LoginMsg.RunID,
},
Ping: *inMsg,
}
retContent, err := ctl.pluginManager.Ping(content)
retContent, err := ctl.sessionCtx.PluginManager.Ping(content)
if err == nil {
inMsg = &retContent.Ping
err = ctl.authVerifier.VerifyPing(inMsg)
err = ctl.sessionCtx.AuthVerifier.VerifyPing(inMsg)
}
if err != nil {
xl.Warnf("received invalid ping: %v", err)
_ = ctl.msgDispatcher.Send(&msg.Pong{
Error: util.GenerateResponseErrorString("invalid ping", err, lo.FromPtr(ctl.serverCfg.DetailedErrorsToClient)),
Error: util.GenerateResponseErrorString("invalid ping", err, lo.FromPtr(ctl.sessionCtx.ServerCfg.DetailedErrorsToClient)),
})
return
}
@@ -445,17 +428,17 @@ func (ctl *Control) handlePing(m msg.Message) {
func (ctl *Control) handleNatHoleVisitor(m msg.Message) {
inMsg := m.(*msg.NatHoleVisitor)
ctl.rc.NatHoleController.HandleVisitor(inMsg, ctl.msgTransporter, ctl.loginMsg.User)
ctl.sessionCtx.RC.NatHoleController.HandleVisitor(inMsg, ctl.msgTransporter, ctl.sessionCtx.LoginMsg.User)
}
func (ctl *Control) handleNatHoleClient(m msg.Message) {
inMsg := m.(*msg.NatHoleClient)
ctl.rc.NatHoleController.HandleClient(inMsg, ctl.msgTransporter)
ctl.sessionCtx.RC.NatHoleController.HandleClient(inMsg, ctl.msgTransporter)
}
func (ctl *Control) handleNatHoleReport(m msg.Message) {
inMsg := m.(*msg.NatHoleReport)
ctl.rc.NatHoleController.HandleReport(inMsg)
ctl.sessionCtx.RC.NatHoleController.HandleReport(inMsg)
}
func (ctl *Control) handleCloseProxy(m msg.Message) {
@@ -468,15 +451,15 @@ func (ctl *Control) handleCloseProxy(m msg.Message) {
func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err error) {
var pxyConf v1.ProxyConfigurer
// Load configures from NewProxy message and validate.
pxyConf, err = config.NewProxyConfigurerFromMsg(pxyMsg, ctl.serverCfg)
pxyConf, err = config.NewProxyConfigurerFromMsg(pxyMsg, ctl.sessionCtx.ServerCfg)
if err != nil {
return
}
// User info
userInfo := plugin.UserInfo{
User: ctl.loginMsg.User,
Metas: ctl.loginMsg.Metas,
User: ctl.sessionCtx.LoginMsg.User,
Metas: ctl.sessionCtx.LoginMsg.Metas,
RunID: ctl.runID,
}
@@ -484,22 +467,22 @@ func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err
// In fact, it creates different proxies based on the proxy type. We just call run() here.
pxy, err := proxy.NewProxy(ctl.ctx, &proxy.Options{
UserInfo: userInfo,
LoginMsg: ctl.loginMsg,
LoginMsg: ctl.sessionCtx.LoginMsg,
PoolCount: ctl.poolCount,
ResourceController: ctl.rc,
ResourceController: ctl.sessionCtx.RC,
GetWorkConnFn: ctl.GetWorkConn,
Configurer: pxyConf,
ServerCfg: ctl.serverCfg,
EncryptionKey: ctl.encryptionKey,
ServerCfg: ctl.sessionCtx.ServerCfg,
EncryptionKey: ctl.sessionCtx.EncryptionKey,
})
if err != nil {
return remoteAddr, err
}
// Check ports used number in each client
if ctl.serverCfg.MaxPortsPerClient > 0 {
if ctl.sessionCtx.ServerCfg.MaxPortsPerClient > 0 {
ctl.mu.Lock()
if ctl.portsUsedNum+pxy.GetUsedPortsNum() > int(ctl.serverCfg.MaxPortsPerClient) {
if ctl.portsUsedNum+pxy.GetUsedPortsNum() > int(ctl.sessionCtx.ServerCfg.MaxPortsPerClient) {
ctl.mu.Unlock()
err = fmt.Errorf("exceed the max_ports_per_client")
return
@@ -516,7 +499,7 @@ func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err
}()
}
if ctl.pxyManager.Exist(pxyMsg.ProxyName) {
if ctl.sessionCtx.PxyManager.Exist(pxyMsg.ProxyName) {
err = fmt.Errorf("proxy [%s] already exists", pxyMsg.ProxyName)
return
}
@@ -531,7 +514,7 @@ func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err
}
}()
err = ctl.pxyManager.Add(pxyMsg.ProxyName, pxy)
err = ctl.sessionCtx.PxyManager.Add(pxyMsg.ProxyName, pxy)
if err != nil {
return
}
@@ -550,11 +533,11 @@ func (ctl *Control) CloseProxy(closeMsg *msg.CloseProxy) (err error) {
return
}
if ctl.serverCfg.MaxPortsPerClient > 0 {
if ctl.sessionCtx.ServerCfg.MaxPortsPerClient > 0 {
ctl.portsUsedNum -= pxy.GetUsedPortsNum()
}
pxy.Close()
ctl.pxyManager.Del(pxy.GetName())
ctl.sessionCtx.PxyManager.Del(pxy.GetName())
delete(ctl.proxies, closeMsg.ProxyName)
ctl.mu.Unlock()
@@ -562,16 +545,16 @@ func (ctl *Control) CloseProxy(closeMsg *msg.CloseProxy) (err error) {
notifyContent := &plugin.CloseProxyContent{
User: plugin.UserInfo{
User: ctl.loginMsg.User,
Metas: ctl.loginMsg.Metas,
RunID: ctl.loginMsg.RunID,
User: ctl.sessionCtx.LoginMsg.User,
Metas: ctl.sessionCtx.LoginMsg.Metas,
RunID: ctl.sessionCtx.LoginMsg.RunID,
},
CloseProxy: msg.CloseProxy{
ProxyName: pxy.GetName(),
},
}
go func() {
_ = ctl.pluginManager.CloseProxy(notifyContent)
_ = ctl.sessionCtx.PluginManager.CloseProxy(notifyContent)
}()
return
}

View File

@@ -604,8 +604,18 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, inter
return err
}
// TODO(fatedier): use SessionContext
ctl, err := NewControl(ctx, svr.rc, svr.pxyManager, svr.pluginManager, authVerifier, svr.auth.EncryptionKey(), ctlConn, !internal, loginMsg, svr.cfg)
ctl, err := NewControl(ctx, &SessionContext{
RC: svr.rc,
PxyManager: svr.pxyManager,
PluginManager: svr.pluginManager,
AuthVerifier: authVerifier,
EncryptionKey: svr.auth.EncryptionKey(),
Conn: ctlConn,
ConnEncrypted: !internal,
LoginMsg: loginMsg,
ServerCfg: svr.cfg,
ClientRegistry: svr.clientRegistry,
})
if err != nil {
xl.Warnf("create new controller error: %v", err)
// don't return detailed errors to client
@@ -626,7 +636,6 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, inter
ctl.Close()
return fmt.Errorf("client_id [%s] for user [%s] is already online", loginMsg.ClientID, loginMsg.User)
}
ctl.clientRegistry = svr.clientRegistry
ctl.Start()
@@ -652,9 +661,9 @@ func (svr *Service) RegisterWorkConn(workConn net.Conn, newMsg *msg.NewWorkConn)
// server plugin hook
content := &plugin.NewWorkConnContent{
User: plugin.UserInfo{
User: ctl.loginMsg.User,
Metas: ctl.loginMsg.Metas,
RunID: ctl.loginMsg.RunID,
User: ctl.sessionCtx.LoginMsg.User,
Metas: ctl.sessionCtx.LoginMsg.Metas,
RunID: ctl.sessionCtx.LoginMsg.RunID,
},
NewWorkConn: *newMsg,
}
@@ -662,7 +671,7 @@ func (svr *Service) RegisterWorkConn(workConn net.Conn, newMsg *msg.NewWorkConn)
if err == nil {
newMsg = &retContent.NewWorkConn
// Check auth.
err = ctl.authVerifier.VerifyNewWorkConn(newMsg)
err = ctl.sessionCtx.AuthVerifier.VerifyNewWorkConn(newMsg)
}
if err != nil {
xl.Warnf("invalid NewWorkConn with run id [%s]", newMsg.RunID)
@@ -683,7 +692,7 @@ func (svr *Service) RegisterVisitorConn(visitorConn net.Conn, newMsg *msg.NewVis
if !exist {
return fmt.Errorf("no client control found for run id [%s]", newMsg.RunID)
}
visitorUser = ctl.loginMsg.User
visitorUser = ctl.sessionCtx.LoginMsg.User
}
return svr.rc.VisitorManager.NewConn(newMsg.ProxyName, visitorConn, newMsg.Timestamp, newMsg.SignKey,
newMsg.UseEncryption, newMsg.UseCompression, visitorUser)