client/visitor: deduplicate visitor connection handshake and wrapping (#5219)

Extract two shared helpers to eliminate duplicated code across STCP,
SUDP, and XTCP visitors:

- dialRawVisitorConn: handles ConnectServer + NewVisitorConn handshake
  (auth, sign key, 10s read deadline, error check)
- wrapVisitorConn: handles encryption + pooled compression wrapping,
  returning a recycleFn for pool resource cleanup

SUDP is upgraded from WithCompression to WithCompressionFromPool,
aligning with the pooled compression used by STCP and XTCP.
This commit is contained in:
fatedier
2026-03-08 01:03:40 +08:00
committed by GitHub
parent 764a626b6e
commit 605f3bdece
4 changed files with 87 additions and 120 deletions

View File

@@ -15,18 +15,12 @@
package visitor package visitor
import ( import (
"fmt"
"io"
"net" "net"
"strconv" "strconv"
"time"
libio "github.com/fatedier/golib/io" libio "github.com/fatedier/golib/io"
v1 "github.com/fatedier/frp/pkg/config/v1" v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/naming"
"github.com/fatedier/frp/pkg/util/util"
"github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/frp/pkg/util/xlog"
) )
@@ -61,7 +55,6 @@ func (sv *STCPVisitor) handleConn(userConn net.Conn) {
xl := xlog.FromContextSafe(sv.ctx) xl := xlog.FromContextSafe(sv.ctx)
var tunnelErr error var tunnelErr error
defer func() { defer func() {
// If there was an error and connection supports CloseWithError, use it
if tunnelErr != nil { if tunnelErr != nil {
if eConn, ok := userConn.(interface{ CloseWithError(error) error }); ok { if eConn, ok := userConn.(interface{ CloseWithError(error) error }); ok {
_ = eConn.CloseWithError(tunnelErr) _ = eConn.CloseWithError(tunnelErr)
@@ -72,62 +65,21 @@ func (sv *STCPVisitor) handleConn(userConn net.Conn) {
}() }()
xl.Debugf("get a new stcp user connection") xl.Debugf("get a new stcp user connection")
visitorConn, err := sv.helper.ConnectServer() visitorConn, err := sv.dialRawVisitorConn(sv.cfg.GetBaseConfig())
if err != nil { if err != nil {
xl.Warnf("dialRawVisitorConn error: %v", err)
tunnelErr = err tunnelErr = err
return return
} }
defer visitorConn.Close() defer visitorConn.Close()
now := time.Now().Unix() remote, recycleFn, err := wrapVisitorConn(visitorConn, sv.cfg.GetBaseConfig())
targetProxyName := naming.BuildTargetServerProxyName(sv.clientCfg.User, sv.cfg.ServerUser, sv.cfg.ServerName)
newVisitorConnMsg := &msg.NewVisitorConn{
RunID: sv.helper.RunID(),
ProxyName: targetProxyName,
SignKey: util.GetAuthKey(sv.cfg.SecretKey, now),
Timestamp: now,
UseEncryption: sv.cfg.Transport.UseEncryption,
UseCompression: sv.cfg.Transport.UseCompression,
}
err = msg.WriteMsg(visitorConn, newVisitorConnMsg)
if err != nil { if err != nil {
xl.Warnf("send newVisitorConnMsg to server error: %v", err) xl.Warnf("wrapVisitorConn error: %v", err)
tunnelErr = err tunnelErr = err
return return
} }
defer recycleFn()
var newVisitorConnRespMsg msg.NewVisitorConnResp
_ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second))
err = msg.ReadMsgInto(visitorConn, &newVisitorConnRespMsg)
if err != nil {
xl.Warnf("get newVisitorConnRespMsg error: %v", err)
tunnelErr = err
return
}
_ = visitorConn.SetReadDeadline(time.Time{})
if newVisitorConnRespMsg.Error != "" {
xl.Warnf("start new visitor connection error: %s", newVisitorConnRespMsg.Error)
tunnelErr = fmt.Errorf("%s", newVisitorConnRespMsg.Error)
return
}
var remote io.ReadWriteCloser
remote = visitorConn
if sv.cfg.Transport.UseEncryption {
remote, err = libio.WithEncryption(remote, []byte(sv.cfg.SecretKey))
if err != nil {
xl.Errorf("create encryption stream error: %v", err)
tunnelErr = err
return
}
}
if sv.cfg.Transport.UseCompression {
var recycleFn func()
remote, recycleFn = libio.WithCompressionFromPool(remote)
defer recycleFn()
}
libio.Join(userConn, remote) libio.Join(userConn, remote)
} }

View File

@@ -16,21 +16,17 @@ package visitor
import ( import (
"fmt" "fmt"
"io"
"net" "net"
"strconv" "strconv"
"sync" "sync"
"time" "time"
"github.com/fatedier/golib/errors" "github.com/fatedier/golib/errors"
libio "github.com/fatedier/golib/io"
v1 "github.com/fatedier/frp/pkg/config/v1" v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/naming"
"github.com/fatedier/frp/pkg/proto/udp" "github.com/fatedier/frp/pkg/proto/udp"
netpkg "github.com/fatedier/frp/pkg/util/net" netpkg "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/util"
"github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/frp/pkg/util/xlog"
) )
@@ -76,6 +72,7 @@ func (sv *SUDPVisitor) dispatcher() {
var ( var (
visitorConn net.Conn visitorConn net.Conn
recycleFn func()
err error err error
firstPacket *msg.UDPPacket firstPacket *msg.UDPPacket
@@ -93,14 +90,17 @@ func (sv *SUDPVisitor) dispatcher() {
return return
} }
visitorConn, err = sv.getNewVisitorConn() visitorConn, recycleFn, err = sv.getNewVisitorConn()
if err != nil { if err != nil {
xl.Warnf("newVisitorConn to frps error: %v, try to reconnect", err) xl.Warnf("newVisitorConn to frps error: %v, try to reconnect", err)
continue continue
} }
// visitorConn always be closed when worker done. // visitorConn always be closed when worker done.
sv.worker(visitorConn, firstPacket) func() {
defer recycleFn()
sv.worker(visitorConn, firstPacket)
}()
select { select {
case <-sv.checkCloseCh: case <-sv.checkCloseCh:
@@ -198,57 +198,17 @@ func (sv *SUDPVisitor) worker(workConn net.Conn, firstPacket *msg.UDPPacket) {
xl.Infof("sudp worker is closed") xl.Infof("sudp worker is closed")
} }
func (sv *SUDPVisitor) getNewVisitorConn() (net.Conn, error) { func (sv *SUDPVisitor) getNewVisitorConn() (net.Conn, func(), error) {
xl := xlog.FromContextSafe(sv.ctx) rawConn, err := sv.dialRawVisitorConn(sv.cfg.GetBaseConfig())
visitorConn, err := sv.helper.ConnectServer()
if err != nil { if err != nil {
return nil, fmt.Errorf("frpc connect frps error: %v", err) return nil, func() {}, err
} }
rwc, recycleFn, err := wrapVisitorConn(rawConn, sv.cfg.GetBaseConfig())
now := time.Now().Unix()
targetProxyName := naming.BuildTargetServerProxyName(sv.clientCfg.User, sv.cfg.ServerUser, sv.cfg.ServerName)
newVisitorConnMsg := &msg.NewVisitorConn{
RunID: sv.helper.RunID(),
ProxyName: targetProxyName,
SignKey: util.GetAuthKey(sv.cfg.SecretKey, now),
Timestamp: now,
UseEncryption: sv.cfg.Transport.UseEncryption,
UseCompression: sv.cfg.Transport.UseCompression,
}
err = msg.WriteMsg(visitorConn, newVisitorConnMsg)
if err != nil { if err != nil {
visitorConn.Close() rawConn.Close()
return nil, fmt.Errorf("frpc send newVisitorConnMsg to frps error: %v", err) return nil, func() {}, err
} }
return netpkg.WrapReadWriteCloserToConn(rwc, rawConn), recycleFn, nil
var newVisitorConnRespMsg msg.NewVisitorConnResp
_ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second))
err = msg.ReadMsgInto(visitorConn, &newVisitorConnRespMsg)
if err != nil {
visitorConn.Close()
return nil, fmt.Errorf("frpc read newVisitorConnRespMsg error: %v", err)
}
_ = visitorConn.SetReadDeadline(time.Time{})
if newVisitorConnRespMsg.Error != "" {
visitorConn.Close()
return nil, fmt.Errorf("start new visitor connection error: %s", newVisitorConnRespMsg.Error)
}
var remote io.ReadWriteCloser
remote = visitorConn
if sv.cfg.Transport.UseEncryption {
remote, err = libio.WithEncryption(remote, []byte(sv.cfg.SecretKey))
if err != nil {
xl.Errorf("create encryption stream error: %v", err)
visitorConn.Close()
return nil, err
}
}
if sv.cfg.Transport.UseCompression {
remote = libio.WithCompression(remote)
}
return netpkg.WrapReadWriteCloserToConn(remote, visitorConn), nil
} }
func (sv *SUDPVisitor) Close() { func (sv *SUDPVisitor) Close() {

View File

@@ -16,13 +16,21 @@ package visitor
import ( import (
"context" "context"
"fmt"
"io"
"net" "net"
"sync" "sync"
"time"
libio "github.com/fatedier/golib/io"
v1 "github.com/fatedier/frp/pkg/config/v1" v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/naming"
plugin "github.com/fatedier/frp/pkg/plugin/visitor" plugin "github.com/fatedier/frp/pkg/plugin/visitor"
"github.com/fatedier/frp/pkg/transport" "github.com/fatedier/frp/pkg/transport"
netpkg "github.com/fatedier/frp/pkg/util/net" netpkg "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/util"
"github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/frp/pkg/util/xlog"
"github.com/fatedier/frp/pkg/vnet" "github.com/fatedier/frp/pkg/vnet"
) )
@@ -142,3 +150,57 @@ func (v *BaseVisitor) Close() {
v.plugin.Close() v.plugin.Close()
} }
} }
func (v *BaseVisitor) dialRawVisitorConn(cfg *v1.VisitorBaseConfig) (net.Conn, error) {
visitorConn, err := v.helper.ConnectServer()
if err != nil {
return nil, fmt.Errorf("connect to server error: %v", err)
}
now := time.Now().Unix()
targetProxyName := naming.BuildTargetServerProxyName(v.clientCfg.User, cfg.ServerUser, cfg.ServerName)
newVisitorConnMsg := &msg.NewVisitorConn{
RunID: v.helper.RunID(),
ProxyName: targetProxyName,
SignKey: util.GetAuthKey(cfg.SecretKey, now),
Timestamp: now,
UseEncryption: cfg.Transport.UseEncryption,
UseCompression: cfg.Transport.UseCompression,
}
err = msg.WriteMsg(visitorConn, newVisitorConnMsg)
if err != nil {
visitorConn.Close()
return nil, fmt.Errorf("send newVisitorConnMsg to server error: %v", err)
}
var newVisitorConnRespMsg msg.NewVisitorConnResp
_ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second))
err = msg.ReadMsgInto(visitorConn, &newVisitorConnRespMsg)
if err != nil {
visitorConn.Close()
return nil, fmt.Errorf("read newVisitorConnRespMsg error: %v", err)
}
_ = visitorConn.SetReadDeadline(time.Time{})
if newVisitorConnRespMsg.Error != "" {
visitorConn.Close()
return nil, fmt.Errorf("start new visitor connection error: %s", newVisitorConnRespMsg.Error)
}
return visitorConn, nil
}
func wrapVisitorConn(conn io.ReadWriteCloser, cfg *v1.VisitorBaseConfig) (io.ReadWriteCloser, func(), error) {
rwc := conn
if cfg.Transport.UseEncryption {
var err error
rwc, err = libio.WithEncryption(rwc, []byte(cfg.SecretKey))
if err != nil {
return nil, func() {}, fmt.Errorf("create encryption stream error: %v", err)
}
}
recycleFn := func() {}
if cfg.Transport.UseCompression {
rwc, recycleFn = libio.WithCompressionFromPool(rwc)
}
return rwc, recycleFn, nil
}

View File

@@ -182,21 +182,14 @@ func (sv *XTCPVisitor) handleConn(userConn net.Conn) {
return return
} }
var muxConnRWCloser io.ReadWriteCloser = tunnelConn muxConnRWCloser, recycleFn, err := wrapVisitorConn(tunnelConn, sv.cfg.GetBaseConfig())
if sv.cfg.Transport.UseEncryption { if err != nil {
muxConnRWCloser, err = libio.WithEncryption(muxConnRWCloser, []byte(sv.cfg.SecretKey)) xl.Errorf("%v", err)
if err != nil { tunnelConn.Close()
xl.Errorf("create encryption stream error: %v", err) tunnelErr = err
tunnelConn.Close() return
tunnelErr = err
return
}
}
if sv.cfg.Transport.UseCompression {
var recycleFn func()
muxConnRWCloser, recycleFn = libio.WithCompressionFromPool(muxConnRWCloser)
defer recycleFn()
} }
defer recycleFn()
_, _, errs := libio.Join(userConn, muxConnRWCloser) _, _, errs := libio.Join(userConn, muxConnRWCloser)
xl.Debugf("join connections closed") xl.Debugf("join connections closed")