From f948c595f48624f791c534067773b7aa0c92b319 Mon Sep 17 00:00:00 2001 From: fatedier Date: Sun, 8 Mar 2026 00:58:30 +0800 Subject: [PATCH] client/visitor: deduplicate visitor connection handshake and wrapping Extract two shared helpers to eliminate duplicated code across STCP, SUDP, and XTCP visitors: - dialRawVisitorConn: handles ConnectServer + NewVisitorConn handshake (auth, sign key, 10s read deadline, error check) - wrapVisitorConn: handles encryption + pooled compression wrapping, returning a recycleFn for pool resource cleanup SUDP is upgraded from WithCompression to WithCompressionFromPool, aligning with the pooled compression used by STCP and XTCP. --- client/visitor/stcp.go | 58 +++------------------------------- client/visitor/sudp.go | 66 ++++++++------------------------------- client/visitor/visitor.go | 62 ++++++++++++++++++++++++++++++++++++ client/visitor/xtcp.go | 21 +++++-------- 4 files changed, 87 insertions(+), 120 deletions(-) diff --git a/client/visitor/stcp.go b/client/visitor/stcp.go index 6529142c..03ec51fe 100644 --- a/client/visitor/stcp.go +++ b/client/visitor/stcp.go @@ -15,18 +15,12 @@ package visitor import ( - "fmt" - "io" "net" "strconv" - "time" libio "github.com/fatedier/golib/io" v1 "github.com/fatedier/frp/pkg/config/v1" - "github.com/fatedier/frp/pkg/msg" - "github.com/fatedier/frp/pkg/naming" - "github.com/fatedier/frp/pkg/util/util" "github.com/fatedier/frp/pkg/util/xlog" ) @@ -61,7 +55,6 @@ func (sv *STCPVisitor) handleConn(userConn net.Conn) { xl := xlog.FromContextSafe(sv.ctx) var tunnelErr error defer func() { - // If there was an error and connection supports CloseWithError, use it if tunnelErr != nil { if eConn, ok := userConn.(interface{ CloseWithError(error) error }); ok { _ = eConn.CloseWithError(tunnelErr) @@ -72,62 +65,21 @@ func (sv *STCPVisitor) handleConn(userConn net.Conn) { }() xl.Debugf("get a new stcp user connection") - visitorConn, err := sv.helper.ConnectServer() + visitorConn, err := sv.dialRawVisitorConn(sv.cfg.GetBaseConfig()) if err != nil { + xl.Warnf("dialRawVisitorConn error: %v", err) tunnelErr = err return } defer visitorConn.Close() - now := time.Now().Unix() - targetProxyName := naming.BuildTargetServerProxyName(sv.clientCfg.User, sv.cfg.ServerUser, sv.cfg.ServerName) - newVisitorConnMsg := &msg.NewVisitorConn{ - RunID: sv.helper.RunID(), - ProxyName: targetProxyName, - SignKey: util.GetAuthKey(sv.cfg.SecretKey, now), - Timestamp: now, - UseEncryption: sv.cfg.Transport.UseEncryption, - UseCompression: sv.cfg.Transport.UseCompression, - } - err = msg.WriteMsg(visitorConn, newVisitorConnMsg) + remote, recycleFn, err := wrapVisitorConn(visitorConn, sv.cfg.GetBaseConfig()) if err != nil { - xl.Warnf("send newVisitorConnMsg to server error: %v", err) + xl.Warnf("wrapVisitorConn error: %v", err) tunnelErr = err return } - - var newVisitorConnRespMsg msg.NewVisitorConnResp - _ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second)) - err = msg.ReadMsgInto(visitorConn, &newVisitorConnRespMsg) - if err != nil { - xl.Warnf("get newVisitorConnRespMsg error: %v", err) - tunnelErr = err - return - } - _ = visitorConn.SetReadDeadline(time.Time{}) - - if newVisitorConnRespMsg.Error != "" { - xl.Warnf("start new visitor connection error: %s", newVisitorConnRespMsg.Error) - tunnelErr = fmt.Errorf("%s", newVisitorConnRespMsg.Error) - return - } - - var remote io.ReadWriteCloser - remote = visitorConn - if sv.cfg.Transport.UseEncryption { - remote, err = libio.WithEncryption(remote, []byte(sv.cfg.SecretKey)) - if err != nil { - xl.Errorf("create encryption stream error: %v", err) - tunnelErr = err - return - } - } - - if sv.cfg.Transport.UseCompression { - var recycleFn func() - remote, recycleFn = libio.WithCompressionFromPool(remote) - defer recycleFn() - } + defer recycleFn() libio.Join(userConn, remote) } diff --git a/client/visitor/sudp.go b/client/visitor/sudp.go index a341da8a..6014161c 100644 --- a/client/visitor/sudp.go +++ b/client/visitor/sudp.go @@ -16,21 +16,17 @@ package visitor import ( "fmt" - "io" "net" "strconv" "sync" "time" "github.com/fatedier/golib/errors" - libio "github.com/fatedier/golib/io" v1 "github.com/fatedier/frp/pkg/config/v1" "github.com/fatedier/frp/pkg/msg" - "github.com/fatedier/frp/pkg/naming" "github.com/fatedier/frp/pkg/proto/udp" netpkg "github.com/fatedier/frp/pkg/util/net" - "github.com/fatedier/frp/pkg/util/util" "github.com/fatedier/frp/pkg/util/xlog" ) @@ -76,6 +72,7 @@ func (sv *SUDPVisitor) dispatcher() { var ( visitorConn net.Conn + recycleFn func() err error firstPacket *msg.UDPPacket @@ -93,14 +90,17 @@ func (sv *SUDPVisitor) dispatcher() { return } - visitorConn, err = sv.getNewVisitorConn() + visitorConn, recycleFn, err = sv.getNewVisitorConn() if err != nil { xl.Warnf("newVisitorConn to frps error: %v, try to reconnect", err) continue } // visitorConn always be closed when worker done. - sv.worker(visitorConn, firstPacket) + func() { + defer recycleFn() + sv.worker(visitorConn, firstPacket) + }() select { case <-sv.checkCloseCh: @@ -198,57 +198,17 @@ func (sv *SUDPVisitor) worker(workConn net.Conn, firstPacket *msg.UDPPacket) { xl.Infof("sudp worker is closed") } -func (sv *SUDPVisitor) getNewVisitorConn() (net.Conn, error) { - xl := xlog.FromContextSafe(sv.ctx) - visitorConn, err := sv.helper.ConnectServer() +func (sv *SUDPVisitor) getNewVisitorConn() (net.Conn, func(), error) { + rawConn, err := sv.dialRawVisitorConn(sv.cfg.GetBaseConfig()) if err != nil { - return nil, fmt.Errorf("frpc connect frps error: %v", err) + return nil, func() {}, err } - - now := time.Now().Unix() - targetProxyName := naming.BuildTargetServerProxyName(sv.clientCfg.User, sv.cfg.ServerUser, sv.cfg.ServerName) - newVisitorConnMsg := &msg.NewVisitorConn{ - RunID: sv.helper.RunID(), - ProxyName: targetProxyName, - SignKey: util.GetAuthKey(sv.cfg.SecretKey, now), - Timestamp: now, - UseEncryption: sv.cfg.Transport.UseEncryption, - UseCompression: sv.cfg.Transport.UseCompression, - } - err = msg.WriteMsg(visitorConn, newVisitorConnMsg) + rwc, recycleFn, err := wrapVisitorConn(rawConn, sv.cfg.GetBaseConfig()) if err != nil { - visitorConn.Close() - return nil, fmt.Errorf("frpc send newVisitorConnMsg to frps error: %v", err) + rawConn.Close() + return nil, func() {}, err } - - var newVisitorConnRespMsg msg.NewVisitorConnResp - _ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second)) - err = msg.ReadMsgInto(visitorConn, &newVisitorConnRespMsg) - if err != nil { - visitorConn.Close() - return nil, fmt.Errorf("frpc read newVisitorConnRespMsg error: %v", err) - } - _ = visitorConn.SetReadDeadline(time.Time{}) - - if newVisitorConnRespMsg.Error != "" { - visitorConn.Close() - return nil, fmt.Errorf("start new visitor connection error: %s", newVisitorConnRespMsg.Error) - } - - var remote io.ReadWriteCloser - remote = visitorConn - if sv.cfg.Transport.UseEncryption { - remote, err = libio.WithEncryption(remote, []byte(sv.cfg.SecretKey)) - if err != nil { - xl.Errorf("create encryption stream error: %v", err) - visitorConn.Close() - return nil, err - } - } - if sv.cfg.Transport.UseCompression { - remote = libio.WithCompression(remote) - } - return netpkg.WrapReadWriteCloserToConn(remote, visitorConn), nil + return netpkg.WrapReadWriteCloserToConn(rwc, rawConn), recycleFn, nil } func (sv *SUDPVisitor) Close() { diff --git a/client/visitor/visitor.go b/client/visitor/visitor.go index 51999499..14a7aa37 100644 --- a/client/visitor/visitor.go +++ b/client/visitor/visitor.go @@ -16,13 +16,21 @@ package visitor import ( "context" + "fmt" + "io" "net" "sync" + "time" + + libio "github.com/fatedier/golib/io" v1 "github.com/fatedier/frp/pkg/config/v1" + "github.com/fatedier/frp/pkg/msg" + "github.com/fatedier/frp/pkg/naming" plugin "github.com/fatedier/frp/pkg/plugin/visitor" "github.com/fatedier/frp/pkg/transport" netpkg "github.com/fatedier/frp/pkg/util/net" + "github.com/fatedier/frp/pkg/util/util" "github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/frp/pkg/vnet" ) @@ -142,3 +150,57 @@ func (v *BaseVisitor) Close() { v.plugin.Close() } } + +func (v *BaseVisitor) dialRawVisitorConn(cfg *v1.VisitorBaseConfig) (net.Conn, error) { + visitorConn, err := v.helper.ConnectServer() + if err != nil { + return nil, fmt.Errorf("connect to server error: %v", err) + } + + now := time.Now().Unix() + targetProxyName := naming.BuildTargetServerProxyName(v.clientCfg.User, cfg.ServerUser, cfg.ServerName) + newVisitorConnMsg := &msg.NewVisitorConn{ + RunID: v.helper.RunID(), + ProxyName: targetProxyName, + SignKey: util.GetAuthKey(cfg.SecretKey, now), + Timestamp: now, + UseEncryption: cfg.Transport.UseEncryption, + UseCompression: cfg.Transport.UseCompression, + } + err = msg.WriteMsg(visitorConn, newVisitorConnMsg) + if err != nil { + visitorConn.Close() + return nil, fmt.Errorf("send newVisitorConnMsg to server error: %v", err) + } + + var newVisitorConnRespMsg msg.NewVisitorConnResp + _ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second)) + err = msg.ReadMsgInto(visitorConn, &newVisitorConnRespMsg) + if err != nil { + visitorConn.Close() + return nil, fmt.Errorf("read newVisitorConnRespMsg error: %v", err) + } + _ = visitorConn.SetReadDeadline(time.Time{}) + + if newVisitorConnRespMsg.Error != "" { + visitorConn.Close() + return nil, fmt.Errorf("start new visitor connection error: %s", newVisitorConnRespMsg.Error) + } + return visitorConn, nil +} + +func wrapVisitorConn(conn io.ReadWriteCloser, cfg *v1.VisitorBaseConfig) (io.ReadWriteCloser, func(), error) { + rwc := conn + if cfg.Transport.UseEncryption { + var err error + rwc, err = libio.WithEncryption(rwc, []byte(cfg.SecretKey)) + if err != nil { + return nil, func() {}, fmt.Errorf("create encryption stream error: %v", err) + } + } + recycleFn := func() {} + if cfg.Transport.UseCompression { + rwc, recycleFn = libio.WithCompressionFromPool(rwc) + } + return rwc, recycleFn, nil +} diff --git a/client/visitor/xtcp.go b/client/visitor/xtcp.go index b2f4ef37..e7a60895 100644 --- a/client/visitor/xtcp.go +++ b/client/visitor/xtcp.go @@ -182,21 +182,14 @@ func (sv *XTCPVisitor) handleConn(userConn net.Conn) { return } - var muxConnRWCloser io.ReadWriteCloser = tunnelConn - if sv.cfg.Transport.UseEncryption { - muxConnRWCloser, err = libio.WithEncryption(muxConnRWCloser, []byte(sv.cfg.SecretKey)) - if err != nil { - xl.Errorf("create encryption stream error: %v", err) - tunnelConn.Close() - tunnelErr = err - return - } - } - if sv.cfg.Transport.UseCompression { - var recycleFn func() - muxConnRWCloser, recycleFn = libio.WithCompressionFromPool(muxConnRWCloser) - defer recycleFn() + muxConnRWCloser, recycleFn, err := wrapVisitorConn(tunnelConn, sv.cfg.GetBaseConfig()) + if err != nil { + xl.Errorf("%v", err) + tunnelConn.Close() + tunnelErr = err + return } + defer recycleFn() _, _, errs := libio.Join(userConn, muxConnRWCloser) xl.Debugf("join connections closed")