Compare commits

..

2 Commits

Author SHA1 Message Date
copilot-swe-agent[bot]
7993225be1 Initial plan 2026-03-05 17:33:24 +00:00
fatedier
17b27d8d96 pkg/msg: change UDPPacket.Content from string to []byte to avoid redundant base64 encode/decode 2026-03-06 01:25:52 +08:00
69 changed files with 772 additions and 613 deletions

View File

@@ -18,7 +18,6 @@ linters:
- lll
- makezero
- misspell
- modernize
- prealloc
- predeclared
- revive
@@ -48,9 +47,6 @@ linters:
ignore-rules:
- cancelled
- marshalled
modernize:
disable:
- omitzero
unparam:
check-exported: false
exclusions:

View File

@@ -16,7 +16,6 @@ package proxy
import (
"context"
"fmt"
"io"
"net"
"reflect"
@@ -123,33 +122,6 @@ func (pxy *BaseProxy) Close() {
}
}
// wrapWorkConn applies rate limiting, encryption, and compression
// to a work connection based on the proxy's transport configuration.
// The returned recycle function should be called when the stream is no longer in use
// to return compression resources to the pool. It is safe to not call recycle,
// in which case resources will be garbage collected normally.
func (pxy *BaseProxy) wrapWorkConn(conn net.Conn, encKey []byte) (io.ReadWriteCloser, func(), error) {
var rwc io.ReadWriteCloser = conn
if pxy.limiter != nil {
rwc = libio.WrapReadWriteCloser(limit.NewReader(conn, pxy.limiter), limit.NewWriter(conn, pxy.limiter), func() error {
return conn.Close()
})
}
if pxy.baseCfg.Transport.UseEncryption {
var err error
rwc, err = libio.WithEncryption(rwc, encKey)
if err != nil {
conn.Close()
return nil, nil, fmt.Errorf("create encryption stream error: %w", err)
}
}
var recycleFn func()
if pxy.baseCfg.Transport.UseCompression {
rwc, recycleFn = libio.WithCompressionFromPool(rwc)
}
return rwc, recycleFn, nil
}
func (pxy *BaseProxy) SetInWorkConnCallback(cb func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool) {
pxy.inWorkConnCallback = cb
}
@@ -167,14 +139,30 @@ func (pxy *BaseProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) {
func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWorkConn, encKey []byte) {
xl := pxy.xl
baseCfg := pxy.baseCfg
var (
remote io.ReadWriteCloser
err error
)
remote = workConn
if pxy.limiter != nil {
remote = libio.WrapReadWriteCloser(limit.NewReader(workConn, pxy.limiter), limit.NewWriter(workConn, pxy.limiter), func() error {
return workConn.Close()
})
}
xl.Tracef("handle tcp work connection, useEncryption: %t, useCompression: %t",
baseCfg.Transport.UseEncryption, baseCfg.Transport.UseCompression)
remote, recycleFn, err := pxy.wrapWorkConn(workConn, encKey)
if err != nil {
xl.Errorf("wrap work connection: %v", err)
return
if baseCfg.Transport.UseEncryption {
remote, err = libio.WithEncryption(remote, encKey)
if err != nil {
workConn.Close()
xl.Errorf("create encryption stream error: %v", err)
return
}
}
var compressionResourceRecycleFn func()
if baseCfg.Transport.UseCompression {
remote, compressionResourceRecycleFn = libio.WithCompressionFromPool(remote)
}
// check if we need to send proxy protocol info
@@ -190,6 +178,7 @@ func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWor
}
if baseCfg.Transport.ProxyProtocolVersion != "" && m.SrcAddr != "" && m.SrcPort != 0 {
// Use the common proxy protocol builder function
header := netpkg.BuildProxyProtocolHeaderStruct(connInfo.SrcAddr, connInfo.DstAddr, baseCfg.Transport.ProxyProtocolVersion)
connInfo.ProxyProtocolHeader = header
}
@@ -198,18 +187,12 @@ func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWor
if pxy.proxyPlugin != nil {
// if plugin is set, let plugin handle connection first
// Don't recycle compression resources here because plugins may
// retain the connection after Handle returns.
xl.Debugf("handle by plugin: %s", pxy.proxyPlugin.Name())
pxy.proxyPlugin.Handle(pxy.ctx, &connInfo)
xl.Debugf("handle by plugin finished")
return
}
if recycleFn != nil {
defer recycleFn()
}
localConn, err := libnet.Dial(
net.JoinHostPort(baseCfg.LocalIP, strconv.Itoa(baseCfg.LocalPort)),
libnet.WithTimeout(10*time.Second),
@@ -226,7 +209,6 @@ func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWor
if connInfo.ProxyProtocolHeader != nil {
if _, err := connInfo.ProxyProtocolHeader.WriteTo(localConn); err != nil {
workConn.Close()
localConn.Close()
xl.Errorf("write proxy protocol header to local conn error: %v", err)
return
}
@@ -237,4 +219,7 @@ func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWor
if len(errs) > 0 {
xl.Tracef("join connections errors: %v", errs)
}
if compressionResourceRecycleFn != nil {
compressionResourceRecycleFn()
}
}

View File

@@ -17,6 +17,7 @@
package proxy
import (
"io"
"net"
"reflect"
"strconv"
@@ -24,15 +25,17 @@ import (
"time"
"github.com/fatedier/golib/errors"
libio "github.com/fatedier/golib/io"
v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/proto/udp"
"github.com/fatedier/frp/pkg/util/limit"
netpkg "github.com/fatedier/frp/pkg/util/net"
)
func init() {
RegisterProxyFactory(reflect.TypeFor[*v1.SUDPProxyConfig](), NewSUDPProxy)
RegisterProxyFactory(reflect.TypeOf(&v1.SUDPProxyConfig{}), NewSUDPProxy)
}
type SUDPProxy struct {
@@ -80,13 +83,27 @@ func (pxy *SUDPProxy) InWorkConn(conn net.Conn, _ *msg.StartWorkConn) {
xl := pxy.xl
xl.Infof("incoming a new work connection for sudp proxy, %s", conn.RemoteAddr().String())
remote, _, err := pxy.wrapWorkConn(conn, pxy.encryptionKey)
if err != nil {
xl.Errorf("wrap work connection: %v", err)
return
var rwc io.ReadWriteCloser = conn
var err error
if pxy.limiter != nil {
rwc = libio.WrapReadWriteCloser(limit.NewReader(conn, pxy.limiter), limit.NewWriter(conn, pxy.limiter), func() error {
return conn.Close()
})
}
if pxy.cfg.Transport.UseEncryption {
rwc, err = libio.WithEncryption(rwc, pxy.encryptionKey)
if err != nil {
conn.Close()
xl.Errorf("create encryption stream error: %v", err)
return
}
}
if pxy.cfg.Transport.UseCompression {
rwc = libio.WithCompression(rwc)
}
conn = netpkg.WrapReadWriteCloserToConn(rwc, conn)
workConn := netpkg.WrapReadWriteCloserToConn(remote, conn)
workConn := conn
readCh := make(chan *msg.UDPPacket, 1024)
sendCh := make(chan msg.Message, 1024)
isClose := false

View File

@@ -17,21 +17,24 @@
package proxy
import (
"io"
"net"
"reflect"
"strconv"
"time"
"github.com/fatedier/golib/errors"
libio "github.com/fatedier/golib/io"
v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/proto/udp"
"github.com/fatedier/frp/pkg/util/limit"
netpkg "github.com/fatedier/frp/pkg/util/net"
)
func init() {
RegisterProxyFactory(reflect.TypeFor[*v1.UDPProxyConfig](), NewUDPProxy)
RegisterProxyFactory(reflect.TypeOf(&v1.UDPProxyConfig{}), NewUDPProxy)
}
type UDPProxy struct {
@@ -91,14 +94,28 @@ func (pxy *UDPProxy) InWorkConn(conn net.Conn, _ *msg.StartWorkConn) {
// close resources related with old workConn
pxy.Close()
remote, _, err := pxy.wrapWorkConn(conn, pxy.encryptionKey)
if err != nil {
xl.Errorf("wrap work connection: %v", err)
return
var rwc io.ReadWriteCloser = conn
var err error
if pxy.limiter != nil {
rwc = libio.WrapReadWriteCloser(limit.NewReader(conn, pxy.limiter), limit.NewWriter(conn, pxy.limiter), func() error {
return conn.Close()
})
}
if pxy.cfg.Transport.UseEncryption {
rwc, err = libio.WithEncryption(rwc, pxy.encryptionKey)
if err != nil {
conn.Close()
xl.Errorf("create encryption stream error: %v", err)
return
}
}
if pxy.cfg.Transport.UseCompression {
rwc = libio.WithCompression(rwc)
}
conn = netpkg.WrapReadWriteCloserToConn(rwc, conn)
pxy.mu.Lock()
pxy.workConn = netpkg.WrapReadWriteCloserToConn(remote, conn)
pxy.workConn = conn
pxy.readCh = make(chan *msg.UDPPacket, 1024)
pxy.sendCh = make(chan msg.Message, 1024)
pxy.closed = false

View File

@@ -34,7 +34,7 @@ import (
)
func init() {
RegisterProxyFactory(reflect.TypeFor[*v1.XTCPProxyConfig](), NewXTCPProxy)
RegisterProxyFactory(reflect.TypeOf(&v1.XTCPProxyConfig{}), NewXTCPProxy)
}
type XTCPProxy struct {

View File

@@ -15,12 +15,18 @@
package visitor
import (
"fmt"
"io"
"net"
"strconv"
"time"
libio "github.com/fatedier/golib/io"
v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/naming"
"github.com/fatedier/frp/pkg/util/util"
"github.com/fatedier/frp/pkg/util/xlog"
)
@@ -36,10 +42,10 @@ func (sv *STCPVisitor) Run() (err error) {
if err != nil {
return
}
go sv.acceptLoop(sv.l, "stcp local", sv.handleConn)
go sv.worker()
}
go sv.acceptLoop(sv.internalLn, "stcp internal", sv.handleConn)
go sv.internalConnWorker()
if sv.plugin != nil {
sv.plugin.Start()
@@ -51,10 +57,35 @@ func (sv *STCPVisitor) Close() {
sv.BaseVisitor.Close()
}
func (sv *STCPVisitor) worker() {
xl := xlog.FromContextSafe(sv.ctx)
for {
conn, err := sv.l.Accept()
if err != nil {
xl.Warnf("stcp local listener closed")
return
}
go sv.handleConn(conn)
}
}
func (sv *STCPVisitor) internalConnWorker() {
xl := xlog.FromContextSafe(sv.ctx)
for {
conn, err := sv.internalLn.Accept()
if err != nil {
xl.Warnf("stcp internal listener closed")
return
}
go sv.handleConn(conn)
}
}
func (sv *STCPVisitor) handleConn(userConn net.Conn) {
xl := xlog.FromContextSafe(sv.ctx)
var tunnelErr error
defer func() {
// If there was an error and connection supports CloseWithError, use it
if tunnelErr != nil {
if eConn, ok := userConn.(interface{ CloseWithError(error) error }); ok {
_ = eConn.CloseWithError(tunnelErr)
@@ -65,21 +96,62 @@ func (sv *STCPVisitor) handleConn(userConn net.Conn) {
}()
xl.Debugf("get a new stcp user connection")
visitorConn, err := sv.dialRawVisitorConn(sv.cfg.GetBaseConfig())
visitorConn, err := sv.helper.ConnectServer()
if err != nil {
xl.Warnf("dialRawVisitorConn error: %v", err)
tunnelErr = err
return
}
defer visitorConn.Close()
remote, recycleFn, err := wrapVisitorConn(visitorConn, sv.cfg.GetBaseConfig())
now := time.Now().Unix()
targetProxyName := naming.BuildTargetServerProxyName(sv.clientCfg.User, sv.cfg.ServerUser, sv.cfg.ServerName)
newVisitorConnMsg := &msg.NewVisitorConn{
RunID: sv.helper.RunID(),
ProxyName: targetProxyName,
SignKey: util.GetAuthKey(sv.cfg.SecretKey, now),
Timestamp: now,
UseEncryption: sv.cfg.Transport.UseEncryption,
UseCompression: sv.cfg.Transport.UseCompression,
}
err = msg.WriteMsg(visitorConn, newVisitorConnMsg)
if err != nil {
xl.Warnf("wrapVisitorConn error: %v", err)
xl.Warnf("send newVisitorConnMsg to server error: %v", err)
tunnelErr = err
return
}
defer recycleFn()
var newVisitorConnRespMsg msg.NewVisitorConnResp
_ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second))
err = msg.ReadMsgInto(visitorConn, &newVisitorConnRespMsg)
if err != nil {
xl.Warnf("get newVisitorConnRespMsg error: %v", err)
tunnelErr = err
return
}
_ = visitorConn.SetReadDeadline(time.Time{})
if newVisitorConnRespMsg.Error != "" {
xl.Warnf("start new visitor connection error: %s", newVisitorConnRespMsg.Error)
tunnelErr = fmt.Errorf("%s", newVisitorConnRespMsg.Error)
return
}
var remote io.ReadWriteCloser
remote = visitorConn
if sv.cfg.Transport.UseEncryption {
remote, err = libio.WithEncryption(remote, []byte(sv.cfg.SecretKey))
if err != nil {
xl.Errorf("create encryption stream error: %v", err)
tunnelErr = err
return
}
}
if sv.cfg.Transport.UseCompression {
var recycleFn func()
remote, recycleFn = libio.WithCompressionFromPool(remote)
defer recycleFn()
}
libio.Join(userConn, remote)
}

View File

@@ -16,17 +16,21 @@ package visitor
import (
"fmt"
"io"
"net"
"strconv"
"sync"
"time"
"github.com/fatedier/golib/errors"
libio "github.com/fatedier/golib/io"
v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/naming"
"github.com/fatedier/frp/pkg/proto/udp"
netpkg "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/util"
"github.com/fatedier/frp/pkg/util/xlog"
)
@@ -72,7 +76,6 @@ func (sv *SUDPVisitor) dispatcher() {
var (
visitorConn net.Conn
recycleFn func()
err error
firstPacket *msg.UDPPacket
@@ -90,17 +93,14 @@ func (sv *SUDPVisitor) dispatcher() {
return
}
visitorConn, recycleFn, err = sv.getNewVisitorConn()
visitorConn, err = sv.getNewVisitorConn()
if err != nil {
xl.Warnf("newVisitorConn to frps error: %v, try to reconnect", err)
continue
}
// visitorConn always be closed when worker done.
func() {
defer recycleFn()
sv.worker(visitorConn, firstPacket)
}()
sv.worker(visitorConn, firstPacket)
select {
case <-sv.checkCloseCh:
@@ -198,17 +198,53 @@ func (sv *SUDPVisitor) worker(workConn net.Conn, firstPacket *msg.UDPPacket) {
xl.Infof("sudp worker is closed")
}
func (sv *SUDPVisitor) getNewVisitorConn() (net.Conn, func(), error) {
rawConn, err := sv.dialRawVisitorConn(sv.cfg.GetBaseConfig())
func (sv *SUDPVisitor) getNewVisitorConn() (net.Conn, error) {
xl := xlog.FromContextSafe(sv.ctx)
visitorConn, err := sv.helper.ConnectServer()
if err != nil {
return nil, func() {}, err
return nil, fmt.Errorf("frpc connect frps error: %v", err)
}
rwc, recycleFn, err := wrapVisitorConn(rawConn, sv.cfg.GetBaseConfig())
now := time.Now().Unix()
targetProxyName := naming.BuildTargetServerProxyName(sv.clientCfg.User, sv.cfg.ServerUser, sv.cfg.ServerName)
newVisitorConnMsg := &msg.NewVisitorConn{
RunID: sv.helper.RunID(),
ProxyName: targetProxyName,
SignKey: util.GetAuthKey(sv.cfg.SecretKey, now),
Timestamp: now,
UseEncryption: sv.cfg.Transport.UseEncryption,
UseCompression: sv.cfg.Transport.UseCompression,
}
err = msg.WriteMsg(visitorConn, newVisitorConnMsg)
if err != nil {
rawConn.Close()
return nil, func() {}, err
return nil, fmt.Errorf("frpc send newVisitorConnMsg to frps error: %v", err)
}
return netpkg.WrapReadWriteCloserToConn(rwc, rawConn), recycleFn, nil
var newVisitorConnRespMsg msg.NewVisitorConnResp
_ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second))
err = msg.ReadMsgInto(visitorConn, &newVisitorConnRespMsg)
if err != nil {
return nil, fmt.Errorf("frpc read newVisitorConnRespMsg error: %v", err)
}
_ = visitorConn.SetReadDeadline(time.Time{})
if newVisitorConnRespMsg.Error != "" {
return nil, fmt.Errorf("start new visitor connection error: %s", newVisitorConnRespMsg.Error)
}
var remote io.ReadWriteCloser
remote = visitorConn
if sv.cfg.Transport.UseEncryption {
remote, err = libio.WithEncryption(remote, []byte(sv.cfg.SecretKey))
if err != nil {
xl.Errorf("create encryption stream error: %v", err)
return nil, err
}
}
if sv.cfg.Transport.UseCompression {
remote = libio.WithCompression(remote)
}
return netpkg.WrapReadWriteCloserToConn(remote, visitorConn), nil
}
func (sv *SUDPVisitor) Close() {

View File

@@ -16,21 +16,13 @@ package visitor
import (
"context"
"fmt"
"io"
"net"
"sync"
"time"
libio "github.com/fatedier/golib/io"
v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/naming"
plugin "github.com/fatedier/frp/pkg/plugin/visitor"
"github.com/fatedier/frp/pkg/transport"
netpkg "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/util"
"github.com/fatedier/frp/pkg/util/xlog"
"github.com/fatedier/frp/pkg/vnet"
)
@@ -127,18 +119,6 @@ func (v *BaseVisitor) AcceptConn(conn net.Conn) error {
return v.internalLn.PutConn(conn)
}
func (v *BaseVisitor) acceptLoop(l net.Listener, name string, handleConn func(net.Conn)) {
xl := xlog.FromContextSafe(v.ctx)
for {
conn, err := l.Accept()
if err != nil {
xl.Warnf("%s listener closed", name)
return
}
go handleConn(conn)
}
}
func (v *BaseVisitor) Close() {
if v.l != nil {
v.l.Close()
@@ -150,57 +130,3 @@ func (v *BaseVisitor) Close() {
v.plugin.Close()
}
}
func (v *BaseVisitor) dialRawVisitorConn(cfg *v1.VisitorBaseConfig) (net.Conn, error) {
visitorConn, err := v.helper.ConnectServer()
if err != nil {
return nil, fmt.Errorf("connect to server error: %v", err)
}
now := time.Now().Unix()
targetProxyName := naming.BuildTargetServerProxyName(v.clientCfg.User, cfg.ServerUser, cfg.ServerName)
newVisitorConnMsg := &msg.NewVisitorConn{
RunID: v.helper.RunID(),
ProxyName: targetProxyName,
SignKey: util.GetAuthKey(cfg.SecretKey, now),
Timestamp: now,
UseEncryption: cfg.Transport.UseEncryption,
UseCompression: cfg.Transport.UseCompression,
}
err = msg.WriteMsg(visitorConn, newVisitorConnMsg)
if err != nil {
visitorConn.Close()
return nil, fmt.Errorf("send newVisitorConnMsg to server error: %v", err)
}
var newVisitorConnRespMsg msg.NewVisitorConnResp
_ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second))
err = msg.ReadMsgInto(visitorConn, &newVisitorConnRespMsg)
if err != nil {
visitorConn.Close()
return nil, fmt.Errorf("read newVisitorConnRespMsg error: %v", err)
}
_ = visitorConn.SetReadDeadline(time.Time{})
if newVisitorConnRespMsg.Error != "" {
visitorConn.Close()
return nil, fmt.Errorf("start new visitor connection error: %s", newVisitorConnRespMsg.Error)
}
return visitorConn, nil
}
func wrapVisitorConn(conn io.ReadWriteCloser, cfg *v1.VisitorBaseConfig) (io.ReadWriteCloser, func(), error) {
rwc := conn
if cfg.Transport.UseEncryption {
var err error
rwc, err = libio.WithEncryption(rwc, []byte(cfg.SecretKey))
if err != nil {
return nil, func() {}, fmt.Errorf("create encryption stream error: %v", err)
}
}
recycleFn := func() {}
if cfg.Transport.UseCompression {
rwc, recycleFn = libio.WithCompressionFromPool(rwc)
}
return rwc, recycleFn, nil
}

View File

@@ -65,10 +65,10 @@ func (sv *XTCPVisitor) Run() (err error) {
if err != nil {
return
}
go sv.acceptLoop(sv.l, "xtcp local", sv.handleConn)
go sv.worker()
}
go sv.acceptLoop(sv.internalLn, "xtcp internal", sv.handleConn)
go sv.internalConnWorker()
go sv.processTunnelStartEvents()
if sv.cfg.KeepTunnelOpen {
sv.retryLimiter = rate.NewLimiter(rate.Every(time.Hour/time.Duration(sv.cfg.MaxRetriesAnHour)), sv.cfg.MaxRetriesAnHour)
@@ -93,6 +93,30 @@ func (sv *XTCPVisitor) Close() {
}
}
func (sv *XTCPVisitor) worker() {
xl := xlog.FromContextSafe(sv.ctx)
for {
conn, err := sv.l.Accept()
if err != nil {
xl.Warnf("xtcp local listener closed")
return
}
go sv.handleConn(conn)
}
}
func (sv *XTCPVisitor) internalConnWorker() {
xl := xlog.FromContextSafe(sv.ctx)
for {
conn, err := sv.internalLn.Accept()
if err != nil {
xl.Warnf("xtcp internal listener closed")
return
}
go sv.handleConn(conn)
}
}
func (sv *XTCPVisitor) processTunnelStartEvents() {
for {
select {
@@ -182,14 +206,20 @@ func (sv *XTCPVisitor) handleConn(userConn net.Conn) {
return
}
muxConnRWCloser, recycleFn, err := wrapVisitorConn(tunnelConn, sv.cfg.GetBaseConfig())
if err != nil {
xl.Errorf("%v", err)
tunnelConn.Close()
tunnelErr = err
return
var muxConnRWCloser io.ReadWriteCloser = tunnelConn
if sv.cfg.Transport.UseEncryption {
muxConnRWCloser, err = libio.WithEncryption(muxConnRWCloser, []byte(sv.cfg.SecretKey))
if err != nil {
xl.Errorf("create encryption stream error: %v", err)
tunnelErr = err
return
}
}
if sv.cfg.Transport.UseCompression {
var recycleFn func()
muxConnRWCloser, recycleFn = libio.WithCompressionFromPool(muxConnRWCloser)
defer recycleFn()
}
defer recycleFn()
_, _, errs := libio.Join(userConn, muxConnRWCloser)
xl.Debugf("join connections closed")
@@ -343,7 +373,6 @@ func (ks *KCPTunnelSession) Init(listenConn *net.UDPConn, raddr *net.UDPAddr) er
}
remote, err := netpkg.NewKCPConnFromUDP(lConn, true, raddr.String())
if err != nil {
lConn.Close()
return fmt.Errorf("create kcp connection from udp connection error: %v", err)
}

View File

@@ -47,7 +47,7 @@ var natholeDiscoveryCmd = &cobra.Command{
Use: "discover",
Short: "Discover nathole information from stun server",
RunE: func(cmd *cobra.Command, args []string) error {
// ignore error here, because we can use command line parameters
// ignore error here, because we can use command line pameters
cfg, _, _, _, err := config.LoadClientConfig(cfgFile, strictConfigMode)
if err != nil {
cfg = &v1.ClientCommonConfig{}

View File

@@ -23,7 +23,6 @@ import (
"net/url"
"os"
"slices"
"sync"
"github.com/coreos/go-oidc/v3/oidc"
"golang.org/x/oauth2"
@@ -206,8 +205,7 @@ type OidcAuthConsumer struct {
additionalAuthScopes []v1.AuthScope
verifier TokenVerifier
mu sync.RWMutex
subjectsFromLogin map[string]struct{}
subjectsFromLogin []string
}
func NewTokenVerifier(cfg v1.AuthOIDCServerConfig) TokenVerifier {
@@ -228,7 +226,7 @@ func NewOidcAuthVerifier(additionalAuthScopes []v1.AuthScope, verifier TokenVeri
return &OidcAuthConsumer{
additionalAuthScopes: additionalAuthScopes,
verifier: verifier,
subjectsFromLogin: make(map[string]struct{}),
subjectsFromLogin: []string{},
}
}
@@ -237,9 +235,9 @@ func (auth *OidcAuthConsumer) VerifyLogin(loginMsg *msg.Login) (err error) {
if err != nil {
return fmt.Errorf("invalid OIDC token in login: %v", err)
}
auth.mu.Lock()
auth.subjectsFromLogin[token.Subject] = struct{}{}
auth.mu.Unlock()
if !slices.Contains(auth.subjectsFromLogin, token.Subject) {
auth.subjectsFromLogin = append(auth.subjectsFromLogin, token.Subject)
}
return nil
}
@@ -248,13 +246,11 @@ func (auth *OidcAuthConsumer) verifyPostLoginToken(privilegeKey string) (err err
if err != nil {
return fmt.Errorf("invalid OIDC token in ping: %v", err)
}
auth.mu.RLock()
_, ok := auth.subjectsFromLogin[token.Subject]
auth.mu.RUnlock()
if !ok {
if !slices.Contains(auth.subjectsFromLogin, token.Subject) {
return fmt.Errorf("received different OIDC subject in login and ping. "+
"original subjects: %s, "+
"new subject: %s",
token.Subject)
auth.subjectsFromLogin, token.Subject)
}
return nil
}

View File

@@ -171,14 +171,15 @@ func Convert_ServerCommonConf_To_v1(conf *ServerCommonConf) *v1.ServerConfig {
func transformHeadersFromPluginParams(params map[string]string) v1.HeaderOperations {
out := v1.HeaderOperations{}
for k, v := range params {
k, ok := strings.CutPrefix(k, "plugin_header_")
if !ok || k == "" {
if !strings.HasPrefix(k, "plugin_header_") {
continue
}
if out.Set == nil {
out.Set = make(map[string]string)
if k = strings.TrimPrefix(k, "plugin_header_"); k != "" {
if out.Set == nil {
out.Set = make(map[string]string)
}
out.Set[k] = v
}
out.Set[k] = v
}
return out
}

View File

@@ -39,14 +39,14 @@ const (
// Proxy
var (
proxyConfTypeMap = map[ProxyType]reflect.Type{
ProxyTypeTCP: reflect.TypeFor[TCPProxyConf](),
ProxyTypeUDP: reflect.TypeFor[UDPProxyConf](),
ProxyTypeTCPMUX: reflect.TypeFor[TCPMuxProxyConf](),
ProxyTypeHTTP: reflect.TypeFor[HTTPProxyConf](),
ProxyTypeHTTPS: reflect.TypeFor[HTTPSProxyConf](),
ProxyTypeSTCP: reflect.TypeFor[STCPProxyConf](),
ProxyTypeXTCP: reflect.TypeFor[XTCPProxyConf](),
ProxyTypeSUDP: reflect.TypeFor[SUDPProxyConf](),
ProxyTypeTCP: reflect.TypeOf(TCPProxyConf{}),
ProxyTypeUDP: reflect.TypeOf(UDPProxyConf{}),
ProxyTypeTCPMUX: reflect.TypeOf(TCPMuxProxyConf{}),
ProxyTypeHTTP: reflect.TypeOf(HTTPProxyConf{}),
ProxyTypeHTTPS: reflect.TypeOf(HTTPSProxyConf{}),
ProxyTypeSTCP: reflect.TypeOf(STCPProxyConf{}),
ProxyTypeXTCP: reflect.TypeOf(XTCPProxyConf{}),
ProxyTypeSUDP: reflect.TypeOf(SUDPProxyConf{}),
}
)

View File

@@ -22,8 +22,8 @@ func GetMapWithoutPrefix(set map[string]string, prefix string) map[string]string
m := make(map[string]string)
for key, value := range set {
if trimmed, ok := strings.CutPrefix(key, prefix); ok {
m[trimmed] = value
if strings.HasPrefix(key, prefix) {
m[strings.TrimPrefix(key, prefix)] = value
}
}

View File

@@ -32,9 +32,9 @@ const (
// Visitor
var (
visitorConfTypeMap = map[VisitorType]reflect.Type{
VisitorTypeSTCP: reflect.TypeFor[STCPVisitorConf](),
VisitorTypeXTCP: reflect.TypeFor[XTCPVisitorConf](),
VisitorTypeSUDP: reflect.TypeFor[SUDPVisitorConf](),
VisitorTypeSTCP: reflect.TypeOf(STCPVisitorConf{}),
VisitorTypeXTCP: reflect.TypeOf(XTCPVisitorConf{}),
VisitorTypeSUDP: reflect.TypeOf(SUDPVisitorConf{}),
}
)

View File

@@ -15,11 +15,9 @@
package source
import (
"cmp"
"errors"
"fmt"
"maps"
"slices"
"sort"
"sync"
v1 "github.com/fatedier/frp/pkg/config/v1"
@@ -99,11 +97,21 @@ func (a *Aggregator) mapsToSortedSlices(
proxyMap map[string]v1.ProxyConfigurer,
visitorMap map[string]v1.VisitorConfigurer,
) ([]v1.ProxyConfigurer, []v1.VisitorConfigurer) {
proxies := slices.SortedFunc(maps.Values(proxyMap), func(x, y v1.ProxyConfigurer) int {
return cmp.Compare(x.GetBaseConfig().Name, y.GetBaseConfig().Name)
proxies := make([]v1.ProxyConfigurer, 0, len(proxyMap))
for _, p := range proxyMap {
proxies = append(proxies, p)
}
sort.Slice(proxies, func(i, j int) bool {
return proxies[i].GetBaseConfig().Name < proxies[j].GetBaseConfig().Name
})
visitors := slices.SortedFunc(maps.Values(visitorMap), func(x, y v1.VisitorConfigurer) int {
return cmp.Compare(x.GetBaseConfig().Name, y.GetBaseConfig().Name)
visitors := make([]v1.VisitorConfigurer, 0, len(visitorMap))
for _, v := range visitorMap {
visitors = append(visitors, v)
}
sort.Slice(visitors, func(i, j int) bool {
return visitors[i].GetBaseConfig().Name < visitors[j].GetBaseConfig().Name
})
return proxies, visitors
}

View File

@@ -196,27 +196,6 @@ func TestAggregator_VisitorMerge(t *testing.T) {
require.Len(visitors, 2)
}
func TestAggregator_Load_ReturnsSortedByName(t *testing.T) {
require := require.New(t)
agg := newTestAggregator(t, nil)
err := agg.ConfigSource().ReplaceAll(
[]v1.ProxyConfigurer{mockProxy("charlie"), mockProxy("alice"), mockProxy("bob")},
[]v1.VisitorConfigurer{mockVisitor("zulu"), mockVisitor("alpha")},
)
require.NoError(err)
proxies, visitors, err := agg.Load()
require.NoError(err)
require.Len(proxies, 3)
require.Equal("alice", proxies[0].GetBaseConfig().Name)
require.Equal("bob", proxies[1].GetBaseConfig().Name)
require.Equal("charlie", proxies[2].GetBaseConfig().Name)
require.Len(visitors, 2)
require.Equal("alpha", visitors[0].GetBaseConfig().Name)
require.Equal("zulu", visitors[1].GetBaseConfig().Name)
}
func TestAggregator_Load_ReturnsDefensiveCopies(t *testing.T) {
require := require.New(t)

View File

@@ -38,7 +38,7 @@ func parseNumberRangePair(firstRangeStr, secondRangeStr string) ([]NumberPair, e
return nil, fmt.Errorf("first and second range numbers are not in pairs")
}
pairs := make([]NumberPair, 0, len(firstRangeNumbers))
for i := range firstRangeNumbers {
for i := 0; i < len(firstRangeNumbers); i++ {
pairs = append(pairs, NumberPair{
First: firstRangeNumbers[i],
Second: secondRangeNumbers[i],

View File

@@ -70,18 +70,24 @@ func (q *BandwidthQuantity) UnmarshalString(s string) error {
f float64
err error
)
if fstr, ok := strings.CutSuffix(s, "MB"); ok {
switch {
case strings.HasSuffix(s, "MB"):
base = MB
fstr := strings.TrimSuffix(s, "MB")
f, err = strconv.ParseFloat(fstr, 64)
} else if fstr, ok := strings.CutSuffix(s, "KB"); ok {
if err != nil {
return err
}
case strings.HasSuffix(s, "KB"):
base = KB
fstr := strings.TrimSuffix(s, "KB")
f, err = strconv.ParseFloat(fstr, 64)
} else {
if err != nil {
return err
}
default:
return errors.New("unit not support")
}
if err != nil {
return err
}
q.s = s
q.i = int64(f * float64(base))
@@ -137,8 +143,8 @@ func (p PortsRangeSlice) String() string {
func NewPortsRangeSliceFromString(str string) ([]PortsRange, error) {
str = strings.TrimSpace(str)
out := []PortsRange{}
numRanges := strings.SplitSeq(str, ",")
for numRangeStr := range numRanges {
numRanges := strings.Split(str, ",")
for _, numRangeStr := range numRanges {
// 1000-2000 or 2001
numArray := strings.Split(numRangeStr, "-")
// length: only 1 or 2 is correct

View File

@@ -39,31 +39,6 @@ func TestBandwidthQuantity(t *testing.T) {
require.Equal(`{"b":"1KB","int":5}`, string(buf))
}
func TestBandwidthQuantity_MB(t *testing.T) {
require := require.New(t)
var w Wrap
err := json.Unmarshal([]byte(`{"b":"2MB","int":1}`), &w)
require.NoError(err)
require.EqualValues(2*MB, w.B.Bytes())
buf, err := json.Marshal(&w)
require.NoError(err)
require.Equal(`{"b":"2MB","int":1}`, string(buf))
}
func TestBandwidthQuantity_InvalidUnit(t *testing.T) {
var w Wrap
err := json.Unmarshal([]byte(`{"b":"1GB","int":1}`), &w)
require.Error(t, err)
}
func TestBandwidthQuantity_InvalidNumber(t *testing.T) {
var w Wrap
err := json.Unmarshal([]byte(`{"b":"abcKB","int":1}`), &w)
require.Error(t, err)
}
func TestPortsRangeSlice2String(t *testing.T) {
require := require.New(t)

View File

@@ -239,14 +239,14 @@ const (
)
var proxyConfigTypeMap = map[ProxyType]reflect.Type{
ProxyTypeTCP: reflect.TypeFor[TCPProxyConfig](),
ProxyTypeUDP: reflect.TypeFor[UDPProxyConfig](),
ProxyTypeHTTP: reflect.TypeFor[HTTPProxyConfig](),
ProxyTypeHTTPS: reflect.TypeFor[HTTPSProxyConfig](),
ProxyTypeTCPMUX: reflect.TypeFor[TCPMuxProxyConfig](),
ProxyTypeSTCP: reflect.TypeFor[STCPProxyConfig](),
ProxyTypeXTCP: reflect.TypeFor[XTCPProxyConfig](),
ProxyTypeSUDP: reflect.TypeFor[SUDPProxyConfig](),
ProxyTypeTCP: reflect.TypeOf(TCPProxyConfig{}),
ProxyTypeUDP: reflect.TypeOf(UDPProxyConfig{}),
ProxyTypeHTTP: reflect.TypeOf(HTTPProxyConfig{}),
ProxyTypeHTTPS: reflect.TypeOf(HTTPSProxyConfig{}),
ProxyTypeTCPMUX: reflect.TypeOf(TCPMuxProxyConfig{}),
ProxyTypeSTCP: reflect.TypeOf(STCPProxyConfig{}),
ProxyTypeXTCP: reflect.TypeOf(XTCPProxyConfig{}),
ProxyTypeSUDP: reflect.TypeOf(SUDPProxyConfig{}),
}
func NewProxyConfigurerByType(proxyType ProxyType) ProxyConfigurer {

View File

@@ -37,16 +37,16 @@ const (
)
var clientPluginOptionsTypeMap = map[string]reflect.Type{
PluginHTTP2HTTPS: reflect.TypeFor[HTTP2HTTPSPluginOptions](),
PluginHTTPProxy: reflect.TypeFor[HTTPProxyPluginOptions](),
PluginHTTPS2HTTP: reflect.TypeFor[HTTPS2HTTPPluginOptions](),
PluginHTTPS2HTTPS: reflect.TypeFor[HTTPS2HTTPSPluginOptions](),
PluginHTTP2HTTP: reflect.TypeFor[HTTP2HTTPPluginOptions](),
PluginSocks5: reflect.TypeFor[Socks5PluginOptions](),
PluginStaticFile: reflect.TypeFor[StaticFilePluginOptions](),
PluginUnixDomainSocket: reflect.TypeFor[UnixDomainSocketPluginOptions](),
PluginTLS2Raw: reflect.TypeFor[TLS2RawPluginOptions](),
PluginVirtualNet: reflect.TypeFor[VirtualNetPluginOptions](),
PluginHTTP2HTTPS: reflect.TypeOf(HTTP2HTTPSPluginOptions{}),
PluginHTTPProxy: reflect.TypeOf(HTTPProxyPluginOptions{}),
PluginHTTPS2HTTP: reflect.TypeOf(HTTPS2HTTPPluginOptions{}),
PluginHTTPS2HTTPS: reflect.TypeOf(HTTPS2HTTPSPluginOptions{}),
PluginHTTP2HTTP: reflect.TypeOf(HTTP2HTTPPluginOptions{}),
PluginSocks5: reflect.TypeOf(Socks5PluginOptions{}),
PluginStaticFile: reflect.TypeOf(StaticFilePluginOptions{}),
PluginUnixDomainSocket: reflect.TypeOf(UnixDomainSocketPluginOptions{}),
PluginTLS2Raw: reflect.TypeOf(TLS2RawPluginOptions{}),
PluginVirtualNet: reflect.TypeOf(VirtualNetPluginOptions{}),
}
type ClientPluginOptions interface {

View File

@@ -79,9 +79,9 @@ const (
)
var visitorConfigTypeMap = map[VisitorType]reflect.Type{
VisitorTypeSTCP: reflect.TypeFor[STCPVisitorConfig](),
VisitorTypeXTCP: reflect.TypeFor[XTCPVisitorConfig](),
VisitorTypeSUDP: reflect.TypeFor[SUDPVisitorConfig](),
VisitorTypeSTCP: reflect.TypeOf(STCPVisitorConfig{}),
VisitorTypeXTCP: reflect.TypeOf(XTCPVisitorConfig{}),
VisitorTypeSUDP: reflect.TypeOf(SUDPVisitorConfig{}),
}
type TypedVisitorConfig struct {

View File

@@ -25,7 +25,7 @@ const (
)
var visitorPluginOptionsTypeMap = map[string]reflect.Type{
VisitorPluginVirtualNet: reflect.TypeFor[VirtualNetVisitorPluginOptions](),
VisitorPluginVirtualNet: reflect.TypeOf(VirtualNetVisitorPluginOptions{}),
}
type VisitorPluginOptions interface {

View File

@@ -203,25 +203,6 @@ func (m *serverMetrics) GetServer() *ServerStats {
return s
}
func toProxyStats(name string, proxyStats *ProxyStatistics) *ProxyStats {
ps := &ProxyStats{
Name: name,
Type: proxyStats.ProxyType,
User: proxyStats.User,
ClientID: proxyStats.ClientID,
TodayTrafficIn: proxyStats.TrafficIn.TodayCount(),
TodayTrafficOut: proxyStats.TrafficOut.TodayCount(),
CurConns: int64(proxyStats.CurConns.Count()),
}
if !proxyStats.LastStartTime.IsZero() {
ps.LastStartTime = proxyStats.LastStartTime.Format("01-02 15:04:05")
}
if !proxyStats.LastCloseTime.IsZero() {
ps.LastCloseTime = proxyStats.LastCloseTime.Format("01-02 15:04:05")
}
return ps
}
func (m *serverMetrics) GetProxiesByType(proxyType string) []*ProxyStats {
res := make([]*ProxyStats, 0)
m.mu.Lock()
@@ -231,7 +212,23 @@ func (m *serverMetrics) GetProxiesByType(proxyType string) []*ProxyStats {
if proxyStats.ProxyType != proxyType {
continue
}
res = append(res, toProxyStats(name, proxyStats))
ps := &ProxyStats{
Name: name,
Type: proxyStats.ProxyType,
User: proxyStats.User,
ClientID: proxyStats.ClientID,
TodayTrafficIn: proxyStats.TrafficIn.TodayCount(),
TodayTrafficOut: proxyStats.TrafficOut.TodayCount(),
CurConns: int64(proxyStats.CurConns.Count()),
}
if !proxyStats.LastStartTime.IsZero() {
ps.LastStartTime = proxyStats.LastStartTime.Format("01-02 15:04:05")
}
if !proxyStats.LastCloseTime.IsZero() {
ps.LastCloseTime = proxyStats.LastCloseTime.Format("01-02 15:04:05")
}
res = append(res, ps)
}
return res
}
@@ -244,10 +241,26 @@ func (m *serverMetrics) GetProxiesByTypeAndName(proxyType string, proxyName stri
if proxyStats.ProxyType != proxyType {
continue
}
if name != proxyName {
continue
}
res = toProxyStats(name, proxyStats)
res = &ProxyStats{
Name: name,
Type: proxyStats.ProxyType,
User: proxyStats.User,
ClientID: proxyStats.ClientID,
TodayTrafficIn: proxyStats.TrafficIn.TodayCount(),
TodayTrafficOut: proxyStats.TrafficOut.TodayCount(),
CurConns: int64(proxyStats.CurConns.Count()),
}
if !proxyStats.LastStartTime.IsZero() {
res.LastStartTime = proxyStats.LastStartTime.Format("01-02 15:04:05")
}
if !proxyStats.LastCloseTime.IsZero() {
res.LastCloseTime = proxyStats.LastCloseTime.Format("01-02 15:04:05")
}
break
}
return
@@ -259,7 +272,21 @@ func (m *serverMetrics) GetProxyByName(proxyName string) (res *ProxyStats) {
proxyStats, ok := m.info.ProxyStatistics[proxyName]
if ok {
res = toProxyStats(proxyName, proxyStats)
res = &ProxyStats{
Name: proxyName,
Type: proxyStats.ProxyType,
User: proxyStats.User,
ClientID: proxyStats.ClientID,
TodayTrafficIn: proxyStats.TrafficIn.TodayCount(),
TodayTrafficOut: proxyStats.TrafficOut.TodayCount(),
CurConns: int64(proxyStats.CurConns.Count()),
}
if !proxyStats.LastStartTime.IsZero() {
res.LastStartTime = proxyStats.LastStartTime.Format("01-02 15:04:05")
}
if !proxyStats.LastCloseTime.IsZero() {
res.LastCloseTime = proxyStats.LastCloseTime.Format("01-02 15:04:05")
}
}
return
}

View File

@@ -61,7 +61,7 @@ var msgTypeMap = map[byte]any{
TypeNatHoleReport: NatHoleReport{},
}
var TypeNameNatHoleResp = reflect.TypeFor[NatHoleResp]().Name()
var TypeNameNatHoleResp = reflect.TypeOf(&NatHoleResp{}).Elem().Name()
type ClientSpec struct {
// Due to the support of VirtualClient, frps needs to know the client type in order to

View File

@@ -16,8 +16,9 @@ func StripUserPrefix(user, name string) string {
if user == "" {
return name
}
if trimmed, ok := strings.CutPrefix(name, user+"."); ok {
return trimmed
prefix := user + "."
if strings.HasPrefix(name, prefix) {
return strings.TrimPrefix(name, prefix)
}
return name
}

View File

@@ -151,7 +151,7 @@ func getBehaviorScoresByMode(mode int, defaultScore int) []*BehaviorScore {
func getBehaviorScoresByMode2(mode int, senderScore, receiverScore int) []*BehaviorScore {
behaviors := getBehaviorByMode(mode)
scores := make([]*BehaviorScore, 0, len(behaviors))
for i := range behaviors {
for i := 0; i < len(behaviors); i++ {
score := receiverScore
if behaviors[i].A.Role == DetectRoleSender {
score = senderScore

View File

@@ -70,8 +70,12 @@ func ClassifyNATFeature(addresses []string, localIPs []string) (*NatFeature, err
continue
}
portMax = max(portMax, portNum)
portMin = min(portMin, portNum)
if portNum > portMax {
portMax = portNum
}
if portNum < portMin {
portMin = portNum
}
if baseIP != ip {
ipChanged = true
}

View File

@@ -152,9 +152,7 @@ func (c *Controller) GenSid() string {
func (c *Controller) HandleVisitor(m *msg.NatHoleVisitor, transporter transport.MessageTransporter, visitorUser string) {
if m.PreCheck {
c.mu.RLock()
cfg, ok := c.clientCfgs[m.ProxyName]
c.mu.RUnlock()
if !ok {
_ = transporter.Send(c.GenNatHoleResponse(m.TransactionID, nil, fmt.Sprintf("xtcp server for [%s] doesn't exist", m.ProxyName)))
return

View File

@@ -298,13 +298,11 @@ func waitDetectMessage(
n, raddr, err := conn.ReadFromUDP(buf)
_ = conn.SetReadDeadline(time.Time{})
if err != nil {
pool.PutBuf(buf)
return nil, err
}
xl.Debugf("get udp message local %s, from %s", conn.LocalAddr(), raddr)
var m msg.NatHoleSid
if err := DecodeMessageInto(buf[:n], key, &m); err != nil {
pool.PutBuf(buf)
xl.Warnf("decode sid message error: %v", err)
continue
}
@@ -410,7 +408,7 @@ func sendSidMessageToRandomPorts(
xl := xlog.FromContextSafe(ctx)
used := sets.New[int]()
getUnusedPort := func() int {
for range 10 {
for i := 0; i < 10; i++ {
port := rand.IntN(65535-1024) + 1024
if !used.Has(port) {
used.Insert(port)
@@ -420,7 +418,7 @@ func sendSidMessageToRandomPorts(
return 0
}
for range count {
for i := 0; i < count; i++ {
select {
case <-ctx.Done():
return

View File

@@ -21,7 +21,6 @@ import (
stdlog "log"
"net/http"
"net/http/httputil"
"time"
"github.com/fatedier/golib/pool"
@@ -69,7 +68,7 @@ func NewHTTP2HTTPPlugin(_ PluginContext, options v1.ClientPluginOptions) (Plugin
p.s = &http.Server{
Handler: rp,
ReadHeaderTimeout: 60 * time.Second,
ReadHeaderTimeout: 0,
}
go func() {

View File

@@ -22,7 +22,6 @@ import (
stdlog "log"
"net/http"
"net/http/httputil"
"time"
"github.com/fatedier/golib/pool"
@@ -78,7 +77,7 @@ func NewHTTP2HTTPSPlugin(_ PluginContext, options v1.ClientPluginOptions) (Plugi
p.s = &http.Server{
Handler: rp,
ReadHeaderTimeout: 60 * time.Second,
ReadHeaderTimeout: 0,
}
go func() {

View File

@@ -62,13 +62,11 @@ func (p *TLS2RawPlugin) Handle(ctx context.Context, connInfo *ConnectionInfo) {
if err := tlsConn.Handshake(); err != nil {
xl.Warnf("tls handshake error: %v", err)
tlsConn.Close()
return
}
rawConn, err := net.Dial("tcp", p.opts.LocalAddr)
if err != nil {
xl.Warnf("dial to local addr error: %v", err)
tlsConn.Close()
return
}

View File

@@ -54,13 +54,10 @@ func (uds *UnixDomainSocketPlugin) Handle(ctx context.Context, connInfo *Connect
localConn, err := net.DialUnix("unix", nil, uds.UnixAddr)
if err != nil {
xl.Warnf("dial to uds %s error: %v", uds.UnixAddr, err)
connInfo.Conn.Close()
return
}
if connInfo.ProxyProtocolHeader != nil {
if _, err := connInfo.ProxyProtocolHeader.WriteTo(localConn); err != nil {
localConn.Close()
connInfo.Conn.Close()
return
}
}

View File

@@ -24,7 +24,6 @@ import (
"net/http"
"net/url"
"reflect"
"slices"
"strings"
v1 "github.com/fatedier/frp/pkg/config/v1"
@@ -65,7 +64,12 @@ func (p *httpPlugin) Name() string {
}
func (p *httpPlugin) IsSupport(op string) bool {
return slices.Contains(p.options.Ops, op)
for _, v := range p.options.Ops {
if v == op {
return true
}
}
return false
}
func (p *httpPlugin) Handle(ctx context.Context, op string, content any) (*Response, any, error) {

View File

@@ -153,7 +153,10 @@ func (p *VirtualNetPlugin) run() {
// Exponential backoff: 60s, 120s, 240s, 300s (capped)
baseDelay := 60 * time.Second
reconnectDelay = min(baseDelay*time.Duration(1<<uint(p.consecutiveErrors-1)), 300*time.Second)
reconnectDelay = baseDelay * time.Duration(1<<uint(p.consecutiveErrors-1))
if reconnectDelay > 300*time.Second {
reconnectDelay = 300 * time.Second
}
} else {
// Reset consecutive errors on successful connection
if p.consecutiveErrors > 0 {

View File

@@ -16,7 +16,6 @@ package featuregate
import (
"fmt"
"maps"
"sort"
"strings"
"sync"
@@ -94,7 +93,9 @@ type featureGate struct {
// NewFeatureGate creates a new feature gate with the default features
func NewFeatureGate() MutableFeatureGate {
known := map[Feature]FeatureSpec{}
maps.Copy(known, defaultFeatures)
for k, v := range defaultFeatures {
known[k] = v
}
f := &featureGate{}
f.known.Store(known)
@@ -109,9 +110,13 @@ func (f *featureGate) SetFromMap(m map[string]bool) error {
// Copy existing state
known := map[Feature]FeatureSpec{}
maps.Copy(known, f.known.Load().(map[Feature]FeatureSpec))
for k, v := range f.known.Load().(map[Feature]FeatureSpec) {
known[k] = v
}
enabled := map[Feature]bool{}
maps.Copy(enabled, f.enabled.Load().(map[Feature]bool))
for k, v := range f.enabled.Load().(map[Feature]bool) {
enabled[k] = v
}
// Apply the new settings
for k, v := range m {
@@ -143,7 +148,9 @@ func (f *featureGate) Add(features map[Feature]FeatureSpec) error {
// Copy existing state
known := map[Feature]FeatureSpec{}
maps.Copy(known, f.known.Load().(map[Feature]FeatureSpec))
for k, v := range f.known.Load().(map[Feature]FeatureSpec) {
known[k] = v
}
// Add new features
for name, spec := range features {

View File

@@ -85,7 +85,6 @@ func Forwarder(dstAddr *net.UDPAddr, readCh <-chan *msg.UDPPacket, sendCh chan<-
}()
buf := pool.GetBuf(bufSize)
defer pool.PutBuf(buf)
for {
_ = udpConn.SetReadDeadline(time.Now().Add(30 * time.Second))
n, _, err := udpConn.ReadFromUDP(buf)

View File

@@ -20,7 +20,6 @@ import (
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"math/big"
"os"
"time"
@@ -86,9 +85,7 @@ func newCertPool(caPath string) (*x509.CertPool, error) {
return nil, err
}
if !pool.AppendCertsFromPEM(caCrt) {
return nil, fmt.Errorf("failed to parse CA certificate from file %q: no valid PEM certificates found", caPath)
}
pool.AppendCertsFromPEM(caCrt)
return pool, nil
}

View File

@@ -89,11 +89,11 @@ func ParseBasicAuth(auth string) (username, password string, ok bool) {
return
}
cs := string(c)
before, after, found := strings.Cut(cs, ":")
if !found {
s := strings.IndexByte(cs, ':')
if s < 0 {
return
}
return before, after, true
return cs[:s], cs[s+1:], true
}
func BasicAuth(username, passwd string) string {

View File

@@ -86,7 +86,11 @@ func (c *FakeUDPConn) Read(b []byte) (n int, err error) {
c.lastActive = time.Now()
c.mu.Unlock()
n = min(len(b), len(content))
if len(b) < len(content) {
n = len(b)
} else {
n = len(content)
}
copy(b, content)
return n, nil
}
@@ -164,15 +168,11 @@ func ListenUDP(bindAddr string, bindPort int) (l *UDPListener, err error) {
return l, err
}
readConn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
return l, err
}
l = &UDPListener{
addr: udpAddr,
acceptCh: make(chan net.Conn),
writeCh: make(chan *UDPPacket, 1000),
readConn: readConn,
fakeConns: make(map[string]*FakeUDPConn),
}

View File

@@ -26,7 +26,6 @@ type WebsocketListener struct {
// ln: tcp listener for websocket connections
func NewWebsocketListener(ln net.Listener) (wl *WebsocketListener) {
wl = &WebsocketListener{
ln: ln,
acceptCh: make(chan net.Conn),
}

View File

@@ -68,8 +68,8 @@ func ParseRangeNumbers(rangeStr string) (numbers []int64, err error) {
rangeStr = strings.TrimSpace(rangeStr)
numbers = make([]int64, 0)
// e.g. 1000-2000,2001,2002,3000-4000
numRanges := strings.SplitSeq(rangeStr, ",")
for numRangeStr := range numRanges {
numRanges := strings.Split(rangeStr, ",")
for _, numRangeStr := range numRanges {
// 1000-2000 or 2001
numArray := strings.Split(numRangeStr, "-")
// length: only 1 or 2 is correct

View File

@@ -266,13 +266,31 @@ func (rp *HTTPReverseProxy) connectHandler(rw http.ResponseWriter, req *http.Req
go libio.Join(remote, client)
}
func parseBasicAuth(auth string) (username, password string, ok bool) {
const prefix = "Basic "
// Case insensitive prefix match. See Issue 22736.
if len(auth) < len(prefix) || !strings.EqualFold(auth[:len(prefix)], prefix) {
return
}
c, err := base64.StdEncoding.DecodeString(auth[len(prefix):])
if err != nil {
return
}
cs := string(c)
s := strings.IndexByte(cs, ':')
if s < 0 {
return
}
return cs[:s], cs[s+1:], true
}
func (rp *HTTPReverseProxy) injectRequestInfoToCtx(req *http.Request) *http.Request {
user := ""
// If url host isn't empty, it's a proxy request. Get http user from Proxy-Authorization header.
if req.URL.Host != "" {
proxyAuth := req.Header.Get("Proxy-Authorization")
if proxyAuth != "" {
user, _, _ = httppkg.ParseBasicAuth(proxyAuth)
user, _, _ = parseBasicAuth(proxyAuth)
}
}
if user == "" {

View File

@@ -63,12 +63,11 @@ func (l *Logger) AddPrefix(prefix LogPrefix) *Logger {
if prefix.Priority <= 0 {
prefix.Priority = 10
}
for i, p := range l.prefixes {
for _, p := range l.prefixes {
if p.Name == prefix.Name {
found = true
l.prefixes[i].Value = prefix.Value
l.prefixes[i].Priority = prefix.Priority
break
p.Value = prefix.Value
p.Priority = prefix.Priority
}
}
if !found {

View File

@@ -95,33 +95,20 @@ func (cm *ControlManager) Close() error {
return nil
}
// SessionContext encapsulates the input parameters for creating a new Control.
type SessionContext struct {
// all resource managers and controllers
RC *controller.ResourceController
// proxy manager
PxyManager *proxy.Manager
// plugin manager
PluginManager *plugin.Manager
// verifies authentication based on selected method
AuthVerifier auth.Verifier
// key used for connection encryption
EncryptionKey []byte
// control connection
Conn net.Conn
// indicates whether the connection is encrypted
ConnEncrypted bool
// login message
LoginMsg *msg.Login
// server configuration
ServerCfg *v1.ServerConfig
// client registry
ClientRegistry *registry.ClientRegistry
}
type Control struct {
// session context
sessionCtx *SessionContext
// all resource managers and controllers
rc *controller.ResourceController
// proxy manager
pxyManager *proxy.Manager
// plugin manager
pluginManager *plugin.Manager
// verifies authentication based on selected method
authVerifier auth.Verifier
// key used for connection encryption
encryptionKey []byte
// other components can use this to communicate with client
msgTransporter transport.MessageTransporter
@@ -130,6 +117,12 @@ type Control struct {
// It provides a channel for sending messages, and you can register handlers to process messages based on their respective types.
msgDispatcher *msg.Dispatcher
// login message
loginMsg *msg.Login
// control connection
conn net.Conn
// work connections
workConnCh chan net.Conn
@@ -152,34 +145,61 @@ type Control struct {
mu sync.RWMutex
// Server configuration information
serverCfg *v1.ServerConfig
clientRegistry *registry.ClientRegistry
xl *xlog.Logger
ctx context.Context
doneCh chan struct{}
}
func NewControl(ctx context.Context, sessionCtx *SessionContext) (*Control, error) {
poolCount := min(sessionCtx.LoginMsg.PoolCount, int(sessionCtx.ServerCfg.Transport.MaxPoolCount))
// TODO(fatedier): Referencing the implementation of frpc, encapsulate the input parameters as SessionContext.
func NewControl(
ctx context.Context,
rc *controller.ResourceController,
pxyManager *proxy.Manager,
pluginManager *plugin.Manager,
authVerifier auth.Verifier,
encryptionKey []byte,
ctlConn net.Conn,
ctlConnEncrypted bool,
loginMsg *msg.Login,
serverCfg *v1.ServerConfig,
) (*Control, error) {
poolCount := loginMsg.PoolCount
if poolCount > int(serverCfg.Transport.MaxPoolCount) {
poolCount = int(serverCfg.Transport.MaxPoolCount)
}
ctl := &Control{
sessionCtx: sessionCtx,
workConnCh: make(chan net.Conn, poolCount+10),
proxies: make(map[string]proxy.Proxy),
poolCount: poolCount,
portsUsedNum: 0,
runID: sessionCtx.LoginMsg.RunID,
xl: xlog.FromContextSafe(ctx),
ctx: ctx,
doneCh: make(chan struct{}),
rc: rc,
pxyManager: pxyManager,
pluginManager: pluginManager,
authVerifier: authVerifier,
encryptionKey: encryptionKey,
conn: ctlConn,
loginMsg: loginMsg,
workConnCh: make(chan net.Conn, poolCount+10),
proxies: make(map[string]proxy.Proxy),
poolCount: poolCount,
portsUsedNum: 0,
runID: loginMsg.RunID,
serverCfg: serverCfg,
xl: xlog.FromContextSafe(ctx),
ctx: ctx,
doneCh: make(chan struct{}),
}
ctl.lastPing.Store(time.Now())
if sessionCtx.ConnEncrypted {
cryptoRW, err := netpkg.NewCryptoReadWriter(sessionCtx.Conn, sessionCtx.EncryptionKey)
if ctlConnEncrypted {
cryptoRW, err := netpkg.NewCryptoReadWriter(ctl.conn, ctl.encryptionKey)
if err != nil {
return nil, err
}
ctl.msgDispatcher = msg.NewDispatcher(cryptoRW)
} else {
ctl.msgDispatcher = msg.NewDispatcher(sessionCtx.Conn)
ctl.msgDispatcher = msg.NewDispatcher(ctl.conn)
}
ctl.registerMsgHandlers()
ctl.msgTransporter = transport.NewMessageTransporter(ctl.msgDispatcher)
@@ -193,7 +213,7 @@ func (ctl *Control) Start() {
RunID: ctl.runID,
Error: "",
}
_ = msg.WriteMsg(ctl.sessionCtx.Conn, loginRespMsg)
_ = msg.WriteMsg(ctl.conn, loginRespMsg)
go func() {
for i := 0; i < ctl.poolCount; i++ {
@@ -205,7 +225,7 @@ func (ctl *Control) Start() {
}
func (ctl *Control) Close() error {
ctl.sessionCtx.Conn.Close()
ctl.conn.Close()
return nil
}
@@ -213,7 +233,7 @@ func (ctl *Control) Replaced(newCtl *Control) {
xl := ctl.xl
xl.Infof("replaced by client [%s]", newCtl.runID)
ctl.runID = ""
ctl.sessionCtx.Conn.Close()
ctl.conn.Close()
}
func (ctl *Control) RegisterWorkConn(conn net.Conn) error {
@@ -271,7 +291,7 @@ func (ctl *Control) GetWorkConn() (workConn net.Conn, err error) {
return
}
case <-time.After(time.Duration(ctl.sessionCtx.ServerCfg.UserConnTimeout) * time.Second):
case <-time.After(time.Duration(ctl.serverCfg.UserConnTimeout) * time.Second):
err = fmt.Errorf("timeout trying to get work connection")
xl.Warnf("%v", err)
return
@@ -284,15 +304,15 @@ func (ctl *Control) GetWorkConn() (workConn net.Conn, err error) {
}
func (ctl *Control) heartbeatWorker() {
if ctl.sessionCtx.ServerCfg.Transport.HeartbeatTimeout <= 0 {
if ctl.serverCfg.Transport.HeartbeatTimeout <= 0 {
return
}
xl := ctl.xl
go wait.Until(func() {
if time.Since(ctl.lastPing.Load().(time.Time)) > time.Duration(ctl.sessionCtx.ServerCfg.Transport.HeartbeatTimeout)*time.Second {
if time.Since(ctl.lastPing.Load().(time.Time)) > time.Duration(ctl.serverCfg.Transport.HeartbeatTimeout)*time.Second {
xl.Warnf("heartbeat timeout")
ctl.sessionCtx.Conn.Close()
ctl.conn.Close()
return
}
}, time.Second, ctl.doneCh)
@@ -303,30 +323,6 @@ func (ctl *Control) WaitClosed() {
<-ctl.doneCh
}
func (ctl *Control) loginUserInfo() plugin.UserInfo {
return plugin.UserInfo{
User: ctl.sessionCtx.LoginMsg.User,
Metas: ctl.sessionCtx.LoginMsg.Metas,
RunID: ctl.sessionCtx.LoginMsg.RunID,
}
}
func (ctl *Control) closeProxy(pxy proxy.Proxy) {
pxy.Close()
ctl.sessionCtx.PxyManager.Del(pxy.GetName())
metrics.Server.CloseProxy(pxy.GetName(), pxy.GetConfigurer().GetBaseConfig().Type)
notifyContent := &plugin.CloseProxyContent{
User: ctl.loginUserInfo(),
CloseProxy: msg.CloseProxy{
ProxyName: pxy.GetName(),
},
}
go func() {
_ = ctl.sessionCtx.PluginManager.CloseProxy(notifyContent)
}()
}
func (ctl *Control) worker() {
xl := ctl.xl
@@ -334,23 +330,38 @@ func (ctl *Control) worker() {
go ctl.msgDispatcher.Run()
<-ctl.msgDispatcher.Done()
ctl.sessionCtx.Conn.Close()
ctl.conn.Close()
ctl.mu.Lock()
defer ctl.mu.Unlock()
close(ctl.workConnCh)
for workConn := range ctl.workConnCh {
workConn.Close()
}
proxies := ctl.proxies
ctl.proxies = make(map[string]proxy.Proxy)
ctl.mu.Unlock()
for _, pxy := range proxies {
ctl.closeProxy(pxy)
for _, pxy := range ctl.proxies {
pxy.Close()
ctl.pxyManager.Del(pxy.GetName())
metrics.Server.CloseProxy(pxy.GetName(), pxy.GetConfigurer().GetBaseConfig().Type)
notifyContent := &plugin.CloseProxyContent{
User: plugin.UserInfo{
User: ctl.loginMsg.User,
Metas: ctl.loginMsg.Metas,
RunID: ctl.loginMsg.RunID,
},
CloseProxy: msg.CloseProxy{
ProxyName: pxy.GetName(),
},
}
go func() {
_ = ctl.pluginManager.CloseProxy(notifyContent)
}()
}
metrics.Server.CloseClient()
ctl.sessionCtx.ClientRegistry.MarkOfflineByRunID(ctl.runID)
ctl.clientRegistry.MarkOfflineByRunID(ctl.runID)
xl.Infof("client exit success")
close(ctl.doneCh)
}
@@ -369,11 +380,15 @@ func (ctl *Control) handleNewProxy(m msg.Message) {
inMsg := m.(*msg.NewProxy)
content := &plugin.NewProxyContent{
User: ctl.loginUserInfo(),
User: plugin.UserInfo{
User: ctl.loginMsg.User,
Metas: ctl.loginMsg.Metas,
RunID: ctl.loginMsg.RunID,
},
NewProxy: *inMsg,
}
var remoteAddr string
retContent, err := ctl.sessionCtx.PluginManager.NewProxy(content)
retContent, err := ctl.pluginManager.NewProxy(content)
if err == nil {
inMsg = &retContent.NewProxy
remoteAddr, err = ctl.RegisterProxy(inMsg)
@@ -386,15 +401,15 @@ func (ctl *Control) handleNewProxy(m msg.Message) {
if err != nil {
xl.Warnf("new proxy [%s] type [%s] error: %v", inMsg.ProxyName, inMsg.ProxyType, err)
resp.Error = util.GenerateResponseErrorString(fmt.Sprintf("new proxy [%s] error", inMsg.ProxyName),
err, lo.FromPtr(ctl.sessionCtx.ServerCfg.DetailedErrorsToClient))
err, lo.FromPtr(ctl.serverCfg.DetailedErrorsToClient))
} else {
resp.RemoteAddr = remoteAddr
xl.Infof("new proxy [%s] type [%s] success", inMsg.ProxyName, inMsg.ProxyType)
clientID := ctl.sessionCtx.LoginMsg.ClientID
clientID := ctl.loginMsg.ClientID
if clientID == "" {
clientID = ctl.sessionCtx.LoginMsg.RunID
clientID = ctl.loginMsg.RunID
}
metrics.Server.NewProxy(inMsg.ProxyName, inMsg.ProxyType, ctl.sessionCtx.LoginMsg.User, clientID)
metrics.Server.NewProxy(inMsg.ProxyName, inMsg.ProxyType, ctl.loginMsg.User, clientID)
}
_ = ctl.msgDispatcher.Send(resp)
}
@@ -404,18 +419,22 @@ func (ctl *Control) handlePing(m msg.Message) {
inMsg := m.(*msg.Ping)
content := &plugin.PingContent{
User: ctl.loginUserInfo(),
User: plugin.UserInfo{
User: ctl.loginMsg.User,
Metas: ctl.loginMsg.Metas,
RunID: ctl.loginMsg.RunID,
},
Ping: *inMsg,
}
retContent, err := ctl.sessionCtx.PluginManager.Ping(content)
retContent, err := ctl.pluginManager.Ping(content)
if err == nil {
inMsg = &retContent.Ping
err = ctl.sessionCtx.AuthVerifier.VerifyPing(inMsg)
err = ctl.authVerifier.VerifyPing(inMsg)
}
if err != nil {
xl.Warnf("received invalid ping: %v", err)
_ = ctl.msgDispatcher.Send(&msg.Pong{
Error: util.GenerateResponseErrorString("invalid ping", err, lo.FromPtr(ctl.sessionCtx.ServerCfg.DetailedErrorsToClient)),
Error: util.GenerateResponseErrorString("invalid ping", err, lo.FromPtr(ctl.serverCfg.DetailedErrorsToClient)),
})
return
}
@@ -426,17 +445,17 @@ func (ctl *Control) handlePing(m msg.Message) {
func (ctl *Control) handleNatHoleVisitor(m msg.Message) {
inMsg := m.(*msg.NatHoleVisitor)
ctl.sessionCtx.RC.NatHoleController.HandleVisitor(inMsg, ctl.msgTransporter, ctl.sessionCtx.LoginMsg.User)
ctl.rc.NatHoleController.HandleVisitor(inMsg, ctl.msgTransporter, ctl.loginMsg.User)
}
func (ctl *Control) handleNatHoleClient(m msg.Message) {
inMsg := m.(*msg.NatHoleClient)
ctl.sessionCtx.RC.NatHoleController.HandleClient(inMsg, ctl.msgTransporter)
ctl.rc.NatHoleController.HandleClient(inMsg, ctl.msgTransporter)
}
func (ctl *Control) handleNatHoleReport(m msg.Message) {
inMsg := m.(*msg.NatHoleReport)
ctl.sessionCtx.RC.NatHoleController.HandleReport(inMsg)
ctl.rc.NatHoleController.HandleReport(inMsg)
}
func (ctl *Control) handleCloseProxy(m msg.Message) {
@@ -449,15 +468,15 @@ func (ctl *Control) handleCloseProxy(m msg.Message) {
func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err error) {
var pxyConf v1.ProxyConfigurer
// Load configures from NewProxy message and validate.
pxyConf, err = config.NewProxyConfigurerFromMsg(pxyMsg, ctl.sessionCtx.ServerCfg)
pxyConf, err = config.NewProxyConfigurerFromMsg(pxyMsg, ctl.serverCfg)
if err != nil {
return
}
// User info
userInfo := plugin.UserInfo{
User: ctl.sessionCtx.LoginMsg.User,
Metas: ctl.sessionCtx.LoginMsg.Metas,
User: ctl.loginMsg.User,
Metas: ctl.loginMsg.Metas,
RunID: ctl.runID,
}
@@ -465,22 +484,22 @@ func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err
// In fact, it creates different proxies based on the proxy type. We just call run() here.
pxy, err := proxy.NewProxy(ctl.ctx, &proxy.Options{
UserInfo: userInfo,
LoginMsg: ctl.sessionCtx.LoginMsg,
LoginMsg: ctl.loginMsg,
PoolCount: ctl.poolCount,
ResourceController: ctl.sessionCtx.RC,
ResourceController: ctl.rc,
GetWorkConnFn: ctl.GetWorkConn,
Configurer: pxyConf,
ServerCfg: ctl.sessionCtx.ServerCfg,
EncryptionKey: ctl.sessionCtx.EncryptionKey,
ServerCfg: ctl.serverCfg,
EncryptionKey: ctl.encryptionKey,
})
if err != nil {
return remoteAddr, err
}
// Check ports used number in each client
if ctl.sessionCtx.ServerCfg.MaxPortsPerClient > 0 {
if ctl.serverCfg.MaxPortsPerClient > 0 {
ctl.mu.Lock()
if ctl.portsUsedNum+pxy.GetUsedPortsNum() > int(ctl.sessionCtx.ServerCfg.MaxPortsPerClient) {
if ctl.portsUsedNum+pxy.GetUsedPortsNum() > int(ctl.serverCfg.MaxPortsPerClient) {
ctl.mu.Unlock()
err = fmt.Errorf("exceed the max_ports_per_client")
return
@@ -497,7 +516,7 @@ func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err
}()
}
if ctl.sessionCtx.PxyManager.Exist(pxyMsg.ProxyName) {
if ctl.pxyManager.Exist(pxyMsg.ProxyName) {
err = fmt.Errorf("proxy [%s] already exists", pxyMsg.ProxyName)
return
}
@@ -512,7 +531,7 @@ func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err
}
}()
err = ctl.sessionCtx.PxyManager.Add(pxyMsg.ProxyName, pxy)
err = ctl.pxyManager.Add(pxyMsg.ProxyName, pxy)
if err != nil {
return
}
@@ -531,12 +550,28 @@ func (ctl *Control) CloseProxy(closeMsg *msg.CloseProxy) (err error) {
return
}
if ctl.sessionCtx.ServerCfg.MaxPortsPerClient > 0 {
if ctl.serverCfg.MaxPortsPerClient > 0 {
ctl.portsUsedNum -= pxy.GetUsedPortsNum()
}
pxy.Close()
ctl.pxyManager.Del(pxy.GetName())
delete(ctl.proxies, closeMsg.ProxyName)
ctl.mu.Unlock()
ctl.closeProxy(pxy)
metrics.Server.CloseProxy(pxy.GetName(), pxy.GetConfigurer().GetBaseConfig().Type)
notifyContent := &plugin.CloseProxyContent{
User: plugin.UserInfo{
User: ctl.loginMsg.User,
Metas: ctl.loginMsg.Metas,
RunID: ctl.loginMsg.RunID,
},
CloseProxy: msg.CloseProxy{
ProxyName: pxy.GetName(),
},
}
go func() {
_ = ctl.pluginManager.CloseProxy(notifyContent)
}()
return
}

View File

@@ -100,9 +100,8 @@ func (tg *TCPGroup) Listen(proxyName string, group string, groupKey string, addr
if err != nil {
return
}
tcpLn, errRet := net.Listen("tcp", net.JoinHostPort(addr, strconv.Itoa(realPort)))
tcpLn, errRet := net.Listen("tcp", net.JoinHostPort(addr, strconv.Itoa(port)))
if errRet != nil {
tg.ctl.portManager.Release(realPort)
err = errRet
return
}

View File

@@ -31,7 +31,7 @@ import (
)
func init() {
RegisterProxyFactory(reflect.TypeFor[*v1.HTTPProxyConfig](), NewHTTPProxy)
RegisterProxyFactory(reflect.TypeOf(&v1.HTTPProxyConfig{}), NewHTTPProxy)
}
type HTTPProxy struct {
@@ -75,13 +75,16 @@ func (pxy *HTTPProxy) Run() (remoteAddr string, err error) {
}
}()
domains := pxy.buildDomains(pxy.cfg.CustomDomains, pxy.cfg.SubDomain)
addrs := make([]string, 0)
for _, domain := range domains {
for _, domain := range pxy.cfg.CustomDomains {
if domain == "" {
continue
}
routeConfig.Domain = domain
for _, location := range locations {
routeConfig.Location = location
tmpRouteConfig := routeConfig
// handle group
@@ -90,6 +93,40 @@ func (pxy *HTTPProxy) Run() (remoteAddr string, err error) {
if err != nil {
return
}
pxy.closeFuncs = append(pxy.closeFuncs, func() {
pxy.rc.HTTPGroupCtl.UnRegister(pxy.name, pxy.cfg.LoadBalancer.Group, tmpRouteConfig)
})
} else {
// no group
err = pxy.rc.HTTPReverseProxy.Register(routeConfig)
if err != nil {
return
}
pxy.closeFuncs = append(pxy.closeFuncs, func() {
pxy.rc.HTTPReverseProxy.UnRegister(tmpRouteConfig)
})
}
addrs = append(addrs, util.CanonicalAddr(routeConfig.Domain, pxy.serverCfg.VhostHTTPPort))
xl.Infof("http proxy listen for host [%s] location [%s] group [%s], routeByHTTPUser [%s]",
routeConfig.Domain, routeConfig.Location, pxy.cfg.LoadBalancer.Group, pxy.cfg.RouteByHTTPUser)
}
}
if pxy.cfg.SubDomain != "" {
routeConfig.Domain = pxy.cfg.SubDomain + "." + pxy.serverCfg.SubDomainHost
for _, location := range locations {
routeConfig.Location = location
tmpRouteConfig := routeConfig
// handle group
if pxy.cfg.LoadBalancer.Group != "" {
err = pxy.rc.HTTPGroupCtl.Register(pxy.name, pxy.cfg.LoadBalancer.Group, pxy.cfg.LoadBalancer.GroupKey, routeConfig)
if err != nil {
return
}
pxy.closeFuncs = append(pxy.closeFuncs, func() {
pxy.rc.HTTPGroupCtl.UnRegister(pxy.name, pxy.cfg.LoadBalancer.Group, tmpRouteConfig)
})
@@ -102,7 +139,8 @@ func (pxy *HTTPProxy) Run() (remoteAddr string, err error) {
pxy.rc.HTTPReverseProxy.UnRegister(tmpRouteConfig)
})
}
addrs = append(addrs, util.CanonicalAddr(routeConfig.Domain, pxy.serverCfg.VhostHTTPPort))
addrs = append(addrs, util.CanonicalAddr(tmpRouteConfig.Domain, pxy.serverCfg.VhostHTTPPort))
xl.Infof("http proxy listen for host [%s] location [%s] group [%s], routeByHTTPUser [%s]",
routeConfig.Domain, routeConfig.Location, pxy.cfg.LoadBalancer.Group, pxy.cfg.RouteByHTTPUser)
}
@@ -130,7 +168,6 @@ func (pxy *HTTPProxy) GetRealConn(remoteAddr string) (workConn net.Conn, err err
rwc, err = libio.WithEncryption(rwc, pxy.encryptionKey)
if err != nil {
xl.Errorf("create encryption stream error: %v", err)
tmpConn.Close()
return
}
}

View File

@@ -25,7 +25,7 @@ import (
)
func init() {
RegisterProxyFactory(reflect.TypeFor[*v1.HTTPSProxyConfig](), NewHTTPSProxy)
RegisterProxyFactory(reflect.TypeOf(&v1.HTTPSProxyConfig{}), NewHTTPSProxy)
}
type HTTPSProxy struct {
@@ -53,10 +53,23 @@ func (pxy *HTTPSProxy) Run() (remoteAddr string, err error) {
pxy.Close()
}
}()
domains := pxy.buildDomains(pxy.cfg.CustomDomains, pxy.cfg.SubDomain)
addrs := make([]string, 0)
for _, domain := range domains {
for _, domain := range pxy.cfg.CustomDomains {
if domain == "" {
continue
}
l, err := pxy.listenForDomain(routeConfig, domain)
if err != nil {
return "", err
}
pxy.listeners = append(pxy.listeners, l)
addrs = append(addrs, util.CanonicalAddr(domain, pxy.serverCfg.VhostHTTPSPort))
xl.Infof("https proxy listen for host [%s] group [%s]", domain, pxy.cfg.LoadBalancer.Group)
}
if pxy.cfg.SubDomain != "" {
domain := pxy.cfg.SubDomain + "." + pxy.serverCfg.SubDomainHost
l, err := pxy.listenForDomain(routeConfig, domain)
if err != nil {
return "", err

View File

@@ -150,7 +150,7 @@ func (pxy *BaseProxy) GetWorkConnFromPool(src, dst net.Addr) (workConn net.Conn,
dstAddr, dstPortStr, _ = net.SplitHostPort(dst.String())
dstPort, _ = strconv.ParseUint(dstPortStr, 10, 16)
}
err = msg.WriteMsg(workConn, &msg.StartWorkConn{
err := msg.WriteMsg(workConn, &msg.StartWorkConn{
ProxyName: pxy.GetName(),
SrcAddr: srcAddr,
SrcPort: uint16(srcPort),
@@ -161,7 +161,6 @@ func (pxy *BaseProxy) GetWorkConnFromPool(src, dst net.Addr) (workConn net.Conn,
if err != nil {
xl.Warnf("failed to send message to work connection from pool: %v, times: %d", err, i)
workConn.Close()
workConn = nil
} else {
break
}
@@ -174,36 +173,6 @@ func (pxy *BaseProxy) GetWorkConnFromPool(src, dst net.Addr) (workConn net.Conn,
return
}
// startVisitorListener sets up a VisitorManager listener for visitor-based proxies (STCP, SUDP).
func (pxy *BaseProxy) startVisitorListener(secretKey string, allowUsers []string, proxyType string) error {
// if allowUsers is empty, only allow same user from proxy
if len(allowUsers) == 0 {
allowUsers = []string{pxy.GetUserInfo().User}
}
listener, err := pxy.rc.VisitorManager.Listen(pxy.GetName(), secretKey, allowUsers)
if err != nil {
return err
}
pxy.listeners = append(pxy.listeners, listener)
pxy.xl.Infof("%s proxy custom listen success", proxyType)
pxy.startCommonTCPListenersHandler()
return nil
}
// buildDomains constructs a list of domains from custom domains and subdomain configuration.
func (pxy *BaseProxy) buildDomains(customDomains []string, subDomain string) []string {
domains := make([]string, 0, len(customDomains)+1)
for _, d := range customDomains {
if d != "" {
domains = append(domains, d)
}
}
if subDomain != "" {
domains = append(domains, subDomain+"."+pxy.serverCfg.SubDomainHost)
}
return domains
}
// startCommonTCPListenersHandler start a goroutine handler for each listener.
func (pxy *BaseProxy) startCommonTCPListenersHandler() {
xl := xlog.FromContextSafe(pxy.ctx)

View File

@@ -21,7 +21,7 @@ import (
)
func init() {
RegisterProxyFactory(reflect.TypeFor[*v1.STCPProxyConfig](), NewSTCPProxy)
RegisterProxyFactory(reflect.TypeOf(&v1.STCPProxyConfig{}), NewSTCPProxy)
}
type STCPProxy struct {
@@ -41,7 +41,21 @@ func NewSTCPProxy(baseProxy *BaseProxy) Proxy {
}
func (pxy *STCPProxy) Run() (remoteAddr string, err error) {
err = pxy.startVisitorListener(pxy.cfg.Secretkey, pxy.cfg.AllowUsers, "stcp")
xl := pxy.xl
allowUsers := pxy.cfg.AllowUsers
// if allowUsers is empty, only allow same user from proxy
if len(allowUsers) == 0 {
allowUsers = []string{pxy.GetUserInfo().User}
}
listener, errRet := pxy.rc.VisitorManager.Listen(pxy.GetName(), pxy.cfg.Secretkey, allowUsers)
if errRet != nil {
err = errRet
return
}
pxy.listeners = append(pxy.listeners, listener)
xl.Infof("stcp proxy custom listen success")
pxy.startCommonTCPListenersHandler()
return
}

View File

@@ -21,7 +21,7 @@ import (
)
func init() {
RegisterProxyFactory(reflect.TypeFor[*v1.SUDPProxyConfig](), NewSUDPProxy)
RegisterProxyFactory(reflect.TypeOf(&v1.SUDPProxyConfig{}), NewSUDPProxy)
}
type SUDPProxy struct {
@@ -41,7 +41,21 @@ func NewSUDPProxy(baseProxy *BaseProxy) Proxy {
}
func (pxy *SUDPProxy) Run() (remoteAddr string, err error) {
err = pxy.startVisitorListener(pxy.cfg.Secretkey, pxy.cfg.AllowUsers, "sudp")
xl := pxy.xl
allowUsers := pxy.cfg.AllowUsers
// if allowUsers is empty, only allow same user from proxy
if len(allowUsers) == 0 {
allowUsers = []string{pxy.GetUserInfo().User}
}
listener, errRet := pxy.rc.VisitorManager.Listen(pxy.GetName(), pxy.cfg.Secretkey, allowUsers)
if errRet != nil {
err = errRet
return
}
pxy.listeners = append(pxy.listeners, listener)
xl.Infof("sudp proxy custom listen success")
pxy.startCommonTCPListenersHandler()
return
}

View File

@@ -24,7 +24,7 @@ import (
)
func init() {
RegisterProxyFactory(reflect.TypeFor[*v1.TCPProxyConfig](), NewTCPProxy)
RegisterProxyFactory(reflect.TypeOf(&v1.TCPProxyConfig{}), NewTCPProxy)
}
type TCPProxy struct {

View File

@@ -26,7 +26,7 @@ import (
)
func init() {
RegisterProxyFactory(reflect.TypeFor[*v1.TCPMuxProxyConfig](), NewTCPMuxProxy)
RegisterProxyFactory(reflect.TypeOf(&v1.TCPMuxProxyConfig{}), NewTCPMuxProxy)
}
type TCPMuxProxy struct {
@@ -72,16 +72,26 @@ func (pxy *TCPMuxProxy) httpConnectListen(
}
func (pxy *TCPMuxProxy) httpConnectRun() (remoteAddr string, err error) {
domains := pxy.buildDomains(pxy.cfg.CustomDomains, pxy.cfg.SubDomain)
addrs := make([]string, 0)
for _, domain := range domains {
for _, domain := range pxy.cfg.CustomDomains {
if domain == "" {
continue
}
addrs, err = pxy.httpConnectListen(domain, pxy.cfg.RouteByHTTPUser, pxy.cfg.HTTPUser, pxy.cfg.HTTPPassword, addrs)
if err != nil {
return "", err
}
}
if pxy.cfg.SubDomain != "" {
addrs, err = pxy.httpConnectListen(pxy.cfg.SubDomain+"."+pxy.serverCfg.SubDomainHost,
pxy.cfg.RouteByHTTPUser, pxy.cfg.HTTPUser, pxy.cfg.HTTPPassword, addrs)
if err != nil {
return "", err
}
}
pxy.startCommonTCPListenersHandler()
remoteAddr = strings.Join(addrs, ",")
return remoteAddr, err

View File

@@ -35,7 +35,7 @@ import (
)
func init() {
RegisterProxyFactory(reflect.TypeFor[*v1.UDPProxyConfig](), NewUDPProxy)
RegisterProxyFactory(reflect.TypeOf(&v1.UDPProxyConfig{}), NewUDPProxy)
}
type UDPProxy struct {

View File

@@ -24,7 +24,7 @@ import (
)
func init() {
RegisterProxyFactory(reflect.TypeFor[*v1.XTCPProxyConfig](), NewXTCPProxy)
RegisterProxyFactory(reflect.TypeOf(&v1.XTCPProxyConfig{}), NewXTCPProxy)
}
type XTCPProxy struct {

View File

@@ -193,7 +193,7 @@ func NewService(cfg *v1.ServerConfig) (*Service, error) {
if err != nil {
return nil, fmt.Errorf("create vhost tcpMuxer error, %v", err)
}
log.Infof("tcpmux httpconnect multiplexer listen on %s, passthrough: %v", address, cfg.TCPMuxPassthrough)
log.Infof("tcpmux httpconnect multiplexer listen on %s, passthough: %v", address, cfg.TCPMuxPassthrough)
}
// Init all plugins
@@ -604,18 +604,8 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, inter
return err
}
ctl, err := NewControl(ctx, &SessionContext{
RC: svr.rc,
PxyManager: svr.pxyManager,
PluginManager: svr.pluginManager,
AuthVerifier: authVerifier,
EncryptionKey: svr.auth.EncryptionKey(),
Conn: ctlConn,
ConnEncrypted: !internal,
LoginMsg: loginMsg,
ServerCfg: svr.cfg,
ClientRegistry: svr.clientRegistry,
})
// TODO(fatedier): use SessionContext
ctl, err := NewControl(ctx, svr.rc, svr.pxyManager, svr.pluginManager, authVerifier, svr.auth.EncryptionKey(), ctlConn, !internal, loginMsg, svr.cfg)
if err != nil {
xl.Warnf("create new controller error: %v", err)
// don't return detailed errors to client
@@ -636,6 +626,7 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, inter
ctl.Close()
return fmt.Errorf("client_id [%s] for user [%s] is already online", loginMsg.ClientID, loginMsg.User)
}
ctl.clientRegistry = svr.clientRegistry
ctl.Start()
@@ -661,9 +652,9 @@ func (svr *Service) RegisterWorkConn(workConn net.Conn, newMsg *msg.NewWorkConn)
// server plugin hook
content := &plugin.NewWorkConnContent{
User: plugin.UserInfo{
User: ctl.sessionCtx.LoginMsg.User,
Metas: ctl.sessionCtx.LoginMsg.Metas,
RunID: ctl.sessionCtx.LoginMsg.RunID,
User: ctl.loginMsg.User,
Metas: ctl.loginMsg.Metas,
RunID: ctl.loginMsg.RunID,
},
NewWorkConn: *newMsg,
}
@@ -671,7 +662,7 @@ func (svr *Service) RegisterWorkConn(workConn net.Conn, newMsg *msg.NewWorkConn)
if err == nil {
newMsg = &retContent.NewWorkConn
// Check auth.
err = ctl.sessionCtx.AuthVerifier.VerifyNewWorkConn(newMsg)
err = ctl.authVerifier.VerifyNewWorkConn(newMsg)
}
if err != nil {
xl.Warnf("invalid NewWorkConn with run id [%s]", newMsg.RunID)
@@ -692,7 +683,7 @@ func (svr *Service) RegisterVisitorConn(visitorConn net.Conn, newMsg *msg.NewVis
if !exist {
return fmt.Errorf("no client control found for run id [%s]", newMsg.RunID)
}
visitorUser = ctl.sessionCtx.LoginMsg.User
visitorUser = ctl.loginMsg.User
}
return svr.rc.VisitorManager.NewConn(newMsg.ProxyName, visitorConn, newMsg.Timestamp, newMsg.SignKey,
newMsg.UseEncryption, newMsg.UseCompression, visitorUser)

View File

@@ -2,7 +2,6 @@ package framework
import (
"fmt"
"maps"
"os"
"path/filepath"
"time"
@@ -21,7 +20,9 @@ func (f *Framework) RunProcesses(serverTemplates []string, clientTemplates []str
ExpectNoError(err)
ExpectTrue(len(templates) > 0)
maps.Copy(f.usedPorts, ports)
for name, port := range ports {
f.usedPorts[name] = port
}
currentServerProcesses := make([]*process.Process, 0, len(serverTemplates))
for i := range serverTemplates {

View File

@@ -26,8 +26,7 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
proxyType := t
ginkgo.It(fmt.Sprintf("Expose a %s echo server", strings.ToUpper(proxyType)), func() {
serverConf := consts.LegacyDefaultServerConfig
var clientConf strings.Builder
clientConf.WriteString(consts.LegacyDefaultClientConfig)
clientConf := consts.LegacyDefaultClientConfig
localPortName := ""
protocol := "tcp"
@@ -79,10 +78,10 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
// build all client config
for _, test := range tests {
clientConf.WriteString(getProxyConf(test.proxyName, test.portName, test.extraConfig) + "\n")
clientConf += getProxyConf(test.proxyName, test.portName, test.extraConfig) + "\n"
}
// run frps and frpc
f.RunProcesses([]string{serverConf}, []string{clientConf.String()})
f.RunProcesses([]string{serverConf}, []string{clientConf})
for _, test := range tests {
framework.NewRequestExpect(f).
@@ -103,8 +102,7 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
vhost_http_port = %d
`, vhostHTTPPort)
var clientConf strings.Builder
clientConf.WriteString(consts.LegacyDefaultClientConfig)
clientConf := consts.LegacyDefaultClientConfig
getProxyConf := func(proxyName string, customDomains string, extra string) string {
return fmt.Sprintf(`
@@ -149,13 +147,13 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
if tests[i].customDomains == "" {
tests[i].customDomains = test.proxyName + ".example.com"
}
clientConf.WriteString(getProxyConf(test.proxyName, tests[i].customDomains, test.extraConfig) + "\n")
clientConf += getProxyConf(test.proxyName, tests[i].customDomains, test.extraConfig) + "\n"
}
// run frps and frpc
f.RunProcesses([]string{serverConf}, []string{clientConf.String()})
f.RunProcesses([]string{serverConf}, []string{clientConf})
for _, test := range tests {
for domain := range strings.SplitSeq(test.customDomains, ",") {
for _, domain := range strings.Split(test.customDomains, ",") {
domain = strings.TrimSpace(domain)
framework.NewRequestExpect(f).
Explain(test.proxyName + "-" + domain).
@@ -187,8 +185,7 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
`, vhostHTTPSPort)
localPort := f.AllocPort()
var clientConf strings.Builder
clientConf.WriteString(consts.LegacyDefaultClientConfig)
clientConf := consts.LegacyDefaultClientConfig
getProxyConf := func(proxyName string, customDomains string, extra string) string {
return fmt.Sprintf(`
[%s]
@@ -232,10 +229,10 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
if tests[i].customDomains == "" {
tests[i].customDomains = test.proxyName + ".example.com"
}
clientConf.WriteString(getProxyConf(test.proxyName, tests[i].customDomains, test.extraConfig) + "\n")
clientConf += getProxyConf(test.proxyName, tests[i].customDomains, test.extraConfig) + "\n"
}
// run frps and frpc
f.RunProcesses([]string{serverConf}, []string{clientConf.String()})
f.RunProcesses([]string{serverConf}, []string{clientConf})
tlsConfig, err := transport.NewServerTLSConfig("", "", "")
framework.ExpectNoError(err)
@@ -247,7 +244,7 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
f.RunServer("", localServer)
for _, test := range tests {
for domain := range strings.SplitSeq(test.customDomains, ",") {
for _, domain := range strings.Split(test.customDomains, ",") {
domain = strings.TrimSpace(domain)
framework.NewRequestExpect(f).
Explain(test.proxyName + "-" + domain).
@@ -285,12 +282,9 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
proxyType := t
ginkgo.It(fmt.Sprintf("Expose echo server with %s", strings.ToUpper(proxyType)), func() {
serverConf := consts.LegacyDefaultServerConfig
var clientServerConf strings.Builder
clientServerConf.WriteString(consts.LegacyDefaultClientConfig + "\nuser = user1")
var clientVisitorConf strings.Builder
clientVisitorConf.WriteString(consts.LegacyDefaultClientConfig + "\nuser = user1")
var clientUser2VisitorConf strings.Builder
clientUser2VisitorConf.WriteString(consts.LegacyDefaultClientConfig + "\nuser = user2")
clientServerConf := consts.LegacyDefaultClientConfig + "\nuser = user1"
clientVisitorConf := consts.LegacyDefaultClientConfig + "\nuser = user1"
clientUser2VisitorConf := consts.LegacyDefaultClientConfig + "\nuser = user2"
localPortName := ""
protocol := "tcp"
@@ -406,20 +400,20 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
// build all client config
for _, test := range tests {
clientServerConf.WriteString(getProxyServerConf(test.proxyName, test.commonExtraConfig+"\n"+test.proxyExtraConfig) + "\n")
clientServerConf += getProxyServerConf(test.proxyName, test.commonExtraConfig+"\n"+test.proxyExtraConfig) + "\n"
}
for _, test := range tests {
config := getProxyVisitorConf(
test.proxyName, test.bindPortName, test.visitorSK, test.commonExtraConfig+"\n"+test.visitorExtraConfig,
) + "\n"
if test.deployUser2Client {
clientUser2VisitorConf.WriteString(config)
clientUser2VisitorConf += config
} else {
clientVisitorConf.WriteString(config)
clientVisitorConf += config
}
}
// run frps and frpc
f.RunProcesses([]string{serverConf}, []string{clientServerConf.String(), clientVisitorConf.String(), clientUser2VisitorConf.String()})
f.RunProcesses([]string{serverConf}, []string{clientServerConf, clientVisitorConf, clientUser2VisitorConf})
for _, test := range tests {
timeout := time.Second
@@ -446,8 +440,7 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
ginkgo.Describe("TCPMUX", func() {
ginkgo.It("Type tcpmux", func() {
serverConf := consts.LegacyDefaultServerConfig
var clientConf strings.Builder
clientConf.WriteString(consts.LegacyDefaultClientConfig)
clientConf := consts.LegacyDefaultClientConfig
tcpmuxHTTPConnectPortName := port.GenName("TCPMUX")
serverConf += fmt.Sprintf(`
@@ -490,14 +483,14 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
// build all client config
for _, test := range tests {
clientConf.WriteString(getProxyConf(test.proxyName, test.extraConfig) + "\n")
clientConf += getProxyConf(test.proxyName, test.extraConfig) + "\n"
localServer := streamserver.New(streamserver.TCP, streamserver.WithBindPort(f.AllocPort()), streamserver.WithRespContent([]byte(test.proxyName)))
f.RunServer(port.GenName(test.proxyName), localServer)
}
// run frps and frpc
f.RunProcesses([]string{serverConf}, []string{clientConf.String()})
f.RunProcesses([]string{serverConf}, []string{clientConf})
// Request without HTTP connect should get error
framework.NewRequestExpect(f).

View File

@@ -28,7 +28,7 @@ var _ = ginkgo.Describe("[Feature: Server Manager]", func() {
tcpPortName := port.GenName("TCP", port.WithRangePorts(10000, 11000))
udpPortName := port.GenName("UDP", port.WithRangePorts(12000, 13000))
clientConf += fmt.Sprintf(`
[tcp-allowed-in-range]
[tcp-allowded-in-range]
type = tcp
local_port = {{ .%s }}
remote_port = {{ .%s }}

View File

@@ -48,10 +48,12 @@ var _ = ginkgo.Describe("[Feature: Group]", func() {
return true
})
}
for range 10 {
wait.Go(func() {
for i := 0; i < 10; i++ {
wait.Add(1)
go func() {
defer wait.Done()
expectFn()
})
}()
}
wait.Wait()
@@ -92,7 +94,7 @@ var _ = ginkgo.Describe("[Feature: Group]", func() {
fooCount := 0
barCount := 0
for i := range 10 {
for i := 0; i < 10; i++ {
framework.NewRequestExpect(f).Explain("times " + strconv.Itoa(i)).Port(remotePort).Ensure(func(resp *request.Response) bool {
switch string(resp.Content) {
case "foo":
@@ -148,7 +150,7 @@ var _ = ginkgo.Describe("[Feature: Group]", func() {
// check foo and bar is ok
results := []string{}
for range 10 {
for i := 0; i < 10; i++ {
framework.NewRequestExpect(f).Port(remotePort).Ensure(validateFooBarResponse, func(resp *request.Response) bool {
results = append(results, string(resp.Content))
return true
@@ -159,7 +161,7 @@ var _ = ginkgo.Describe("[Feature: Group]", func() {
// close bar server, check foo is ok
barServer.Close()
time.Sleep(2 * time.Second)
for range 10 {
for i := 0; i < 10; i++ {
framework.NewRequestExpect(f).Port(remotePort).ExpectResp([]byte("foo")).Ensure()
}
@@ -167,7 +169,7 @@ var _ = ginkgo.Describe("[Feature: Group]", func() {
f.RunServer("", barServer)
time.Sleep(2 * time.Second)
results = []string{}
for range 10 {
for i := 0; i < 10; i++ {
framework.NewRequestExpect(f).Port(remotePort).Ensure(validateFooBarResponse, func(resp *request.Response) bool {
results = append(results, string(resp.Content))
return true

View File

@@ -4,7 +4,6 @@ import (
"crypto/tls"
"fmt"
"strconv"
"strings"
"github.com/onsi/ginkgo/v2"
@@ -23,8 +22,7 @@ var _ = ginkgo.Describe("[Feature: Client-Plugins]", func() {
ginkgo.Describe("UnixDomainSocket", func() {
ginkgo.It("Expose a unix domain socket echo server", func() {
serverConf := consts.LegacyDefaultServerConfig
var clientConf strings.Builder
clientConf.WriteString(consts.LegacyDefaultClientConfig)
clientConf := consts.LegacyDefaultClientConfig
getProxyConf := func(proxyName string, portName string, extra string) string {
return fmt.Sprintf(`
@@ -67,10 +65,10 @@ var _ = ginkgo.Describe("[Feature: Client-Plugins]", func() {
// build all client config
for _, test := range tests {
clientConf.WriteString(getProxyConf(test.proxyName, test.portName, test.extraConfig) + "\n")
clientConf += getProxyConf(test.proxyName, test.portName, test.extraConfig) + "\n"
}
// run frps and frpc
f.RunProcesses([]string{serverConf}, []string{clientConf.String()})
f.RunProcesses([]string{serverConf}, []string{clientConf})
for _, test := range tests {
framework.NewRequestExpect(f).Port(f.PortByName(test.portName)).Ensure()

View File

@@ -52,7 +52,7 @@ func (pa *Allocator) GetByName(portName string) int {
pa.mu.Lock()
defer pa.mu.Unlock()
for range 20 {
for i := 0; i < 20; i++ {
port := pa.getByRange(builder.rangePortFrom, builder.rangePortTo)
if port == 0 {
return 0

View File

@@ -75,11 +75,11 @@ func (c *TunnelClient) serveListener() {
if err != nil {
return
}
go c.handleConn(conn)
go c.hanldeConn(conn)
}
}
func (c *TunnelClient) handleConn(conn net.Conn) {
func (c *TunnelClient) hanldeConn(conn net.Conn) {
defer conn.Close()
local, err := net.Dial("tcp", c.localAddr)
if err != nil {

View File

@@ -26,8 +26,7 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
proxyType := t
ginkgo.It(fmt.Sprintf("Expose a %s echo server", strings.ToUpper(proxyType)), func() {
serverConf := consts.DefaultServerConfig
var clientConf strings.Builder
clientConf.WriteString(consts.DefaultClientConfig)
clientConf := consts.DefaultClientConfig
localPortName := ""
protocol := "tcp"
@@ -80,10 +79,10 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
// build all client config
for _, test := range tests {
clientConf.WriteString(getProxyConf(test.proxyName, test.portName, test.extraConfig) + "\n")
clientConf += getProxyConf(test.proxyName, test.portName, test.extraConfig) + "\n"
}
// run frps and frpc
f.RunProcesses([]string{serverConf}, []string{clientConf.String()})
f.RunProcesses([]string{serverConf}, []string{clientConf})
for _, test := range tests {
framework.NewRequestExpect(f).
@@ -104,8 +103,7 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
vhostHTTPPort = %d
`, vhostHTTPPort)
var clientConf strings.Builder
clientConf.WriteString(consts.DefaultClientConfig)
clientConf := consts.DefaultClientConfig
getProxyConf := func(proxyName string, customDomains string, extra string) string {
return fmt.Sprintf(`
@@ -151,13 +149,13 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
if tests[i].customDomains == "" {
tests[i].customDomains = fmt.Sprintf(`["%s"]`, test.proxyName+".example.com")
}
clientConf.WriteString(getProxyConf(test.proxyName, tests[i].customDomains, test.extraConfig) + "\n")
clientConf += getProxyConf(test.proxyName, tests[i].customDomains, test.extraConfig) + "\n"
}
// run frps and frpc
f.RunProcesses([]string{serverConf}, []string{clientConf.String()})
f.RunProcesses([]string{serverConf}, []string{clientConf})
for _, test := range tests {
for domain := range strings.SplitSeq(test.customDomains, ",") {
for _, domain := range strings.Split(test.customDomains, ",") {
domain = strings.TrimSpace(domain)
domain = strings.TrimLeft(domain, "[\"")
domain = strings.TrimRight(domain, "]\"")
@@ -191,8 +189,7 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
`, vhostHTTPSPort)
localPort := f.AllocPort()
var clientConf strings.Builder
clientConf.WriteString(consts.DefaultClientConfig)
clientConf := consts.DefaultClientConfig
getProxyConf := func(proxyName string, customDomains string, extra string) string {
return fmt.Sprintf(`
[[proxies]]
@@ -237,10 +234,10 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
if tests[i].customDomains == "" {
tests[i].customDomains = fmt.Sprintf(`["%s"]`, test.proxyName+".example.com")
}
clientConf.WriteString(getProxyConf(test.proxyName, tests[i].customDomains, test.extraConfig) + "\n")
clientConf += getProxyConf(test.proxyName, tests[i].customDomains, test.extraConfig) + "\n"
}
// run frps and frpc
f.RunProcesses([]string{serverConf}, []string{clientConf.String()})
f.RunProcesses([]string{serverConf}, []string{clientConf})
tlsConfig, err := transport.NewServerTLSConfig("", "", "")
framework.ExpectNoError(err)
@@ -252,7 +249,7 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
f.RunServer("", localServer)
for _, test := range tests {
for domain := range strings.SplitSeq(test.customDomains, ",") {
for _, domain := range strings.Split(test.customDomains, ",") {
domain = strings.TrimSpace(domain)
domain = strings.TrimLeft(domain, "[\"")
domain = strings.TrimRight(domain, "]\"")
@@ -292,12 +289,9 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
proxyType := t
ginkgo.It(fmt.Sprintf("Expose echo server with %s", strings.ToUpper(proxyType)), func() {
serverConf := consts.DefaultServerConfig
var clientServerConf strings.Builder
clientServerConf.WriteString(consts.DefaultClientConfig + "\nuser = \"user1\"")
var clientVisitorConf strings.Builder
clientVisitorConf.WriteString(consts.DefaultClientConfig + "\nuser = \"user1\"")
var clientUser2VisitorConf strings.Builder
clientUser2VisitorConf.WriteString(consts.DefaultClientConfig + "\nuser = \"user2\"")
clientServerConf := consts.DefaultClientConfig + "\nuser = \"user1\""
clientVisitorConf := consts.DefaultClientConfig + "\nuser = \"user1\""
clientUser2VisitorConf := consts.DefaultClientConfig + "\nuser = \"user2\""
localPortName := ""
protocol := "tcp"
@@ -413,20 +407,20 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
// build all client config
for _, test := range tests {
clientServerConf.WriteString(getProxyServerConf(test.proxyName, test.commonExtraConfig+"\n"+test.proxyExtraConfig) + "\n")
clientServerConf += getProxyServerConf(test.proxyName, test.commonExtraConfig+"\n"+test.proxyExtraConfig) + "\n"
}
for _, test := range tests {
config := getProxyVisitorConf(
test.proxyName, test.bindPortName, test.visitorSK, test.commonExtraConfig+"\n"+test.visitorExtraConfig,
) + "\n"
if test.deployUser2Client {
clientUser2VisitorConf.WriteString(config)
clientUser2VisitorConf += config
} else {
clientVisitorConf.WriteString(config)
clientVisitorConf += config
}
}
// run frps and frpc
f.RunProcesses([]string{serverConf}, []string{clientServerConf.String(), clientVisitorConf.String(), clientUser2VisitorConf.String()})
f.RunProcesses([]string{serverConf}, []string{clientServerConf, clientVisitorConf, clientUser2VisitorConf})
for _, test := range tests {
timeout := time.Second
@@ -453,8 +447,7 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
ginkgo.Describe("TCPMUX", func() {
ginkgo.It("Type tcpmux", func() {
serverConf := consts.DefaultServerConfig
var clientConf strings.Builder
clientConf.WriteString(consts.DefaultClientConfig)
clientConf := consts.DefaultClientConfig
tcpmuxHTTPConnectPortName := port.GenName("TCPMUX")
serverConf += fmt.Sprintf(`
@@ -498,14 +491,14 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
// build all client config
for _, test := range tests {
clientConf.WriteString(getProxyConf(test.proxyName, test.extraConfig) + "\n")
clientConf += getProxyConf(test.proxyName, test.extraConfig) + "\n"
localServer := streamserver.New(streamserver.TCP, streamserver.WithBindPort(f.AllocPort()), streamserver.WithRespContent([]byte(test.proxyName)))
f.RunServer(port.GenName(test.proxyName), localServer)
}
// run frps and frpc
f.RunProcesses([]string{serverConf}, []string{clientConf.String()})
f.RunProcesses([]string{serverConf}, []string{clientConf})
// Request without HTTP connect should get error
framework.NewRequestExpect(f).

View File

@@ -33,7 +33,7 @@ var _ = ginkgo.Describe("[Feature: Server Manager]", func() {
udpPortName := port.GenName("UDP", port.WithRangePorts(12000, 13000))
clientConf += fmt.Sprintf(`
[[proxies]]
name = "tcp-allowed-in-range"
name = "tcp-allowded-in-range"
type = "tcp"
localPort = {{ .%s }}
remotePort = {{ .%s }}

View File

@@ -50,10 +50,12 @@ var _ = ginkgo.Describe("[Feature: Group]", func() {
return true
})
}
for range 10 {
wait.Go(func() {
for i := 0; i < 10; i++ {
wait.Add(1)
go func() {
defer wait.Done()
expectFn()
})
}()
}
wait.Wait()
@@ -96,7 +98,7 @@ var _ = ginkgo.Describe("[Feature: Group]", func() {
fooCount := 0
barCount := 0
for i := range 10 {
for i := 0; i < 10; i++ {
framework.NewRequestExpect(f).Explain("times " + strconv.Itoa(i)).Port(remotePort).Ensure(func(resp *request.Response) bool {
switch string(resp.Content) {
case "foo":
@@ -161,7 +163,7 @@ var _ = ginkgo.Describe("[Feature: Group]", func() {
fooCount := 0
barCount := 0
for i := range 10 {
for i := 0; i < 10; i++ {
framework.NewRequestExpect(f).
Explain("times " + strconv.Itoa(i)).
Port(vhostHTTPSPort).
@@ -228,7 +230,7 @@ var _ = ginkgo.Describe("[Feature: Group]", func() {
// check foo and bar is ok
results := []string{}
for range 10 {
for i := 0; i < 10; i++ {
framework.NewRequestExpect(f).Port(remotePort).Ensure(validateFooBarResponse, func(resp *request.Response) bool {
results = append(results, string(resp.Content))
return true
@@ -239,7 +241,7 @@ var _ = ginkgo.Describe("[Feature: Group]", func() {
// close bar server, check foo is ok
barServer.Close()
time.Sleep(2 * time.Second)
for range 10 {
for i := 0; i < 10; i++ {
framework.NewRequestExpect(f).Port(remotePort).ExpectResp([]byte("foo")).Ensure()
}
@@ -247,7 +249,7 @@ var _ = ginkgo.Describe("[Feature: Group]", func() {
f.RunServer("", barServer)
time.Sleep(2 * time.Second)
results = []string{}
for range 10 {
for i := 0; i < 10; i++ {
framework.NewRequestExpect(f).Port(remotePort).Ensure(validateFooBarResponse, func(resp *request.Response) bool {
results = append(results, string(resp.Content))
return true

View File

@@ -5,7 +5,6 @@ import (
"fmt"
"net/http"
"strconv"
"strings"
"github.com/onsi/ginkgo/v2"
@@ -24,8 +23,7 @@ var _ = ginkgo.Describe("[Feature: Client-Plugins]", func() {
ginkgo.Describe("UnixDomainSocket", func() {
ginkgo.It("Expose a unix domain socket echo server", func() {
serverConf := consts.DefaultServerConfig
var clientConf strings.Builder
clientConf.WriteString(consts.DefaultClientConfig)
clientConf := consts.DefaultClientConfig
getProxyConf := func(proxyName string, portName string, extra string) string {
return fmt.Sprintf(`
@@ -71,10 +69,10 @@ var _ = ginkgo.Describe("[Feature: Client-Plugins]", func() {
// build all client config
for _, test := range tests {
clientConf.WriteString(getProxyConf(test.proxyName, test.portName, test.extraConfig) + "\n")
clientConf += getProxyConf(test.proxyName, test.portName, test.extraConfig) + "\n"
}
// run frps and frpc
f.RunProcesses([]string{serverConf}, []string{clientConf.String()})
f.RunProcesses([]string{serverConf}, []string{clientConf})
for _, test := range tests {
framework.NewRequestExpect(f).Port(f.PortByName(test.portName)).Ensure()