mirror of
https://github.com/fatedier/frp.git
synced 2026-03-08 02:49:10 +08:00
client/visitor: deduplicate visitor connection handshake and wrapping (#5219)
Extract two shared helpers to eliminate duplicated code across STCP, SUDP, and XTCP visitors: - dialRawVisitorConn: handles ConnectServer + NewVisitorConn handshake (auth, sign key, 10s read deadline, error check) - wrapVisitorConn: handles encryption + pooled compression wrapping, returning a recycleFn for pool resource cleanup SUDP is upgraded from WithCompression to WithCompressionFromPool, aligning with the pooled compression used by STCP and XTCP.
This commit is contained in:
@@ -15,18 +15,12 @@
|
|||||||
package visitor
|
package visitor
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
|
||||||
|
|
||||||
libio "github.com/fatedier/golib/io"
|
libio "github.com/fatedier/golib/io"
|
||||||
|
|
||||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||||
"github.com/fatedier/frp/pkg/msg"
|
|
||||||
"github.com/fatedier/frp/pkg/naming"
|
|
||||||
"github.com/fatedier/frp/pkg/util/util"
|
|
||||||
"github.com/fatedier/frp/pkg/util/xlog"
|
"github.com/fatedier/frp/pkg/util/xlog"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -61,7 +55,6 @@ func (sv *STCPVisitor) handleConn(userConn net.Conn) {
|
|||||||
xl := xlog.FromContextSafe(sv.ctx)
|
xl := xlog.FromContextSafe(sv.ctx)
|
||||||
var tunnelErr error
|
var tunnelErr error
|
||||||
defer func() {
|
defer func() {
|
||||||
// If there was an error and connection supports CloseWithError, use it
|
|
||||||
if tunnelErr != nil {
|
if tunnelErr != nil {
|
||||||
if eConn, ok := userConn.(interface{ CloseWithError(error) error }); ok {
|
if eConn, ok := userConn.(interface{ CloseWithError(error) error }); ok {
|
||||||
_ = eConn.CloseWithError(tunnelErr)
|
_ = eConn.CloseWithError(tunnelErr)
|
||||||
@@ -72,62 +65,21 @@ func (sv *STCPVisitor) handleConn(userConn net.Conn) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
xl.Debugf("get a new stcp user connection")
|
xl.Debugf("get a new stcp user connection")
|
||||||
visitorConn, err := sv.helper.ConnectServer()
|
visitorConn, err := sv.dialRawVisitorConn(sv.cfg.GetBaseConfig())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
xl.Warnf("dialRawVisitorConn error: %v", err)
|
||||||
tunnelErr = err
|
tunnelErr = err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer visitorConn.Close()
|
defer visitorConn.Close()
|
||||||
|
|
||||||
now := time.Now().Unix()
|
remote, recycleFn, err := wrapVisitorConn(visitorConn, sv.cfg.GetBaseConfig())
|
||||||
targetProxyName := naming.BuildTargetServerProxyName(sv.clientCfg.User, sv.cfg.ServerUser, sv.cfg.ServerName)
|
|
||||||
newVisitorConnMsg := &msg.NewVisitorConn{
|
|
||||||
RunID: sv.helper.RunID(),
|
|
||||||
ProxyName: targetProxyName,
|
|
||||||
SignKey: util.GetAuthKey(sv.cfg.SecretKey, now),
|
|
||||||
Timestamp: now,
|
|
||||||
UseEncryption: sv.cfg.Transport.UseEncryption,
|
|
||||||
UseCompression: sv.cfg.Transport.UseCompression,
|
|
||||||
}
|
|
||||||
err = msg.WriteMsg(visitorConn, newVisitorConnMsg)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
xl.Warnf("send newVisitorConnMsg to server error: %v", err)
|
xl.Warnf("wrapVisitorConn error: %v", err)
|
||||||
tunnelErr = err
|
tunnelErr = err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
defer recycleFn()
|
||||||
var newVisitorConnRespMsg msg.NewVisitorConnResp
|
|
||||||
_ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
|
||||||
err = msg.ReadMsgInto(visitorConn, &newVisitorConnRespMsg)
|
|
||||||
if err != nil {
|
|
||||||
xl.Warnf("get newVisitorConnRespMsg error: %v", err)
|
|
||||||
tunnelErr = err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
_ = visitorConn.SetReadDeadline(time.Time{})
|
|
||||||
|
|
||||||
if newVisitorConnRespMsg.Error != "" {
|
|
||||||
xl.Warnf("start new visitor connection error: %s", newVisitorConnRespMsg.Error)
|
|
||||||
tunnelErr = fmt.Errorf("%s", newVisitorConnRespMsg.Error)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var remote io.ReadWriteCloser
|
|
||||||
remote = visitorConn
|
|
||||||
if sv.cfg.Transport.UseEncryption {
|
|
||||||
remote, err = libio.WithEncryption(remote, []byte(sv.cfg.SecretKey))
|
|
||||||
if err != nil {
|
|
||||||
xl.Errorf("create encryption stream error: %v", err)
|
|
||||||
tunnelErr = err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if sv.cfg.Transport.UseCompression {
|
|
||||||
var recycleFn func()
|
|
||||||
remote, recycleFn = libio.WithCompressionFromPool(remote)
|
|
||||||
defer recycleFn()
|
|
||||||
}
|
|
||||||
|
|
||||||
libio.Join(userConn, remote)
|
libio.Join(userConn, remote)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,21 +16,17 @@ package visitor
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fatedier/golib/errors"
|
"github.com/fatedier/golib/errors"
|
||||||
libio "github.com/fatedier/golib/io"
|
|
||||||
|
|
||||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||||
"github.com/fatedier/frp/pkg/msg"
|
"github.com/fatedier/frp/pkg/msg"
|
||||||
"github.com/fatedier/frp/pkg/naming"
|
|
||||||
"github.com/fatedier/frp/pkg/proto/udp"
|
"github.com/fatedier/frp/pkg/proto/udp"
|
||||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||||
"github.com/fatedier/frp/pkg/util/util"
|
|
||||||
"github.com/fatedier/frp/pkg/util/xlog"
|
"github.com/fatedier/frp/pkg/util/xlog"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -76,6 +72,7 @@ func (sv *SUDPVisitor) dispatcher() {
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
visitorConn net.Conn
|
visitorConn net.Conn
|
||||||
|
recycleFn func()
|
||||||
err error
|
err error
|
||||||
|
|
||||||
firstPacket *msg.UDPPacket
|
firstPacket *msg.UDPPacket
|
||||||
@@ -93,14 +90,17 @@ func (sv *SUDPVisitor) dispatcher() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
visitorConn, err = sv.getNewVisitorConn()
|
visitorConn, recycleFn, err = sv.getNewVisitorConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
xl.Warnf("newVisitorConn to frps error: %v, try to reconnect", err)
|
xl.Warnf("newVisitorConn to frps error: %v, try to reconnect", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// visitorConn always be closed when worker done.
|
// visitorConn always be closed when worker done.
|
||||||
sv.worker(visitorConn, firstPacket)
|
func() {
|
||||||
|
defer recycleFn()
|
||||||
|
sv.worker(visitorConn, firstPacket)
|
||||||
|
}()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-sv.checkCloseCh:
|
case <-sv.checkCloseCh:
|
||||||
@@ -198,57 +198,17 @@ func (sv *SUDPVisitor) worker(workConn net.Conn, firstPacket *msg.UDPPacket) {
|
|||||||
xl.Infof("sudp worker is closed")
|
xl.Infof("sudp worker is closed")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sv *SUDPVisitor) getNewVisitorConn() (net.Conn, error) {
|
func (sv *SUDPVisitor) getNewVisitorConn() (net.Conn, func(), error) {
|
||||||
xl := xlog.FromContextSafe(sv.ctx)
|
rawConn, err := sv.dialRawVisitorConn(sv.cfg.GetBaseConfig())
|
||||||
visitorConn, err := sv.helper.ConnectServer()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("frpc connect frps error: %v", err)
|
return nil, func() {}, err
|
||||||
}
|
}
|
||||||
|
rwc, recycleFn, err := wrapVisitorConn(rawConn, sv.cfg.GetBaseConfig())
|
||||||
now := time.Now().Unix()
|
|
||||||
targetProxyName := naming.BuildTargetServerProxyName(sv.clientCfg.User, sv.cfg.ServerUser, sv.cfg.ServerName)
|
|
||||||
newVisitorConnMsg := &msg.NewVisitorConn{
|
|
||||||
RunID: sv.helper.RunID(),
|
|
||||||
ProxyName: targetProxyName,
|
|
||||||
SignKey: util.GetAuthKey(sv.cfg.SecretKey, now),
|
|
||||||
Timestamp: now,
|
|
||||||
UseEncryption: sv.cfg.Transport.UseEncryption,
|
|
||||||
UseCompression: sv.cfg.Transport.UseCompression,
|
|
||||||
}
|
|
||||||
err = msg.WriteMsg(visitorConn, newVisitorConnMsg)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
visitorConn.Close()
|
rawConn.Close()
|
||||||
return nil, fmt.Errorf("frpc send newVisitorConnMsg to frps error: %v", err)
|
return nil, func() {}, err
|
||||||
}
|
}
|
||||||
|
return netpkg.WrapReadWriteCloserToConn(rwc, rawConn), recycleFn, nil
|
||||||
var newVisitorConnRespMsg msg.NewVisitorConnResp
|
|
||||||
_ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
|
||||||
err = msg.ReadMsgInto(visitorConn, &newVisitorConnRespMsg)
|
|
||||||
if err != nil {
|
|
||||||
visitorConn.Close()
|
|
||||||
return nil, fmt.Errorf("frpc read newVisitorConnRespMsg error: %v", err)
|
|
||||||
}
|
|
||||||
_ = visitorConn.SetReadDeadline(time.Time{})
|
|
||||||
|
|
||||||
if newVisitorConnRespMsg.Error != "" {
|
|
||||||
visitorConn.Close()
|
|
||||||
return nil, fmt.Errorf("start new visitor connection error: %s", newVisitorConnRespMsg.Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
var remote io.ReadWriteCloser
|
|
||||||
remote = visitorConn
|
|
||||||
if sv.cfg.Transport.UseEncryption {
|
|
||||||
remote, err = libio.WithEncryption(remote, []byte(sv.cfg.SecretKey))
|
|
||||||
if err != nil {
|
|
||||||
xl.Errorf("create encryption stream error: %v", err)
|
|
||||||
visitorConn.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if sv.cfg.Transport.UseCompression {
|
|
||||||
remote = libio.WithCompression(remote)
|
|
||||||
}
|
|
||||||
return netpkg.WrapReadWriteCloserToConn(remote, visitorConn), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sv *SUDPVisitor) Close() {
|
func (sv *SUDPVisitor) Close() {
|
||||||
|
|||||||
@@ -16,13 +16,21 @@ package visitor
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
libio "github.com/fatedier/golib/io"
|
||||||
|
|
||||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||||
|
"github.com/fatedier/frp/pkg/msg"
|
||||||
|
"github.com/fatedier/frp/pkg/naming"
|
||||||
plugin "github.com/fatedier/frp/pkg/plugin/visitor"
|
plugin "github.com/fatedier/frp/pkg/plugin/visitor"
|
||||||
"github.com/fatedier/frp/pkg/transport"
|
"github.com/fatedier/frp/pkg/transport"
|
||||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||||
|
"github.com/fatedier/frp/pkg/util/util"
|
||||||
"github.com/fatedier/frp/pkg/util/xlog"
|
"github.com/fatedier/frp/pkg/util/xlog"
|
||||||
"github.com/fatedier/frp/pkg/vnet"
|
"github.com/fatedier/frp/pkg/vnet"
|
||||||
)
|
)
|
||||||
@@ -142,3 +150,57 @@ func (v *BaseVisitor) Close() {
|
|||||||
v.plugin.Close()
|
v.plugin.Close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (v *BaseVisitor) dialRawVisitorConn(cfg *v1.VisitorBaseConfig) (net.Conn, error) {
|
||||||
|
visitorConn, err := v.helper.ConnectServer()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("connect to server error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now().Unix()
|
||||||
|
targetProxyName := naming.BuildTargetServerProxyName(v.clientCfg.User, cfg.ServerUser, cfg.ServerName)
|
||||||
|
newVisitorConnMsg := &msg.NewVisitorConn{
|
||||||
|
RunID: v.helper.RunID(),
|
||||||
|
ProxyName: targetProxyName,
|
||||||
|
SignKey: util.GetAuthKey(cfg.SecretKey, now),
|
||||||
|
Timestamp: now,
|
||||||
|
UseEncryption: cfg.Transport.UseEncryption,
|
||||||
|
UseCompression: cfg.Transport.UseCompression,
|
||||||
|
}
|
||||||
|
err = msg.WriteMsg(visitorConn, newVisitorConnMsg)
|
||||||
|
if err != nil {
|
||||||
|
visitorConn.Close()
|
||||||
|
return nil, fmt.Errorf("send newVisitorConnMsg to server error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var newVisitorConnRespMsg msg.NewVisitorConnResp
|
||||||
|
_ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
||||||
|
err = msg.ReadMsgInto(visitorConn, &newVisitorConnRespMsg)
|
||||||
|
if err != nil {
|
||||||
|
visitorConn.Close()
|
||||||
|
return nil, fmt.Errorf("read newVisitorConnRespMsg error: %v", err)
|
||||||
|
}
|
||||||
|
_ = visitorConn.SetReadDeadline(time.Time{})
|
||||||
|
|
||||||
|
if newVisitorConnRespMsg.Error != "" {
|
||||||
|
visitorConn.Close()
|
||||||
|
return nil, fmt.Errorf("start new visitor connection error: %s", newVisitorConnRespMsg.Error)
|
||||||
|
}
|
||||||
|
return visitorConn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func wrapVisitorConn(conn io.ReadWriteCloser, cfg *v1.VisitorBaseConfig) (io.ReadWriteCloser, func(), error) {
|
||||||
|
rwc := conn
|
||||||
|
if cfg.Transport.UseEncryption {
|
||||||
|
var err error
|
||||||
|
rwc, err = libio.WithEncryption(rwc, []byte(cfg.SecretKey))
|
||||||
|
if err != nil {
|
||||||
|
return nil, func() {}, fmt.Errorf("create encryption stream error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
recycleFn := func() {}
|
||||||
|
if cfg.Transport.UseCompression {
|
||||||
|
rwc, recycleFn = libio.WithCompressionFromPool(rwc)
|
||||||
|
}
|
||||||
|
return rwc, recycleFn, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -182,21 +182,14 @@ func (sv *XTCPVisitor) handleConn(userConn net.Conn) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var muxConnRWCloser io.ReadWriteCloser = tunnelConn
|
muxConnRWCloser, recycleFn, err := wrapVisitorConn(tunnelConn, sv.cfg.GetBaseConfig())
|
||||||
if sv.cfg.Transport.UseEncryption {
|
if err != nil {
|
||||||
muxConnRWCloser, err = libio.WithEncryption(muxConnRWCloser, []byte(sv.cfg.SecretKey))
|
xl.Errorf("%v", err)
|
||||||
if err != nil {
|
tunnelConn.Close()
|
||||||
xl.Errorf("create encryption stream error: %v", err)
|
tunnelErr = err
|
||||||
tunnelConn.Close()
|
return
|
||||||
tunnelErr = err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if sv.cfg.Transport.UseCompression {
|
|
||||||
var recycleFn func()
|
|
||||||
muxConnRWCloser, recycleFn = libio.WithCompressionFromPool(muxConnRWCloser)
|
|
||||||
defer recycleFn()
|
|
||||||
}
|
}
|
||||||
|
defer recycleFn()
|
||||||
|
|
||||||
_, _, errs := libio.Join(userConn, muxConnRWCloser)
|
_, _, errs := libio.Join(userConn, muxConnRWCloser)
|
||||||
xl.Debugf("join connections closed")
|
xl.Debugf("join connections closed")
|
||||||
|
|||||||
Reference in New Issue
Block a user