diff --git a/server/control.go b/server/control.go index 863104f3..219d6368 100644 --- a/server/control.go +++ b/server/control.go @@ -95,20 +95,33 @@ func (cm *ControlManager) Close() error { return nil } -type Control struct { +// SessionContext encapsulates the input parameters for creating a new Control. +type SessionContext struct { // all resource managers and controllers - rc *controller.ResourceController - + RC *controller.ResourceController // proxy manager - pxyManager *proxy.Manager - + PxyManager *proxy.Manager // plugin manager - pluginManager *plugin.Manager - + PluginManager *plugin.Manager // verifies authentication based on selected method - authVerifier auth.Verifier + AuthVerifier auth.Verifier // key used for connection encryption - encryptionKey []byte + EncryptionKey []byte + // control connection + Conn net.Conn + // indicates whether the connection is encrypted + ConnEncrypted bool + // login message + LoginMsg *msg.Login + // server configuration + ServerCfg *v1.ServerConfig + // client registry + ClientRegistry *registry.ClientRegistry +} + +type Control struct { + // session context + sessionCtx *SessionContext // other components can use this to communicate with client msgTransporter transport.MessageTransporter @@ -117,12 +130,6 @@ type Control struct { // It provides a channel for sending messages, and you can register handlers to process messages based on their respective types. msgDispatcher *msg.Dispatcher - // login message - loginMsg *msg.Login - - // control connection - conn net.Conn - // work connections workConnCh chan net.Conn @@ -145,61 +152,37 @@ type Control struct { mu sync.RWMutex - // Server configuration information - serverCfg *v1.ServerConfig - - clientRegistry *registry.ClientRegistry - xl *xlog.Logger ctx context.Context doneCh chan struct{} } -// TODO(fatedier): Referencing the implementation of frpc, encapsulate the input parameters as SessionContext. -func NewControl( - ctx context.Context, - rc *controller.ResourceController, - pxyManager *proxy.Manager, - pluginManager *plugin.Manager, - authVerifier auth.Verifier, - encryptionKey []byte, - ctlConn net.Conn, - ctlConnEncrypted bool, - loginMsg *msg.Login, - serverCfg *v1.ServerConfig, -) (*Control, error) { - poolCount := loginMsg.PoolCount - if poolCount > int(serverCfg.Transport.MaxPoolCount) { - poolCount = int(serverCfg.Transport.MaxPoolCount) +func NewControl(ctx context.Context, sessionCtx *SessionContext) (*Control, error) { + poolCount := sessionCtx.LoginMsg.PoolCount + if poolCount > int(sessionCtx.ServerCfg.Transport.MaxPoolCount) { + poolCount = int(sessionCtx.ServerCfg.Transport.MaxPoolCount) } ctl := &Control{ - rc: rc, - pxyManager: pxyManager, - pluginManager: pluginManager, - authVerifier: authVerifier, - encryptionKey: encryptionKey, - conn: ctlConn, - loginMsg: loginMsg, - workConnCh: make(chan net.Conn, poolCount+10), - proxies: make(map[string]proxy.Proxy), - poolCount: poolCount, - portsUsedNum: 0, - runID: loginMsg.RunID, - serverCfg: serverCfg, - xl: xlog.FromContextSafe(ctx), - ctx: ctx, - doneCh: make(chan struct{}), + sessionCtx: sessionCtx, + workConnCh: make(chan net.Conn, poolCount+10), + proxies: make(map[string]proxy.Proxy), + poolCount: poolCount, + portsUsedNum: 0, + runID: sessionCtx.LoginMsg.RunID, + xl: xlog.FromContextSafe(ctx), + ctx: ctx, + doneCh: make(chan struct{}), } ctl.lastPing.Store(time.Now()) - if ctlConnEncrypted { - cryptoRW, err := netpkg.NewCryptoReadWriter(ctl.conn, ctl.encryptionKey) + if sessionCtx.ConnEncrypted { + cryptoRW, err := netpkg.NewCryptoReadWriter(sessionCtx.Conn, sessionCtx.EncryptionKey) if err != nil { return nil, err } ctl.msgDispatcher = msg.NewDispatcher(cryptoRW) } else { - ctl.msgDispatcher = msg.NewDispatcher(ctl.conn) + ctl.msgDispatcher = msg.NewDispatcher(sessionCtx.Conn) } ctl.registerMsgHandlers() ctl.msgTransporter = transport.NewMessageTransporter(ctl.msgDispatcher) @@ -213,7 +196,7 @@ func (ctl *Control) Start() { RunID: ctl.runID, Error: "", } - _ = msg.WriteMsg(ctl.conn, loginRespMsg) + _ = msg.WriteMsg(ctl.sessionCtx.Conn, loginRespMsg) go func() { for i := 0; i < ctl.poolCount; i++ { @@ -225,7 +208,7 @@ func (ctl *Control) Start() { } func (ctl *Control) Close() error { - ctl.conn.Close() + ctl.sessionCtx.Conn.Close() return nil } @@ -233,7 +216,7 @@ func (ctl *Control) Replaced(newCtl *Control) { xl := ctl.xl xl.Infof("replaced by client [%s]", newCtl.runID) ctl.runID = "" - ctl.conn.Close() + ctl.sessionCtx.Conn.Close() } func (ctl *Control) RegisterWorkConn(conn net.Conn) error { @@ -291,7 +274,7 @@ func (ctl *Control) GetWorkConn() (workConn net.Conn, err error) { return } - case <-time.After(time.Duration(ctl.serverCfg.UserConnTimeout) * time.Second): + case <-time.After(time.Duration(ctl.sessionCtx.ServerCfg.UserConnTimeout) * time.Second): err = fmt.Errorf("timeout trying to get work connection") xl.Warnf("%v", err) return @@ -304,15 +287,15 @@ func (ctl *Control) GetWorkConn() (workConn net.Conn, err error) { } func (ctl *Control) heartbeatWorker() { - if ctl.serverCfg.Transport.HeartbeatTimeout <= 0 { + if ctl.sessionCtx.ServerCfg.Transport.HeartbeatTimeout <= 0 { return } xl := ctl.xl go wait.Until(func() { - if time.Since(ctl.lastPing.Load().(time.Time)) > time.Duration(ctl.serverCfg.Transport.HeartbeatTimeout)*time.Second { + if time.Since(ctl.lastPing.Load().(time.Time)) > time.Duration(ctl.sessionCtx.ServerCfg.Transport.HeartbeatTimeout)*time.Second { xl.Warnf("heartbeat timeout") - ctl.conn.Close() + ctl.sessionCtx.Conn.Close() return } }, time.Second, ctl.doneCh) @@ -330,7 +313,7 @@ func (ctl *Control) worker() { go ctl.msgDispatcher.Run() <-ctl.msgDispatcher.Done() - ctl.conn.Close() + ctl.sessionCtx.Conn.Close() ctl.mu.Lock() defer ctl.mu.Unlock() @@ -342,26 +325,26 @@ func (ctl *Control) worker() { for _, pxy := range ctl.proxies { pxy.Close() - ctl.pxyManager.Del(pxy.GetName()) + ctl.sessionCtx.PxyManager.Del(pxy.GetName()) metrics.Server.CloseProxy(pxy.GetName(), pxy.GetConfigurer().GetBaseConfig().Type) notifyContent := &plugin.CloseProxyContent{ User: plugin.UserInfo{ - User: ctl.loginMsg.User, - Metas: ctl.loginMsg.Metas, - RunID: ctl.loginMsg.RunID, + User: ctl.sessionCtx.LoginMsg.User, + Metas: ctl.sessionCtx.LoginMsg.Metas, + RunID: ctl.sessionCtx.LoginMsg.RunID, }, CloseProxy: msg.CloseProxy{ ProxyName: pxy.GetName(), }, } go func() { - _ = ctl.pluginManager.CloseProxy(notifyContent) + _ = ctl.sessionCtx.PluginManager.CloseProxy(notifyContent) }() } metrics.Server.CloseClient() - ctl.clientRegistry.MarkOfflineByRunID(ctl.runID) + ctl.sessionCtx.ClientRegistry.MarkOfflineByRunID(ctl.runID) xl.Infof("client exit success") close(ctl.doneCh) } @@ -381,14 +364,14 @@ func (ctl *Control) handleNewProxy(m msg.Message) { content := &plugin.NewProxyContent{ User: plugin.UserInfo{ - User: ctl.loginMsg.User, - Metas: ctl.loginMsg.Metas, - RunID: ctl.loginMsg.RunID, + User: ctl.sessionCtx.LoginMsg.User, + Metas: ctl.sessionCtx.LoginMsg.Metas, + RunID: ctl.sessionCtx.LoginMsg.RunID, }, NewProxy: *inMsg, } var remoteAddr string - retContent, err := ctl.pluginManager.NewProxy(content) + retContent, err := ctl.sessionCtx.PluginManager.NewProxy(content) if err == nil { inMsg = &retContent.NewProxy remoteAddr, err = ctl.RegisterProxy(inMsg) @@ -401,15 +384,15 @@ func (ctl *Control) handleNewProxy(m msg.Message) { if err != nil { xl.Warnf("new proxy [%s] type [%s] error: %v", inMsg.ProxyName, inMsg.ProxyType, err) resp.Error = util.GenerateResponseErrorString(fmt.Sprintf("new proxy [%s] error", inMsg.ProxyName), - err, lo.FromPtr(ctl.serverCfg.DetailedErrorsToClient)) + err, lo.FromPtr(ctl.sessionCtx.ServerCfg.DetailedErrorsToClient)) } else { resp.RemoteAddr = remoteAddr xl.Infof("new proxy [%s] type [%s] success", inMsg.ProxyName, inMsg.ProxyType) - clientID := ctl.loginMsg.ClientID + clientID := ctl.sessionCtx.LoginMsg.ClientID if clientID == "" { - clientID = ctl.loginMsg.RunID + clientID = ctl.sessionCtx.LoginMsg.RunID } - metrics.Server.NewProxy(inMsg.ProxyName, inMsg.ProxyType, ctl.loginMsg.User, clientID) + metrics.Server.NewProxy(inMsg.ProxyName, inMsg.ProxyType, ctl.sessionCtx.LoginMsg.User, clientID) } _ = ctl.msgDispatcher.Send(resp) } @@ -420,21 +403,21 @@ func (ctl *Control) handlePing(m msg.Message) { content := &plugin.PingContent{ User: plugin.UserInfo{ - User: ctl.loginMsg.User, - Metas: ctl.loginMsg.Metas, - RunID: ctl.loginMsg.RunID, + User: ctl.sessionCtx.LoginMsg.User, + Metas: ctl.sessionCtx.LoginMsg.Metas, + RunID: ctl.sessionCtx.LoginMsg.RunID, }, Ping: *inMsg, } - retContent, err := ctl.pluginManager.Ping(content) + retContent, err := ctl.sessionCtx.PluginManager.Ping(content) if err == nil { inMsg = &retContent.Ping - err = ctl.authVerifier.VerifyPing(inMsg) + err = ctl.sessionCtx.AuthVerifier.VerifyPing(inMsg) } if err != nil { xl.Warnf("received invalid ping: %v", err) _ = ctl.msgDispatcher.Send(&msg.Pong{ - Error: util.GenerateResponseErrorString("invalid ping", err, lo.FromPtr(ctl.serverCfg.DetailedErrorsToClient)), + Error: util.GenerateResponseErrorString("invalid ping", err, lo.FromPtr(ctl.sessionCtx.ServerCfg.DetailedErrorsToClient)), }) return } @@ -445,17 +428,17 @@ func (ctl *Control) handlePing(m msg.Message) { func (ctl *Control) handleNatHoleVisitor(m msg.Message) { inMsg := m.(*msg.NatHoleVisitor) - ctl.rc.NatHoleController.HandleVisitor(inMsg, ctl.msgTransporter, ctl.loginMsg.User) + ctl.sessionCtx.RC.NatHoleController.HandleVisitor(inMsg, ctl.msgTransporter, ctl.sessionCtx.LoginMsg.User) } func (ctl *Control) handleNatHoleClient(m msg.Message) { inMsg := m.(*msg.NatHoleClient) - ctl.rc.NatHoleController.HandleClient(inMsg, ctl.msgTransporter) + ctl.sessionCtx.RC.NatHoleController.HandleClient(inMsg, ctl.msgTransporter) } func (ctl *Control) handleNatHoleReport(m msg.Message) { inMsg := m.(*msg.NatHoleReport) - ctl.rc.NatHoleController.HandleReport(inMsg) + ctl.sessionCtx.RC.NatHoleController.HandleReport(inMsg) } func (ctl *Control) handleCloseProxy(m msg.Message) { @@ -468,15 +451,15 @@ func (ctl *Control) handleCloseProxy(m msg.Message) { func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err error) { var pxyConf v1.ProxyConfigurer // Load configures from NewProxy message and validate. - pxyConf, err = config.NewProxyConfigurerFromMsg(pxyMsg, ctl.serverCfg) + pxyConf, err = config.NewProxyConfigurerFromMsg(pxyMsg, ctl.sessionCtx.ServerCfg) if err != nil { return } // User info userInfo := plugin.UserInfo{ - User: ctl.loginMsg.User, - Metas: ctl.loginMsg.Metas, + User: ctl.sessionCtx.LoginMsg.User, + Metas: ctl.sessionCtx.LoginMsg.Metas, RunID: ctl.runID, } @@ -484,22 +467,22 @@ func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err // In fact, it creates different proxies based on the proxy type. We just call run() here. pxy, err := proxy.NewProxy(ctl.ctx, &proxy.Options{ UserInfo: userInfo, - LoginMsg: ctl.loginMsg, + LoginMsg: ctl.sessionCtx.LoginMsg, PoolCount: ctl.poolCount, - ResourceController: ctl.rc, + ResourceController: ctl.sessionCtx.RC, GetWorkConnFn: ctl.GetWorkConn, Configurer: pxyConf, - ServerCfg: ctl.serverCfg, - EncryptionKey: ctl.encryptionKey, + ServerCfg: ctl.sessionCtx.ServerCfg, + EncryptionKey: ctl.sessionCtx.EncryptionKey, }) if err != nil { return remoteAddr, err } // Check ports used number in each client - if ctl.serverCfg.MaxPortsPerClient > 0 { + if ctl.sessionCtx.ServerCfg.MaxPortsPerClient > 0 { ctl.mu.Lock() - if ctl.portsUsedNum+pxy.GetUsedPortsNum() > int(ctl.serverCfg.MaxPortsPerClient) { + if ctl.portsUsedNum+pxy.GetUsedPortsNum() > int(ctl.sessionCtx.ServerCfg.MaxPortsPerClient) { ctl.mu.Unlock() err = fmt.Errorf("exceed the max_ports_per_client") return @@ -516,7 +499,7 @@ func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err }() } - if ctl.pxyManager.Exist(pxyMsg.ProxyName) { + if ctl.sessionCtx.PxyManager.Exist(pxyMsg.ProxyName) { err = fmt.Errorf("proxy [%s] already exists", pxyMsg.ProxyName) return } @@ -531,7 +514,7 @@ func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err } }() - err = ctl.pxyManager.Add(pxyMsg.ProxyName, pxy) + err = ctl.sessionCtx.PxyManager.Add(pxyMsg.ProxyName, pxy) if err != nil { return } @@ -550,11 +533,11 @@ func (ctl *Control) CloseProxy(closeMsg *msg.CloseProxy) (err error) { return } - if ctl.serverCfg.MaxPortsPerClient > 0 { + if ctl.sessionCtx.ServerCfg.MaxPortsPerClient > 0 { ctl.portsUsedNum -= pxy.GetUsedPortsNum() } pxy.Close() - ctl.pxyManager.Del(pxy.GetName()) + ctl.sessionCtx.PxyManager.Del(pxy.GetName()) delete(ctl.proxies, closeMsg.ProxyName) ctl.mu.Unlock() @@ -562,16 +545,16 @@ func (ctl *Control) CloseProxy(closeMsg *msg.CloseProxy) (err error) { notifyContent := &plugin.CloseProxyContent{ User: plugin.UserInfo{ - User: ctl.loginMsg.User, - Metas: ctl.loginMsg.Metas, - RunID: ctl.loginMsg.RunID, + User: ctl.sessionCtx.LoginMsg.User, + Metas: ctl.sessionCtx.LoginMsg.Metas, + RunID: ctl.sessionCtx.LoginMsg.RunID, }, CloseProxy: msg.CloseProxy{ ProxyName: pxy.GetName(), }, } go func() { - _ = ctl.pluginManager.CloseProxy(notifyContent) + _ = ctl.sessionCtx.PluginManager.CloseProxy(notifyContent) }() return } diff --git a/server/service.go b/server/service.go index b0db0327..bc106eb6 100644 --- a/server/service.go +++ b/server/service.go @@ -604,8 +604,18 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, inter return err } - // TODO(fatedier): use SessionContext - ctl, err := NewControl(ctx, svr.rc, svr.pxyManager, svr.pluginManager, authVerifier, svr.auth.EncryptionKey(), ctlConn, !internal, loginMsg, svr.cfg) + ctl, err := NewControl(ctx, &SessionContext{ + RC: svr.rc, + PxyManager: svr.pxyManager, + PluginManager: svr.pluginManager, + AuthVerifier: authVerifier, + EncryptionKey: svr.auth.EncryptionKey(), + Conn: ctlConn, + ConnEncrypted: !internal, + LoginMsg: loginMsg, + ServerCfg: svr.cfg, + ClientRegistry: svr.clientRegistry, + }) if err != nil { xl.Warnf("create new controller error: %v", err) // don't return detailed errors to client @@ -626,7 +636,6 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, inter ctl.Close() return fmt.Errorf("client_id [%s] for user [%s] is already online", loginMsg.ClientID, loginMsg.User) } - ctl.clientRegistry = svr.clientRegistry ctl.Start() @@ -652,9 +661,9 @@ func (svr *Service) RegisterWorkConn(workConn net.Conn, newMsg *msg.NewWorkConn) // server plugin hook content := &plugin.NewWorkConnContent{ User: plugin.UserInfo{ - User: ctl.loginMsg.User, - Metas: ctl.loginMsg.Metas, - RunID: ctl.loginMsg.RunID, + User: ctl.sessionCtx.LoginMsg.User, + Metas: ctl.sessionCtx.LoginMsg.Metas, + RunID: ctl.sessionCtx.LoginMsg.RunID, }, NewWorkConn: *newMsg, } @@ -662,7 +671,7 @@ func (svr *Service) RegisterWorkConn(workConn net.Conn, newMsg *msg.NewWorkConn) if err == nil { newMsg = &retContent.NewWorkConn // Check auth. - err = ctl.authVerifier.VerifyNewWorkConn(newMsg) + err = ctl.sessionCtx.AuthVerifier.VerifyNewWorkConn(newMsg) } if err != nil { xl.Warnf("invalid NewWorkConn with run id [%s]", newMsg.RunID) @@ -683,7 +692,7 @@ func (svr *Service) RegisterVisitorConn(visitorConn net.Conn, newMsg *msg.NewVis if !exist { return fmt.Errorf("no client control found for run id [%s]", newMsg.RunID) } - visitorUser = ctl.loginMsg.User + visitorUser = ctl.sessionCtx.LoginMsg.User } return svr.rc.VisitorManager.NewConn(newMsg.ProxyName, visitorConn, newMsg.Timestamp, newMsg.SignKey, newMsg.UseEncryption, newMsg.UseCompression, visitorUser)