Compare commits

..

4 Commits

Author SHA1 Message Date
fatedier
443b9bca66 server: introduce SessionContext to encapsulate NewControl parameters
Replace 10 positional parameters in NewControl() with a single
SessionContext struct, matching the client-side pattern. This also
eliminates the post-construction mutation of clientRegistry and
removes two TODO comments.
2026-03-07 18:26:56 +08:00
fatedier
c70ceff370 fix: three high-severity bugs across nathole, proxy, and udp modules (#5214)
- pkg/nathole: add RLock when reading clientCfgs map in PreCheck path
  to prevent concurrent map read/write crash
- server/proxy: fix error variable shadowing in GetWorkConnFromPool
  that could return a closed connection with nil error
- pkg/util/net: check ListenUDP error before spawning goroutines
  and assign readConn to struct field so Close() works correctly
2026-03-07 13:36:02 +08:00
fatedier
bb3d0e7140 deduplicate common logic across proxy, visitor, and metrics modules (#5213)
- Replace duplicate parseBasicAuth with existing httppkg.ParseBasicAuth
- Extract buildDomains helper in BaseProxy for HTTP/HTTPS/TCPMux proxies
- Extract toProxyStats helper to deduplicate ProxyStats construction
- Extract startVisitorListener helper in BaseProxy for STCP/SUDP proxies
- Extract acceptLoop helper in BaseVisitor for STCP/XTCP visitors
2026-03-07 12:00:27 +08:00
fatedier
cf396563f8 client/proxy: unify work conn wrapping across all proxy types (#5212)
* client/proxy: extract wrapWorkConn to deduplicate UDP/SUDP connection wrapping

Move the repeated rate-limiting, encryption, and compression wrapping
logic from UDPProxy and SUDPProxy into a shared BaseProxy.wrapWorkConn
method, reducing ~18 lines of duplication in each proxy type.

* client/proxy: unify work conn wrapping with pooled compression for all proxy types

Refactor wrapWorkConn to accept encKey parameter and return
(io.ReadWriteCloser, recycleFn, error), enabling HandleTCPWorkConnection
to reuse the same limiter/encryption/compression pipeline.

Switch all proxy types from WithCompression to WithCompressionFromPool.
TCP non-plugin path calls recycleFn via defer after Join; plugin and
UDP/SUDP paths skip recycle (objects are GC'd safely, per golib contract).
2026-03-07 01:33:37 +08:00
15 changed files with 185 additions and 289 deletions

View File

@@ -42,10 +42,10 @@ func (sv *STCPVisitor) Run() (err error) {
if err != nil {
return
}
go sv.worker()
go sv.acceptLoop(sv.l, "stcp local", sv.handleConn)
}
go sv.internalConnWorker()
go sv.acceptLoop(sv.internalLn, "stcp internal", sv.handleConn)
if sv.plugin != nil {
sv.plugin.Start()
@@ -57,30 +57,6 @@ func (sv *STCPVisitor) Close() {
sv.BaseVisitor.Close()
}
func (sv *STCPVisitor) worker() {
xl := xlog.FromContextSafe(sv.ctx)
for {
conn, err := sv.l.Accept()
if err != nil {
xl.Warnf("stcp local listener closed")
return
}
go sv.handleConn(conn)
}
}
func (sv *STCPVisitor) internalConnWorker() {
xl := xlog.FromContextSafe(sv.ctx)
for {
conn, err := sv.internalLn.Accept()
if err != nil {
xl.Warnf("stcp internal listener closed")
return
}
go sv.handleConn(conn)
}
}
func (sv *STCPVisitor) handleConn(userConn net.Conn) {
xl := xlog.FromContextSafe(sv.ctx)
var tunnelErr error

View File

@@ -119,6 +119,18 @@ func (v *BaseVisitor) AcceptConn(conn net.Conn) error {
return v.internalLn.PutConn(conn)
}
func (v *BaseVisitor) acceptLoop(l net.Listener, name string, handleConn func(net.Conn)) {
xl := xlog.FromContextSafe(v.ctx)
for {
conn, err := l.Accept()
if err != nil {
xl.Warnf("%s listener closed", name)
return
}
go handleConn(conn)
}
}
func (v *BaseVisitor) Close() {
if v.l != nil {
v.l.Close()

View File

@@ -65,10 +65,10 @@ func (sv *XTCPVisitor) Run() (err error) {
if err != nil {
return
}
go sv.worker()
go sv.acceptLoop(sv.l, "xtcp local", sv.handleConn)
}
go sv.internalConnWorker()
go sv.acceptLoop(sv.internalLn, "xtcp internal", sv.handleConn)
go sv.processTunnelStartEvents()
if sv.cfg.KeepTunnelOpen {
sv.retryLimiter = rate.NewLimiter(rate.Every(time.Hour/time.Duration(sv.cfg.MaxRetriesAnHour)), sv.cfg.MaxRetriesAnHour)
@@ -93,30 +93,6 @@ func (sv *XTCPVisitor) Close() {
}
}
func (sv *XTCPVisitor) worker() {
xl := xlog.FromContextSafe(sv.ctx)
for {
conn, err := sv.l.Accept()
if err != nil {
xl.Warnf("xtcp local listener closed")
return
}
go sv.handleConn(conn)
}
}
func (sv *XTCPVisitor) internalConnWorker() {
xl := xlog.FromContextSafe(sv.ctx)
for {
conn, err := sv.internalLn.Accept()
if err != nil {
xl.Warnf("xtcp internal listener closed")
return
}
go sv.handleConn(conn)
}
}
func (sv *XTCPVisitor) processTunnelStartEvents() {
for {
select {

View File

@@ -203,6 +203,25 @@ func (m *serverMetrics) GetServer() *ServerStats {
return s
}
func toProxyStats(name string, proxyStats *ProxyStatistics) *ProxyStats {
ps := &ProxyStats{
Name: name,
Type: proxyStats.ProxyType,
User: proxyStats.User,
ClientID: proxyStats.ClientID,
TodayTrafficIn: proxyStats.TrafficIn.TodayCount(),
TodayTrafficOut: proxyStats.TrafficOut.TodayCount(),
CurConns: int64(proxyStats.CurConns.Count()),
}
if !proxyStats.LastStartTime.IsZero() {
ps.LastStartTime = proxyStats.LastStartTime.Format("01-02 15:04:05")
}
if !proxyStats.LastCloseTime.IsZero() {
ps.LastCloseTime = proxyStats.LastCloseTime.Format("01-02 15:04:05")
}
return ps
}
func (m *serverMetrics) GetProxiesByType(proxyType string) []*ProxyStats {
res := make([]*ProxyStats, 0)
m.mu.Lock()
@@ -212,23 +231,7 @@ func (m *serverMetrics) GetProxiesByType(proxyType string) []*ProxyStats {
if proxyStats.ProxyType != proxyType {
continue
}
ps := &ProxyStats{
Name: name,
Type: proxyStats.ProxyType,
User: proxyStats.User,
ClientID: proxyStats.ClientID,
TodayTrafficIn: proxyStats.TrafficIn.TodayCount(),
TodayTrafficOut: proxyStats.TrafficOut.TodayCount(),
CurConns: int64(proxyStats.CurConns.Count()),
}
if !proxyStats.LastStartTime.IsZero() {
ps.LastStartTime = proxyStats.LastStartTime.Format("01-02 15:04:05")
}
if !proxyStats.LastCloseTime.IsZero() {
ps.LastCloseTime = proxyStats.LastCloseTime.Format("01-02 15:04:05")
}
res = append(res, ps)
res = append(res, toProxyStats(name, proxyStats))
}
return res
}
@@ -241,26 +244,10 @@ func (m *serverMetrics) GetProxiesByTypeAndName(proxyType string, proxyName stri
if proxyStats.ProxyType != proxyType {
continue
}
if name != proxyName {
continue
}
res = &ProxyStats{
Name: name,
Type: proxyStats.ProxyType,
User: proxyStats.User,
ClientID: proxyStats.ClientID,
TodayTrafficIn: proxyStats.TrafficIn.TodayCount(),
TodayTrafficOut: proxyStats.TrafficOut.TodayCount(),
CurConns: int64(proxyStats.CurConns.Count()),
}
if !proxyStats.LastStartTime.IsZero() {
res.LastStartTime = proxyStats.LastStartTime.Format("01-02 15:04:05")
}
if !proxyStats.LastCloseTime.IsZero() {
res.LastCloseTime = proxyStats.LastCloseTime.Format("01-02 15:04:05")
}
res = toProxyStats(name, proxyStats)
break
}
return
@@ -272,21 +259,7 @@ func (m *serverMetrics) GetProxyByName(proxyName string) (res *ProxyStats) {
proxyStats, ok := m.info.ProxyStatistics[proxyName]
if ok {
res = &ProxyStats{
Name: proxyName,
Type: proxyStats.ProxyType,
User: proxyStats.User,
ClientID: proxyStats.ClientID,
TodayTrafficIn: proxyStats.TrafficIn.TodayCount(),
TodayTrafficOut: proxyStats.TrafficOut.TodayCount(),
CurConns: int64(proxyStats.CurConns.Count()),
}
if !proxyStats.LastStartTime.IsZero() {
res.LastStartTime = proxyStats.LastStartTime.Format("01-02 15:04:05")
}
if !proxyStats.LastCloseTime.IsZero() {
res.LastCloseTime = proxyStats.LastCloseTime.Format("01-02 15:04:05")
}
res = toProxyStats(proxyName, proxyStats)
}
return
}

View File

@@ -152,7 +152,9 @@ func (c *Controller) GenSid() string {
func (c *Controller) HandleVisitor(m *msg.NatHoleVisitor, transporter transport.MessageTransporter, visitorUser string) {
if m.PreCheck {
c.mu.RLock()
cfg, ok := c.clientCfgs[m.ProxyName]
c.mu.RUnlock()
if !ok {
_ = transporter.Send(c.GenNatHoleResponse(m.TransactionID, nil, fmt.Sprintf("xtcp server for [%s] doesn't exist", m.ProxyName)))
return

View File

@@ -168,11 +168,15 @@ func ListenUDP(bindAddr string, bindPort int) (l *UDPListener, err error) {
return l, err
}
readConn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
return l, err
}
l = &UDPListener{
addr: udpAddr,
acceptCh: make(chan net.Conn),
writeCh: make(chan *UDPPacket, 1000),
readConn: readConn,
fakeConns: make(map[string]*FakeUDPConn),
}

View File

@@ -266,31 +266,13 @@ func (rp *HTTPReverseProxy) connectHandler(rw http.ResponseWriter, req *http.Req
go libio.Join(remote, client)
}
func parseBasicAuth(auth string) (username, password string, ok bool) {
const prefix = "Basic "
// Case insensitive prefix match. See Issue 22736.
if len(auth) < len(prefix) || !strings.EqualFold(auth[:len(prefix)], prefix) {
return
}
c, err := base64.StdEncoding.DecodeString(auth[len(prefix):])
if err != nil {
return
}
cs := string(c)
s := strings.IndexByte(cs, ':')
if s < 0 {
return
}
return cs[:s], cs[s+1:], true
}
func (rp *HTTPReverseProxy) injectRequestInfoToCtx(req *http.Request) *http.Request {
user := ""
// If url host isn't empty, it's a proxy request. Get http user from Proxy-Authorization header.
if req.URL.Host != "" {
proxyAuth := req.Header.Get("Proxy-Authorization")
if proxyAuth != "" {
user, _, _ = parseBasicAuth(proxyAuth)
user, _, _ = httppkg.ParseBasicAuth(proxyAuth)
}
}
if user == "" {

View File

@@ -95,20 +95,33 @@ func (cm *ControlManager) Close() error {
return nil
}
type Control struct {
// SessionContext encapsulates the input parameters for creating a new Control.
type SessionContext struct {
// all resource managers and controllers
rc *controller.ResourceController
RC *controller.ResourceController
// proxy manager
pxyManager *proxy.Manager
PxyManager *proxy.Manager
// plugin manager
pluginManager *plugin.Manager
PluginManager *plugin.Manager
// verifies authentication based on selected method
authVerifier auth.Verifier
AuthVerifier auth.Verifier
// key used for connection encryption
encryptionKey []byte
EncryptionKey []byte
// control connection
Conn net.Conn
// indicates whether the connection is encrypted
ConnEncrypted bool
// login message
LoginMsg *msg.Login
// server configuration
ServerCfg *v1.ServerConfig
// client registry
ClientRegistry *registry.ClientRegistry
}
type Control struct {
// session context
sessionCtx *SessionContext
// other components can use this to communicate with client
msgTransporter transport.MessageTransporter
@@ -117,12 +130,6 @@ type Control struct {
// It provides a channel for sending messages, and you can register handlers to process messages based on their respective types.
msgDispatcher *msg.Dispatcher
// login message
loginMsg *msg.Login
// control connection
conn net.Conn
// work connections
workConnCh chan net.Conn
@@ -145,61 +152,37 @@ type Control struct {
mu sync.RWMutex
// Server configuration information
serverCfg *v1.ServerConfig
clientRegistry *registry.ClientRegistry
xl *xlog.Logger
ctx context.Context
doneCh chan struct{}
}
// TODO(fatedier): Referencing the implementation of frpc, encapsulate the input parameters as SessionContext.
func NewControl(
ctx context.Context,
rc *controller.ResourceController,
pxyManager *proxy.Manager,
pluginManager *plugin.Manager,
authVerifier auth.Verifier,
encryptionKey []byte,
ctlConn net.Conn,
ctlConnEncrypted bool,
loginMsg *msg.Login,
serverCfg *v1.ServerConfig,
) (*Control, error) {
poolCount := loginMsg.PoolCount
if poolCount > int(serverCfg.Transport.MaxPoolCount) {
poolCount = int(serverCfg.Transport.MaxPoolCount)
func NewControl(ctx context.Context, sessionCtx *SessionContext) (*Control, error) {
poolCount := sessionCtx.LoginMsg.PoolCount
if poolCount > int(sessionCtx.ServerCfg.Transport.MaxPoolCount) {
poolCount = int(sessionCtx.ServerCfg.Transport.MaxPoolCount)
}
ctl := &Control{
rc: rc,
pxyManager: pxyManager,
pluginManager: pluginManager,
authVerifier: authVerifier,
encryptionKey: encryptionKey,
conn: ctlConn,
loginMsg: loginMsg,
workConnCh: make(chan net.Conn, poolCount+10),
proxies: make(map[string]proxy.Proxy),
poolCount: poolCount,
portsUsedNum: 0,
runID: loginMsg.RunID,
serverCfg: serverCfg,
xl: xlog.FromContextSafe(ctx),
ctx: ctx,
doneCh: make(chan struct{}),
sessionCtx: sessionCtx,
workConnCh: make(chan net.Conn, poolCount+10),
proxies: make(map[string]proxy.Proxy),
poolCount: poolCount,
portsUsedNum: 0,
runID: sessionCtx.LoginMsg.RunID,
xl: xlog.FromContextSafe(ctx),
ctx: ctx,
doneCh: make(chan struct{}),
}
ctl.lastPing.Store(time.Now())
if ctlConnEncrypted {
cryptoRW, err := netpkg.NewCryptoReadWriter(ctl.conn, ctl.encryptionKey)
if sessionCtx.ConnEncrypted {
cryptoRW, err := netpkg.NewCryptoReadWriter(sessionCtx.Conn, sessionCtx.EncryptionKey)
if err != nil {
return nil, err
}
ctl.msgDispatcher = msg.NewDispatcher(cryptoRW)
} else {
ctl.msgDispatcher = msg.NewDispatcher(ctl.conn)
ctl.msgDispatcher = msg.NewDispatcher(sessionCtx.Conn)
}
ctl.registerMsgHandlers()
ctl.msgTransporter = transport.NewMessageTransporter(ctl.msgDispatcher)
@@ -213,7 +196,7 @@ func (ctl *Control) Start() {
RunID: ctl.runID,
Error: "",
}
_ = msg.WriteMsg(ctl.conn, loginRespMsg)
_ = msg.WriteMsg(ctl.sessionCtx.Conn, loginRespMsg)
go func() {
for i := 0; i < ctl.poolCount; i++ {
@@ -225,7 +208,7 @@ func (ctl *Control) Start() {
}
func (ctl *Control) Close() error {
ctl.conn.Close()
ctl.sessionCtx.Conn.Close()
return nil
}
@@ -233,7 +216,7 @@ func (ctl *Control) Replaced(newCtl *Control) {
xl := ctl.xl
xl.Infof("replaced by client [%s]", newCtl.runID)
ctl.runID = ""
ctl.conn.Close()
ctl.sessionCtx.Conn.Close()
}
func (ctl *Control) RegisterWorkConn(conn net.Conn) error {
@@ -291,7 +274,7 @@ func (ctl *Control) GetWorkConn() (workConn net.Conn, err error) {
return
}
case <-time.After(time.Duration(ctl.serverCfg.UserConnTimeout) * time.Second):
case <-time.After(time.Duration(ctl.sessionCtx.ServerCfg.UserConnTimeout) * time.Second):
err = fmt.Errorf("timeout trying to get work connection")
xl.Warnf("%v", err)
return
@@ -304,15 +287,15 @@ func (ctl *Control) GetWorkConn() (workConn net.Conn, err error) {
}
func (ctl *Control) heartbeatWorker() {
if ctl.serverCfg.Transport.HeartbeatTimeout <= 0 {
if ctl.sessionCtx.ServerCfg.Transport.HeartbeatTimeout <= 0 {
return
}
xl := ctl.xl
go wait.Until(func() {
if time.Since(ctl.lastPing.Load().(time.Time)) > time.Duration(ctl.serverCfg.Transport.HeartbeatTimeout)*time.Second {
if time.Since(ctl.lastPing.Load().(time.Time)) > time.Duration(ctl.sessionCtx.ServerCfg.Transport.HeartbeatTimeout)*time.Second {
xl.Warnf("heartbeat timeout")
ctl.conn.Close()
ctl.sessionCtx.Conn.Close()
return
}
}, time.Second, ctl.doneCh)
@@ -330,7 +313,7 @@ func (ctl *Control) worker() {
go ctl.msgDispatcher.Run()
<-ctl.msgDispatcher.Done()
ctl.conn.Close()
ctl.sessionCtx.Conn.Close()
ctl.mu.Lock()
defer ctl.mu.Unlock()
@@ -342,26 +325,26 @@ func (ctl *Control) worker() {
for _, pxy := range ctl.proxies {
pxy.Close()
ctl.pxyManager.Del(pxy.GetName())
ctl.sessionCtx.PxyManager.Del(pxy.GetName())
metrics.Server.CloseProxy(pxy.GetName(), pxy.GetConfigurer().GetBaseConfig().Type)
notifyContent := &plugin.CloseProxyContent{
User: plugin.UserInfo{
User: ctl.loginMsg.User,
Metas: ctl.loginMsg.Metas,
RunID: ctl.loginMsg.RunID,
User: ctl.sessionCtx.LoginMsg.User,
Metas: ctl.sessionCtx.LoginMsg.Metas,
RunID: ctl.sessionCtx.LoginMsg.RunID,
},
CloseProxy: msg.CloseProxy{
ProxyName: pxy.GetName(),
},
}
go func() {
_ = ctl.pluginManager.CloseProxy(notifyContent)
_ = ctl.sessionCtx.PluginManager.CloseProxy(notifyContent)
}()
}
metrics.Server.CloseClient()
ctl.clientRegistry.MarkOfflineByRunID(ctl.runID)
ctl.sessionCtx.ClientRegistry.MarkOfflineByRunID(ctl.runID)
xl.Infof("client exit success")
close(ctl.doneCh)
}
@@ -381,14 +364,14 @@ func (ctl *Control) handleNewProxy(m msg.Message) {
content := &plugin.NewProxyContent{
User: plugin.UserInfo{
User: ctl.loginMsg.User,
Metas: ctl.loginMsg.Metas,
RunID: ctl.loginMsg.RunID,
User: ctl.sessionCtx.LoginMsg.User,
Metas: ctl.sessionCtx.LoginMsg.Metas,
RunID: ctl.sessionCtx.LoginMsg.RunID,
},
NewProxy: *inMsg,
}
var remoteAddr string
retContent, err := ctl.pluginManager.NewProxy(content)
retContent, err := ctl.sessionCtx.PluginManager.NewProxy(content)
if err == nil {
inMsg = &retContent.NewProxy
remoteAddr, err = ctl.RegisterProxy(inMsg)
@@ -401,15 +384,15 @@ func (ctl *Control) handleNewProxy(m msg.Message) {
if err != nil {
xl.Warnf("new proxy [%s] type [%s] error: %v", inMsg.ProxyName, inMsg.ProxyType, err)
resp.Error = util.GenerateResponseErrorString(fmt.Sprintf("new proxy [%s] error", inMsg.ProxyName),
err, lo.FromPtr(ctl.serverCfg.DetailedErrorsToClient))
err, lo.FromPtr(ctl.sessionCtx.ServerCfg.DetailedErrorsToClient))
} else {
resp.RemoteAddr = remoteAddr
xl.Infof("new proxy [%s] type [%s] success", inMsg.ProxyName, inMsg.ProxyType)
clientID := ctl.loginMsg.ClientID
clientID := ctl.sessionCtx.LoginMsg.ClientID
if clientID == "" {
clientID = ctl.loginMsg.RunID
clientID = ctl.sessionCtx.LoginMsg.RunID
}
metrics.Server.NewProxy(inMsg.ProxyName, inMsg.ProxyType, ctl.loginMsg.User, clientID)
metrics.Server.NewProxy(inMsg.ProxyName, inMsg.ProxyType, ctl.sessionCtx.LoginMsg.User, clientID)
}
_ = ctl.msgDispatcher.Send(resp)
}
@@ -420,21 +403,21 @@ func (ctl *Control) handlePing(m msg.Message) {
content := &plugin.PingContent{
User: plugin.UserInfo{
User: ctl.loginMsg.User,
Metas: ctl.loginMsg.Metas,
RunID: ctl.loginMsg.RunID,
User: ctl.sessionCtx.LoginMsg.User,
Metas: ctl.sessionCtx.LoginMsg.Metas,
RunID: ctl.sessionCtx.LoginMsg.RunID,
},
Ping: *inMsg,
}
retContent, err := ctl.pluginManager.Ping(content)
retContent, err := ctl.sessionCtx.PluginManager.Ping(content)
if err == nil {
inMsg = &retContent.Ping
err = ctl.authVerifier.VerifyPing(inMsg)
err = ctl.sessionCtx.AuthVerifier.VerifyPing(inMsg)
}
if err != nil {
xl.Warnf("received invalid ping: %v", err)
_ = ctl.msgDispatcher.Send(&msg.Pong{
Error: util.GenerateResponseErrorString("invalid ping", err, lo.FromPtr(ctl.serverCfg.DetailedErrorsToClient)),
Error: util.GenerateResponseErrorString("invalid ping", err, lo.FromPtr(ctl.sessionCtx.ServerCfg.DetailedErrorsToClient)),
})
return
}
@@ -445,17 +428,17 @@ func (ctl *Control) handlePing(m msg.Message) {
func (ctl *Control) handleNatHoleVisitor(m msg.Message) {
inMsg := m.(*msg.NatHoleVisitor)
ctl.rc.NatHoleController.HandleVisitor(inMsg, ctl.msgTransporter, ctl.loginMsg.User)
ctl.sessionCtx.RC.NatHoleController.HandleVisitor(inMsg, ctl.msgTransporter, ctl.sessionCtx.LoginMsg.User)
}
func (ctl *Control) handleNatHoleClient(m msg.Message) {
inMsg := m.(*msg.NatHoleClient)
ctl.rc.NatHoleController.HandleClient(inMsg, ctl.msgTransporter)
ctl.sessionCtx.RC.NatHoleController.HandleClient(inMsg, ctl.msgTransporter)
}
func (ctl *Control) handleNatHoleReport(m msg.Message) {
inMsg := m.(*msg.NatHoleReport)
ctl.rc.NatHoleController.HandleReport(inMsg)
ctl.sessionCtx.RC.NatHoleController.HandleReport(inMsg)
}
func (ctl *Control) handleCloseProxy(m msg.Message) {
@@ -468,15 +451,15 @@ func (ctl *Control) handleCloseProxy(m msg.Message) {
func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err error) {
var pxyConf v1.ProxyConfigurer
// Load configures from NewProxy message and validate.
pxyConf, err = config.NewProxyConfigurerFromMsg(pxyMsg, ctl.serverCfg)
pxyConf, err = config.NewProxyConfigurerFromMsg(pxyMsg, ctl.sessionCtx.ServerCfg)
if err != nil {
return
}
// User info
userInfo := plugin.UserInfo{
User: ctl.loginMsg.User,
Metas: ctl.loginMsg.Metas,
User: ctl.sessionCtx.LoginMsg.User,
Metas: ctl.sessionCtx.LoginMsg.Metas,
RunID: ctl.runID,
}
@@ -484,22 +467,22 @@ func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err
// In fact, it creates different proxies based on the proxy type. We just call run() here.
pxy, err := proxy.NewProxy(ctl.ctx, &proxy.Options{
UserInfo: userInfo,
LoginMsg: ctl.loginMsg,
LoginMsg: ctl.sessionCtx.LoginMsg,
PoolCount: ctl.poolCount,
ResourceController: ctl.rc,
ResourceController: ctl.sessionCtx.RC,
GetWorkConnFn: ctl.GetWorkConn,
Configurer: pxyConf,
ServerCfg: ctl.serverCfg,
EncryptionKey: ctl.encryptionKey,
ServerCfg: ctl.sessionCtx.ServerCfg,
EncryptionKey: ctl.sessionCtx.EncryptionKey,
})
if err != nil {
return remoteAddr, err
}
// Check ports used number in each client
if ctl.serverCfg.MaxPortsPerClient > 0 {
if ctl.sessionCtx.ServerCfg.MaxPortsPerClient > 0 {
ctl.mu.Lock()
if ctl.portsUsedNum+pxy.GetUsedPortsNum() > int(ctl.serverCfg.MaxPortsPerClient) {
if ctl.portsUsedNum+pxy.GetUsedPortsNum() > int(ctl.sessionCtx.ServerCfg.MaxPortsPerClient) {
ctl.mu.Unlock()
err = fmt.Errorf("exceed the max_ports_per_client")
return
@@ -516,7 +499,7 @@ func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err
}()
}
if ctl.pxyManager.Exist(pxyMsg.ProxyName) {
if ctl.sessionCtx.PxyManager.Exist(pxyMsg.ProxyName) {
err = fmt.Errorf("proxy [%s] already exists", pxyMsg.ProxyName)
return
}
@@ -531,7 +514,7 @@ func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err
}
}()
err = ctl.pxyManager.Add(pxyMsg.ProxyName, pxy)
err = ctl.sessionCtx.PxyManager.Add(pxyMsg.ProxyName, pxy)
if err != nil {
return
}
@@ -550,11 +533,11 @@ func (ctl *Control) CloseProxy(closeMsg *msg.CloseProxy) (err error) {
return
}
if ctl.serverCfg.MaxPortsPerClient > 0 {
if ctl.sessionCtx.ServerCfg.MaxPortsPerClient > 0 {
ctl.portsUsedNum -= pxy.GetUsedPortsNum()
}
pxy.Close()
ctl.pxyManager.Del(pxy.GetName())
ctl.sessionCtx.PxyManager.Del(pxy.GetName())
delete(ctl.proxies, closeMsg.ProxyName)
ctl.mu.Unlock()
@@ -562,16 +545,16 @@ func (ctl *Control) CloseProxy(closeMsg *msg.CloseProxy) (err error) {
notifyContent := &plugin.CloseProxyContent{
User: plugin.UserInfo{
User: ctl.loginMsg.User,
Metas: ctl.loginMsg.Metas,
RunID: ctl.loginMsg.RunID,
User: ctl.sessionCtx.LoginMsg.User,
Metas: ctl.sessionCtx.LoginMsg.Metas,
RunID: ctl.sessionCtx.LoginMsg.RunID,
},
CloseProxy: msg.CloseProxy{
ProxyName: pxy.GetName(),
},
}
go func() {
_ = ctl.pluginManager.CloseProxy(notifyContent)
_ = ctl.sessionCtx.PluginManager.CloseProxy(notifyContent)
}()
return
}

View File

@@ -75,15 +75,7 @@ func (pxy *HTTPProxy) Run() (remoteAddr string, err error) {
}
}()
domains := make([]string, 0, len(pxy.cfg.CustomDomains)+1)
for _, d := range pxy.cfg.CustomDomains {
if d != "" {
domains = append(domains, d)
}
}
if pxy.cfg.SubDomain != "" {
domains = append(domains, pxy.cfg.SubDomain+"."+pxy.serverCfg.SubDomainHost)
}
domains := pxy.buildDomains(pxy.cfg.CustomDomains, pxy.cfg.SubDomain)
addrs := make([]string, 0)
for _, domain := range domains {

View File

@@ -53,15 +53,7 @@ func (pxy *HTTPSProxy) Run() (remoteAddr string, err error) {
pxy.Close()
}
}()
domains := make([]string, 0, len(pxy.cfg.CustomDomains)+1)
for _, d := range pxy.cfg.CustomDomains {
if d != "" {
domains = append(domains, d)
}
}
if pxy.cfg.SubDomain != "" {
domains = append(domains, pxy.cfg.SubDomain+"."+pxy.serverCfg.SubDomainHost)
}
domains := pxy.buildDomains(pxy.cfg.CustomDomains, pxy.cfg.SubDomain)
addrs := make([]string, 0)
for _, domain := range domains {

View File

@@ -150,7 +150,7 @@ func (pxy *BaseProxy) GetWorkConnFromPool(src, dst net.Addr) (workConn net.Conn,
dstAddr, dstPortStr, _ = net.SplitHostPort(dst.String())
dstPort, _ = strconv.ParseUint(dstPortStr, 10, 16)
}
err := msg.WriteMsg(workConn, &msg.StartWorkConn{
err = msg.WriteMsg(workConn, &msg.StartWorkConn{
ProxyName: pxy.GetName(),
SrcAddr: srcAddr,
SrcPort: uint16(srcPort),
@@ -161,6 +161,7 @@ func (pxy *BaseProxy) GetWorkConnFromPool(src, dst net.Addr) (workConn net.Conn,
if err != nil {
xl.Warnf("failed to send message to work connection from pool: %v, times: %d", err, i)
workConn.Close()
workConn = nil
} else {
break
}
@@ -173,6 +174,36 @@ func (pxy *BaseProxy) GetWorkConnFromPool(src, dst net.Addr) (workConn net.Conn,
return
}
// startVisitorListener sets up a VisitorManager listener for visitor-based proxies (STCP, SUDP).
func (pxy *BaseProxy) startVisitorListener(secretKey string, allowUsers []string, proxyType string) error {
// if allowUsers is empty, only allow same user from proxy
if len(allowUsers) == 0 {
allowUsers = []string{pxy.GetUserInfo().User}
}
listener, err := pxy.rc.VisitorManager.Listen(pxy.GetName(), secretKey, allowUsers)
if err != nil {
return err
}
pxy.listeners = append(pxy.listeners, listener)
pxy.xl.Infof("%s proxy custom listen success", proxyType)
pxy.startCommonTCPListenersHandler()
return nil
}
// buildDomains constructs a list of domains from custom domains and subdomain configuration.
func (pxy *BaseProxy) buildDomains(customDomains []string, subDomain string) []string {
domains := make([]string, 0, len(customDomains)+1)
for _, d := range customDomains {
if d != "" {
domains = append(domains, d)
}
}
if subDomain != "" {
domains = append(domains, subDomain+"."+pxy.serverCfg.SubDomainHost)
}
return domains
}
// startCommonTCPListenersHandler start a goroutine handler for each listener.
func (pxy *BaseProxy) startCommonTCPListenersHandler() {
xl := xlog.FromContextSafe(pxy.ctx)

View File

@@ -41,21 +41,7 @@ func NewSTCPProxy(baseProxy *BaseProxy) Proxy {
}
func (pxy *STCPProxy) Run() (remoteAddr string, err error) {
xl := pxy.xl
allowUsers := pxy.cfg.AllowUsers
// if allowUsers is empty, only allow same user from proxy
if len(allowUsers) == 0 {
allowUsers = []string{pxy.GetUserInfo().User}
}
listener, errRet := pxy.rc.VisitorManager.Listen(pxy.GetName(), pxy.cfg.Secretkey, allowUsers)
if errRet != nil {
err = errRet
return
}
pxy.listeners = append(pxy.listeners, listener)
xl.Infof("stcp proxy custom listen success")
pxy.startCommonTCPListenersHandler()
err = pxy.startVisitorListener(pxy.cfg.Secretkey, pxy.cfg.AllowUsers, "stcp")
return
}

View File

@@ -41,21 +41,7 @@ func NewSUDPProxy(baseProxy *BaseProxy) Proxy {
}
func (pxy *SUDPProxy) Run() (remoteAddr string, err error) {
xl := pxy.xl
allowUsers := pxy.cfg.AllowUsers
// if allowUsers is empty, only allow same user from proxy
if len(allowUsers) == 0 {
allowUsers = []string{pxy.GetUserInfo().User}
}
listener, errRet := pxy.rc.VisitorManager.Listen(pxy.GetName(), pxy.cfg.Secretkey, allowUsers)
if errRet != nil {
err = errRet
return
}
pxy.listeners = append(pxy.listeners, listener)
xl.Infof("sudp proxy custom listen success")
pxy.startCommonTCPListenersHandler()
err = pxy.startVisitorListener(pxy.cfg.Secretkey, pxy.cfg.AllowUsers, "sudp")
return
}

View File

@@ -72,15 +72,7 @@ func (pxy *TCPMuxProxy) httpConnectListen(
}
func (pxy *TCPMuxProxy) httpConnectRun() (remoteAddr string, err error) {
domains := make([]string, 0, len(pxy.cfg.CustomDomains)+1)
for _, d := range pxy.cfg.CustomDomains {
if d != "" {
domains = append(domains, d)
}
}
if pxy.cfg.SubDomain != "" {
domains = append(domains, pxy.cfg.SubDomain+"."+pxy.serverCfg.SubDomainHost)
}
domains := pxy.buildDomains(pxy.cfg.CustomDomains, pxy.cfg.SubDomain)
addrs := make([]string, 0)
for _, domain := range domains {

View File

@@ -604,8 +604,18 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, inter
return err
}
// TODO(fatedier): use SessionContext
ctl, err := NewControl(ctx, svr.rc, svr.pxyManager, svr.pluginManager, authVerifier, svr.auth.EncryptionKey(), ctlConn, !internal, loginMsg, svr.cfg)
ctl, err := NewControl(ctx, &SessionContext{
RC: svr.rc,
PxyManager: svr.pxyManager,
PluginManager: svr.pluginManager,
AuthVerifier: authVerifier,
EncryptionKey: svr.auth.EncryptionKey(),
Conn: ctlConn,
ConnEncrypted: !internal,
LoginMsg: loginMsg,
ServerCfg: svr.cfg,
ClientRegistry: svr.clientRegistry,
})
if err != nil {
xl.Warnf("create new controller error: %v", err)
// don't return detailed errors to client
@@ -626,7 +636,6 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, inter
ctl.Close()
return fmt.Errorf("client_id [%s] for user [%s] is already online", loginMsg.ClientID, loginMsg.User)
}
ctl.clientRegistry = svr.clientRegistry
ctl.Start()
@@ -652,9 +661,9 @@ func (svr *Service) RegisterWorkConn(workConn net.Conn, newMsg *msg.NewWorkConn)
// server plugin hook
content := &plugin.NewWorkConnContent{
User: plugin.UserInfo{
User: ctl.loginMsg.User,
Metas: ctl.loginMsg.Metas,
RunID: ctl.loginMsg.RunID,
User: ctl.sessionCtx.LoginMsg.User,
Metas: ctl.sessionCtx.LoginMsg.Metas,
RunID: ctl.sessionCtx.LoginMsg.RunID,
},
NewWorkConn: *newMsg,
}
@@ -662,7 +671,7 @@ func (svr *Service) RegisterWorkConn(workConn net.Conn, newMsg *msg.NewWorkConn)
if err == nil {
newMsg = &retContent.NewWorkConn
// Check auth.
err = ctl.authVerifier.VerifyNewWorkConn(newMsg)
err = ctl.sessionCtx.AuthVerifier.VerifyNewWorkConn(newMsg)
}
if err != nil {
xl.Warnf("invalid NewWorkConn with run id [%s]", newMsg.RunID)
@@ -683,7 +692,7 @@ func (svr *Service) RegisterVisitorConn(visitorConn net.Conn, newMsg *msg.NewVis
if !exist {
return fmt.Errorf("no client control found for run id [%s]", newMsg.RunID)
}
visitorUser = ctl.loginMsg.User
visitorUser = ctl.sessionCtx.LoginMsg.User
}
return svr.rc.VisitorManager.NewConn(newMsg.ProxyName, visitorConn, newMsg.Timestamp, newMsg.SignKey,
newMsg.UseEncryption, newMsg.UseCompression, visitorUser)