mirror of
https://github.com/fatedier/frp.git
synced 2026-03-09 11:29:11 +08:00
Compare commits
24 Commits
copilot/re
...
dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
48e8901466 | ||
|
|
bcd2424c24 | ||
|
|
c7ac12ea0f | ||
|
|
eeb0dacfc1 | ||
|
|
535eb3db35 | ||
|
|
605f3bdece | ||
|
|
764a626b6e | ||
|
|
c2454e7114 | ||
|
|
017d71717f | ||
|
|
bd200b1a3b | ||
|
|
c70ceff370 | ||
|
|
bb3d0e7140 | ||
|
|
cf396563f8 | ||
|
|
0b4f83cd04 | ||
|
|
e9f7a1a9f2 | ||
|
|
d644593342 | ||
|
|
427c4ca3ae | ||
|
|
f2d1f3739a | ||
|
|
c23894f156 | ||
|
|
cb459b02b6 | ||
|
|
8f633fe363 | ||
|
|
c62a1da161 | ||
|
|
f22f7d539c | ||
|
|
462c987f6d |
@@ -18,6 +18,7 @@ linters:
|
||||
- lll
|
||||
- makezero
|
||||
- misspell
|
||||
- modernize
|
||||
- prealloc
|
||||
- predeclared
|
||||
- revive
|
||||
@@ -47,6 +48,9 @@ linters:
|
||||
ignore-rules:
|
||||
- cancelled
|
||||
- marshalled
|
||||
modernize:
|
||||
disable:
|
||||
- omitzero
|
||||
unparam:
|
||||
check-exported: false
|
||||
exclusions:
|
||||
|
||||
@@ -16,6 +16,7 @@ package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"reflect"
|
||||
@@ -122,6 +123,33 @@ func (pxy *BaseProxy) Close() {
|
||||
}
|
||||
}
|
||||
|
||||
// wrapWorkConn applies rate limiting, encryption, and compression
|
||||
// to a work connection based on the proxy's transport configuration.
|
||||
// The returned recycle function should be called when the stream is no longer in use
|
||||
// to return compression resources to the pool. It is safe to not call recycle,
|
||||
// in which case resources will be garbage collected normally.
|
||||
func (pxy *BaseProxy) wrapWorkConn(conn net.Conn, encKey []byte) (io.ReadWriteCloser, func(), error) {
|
||||
var rwc io.ReadWriteCloser = conn
|
||||
if pxy.limiter != nil {
|
||||
rwc = libio.WrapReadWriteCloser(limit.NewReader(conn, pxy.limiter), limit.NewWriter(conn, pxy.limiter), func() error {
|
||||
return conn.Close()
|
||||
})
|
||||
}
|
||||
if pxy.baseCfg.Transport.UseEncryption {
|
||||
var err error
|
||||
rwc, err = libio.WithEncryption(rwc, encKey)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, nil, fmt.Errorf("create encryption stream error: %w", err)
|
||||
}
|
||||
}
|
||||
var recycleFn func()
|
||||
if pxy.baseCfg.Transport.UseCompression {
|
||||
rwc, recycleFn = libio.WithCompressionFromPool(rwc)
|
||||
}
|
||||
return rwc, recycleFn, nil
|
||||
}
|
||||
|
||||
func (pxy *BaseProxy) SetInWorkConnCallback(cb func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool) {
|
||||
pxy.inWorkConnCallback = cb
|
||||
}
|
||||
@@ -139,30 +167,14 @@ func (pxy *BaseProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) {
|
||||
func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWorkConn, encKey []byte) {
|
||||
xl := pxy.xl
|
||||
baseCfg := pxy.baseCfg
|
||||
var (
|
||||
remote io.ReadWriteCloser
|
||||
err error
|
||||
)
|
||||
remote = workConn
|
||||
if pxy.limiter != nil {
|
||||
remote = libio.WrapReadWriteCloser(limit.NewReader(workConn, pxy.limiter), limit.NewWriter(workConn, pxy.limiter), func() error {
|
||||
return workConn.Close()
|
||||
})
|
||||
}
|
||||
|
||||
xl.Tracef("handle tcp work connection, useEncryption: %t, useCompression: %t",
|
||||
baseCfg.Transport.UseEncryption, baseCfg.Transport.UseCompression)
|
||||
if baseCfg.Transport.UseEncryption {
|
||||
remote, err = libio.WithEncryption(remote, encKey)
|
||||
if err != nil {
|
||||
workConn.Close()
|
||||
xl.Errorf("create encryption stream error: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
var compressionResourceRecycleFn func()
|
||||
if baseCfg.Transport.UseCompression {
|
||||
remote, compressionResourceRecycleFn = libio.WithCompressionFromPool(remote)
|
||||
|
||||
remote, recycleFn, err := pxy.wrapWorkConn(workConn, encKey)
|
||||
if err != nil {
|
||||
xl.Errorf("wrap work connection: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// check if we need to send proxy protocol info
|
||||
@@ -178,7 +190,6 @@ func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWor
|
||||
}
|
||||
|
||||
if baseCfg.Transport.ProxyProtocolVersion != "" && m.SrcAddr != "" && m.SrcPort != 0 {
|
||||
// Use the common proxy protocol builder function
|
||||
header := netpkg.BuildProxyProtocolHeaderStruct(connInfo.SrcAddr, connInfo.DstAddr, baseCfg.Transport.ProxyProtocolVersion)
|
||||
connInfo.ProxyProtocolHeader = header
|
||||
}
|
||||
@@ -187,12 +198,18 @@ func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWor
|
||||
|
||||
if pxy.proxyPlugin != nil {
|
||||
// if plugin is set, let plugin handle connection first
|
||||
// Don't recycle compression resources here because plugins may
|
||||
// retain the connection after Handle returns.
|
||||
xl.Debugf("handle by plugin: %s", pxy.proxyPlugin.Name())
|
||||
pxy.proxyPlugin.Handle(pxy.ctx, &connInfo)
|
||||
xl.Debugf("handle by plugin finished")
|
||||
return
|
||||
}
|
||||
|
||||
if recycleFn != nil {
|
||||
defer recycleFn()
|
||||
}
|
||||
|
||||
localConn, err := libnet.Dial(
|
||||
net.JoinHostPort(baseCfg.LocalIP, strconv.Itoa(baseCfg.LocalPort)),
|
||||
libnet.WithTimeout(10*time.Second),
|
||||
@@ -209,6 +226,7 @@ func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWor
|
||||
if connInfo.ProxyProtocolHeader != nil {
|
||||
if _, err := connInfo.ProxyProtocolHeader.WriteTo(localConn); err != nil {
|
||||
workConn.Close()
|
||||
localConn.Close()
|
||||
xl.Errorf("write proxy protocol header to local conn error: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -219,7 +237,4 @@ func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWor
|
||||
if len(errs) > 0 {
|
||||
xl.Tracef("join connections errors: %v", errs)
|
||||
}
|
||||
if compressionResourceRecycleFn != nil {
|
||||
compressionResourceRecycleFn()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"reflect"
|
||||
"strconv"
|
||||
@@ -25,17 +24,15 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/golib/errors"
|
||||
libio "github.com/fatedier/golib/io"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
"github.com/fatedier/frp/pkg/proto/udp"
|
||||
"github.com/fatedier/frp/pkg/util/limit"
|
||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||
)
|
||||
|
||||
func init() {
|
||||
RegisterProxyFactory(reflect.TypeOf(&v1.SUDPProxyConfig{}), NewSUDPProxy)
|
||||
RegisterProxyFactory(reflect.TypeFor[*v1.SUDPProxyConfig](), NewSUDPProxy)
|
||||
}
|
||||
|
||||
type SUDPProxy struct {
|
||||
@@ -83,27 +80,13 @@ func (pxy *SUDPProxy) InWorkConn(conn net.Conn, _ *msg.StartWorkConn) {
|
||||
xl := pxy.xl
|
||||
xl.Infof("incoming a new work connection for sudp proxy, %s", conn.RemoteAddr().String())
|
||||
|
||||
var rwc io.ReadWriteCloser = conn
|
||||
var err error
|
||||
if pxy.limiter != nil {
|
||||
rwc = libio.WrapReadWriteCloser(limit.NewReader(conn, pxy.limiter), limit.NewWriter(conn, pxy.limiter), func() error {
|
||||
return conn.Close()
|
||||
})
|
||||
remote, _, err := pxy.wrapWorkConn(conn, pxy.encryptionKey)
|
||||
if err != nil {
|
||||
xl.Errorf("wrap work connection: %v", err)
|
||||
return
|
||||
}
|
||||
if pxy.cfg.Transport.UseEncryption {
|
||||
rwc, err = libio.WithEncryption(rwc, pxy.encryptionKey)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
xl.Errorf("create encryption stream error: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
if pxy.cfg.Transport.UseCompression {
|
||||
rwc = libio.WithCompression(rwc)
|
||||
}
|
||||
conn = netpkg.WrapReadWriteCloserToConn(rwc, conn)
|
||||
|
||||
workConn := conn
|
||||
workConn := netpkg.WrapReadWriteCloserToConn(remote, conn)
|
||||
readCh := make(chan *msg.UDPPacket, 1024)
|
||||
sendCh := make(chan msg.Message, 1024)
|
||||
isClose := false
|
||||
|
||||
@@ -17,24 +17,21 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/golib/errors"
|
||||
libio "github.com/fatedier/golib/io"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
"github.com/fatedier/frp/pkg/proto/udp"
|
||||
"github.com/fatedier/frp/pkg/util/limit"
|
||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||
)
|
||||
|
||||
func init() {
|
||||
RegisterProxyFactory(reflect.TypeOf(&v1.UDPProxyConfig{}), NewUDPProxy)
|
||||
RegisterProxyFactory(reflect.TypeFor[*v1.UDPProxyConfig](), NewUDPProxy)
|
||||
}
|
||||
|
||||
type UDPProxy struct {
|
||||
@@ -94,28 +91,14 @@ func (pxy *UDPProxy) InWorkConn(conn net.Conn, _ *msg.StartWorkConn) {
|
||||
// close resources related with old workConn
|
||||
pxy.Close()
|
||||
|
||||
var rwc io.ReadWriteCloser = conn
|
||||
var err error
|
||||
if pxy.limiter != nil {
|
||||
rwc = libio.WrapReadWriteCloser(limit.NewReader(conn, pxy.limiter), limit.NewWriter(conn, pxy.limiter), func() error {
|
||||
return conn.Close()
|
||||
})
|
||||
remote, _, err := pxy.wrapWorkConn(conn, pxy.encryptionKey)
|
||||
if err != nil {
|
||||
xl.Errorf("wrap work connection: %v", err)
|
||||
return
|
||||
}
|
||||
if pxy.cfg.Transport.UseEncryption {
|
||||
rwc, err = libio.WithEncryption(rwc, pxy.encryptionKey)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
xl.Errorf("create encryption stream error: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
if pxy.cfg.Transport.UseCompression {
|
||||
rwc = libio.WithCompression(rwc)
|
||||
}
|
||||
conn = netpkg.WrapReadWriteCloserToConn(rwc, conn)
|
||||
|
||||
pxy.mu.Lock()
|
||||
pxy.workConn = conn
|
||||
pxy.workConn = netpkg.WrapReadWriteCloserToConn(remote, conn)
|
||||
pxy.readCh = make(chan *msg.UDPPacket, 1024)
|
||||
pxy.sendCh = make(chan msg.Message, 1024)
|
||||
pxy.closed = false
|
||||
@@ -129,7 +112,7 @@ func (pxy *UDPProxy) InWorkConn(conn net.Conn, _ *msg.StartWorkConn) {
|
||||
return
|
||||
}
|
||||
if errRet := errors.PanicToError(func() {
|
||||
xl.Tracef("get udp package from workConn: %s", udpMsg.Content)
|
||||
xl.Tracef("get udp package from workConn, len: %d", len(udpMsg.Content))
|
||||
readCh <- &udpMsg
|
||||
}); errRet != nil {
|
||||
xl.Infof("reader goroutine for udp work connection closed: %v", errRet)
|
||||
@@ -145,7 +128,7 @@ func (pxy *UDPProxy) InWorkConn(conn net.Conn, _ *msg.StartWorkConn) {
|
||||
for rawMsg := range sendCh {
|
||||
switch m := rawMsg.(type) {
|
||||
case *msg.UDPPacket:
|
||||
xl.Tracef("send udp package to workConn: %s", m.Content)
|
||||
xl.Tracef("send udp package to workConn, len: %d", len(m.Content))
|
||||
case *msg.Ping:
|
||||
xl.Tracef("send ping message to udp workConn")
|
||||
}
|
||||
|
||||
@@ -34,7 +34,7 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
RegisterProxyFactory(reflect.TypeOf(&v1.XTCPProxyConfig{}), NewXTCPProxy)
|
||||
RegisterProxyFactory(reflect.TypeFor[*v1.XTCPProxyConfig](), NewXTCPProxy)
|
||||
}
|
||||
|
||||
type XTCPProxy struct {
|
||||
|
||||
@@ -15,18 +15,12 @@
|
||||
package visitor
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
libio "github.com/fatedier/golib/io"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
"github.com/fatedier/frp/pkg/naming"
|
||||
"github.com/fatedier/frp/pkg/util/util"
|
||||
"github.com/fatedier/frp/pkg/util/xlog"
|
||||
)
|
||||
|
||||
@@ -42,10 +36,10 @@ func (sv *STCPVisitor) Run() (err error) {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go sv.worker()
|
||||
go sv.acceptLoop(sv.l, "stcp local", sv.handleConn)
|
||||
}
|
||||
|
||||
go sv.internalConnWorker()
|
||||
go sv.acceptLoop(sv.internalLn, "stcp internal", sv.handleConn)
|
||||
|
||||
if sv.plugin != nil {
|
||||
sv.plugin.Start()
|
||||
@@ -57,35 +51,10 @@ func (sv *STCPVisitor) Close() {
|
||||
sv.BaseVisitor.Close()
|
||||
}
|
||||
|
||||
func (sv *STCPVisitor) worker() {
|
||||
xl := xlog.FromContextSafe(sv.ctx)
|
||||
for {
|
||||
conn, err := sv.l.Accept()
|
||||
if err != nil {
|
||||
xl.Warnf("stcp local listener closed")
|
||||
return
|
||||
}
|
||||
go sv.handleConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (sv *STCPVisitor) internalConnWorker() {
|
||||
xl := xlog.FromContextSafe(sv.ctx)
|
||||
for {
|
||||
conn, err := sv.internalLn.Accept()
|
||||
if err != nil {
|
||||
xl.Warnf("stcp internal listener closed")
|
||||
return
|
||||
}
|
||||
go sv.handleConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (sv *STCPVisitor) handleConn(userConn net.Conn) {
|
||||
xl := xlog.FromContextSafe(sv.ctx)
|
||||
var tunnelErr error
|
||||
defer func() {
|
||||
// If there was an error and connection supports CloseWithError, use it
|
||||
if tunnelErr != nil {
|
||||
if eConn, ok := userConn.(interface{ CloseWithError(error) error }); ok {
|
||||
_ = eConn.CloseWithError(tunnelErr)
|
||||
@@ -96,62 +65,21 @@ func (sv *STCPVisitor) handleConn(userConn net.Conn) {
|
||||
}()
|
||||
|
||||
xl.Debugf("get a new stcp user connection")
|
||||
visitorConn, err := sv.helper.ConnectServer()
|
||||
visitorConn, err := sv.dialRawVisitorConn(sv.cfg.GetBaseConfig())
|
||||
if err != nil {
|
||||
xl.Warnf("dialRawVisitorConn error: %v", err)
|
||||
tunnelErr = err
|
||||
return
|
||||
}
|
||||
defer visitorConn.Close()
|
||||
|
||||
now := time.Now().Unix()
|
||||
targetProxyName := naming.BuildTargetServerProxyName(sv.clientCfg.User, sv.cfg.ServerUser, sv.cfg.ServerName)
|
||||
newVisitorConnMsg := &msg.NewVisitorConn{
|
||||
RunID: sv.helper.RunID(),
|
||||
ProxyName: targetProxyName,
|
||||
SignKey: util.GetAuthKey(sv.cfg.SecretKey, now),
|
||||
Timestamp: now,
|
||||
UseEncryption: sv.cfg.Transport.UseEncryption,
|
||||
UseCompression: sv.cfg.Transport.UseCompression,
|
||||
}
|
||||
err = msg.WriteMsg(visitorConn, newVisitorConnMsg)
|
||||
remote, recycleFn, err := wrapVisitorConn(visitorConn, sv.cfg.GetBaseConfig())
|
||||
if err != nil {
|
||||
xl.Warnf("send newVisitorConnMsg to server error: %v", err)
|
||||
xl.Warnf("wrapVisitorConn error: %v", err)
|
||||
tunnelErr = err
|
||||
return
|
||||
}
|
||||
|
||||
var newVisitorConnRespMsg msg.NewVisitorConnResp
|
||||
_ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
||||
err = msg.ReadMsgInto(visitorConn, &newVisitorConnRespMsg)
|
||||
if err != nil {
|
||||
xl.Warnf("get newVisitorConnRespMsg error: %v", err)
|
||||
tunnelErr = err
|
||||
return
|
||||
}
|
||||
_ = visitorConn.SetReadDeadline(time.Time{})
|
||||
|
||||
if newVisitorConnRespMsg.Error != "" {
|
||||
xl.Warnf("start new visitor connection error: %s", newVisitorConnRespMsg.Error)
|
||||
tunnelErr = fmt.Errorf("%s", newVisitorConnRespMsg.Error)
|
||||
return
|
||||
}
|
||||
|
||||
var remote io.ReadWriteCloser
|
||||
remote = visitorConn
|
||||
if sv.cfg.Transport.UseEncryption {
|
||||
remote, err = libio.WithEncryption(remote, []byte(sv.cfg.SecretKey))
|
||||
if err != nil {
|
||||
xl.Errorf("create encryption stream error: %v", err)
|
||||
tunnelErr = err
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if sv.cfg.Transport.UseCompression {
|
||||
var recycleFn func()
|
||||
remote, recycleFn = libio.WithCompressionFromPool(remote)
|
||||
defer recycleFn()
|
||||
}
|
||||
defer recycleFn()
|
||||
|
||||
libio.Join(userConn, remote)
|
||||
}
|
||||
|
||||
@@ -16,21 +16,17 @@ package visitor
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/golib/errors"
|
||||
libio "github.com/fatedier/golib/io"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
"github.com/fatedier/frp/pkg/naming"
|
||||
"github.com/fatedier/frp/pkg/proto/udp"
|
||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||
"github.com/fatedier/frp/pkg/util/util"
|
||||
"github.com/fatedier/frp/pkg/util/xlog"
|
||||
)
|
||||
|
||||
@@ -76,6 +72,7 @@ func (sv *SUDPVisitor) dispatcher() {
|
||||
|
||||
var (
|
||||
visitorConn net.Conn
|
||||
recycleFn func()
|
||||
err error
|
||||
|
||||
firstPacket *msg.UDPPacket
|
||||
@@ -93,14 +90,17 @@ func (sv *SUDPVisitor) dispatcher() {
|
||||
return
|
||||
}
|
||||
|
||||
visitorConn, err = sv.getNewVisitorConn()
|
||||
visitorConn, recycleFn, err = sv.getNewVisitorConn()
|
||||
if err != nil {
|
||||
xl.Warnf("newVisitorConn to frps error: %v, try to reconnect", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// visitorConn always be closed when worker done.
|
||||
sv.worker(visitorConn, firstPacket)
|
||||
func() {
|
||||
defer recycleFn()
|
||||
sv.worker(visitorConn, firstPacket)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-sv.checkCloseCh:
|
||||
@@ -147,7 +147,7 @@ func (sv *SUDPVisitor) worker(workConn net.Conn, firstPacket *msg.UDPPacket) {
|
||||
case *msg.UDPPacket:
|
||||
if errRet := errors.PanicToError(func() {
|
||||
sv.readCh <- m
|
||||
xl.Tracef("frpc visitor get udp packet from workConn: %s", m.Content)
|
||||
xl.Tracef("frpc visitor get udp packet from workConn, len: %d", len(m.Content))
|
||||
}); errRet != nil {
|
||||
xl.Infof("reader goroutine for udp work connection closed")
|
||||
return
|
||||
@@ -169,7 +169,7 @@ func (sv *SUDPVisitor) worker(workConn net.Conn, firstPacket *msg.UDPPacket) {
|
||||
xl.Warnf("sender goroutine for udp work connection closed: %v", errRet)
|
||||
return
|
||||
}
|
||||
xl.Tracef("send udp package to workConn: %s", firstPacket.Content)
|
||||
xl.Tracef("send udp package to workConn, len: %d", len(firstPacket.Content))
|
||||
}
|
||||
|
||||
for {
|
||||
@@ -184,7 +184,7 @@ func (sv *SUDPVisitor) worker(workConn net.Conn, firstPacket *msg.UDPPacket) {
|
||||
xl.Warnf("sender goroutine for udp work connection closed: %v", errRet)
|
||||
return
|
||||
}
|
||||
xl.Tracef("send udp package to workConn: %s", udpMsg.Content)
|
||||
xl.Tracef("send udp package to workConn, len: %d", len(udpMsg.Content))
|
||||
case <-closeCh:
|
||||
return
|
||||
}
|
||||
@@ -198,53 +198,17 @@ func (sv *SUDPVisitor) worker(workConn net.Conn, firstPacket *msg.UDPPacket) {
|
||||
xl.Infof("sudp worker is closed")
|
||||
}
|
||||
|
||||
func (sv *SUDPVisitor) getNewVisitorConn() (net.Conn, error) {
|
||||
xl := xlog.FromContextSafe(sv.ctx)
|
||||
visitorConn, err := sv.helper.ConnectServer()
|
||||
func (sv *SUDPVisitor) getNewVisitorConn() (net.Conn, func(), error) {
|
||||
rawConn, err := sv.dialRawVisitorConn(sv.cfg.GetBaseConfig())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("frpc connect frps error: %v", err)
|
||||
return nil, func() {}, err
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
targetProxyName := naming.BuildTargetServerProxyName(sv.clientCfg.User, sv.cfg.ServerUser, sv.cfg.ServerName)
|
||||
newVisitorConnMsg := &msg.NewVisitorConn{
|
||||
RunID: sv.helper.RunID(),
|
||||
ProxyName: targetProxyName,
|
||||
SignKey: util.GetAuthKey(sv.cfg.SecretKey, now),
|
||||
Timestamp: now,
|
||||
UseEncryption: sv.cfg.Transport.UseEncryption,
|
||||
UseCompression: sv.cfg.Transport.UseCompression,
|
||||
}
|
||||
err = msg.WriteMsg(visitorConn, newVisitorConnMsg)
|
||||
rwc, recycleFn, err := wrapVisitorConn(rawConn, sv.cfg.GetBaseConfig())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("frpc send newVisitorConnMsg to frps error: %v", err)
|
||||
rawConn.Close()
|
||||
return nil, func() {}, err
|
||||
}
|
||||
|
||||
var newVisitorConnRespMsg msg.NewVisitorConnResp
|
||||
_ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
||||
err = msg.ReadMsgInto(visitorConn, &newVisitorConnRespMsg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("frpc read newVisitorConnRespMsg error: %v", err)
|
||||
}
|
||||
_ = visitorConn.SetReadDeadline(time.Time{})
|
||||
|
||||
if newVisitorConnRespMsg.Error != "" {
|
||||
return nil, fmt.Errorf("start new visitor connection error: %s", newVisitorConnRespMsg.Error)
|
||||
}
|
||||
|
||||
var remote io.ReadWriteCloser
|
||||
remote = visitorConn
|
||||
if sv.cfg.Transport.UseEncryption {
|
||||
remote, err = libio.WithEncryption(remote, []byte(sv.cfg.SecretKey))
|
||||
if err != nil {
|
||||
xl.Errorf("create encryption stream error: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if sv.cfg.Transport.UseCompression {
|
||||
remote = libio.WithCompression(remote)
|
||||
}
|
||||
return netpkg.WrapReadWriteCloserToConn(remote, visitorConn), nil
|
||||
return netpkg.WrapReadWriteCloserToConn(rwc, rawConn), recycleFn, nil
|
||||
}
|
||||
|
||||
func (sv *SUDPVisitor) Close() {
|
||||
|
||||
@@ -16,13 +16,21 @@ package visitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
libio "github.com/fatedier/golib/io"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
"github.com/fatedier/frp/pkg/naming"
|
||||
plugin "github.com/fatedier/frp/pkg/plugin/visitor"
|
||||
"github.com/fatedier/frp/pkg/transport"
|
||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||
"github.com/fatedier/frp/pkg/util/util"
|
||||
"github.com/fatedier/frp/pkg/util/xlog"
|
||||
"github.com/fatedier/frp/pkg/vnet"
|
||||
)
|
||||
@@ -119,6 +127,18 @@ func (v *BaseVisitor) AcceptConn(conn net.Conn) error {
|
||||
return v.internalLn.PutConn(conn)
|
||||
}
|
||||
|
||||
func (v *BaseVisitor) acceptLoop(l net.Listener, name string, handleConn func(net.Conn)) {
|
||||
xl := xlog.FromContextSafe(v.ctx)
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
xl.Warnf("%s listener closed", name)
|
||||
return
|
||||
}
|
||||
go handleConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (v *BaseVisitor) Close() {
|
||||
if v.l != nil {
|
||||
v.l.Close()
|
||||
@@ -130,3 +150,57 @@ func (v *BaseVisitor) Close() {
|
||||
v.plugin.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (v *BaseVisitor) dialRawVisitorConn(cfg *v1.VisitorBaseConfig) (net.Conn, error) {
|
||||
visitorConn, err := v.helper.ConnectServer()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connect to server error: %v", err)
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
targetProxyName := naming.BuildTargetServerProxyName(v.clientCfg.User, cfg.ServerUser, cfg.ServerName)
|
||||
newVisitorConnMsg := &msg.NewVisitorConn{
|
||||
RunID: v.helper.RunID(),
|
||||
ProxyName: targetProxyName,
|
||||
SignKey: util.GetAuthKey(cfg.SecretKey, now),
|
||||
Timestamp: now,
|
||||
UseEncryption: cfg.Transport.UseEncryption,
|
||||
UseCompression: cfg.Transport.UseCompression,
|
||||
}
|
||||
err = msg.WriteMsg(visitorConn, newVisitorConnMsg)
|
||||
if err != nil {
|
||||
visitorConn.Close()
|
||||
return nil, fmt.Errorf("send newVisitorConnMsg to server error: %v", err)
|
||||
}
|
||||
|
||||
var newVisitorConnRespMsg msg.NewVisitorConnResp
|
||||
_ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
||||
err = msg.ReadMsgInto(visitorConn, &newVisitorConnRespMsg)
|
||||
if err != nil {
|
||||
visitorConn.Close()
|
||||
return nil, fmt.Errorf("read newVisitorConnRespMsg error: %v", err)
|
||||
}
|
||||
_ = visitorConn.SetReadDeadline(time.Time{})
|
||||
|
||||
if newVisitorConnRespMsg.Error != "" {
|
||||
visitorConn.Close()
|
||||
return nil, fmt.Errorf("start new visitor connection error: %s", newVisitorConnRespMsg.Error)
|
||||
}
|
||||
return visitorConn, nil
|
||||
}
|
||||
|
||||
func wrapVisitorConn(conn io.ReadWriteCloser, cfg *v1.VisitorBaseConfig) (io.ReadWriteCloser, func(), error) {
|
||||
rwc := conn
|
||||
if cfg.Transport.UseEncryption {
|
||||
var err error
|
||||
rwc, err = libio.WithEncryption(rwc, []byte(cfg.SecretKey))
|
||||
if err != nil {
|
||||
return nil, func() {}, fmt.Errorf("create encryption stream error: %v", err)
|
||||
}
|
||||
}
|
||||
recycleFn := func() {}
|
||||
if cfg.Transport.UseCompression {
|
||||
rwc, recycleFn = libio.WithCompressionFromPool(rwc)
|
||||
}
|
||||
return rwc, recycleFn, nil
|
||||
}
|
||||
|
||||
@@ -65,10 +65,10 @@ func (sv *XTCPVisitor) Run() (err error) {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go sv.worker()
|
||||
go sv.acceptLoop(sv.l, "xtcp local", sv.handleConn)
|
||||
}
|
||||
|
||||
go sv.internalConnWorker()
|
||||
go sv.acceptLoop(sv.internalLn, "xtcp internal", sv.handleConn)
|
||||
go sv.processTunnelStartEvents()
|
||||
if sv.cfg.KeepTunnelOpen {
|
||||
sv.retryLimiter = rate.NewLimiter(rate.Every(time.Hour/time.Duration(sv.cfg.MaxRetriesAnHour)), sv.cfg.MaxRetriesAnHour)
|
||||
@@ -93,30 +93,6 @@ func (sv *XTCPVisitor) Close() {
|
||||
}
|
||||
}
|
||||
|
||||
func (sv *XTCPVisitor) worker() {
|
||||
xl := xlog.FromContextSafe(sv.ctx)
|
||||
for {
|
||||
conn, err := sv.l.Accept()
|
||||
if err != nil {
|
||||
xl.Warnf("xtcp local listener closed")
|
||||
return
|
||||
}
|
||||
go sv.handleConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (sv *XTCPVisitor) internalConnWorker() {
|
||||
xl := xlog.FromContextSafe(sv.ctx)
|
||||
for {
|
||||
conn, err := sv.internalLn.Accept()
|
||||
if err != nil {
|
||||
xl.Warnf("xtcp internal listener closed")
|
||||
return
|
||||
}
|
||||
go sv.handleConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (sv *XTCPVisitor) processTunnelStartEvents() {
|
||||
for {
|
||||
select {
|
||||
@@ -206,20 +182,14 @@ func (sv *XTCPVisitor) handleConn(userConn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
var muxConnRWCloser io.ReadWriteCloser = tunnelConn
|
||||
if sv.cfg.Transport.UseEncryption {
|
||||
muxConnRWCloser, err = libio.WithEncryption(muxConnRWCloser, []byte(sv.cfg.SecretKey))
|
||||
if err != nil {
|
||||
xl.Errorf("create encryption stream error: %v", err)
|
||||
tunnelErr = err
|
||||
return
|
||||
}
|
||||
}
|
||||
if sv.cfg.Transport.UseCompression {
|
||||
var recycleFn func()
|
||||
muxConnRWCloser, recycleFn = libio.WithCompressionFromPool(muxConnRWCloser)
|
||||
defer recycleFn()
|
||||
muxConnRWCloser, recycleFn, err := wrapVisitorConn(tunnelConn, sv.cfg.GetBaseConfig())
|
||||
if err != nil {
|
||||
xl.Errorf("%v", err)
|
||||
tunnelConn.Close()
|
||||
tunnelErr = err
|
||||
return
|
||||
}
|
||||
defer recycleFn()
|
||||
|
||||
_, _, errs := libio.Join(userConn, muxConnRWCloser)
|
||||
xl.Debugf("join connections closed")
|
||||
@@ -373,6 +343,7 @@ func (ks *KCPTunnelSession) Init(listenConn *net.UDPConn, raddr *net.UDPAddr) er
|
||||
}
|
||||
remote, err := netpkg.NewKCPConnFromUDP(lConn, true, raddr.String())
|
||||
if err != nil {
|
||||
lConn.Close()
|
||||
return fmt.Errorf("create kcp connection from udp connection error: %v", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ var natholeDiscoveryCmd = &cobra.Command{
|
||||
Use: "discover",
|
||||
Short: "Discover nathole information from stun server",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
// ignore error here, because we can use command line pameters
|
||||
// ignore error here, because we can use command line parameters
|
||||
cfg, _, _, _, err := config.LoadClientConfig(cfgFile, strictConfigMode)
|
||||
if err != nil {
|
||||
cfg = &v1.ClientCommonConfig{}
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
@@ -205,7 +206,8 @@ type OidcAuthConsumer struct {
|
||||
additionalAuthScopes []v1.AuthScope
|
||||
|
||||
verifier TokenVerifier
|
||||
subjectsFromLogin []string
|
||||
mu sync.RWMutex
|
||||
subjectsFromLogin map[string]struct{}
|
||||
}
|
||||
|
||||
func NewTokenVerifier(cfg v1.AuthOIDCServerConfig) TokenVerifier {
|
||||
@@ -226,7 +228,7 @@ func NewOidcAuthVerifier(additionalAuthScopes []v1.AuthScope, verifier TokenVeri
|
||||
return &OidcAuthConsumer{
|
||||
additionalAuthScopes: additionalAuthScopes,
|
||||
verifier: verifier,
|
||||
subjectsFromLogin: []string{},
|
||||
subjectsFromLogin: make(map[string]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -235,9 +237,9 @@ func (auth *OidcAuthConsumer) VerifyLogin(loginMsg *msg.Login) (err error) {
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid OIDC token in login: %v", err)
|
||||
}
|
||||
if !slices.Contains(auth.subjectsFromLogin, token.Subject) {
|
||||
auth.subjectsFromLogin = append(auth.subjectsFromLogin, token.Subject)
|
||||
}
|
||||
auth.mu.Lock()
|
||||
auth.subjectsFromLogin[token.Subject] = struct{}{}
|
||||
auth.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -246,11 +248,13 @@ func (auth *OidcAuthConsumer) verifyPostLoginToken(privilegeKey string) (err err
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid OIDC token in ping: %v", err)
|
||||
}
|
||||
if !slices.Contains(auth.subjectsFromLogin, token.Subject) {
|
||||
auth.mu.RLock()
|
||||
_, ok := auth.subjectsFromLogin[token.Subject]
|
||||
auth.mu.RUnlock()
|
||||
if !ok {
|
||||
return fmt.Errorf("received different OIDC subject in login and ping. "+
|
||||
"original subjects: %s, "+
|
||||
"new subject: %s",
|
||||
auth.subjectsFromLogin, token.Subject)
|
||||
token.Subject)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -171,15 +171,14 @@ func Convert_ServerCommonConf_To_v1(conf *ServerCommonConf) *v1.ServerConfig {
|
||||
func transformHeadersFromPluginParams(params map[string]string) v1.HeaderOperations {
|
||||
out := v1.HeaderOperations{}
|
||||
for k, v := range params {
|
||||
if !strings.HasPrefix(k, "plugin_header_") {
|
||||
k, ok := strings.CutPrefix(k, "plugin_header_")
|
||||
if !ok || k == "" {
|
||||
continue
|
||||
}
|
||||
if k = strings.TrimPrefix(k, "plugin_header_"); k != "" {
|
||||
if out.Set == nil {
|
||||
out.Set = make(map[string]string)
|
||||
}
|
||||
out.Set[k] = v
|
||||
if out.Set == nil {
|
||||
out.Set = make(map[string]string)
|
||||
}
|
||||
out.Set[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -39,14 +39,14 @@ const (
|
||||
// Proxy
|
||||
var (
|
||||
proxyConfTypeMap = map[ProxyType]reflect.Type{
|
||||
ProxyTypeTCP: reflect.TypeOf(TCPProxyConf{}),
|
||||
ProxyTypeUDP: reflect.TypeOf(UDPProxyConf{}),
|
||||
ProxyTypeTCPMUX: reflect.TypeOf(TCPMuxProxyConf{}),
|
||||
ProxyTypeHTTP: reflect.TypeOf(HTTPProxyConf{}),
|
||||
ProxyTypeHTTPS: reflect.TypeOf(HTTPSProxyConf{}),
|
||||
ProxyTypeSTCP: reflect.TypeOf(STCPProxyConf{}),
|
||||
ProxyTypeXTCP: reflect.TypeOf(XTCPProxyConf{}),
|
||||
ProxyTypeSUDP: reflect.TypeOf(SUDPProxyConf{}),
|
||||
ProxyTypeTCP: reflect.TypeFor[TCPProxyConf](),
|
||||
ProxyTypeUDP: reflect.TypeFor[UDPProxyConf](),
|
||||
ProxyTypeTCPMUX: reflect.TypeFor[TCPMuxProxyConf](),
|
||||
ProxyTypeHTTP: reflect.TypeFor[HTTPProxyConf](),
|
||||
ProxyTypeHTTPS: reflect.TypeFor[HTTPSProxyConf](),
|
||||
ProxyTypeSTCP: reflect.TypeFor[STCPProxyConf](),
|
||||
ProxyTypeXTCP: reflect.TypeFor[XTCPProxyConf](),
|
||||
ProxyTypeSUDP: reflect.TypeFor[SUDPProxyConf](),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -22,8 +22,8 @@ func GetMapWithoutPrefix(set map[string]string, prefix string) map[string]string
|
||||
m := make(map[string]string)
|
||||
|
||||
for key, value := range set {
|
||||
if strings.HasPrefix(key, prefix) {
|
||||
m[strings.TrimPrefix(key, prefix)] = value
|
||||
if trimmed, ok := strings.CutPrefix(key, prefix); ok {
|
||||
m[trimmed] = value
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -32,9 +32,9 @@ const (
|
||||
// Visitor
|
||||
var (
|
||||
visitorConfTypeMap = map[VisitorType]reflect.Type{
|
||||
VisitorTypeSTCP: reflect.TypeOf(STCPVisitorConf{}),
|
||||
VisitorTypeXTCP: reflect.TypeOf(XTCPVisitorConf{}),
|
||||
VisitorTypeSUDP: reflect.TypeOf(SUDPVisitorConf{}),
|
||||
VisitorTypeSTCP: reflect.TypeFor[STCPVisitorConf](),
|
||||
VisitorTypeXTCP: reflect.TypeFor[XTCPVisitorConf](),
|
||||
VisitorTypeSUDP: reflect.TypeFor[SUDPVisitorConf](),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -15,9 +15,11 @@
|
||||
package source
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"maps"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
@@ -97,21 +99,11 @@ func (a *Aggregator) mapsToSortedSlices(
|
||||
proxyMap map[string]v1.ProxyConfigurer,
|
||||
visitorMap map[string]v1.VisitorConfigurer,
|
||||
) ([]v1.ProxyConfigurer, []v1.VisitorConfigurer) {
|
||||
proxies := make([]v1.ProxyConfigurer, 0, len(proxyMap))
|
||||
for _, p := range proxyMap {
|
||||
proxies = append(proxies, p)
|
||||
}
|
||||
sort.Slice(proxies, func(i, j int) bool {
|
||||
return proxies[i].GetBaseConfig().Name < proxies[j].GetBaseConfig().Name
|
||||
proxies := slices.SortedFunc(maps.Values(proxyMap), func(x, y v1.ProxyConfigurer) int {
|
||||
return cmp.Compare(x.GetBaseConfig().Name, y.GetBaseConfig().Name)
|
||||
})
|
||||
|
||||
visitors := make([]v1.VisitorConfigurer, 0, len(visitorMap))
|
||||
for _, v := range visitorMap {
|
||||
visitors = append(visitors, v)
|
||||
}
|
||||
sort.Slice(visitors, func(i, j int) bool {
|
||||
return visitors[i].GetBaseConfig().Name < visitors[j].GetBaseConfig().Name
|
||||
visitors := slices.SortedFunc(maps.Values(visitorMap), func(x, y v1.VisitorConfigurer) int {
|
||||
return cmp.Compare(x.GetBaseConfig().Name, y.GetBaseConfig().Name)
|
||||
})
|
||||
|
||||
return proxies, visitors
|
||||
}
|
||||
|
||||
@@ -196,6 +196,27 @@ func TestAggregator_VisitorMerge(t *testing.T) {
|
||||
require.Len(visitors, 2)
|
||||
}
|
||||
|
||||
func TestAggregator_Load_ReturnsSortedByName(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
agg := newTestAggregator(t, nil)
|
||||
err := agg.ConfigSource().ReplaceAll(
|
||||
[]v1.ProxyConfigurer{mockProxy("charlie"), mockProxy("alice"), mockProxy("bob")},
|
||||
[]v1.VisitorConfigurer{mockVisitor("zulu"), mockVisitor("alpha")},
|
||||
)
|
||||
require.NoError(err)
|
||||
|
||||
proxies, visitors, err := agg.Load()
|
||||
require.NoError(err)
|
||||
require.Len(proxies, 3)
|
||||
require.Equal("alice", proxies[0].GetBaseConfig().Name)
|
||||
require.Equal("bob", proxies[1].GetBaseConfig().Name)
|
||||
require.Equal("charlie", proxies[2].GetBaseConfig().Name)
|
||||
require.Len(visitors, 2)
|
||||
require.Equal("alpha", visitors[0].GetBaseConfig().Name)
|
||||
require.Equal("zulu", visitors[1].GetBaseConfig().Name)
|
||||
}
|
||||
|
||||
func TestAggregator_Load_ReturnsDefensiveCopies(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ func parseNumberRangePair(firstRangeStr, secondRangeStr string) ([]NumberPair, e
|
||||
return nil, fmt.Errorf("first and second range numbers are not in pairs")
|
||||
}
|
||||
pairs := make([]NumberPair, 0, len(firstRangeNumbers))
|
||||
for i := 0; i < len(firstRangeNumbers); i++ {
|
||||
for i := range firstRangeNumbers {
|
||||
pairs = append(pairs, NumberPair{
|
||||
First: firstRangeNumbers[i],
|
||||
Second: secondRangeNumbers[i],
|
||||
|
||||
@@ -70,24 +70,18 @@ func (q *BandwidthQuantity) UnmarshalString(s string) error {
|
||||
f float64
|
||||
err error
|
||||
)
|
||||
switch {
|
||||
case strings.HasSuffix(s, "MB"):
|
||||
if fstr, ok := strings.CutSuffix(s, "MB"); ok {
|
||||
base = MB
|
||||
fstr := strings.TrimSuffix(s, "MB")
|
||||
f, err = strconv.ParseFloat(fstr, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case strings.HasSuffix(s, "KB"):
|
||||
} else if fstr, ok := strings.CutSuffix(s, "KB"); ok {
|
||||
base = KB
|
||||
fstr := strings.TrimSuffix(s, "KB")
|
||||
f, err = strconv.ParseFloat(fstr, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
} else {
|
||||
return errors.New("unit not support")
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
q.s = s
|
||||
q.i = int64(f * float64(base))
|
||||
@@ -143,8 +137,8 @@ func (p PortsRangeSlice) String() string {
|
||||
func NewPortsRangeSliceFromString(str string) ([]PortsRange, error) {
|
||||
str = strings.TrimSpace(str)
|
||||
out := []PortsRange{}
|
||||
numRanges := strings.Split(str, ",")
|
||||
for _, numRangeStr := range numRanges {
|
||||
numRanges := strings.SplitSeq(str, ",")
|
||||
for numRangeStr := range numRanges {
|
||||
// 1000-2000 or 2001
|
||||
numArray := strings.Split(numRangeStr, "-")
|
||||
// length: only 1 or 2 is correct
|
||||
|
||||
@@ -39,6 +39,31 @@ func TestBandwidthQuantity(t *testing.T) {
|
||||
require.Equal(`{"b":"1KB","int":5}`, string(buf))
|
||||
}
|
||||
|
||||
func TestBandwidthQuantity_MB(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
var w Wrap
|
||||
err := json.Unmarshal([]byte(`{"b":"2MB","int":1}`), &w)
|
||||
require.NoError(err)
|
||||
require.EqualValues(2*MB, w.B.Bytes())
|
||||
|
||||
buf, err := json.Marshal(&w)
|
||||
require.NoError(err)
|
||||
require.Equal(`{"b":"2MB","int":1}`, string(buf))
|
||||
}
|
||||
|
||||
func TestBandwidthQuantity_InvalidUnit(t *testing.T) {
|
||||
var w Wrap
|
||||
err := json.Unmarshal([]byte(`{"b":"1GB","int":1}`), &w)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestBandwidthQuantity_InvalidNumber(t *testing.T) {
|
||||
var w Wrap
|
||||
err := json.Unmarshal([]byte(`{"b":"abcKB","int":1}`), &w)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestPortsRangeSlice2String(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
|
||||
@@ -239,14 +239,14 @@ const (
|
||||
)
|
||||
|
||||
var proxyConfigTypeMap = map[ProxyType]reflect.Type{
|
||||
ProxyTypeTCP: reflect.TypeOf(TCPProxyConfig{}),
|
||||
ProxyTypeUDP: reflect.TypeOf(UDPProxyConfig{}),
|
||||
ProxyTypeHTTP: reflect.TypeOf(HTTPProxyConfig{}),
|
||||
ProxyTypeHTTPS: reflect.TypeOf(HTTPSProxyConfig{}),
|
||||
ProxyTypeTCPMUX: reflect.TypeOf(TCPMuxProxyConfig{}),
|
||||
ProxyTypeSTCP: reflect.TypeOf(STCPProxyConfig{}),
|
||||
ProxyTypeXTCP: reflect.TypeOf(XTCPProxyConfig{}),
|
||||
ProxyTypeSUDP: reflect.TypeOf(SUDPProxyConfig{}),
|
||||
ProxyTypeTCP: reflect.TypeFor[TCPProxyConfig](),
|
||||
ProxyTypeUDP: reflect.TypeFor[UDPProxyConfig](),
|
||||
ProxyTypeHTTP: reflect.TypeFor[HTTPProxyConfig](),
|
||||
ProxyTypeHTTPS: reflect.TypeFor[HTTPSProxyConfig](),
|
||||
ProxyTypeTCPMUX: reflect.TypeFor[TCPMuxProxyConfig](),
|
||||
ProxyTypeSTCP: reflect.TypeFor[STCPProxyConfig](),
|
||||
ProxyTypeXTCP: reflect.TypeFor[XTCPProxyConfig](),
|
||||
ProxyTypeSUDP: reflect.TypeFor[SUDPProxyConfig](),
|
||||
}
|
||||
|
||||
func NewProxyConfigurerByType(proxyType ProxyType) ProxyConfigurer {
|
||||
|
||||
@@ -37,16 +37,16 @@ const (
|
||||
)
|
||||
|
||||
var clientPluginOptionsTypeMap = map[string]reflect.Type{
|
||||
PluginHTTP2HTTPS: reflect.TypeOf(HTTP2HTTPSPluginOptions{}),
|
||||
PluginHTTPProxy: reflect.TypeOf(HTTPProxyPluginOptions{}),
|
||||
PluginHTTPS2HTTP: reflect.TypeOf(HTTPS2HTTPPluginOptions{}),
|
||||
PluginHTTPS2HTTPS: reflect.TypeOf(HTTPS2HTTPSPluginOptions{}),
|
||||
PluginHTTP2HTTP: reflect.TypeOf(HTTP2HTTPPluginOptions{}),
|
||||
PluginSocks5: reflect.TypeOf(Socks5PluginOptions{}),
|
||||
PluginStaticFile: reflect.TypeOf(StaticFilePluginOptions{}),
|
||||
PluginUnixDomainSocket: reflect.TypeOf(UnixDomainSocketPluginOptions{}),
|
||||
PluginTLS2Raw: reflect.TypeOf(TLS2RawPluginOptions{}),
|
||||
PluginVirtualNet: reflect.TypeOf(VirtualNetPluginOptions{}),
|
||||
PluginHTTP2HTTPS: reflect.TypeFor[HTTP2HTTPSPluginOptions](),
|
||||
PluginHTTPProxy: reflect.TypeFor[HTTPProxyPluginOptions](),
|
||||
PluginHTTPS2HTTP: reflect.TypeFor[HTTPS2HTTPPluginOptions](),
|
||||
PluginHTTPS2HTTPS: reflect.TypeFor[HTTPS2HTTPSPluginOptions](),
|
||||
PluginHTTP2HTTP: reflect.TypeFor[HTTP2HTTPPluginOptions](),
|
||||
PluginSocks5: reflect.TypeFor[Socks5PluginOptions](),
|
||||
PluginStaticFile: reflect.TypeFor[StaticFilePluginOptions](),
|
||||
PluginUnixDomainSocket: reflect.TypeFor[UnixDomainSocketPluginOptions](),
|
||||
PluginTLS2Raw: reflect.TypeFor[TLS2RawPluginOptions](),
|
||||
PluginVirtualNet: reflect.TypeFor[VirtualNetPluginOptions](),
|
||||
}
|
||||
|
||||
type ClientPluginOptions interface {
|
||||
|
||||
@@ -79,9 +79,9 @@ const (
|
||||
)
|
||||
|
||||
var visitorConfigTypeMap = map[VisitorType]reflect.Type{
|
||||
VisitorTypeSTCP: reflect.TypeOf(STCPVisitorConfig{}),
|
||||
VisitorTypeXTCP: reflect.TypeOf(XTCPVisitorConfig{}),
|
||||
VisitorTypeSUDP: reflect.TypeOf(SUDPVisitorConfig{}),
|
||||
VisitorTypeSTCP: reflect.TypeFor[STCPVisitorConfig](),
|
||||
VisitorTypeXTCP: reflect.TypeFor[XTCPVisitorConfig](),
|
||||
VisitorTypeSUDP: reflect.TypeFor[SUDPVisitorConfig](),
|
||||
}
|
||||
|
||||
type TypedVisitorConfig struct {
|
||||
|
||||
@@ -25,7 +25,7 @@ const (
|
||||
)
|
||||
|
||||
var visitorPluginOptionsTypeMap = map[string]reflect.Type{
|
||||
VisitorPluginVirtualNet: reflect.TypeOf(VirtualNetVisitorPluginOptions{}),
|
||||
VisitorPluginVirtualNet: reflect.TypeFor[VirtualNetVisitorPluginOptions](),
|
||||
}
|
||||
|
||||
type VisitorPluginOptions interface {
|
||||
|
||||
@@ -143,7 +143,6 @@ func (m *serverMetrics) OpenConnection(name string, _ string) {
|
||||
proxyStats, ok := m.info.ProxyStatistics[name]
|
||||
if ok {
|
||||
proxyStats.CurConns.Inc(1)
|
||||
m.info.ProxyStatistics[name] = proxyStats
|
||||
}
|
||||
}
|
||||
|
||||
@@ -155,7 +154,6 @@ func (m *serverMetrics) CloseConnection(name string, _ string) {
|
||||
proxyStats, ok := m.info.ProxyStatistics[name]
|
||||
if ok {
|
||||
proxyStats.CurConns.Dec(1)
|
||||
m.info.ProxyStatistics[name] = proxyStats
|
||||
}
|
||||
}
|
||||
|
||||
@@ -168,7 +166,6 @@ func (m *serverMetrics) AddTrafficIn(name string, _ string, trafficBytes int64)
|
||||
proxyStats, ok := m.info.ProxyStatistics[name]
|
||||
if ok {
|
||||
proxyStats.TrafficIn.Inc(trafficBytes)
|
||||
m.info.ProxyStatistics[name] = proxyStats
|
||||
}
|
||||
}
|
||||
|
||||
@@ -181,7 +178,6 @@ func (m *serverMetrics) AddTrafficOut(name string, _ string, trafficBytes int64)
|
||||
proxyStats, ok := m.info.ProxyStatistics[name]
|
||||
if ok {
|
||||
proxyStats.TrafficOut.Inc(trafficBytes)
|
||||
m.info.ProxyStatistics[name] = proxyStats
|
||||
}
|
||||
}
|
||||
|
||||
@@ -203,6 +199,25 @@ func (m *serverMetrics) GetServer() *ServerStats {
|
||||
return s
|
||||
}
|
||||
|
||||
func toProxyStats(name string, proxyStats *ProxyStatistics) *ProxyStats {
|
||||
ps := &ProxyStats{
|
||||
Name: name,
|
||||
Type: proxyStats.ProxyType,
|
||||
User: proxyStats.User,
|
||||
ClientID: proxyStats.ClientID,
|
||||
TodayTrafficIn: proxyStats.TrafficIn.TodayCount(),
|
||||
TodayTrafficOut: proxyStats.TrafficOut.TodayCount(),
|
||||
CurConns: int64(proxyStats.CurConns.Count()),
|
||||
}
|
||||
if !proxyStats.LastStartTime.IsZero() {
|
||||
ps.LastStartTime = proxyStats.LastStartTime.Format("01-02 15:04:05")
|
||||
}
|
||||
if !proxyStats.LastCloseTime.IsZero() {
|
||||
ps.LastCloseTime = proxyStats.LastCloseTime.Format("01-02 15:04:05")
|
||||
}
|
||||
return ps
|
||||
}
|
||||
|
||||
func (m *serverMetrics) GetProxiesByType(proxyType string) []*ProxyStats {
|
||||
res := make([]*ProxyStats, 0)
|
||||
m.mu.Lock()
|
||||
@@ -212,23 +227,7 @@ func (m *serverMetrics) GetProxiesByType(proxyType string) []*ProxyStats {
|
||||
if proxyStats.ProxyType != proxyType {
|
||||
continue
|
||||
}
|
||||
|
||||
ps := &ProxyStats{
|
||||
Name: name,
|
||||
Type: proxyStats.ProxyType,
|
||||
User: proxyStats.User,
|
||||
ClientID: proxyStats.ClientID,
|
||||
TodayTrafficIn: proxyStats.TrafficIn.TodayCount(),
|
||||
TodayTrafficOut: proxyStats.TrafficOut.TodayCount(),
|
||||
CurConns: int64(proxyStats.CurConns.Count()),
|
||||
}
|
||||
if !proxyStats.LastStartTime.IsZero() {
|
||||
ps.LastStartTime = proxyStats.LastStartTime.Format("01-02 15:04:05")
|
||||
}
|
||||
if !proxyStats.LastCloseTime.IsZero() {
|
||||
ps.LastCloseTime = proxyStats.LastCloseTime.Format("01-02 15:04:05")
|
||||
}
|
||||
res = append(res, ps)
|
||||
res = append(res, toProxyStats(name, proxyStats))
|
||||
}
|
||||
return res
|
||||
}
|
||||
@@ -237,31 +236,9 @@ func (m *serverMetrics) GetProxiesByTypeAndName(proxyType string, proxyName stri
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for name, proxyStats := range m.info.ProxyStatistics {
|
||||
if proxyStats.ProxyType != proxyType {
|
||||
continue
|
||||
}
|
||||
|
||||
if name != proxyName {
|
||||
continue
|
||||
}
|
||||
|
||||
res = &ProxyStats{
|
||||
Name: name,
|
||||
Type: proxyStats.ProxyType,
|
||||
User: proxyStats.User,
|
||||
ClientID: proxyStats.ClientID,
|
||||
TodayTrafficIn: proxyStats.TrafficIn.TodayCount(),
|
||||
TodayTrafficOut: proxyStats.TrafficOut.TodayCount(),
|
||||
CurConns: int64(proxyStats.CurConns.Count()),
|
||||
}
|
||||
if !proxyStats.LastStartTime.IsZero() {
|
||||
res.LastStartTime = proxyStats.LastStartTime.Format("01-02 15:04:05")
|
||||
}
|
||||
if !proxyStats.LastCloseTime.IsZero() {
|
||||
res.LastCloseTime = proxyStats.LastCloseTime.Format("01-02 15:04:05")
|
||||
}
|
||||
break
|
||||
proxyStats, ok := m.info.ProxyStatistics[proxyName]
|
||||
if ok && proxyStats.ProxyType == proxyType {
|
||||
res = toProxyStats(proxyName, proxyStats)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -272,21 +249,7 @@ func (m *serverMetrics) GetProxyByName(proxyName string) (res *ProxyStats) {
|
||||
|
||||
proxyStats, ok := m.info.ProxyStatistics[proxyName]
|
||||
if ok {
|
||||
res = &ProxyStats{
|
||||
Name: proxyName,
|
||||
Type: proxyStats.ProxyType,
|
||||
User: proxyStats.User,
|
||||
ClientID: proxyStats.ClientID,
|
||||
TodayTrafficIn: proxyStats.TrafficIn.TodayCount(),
|
||||
TodayTrafficOut: proxyStats.TrafficOut.TodayCount(),
|
||||
CurConns: int64(proxyStats.CurConns.Count()),
|
||||
}
|
||||
if !proxyStats.LastStartTime.IsZero() {
|
||||
res.LastStartTime = proxyStats.LastStartTime.Format("01-02 15:04:05")
|
||||
}
|
||||
if !proxyStats.LastCloseTime.IsZero() {
|
||||
res.LastCloseTime = proxyStats.LastCloseTime.Format("01-02 15:04:05")
|
||||
}
|
||||
res = toProxyStats(proxyName, proxyStats)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -61,7 +61,7 @@ var msgTypeMap = map[byte]any{
|
||||
TypeNatHoleReport: NatHoleReport{},
|
||||
}
|
||||
|
||||
var TypeNameNatHoleResp = reflect.TypeOf(&NatHoleResp{}).Elem().Name()
|
||||
var TypeNameNatHoleResp = reflect.TypeFor[NatHoleResp]().Name()
|
||||
|
||||
type ClientSpec struct {
|
||||
// Due to the support of VirtualClient, frps needs to know the client type in order to
|
||||
@@ -184,7 +184,7 @@ type Pong struct {
|
||||
}
|
||||
|
||||
type UDPPacket struct {
|
||||
Content string `json:"c,omitempty"`
|
||||
Content []byte `json:"c,omitempty"`
|
||||
LocalAddr *net.UDPAddr `json:"l,omitempty"`
|
||||
RemoteAddr *net.UDPAddr `json:"r,omitempty"`
|
||||
}
|
||||
|
||||
@@ -16,9 +16,8 @@ func StripUserPrefix(user, name string) string {
|
||||
if user == "" {
|
||||
return name
|
||||
}
|
||||
prefix := user + "."
|
||||
if strings.HasPrefix(name, prefix) {
|
||||
return strings.TrimPrefix(name, prefix)
|
||||
if trimmed, ok := strings.CutPrefix(name, user+"."); ok {
|
||||
return trimmed
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
@@ -151,7 +151,7 @@ func getBehaviorScoresByMode(mode int, defaultScore int) []*BehaviorScore {
|
||||
func getBehaviorScoresByMode2(mode int, senderScore, receiverScore int) []*BehaviorScore {
|
||||
behaviors := getBehaviorByMode(mode)
|
||||
scores := make([]*BehaviorScore, 0, len(behaviors))
|
||||
for i := 0; i < len(behaviors); i++ {
|
||||
for i := range behaviors {
|
||||
score := receiverScore
|
||||
if behaviors[i].A.Role == DetectRoleSender {
|
||||
score = senderScore
|
||||
|
||||
@@ -70,12 +70,8 @@ func ClassifyNATFeature(addresses []string, localIPs []string) (*NatFeature, err
|
||||
continue
|
||||
}
|
||||
|
||||
if portNum > portMax {
|
||||
portMax = portNum
|
||||
}
|
||||
if portNum < portMin {
|
||||
portMin = portNum
|
||||
}
|
||||
portMax = max(portMax, portNum)
|
||||
portMin = min(portMin, portNum)
|
||||
if baseIP != ip {
|
||||
ipChanged = true
|
||||
}
|
||||
|
||||
@@ -152,7 +152,9 @@ func (c *Controller) GenSid() string {
|
||||
|
||||
func (c *Controller) HandleVisitor(m *msg.NatHoleVisitor, transporter transport.MessageTransporter, visitorUser string) {
|
||||
if m.PreCheck {
|
||||
c.mu.RLock()
|
||||
cfg, ok := c.clientCfgs[m.ProxyName]
|
||||
c.mu.RUnlock()
|
||||
if !ok {
|
||||
_ = transporter.Send(c.GenNatHoleResponse(m.TransactionID, nil, fmt.Sprintf("xtcp server for [%s] doesn't exist", m.ProxyName)))
|
||||
return
|
||||
|
||||
@@ -298,11 +298,13 @@ func waitDetectMessage(
|
||||
n, raddr, err := conn.ReadFromUDP(buf)
|
||||
_ = conn.SetReadDeadline(time.Time{})
|
||||
if err != nil {
|
||||
pool.PutBuf(buf)
|
||||
return nil, err
|
||||
}
|
||||
xl.Debugf("get udp message local %s, from %s", conn.LocalAddr(), raddr)
|
||||
var m msg.NatHoleSid
|
||||
if err := DecodeMessageInto(buf[:n], key, &m); err != nil {
|
||||
pool.PutBuf(buf)
|
||||
xl.Warnf("decode sid message error: %v", err)
|
||||
continue
|
||||
}
|
||||
@@ -408,7 +410,7 @@ func sendSidMessageToRandomPorts(
|
||||
xl := xlog.FromContextSafe(ctx)
|
||||
used := sets.New[int]()
|
||||
getUnusedPort := func() int {
|
||||
for i := 0; i < 10; i++ {
|
||||
for range 10 {
|
||||
port := rand.IntN(65535-1024) + 1024
|
||||
if !used.Has(port) {
|
||||
used.Insert(port)
|
||||
@@ -418,7 +420,7 @@ func sendSidMessageToRandomPorts(
|
||||
return 0
|
||||
}
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
for range count {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
stdlog "log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/golib/pool"
|
||||
|
||||
@@ -68,7 +69,7 @@ func NewHTTP2HTTPPlugin(_ PluginContext, options v1.ClientPluginOptions) (Plugin
|
||||
|
||||
p.s = &http.Server{
|
||||
Handler: rp,
|
||||
ReadHeaderTimeout: 0,
|
||||
ReadHeaderTimeout: 60 * time.Second,
|
||||
}
|
||||
|
||||
go func() {
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
stdlog "log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/golib/pool"
|
||||
|
||||
@@ -77,7 +78,7 @@ func NewHTTP2HTTPSPlugin(_ PluginContext, options v1.ClientPluginOptions) (Plugi
|
||||
|
||||
p.s = &http.Server{
|
||||
Handler: rp,
|
||||
ReadHeaderTimeout: 0,
|
||||
ReadHeaderTimeout: 60 * time.Second,
|
||||
}
|
||||
|
||||
go func() {
|
||||
|
||||
@@ -62,11 +62,13 @@ func (p *TLS2RawPlugin) Handle(ctx context.Context, connInfo *ConnectionInfo) {
|
||||
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
xl.Warnf("tls handshake error: %v", err)
|
||||
tlsConn.Close()
|
||||
return
|
||||
}
|
||||
rawConn, err := net.Dial("tcp", p.opts.LocalAddr)
|
||||
if err != nil {
|
||||
xl.Warnf("dial to local addr error: %v", err)
|
||||
tlsConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -54,10 +54,13 @@ func (uds *UnixDomainSocketPlugin) Handle(ctx context.Context, connInfo *Connect
|
||||
localConn, err := net.DialUnix("unix", nil, uds.UnixAddr)
|
||||
if err != nil {
|
||||
xl.Warnf("dial to uds %s error: %v", uds.UnixAddr, err)
|
||||
connInfo.Conn.Close()
|
||||
return
|
||||
}
|
||||
if connInfo.ProxyProtocolHeader != nil {
|
||||
if _, err := connInfo.ProxyProtocolHeader.WriteTo(localConn); err != nil {
|
||||
localConn.Close()
|
||||
connInfo.Conn.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
@@ -64,12 +65,7 @@ func (p *httpPlugin) Name() string {
|
||||
}
|
||||
|
||||
func (p *httpPlugin) IsSupport(op string) bool {
|
||||
for _, v := range p.options.Ops {
|
||||
if v == op {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
return slices.Contains(p.options.Ops, op)
|
||||
}
|
||||
|
||||
func (p *httpPlugin) Handle(ctx context.Context, op string, content any) (*Response, any, error) {
|
||||
|
||||
@@ -153,10 +153,7 @@ func (p *VirtualNetPlugin) run() {
|
||||
|
||||
// Exponential backoff: 60s, 120s, 240s, 300s (capped)
|
||||
baseDelay := 60 * time.Second
|
||||
reconnectDelay = baseDelay * time.Duration(1<<uint(p.consecutiveErrors-1))
|
||||
if reconnectDelay > 300*time.Second {
|
||||
reconnectDelay = 300 * time.Second
|
||||
}
|
||||
reconnectDelay = min(baseDelay*time.Duration(1<<uint(p.consecutiveErrors-1)), 300*time.Second)
|
||||
} else {
|
||||
// Reset consecutive errors on successful connection
|
||||
if p.consecutiveErrors > 0 {
|
||||
|
||||
@@ -16,6 +16,7 @@ package featuregate
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"maps"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -92,10 +93,7 @@ type featureGate struct {
|
||||
|
||||
// NewFeatureGate creates a new feature gate with the default features
|
||||
func NewFeatureGate() MutableFeatureGate {
|
||||
known := map[Feature]FeatureSpec{}
|
||||
for k, v := range defaultFeatures {
|
||||
known[k] = v
|
||||
}
|
||||
known := maps.Clone(defaultFeatures)
|
||||
|
||||
f := &featureGate{}
|
||||
f.known.Store(known)
|
||||
@@ -109,14 +107,8 @@ func (f *featureGate) SetFromMap(m map[string]bool) error {
|
||||
defer f.lock.Unlock()
|
||||
|
||||
// Copy existing state
|
||||
known := map[Feature]FeatureSpec{}
|
||||
for k, v := range f.known.Load().(map[Feature]FeatureSpec) {
|
||||
known[k] = v
|
||||
}
|
||||
enabled := map[Feature]bool{}
|
||||
for k, v := range f.enabled.Load().(map[Feature]bool) {
|
||||
enabled[k] = v
|
||||
}
|
||||
known := maps.Clone(f.known.Load().(map[Feature]FeatureSpec))
|
||||
enabled := maps.Clone(f.enabled.Load().(map[Feature]bool))
|
||||
|
||||
// Apply the new settings
|
||||
for k, v := range m {
|
||||
@@ -147,10 +139,7 @@ func (f *featureGate) Add(features map[Feature]FeatureSpec) error {
|
||||
}
|
||||
|
||||
// Copy existing state
|
||||
known := map[Feature]FeatureSpec{}
|
||||
for k, v := range f.known.Load().(map[Feature]FeatureSpec) {
|
||||
known[k] = v
|
||||
}
|
||||
known := maps.Clone(f.known.Load().(map[Feature]FeatureSpec))
|
||||
|
||||
// Add new features
|
||||
for name, spec := range features {
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -28,16 +27,17 @@ import (
|
||||
)
|
||||
|
||||
func NewUDPPacket(buf []byte, laddr, raddr *net.UDPAddr) *msg.UDPPacket {
|
||||
content := make([]byte, len(buf))
|
||||
copy(content, buf)
|
||||
return &msg.UDPPacket{
|
||||
Content: base64.StdEncoding.EncodeToString(buf),
|
||||
Content: content,
|
||||
LocalAddr: laddr,
|
||||
RemoteAddr: raddr,
|
||||
}
|
||||
}
|
||||
|
||||
func GetContent(m *msg.UDPPacket) (buf []byte, err error) {
|
||||
buf, err = base64.StdEncoding.DecodeString(m.Content)
|
||||
return
|
||||
return m.Content, nil
|
||||
}
|
||||
|
||||
func ForwardUserConn(udpConn *net.UDPConn, readCh <-chan *msg.UDPPacket, sendCh chan<- *msg.UDPPacket, bufSize int) {
|
||||
@@ -60,7 +60,7 @@ func ForwardUserConn(udpConn *net.UDPConn, readCh <-chan *msg.UDPPacket, sendCh
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// buf[:n] will be encoded to string, so the bytes can be reused
|
||||
// NewUDPPacket copies buf[:n], so the read buffer can be reused
|
||||
udpMsg := NewUDPPacket(buf[:n], nil, remoteAddr)
|
||||
|
||||
select {
|
||||
@@ -85,6 +85,7 @@ func Forwarder(dstAddr *net.UDPAddr, readCh <-chan *msg.UDPPacket, sendCh chan<-
|
||||
}()
|
||||
|
||||
buf := pool.GetBuf(bufSize)
|
||||
defer pool.PutBuf(buf)
|
||||
for {
|
||||
_ = udpConn.SetReadDeadline(time.Now().Add(30 * time.Second))
|
||||
n, _, err := udpConn.ReadFromUDP(buf)
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"os"
|
||||
"time"
|
||||
@@ -85,7 +86,9 @@ func newCertPool(caPath string) (*x509.CertPool, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pool.AppendCertsFromPEM(caCrt)
|
||||
if !pool.AppendCertsFromPEM(caCrt) {
|
||||
return nil, fmt.Errorf("failed to parse CA certificate from file %q: no valid PEM certificates found", caPath)
|
||||
}
|
||||
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
@@ -89,11 +89,11 @@ func ParseBasicAuth(auth string) (username, password string, ok bool) {
|
||||
return
|
||||
}
|
||||
cs := string(c)
|
||||
s := strings.IndexByte(cs, ':')
|
||||
if s < 0 {
|
||||
before, after, found := strings.Cut(cs, ":")
|
||||
if !found {
|
||||
return
|
||||
}
|
||||
return cs[:s], cs[s+1:], true
|
||||
return before, after, true
|
||||
}
|
||||
|
||||
func BasicAuth(username, passwd string) string {
|
||||
|
||||
@@ -86,11 +86,7 @@ func (c *FakeUDPConn) Read(b []byte) (n int, err error) {
|
||||
c.lastActive = time.Now()
|
||||
c.mu.Unlock()
|
||||
|
||||
if len(b) < len(content) {
|
||||
n = len(b)
|
||||
} else {
|
||||
n = len(content)
|
||||
}
|
||||
n = min(len(b), len(content))
|
||||
copy(b, content)
|
||||
return n, nil
|
||||
}
|
||||
@@ -168,11 +164,15 @@ func ListenUDP(bindAddr string, bindPort int) (l *UDPListener, err error) {
|
||||
return l, err
|
||||
}
|
||||
readConn, err := net.ListenUDP("udp", udpAddr)
|
||||
if err != nil {
|
||||
return l, err
|
||||
}
|
||||
|
||||
l = &UDPListener{
|
||||
addr: udpAddr,
|
||||
acceptCh: make(chan net.Conn),
|
||||
writeCh: make(chan *UDPPacket, 1000),
|
||||
readConn: readConn,
|
||||
fakeConns: make(map[string]*FakeUDPConn),
|
||||
}
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ type WebsocketListener struct {
|
||||
// ln: tcp listener for websocket connections
|
||||
func NewWebsocketListener(ln net.Listener) (wl *WebsocketListener) {
|
||||
wl = &WebsocketListener{
|
||||
ln: ln,
|
||||
acceptCh: make(chan net.Conn),
|
||||
}
|
||||
|
||||
|
||||
@@ -68,8 +68,8 @@ func ParseRangeNumbers(rangeStr string) (numbers []int64, err error) {
|
||||
rangeStr = strings.TrimSpace(rangeStr)
|
||||
numbers = make([]int64, 0)
|
||||
// e.g. 1000-2000,2001,2002,3000-4000
|
||||
numRanges := strings.Split(rangeStr, ",")
|
||||
for _, numRangeStr := range numRanges {
|
||||
numRanges := strings.SplitSeq(rangeStr, ",")
|
||||
for numRangeStr := range numRanges {
|
||||
// 1000-2000 or 2001
|
||||
numArray := strings.Split(numRangeStr, "-")
|
||||
// length: only 1 or 2 is correct
|
||||
|
||||
@@ -266,31 +266,13 @@ func (rp *HTTPReverseProxy) connectHandler(rw http.ResponseWriter, req *http.Req
|
||||
go libio.Join(remote, client)
|
||||
}
|
||||
|
||||
func parseBasicAuth(auth string) (username, password string, ok bool) {
|
||||
const prefix = "Basic "
|
||||
// Case insensitive prefix match. See Issue 22736.
|
||||
if len(auth) < len(prefix) || !strings.EqualFold(auth[:len(prefix)], prefix) {
|
||||
return
|
||||
}
|
||||
c, err := base64.StdEncoding.DecodeString(auth[len(prefix):])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
cs := string(c)
|
||||
s := strings.IndexByte(cs, ':')
|
||||
if s < 0 {
|
||||
return
|
||||
}
|
||||
return cs[:s], cs[s+1:], true
|
||||
}
|
||||
|
||||
func (rp *HTTPReverseProxy) injectRequestInfoToCtx(req *http.Request) *http.Request {
|
||||
user := ""
|
||||
// If url host isn't empty, it's a proxy request. Get http user from Proxy-Authorization header.
|
||||
if req.URL.Host != "" {
|
||||
proxyAuth := req.Header.Get("Proxy-Authorization")
|
||||
if proxyAuth != "" {
|
||||
user, _, _ = parseBasicAuth(proxyAuth)
|
||||
user, _, _ = httppkg.ParseBasicAuth(proxyAuth)
|
||||
}
|
||||
}
|
||||
if user == "" {
|
||||
|
||||
@@ -63,11 +63,12 @@ func (l *Logger) AddPrefix(prefix LogPrefix) *Logger {
|
||||
if prefix.Priority <= 0 {
|
||||
prefix.Priority = 10
|
||||
}
|
||||
for _, p := range l.prefixes {
|
||||
for i, p := range l.prefixes {
|
||||
if p.Name == prefix.Name {
|
||||
found = true
|
||||
p.Value = prefix.Value
|
||||
p.Priority = prefix.Priority
|
||||
l.prefixes[i].Value = prefix.Value
|
||||
l.prefixes[i].Priority = prefix.Priority
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
|
||||
@@ -95,20 +95,33 @@ func (cm *ControlManager) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type Control struct {
|
||||
// SessionContext encapsulates the input parameters for creating a new Control.
|
||||
type SessionContext struct {
|
||||
// all resource managers and controllers
|
||||
rc *controller.ResourceController
|
||||
|
||||
RC *controller.ResourceController
|
||||
// proxy manager
|
||||
pxyManager *proxy.Manager
|
||||
|
||||
PxyManager *proxy.Manager
|
||||
// plugin manager
|
||||
pluginManager *plugin.Manager
|
||||
|
||||
PluginManager *plugin.Manager
|
||||
// verifies authentication based on selected method
|
||||
authVerifier auth.Verifier
|
||||
AuthVerifier auth.Verifier
|
||||
// key used for connection encryption
|
||||
encryptionKey []byte
|
||||
EncryptionKey []byte
|
||||
// control connection
|
||||
Conn net.Conn
|
||||
// indicates whether the connection is encrypted
|
||||
ConnEncrypted bool
|
||||
// login message
|
||||
LoginMsg *msg.Login
|
||||
// server configuration
|
||||
ServerCfg *v1.ServerConfig
|
||||
// client registry
|
||||
ClientRegistry *registry.ClientRegistry
|
||||
}
|
||||
|
||||
type Control struct {
|
||||
// session context
|
||||
sessionCtx *SessionContext
|
||||
|
||||
// other components can use this to communicate with client
|
||||
msgTransporter transport.MessageTransporter
|
||||
@@ -117,12 +130,6 @@ type Control struct {
|
||||
// It provides a channel for sending messages, and you can register handlers to process messages based on their respective types.
|
||||
msgDispatcher *msg.Dispatcher
|
||||
|
||||
// login message
|
||||
loginMsg *msg.Login
|
||||
|
||||
// control connection
|
||||
conn net.Conn
|
||||
|
||||
// work connections
|
||||
workConnCh chan net.Conn
|
||||
|
||||
@@ -145,61 +152,34 @@ type Control struct {
|
||||
|
||||
mu sync.RWMutex
|
||||
|
||||
// Server configuration information
|
||||
serverCfg *v1.ServerConfig
|
||||
|
||||
clientRegistry *registry.ClientRegistry
|
||||
|
||||
xl *xlog.Logger
|
||||
ctx context.Context
|
||||
doneCh chan struct{}
|
||||
}
|
||||
|
||||
// TODO(fatedier): Referencing the implementation of frpc, encapsulate the input parameters as SessionContext.
|
||||
func NewControl(
|
||||
ctx context.Context,
|
||||
rc *controller.ResourceController,
|
||||
pxyManager *proxy.Manager,
|
||||
pluginManager *plugin.Manager,
|
||||
authVerifier auth.Verifier,
|
||||
encryptionKey []byte,
|
||||
ctlConn net.Conn,
|
||||
ctlConnEncrypted bool,
|
||||
loginMsg *msg.Login,
|
||||
serverCfg *v1.ServerConfig,
|
||||
) (*Control, error) {
|
||||
poolCount := loginMsg.PoolCount
|
||||
if poolCount > int(serverCfg.Transport.MaxPoolCount) {
|
||||
poolCount = int(serverCfg.Transport.MaxPoolCount)
|
||||
}
|
||||
func NewControl(ctx context.Context, sessionCtx *SessionContext) (*Control, error) {
|
||||
poolCount := min(sessionCtx.LoginMsg.PoolCount, int(sessionCtx.ServerCfg.Transport.MaxPoolCount))
|
||||
ctl := &Control{
|
||||
rc: rc,
|
||||
pxyManager: pxyManager,
|
||||
pluginManager: pluginManager,
|
||||
authVerifier: authVerifier,
|
||||
encryptionKey: encryptionKey,
|
||||
conn: ctlConn,
|
||||
loginMsg: loginMsg,
|
||||
workConnCh: make(chan net.Conn, poolCount+10),
|
||||
proxies: make(map[string]proxy.Proxy),
|
||||
poolCount: poolCount,
|
||||
portsUsedNum: 0,
|
||||
runID: loginMsg.RunID,
|
||||
serverCfg: serverCfg,
|
||||
xl: xlog.FromContextSafe(ctx),
|
||||
ctx: ctx,
|
||||
doneCh: make(chan struct{}),
|
||||
sessionCtx: sessionCtx,
|
||||
workConnCh: make(chan net.Conn, poolCount+10),
|
||||
proxies: make(map[string]proxy.Proxy),
|
||||
poolCount: poolCount,
|
||||
portsUsedNum: 0,
|
||||
runID: sessionCtx.LoginMsg.RunID,
|
||||
xl: xlog.FromContextSafe(ctx),
|
||||
ctx: ctx,
|
||||
doneCh: make(chan struct{}),
|
||||
}
|
||||
ctl.lastPing.Store(time.Now())
|
||||
|
||||
if ctlConnEncrypted {
|
||||
cryptoRW, err := netpkg.NewCryptoReadWriter(ctl.conn, ctl.encryptionKey)
|
||||
if sessionCtx.ConnEncrypted {
|
||||
cryptoRW, err := netpkg.NewCryptoReadWriter(sessionCtx.Conn, sessionCtx.EncryptionKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctl.msgDispatcher = msg.NewDispatcher(cryptoRW)
|
||||
} else {
|
||||
ctl.msgDispatcher = msg.NewDispatcher(ctl.conn)
|
||||
ctl.msgDispatcher = msg.NewDispatcher(sessionCtx.Conn)
|
||||
}
|
||||
ctl.registerMsgHandlers()
|
||||
ctl.msgTransporter = transport.NewMessageTransporter(ctl.msgDispatcher)
|
||||
@@ -213,7 +193,7 @@ func (ctl *Control) Start() {
|
||||
RunID: ctl.runID,
|
||||
Error: "",
|
||||
}
|
||||
_ = msg.WriteMsg(ctl.conn, loginRespMsg)
|
||||
_ = msg.WriteMsg(ctl.sessionCtx.Conn, loginRespMsg)
|
||||
|
||||
go func() {
|
||||
for i := 0; i < ctl.poolCount; i++ {
|
||||
@@ -225,7 +205,7 @@ func (ctl *Control) Start() {
|
||||
}
|
||||
|
||||
func (ctl *Control) Close() error {
|
||||
ctl.conn.Close()
|
||||
ctl.sessionCtx.Conn.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -233,7 +213,7 @@ func (ctl *Control) Replaced(newCtl *Control) {
|
||||
xl := ctl.xl
|
||||
xl.Infof("replaced by client [%s]", newCtl.runID)
|
||||
ctl.runID = ""
|
||||
ctl.conn.Close()
|
||||
ctl.sessionCtx.Conn.Close()
|
||||
}
|
||||
|
||||
func (ctl *Control) RegisterWorkConn(conn net.Conn) error {
|
||||
@@ -291,7 +271,7 @@ func (ctl *Control) GetWorkConn() (workConn net.Conn, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
case <-time.After(time.Duration(ctl.serverCfg.UserConnTimeout) * time.Second):
|
||||
case <-time.After(time.Duration(ctl.sessionCtx.ServerCfg.UserConnTimeout) * time.Second):
|
||||
err = fmt.Errorf("timeout trying to get work connection")
|
||||
xl.Warnf("%v", err)
|
||||
return
|
||||
@@ -304,15 +284,15 @@ func (ctl *Control) GetWorkConn() (workConn net.Conn, err error) {
|
||||
}
|
||||
|
||||
func (ctl *Control) heartbeatWorker() {
|
||||
if ctl.serverCfg.Transport.HeartbeatTimeout <= 0 {
|
||||
if ctl.sessionCtx.ServerCfg.Transport.HeartbeatTimeout <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
xl := ctl.xl
|
||||
go wait.Until(func() {
|
||||
if time.Since(ctl.lastPing.Load().(time.Time)) > time.Duration(ctl.serverCfg.Transport.HeartbeatTimeout)*time.Second {
|
||||
if time.Since(ctl.lastPing.Load().(time.Time)) > time.Duration(ctl.sessionCtx.ServerCfg.Transport.HeartbeatTimeout)*time.Second {
|
||||
xl.Warnf("heartbeat timeout")
|
||||
ctl.conn.Close()
|
||||
ctl.sessionCtx.Conn.Close()
|
||||
return
|
||||
}
|
||||
}, time.Second, ctl.doneCh)
|
||||
@@ -323,6 +303,30 @@ func (ctl *Control) WaitClosed() {
|
||||
<-ctl.doneCh
|
||||
}
|
||||
|
||||
func (ctl *Control) loginUserInfo() plugin.UserInfo {
|
||||
return plugin.UserInfo{
|
||||
User: ctl.sessionCtx.LoginMsg.User,
|
||||
Metas: ctl.sessionCtx.LoginMsg.Metas,
|
||||
RunID: ctl.sessionCtx.LoginMsg.RunID,
|
||||
}
|
||||
}
|
||||
|
||||
func (ctl *Control) closeProxy(pxy proxy.Proxy) {
|
||||
pxy.Close()
|
||||
ctl.sessionCtx.PxyManager.Del(pxy.GetName())
|
||||
metrics.Server.CloseProxy(pxy.GetName(), pxy.GetConfigurer().GetBaseConfig().Type)
|
||||
|
||||
notifyContent := &plugin.CloseProxyContent{
|
||||
User: ctl.loginUserInfo(),
|
||||
CloseProxy: msg.CloseProxy{
|
||||
ProxyName: pxy.GetName(),
|
||||
},
|
||||
}
|
||||
go func() {
|
||||
_ = ctl.sessionCtx.PluginManager.CloseProxy(notifyContent)
|
||||
}()
|
||||
}
|
||||
|
||||
func (ctl *Control) worker() {
|
||||
xl := ctl.xl
|
||||
|
||||
@@ -330,38 +334,23 @@ func (ctl *Control) worker() {
|
||||
go ctl.msgDispatcher.Run()
|
||||
|
||||
<-ctl.msgDispatcher.Done()
|
||||
ctl.conn.Close()
|
||||
ctl.sessionCtx.Conn.Close()
|
||||
|
||||
ctl.mu.Lock()
|
||||
defer ctl.mu.Unlock()
|
||||
|
||||
close(ctl.workConnCh)
|
||||
for workConn := range ctl.workConnCh {
|
||||
workConn.Close()
|
||||
}
|
||||
proxies := ctl.proxies
|
||||
ctl.proxies = make(map[string]proxy.Proxy)
|
||||
ctl.mu.Unlock()
|
||||
|
||||
for _, pxy := range ctl.proxies {
|
||||
pxy.Close()
|
||||
ctl.pxyManager.Del(pxy.GetName())
|
||||
metrics.Server.CloseProxy(pxy.GetName(), pxy.GetConfigurer().GetBaseConfig().Type)
|
||||
|
||||
notifyContent := &plugin.CloseProxyContent{
|
||||
User: plugin.UserInfo{
|
||||
User: ctl.loginMsg.User,
|
||||
Metas: ctl.loginMsg.Metas,
|
||||
RunID: ctl.loginMsg.RunID,
|
||||
},
|
||||
CloseProxy: msg.CloseProxy{
|
||||
ProxyName: pxy.GetName(),
|
||||
},
|
||||
}
|
||||
go func() {
|
||||
_ = ctl.pluginManager.CloseProxy(notifyContent)
|
||||
}()
|
||||
for _, pxy := range proxies {
|
||||
ctl.closeProxy(pxy)
|
||||
}
|
||||
|
||||
metrics.Server.CloseClient()
|
||||
ctl.clientRegistry.MarkOfflineByRunID(ctl.runID)
|
||||
ctl.sessionCtx.ClientRegistry.MarkOfflineByRunID(ctl.runID)
|
||||
xl.Infof("client exit success")
|
||||
close(ctl.doneCh)
|
||||
}
|
||||
@@ -380,15 +369,11 @@ func (ctl *Control) handleNewProxy(m msg.Message) {
|
||||
inMsg := m.(*msg.NewProxy)
|
||||
|
||||
content := &plugin.NewProxyContent{
|
||||
User: plugin.UserInfo{
|
||||
User: ctl.loginMsg.User,
|
||||
Metas: ctl.loginMsg.Metas,
|
||||
RunID: ctl.loginMsg.RunID,
|
||||
},
|
||||
User: ctl.loginUserInfo(),
|
||||
NewProxy: *inMsg,
|
||||
}
|
||||
var remoteAddr string
|
||||
retContent, err := ctl.pluginManager.NewProxy(content)
|
||||
retContent, err := ctl.sessionCtx.PluginManager.NewProxy(content)
|
||||
if err == nil {
|
||||
inMsg = &retContent.NewProxy
|
||||
remoteAddr, err = ctl.RegisterProxy(inMsg)
|
||||
@@ -401,15 +386,15 @@ func (ctl *Control) handleNewProxy(m msg.Message) {
|
||||
if err != nil {
|
||||
xl.Warnf("new proxy [%s] type [%s] error: %v", inMsg.ProxyName, inMsg.ProxyType, err)
|
||||
resp.Error = util.GenerateResponseErrorString(fmt.Sprintf("new proxy [%s] error", inMsg.ProxyName),
|
||||
err, lo.FromPtr(ctl.serverCfg.DetailedErrorsToClient))
|
||||
err, lo.FromPtr(ctl.sessionCtx.ServerCfg.DetailedErrorsToClient))
|
||||
} else {
|
||||
resp.RemoteAddr = remoteAddr
|
||||
xl.Infof("new proxy [%s] type [%s] success", inMsg.ProxyName, inMsg.ProxyType)
|
||||
clientID := ctl.loginMsg.ClientID
|
||||
clientID := ctl.sessionCtx.LoginMsg.ClientID
|
||||
if clientID == "" {
|
||||
clientID = ctl.loginMsg.RunID
|
||||
clientID = ctl.sessionCtx.LoginMsg.RunID
|
||||
}
|
||||
metrics.Server.NewProxy(inMsg.ProxyName, inMsg.ProxyType, ctl.loginMsg.User, clientID)
|
||||
metrics.Server.NewProxy(inMsg.ProxyName, inMsg.ProxyType, ctl.sessionCtx.LoginMsg.User, clientID)
|
||||
}
|
||||
_ = ctl.msgDispatcher.Send(resp)
|
||||
}
|
||||
@@ -419,22 +404,18 @@ func (ctl *Control) handlePing(m msg.Message) {
|
||||
inMsg := m.(*msg.Ping)
|
||||
|
||||
content := &plugin.PingContent{
|
||||
User: plugin.UserInfo{
|
||||
User: ctl.loginMsg.User,
|
||||
Metas: ctl.loginMsg.Metas,
|
||||
RunID: ctl.loginMsg.RunID,
|
||||
},
|
||||
User: ctl.loginUserInfo(),
|
||||
Ping: *inMsg,
|
||||
}
|
||||
retContent, err := ctl.pluginManager.Ping(content)
|
||||
retContent, err := ctl.sessionCtx.PluginManager.Ping(content)
|
||||
if err == nil {
|
||||
inMsg = &retContent.Ping
|
||||
err = ctl.authVerifier.VerifyPing(inMsg)
|
||||
err = ctl.sessionCtx.AuthVerifier.VerifyPing(inMsg)
|
||||
}
|
||||
if err != nil {
|
||||
xl.Warnf("received invalid ping: %v", err)
|
||||
_ = ctl.msgDispatcher.Send(&msg.Pong{
|
||||
Error: util.GenerateResponseErrorString("invalid ping", err, lo.FromPtr(ctl.serverCfg.DetailedErrorsToClient)),
|
||||
Error: util.GenerateResponseErrorString("invalid ping", err, lo.FromPtr(ctl.sessionCtx.ServerCfg.DetailedErrorsToClient)),
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -445,17 +426,17 @@ func (ctl *Control) handlePing(m msg.Message) {
|
||||
|
||||
func (ctl *Control) handleNatHoleVisitor(m msg.Message) {
|
||||
inMsg := m.(*msg.NatHoleVisitor)
|
||||
ctl.rc.NatHoleController.HandleVisitor(inMsg, ctl.msgTransporter, ctl.loginMsg.User)
|
||||
ctl.sessionCtx.RC.NatHoleController.HandleVisitor(inMsg, ctl.msgTransporter, ctl.sessionCtx.LoginMsg.User)
|
||||
}
|
||||
|
||||
func (ctl *Control) handleNatHoleClient(m msg.Message) {
|
||||
inMsg := m.(*msg.NatHoleClient)
|
||||
ctl.rc.NatHoleController.HandleClient(inMsg, ctl.msgTransporter)
|
||||
ctl.sessionCtx.RC.NatHoleController.HandleClient(inMsg, ctl.msgTransporter)
|
||||
}
|
||||
|
||||
func (ctl *Control) handleNatHoleReport(m msg.Message) {
|
||||
inMsg := m.(*msg.NatHoleReport)
|
||||
ctl.rc.NatHoleController.HandleReport(inMsg)
|
||||
ctl.sessionCtx.RC.NatHoleController.HandleReport(inMsg)
|
||||
}
|
||||
|
||||
func (ctl *Control) handleCloseProxy(m msg.Message) {
|
||||
@@ -468,15 +449,15 @@ func (ctl *Control) handleCloseProxy(m msg.Message) {
|
||||
func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err error) {
|
||||
var pxyConf v1.ProxyConfigurer
|
||||
// Load configures from NewProxy message and validate.
|
||||
pxyConf, err = config.NewProxyConfigurerFromMsg(pxyMsg, ctl.serverCfg)
|
||||
pxyConf, err = config.NewProxyConfigurerFromMsg(pxyMsg, ctl.sessionCtx.ServerCfg)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// User info
|
||||
userInfo := plugin.UserInfo{
|
||||
User: ctl.loginMsg.User,
|
||||
Metas: ctl.loginMsg.Metas,
|
||||
User: ctl.sessionCtx.LoginMsg.User,
|
||||
Metas: ctl.sessionCtx.LoginMsg.Metas,
|
||||
RunID: ctl.runID,
|
||||
}
|
||||
|
||||
@@ -484,22 +465,22 @@ func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err
|
||||
// In fact, it creates different proxies based on the proxy type. We just call run() here.
|
||||
pxy, err := proxy.NewProxy(ctl.ctx, &proxy.Options{
|
||||
UserInfo: userInfo,
|
||||
LoginMsg: ctl.loginMsg,
|
||||
LoginMsg: ctl.sessionCtx.LoginMsg,
|
||||
PoolCount: ctl.poolCount,
|
||||
ResourceController: ctl.rc,
|
||||
ResourceController: ctl.sessionCtx.RC,
|
||||
GetWorkConnFn: ctl.GetWorkConn,
|
||||
Configurer: pxyConf,
|
||||
ServerCfg: ctl.serverCfg,
|
||||
EncryptionKey: ctl.encryptionKey,
|
||||
ServerCfg: ctl.sessionCtx.ServerCfg,
|
||||
EncryptionKey: ctl.sessionCtx.EncryptionKey,
|
||||
})
|
||||
if err != nil {
|
||||
return remoteAddr, err
|
||||
}
|
||||
|
||||
// Check ports used number in each client
|
||||
if ctl.serverCfg.MaxPortsPerClient > 0 {
|
||||
if ctl.sessionCtx.ServerCfg.MaxPortsPerClient > 0 {
|
||||
ctl.mu.Lock()
|
||||
if ctl.portsUsedNum+pxy.GetUsedPortsNum() > int(ctl.serverCfg.MaxPortsPerClient) {
|
||||
if ctl.portsUsedNum+pxy.GetUsedPortsNum() > int(ctl.sessionCtx.ServerCfg.MaxPortsPerClient) {
|
||||
ctl.mu.Unlock()
|
||||
err = fmt.Errorf("exceed the max_ports_per_client")
|
||||
return
|
||||
@@ -516,7 +497,7 @@ func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err
|
||||
}()
|
||||
}
|
||||
|
||||
if ctl.pxyManager.Exist(pxyMsg.ProxyName) {
|
||||
if ctl.sessionCtx.PxyManager.Exist(pxyMsg.ProxyName) {
|
||||
err = fmt.Errorf("proxy [%s] already exists", pxyMsg.ProxyName)
|
||||
return
|
||||
}
|
||||
@@ -531,7 +512,7 @@ func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err
|
||||
}
|
||||
}()
|
||||
|
||||
err = ctl.pxyManager.Add(pxyMsg.ProxyName, pxy)
|
||||
err = ctl.sessionCtx.PxyManager.Add(pxyMsg.ProxyName, pxy)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -550,28 +531,12 @@ func (ctl *Control) CloseProxy(closeMsg *msg.CloseProxy) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
if ctl.serverCfg.MaxPortsPerClient > 0 {
|
||||
if ctl.sessionCtx.ServerCfg.MaxPortsPerClient > 0 {
|
||||
ctl.portsUsedNum -= pxy.GetUsedPortsNum()
|
||||
}
|
||||
pxy.Close()
|
||||
ctl.pxyManager.Del(pxy.GetName())
|
||||
delete(ctl.proxies, closeMsg.ProxyName)
|
||||
ctl.mu.Unlock()
|
||||
|
||||
metrics.Server.CloseProxy(pxy.GetName(), pxy.GetConfigurer().GetBaseConfig().Type)
|
||||
|
||||
notifyContent := &plugin.CloseProxyContent{
|
||||
User: plugin.UserInfo{
|
||||
User: ctl.loginMsg.User,
|
||||
Metas: ctl.loginMsg.Metas,
|
||||
RunID: ctl.loginMsg.RunID,
|
||||
},
|
||||
CloseProxy: msg.CloseProxy{
|
||||
ProxyName: pxy.GetName(),
|
||||
},
|
||||
}
|
||||
go func() {
|
||||
_ = ctl.pluginManager.CloseProxy(notifyContent)
|
||||
}()
|
||||
ctl.closeProxy(pxy)
|
||||
return
|
||||
}
|
||||
|
||||
77
server/group/base.go
Normal file
77
server/group/base.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package group
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
gerr "github.com/fatedier/golib/errors"
|
||||
)
|
||||
|
||||
// baseGroup contains the shared plumbing for listener-based groups
|
||||
// (TCP, HTTPS, TCPMux). Each concrete group embeds this and provides
|
||||
// its own Listen method with protocol-specific validation.
|
||||
type baseGroup struct {
|
||||
group string
|
||||
groupKey string
|
||||
|
||||
acceptCh chan net.Conn
|
||||
realLn net.Listener
|
||||
lns []*Listener
|
||||
mu sync.Mutex
|
||||
cleanupFn func()
|
||||
}
|
||||
|
||||
// initBase resets the baseGroup for a fresh listen cycle.
|
||||
// Must be called under mu when len(lns) == 0.
|
||||
func (bg *baseGroup) initBase(group, groupKey string, realLn net.Listener, cleanupFn func()) {
|
||||
bg.group = group
|
||||
bg.groupKey = groupKey
|
||||
bg.realLn = realLn
|
||||
bg.acceptCh = make(chan net.Conn)
|
||||
bg.cleanupFn = cleanupFn
|
||||
}
|
||||
|
||||
// worker reads from the real listener and fans out to acceptCh.
|
||||
// The parameters are captured at creation time so that the worker is
|
||||
// bound to a specific listen cycle and cannot observe a later initBase.
|
||||
func (bg *baseGroup) worker(realLn net.Listener, acceptCh chan<- net.Conn) {
|
||||
for {
|
||||
c, err := realLn.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = gerr.PanicToError(func() {
|
||||
acceptCh <- c
|
||||
})
|
||||
if err != nil {
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// newListener creates a new Listener wired to this baseGroup.
|
||||
// Must be called under mu.
|
||||
func (bg *baseGroup) newListener(addr net.Addr) *Listener {
|
||||
ln := newListener(bg.acceptCh, addr, bg.closeListener)
|
||||
bg.lns = append(bg.lns, ln)
|
||||
return ln
|
||||
}
|
||||
|
||||
// closeListener removes ln from the list. When the last listener is removed,
|
||||
// it closes acceptCh, closes the real listener, and calls cleanupFn.
|
||||
func (bg *baseGroup) closeListener(ln *Listener) {
|
||||
bg.mu.Lock()
|
||||
defer bg.mu.Unlock()
|
||||
for i, l := range bg.lns {
|
||||
if l == ln {
|
||||
bg.lns = append(bg.lns[:i], bg.lns[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(bg.lns) == 0 {
|
||||
close(bg.acceptCh)
|
||||
bg.realLn.Close()
|
||||
bg.cleanupFn()
|
||||
}
|
||||
}
|
||||
169
server/group/base_test.go
Normal file
169
server/group/base_test.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package group
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// fakeLn is a controllable net.Listener for tests.
|
||||
type fakeLn struct {
|
||||
connCh chan net.Conn
|
||||
closed chan struct{}
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func newFakeLn() *fakeLn {
|
||||
return &fakeLn{
|
||||
connCh: make(chan net.Conn, 8),
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeLn) Accept() (net.Conn, error) {
|
||||
select {
|
||||
case c := <-f.connCh:
|
||||
return c, nil
|
||||
case <-f.closed:
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeLn) Close() error {
|
||||
f.once.Do(func() { close(f.closed) })
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeLn) Addr() net.Addr { return fakeAddr("127.0.0.1:9999") }
|
||||
|
||||
func (f *fakeLn) inject(c net.Conn) {
|
||||
select {
|
||||
case f.connCh <- c:
|
||||
case <-f.closed:
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseGroup_WorkerFanOut(t *testing.T) {
|
||||
fl := newFakeLn()
|
||||
var bg baseGroup
|
||||
bg.initBase("g", "key", fl, func() {})
|
||||
|
||||
go bg.worker(fl, bg.acceptCh)
|
||||
|
||||
c1, c2 := net.Pipe()
|
||||
defer c2.Close()
|
||||
fl.inject(c1)
|
||||
|
||||
select {
|
||||
case got := <-bg.acceptCh:
|
||||
assert.Equal(t, c1, got)
|
||||
got.Close()
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for connection on acceptCh")
|
||||
}
|
||||
|
||||
fl.Close()
|
||||
}
|
||||
|
||||
func TestBaseGroup_WorkerStopsOnListenerClose(t *testing.T) {
|
||||
fl := newFakeLn()
|
||||
var bg baseGroup
|
||||
bg.initBase("g", "key", fl, func() {})
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
bg.worker(fl, bg.acceptCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
fl.Close()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("worker did not stop after listener close")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseGroup_WorkerClosesConnOnClosedChannel(t *testing.T) {
|
||||
fl := newFakeLn()
|
||||
var bg baseGroup
|
||||
bg.initBase("g", "key", fl, func() {})
|
||||
|
||||
// Close acceptCh before worker sends.
|
||||
close(bg.acceptCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
bg.worker(fl, bg.acceptCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
c1, c2 := net.Pipe()
|
||||
defer c2.Close()
|
||||
fl.inject(c1)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("worker did not stop after panic recovery")
|
||||
}
|
||||
|
||||
// c1 should have been closed by worker's panic recovery path.
|
||||
buf := make([]byte, 1)
|
||||
_, err := c1.Read(buf)
|
||||
assert.Error(t, err, "connection should be closed by worker")
|
||||
}
|
||||
|
||||
func TestBaseGroup_CloseLastListenerTriggersCleanup(t *testing.T) {
|
||||
fl := newFakeLn()
|
||||
var bg baseGroup
|
||||
cleanupCalled := 0
|
||||
bg.initBase("g", "key", fl, func() { cleanupCalled++ })
|
||||
|
||||
bg.mu.Lock()
|
||||
ln1 := bg.newListener(fl.Addr())
|
||||
ln2 := bg.newListener(fl.Addr())
|
||||
bg.mu.Unlock()
|
||||
|
||||
go bg.worker(fl, bg.acceptCh)
|
||||
|
||||
ln1.Close()
|
||||
assert.Equal(t, 0, cleanupCalled, "cleanup should not run while listeners remain")
|
||||
|
||||
ln2.Close()
|
||||
assert.Equal(t, 1, cleanupCalled, "cleanup should run after last listener closed")
|
||||
}
|
||||
|
||||
func TestBaseGroup_CloseOneOfTwoListeners(t *testing.T) {
|
||||
fl := newFakeLn()
|
||||
var bg baseGroup
|
||||
cleanupCalled := 0
|
||||
bg.initBase("g", "key", fl, func() { cleanupCalled++ })
|
||||
|
||||
bg.mu.Lock()
|
||||
ln1 := bg.newListener(fl.Addr())
|
||||
ln2 := bg.newListener(fl.Addr())
|
||||
bg.mu.Unlock()
|
||||
|
||||
go bg.worker(fl, bg.acceptCh)
|
||||
|
||||
ln1.Close()
|
||||
assert.Equal(t, 0, cleanupCalled)
|
||||
|
||||
// ln2 should still receive connections.
|
||||
c1, c2 := net.Pipe()
|
||||
defer c2.Close()
|
||||
fl.inject(c1)
|
||||
|
||||
got, err := ln2.Accept()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, c1, got)
|
||||
got.Close()
|
||||
|
||||
ln2.Close()
|
||||
assert.Equal(t, 1, cleanupCalled)
|
||||
}
|
||||
@@ -24,4 +24,6 @@ var (
|
||||
ErrListenerClosed = errors.New("group listener closed")
|
||||
ErrGroupDifferentPort = errors.New("group should have same remote port")
|
||||
ErrProxyRepeated = errors.New("group proxy repeated")
|
||||
|
||||
errGroupStale = errors.New("stale group reference")
|
||||
)
|
||||
|
||||
@@ -9,53 +9,42 @@ import (
|
||||
"github.com/fatedier/frp/pkg/util/vhost"
|
||||
)
|
||||
|
||||
// HTTPGroupController manages HTTP groups that use round-robin
|
||||
// callback routing (fundamentally different from listener-based groups).
|
||||
type HTTPGroupController struct {
|
||||
// groups indexed by group name
|
||||
groups map[string]*HTTPGroup
|
||||
|
||||
// register createConn for each group to vhostRouter.
|
||||
// createConn will get a connection from one proxy of the group
|
||||
groupRegistry[*HTTPGroup]
|
||||
vhostRouter *vhost.Routers
|
||||
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewHTTPGroupController(vhostRouter *vhost.Routers) *HTTPGroupController {
|
||||
return &HTTPGroupController{
|
||||
groups: make(map[string]*HTTPGroup),
|
||||
vhostRouter: vhostRouter,
|
||||
groupRegistry: newGroupRegistry[*HTTPGroup](),
|
||||
vhostRouter: vhostRouter,
|
||||
}
|
||||
}
|
||||
|
||||
func (ctl *HTTPGroupController) Register(
|
||||
proxyName, group, groupKey string,
|
||||
routeConfig vhost.RouteConfig,
|
||||
) (err error) {
|
||||
indexKey := group
|
||||
ctl.mu.Lock()
|
||||
g, ok := ctl.groups[indexKey]
|
||||
if !ok {
|
||||
g = NewHTTPGroup(ctl)
|
||||
ctl.groups[indexKey] = g
|
||||
) error {
|
||||
for {
|
||||
g := ctl.getOrCreate(group, func() *HTTPGroup {
|
||||
return NewHTTPGroup(ctl)
|
||||
})
|
||||
err := g.Register(proxyName, group, groupKey, routeConfig)
|
||||
if err == errGroupStale {
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
ctl.mu.Unlock()
|
||||
|
||||
return g.Register(proxyName, group, groupKey, routeConfig)
|
||||
}
|
||||
|
||||
func (ctl *HTTPGroupController) UnRegister(proxyName, group string, _ vhost.RouteConfig) {
|
||||
indexKey := group
|
||||
ctl.mu.Lock()
|
||||
defer ctl.mu.Unlock()
|
||||
g, ok := ctl.groups[indexKey]
|
||||
g, ok := ctl.get(group)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
isEmpty := g.UnRegister(proxyName)
|
||||
if isEmpty {
|
||||
delete(ctl.groups, indexKey)
|
||||
}
|
||||
g.UnRegister(proxyName)
|
||||
}
|
||||
|
||||
type HTTPGroup struct {
|
||||
@@ -87,6 +76,9 @@ func (g *HTTPGroup) Register(
|
||||
) (err error) {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
if !g.ctl.isCurrent(group, func(cur *HTTPGroup) bool { return cur == g }) {
|
||||
return errGroupStale
|
||||
}
|
||||
if len(g.createFuncs) == 0 {
|
||||
// the first proxy in this group
|
||||
tmp := routeConfig // copy object
|
||||
@@ -123,7 +115,7 @@ func (g *HTTPGroup) Register(
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *HTTPGroup) UnRegister(proxyName string) (isEmpty bool) {
|
||||
func (g *HTTPGroup) UnRegister(proxyName string) {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
delete(g.createFuncs, proxyName)
|
||||
@@ -135,10 +127,11 @@ func (g *HTTPGroup) UnRegister(proxyName string) (isEmpty bool) {
|
||||
}
|
||||
|
||||
if len(g.createFuncs) == 0 {
|
||||
isEmpty = true
|
||||
g.ctl.vhostRouter.Del(g.domain, g.location, g.routeByHTTPUser)
|
||||
g.ctl.removeIf(g.group, func(cur *HTTPGroup) bool {
|
||||
return cur == g
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (g *HTTPGroup) createConn(remoteAddr string) (net.Conn, error) {
|
||||
@@ -151,7 +144,7 @@ func (g *HTTPGroup) createConn(remoteAddr string) (net.Conn, error) {
|
||||
location := g.location
|
||||
routeByHTTPUser := g.routeByHTTPUser
|
||||
if len(g.pxyNames) > 0 {
|
||||
name := g.pxyNames[int(newIndex)%len(g.pxyNames)]
|
||||
name := g.pxyNames[newIndex%uint64(len(g.pxyNames))]
|
||||
f = g.createFuncs[name]
|
||||
}
|
||||
g.mu.RUnlock()
|
||||
@@ -174,7 +167,7 @@ func (g *HTTPGroup) chooseEndpoint() (string, error) {
|
||||
location := g.location
|
||||
routeByHTTPUser := g.routeByHTTPUser
|
||||
if len(g.pxyNames) > 0 {
|
||||
name = g.pxyNames[int(newIndex)%len(g.pxyNames)]
|
||||
name = g.pxyNames[newIndex%uint64(len(g.pxyNames))]
|
||||
}
|
||||
g.mu.RUnlock()
|
||||
|
||||
|
||||
@@ -17,25 +17,19 @@ package group
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
gerr "github.com/fatedier/golib/errors"
|
||||
|
||||
"github.com/fatedier/frp/pkg/util/vhost"
|
||||
)
|
||||
|
||||
type HTTPSGroupController struct {
|
||||
groups map[string]*HTTPSGroup
|
||||
|
||||
groupRegistry[*HTTPSGroup]
|
||||
httpsMuxer *vhost.HTTPSMuxer
|
||||
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewHTTPSGroupController(httpsMuxer *vhost.HTTPSMuxer) *HTTPSGroupController {
|
||||
return &HTTPSGroupController{
|
||||
groups: make(map[string]*HTTPSGroup),
|
||||
httpsMuxer: httpsMuxer,
|
||||
groupRegistry: newGroupRegistry[*HTTPSGroup](),
|
||||
httpsMuxer: httpsMuxer,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,41 +38,28 @@ func (ctl *HTTPSGroupController) Listen(
|
||||
group, groupKey string,
|
||||
routeConfig vhost.RouteConfig,
|
||||
) (l net.Listener, err error) {
|
||||
indexKey := group
|
||||
ctl.mu.Lock()
|
||||
g, ok := ctl.groups[indexKey]
|
||||
if !ok {
|
||||
g = NewHTTPSGroup(ctl)
|
||||
ctl.groups[indexKey] = g
|
||||
for {
|
||||
g := ctl.getOrCreate(group, func() *HTTPSGroup {
|
||||
return NewHTTPSGroup(ctl)
|
||||
})
|
||||
l, err = g.Listen(ctx, group, groupKey, routeConfig)
|
||||
if err == errGroupStale {
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
ctl.mu.Unlock()
|
||||
|
||||
return g.Listen(ctx, group, groupKey, routeConfig)
|
||||
}
|
||||
|
||||
func (ctl *HTTPSGroupController) RemoveGroup(group string) {
|
||||
ctl.mu.Lock()
|
||||
defer ctl.mu.Unlock()
|
||||
delete(ctl.groups, group)
|
||||
}
|
||||
|
||||
type HTTPSGroup struct {
|
||||
group string
|
||||
groupKey string
|
||||
domain string
|
||||
baseGroup
|
||||
|
||||
acceptCh chan net.Conn
|
||||
httpsLn *vhost.Listener
|
||||
lns []*HTTPSGroupListener
|
||||
ctl *HTTPSGroupController
|
||||
mu sync.Mutex
|
||||
domain string
|
||||
ctl *HTTPSGroupController
|
||||
}
|
||||
|
||||
func NewHTTPSGroup(ctl *HTTPSGroupController) *HTTPSGroup {
|
||||
return &HTTPSGroup{
|
||||
lns: make([]*HTTPSGroupListener, 0),
|
||||
ctl: ctl,
|
||||
acceptCh: make(chan net.Conn),
|
||||
ctl: ctl,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -86,23 +67,27 @@ func (g *HTTPSGroup) Listen(
|
||||
ctx context.Context,
|
||||
group, groupKey string,
|
||||
routeConfig vhost.RouteConfig,
|
||||
) (ln *HTTPSGroupListener, err error) {
|
||||
) (ln *Listener, err error) {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
if !g.ctl.isCurrent(group, func(cur *HTTPSGroup) bool { return cur == g }) {
|
||||
return nil, errGroupStale
|
||||
}
|
||||
if len(g.lns) == 0 {
|
||||
// the first listener, listen on the real address
|
||||
httpsLn, errRet := g.ctl.httpsMuxer.Listen(ctx, &routeConfig)
|
||||
if errRet != nil {
|
||||
return nil, errRet
|
||||
}
|
||||
ln = newHTTPSGroupListener(group, g, httpsLn.Addr())
|
||||
|
||||
g.group = group
|
||||
g.groupKey = groupKey
|
||||
g.domain = routeConfig.Domain
|
||||
g.httpsLn = httpsLn
|
||||
g.lns = append(g.lns, ln)
|
||||
go g.worker()
|
||||
g.initBase(group, groupKey, httpsLn, func() {
|
||||
g.ctl.removeIf(g.group, func(cur *HTTPSGroup) bool {
|
||||
return cur == g
|
||||
})
|
||||
})
|
||||
ln = g.newListener(httpsLn.Addr())
|
||||
go g.worker(httpsLn, g.acceptCh)
|
||||
} else {
|
||||
// route config in the same group must be equal
|
||||
if g.group != group || g.domain != routeConfig.Domain {
|
||||
@@ -111,87 +96,7 @@ func (g *HTTPSGroup) Listen(
|
||||
if g.groupKey != groupKey {
|
||||
return nil, ErrGroupAuthFailed
|
||||
}
|
||||
ln = newHTTPSGroupListener(group, g, g.lns[0].Addr())
|
||||
g.lns = append(g.lns, ln)
|
||||
ln = g.newListener(g.lns[0].Addr())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (g *HTTPSGroup) worker() {
|
||||
for {
|
||||
c, err := g.httpsLn.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = gerr.PanicToError(func() {
|
||||
g.acceptCh <- c
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *HTTPSGroup) Accept() <-chan net.Conn {
|
||||
return g.acceptCh
|
||||
}
|
||||
|
||||
func (g *HTTPSGroup) CloseListener(ln *HTTPSGroupListener) {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
for i, tmpLn := range g.lns {
|
||||
if tmpLn == ln {
|
||||
g.lns = append(g.lns[:i], g.lns[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(g.lns) == 0 {
|
||||
close(g.acceptCh)
|
||||
if g.httpsLn != nil {
|
||||
g.httpsLn.Close()
|
||||
}
|
||||
g.ctl.RemoveGroup(g.group)
|
||||
}
|
||||
}
|
||||
|
||||
type HTTPSGroupListener struct {
|
||||
groupName string
|
||||
group *HTTPSGroup
|
||||
|
||||
addr net.Addr
|
||||
closeCh chan struct{}
|
||||
}
|
||||
|
||||
func newHTTPSGroupListener(name string, group *HTTPSGroup, addr net.Addr) *HTTPSGroupListener {
|
||||
return &HTTPSGroupListener{
|
||||
groupName: name,
|
||||
group: group,
|
||||
addr: addr,
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (ln *HTTPSGroupListener) Accept() (c net.Conn, err error) {
|
||||
var ok bool
|
||||
select {
|
||||
case <-ln.closeCh:
|
||||
return nil, ErrListenerClosed
|
||||
case c, ok = <-ln.group.Accept():
|
||||
if !ok {
|
||||
return nil, ErrListenerClosed
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (ln *HTTPSGroupListener) Addr() net.Addr {
|
||||
return ln.addr
|
||||
}
|
||||
|
||||
func (ln *HTTPSGroupListener) Close() (err error) {
|
||||
close(ln.closeCh)
|
||||
|
||||
// remove self from HTTPSGroup
|
||||
ln.group.CloseListener(ln)
|
||||
return
|
||||
}
|
||||
|
||||
49
server/group/listener.go
Normal file
49
server/group/listener.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package group
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Listener is a per-proxy virtual listener that receives connections
|
||||
// from a shared group. It implements net.Listener.
|
||||
type Listener struct {
|
||||
acceptCh <-chan net.Conn
|
||||
addr net.Addr
|
||||
closeCh chan struct{}
|
||||
onClose func(*Listener)
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func newListener(acceptCh <-chan net.Conn, addr net.Addr, onClose func(*Listener)) *Listener {
|
||||
return &Listener{
|
||||
acceptCh: acceptCh,
|
||||
addr: addr,
|
||||
closeCh: make(chan struct{}),
|
||||
onClose: onClose,
|
||||
}
|
||||
}
|
||||
|
||||
func (ln *Listener) Accept() (net.Conn, error) {
|
||||
select {
|
||||
case <-ln.closeCh:
|
||||
return nil, ErrListenerClosed
|
||||
case c, ok := <-ln.acceptCh:
|
||||
if !ok {
|
||||
return nil, ErrListenerClosed
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (ln *Listener) Addr() net.Addr {
|
||||
return ln.addr
|
||||
}
|
||||
|
||||
func (ln *Listener) Close() error {
|
||||
ln.once.Do(func() {
|
||||
close(ln.closeCh)
|
||||
ln.onClose(ln)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
68
server/group/listener_test.go
Normal file
68
server/group/listener_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package group
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestListener_Accept(t *testing.T) {
|
||||
acceptCh := make(chan net.Conn, 1)
|
||||
ln := newListener(acceptCh, fakeAddr("127.0.0.1:1234"), func(*Listener) {})
|
||||
|
||||
c1, c2 := net.Pipe()
|
||||
defer c1.Close()
|
||||
defer c2.Close()
|
||||
|
||||
acceptCh <- c1
|
||||
got, err := ln.Accept()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, c1, got)
|
||||
}
|
||||
|
||||
func TestListener_AcceptAfterChannelClose(t *testing.T) {
|
||||
acceptCh := make(chan net.Conn)
|
||||
ln := newListener(acceptCh, fakeAddr("127.0.0.1:1234"), func(*Listener) {})
|
||||
|
||||
close(acceptCh)
|
||||
_, err := ln.Accept()
|
||||
assert.ErrorIs(t, err, ErrListenerClosed)
|
||||
}
|
||||
|
||||
func TestListener_AcceptAfterListenerClose(t *testing.T) {
|
||||
acceptCh := make(chan net.Conn) // open, not closed
|
||||
ln := newListener(acceptCh, fakeAddr("127.0.0.1:1234"), func(*Listener) {})
|
||||
|
||||
ln.Close()
|
||||
_, err := ln.Accept()
|
||||
assert.ErrorIs(t, err, ErrListenerClosed)
|
||||
}
|
||||
|
||||
func TestListener_DoubleClose(t *testing.T) {
|
||||
closeCalls := 0
|
||||
ln := newListener(
|
||||
make(chan net.Conn),
|
||||
fakeAddr("127.0.0.1:1234"),
|
||||
func(*Listener) { closeCalls++ },
|
||||
)
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
ln.Close()
|
||||
ln.Close()
|
||||
})
|
||||
assert.Equal(t, 1, closeCalls, "onClose should be called exactly once")
|
||||
}
|
||||
|
||||
func TestListener_Addr(t *testing.T) {
|
||||
addr := fakeAddr("10.0.0.1:5555")
|
||||
ln := newListener(make(chan net.Conn), addr, func(*Listener) {})
|
||||
assert.Equal(t, addr, ln.Addr())
|
||||
}
|
||||
|
||||
// fakeAddr implements net.Addr for testing.
|
||||
type fakeAddr string
|
||||
|
||||
func (a fakeAddr) Network() string { return "tcp" }
|
||||
func (a fakeAddr) String() string { return string(a) }
|
||||
59
server/group/registry.go
Normal file
59
server/group/registry.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package group
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// groupRegistry is a concurrent map of named groups with
|
||||
// automatic creation on first access.
|
||||
type groupRegistry[G any] struct {
|
||||
groups map[string]G
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newGroupRegistry[G any]() groupRegistry[G] {
|
||||
return groupRegistry[G]{
|
||||
groups: make(map[string]G),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *groupRegistry[G]) getOrCreate(key string, newFn func() G) G {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
g, ok := r.groups[key]
|
||||
if !ok {
|
||||
g = newFn()
|
||||
r.groups[key] = g
|
||||
}
|
||||
return g
|
||||
}
|
||||
|
||||
func (r *groupRegistry[G]) get(key string) (G, bool) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
g, ok := r.groups[key]
|
||||
return g, ok
|
||||
}
|
||||
|
||||
// isCurrent returns true if key exists in the registry and matchFn
|
||||
// returns true for the stored value.
|
||||
func (r *groupRegistry[G]) isCurrent(key string, matchFn func(G) bool) bool {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
g, ok := r.groups[key]
|
||||
return ok && matchFn(g)
|
||||
}
|
||||
|
||||
// removeIf atomically looks up the group for key, calls fn on it,
|
||||
// and removes the entry if fn returns true.
|
||||
func (r *groupRegistry[G]) removeIf(key string, fn func(G) bool) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
g, ok := r.groups[key]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if fn(g) {
|
||||
delete(r.groups, key)
|
||||
}
|
||||
}
|
||||
102
server/group/registry_test.go
Normal file
102
server/group/registry_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package group
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetOrCreate_New(t *testing.T) {
|
||||
r := newGroupRegistry[*int]()
|
||||
called := 0
|
||||
v := 42
|
||||
got := r.getOrCreate("k", func() *int { called++; return &v })
|
||||
assert.Equal(t, 1, called)
|
||||
assert.Equal(t, &v, got)
|
||||
}
|
||||
|
||||
func TestGetOrCreate_Existing(t *testing.T) {
|
||||
r := newGroupRegistry[*int]()
|
||||
v := 42
|
||||
r.getOrCreate("k", func() *int { return &v })
|
||||
|
||||
called := 0
|
||||
got := r.getOrCreate("k", func() *int { called++; return nil })
|
||||
assert.Equal(t, 0, called)
|
||||
assert.Equal(t, &v, got)
|
||||
}
|
||||
|
||||
func TestGet_ExistingAndMissing(t *testing.T) {
|
||||
r := newGroupRegistry[*int]()
|
||||
v := 1
|
||||
r.getOrCreate("k", func() *int { return &v })
|
||||
|
||||
got, ok := r.get("k")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, &v, got)
|
||||
|
||||
_, ok = r.get("missing")
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestIsCurrent(t *testing.T) {
|
||||
r := newGroupRegistry[*int]()
|
||||
v1 := 1
|
||||
v2 := 2
|
||||
r.getOrCreate("k", func() *int { return &v1 })
|
||||
|
||||
assert.True(t, r.isCurrent("k", func(g *int) bool { return g == &v1 }))
|
||||
assert.False(t, r.isCurrent("k", func(g *int) bool { return g == &v2 }))
|
||||
assert.False(t, r.isCurrent("missing", func(g *int) bool { return true }))
|
||||
}
|
||||
|
||||
func TestRemoveIf(t *testing.T) {
|
||||
t.Run("removes when fn returns true", func(t *testing.T) {
|
||||
r := newGroupRegistry[*int]()
|
||||
v := 1
|
||||
r.getOrCreate("k", func() *int { return &v })
|
||||
r.removeIf("k", func(g *int) bool { return g == &v })
|
||||
_, ok := r.get("k")
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("keeps when fn returns false", func(t *testing.T) {
|
||||
r := newGroupRegistry[*int]()
|
||||
v := 1
|
||||
r.getOrCreate("k", func() *int { return &v })
|
||||
r.removeIf("k", func(g *int) bool { return false })
|
||||
_, ok := r.get("k")
|
||||
assert.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("noop on missing key", func(t *testing.T) {
|
||||
r := newGroupRegistry[*int]()
|
||||
r.removeIf("missing", func(g *int) bool { return true }) // should not panic
|
||||
})
|
||||
}
|
||||
|
||||
func TestConcurrentGetOrCreateAndRemoveIf(t *testing.T) {
|
||||
r := newGroupRegistry[*int]()
|
||||
const n = 100
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(n * 2)
|
||||
for i := range n {
|
||||
v := i
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
r.getOrCreate("k", func() *int { return &v })
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
r.removeIf("k", func(*int) bool { return true })
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// After all goroutines finish, accessing the key must not panic.
|
||||
require.NotPanics(t, func() {
|
||||
_, _ = r.get("k")
|
||||
})
|
||||
}
|
||||
@@ -17,107 +17,91 @@ package group
|
||||
import (
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
gerr "github.com/fatedier/golib/errors"
|
||||
|
||||
"github.com/fatedier/frp/server/ports"
|
||||
)
|
||||
|
||||
// TCPGroupCtl manage all TCPGroups
|
||||
// TCPGroupCtl manages all TCPGroups.
|
||||
type TCPGroupCtl struct {
|
||||
groups map[string]*TCPGroup
|
||||
|
||||
// portManager is used to manage port
|
||||
groupRegistry[*TCPGroup]
|
||||
portManager *ports.Manager
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewTCPGroupCtl return a new TcpGroupCtl
|
||||
// NewTCPGroupCtl returns a new TCPGroupCtl.
|
||||
func NewTCPGroupCtl(portManager *ports.Manager) *TCPGroupCtl {
|
||||
return &TCPGroupCtl{
|
||||
groups: make(map[string]*TCPGroup),
|
||||
portManager: portManager,
|
||||
groupRegistry: newGroupRegistry[*TCPGroup](),
|
||||
portManager: portManager,
|
||||
}
|
||||
}
|
||||
|
||||
// Listen is the wrapper for TCPGroup's Listen
|
||||
// If there are no group, we will create one here
|
||||
// Listen is the wrapper for TCPGroup's Listen.
|
||||
// If there is no group, one will be created.
|
||||
func (tgc *TCPGroupCtl) Listen(proxyName string, group string, groupKey string,
|
||||
addr string, port int,
|
||||
) (l net.Listener, realPort int, err error) {
|
||||
tgc.mu.Lock()
|
||||
tcpGroup, ok := tgc.groups[group]
|
||||
if !ok {
|
||||
tcpGroup = NewTCPGroup(tgc)
|
||||
tgc.groups[group] = tcpGroup
|
||||
for {
|
||||
tcpGroup := tgc.getOrCreate(group, func() *TCPGroup {
|
||||
return NewTCPGroup(tgc)
|
||||
})
|
||||
l, realPort, err = tcpGroup.Listen(proxyName, group, groupKey, addr, port)
|
||||
if err == errGroupStale {
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
tgc.mu.Unlock()
|
||||
|
||||
return tcpGroup.Listen(proxyName, group, groupKey, addr, port)
|
||||
}
|
||||
|
||||
// RemoveGroup remove TCPGroup from controller
|
||||
func (tgc *TCPGroupCtl) RemoveGroup(group string) {
|
||||
tgc.mu.Lock()
|
||||
defer tgc.mu.Unlock()
|
||||
delete(tgc.groups, group)
|
||||
}
|
||||
|
||||
// TCPGroup route connections to different proxies
|
||||
// TCPGroup routes connections to different proxies.
|
||||
type TCPGroup struct {
|
||||
group string
|
||||
groupKey string
|
||||
baseGroup
|
||||
|
||||
addr string
|
||||
port int
|
||||
realPort int
|
||||
|
||||
acceptCh chan net.Conn
|
||||
tcpLn net.Listener
|
||||
lns []*TCPGroupListener
|
||||
ctl *TCPGroupCtl
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewTCPGroup return a new TCPGroup
|
||||
// NewTCPGroup returns a new TCPGroup.
|
||||
func NewTCPGroup(ctl *TCPGroupCtl) *TCPGroup {
|
||||
return &TCPGroup{
|
||||
lns: make([]*TCPGroupListener, 0),
|
||||
ctl: ctl,
|
||||
acceptCh: make(chan net.Conn),
|
||||
ctl: ctl,
|
||||
}
|
||||
}
|
||||
|
||||
// Listen will return a new TCPGroupListener
|
||||
// if TCPGroup already has a listener, just add a new TCPGroupListener to the queues
|
||||
// otherwise, listen on the real address
|
||||
func (tg *TCPGroup) Listen(proxyName string, group string, groupKey string, addr string, port int) (ln *TCPGroupListener, realPort int, err error) {
|
||||
// Listen will return a new Listener.
|
||||
// If TCPGroup already has a listener, just add a new Listener to the queues,
|
||||
// otherwise listen on the real address.
|
||||
func (tg *TCPGroup) Listen(proxyName string, group string, groupKey string, addr string, port int) (ln *Listener, realPort int, err error) {
|
||||
tg.mu.Lock()
|
||||
defer tg.mu.Unlock()
|
||||
if !tg.ctl.isCurrent(group, func(cur *TCPGroup) bool { return cur == tg }) {
|
||||
return nil, 0, errGroupStale
|
||||
}
|
||||
if len(tg.lns) == 0 {
|
||||
// the first listener, listen on the real address
|
||||
realPort, err = tg.ctl.portManager.Acquire(proxyName, port)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
tcpLn, errRet := net.Listen("tcp", net.JoinHostPort(addr, strconv.Itoa(port)))
|
||||
tcpLn, errRet := net.Listen("tcp", net.JoinHostPort(addr, strconv.Itoa(realPort)))
|
||||
if errRet != nil {
|
||||
tg.ctl.portManager.Release(realPort)
|
||||
err = errRet
|
||||
return
|
||||
}
|
||||
ln = newTCPGroupListener(group, tg, tcpLn.Addr())
|
||||
|
||||
tg.group = group
|
||||
tg.groupKey = groupKey
|
||||
tg.addr = addr
|
||||
tg.port = port
|
||||
tg.realPort = realPort
|
||||
tg.tcpLn = tcpLn
|
||||
tg.lns = append(tg.lns, ln)
|
||||
if tg.acceptCh == nil {
|
||||
tg.acceptCh = make(chan net.Conn)
|
||||
}
|
||||
go tg.worker()
|
||||
tg.initBase(group, groupKey, tcpLn, func() {
|
||||
tg.ctl.portManager.Release(tg.realPort)
|
||||
tg.ctl.removeIf(tg.group, func(cur *TCPGroup) bool {
|
||||
return cur == tg
|
||||
})
|
||||
})
|
||||
ln = tg.newListener(tcpLn.Addr())
|
||||
go tg.worker(tcpLn, tg.acceptCh)
|
||||
} else {
|
||||
// address and port in the same group must be equal
|
||||
if tg.group != group || tg.addr != addr {
|
||||
@@ -132,92 +116,8 @@ func (tg *TCPGroup) Listen(proxyName string, group string, groupKey string, addr
|
||||
err = ErrGroupAuthFailed
|
||||
return
|
||||
}
|
||||
ln = newTCPGroupListener(group, tg, tg.lns[0].Addr())
|
||||
ln = tg.newListener(tg.lns[0].Addr())
|
||||
realPort = tg.realPort
|
||||
tg.lns = append(tg.lns, ln)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// worker is called when the real tcp listener has been created
|
||||
func (tg *TCPGroup) worker() {
|
||||
for {
|
||||
c, err := tg.tcpLn.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = gerr.PanicToError(func() {
|
||||
tg.acceptCh <- c
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (tg *TCPGroup) Accept() <-chan net.Conn {
|
||||
return tg.acceptCh
|
||||
}
|
||||
|
||||
// CloseListener remove the TCPGroupListener from the TCPGroup
|
||||
func (tg *TCPGroup) CloseListener(ln *TCPGroupListener) {
|
||||
tg.mu.Lock()
|
||||
defer tg.mu.Unlock()
|
||||
for i, tmpLn := range tg.lns {
|
||||
if tmpLn == ln {
|
||||
tg.lns = append(tg.lns[:i], tg.lns[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(tg.lns) == 0 {
|
||||
close(tg.acceptCh)
|
||||
tg.tcpLn.Close()
|
||||
tg.ctl.portManager.Release(tg.realPort)
|
||||
tg.ctl.RemoveGroup(tg.group)
|
||||
}
|
||||
}
|
||||
|
||||
// TCPGroupListener
|
||||
type TCPGroupListener struct {
|
||||
groupName string
|
||||
group *TCPGroup
|
||||
|
||||
addr net.Addr
|
||||
closeCh chan struct{}
|
||||
}
|
||||
|
||||
func newTCPGroupListener(name string, group *TCPGroup, addr net.Addr) *TCPGroupListener {
|
||||
return &TCPGroupListener{
|
||||
groupName: name,
|
||||
group: group,
|
||||
addr: addr,
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Accept will accept connections from TCPGroup
|
||||
func (ln *TCPGroupListener) Accept() (c net.Conn, err error) {
|
||||
var ok bool
|
||||
select {
|
||||
case <-ln.closeCh:
|
||||
return nil, ErrListenerClosed
|
||||
case c, ok = <-ln.group.Accept():
|
||||
if !ok {
|
||||
return nil, ErrListenerClosed
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (ln *TCPGroupListener) Addr() net.Addr {
|
||||
return ln.addr
|
||||
}
|
||||
|
||||
// Close close the listener
|
||||
func (ln *TCPGroupListener) Close() (err error) {
|
||||
close(ln.closeCh)
|
||||
|
||||
// remove self from TcpGroup
|
||||
ln.group.CloseListener(ln)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -18,118 +18,100 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
gerr "github.com/fatedier/golib/errors"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/util/tcpmux"
|
||||
"github.com/fatedier/frp/pkg/util/vhost"
|
||||
)
|
||||
|
||||
// TCPMuxGroupCtl manage all TCPMuxGroups
|
||||
// TCPMuxGroupCtl manages all TCPMuxGroups.
|
||||
type TCPMuxGroupCtl struct {
|
||||
groups map[string]*TCPMuxGroup
|
||||
|
||||
// portManager is used to manage port
|
||||
groupRegistry[*TCPMuxGroup]
|
||||
tcpMuxHTTPConnectMuxer *tcpmux.HTTPConnectTCPMuxer
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewTCPMuxGroupCtl return a new TCPMuxGroupCtl
|
||||
// NewTCPMuxGroupCtl returns a new TCPMuxGroupCtl.
|
||||
func NewTCPMuxGroupCtl(tcpMuxHTTPConnectMuxer *tcpmux.HTTPConnectTCPMuxer) *TCPMuxGroupCtl {
|
||||
return &TCPMuxGroupCtl{
|
||||
groups: make(map[string]*TCPMuxGroup),
|
||||
groupRegistry: newGroupRegistry[*TCPMuxGroup](),
|
||||
tcpMuxHTTPConnectMuxer: tcpMuxHTTPConnectMuxer,
|
||||
}
|
||||
}
|
||||
|
||||
// Listen is the wrapper for TCPMuxGroup's Listen
|
||||
// If there are no group, we will create one here
|
||||
// Listen is the wrapper for TCPMuxGroup's Listen.
|
||||
// If there is no group, one will be created.
|
||||
func (tmgc *TCPMuxGroupCtl) Listen(
|
||||
ctx context.Context,
|
||||
multiplexer, group, groupKey string,
|
||||
routeConfig vhost.RouteConfig,
|
||||
) (l net.Listener, err error) {
|
||||
tmgc.mu.Lock()
|
||||
tcpMuxGroup, ok := tmgc.groups[group]
|
||||
if !ok {
|
||||
tcpMuxGroup = NewTCPMuxGroup(tmgc)
|
||||
tmgc.groups[group] = tcpMuxGroup
|
||||
}
|
||||
tmgc.mu.Unlock()
|
||||
for {
|
||||
tcpMuxGroup := tmgc.getOrCreate(group, func() *TCPMuxGroup {
|
||||
return NewTCPMuxGroup(tmgc)
|
||||
})
|
||||
|
||||
switch v1.TCPMultiplexerType(multiplexer) {
|
||||
case v1.TCPMultiplexerHTTPConnect:
|
||||
return tcpMuxGroup.HTTPConnectListen(ctx, group, groupKey, routeConfig)
|
||||
default:
|
||||
err = fmt.Errorf("unknown multiplexer [%s]", multiplexer)
|
||||
return
|
||||
switch v1.TCPMultiplexerType(multiplexer) {
|
||||
case v1.TCPMultiplexerHTTPConnect:
|
||||
l, err = tcpMuxGroup.HTTPConnectListen(ctx, group, groupKey, routeConfig)
|
||||
if err == errGroupStale {
|
||||
continue
|
||||
}
|
||||
return
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown multiplexer [%s]", multiplexer)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveGroup remove TCPMuxGroup from controller
|
||||
func (tmgc *TCPMuxGroupCtl) RemoveGroup(group string) {
|
||||
tmgc.mu.Lock()
|
||||
defer tmgc.mu.Unlock()
|
||||
delete(tmgc.groups, group)
|
||||
}
|
||||
|
||||
// TCPMuxGroup route connections to different proxies
|
||||
// TCPMuxGroup routes connections to different proxies.
|
||||
type TCPMuxGroup struct {
|
||||
group string
|
||||
groupKey string
|
||||
baseGroup
|
||||
|
||||
domain string
|
||||
routeByHTTPUser string
|
||||
username string
|
||||
password string
|
||||
|
||||
acceptCh chan net.Conn
|
||||
tcpMuxLn net.Listener
|
||||
lns []*TCPMuxGroupListener
|
||||
ctl *TCPMuxGroupCtl
|
||||
mu sync.Mutex
|
||||
ctl *TCPMuxGroupCtl
|
||||
}
|
||||
|
||||
// NewTCPMuxGroup return a new TCPMuxGroup
|
||||
// NewTCPMuxGroup returns a new TCPMuxGroup.
|
||||
func NewTCPMuxGroup(ctl *TCPMuxGroupCtl) *TCPMuxGroup {
|
||||
return &TCPMuxGroup{
|
||||
lns: make([]*TCPMuxGroupListener, 0),
|
||||
ctl: ctl,
|
||||
acceptCh: make(chan net.Conn),
|
||||
ctl: ctl,
|
||||
}
|
||||
}
|
||||
|
||||
// Listen will return a new TCPMuxGroupListener
|
||||
// if TCPMuxGroup already has a listener, just add a new TCPMuxGroupListener to the queues
|
||||
// otherwise, listen on the real address
|
||||
// HTTPConnectListen will return a new Listener.
|
||||
// If TCPMuxGroup already has a listener, just add a new Listener to the queues,
|
||||
// otherwise listen on the real address.
|
||||
func (tmg *TCPMuxGroup) HTTPConnectListen(
|
||||
ctx context.Context,
|
||||
group, groupKey string,
|
||||
routeConfig vhost.RouteConfig,
|
||||
) (ln *TCPMuxGroupListener, err error) {
|
||||
) (ln *Listener, err error) {
|
||||
tmg.mu.Lock()
|
||||
defer tmg.mu.Unlock()
|
||||
if !tmg.ctl.isCurrent(group, func(cur *TCPMuxGroup) bool { return cur == tmg }) {
|
||||
return nil, errGroupStale
|
||||
}
|
||||
if len(tmg.lns) == 0 {
|
||||
// the first listener, listen on the real address
|
||||
tcpMuxLn, errRet := tmg.ctl.tcpMuxHTTPConnectMuxer.Listen(ctx, &routeConfig)
|
||||
if errRet != nil {
|
||||
return nil, errRet
|
||||
}
|
||||
ln = newTCPMuxGroupListener(group, tmg, tcpMuxLn.Addr())
|
||||
|
||||
tmg.group = group
|
||||
tmg.groupKey = groupKey
|
||||
tmg.domain = routeConfig.Domain
|
||||
tmg.routeByHTTPUser = routeConfig.RouteByHTTPUser
|
||||
tmg.username = routeConfig.Username
|
||||
tmg.password = routeConfig.Password
|
||||
tmg.tcpMuxLn = tcpMuxLn
|
||||
tmg.lns = append(tmg.lns, ln)
|
||||
if tmg.acceptCh == nil {
|
||||
tmg.acceptCh = make(chan net.Conn)
|
||||
}
|
||||
go tmg.worker()
|
||||
tmg.initBase(group, groupKey, tcpMuxLn, func() {
|
||||
tmg.ctl.removeIf(tmg.group, func(cur *TCPMuxGroup) bool {
|
||||
return cur == tmg
|
||||
})
|
||||
})
|
||||
ln = tmg.newListener(tcpMuxLn.Addr())
|
||||
go tmg.worker(tcpMuxLn, tmg.acceptCh)
|
||||
} else {
|
||||
// route config in the same group must be equal
|
||||
if tmg.group != group || tmg.domain != routeConfig.Domain ||
|
||||
@@ -141,90 +123,7 @@ func (tmg *TCPMuxGroup) HTTPConnectListen(
|
||||
if tmg.groupKey != groupKey {
|
||||
return nil, ErrGroupAuthFailed
|
||||
}
|
||||
ln = newTCPMuxGroupListener(group, tmg, tmg.lns[0].Addr())
|
||||
tmg.lns = append(tmg.lns, ln)
|
||||
ln = tmg.newListener(tmg.lns[0].Addr())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// worker is called when the real TCP listener has been created
|
||||
func (tmg *TCPMuxGroup) worker() {
|
||||
for {
|
||||
c, err := tmg.tcpMuxLn.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = gerr.PanicToError(func() {
|
||||
tmg.acceptCh <- c
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (tmg *TCPMuxGroup) Accept() <-chan net.Conn {
|
||||
return tmg.acceptCh
|
||||
}
|
||||
|
||||
// CloseListener remove the TCPMuxGroupListener from the TCPMuxGroup
|
||||
func (tmg *TCPMuxGroup) CloseListener(ln *TCPMuxGroupListener) {
|
||||
tmg.mu.Lock()
|
||||
defer tmg.mu.Unlock()
|
||||
for i, tmpLn := range tmg.lns {
|
||||
if tmpLn == ln {
|
||||
tmg.lns = append(tmg.lns[:i], tmg.lns[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(tmg.lns) == 0 {
|
||||
close(tmg.acceptCh)
|
||||
tmg.tcpMuxLn.Close()
|
||||
tmg.ctl.RemoveGroup(tmg.group)
|
||||
}
|
||||
}
|
||||
|
||||
// TCPMuxGroupListener
|
||||
type TCPMuxGroupListener struct {
|
||||
groupName string
|
||||
group *TCPMuxGroup
|
||||
|
||||
addr net.Addr
|
||||
closeCh chan struct{}
|
||||
}
|
||||
|
||||
func newTCPMuxGroupListener(name string, group *TCPMuxGroup, addr net.Addr) *TCPMuxGroupListener {
|
||||
return &TCPMuxGroupListener{
|
||||
groupName: name,
|
||||
group: group,
|
||||
addr: addr,
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Accept will accept connections from TCPMuxGroup
|
||||
func (ln *TCPMuxGroupListener) Accept() (c net.Conn, err error) {
|
||||
var ok bool
|
||||
select {
|
||||
case <-ln.closeCh:
|
||||
return nil, ErrListenerClosed
|
||||
case c, ok = <-ln.group.Accept():
|
||||
if !ok {
|
||||
return nil, ErrListenerClosed
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (ln *TCPMuxGroupListener) Addr() net.Addr {
|
||||
return ln.addr
|
||||
}
|
||||
|
||||
// Close close the listener
|
||||
func (ln *TCPMuxGroupListener) Close() (err error) {
|
||||
close(ln.closeCh)
|
||||
|
||||
// remove self from TcpMuxGroup
|
||||
ln.group.CloseListener(ln)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
RegisterProxyFactory(reflect.TypeOf(&v1.HTTPProxyConfig{}), NewHTTPProxy)
|
||||
RegisterProxyFactory(reflect.TypeFor[*v1.HTTPProxyConfig](), NewHTTPProxy)
|
||||
}
|
||||
|
||||
type HTTPProxy struct {
|
||||
@@ -75,16 +75,13 @@ func (pxy *HTTPProxy) Run() (remoteAddr string, err error) {
|
||||
}
|
||||
}()
|
||||
|
||||
addrs := make([]string, 0)
|
||||
for _, domain := range pxy.cfg.CustomDomains {
|
||||
if domain == "" {
|
||||
continue
|
||||
}
|
||||
domains := pxy.buildDomains(pxy.cfg.CustomDomains, pxy.cfg.SubDomain)
|
||||
|
||||
addrs := make([]string, 0)
|
||||
for _, domain := range domains {
|
||||
routeConfig.Domain = domain
|
||||
for _, location := range locations {
|
||||
routeConfig.Location = location
|
||||
|
||||
tmpRouteConfig := routeConfig
|
||||
|
||||
// handle group
|
||||
@@ -93,12 +90,10 @@ func (pxy *HTTPProxy) Run() (remoteAddr string, err error) {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
pxy.closeFuncs = append(pxy.closeFuncs, func() {
|
||||
pxy.rc.HTTPGroupCtl.UnRegister(pxy.name, pxy.cfg.LoadBalancer.Group, tmpRouteConfig)
|
||||
})
|
||||
} else {
|
||||
// no group
|
||||
err = pxy.rc.HTTPReverseProxy.Register(routeConfig)
|
||||
if err != nil {
|
||||
return
|
||||
@@ -112,39 +107,6 @@ func (pxy *HTTPProxy) Run() (remoteAddr string, err error) {
|
||||
routeConfig.Domain, routeConfig.Location, pxy.cfg.LoadBalancer.Group, pxy.cfg.RouteByHTTPUser)
|
||||
}
|
||||
}
|
||||
|
||||
if pxy.cfg.SubDomain != "" {
|
||||
routeConfig.Domain = pxy.cfg.SubDomain + "." + pxy.serverCfg.SubDomainHost
|
||||
for _, location := range locations {
|
||||
routeConfig.Location = location
|
||||
|
||||
tmpRouteConfig := routeConfig
|
||||
|
||||
// handle group
|
||||
if pxy.cfg.LoadBalancer.Group != "" {
|
||||
err = pxy.rc.HTTPGroupCtl.Register(pxy.name, pxy.cfg.LoadBalancer.Group, pxy.cfg.LoadBalancer.GroupKey, routeConfig)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
pxy.closeFuncs = append(pxy.closeFuncs, func() {
|
||||
pxy.rc.HTTPGroupCtl.UnRegister(pxy.name, pxy.cfg.LoadBalancer.Group, tmpRouteConfig)
|
||||
})
|
||||
} else {
|
||||
err = pxy.rc.HTTPReverseProxy.Register(routeConfig)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
pxy.closeFuncs = append(pxy.closeFuncs, func() {
|
||||
pxy.rc.HTTPReverseProxy.UnRegister(tmpRouteConfig)
|
||||
})
|
||||
}
|
||||
addrs = append(addrs, util.CanonicalAddr(tmpRouteConfig.Domain, pxy.serverCfg.VhostHTTPPort))
|
||||
|
||||
xl.Infof("http proxy listen for host [%s] location [%s] group [%s], routeByHTTPUser [%s]",
|
||||
routeConfig.Domain, routeConfig.Location, pxy.cfg.LoadBalancer.Group, pxy.cfg.RouteByHTTPUser)
|
||||
}
|
||||
}
|
||||
remoteAddr = strings.Join(addrs, ",")
|
||||
return
|
||||
}
|
||||
@@ -168,6 +130,7 @@ func (pxy *HTTPProxy) GetRealConn(remoteAddr string) (workConn net.Conn, err err
|
||||
rwc, err = libio.WithEncryption(rwc, pxy.encryptionKey)
|
||||
if err != nil {
|
||||
xl.Errorf("create encryption stream error: %v", err)
|
||||
tmpConn.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
RegisterProxyFactory(reflect.TypeOf(&v1.HTTPSProxyConfig{}), NewHTTPSProxy)
|
||||
RegisterProxyFactory(reflect.TypeFor[*v1.HTTPSProxyConfig](), NewHTTPSProxy)
|
||||
}
|
||||
|
||||
type HTTPSProxy struct {
|
||||
@@ -53,23 +53,10 @@ func (pxy *HTTPSProxy) Run() (remoteAddr string, err error) {
|
||||
pxy.Close()
|
||||
}
|
||||
}()
|
||||
domains := pxy.buildDomains(pxy.cfg.CustomDomains, pxy.cfg.SubDomain)
|
||||
|
||||
addrs := make([]string, 0)
|
||||
for _, domain := range pxy.cfg.CustomDomains {
|
||||
if domain == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
l, err := pxy.listenForDomain(routeConfig, domain)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
pxy.listeners = append(pxy.listeners, l)
|
||||
addrs = append(addrs, util.CanonicalAddr(domain, pxy.serverCfg.VhostHTTPSPort))
|
||||
xl.Infof("https proxy listen for host [%s] group [%s]", domain, pxy.cfg.LoadBalancer.Group)
|
||||
}
|
||||
|
||||
if pxy.cfg.SubDomain != "" {
|
||||
domain := pxy.cfg.SubDomain + "." + pxy.serverCfg.SubDomainHost
|
||||
for _, domain := range domains {
|
||||
l, err := pxy.listenForDomain(routeConfig, domain)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
||||
@@ -150,7 +150,7 @@ func (pxy *BaseProxy) GetWorkConnFromPool(src, dst net.Addr) (workConn net.Conn,
|
||||
dstAddr, dstPortStr, _ = net.SplitHostPort(dst.String())
|
||||
dstPort, _ = strconv.ParseUint(dstPortStr, 10, 16)
|
||||
}
|
||||
err := msg.WriteMsg(workConn, &msg.StartWorkConn{
|
||||
err = msg.WriteMsg(workConn, &msg.StartWorkConn{
|
||||
ProxyName: pxy.GetName(),
|
||||
SrcAddr: srcAddr,
|
||||
SrcPort: uint16(srcPort),
|
||||
@@ -161,6 +161,7 @@ func (pxy *BaseProxy) GetWorkConnFromPool(src, dst net.Addr) (workConn net.Conn,
|
||||
if err != nil {
|
||||
xl.Warnf("failed to send message to work connection from pool: %v, times: %d", err, i)
|
||||
workConn.Close()
|
||||
workConn = nil
|
||||
} else {
|
||||
break
|
||||
}
|
||||
@@ -173,6 +174,36 @@ func (pxy *BaseProxy) GetWorkConnFromPool(src, dst net.Addr) (workConn net.Conn,
|
||||
return
|
||||
}
|
||||
|
||||
// startVisitorListener sets up a VisitorManager listener for visitor-based proxies (STCP, SUDP).
|
||||
func (pxy *BaseProxy) startVisitorListener(secretKey string, allowUsers []string, proxyType string) error {
|
||||
// if allowUsers is empty, only allow same user from proxy
|
||||
if len(allowUsers) == 0 {
|
||||
allowUsers = []string{pxy.GetUserInfo().User}
|
||||
}
|
||||
listener, err := pxy.rc.VisitorManager.Listen(pxy.GetName(), secretKey, allowUsers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pxy.listeners = append(pxy.listeners, listener)
|
||||
pxy.xl.Infof("%s proxy custom listen success", proxyType)
|
||||
pxy.startCommonTCPListenersHandler()
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildDomains constructs a list of domains from custom domains and subdomain configuration.
|
||||
func (pxy *BaseProxy) buildDomains(customDomains []string, subDomain string) []string {
|
||||
domains := make([]string, 0, len(customDomains)+1)
|
||||
for _, d := range customDomains {
|
||||
if d != "" {
|
||||
domains = append(domains, d)
|
||||
}
|
||||
}
|
||||
if subDomain != "" {
|
||||
domains = append(domains, subDomain+"."+pxy.serverCfg.SubDomainHost)
|
||||
}
|
||||
return domains
|
||||
}
|
||||
|
||||
// startCommonTCPListenersHandler start a goroutine handler for each listener.
|
||||
func (pxy *BaseProxy) startCommonTCPListenersHandler() {
|
||||
xl := xlog.FromContextSafe(pxy.ctx)
|
||||
|
||||
@@ -21,7 +21,7 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
RegisterProxyFactory(reflect.TypeOf(&v1.STCPProxyConfig{}), NewSTCPProxy)
|
||||
RegisterProxyFactory(reflect.TypeFor[*v1.STCPProxyConfig](), NewSTCPProxy)
|
||||
}
|
||||
|
||||
type STCPProxy struct {
|
||||
@@ -41,21 +41,7 @@ func NewSTCPProxy(baseProxy *BaseProxy) Proxy {
|
||||
}
|
||||
|
||||
func (pxy *STCPProxy) Run() (remoteAddr string, err error) {
|
||||
xl := pxy.xl
|
||||
allowUsers := pxy.cfg.AllowUsers
|
||||
// if allowUsers is empty, only allow same user from proxy
|
||||
if len(allowUsers) == 0 {
|
||||
allowUsers = []string{pxy.GetUserInfo().User}
|
||||
}
|
||||
listener, errRet := pxy.rc.VisitorManager.Listen(pxy.GetName(), pxy.cfg.Secretkey, allowUsers)
|
||||
if errRet != nil {
|
||||
err = errRet
|
||||
return
|
||||
}
|
||||
pxy.listeners = append(pxy.listeners, listener)
|
||||
xl.Infof("stcp proxy custom listen success")
|
||||
|
||||
pxy.startCommonTCPListenersHandler()
|
||||
err = pxy.startVisitorListener(pxy.cfg.Secretkey, pxy.cfg.AllowUsers, "stcp")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
RegisterProxyFactory(reflect.TypeOf(&v1.SUDPProxyConfig{}), NewSUDPProxy)
|
||||
RegisterProxyFactory(reflect.TypeFor[*v1.SUDPProxyConfig](), NewSUDPProxy)
|
||||
}
|
||||
|
||||
type SUDPProxy struct {
|
||||
@@ -41,21 +41,7 @@ func NewSUDPProxy(baseProxy *BaseProxy) Proxy {
|
||||
}
|
||||
|
||||
func (pxy *SUDPProxy) Run() (remoteAddr string, err error) {
|
||||
xl := pxy.xl
|
||||
allowUsers := pxy.cfg.AllowUsers
|
||||
// if allowUsers is empty, only allow same user from proxy
|
||||
if len(allowUsers) == 0 {
|
||||
allowUsers = []string{pxy.GetUserInfo().User}
|
||||
}
|
||||
listener, errRet := pxy.rc.VisitorManager.Listen(pxy.GetName(), pxy.cfg.Secretkey, allowUsers)
|
||||
if errRet != nil {
|
||||
err = errRet
|
||||
return
|
||||
}
|
||||
pxy.listeners = append(pxy.listeners, listener)
|
||||
xl.Infof("sudp proxy custom listen success")
|
||||
|
||||
pxy.startCommonTCPListenersHandler()
|
||||
err = pxy.startVisitorListener(pxy.cfg.Secretkey, pxy.cfg.AllowUsers, "sudp")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
RegisterProxyFactory(reflect.TypeOf(&v1.TCPProxyConfig{}), NewTCPProxy)
|
||||
RegisterProxyFactory(reflect.TypeFor[*v1.TCPProxyConfig](), NewTCPProxy)
|
||||
}
|
||||
|
||||
type TCPProxy struct {
|
||||
|
||||
@@ -26,7 +26,7 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
RegisterProxyFactory(reflect.TypeOf(&v1.TCPMuxProxyConfig{}), NewTCPMuxProxy)
|
||||
RegisterProxyFactory(reflect.TypeFor[*v1.TCPMuxProxyConfig](), NewTCPMuxProxy)
|
||||
}
|
||||
|
||||
type TCPMuxProxy struct {
|
||||
@@ -72,26 +72,16 @@ func (pxy *TCPMuxProxy) httpConnectListen(
|
||||
}
|
||||
|
||||
func (pxy *TCPMuxProxy) httpConnectRun() (remoteAddr string, err error) {
|
||||
domains := pxy.buildDomains(pxy.cfg.CustomDomains, pxy.cfg.SubDomain)
|
||||
|
||||
addrs := make([]string, 0)
|
||||
for _, domain := range pxy.cfg.CustomDomains {
|
||||
if domain == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, domain := range domains {
|
||||
addrs, err = pxy.httpConnectListen(domain, pxy.cfg.RouteByHTTPUser, pxy.cfg.HTTPUser, pxy.cfg.HTTPPassword, addrs)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
if pxy.cfg.SubDomain != "" {
|
||||
addrs, err = pxy.httpConnectListen(pxy.cfg.SubDomain+"."+pxy.serverCfg.SubDomainHost,
|
||||
pxy.cfg.RouteByHTTPUser, pxy.cfg.HTTPUser, pxy.cfg.HTTPPassword, addrs)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
pxy.startCommonTCPListenersHandler()
|
||||
remoteAddr = strings.Join(addrs, ",")
|
||||
return remoteAddr, err
|
||||
|
||||
@@ -35,7 +35,7 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
RegisterProxyFactory(reflect.TypeOf(&v1.UDPProxyConfig{}), NewUDPProxy)
|
||||
RegisterProxyFactory(reflect.TypeFor[*v1.UDPProxyConfig](), NewUDPProxy)
|
||||
}
|
||||
|
||||
type UDPProxy struct {
|
||||
@@ -136,7 +136,7 @@ func (pxy *UDPProxy) Run() (remoteAddr string, err error) {
|
||||
continue
|
||||
case *msg.UDPPacket:
|
||||
if errRet := errors.PanicToError(func() {
|
||||
xl.Tracef("get udp message from workConn: %s", m.Content)
|
||||
xl.Tracef("get udp message from workConn, len: %d", len(m.Content))
|
||||
pxy.readCh <- m
|
||||
metrics.Server.AddTrafficOut(
|
||||
pxy.GetName(),
|
||||
@@ -167,7 +167,7 @@ func (pxy *UDPProxy) Run() (remoteAddr string, err error) {
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
xl.Tracef("send message to udp workConn: %s", udpMsg.Content)
|
||||
xl.Tracef("send message to udp workConn, len: %d", len(udpMsg.Content))
|
||||
metrics.Server.AddTrafficIn(
|
||||
pxy.GetName(),
|
||||
pxy.GetConfigurer().GetBaseConfig().Type,
|
||||
|
||||
@@ -24,7 +24,7 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
RegisterProxyFactory(reflect.TypeOf(&v1.XTCPProxyConfig{}), NewXTCPProxy)
|
||||
RegisterProxyFactory(reflect.TypeFor[*v1.XTCPProxyConfig](), NewXTCPProxy)
|
||||
}
|
||||
|
||||
type XTCPProxy struct {
|
||||
|
||||
@@ -193,7 +193,7 @@ func NewService(cfg *v1.ServerConfig) (*Service, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create vhost tcpMuxer error, %v", err)
|
||||
}
|
||||
log.Infof("tcpmux httpconnect multiplexer listen on %s, passthough: %v", address, cfg.TCPMuxPassthrough)
|
||||
log.Infof("tcpmux httpconnect multiplexer listen on %s, passthrough: %v", address, cfg.TCPMuxPassthrough)
|
||||
}
|
||||
|
||||
// Init all plugins
|
||||
@@ -604,8 +604,18 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, inter
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO(fatedier): use SessionContext
|
||||
ctl, err := NewControl(ctx, svr.rc, svr.pxyManager, svr.pluginManager, authVerifier, svr.auth.EncryptionKey(), ctlConn, !internal, loginMsg, svr.cfg)
|
||||
ctl, err := NewControl(ctx, &SessionContext{
|
||||
RC: svr.rc,
|
||||
PxyManager: svr.pxyManager,
|
||||
PluginManager: svr.pluginManager,
|
||||
AuthVerifier: authVerifier,
|
||||
EncryptionKey: svr.auth.EncryptionKey(),
|
||||
Conn: ctlConn,
|
||||
ConnEncrypted: !internal,
|
||||
LoginMsg: loginMsg,
|
||||
ServerCfg: svr.cfg,
|
||||
ClientRegistry: svr.clientRegistry,
|
||||
})
|
||||
if err != nil {
|
||||
xl.Warnf("create new controller error: %v", err)
|
||||
// don't return detailed errors to client
|
||||
@@ -626,7 +636,6 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, inter
|
||||
ctl.Close()
|
||||
return fmt.Errorf("client_id [%s] for user [%s] is already online", loginMsg.ClientID, loginMsg.User)
|
||||
}
|
||||
ctl.clientRegistry = svr.clientRegistry
|
||||
|
||||
ctl.Start()
|
||||
|
||||
@@ -652,9 +661,9 @@ func (svr *Service) RegisterWorkConn(workConn net.Conn, newMsg *msg.NewWorkConn)
|
||||
// server plugin hook
|
||||
content := &plugin.NewWorkConnContent{
|
||||
User: plugin.UserInfo{
|
||||
User: ctl.loginMsg.User,
|
||||
Metas: ctl.loginMsg.Metas,
|
||||
RunID: ctl.loginMsg.RunID,
|
||||
User: ctl.sessionCtx.LoginMsg.User,
|
||||
Metas: ctl.sessionCtx.LoginMsg.Metas,
|
||||
RunID: ctl.sessionCtx.LoginMsg.RunID,
|
||||
},
|
||||
NewWorkConn: *newMsg,
|
||||
}
|
||||
@@ -662,7 +671,7 @@ func (svr *Service) RegisterWorkConn(workConn net.Conn, newMsg *msg.NewWorkConn)
|
||||
if err == nil {
|
||||
newMsg = &retContent.NewWorkConn
|
||||
// Check auth.
|
||||
err = ctl.authVerifier.VerifyNewWorkConn(newMsg)
|
||||
err = ctl.sessionCtx.AuthVerifier.VerifyNewWorkConn(newMsg)
|
||||
}
|
||||
if err != nil {
|
||||
xl.Warnf("invalid NewWorkConn with run id [%s]", newMsg.RunID)
|
||||
@@ -683,7 +692,7 @@ func (svr *Service) RegisterVisitorConn(visitorConn net.Conn, newMsg *msg.NewVis
|
||||
if !exist {
|
||||
return fmt.Errorf("no client control found for run id [%s]", newMsg.RunID)
|
||||
}
|
||||
visitorUser = ctl.loginMsg.User
|
||||
visitorUser = ctl.sessionCtx.LoginMsg.User
|
||||
}
|
||||
return svr.rc.VisitorManager.NewConn(newMsg.ProxyName, visitorConn, newMsg.Timestamp, newMsg.SignKey,
|
||||
newMsg.UseEncryption, newMsg.UseCompression, visitorUser)
|
||||
|
||||
@@ -26,7 +26,7 @@ var _ = ginkgo.Describe("[Feature: Example]", func() {
|
||||
remotePort = %d
|
||||
`, framework.TCPEchoServerPort, remotePort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
framework.NewRequestExpect(f).Port(remotePort).Ensure()
|
||||
})
|
||||
|
||||
@@ -2,70 +2,71 @@ package framework
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"maps"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
flog "github.com/fatedier/frp/pkg/util/log"
|
||||
"github.com/fatedier/frp/test/e2e/framework/consts"
|
||||
"github.com/fatedier/frp/test/e2e/pkg/process"
|
||||
)
|
||||
|
||||
// RunProcesses run multiple processes from templates.
|
||||
// The first template should always be frps.
|
||||
func (f *Framework) RunProcesses(serverTemplates []string, clientTemplates []string) ([]*process.Process, []*process.Process) {
|
||||
templates := make([]string, 0, len(serverTemplates)+len(clientTemplates))
|
||||
templates = append(templates, serverTemplates...)
|
||||
templates = append(templates, clientTemplates...)
|
||||
// RunProcesses starts one frps and zero or more frpc processes from templates.
|
||||
func (f *Framework) RunProcesses(serverTemplate string, clientTemplates []string) (*process.Process, []*process.Process) {
|
||||
templates := append([]string{serverTemplate}, clientTemplates...)
|
||||
outs, ports, err := f.RenderTemplates(templates)
|
||||
ExpectNoError(err)
|
||||
ExpectTrue(len(templates) > 0)
|
||||
|
||||
for name, port := range ports {
|
||||
f.usedPorts[name] = port
|
||||
maps.Copy(f.usedPorts, ports)
|
||||
|
||||
// Start frps.
|
||||
serverPath := filepath.Join(f.TempDirectory, "frp-e2e-server-0")
|
||||
err = os.WriteFile(serverPath, []byte(outs[0]), 0o600)
|
||||
ExpectNoError(err)
|
||||
|
||||
if TestContext.Debug {
|
||||
flog.Debugf("[%s] %s", serverPath, outs[0])
|
||||
}
|
||||
|
||||
currentServerProcesses := make([]*process.Process, 0, len(serverTemplates))
|
||||
for i := range serverTemplates {
|
||||
path := filepath.Join(f.TempDirectory, fmt.Sprintf("frp-e2e-server-%d", i))
|
||||
err = os.WriteFile(path, []byte(outs[i]), 0o600)
|
||||
ExpectNoError(err)
|
||||
serverProcess := process.NewWithEnvs(TestContext.FRPServerPath, []string{"-c", serverPath}, f.osEnvs)
|
||||
f.serverConfPaths = append(f.serverConfPaths, serverPath)
|
||||
f.serverProcesses = append(f.serverProcesses, serverProcess)
|
||||
err = serverProcess.Start()
|
||||
ExpectNoError(err)
|
||||
|
||||
if TestContext.Debug {
|
||||
flog.Debugf("[%s] %s", path, outs[i])
|
||||
}
|
||||
|
||||
p := process.NewWithEnvs(TestContext.FRPServerPath, []string{"-c", path}, f.osEnvs)
|
||||
f.serverConfPaths = append(f.serverConfPaths, path)
|
||||
f.serverProcesses = append(f.serverProcesses, p)
|
||||
currentServerProcesses = append(currentServerProcesses, p)
|
||||
err = p.Start()
|
||||
ExpectNoError(err)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
if port, ok := ports[consts.PortServerName]; ok {
|
||||
ExpectNoError(WaitForTCPReady(net.JoinHostPort("127.0.0.1", strconv.Itoa(port)), 5*time.Second))
|
||||
} else {
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
currentClientProcesses := make([]*process.Process, 0, len(clientTemplates))
|
||||
// Start frpc(s).
|
||||
clientProcesses := make([]*process.Process, 0, len(clientTemplates))
|
||||
for i := range clientTemplates {
|
||||
index := i + len(serverTemplates)
|
||||
path := filepath.Join(f.TempDirectory, fmt.Sprintf("frp-e2e-client-%d", i))
|
||||
err = os.WriteFile(path, []byte(outs[index]), 0o600)
|
||||
err = os.WriteFile(path, []byte(outs[1+i]), 0o600)
|
||||
ExpectNoError(err)
|
||||
|
||||
if TestContext.Debug {
|
||||
flog.Debugf("[%s] %s", path, outs[index])
|
||||
flog.Debugf("[%s] %s", path, outs[1+i])
|
||||
}
|
||||
|
||||
p := process.NewWithEnvs(TestContext.FRPClientPath, []string{"-c", path}, f.osEnvs)
|
||||
f.clientConfPaths = append(f.clientConfPaths, path)
|
||||
f.clientProcesses = append(f.clientProcesses, p)
|
||||
currentClientProcesses = append(currentClientProcesses, p)
|
||||
clientProcesses = append(clientProcesses, p)
|
||||
err = p.Start()
|
||||
ExpectNoError(err)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
time.Sleep(3 * time.Second)
|
||||
// frpc needs time to connect and register proxies with frps.
|
||||
if len(clientProcesses) > 0 {
|
||||
time.Sleep(1500 * time.Millisecond)
|
||||
}
|
||||
|
||||
return currentServerProcesses, currentClientProcesses
|
||||
return serverProcess, clientProcesses
|
||||
}
|
||||
|
||||
func (f *Framework) RunFrps(args ...string) (*process.Process, string, error) {
|
||||
@@ -73,11 +74,13 @@ func (f *Framework) RunFrps(args ...string) (*process.Process, string, error) {
|
||||
f.serverProcesses = append(f.serverProcesses, p)
|
||||
err := p.Start()
|
||||
if err != nil {
|
||||
return p, p.StdOutput(), err
|
||||
return p, p.Output(), err
|
||||
}
|
||||
// Give frps extra time to finish binding ports before proceeding.
|
||||
time.Sleep(4 * time.Second)
|
||||
return p, p.StdOutput(), nil
|
||||
select {
|
||||
case <-p.Done():
|
||||
case <-time.After(2 * time.Second):
|
||||
}
|
||||
return p, p.Output(), nil
|
||||
}
|
||||
|
||||
func (f *Framework) RunFrpc(args ...string) (*process.Process, string, error) {
|
||||
@@ -85,10 +88,13 @@ func (f *Framework) RunFrpc(args ...string) (*process.Process, string, error) {
|
||||
f.clientProcesses = append(f.clientProcesses, p)
|
||||
err := p.Start()
|
||||
if err != nil {
|
||||
return p, p.StdOutput(), err
|
||||
return p, p.Output(), err
|
||||
}
|
||||
time.Sleep(2 * time.Second)
|
||||
return p, p.StdOutput(), nil
|
||||
select {
|
||||
case <-p.Done():
|
||||
case <-time.After(1500 * time.Millisecond):
|
||||
}
|
||||
return p, p.Output(), nil
|
||||
}
|
||||
|
||||
func (f *Framework) GenerateConfigFile(content string) string {
|
||||
@@ -98,3 +104,25 @@ func (f *Framework) GenerateConfigFile(content string) string {
|
||||
ExpectNoError(err)
|
||||
return path
|
||||
}
|
||||
|
||||
// WaitForTCPReady polls a TCP address until a connection succeeds or timeout.
|
||||
func WaitForTCPReady(addr string, timeout time.Duration) error {
|
||||
if timeout <= 0 {
|
||||
return fmt.Errorf("invalid timeout for TCP readiness on %s: timeout must be positive", addr)
|
||||
}
|
||||
deadline := time.Now().Add(timeout)
|
||||
var lastErr error
|
||||
for time.Now().Before(deadline) {
|
||||
conn, err := net.DialTimeout("tcp", addr, 100*time.Millisecond)
|
||||
if err == nil {
|
||||
conn.Close()
|
||||
return nil
|
||||
}
|
||||
lastErr = err
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
if lastErr == nil {
|
||||
return fmt.Errorf("timeout waiting for TCP readiness on %s before any dial attempt", addr)
|
||||
}
|
||||
return fmt.Errorf("timeout waiting for TCP readiness on %s: %w", addr, lastErr)
|
||||
}
|
||||
|
||||
@@ -26,7 +26,8 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
|
||||
proxyType := t
|
||||
ginkgo.It(fmt.Sprintf("Expose a %s echo server", strings.ToUpper(proxyType)), func() {
|
||||
serverConf := consts.LegacyDefaultServerConfig
|
||||
clientConf := consts.LegacyDefaultClientConfig
|
||||
var clientConf strings.Builder
|
||||
clientConf.WriteString(consts.LegacyDefaultClientConfig)
|
||||
|
||||
localPortName := ""
|
||||
protocol := "tcp"
|
||||
@@ -78,10 +79,10 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
|
||||
|
||||
// build all client config
|
||||
for _, test := range tests {
|
||||
clientConf += getProxyConf(test.proxyName, test.portName, test.extraConfig) + "\n"
|
||||
clientConf.WriteString(getProxyConf(test.proxyName, test.portName, test.extraConfig) + "\n")
|
||||
}
|
||||
// run frps and frpc
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf.String()})
|
||||
|
||||
for _, test := range tests {
|
||||
framework.NewRequestExpect(f).
|
||||
@@ -102,7 +103,8 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
|
||||
vhost_http_port = %d
|
||||
`, vhostHTTPPort)
|
||||
|
||||
clientConf := consts.LegacyDefaultClientConfig
|
||||
var clientConf strings.Builder
|
||||
clientConf.WriteString(consts.LegacyDefaultClientConfig)
|
||||
|
||||
getProxyConf := func(proxyName string, customDomains string, extra string) string {
|
||||
return fmt.Sprintf(`
|
||||
@@ -147,13 +149,13 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
|
||||
if tests[i].customDomains == "" {
|
||||
tests[i].customDomains = test.proxyName + ".example.com"
|
||||
}
|
||||
clientConf += getProxyConf(test.proxyName, tests[i].customDomains, test.extraConfig) + "\n"
|
||||
clientConf.WriteString(getProxyConf(test.proxyName, tests[i].customDomains, test.extraConfig) + "\n")
|
||||
}
|
||||
// run frps and frpc
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf.String()})
|
||||
|
||||
for _, test := range tests {
|
||||
for _, domain := range strings.Split(test.customDomains, ",") {
|
||||
for domain := range strings.SplitSeq(test.customDomains, ",") {
|
||||
domain = strings.TrimSpace(domain)
|
||||
framework.NewRequestExpect(f).
|
||||
Explain(test.proxyName + "-" + domain).
|
||||
@@ -185,7 +187,8 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
|
||||
`, vhostHTTPSPort)
|
||||
|
||||
localPort := f.AllocPort()
|
||||
clientConf := consts.LegacyDefaultClientConfig
|
||||
var clientConf strings.Builder
|
||||
clientConf.WriteString(consts.LegacyDefaultClientConfig)
|
||||
getProxyConf := func(proxyName string, customDomains string, extra string) string {
|
||||
return fmt.Sprintf(`
|
||||
[%s]
|
||||
@@ -229,10 +232,10 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
|
||||
if tests[i].customDomains == "" {
|
||||
tests[i].customDomains = test.proxyName + ".example.com"
|
||||
}
|
||||
clientConf += getProxyConf(test.proxyName, tests[i].customDomains, test.extraConfig) + "\n"
|
||||
clientConf.WriteString(getProxyConf(test.proxyName, tests[i].customDomains, test.extraConfig) + "\n")
|
||||
}
|
||||
// run frps and frpc
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf.String()})
|
||||
|
||||
tlsConfig, err := transport.NewServerTLSConfig("", "", "")
|
||||
framework.ExpectNoError(err)
|
||||
@@ -244,7 +247,7 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
|
||||
f.RunServer("", localServer)
|
||||
|
||||
for _, test := range tests {
|
||||
for _, domain := range strings.Split(test.customDomains, ",") {
|
||||
for domain := range strings.SplitSeq(test.customDomains, ",") {
|
||||
domain = strings.TrimSpace(domain)
|
||||
framework.NewRequestExpect(f).
|
||||
Explain(test.proxyName + "-" + domain).
|
||||
@@ -282,9 +285,12 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
|
||||
proxyType := t
|
||||
ginkgo.It(fmt.Sprintf("Expose echo server with %s", strings.ToUpper(proxyType)), func() {
|
||||
serverConf := consts.LegacyDefaultServerConfig
|
||||
clientServerConf := consts.LegacyDefaultClientConfig + "\nuser = user1"
|
||||
clientVisitorConf := consts.LegacyDefaultClientConfig + "\nuser = user1"
|
||||
clientUser2VisitorConf := consts.LegacyDefaultClientConfig + "\nuser = user2"
|
||||
var clientServerConf strings.Builder
|
||||
clientServerConf.WriteString(consts.LegacyDefaultClientConfig + "\nuser = user1")
|
||||
var clientVisitorConf strings.Builder
|
||||
clientVisitorConf.WriteString(consts.LegacyDefaultClientConfig + "\nuser = user1")
|
||||
var clientUser2VisitorConf strings.Builder
|
||||
clientUser2VisitorConf.WriteString(consts.LegacyDefaultClientConfig + "\nuser = user2")
|
||||
|
||||
localPortName := ""
|
||||
protocol := "tcp"
|
||||
@@ -400,20 +406,20 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
|
||||
|
||||
// build all client config
|
||||
for _, test := range tests {
|
||||
clientServerConf += getProxyServerConf(test.proxyName, test.commonExtraConfig+"\n"+test.proxyExtraConfig) + "\n"
|
||||
clientServerConf.WriteString(getProxyServerConf(test.proxyName, test.commonExtraConfig+"\n"+test.proxyExtraConfig) + "\n")
|
||||
}
|
||||
for _, test := range tests {
|
||||
config := getProxyVisitorConf(
|
||||
test.proxyName, test.bindPortName, test.visitorSK, test.commonExtraConfig+"\n"+test.visitorExtraConfig,
|
||||
) + "\n"
|
||||
if test.deployUser2Client {
|
||||
clientUser2VisitorConf += config
|
||||
clientUser2VisitorConf.WriteString(config)
|
||||
} else {
|
||||
clientVisitorConf += config
|
||||
clientVisitorConf.WriteString(config)
|
||||
}
|
||||
}
|
||||
// run frps and frpc
|
||||
f.RunProcesses([]string{serverConf}, []string{clientServerConf, clientVisitorConf, clientUser2VisitorConf})
|
||||
f.RunProcesses(serverConf, []string{clientServerConf.String(), clientVisitorConf.String(), clientUser2VisitorConf.String()})
|
||||
|
||||
for _, test := range tests {
|
||||
timeout := time.Second
|
||||
@@ -440,7 +446,8 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
|
||||
ginkgo.Describe("TCPMUX", func() {
|
||||
ginkgo.It("Type tcpmux", func() {
|
||||
serverConf := consts.LegacyDefaultServerConfig
|
||||
clientConf := consts.LegacyDefaultClientConfig
|
||||
var clientConf strings.Builder
|
||||
clientConf.WriteString(consts.LegacyDefaultClientConfig)
|
||||
|
||||
tcpmuxHTTPConnectPortName := port.GenName("TCPMUX")
|
||||
serverConf += fmt.Sprintf(`
|
||||
@@ -483,14 +490,14 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
|
||||
|
||||
// build all client config
|
||||
for _, test := range tests {
|
||||
clientConf += getProxyConf(test.proxyName, test.extraConfig) + "\n"
|
||||
clientConf.WriteString(getProxyConf(test.proxyName, test.extraConfig) + "\n")
|
||||
|
||||
localServer := streamserver.New(streamserver.TCP, streamserver.WithBindPort(f.AllocPort()), streamserver.WithRespContent([]byte(test.proxyName)))
|
||||
f.RunServer(port.GenName(test.proxyName), localServer)
|
||||
}
|
||||
|
||||
// run frps and frpc
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf.String()})
|
||||
|
||||
// Request without HTTP connect should get error
|
||||
framework.NewRequestExpect(f).
|
||||
|
||||
@@ -48,7 +48,7 @@ var _ = ginkgo.Describe("[Feature: ClientManage]", func() {
|
||||
framework.TCPEchoServerPort, p2Port,
|
||||
framework.TCPEchoServerPort, p3Port)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
framework.NewRequestExpect(f).Port(p1Port).Ensure()
|
||||
framework.NewRequestExpect(f).Port(p2Port).Ensure()
|
||||
@@ -90,7 +90,7 @@ var _ = ginkgo.Describe("[Feature: ClientManage]", func() {
|
||||
admin_pwd = admin
|
||||
`, dashboardPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
framework.NewRequestExpect(f).RequestModify(func(r *request.Request) {
|
||||
r.HTTP().HTTPPath("/healthz")
|
||||
@@ -116,7 +116,7 @@ var _ = ginkgo.Describe("[Feature: ClientManage]", func() {
|
||||
remote_port = %d
|
||||
`, adminPort, framework.TCPEchoServerPort, testPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
framework.NewRequestExpect(f).Port(testPort).Ensure()
|
||||
|
||||
|
||||
@@ -76,7 +76,7 @@ func runClientServerTest(f *framework.Framework, configures *generalTestConfigur
|
||||
clientConfs = append(clientConfs, client2Conf)
|
||||
}
|
||||
|
||||
f.RunProcesses([]string{serverConf}, clientConfs)
|
||||
f.RunProcesses(serverConf, clientConfs)
|
||||
|
||||
if configures.testDelay > 0 {
|
||||
time.Sleep(configures.testDelay)
|
||||
|
||||
@@ -33,7 +33,7 @@ var _ = ginkgo.Describe("[Feature: Config]", func() {
|
||||
`, "`", "`", framework.TCPEchoServerPort, portName)
|
||||
|
||||
f.SetEnvs([]string{"FRP_TOKEN=123"})
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
framework.NewRequestExpect(f).PortName(portName).Ensure()
|
||||
})
|
||||
|
||||
@@ -56,7 +56,7 @@ var _ = ginkgo.Describe("[Feature: HTTP]", func() {
|
||||
locations = /bar
|
||||
`, fooPort, barPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
tests := []struct {
|
||||
path string
|
||||
@@ -111,7 +111,7 @@ var _ = ginkgo.Describe("[Feature: HTTP]", func() {
|
||||
custom_domains = normal.example.com
|
||||
`, fooPort, barPort, otherPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
// user1
|
||||
framework.NewRequestExpect(f).Explain("user1").Port(vhostHTTPPort).
|
||||
@@ -152,7 +152,7 @@ var _ = ginkgo.Describe("[Feature: HTTP]", func() {
|
||||
http_pwd = test
|
||||
`, framework.HTTPSimpleServerPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
// not set auth header
|
||||
framework.NewRequestExpect(f).Port(vhostHTTPPort).
|
||||
@@ -188,7 +188,7 @@ var _ = ginkgo.Describe("[Feature: HTTP]", func() {
|
||||
custom_domains = *.example.com
|
||||
`, framework.HTTPSimpleServerPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
// not match host
|
||||
framework.NewRequestExpect(f).Port(vhostHTTPPort).
|
||||
@@ -238,7 +238,7 @@ var _ = ginkgo.Describe("[Feature: HTTP]", func() {
|
||||
subdomain = bar
|
||||
`, fooPort, barPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
// foo
|
||||
framework.NewRequestExpect(f).Explain("foo subdomain").Port(vhostHTTPPort).
|
||||
@@ -279,7 +279,7 @@ var _ = ginkgo.Describe("[Feature: HTTP]", func() {
|
||||
header_X-From-Where = frp
|
||||
`, localPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
// not set auth header
|
||||
framework.NewRequestExpect(f).Port(vhostHTTPPort).
|
||||
@@ -312,7 +312,7 @@ var _ = ginkgo.Describe("[Feature: HTTP]", func() {
|
||||
host_header_rewrite = rewrite.example.com
|
||||
`, localPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
framework.NewRequestExpect(f).Port(vhostHTTPPort).
|
||||
RequestModify(func(r *request.Request) {
|
||||
@@ -360,7 +360,7 @@ var _ = ginkgo.Describe("[Feature: HTTP]", func() {
|
||||
custom_domains = 127.0.0.1
|
||||
`, localPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
u := url.URL{Scheme: "ws", Host: "127.0.0.1:" + strconv.Itoa(vhostHTTPPort)}
|
||||
c, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
|
||||
|
||||
@@ -28,7 +28,7 @@ var _ = ginkgo.Describe("[Feature: Server Manager]", func() {
|
||||
tcpPortName := port.GenName("TCP", port.WithRangePorts(10000, 11000))
|
||||
udpPortName := port.GenName("UDP", port.WithRangePorts(12000, 13000))
|
||||
clientConf += fmt.Sprintf(`
|
||||
[tcp-allowded-in-range]
|
||||
[tcp-allowed-in-range]
|
||||
type = tcp
|
||||
local_port = {{ .%s }}
|
||||
remote_port = {{ .%s }}
|
||||
@@ -58,7 +58,7 @@ var _ = ginkgo.Describe("[Feature: Server Manager]", func() {
|
||||
remote_port = 11003
|
||||
`, framework.UDPEchoServerPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
// TCP
|
||||
// Allowed in range
|
||||
@@ -97,7 +97,7 @@ var _ = ginkgo.Describe("[Feature: Server Manager]", func() {
|
||||
local_port = {{ .%s }}
|
||||
`, adminPort, framework.TCPEchoServerPort, framework.UDPEchoServerPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
client := f.APIClientForFrpc(adminPort)
|
||||
|
||||
@@ -138,7 +138,7 @@ var _ = ginkgo.Describe("[Feature: Server Manager]", func() {
|
||||
custom_domains = example.com
|
||||
`, framework.HTTPSimpleServerPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
framework.NewRequestExpect(f).RequestModify(func(r *request.Request) {
|
||||
r.HTTP().HTTPHost("example.com")
|
||||
@@ -165,7 +165,7 @@ var _ = ginkgo.Describe("[Feature: Server Manager]", func() {
|
||||
custom_domains = example.com
|
||||
`, framework.HTTPSimpleServerPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
framework.NewRequestExpect(f).RequestModify(func(r *request.Request) {
|
||||
r.HTTP().HTTPPath("/healthz")
|
||||
|
||||
@@ -76,7 +76,7 @@ var _ = ginkgo.Describe("[Feature: TCPMUX httpconnect]", func() {
|
||||
custom_domains = normal.example.com
|
||||
`, fooPort, barPort, otherPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
// user1
|
||||
framework.NewRequestExpect(f).Explain("user1").
|
||||
@@ -121,7 +121,7 @@ var _ = ginkgo.Describe("[Feature: TCPMUX httpconnect]", func() {
|
||||
http_pwd = test
|
||||
`, fooPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
// not set auth header
|
||||
framework.NewRequestExpect(f).Explain("no auth").
|
||||
@@ -204,7 +204,7 @@ var _ = ginkgo.Describe("[Feature: TCPMUX httpconnect]", func() {
|
||||
custom_domains = normal.example.com
|
||||
`, localPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
framework.NewRequestExpect(f).
|
||||
RequestModify(func(r *request.Request) {
|
||||
|
||||
@@ -41,7 +41,7 @@ var _ = ginkgo.Describe("[Feature: XTCP]", func() {
|
||||
fallback_timeout_ms = 200
|
||||
`, framework.TCPEchoServerPort, bindPortName)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
framework.NewRequestExpect(f).
|
||||
RequestModify(func(r *request.Request) {
|
||||
r.Timeout(time.Second)
|
||||
|
||||
@@ -35,7 +35,7 @@ var _ = ginkgo.Describe("[Feature: Bandwidth Limit]", func() {
|
||||
bandwidth_limit = 10KB
|
||||
`, localPort, remotePort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
content := strings.Repeat("a", 50*1024) // 5KB
|
||||
start := time.Now()
|
||||
@@ -89,7 +89,7 @@ var _ = ginkgo.Describe("[Feature: Bandwidth Limit]", func() {
|
||||
remote_port = %d
|
||||
`, localPort, remotePort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
content := strings.Repeat("a", 50*1024) // 5KB
|
||||
start := time.Now()
|
||||
|
||||
@@ -48,12 +48,10 @@ var _ = ginkgo.Describe("[Feature: Group]", func() {
|
||||
return true
|
||||
})
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
wait.Add(1)
|
||||
go func() {
|
||||
defer wait.Done()
|
||||
for range 10 {
|
||||
wait.Go(func() {
|
||||
expectFn()
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
wait.Wait()
|
||||
@@ -90,11 +88,11 @@ var _ = ginkgo.Describe("[Feature: Group]", func() {
|
||||
group_key = 123
|
||||
`, fooPort, remotePort, barPort, remotePort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
fooCount := 0
|
||||
barCount := 0
|
||||
for i := 0; i < 10; i++ {
|
||||
for i := range 10 {
|
||||
framework.NewRequestExpect(f).Explain("times " + strconv.Itoa(i)).Port(remotePort).Ensure(func(resp *request.Response) bool {
|
||||
switch string(resp.Content) {
|
||||
case "foo":
|
||||
@@ -146,11 +144,11 @@ var _ = ginkgo.Describe("[Feature: Group]", func() {
|
||||
health_check_interval_s = 1
|
||||
`, fooPort, remotePort, barPort, remotePort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
// check foo and bar is ok
|
||||
results := []string{}
|
||||
for i := 0; i < 10; i++ {
|
||||
for range 10 {
|
||||
framework.NewRequestExpect(f).Port(remotePort).Ensure(validateFooBarResponse, func(resp *request.Response) bool {
|
||||
results = append(results, string(resp.Content))
|
||||
return true
|
||||
@@ -161,7 +159,7 @@ var _ = ginkgo.Describe("[Feature: Group]", func() {
|
||||
// close bar server, check foo is ok
|
||||
barServer.Close()
|
||||
time.Sleep(2 * time.Second)
|
||||
for i := 0; i < 10; i++ {
|
||||
for range 10 {
|
||||
framework.NewRequestExpect(f).Port(remotePort).ExpectResp([]byte("foo")).Ensure()
|
||||
}
|
||||
|
||||
@@ -169,7 +167,7 @@ var _ = ginkgo.Describe("[Feature: Group]", func() {
|
||||
f.RunServer("", barServer)
|
||||
time.Sleep(2 * time.Second)
|
||||
results = []string{}
|
||||
for i := 0; i < 10; i++ {
|
||||
for range 10 {
|
||||
framework.NewRequestExpect(f).Port(remotePort).Ensure(validateFooBarResponse, func(resp *request.Response) bool {
|
||||
results = append(results, string(resp.Content))
|
||||
return true
|
||||
@@ -215,7 +213,7 @@ var _ = ginkgo.Describe("[Feature: Group]", func() {
|
||||
health_check_url = /healthz
|
||||
`, fooPort, barPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
// send first HTTP request
|
||||
var contents []string
|
||||
|
||||
@@ -38,7 +38,7 @@ var _ = ginkgo.Describe("[Feature: Heartbeat]", func() {
|
||||
`, serverPort, f.PortByName(framework.TCPEchoServerPort), remotePort)
|
||||
|
||||
// run frps and frpc
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
framework.NewRequestExpect(f).Protocol("tcp").Port(remotePort).Ensure()
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ var _ = ginkgo.Describe("[Feature: Monitor]", func() {
|
||||
remote_port = %d
|
||||
`, framework.TCPEchoServerPort, remotePort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
framework.NewRequestExpect(f).Port(remotePort).Ensure()
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
@@ -44,7 +44,7 @@ var _ = ginkgo.Describe("[Feature: Real IP]", func() {
|
||||
custom_domains = normal.example.com
|
||||
`, localPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
framework.NewRequestExpect(f).Port(vhostHTTPPort).
|
||||
RequestModify(func(r *request.Request) {
|
||||
@@ -90,7 +90,7 @@ var _ = ginkgo.Describe("[Feature: Real IP]", func() {
|
||||
proxy_protocol_version = v2
|
||||
`, localPort, remotePort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
framework.NewRequestExpect(f).Port(remotePort).Ensure(func(resp *request.Response) bool {
|
||||
log.Tracef("proxy protocol get SourceAddr: %s", string(resp.Content))
|
||||
@@ -136,7 +136,7 @@ var _ = ginkgo.Describe("[Feature: Real IP]", func() {
|
||||
proxy_protocol_version = v2
|
||||
`, localPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
framework.NewRequestExpect(f).Port(vhostHTTPPort).RequestModify(func(r *request.Request) {
|
||||
r.HTTP().HTTPHost("normal.example.com")
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/onsi/ginkgo/v2"
|
||||
|
||||
@@ -22,7 +23,8 @@ var _ = ginkgo.Describe("[Feature: Client-Plugins]", func() {
|
||||
ginkgo.Describe("UnixDomainSocket", func() {
|
||||
ginkgo.It("Expose a unix domain socket echo server", func() {
|
||||
serverConf := consts.LegacyDefaultServerConfig
|
||||
clientConf := consts.LegacyDefaultClientConfig
|
||||
var clientConf strings.Builder
|
||||
clientConf.WriteString(consts.LegacyDefaultClientConfig)
|
||||
|
||||
getProxyConf := func(proxyName string, portName string, extra string) string {
|
||||
return fmt.Sprintf(`
|
||||
@@ -65,10 +67,10 @@ var _ = ginkgo.Describe("[Feature: Client-Plugins]", func() {
|
||||
|
||||
// build all client config
|
||||
for _, test := range tests {
|
||||
clientConf += getProxyConf(test.proxyName, test.portName, test.extraConfig) + "\n"
|
||||
clientConf.WriteString(getProxyConf(test.proxyName, test.portName, test.extraConfig) + "\n")
|
||||
}
|
||||
// run frps and frpc
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf.String()})
|
||||
|
||||
for _, test := range tests {
|
||||
framework.NewRequestExpect(f).Port(f.PortByName(test.portName)).Ensure()
|
||||
@@ -90,7 +92,7 @@ var _ = ginkgo.Describe("[Feature: Client-Plugins]", func() {
|
||||
plugin_http_passwd = 123
|
||||
`, remotePort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
// http proxy, no auth info
|
||||
framework.NewRequestExpect(f).PortName(framework.HTTPSimpleServerPort).RequestModify(func(r *request.Request) {
|
||||
@@ -122,7 +124,7 @@ var _ = ginkgo.Describe("[Feature: Client-Plugins]", func() {
|
||||
plugin_passwd = 123
|
||||
`, remotePort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
// http proxy, no auth info
|
||||
framework.NewRequestExpect(f).PortName(framework.TCPEchoServerPort).RequestModify(func(r *request.Request) {
|
||||
@@ -166,7 +168,7 @@ var _ = ginkgo.Describe("[Feature: Client-Plugins]", func() {
|
||||
plugin_http_passwd = 123
|
||||
`, remotePort, f.TempDirectory, f.TempDirectory, f.TempDirectory)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
// from tcp proxy
|
||||
framework.NewRequestExpect(f).Request(
|
||||
@@ -200,7 +202,7 @@ var _ = ginkgo.Describe("[Feature: Client-Plugins]", func() {
|
||||
plugin_local_addr = 127.0.0.1:%d
|
||||
`, localPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
tlsConfig, err := transport.NewServerTLSConfig("", "", "")
|
||||
framework.ExpectNoError(err)
|
||||
@@ -244,7 +246,7 @@ var _ = ginkgo.Describe("[Feature: Client-Plugins]", func() {
|
||||
plugin_key_path = %s
|
||||
`, localPort, crtPath, keyPath)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
localServer := httpserver.New(
|
||||
httpserver.WithBindPort(localPort),
|
||||
@@ -288,7 +290,7 @@ var _ = ginkgo.Describe("[Feature: Client-Plugins]", func() {
|
||||
plugin_key_path = %s
|
||||
`, localPort, crtPath, keyPath)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
tlsConfig, err := transport.NewServerTLSConfig("", "", "")
|
||||
framework.ExpectNoError(err)
|
||||
|
||||
@@ -71,7 +71,7 @@ var _ = ginkgo.Describe("[Feature: Server-Plugins]", func() {
|
||||
remote_port = %d
|
||||
`, framework.TCPEchoServerPort, remotePort2)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf, invalidTokenClientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf, invalidTokenClientConf})
|
||||
|
||||
framework.NewRequestExpect(f).Port(remotePort).Ensure()
|
||||
framework.NewRequestExpect(f).Port(remotePort2).ExpectError(true).Ensure()
|
||||
@@ -119,7 +119,7 @@ var _ = ginkgo.Describe("[Feature: Server-Plugins]", func() {
|
||||
remote_port = %d
|
||||
`, framework.TCPEchoServerPort, remotePort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
framework.NewRequestExpect(f).Port(remotePort).Ensure()
|
||||
})
|
||||
@@ -153,7 +153,7 @@ var _ = ginkgo.Describe("[Feature: Server-Plugins]", func() {
|
||||
remote_port = 0
|
||||
`, framework.TCPEchoServerPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
framework.NewRequestExpect(f).Port(remotePort).Ensure()
|
||||
})
|
||||
@@ -195,7 +195,7 @@ var _ = ginkgo.Describe("[Feature: Server-Plugins]", func() {
|
||||
remote_port = %d
|
||||
`, framework.TCPEchoServerPort, remotePort)
|
||||
|
||||
_, clients := f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
_, clients := f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
framework.NewRequestExpect(f).Port(remotePort).Ensure()
|
||||
|
||||
@@ -250,7 +250,7 @@ var _ = ginkgo.Describe("[Feature: Server-Plugins]", func() {
|
||||
remote_port = %d
|
||||
`, framework.TCPEchoServerPort, remotePort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
framework.NewRequestExpect(f).Port(remotePort).Ensure()
|
||||
|
||||
@@ -297,7 +297,7 @@ var _ = ginkgo.Describe("[Feature: Server-Plugins]", func() {
|
||||
remote_port = %d
|
||||
`, framework.TCPEchoServerPort, remotePort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
framework.NewRequestExpect(f).Port(remotePort).Ensure()
|
||||
|
||||
@@ -342,7 +342,7 @@ var _ = ginkgo.Describe("[Feature: Server-Plugins]", func() {
|
||||
remote_port = %d
|
||||
`, framework.TCPEchoServerPort, remotePort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
framework.NewRequestExpect(f).Port(remotePort).Ensure()
|
||||
|
||||
@@ -389,7 +389,7 @@ var _ = ginkgo.Describe("[Feature: Server-Plugins]", func() {
|
||||
remote_port = %d
|
||||
`, framework.TCPEchoServerPort, remotePort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
framework.NewRequestExpect(f).Port(remotePort).Ensure()
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ func (pa *Allocator) GetByName(portName string) int {
|
||||
pa.mu.Lock()
|
||||
defer pa.mu.Unlock()
|
||||
|
||||
for i := 0; i < 20; i++ {
|
||||
for range 20 {
|
||||
port := pa.getByRange(builder.rangePortFrom, builder.rangePortTo)
|
||||
if port == 0 {
|
||||
return 0
|
||||
|
||||
@@ -3,7 +3,9 @@ package process
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"os/exec"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Process struct {
|
||||
@@ -12,6 +14,11 @@ type Process struct {
|
||||
errorOutput *bytes.Buffer
|
||||
stdOutput *bytes.Buffer
|
||||
|
||||
done chan struct{}
|
||||
closeOne sync.Once
|
||||
waitErr error
|
||||
|
||||
started bool
|
||||
beforeStopHandler func()
|
||||
stopped bool
|
||||
}
|
||||
@@ -27,6 +34,7 @@ func NewWithEnvs(path string, params []string, envs []string) *Process {
|
||||
p := &Process{
|
||||
cmd: cmd,
|
||||
cancel: cancel,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
p.errorOutput = bytes.NewBufferString("")
|
||||
p.stdOutput = bytes.NewBufferString("")
|
||||
@@ -36,11 +44,35 @@ func NewWithEnvs(path string, params []string, envs []string) *Process {
|
||||
}
|
||||
|
||||
func (p *Process) Start() error {
|
||||
return p.cmd.Start()
|
||||
if p.started {
|
||||
return errors.New("process already started")
|
||||
}
|
||||
p.started = true
|
||||
|
||||
err := p.cmd.Start()
|
||||
if err != nil {
|
||||
p.waitErr = err
|
||||
p.closeDone()
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
p.waitErr = p.cmd.Wait()
|
||||
p.closeDone()
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Process) closeDone() {
|
||||
p.closeOne.Do(func() { close(p.done) })
|
||||
}
|
||||
|
||||
// Done returns a channel that is closed when the process exits.
|
||||
func (p *Process) Done() <-chan struct{} {
|
||||
return p.done
|
||||
}
|
||||
|
||||
func (p *Process) Stop() error {
|
||||
if p.stopped {
|
||||
if p.stopped || !p.started {
|
||||
return nil
|
||||
}
|
||||
defer func() {
|
||||
@@ -50,7 +82,8 @@ func (p *Process) Stop() error {
|
||||
p.beforeStopHandler()
|
||||
}
|
||||
p.cancel()
|
||||
return p.cmd.Wait()
|
||||
<-p.done
|
||||
return p.waitErr
|
||||
}
|
||||
|
||||
func (p *Process) ErrorOutput() string {
|
||||
@@ -61,6 +94,10 @@ func (p *Process) StdOutput() string {
|
||||
return p.stdOutput.String()
|
||||
}
|
||||
|
||||
func (p *Process) Output() string {
|
||||
return p.stdOutput.String() + p.errorOutput.String()
|
||||
}
|
||||
|
||||
func (p *Process) SetBeforeStopHandler(fn func()) {
|
||||
p.beforeStopHandler = fn
|
||||
}
|
||||
|
||||
@@ -75,11 +75,11 @@ func (c *TunnelClient) serveListener() {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go c.hanldeConn(conn)
|
||||
go c.handleConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *TunnelClient) hanldeConn(conn net.Conn) {
|
||||
func (c *TunnelClient) handleConn(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
local, err := net.Dial("tcp", c.localAddr)
|
||||
if err != nil {
|
||||
|
||||
@@ -35,7 +35,7 @@ var _ = ginkgo.Describe("[Feature: Annotations]", func() {
|
||||
"frp.e2e.test/bar" = "value2"
|
||||
`, framework.TCPEchoServerPort, p1Port)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
framework.NewRequestExpect(f).Port(p1Port).Ensure()
|
||||
|
||||
|
||||
@@ -26,7 +26,8 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
|
||||
proxyType := t
|
||||
ginkgo.It(fmt.Sprintf("Expose a %s echo server", strings.ToUpper(proxyType)), func() {
|
||||
serverConf := consts.DefaultServerConfig
|
||||
clientConf := consts.DefaultClientConfig
|
||||
var clientConf strings.Builder
|
||||
clientConf.WriteString(consts.DefaultClientConfig)
|
||||
|
||||
localPortName := ""
|
||||
protocol := "tcp"
|
||||
@@ -79,10 +80,10 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
|
||||
|
||||
// build all client config
|
||||
for _, test := range tests {
|
||||
clientConf += getProxyConf(test.proxyName, test.portName, test.extraConfig) + "\n"
|
||||
clientConf.WriteString(getProxyConf(test.proxyName, test.portName, test.extraConfig) + "\n")
|
||||
}
|
||||
// run frps and frpc
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf.String()})
|
||||
|
||||
for _, test := range tests {
|
||||
framework.NewRequestExpect(f).
|
||||
@@ -103,7 +104,8 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
|
||||
vhostHTTPPort = %d
|
||||
`, vhostHTTPPort)
|
||||
|
||||
clientConf := consts.DefaultClientConfig
|
||||
var clientConf strings.Builder
|
||||
clientConf.WriteString(consts.DefaultClientConfig)
|
||||
|
||||
getProxyConf := func(proxyName string, customDomains string, extra string) string {
|
||||
return fmt.Sprintf(`
|
||||
@@ -149,13 +151,13 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
|
||||
if tests[i].customDomains == "" {
|
||||
tests[i].customDomains = fmt.Sprintf(`["%s"]`, test.proxyName+".example.com")
|
||||
}
|
||||
clientConf += getProxyConf(test.proxyName, tests[i].customDomains, test.extraConfig) + "\n"
|
||||
clientConf.WriteString(getProxyConf(test.proxyName, tests[i].customDomains, test.extraConfig) + "\n")
|
||||
}
|
||||
// run frps and frpc
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf.String()})
|
||||
|
||||
for _, test := range tests {
|
||||
for _, domain := range strings.Split(test.customDomains, ",") {
|
||||
for domain := range strings.SplitSeq(test.customDomains, ",") {
|
||||
domain = strings.TrimSpace(domain)
|
||||
domain = strings.TrimLeft(domain, "[\"")
|
||||
domain = strings.TrimRight(domain, "]\"")
|
||||
@@ -189,7 +191,8 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
|
||||
`, vhostHTTPSPort)
|
||||
|
||||
localPort := f.AllocPort()
|
||||
clientConf := consts.DefaultClientConfig
|
||||
var clientConf strings.Builder
|
||||
clientConf.WriteString(consts.DefaultClientConfig)
|
||||
getProxyConf := func(proxyName string, customDomains string, extra string) string {
|
||||
return fmt.Sprintf(`
|
||||
[[proxies]]
|
||||
@@ -234,10 +237,10 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
|
||||
if tests[i].customDomains == "" {
|
||||
tests[i].customDomains = fmt.Sprintf(`["%s"]`, test.proxyName+".example.com")
|
||||
}
|
||||
clientConf += getProxyConf(test.proxyName, tests[i].customDomains, test.extraConfig) + "\n"
|
||||
clientConf.WriteString(getProxyConf(test.proxyName, tests[i].customDomains, test.extraConfig) + "\n")
|
||||
}
|
||||
// run frps and frpc
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf.String()})
|
||||
|
||||
tlsConfig, err := transport.NewServerTLSConfig("", "", "")
|
||||
framework.ExpectNoError(err)
|
||||
@@ -249,7 +252,7 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
|
||||
f.RunServer("", localServer)
|
||||
|
||||
for _, test := range tests {
|
||||
for _, domain := range strings.Split(test.customDomains, ",") {
|
||||
for domain := range strings.SplitSeq(test.customDomains, ",") {
|
||||
domain = strings.TrimSpace(domain)
|
||||
domain = strings.TrimLeft(domain, "[\"")
|
||||
domain = strings.TrimRight(domain, "]\"")
|
||||
@@ -289,9 +292,12 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
|
||||
proxyType := t
|
||||
ginkgo.It(fmt.Sprintf("Expose echo server with %s", strings.ToUpper(proxyType)), func() {
|
||||
serverConf := consts.DefaultServerConfig
|
||||
clientServerConf := consts.DefaultClientConfig + "\nuser = \"user1\""
|
||||
clientVisitorConf := consts.DefaultClientConfig + "\nuser = \"user1\""
|
||||
clientUser2VisitorConf := consts.DefaultClientConfig + "\nuser = \"user2\""
|
||||
var clientServerConf strings.Builder
|
||||
clientServerConf.WriteString(consts.DefaultClientConfig + "\nuser = \"user1\"")
|
||||
var clientVisitorConf strings.Builder
|
||||
clientVisitorConf.WriteString(consts.DefaultClientConfig + "\nuser = \"user1\"")
|
||||
var clientUser2VisitorConf strings.Builder
|
||||
clientUser2VisitorConf.WriteString(consts.DefaultClientConfig + "\nuser = \"user2\"")
|
||||
|
||||
localPortName := ""
|
||||
protocol := "tcp"
|
||||
@@ -407,20 +413,20 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
|
||||
|
||||
// build all client config
|
||||
for _, test := range tests {
|
||||
clientServerConf += getProxyServerConf(test.proxyName, test.commonExtraConfig+"\n"+test.proxyExtraConfig) + "\n"
|
||||
clientServerConf.WriteString(getProxyServerConf(test.proxyName, test.commonExtraConfig+"\n"+test.proxyExtraConfig) + "\n")
|
||||
}
|
||||
for _, test := range tests {
|
||||
config := getProxyVisitorConf(
|
||||
test.proxyName, test.bindPortName, test.visitorSK, test.commonExtraConfig+"\n"+test.visitorExtraConfig,
|
||||
) + "\n"
|
||||
if test.deployUser2Client {
|
||||
clientUser2VisitorConf += config
|
||||
clientUser2VisitorConf.WriteString(config)
|
||||
} else {
|
||||
clientVisitorConf += config
|
||||
clientVisitorConf.WriteString(config)
|
||||
}
|
||||
}
|
||||
// run frps and frpc
|
||||
f.RunProcesses([]string{serverConf}, []string{clientServerConf, clientVisitorConf, clientUser2VisitorConf})
|
||||
f.RunProcesses(serverConf, []string{clientServerConf.String(), clientVisitorConf.String(), clientUser2VisitorConf.String()})
|
||||
|
||||
for _, test := range tests {
|
||||
timeout := time.Second
|
||||
@@ -447,7 +453,8 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
|
||||
ginkgo.Describe("TCPMUX", func() {
|
||||
ginkgo.It("Type tcpmux", func() {
|
||||
serverConf := consts.DefaultServerConfig
|
||||
clientConf := consts.DefaultClientConfig
|
||||
var clientConf strings.Builder
|
||||
clientConf.WriteString(consts.DefaultClientConfig)
|
||||
|
||||
tcpmuxHTTPConnectPortName := port.GenName("TCPMUX")
|
||||
serverConf += fmt.Sprintf(`
|
||||
@@ -491,14 +498,14 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
|
||||
|
||||
// build all client config
|
||||
for _, test := range tests {
|
||||
clientConf += getProxyConf(test.proxyName, test.extraConfig) + "\n"
|
||||
clientConf.WriteString(getProxyConf(test.proxyName, test.extraConfig) + "\n")
|
||||
|
||||
localServer := streamserver.New(streamserver.TCP, streamserver.WithBindPort(f.AllocPort()), streamserver.WithRespContent([]byte(test.proxyName)))
|
||||
f.RunServer(port.GenName(test.proxyName), localServer)
|
||||
}
|
||||
|
||||
// run frps and frpc
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf.String()})
|
||||
|
||||
// Request without HTTP connect should get error
|
||||
framework.NewRequestExpect(f).
|
||||
|
||||
@@ -51,7 +51,7 @@ var _ = ginkgo.Describe("[Feature: ClientManage]", func() {
|
||||
framework.TCPEchoServerPort, p2Port,
|
||||
framework.TCPEchoServerPort, p3Port)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
framework.NewRequestExpect(f).Port(p1Port).Ensure()
|
||||
framework.NewRequestExpect(f).Port(p2Port).Ensure()
|
||||
@@ -93,7 +93,7 @@ var _ = ginkgo.Describe("[Feature: ClientManage]", func() {
|
||||
webServer.password = "admin"
|
||||
`, dashboardPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
framework.NewRequestExpect(f).RequestModify(func(r *request.Request) {
|
||||
r.HTTP().HTTPPath("/healthz")
|
||||
@@ -120,7 +120,7 @@ var _ = ginkgo.Describe("[Feature: ClientManage]", func() {
|
||||
remotePort = %d
|
||||
`, adminPort, framework.TCPEchoServerPort, testPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
framework.NewRequestExpect(f).Port(testPort).Ensure()
|
||||
|
||||
|
||||
@@ -78,7 +78,7 @@ func runClientServerTest(f *framework.Framework, configures *generalTestConfigur
|
||||
clientConfs = append(clientConfs, client2Conf)
|
||||
}
|
||||
|
||||
f.RunProcesses([]string{serverConf}, clientConfs)
|
||||
f.RunProcesses(serverConf, clientConfs)
|
||||
|
||||
if configures.testDelay > 0 {
|
||||
time.Sleep(configures.testDelay)
|
||||
|
||||
@@ -35,7 +35,7 @@ var _ = ginkgo.Describe("[Feature: Config]", func() {
|
||||
`, "`", "`", framework.TCPEchoServerPort, portName)
|
||||
|
||||
f.SetEnvs([]string{"FRP_TOKEN=123"})
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
framework.NewRequestExpect(f).PortName(portName).Ensure()
|
||||
})
|
||||
@@ -69,7 +69,7 @@ var _ = ginkgo.Describe("[Feature: Config]", func() {
|
||||
escapeTemplate("{{- end }}"),
|
||||
)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
client := f.APIClientForFrpc(adminPort)
|
||||
checkProxyFn := func(name string, localPort, remotePort int) {
|
||||
@@ -149,7 +149,7 @@ proxies:
|
||||
remotePort: %d
|
||||
`, port.GenName("Server"), framework.TCPEchoServerPort, remotePort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
framework.NewRequestExpect(f).Port(remotePort).Ensure()
|
||||
})
|
||||
|
||||
@@ -161,7 +161,7 @@ proxies:
|
||||
"proxies": [{"name": "tcp", "type": "tcp", "localPort": {{ .%s }}, "remotePort": %d}]}`,
|
||||
port.GenName("Server"), framework.TCPEchoServerPort, remotePort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
framework.NewRequestExpect(f).Port(remotePort).Ensure()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -59,7 +59,7 @@ var _ = ginkgo.Describe("[Feature: HTTP]", func() {
|
||||
locations = ["/bar"]
|
||||
`, fooPort, barPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
tests := []struct {
|
||||
path string
|
||||
@@ -117,7 +117,7 @@ var _ = ginkgo.Describe("[Feature: HTTP]", func() {
|
||||
customDomains = ["normal.example.com"]
|
||||
`, fooPort, barPort, otherPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
// user1
|
||||
framework.NewRequestExpect(f).Explain("user1").Port(vhostHTTPPort).
|
||||
@@ -159,7 +159,7 @@ var _ = ginkgo.Describe("[Feature: HTTP]", func() {
|
||||
httpPassword = "test"
|
||||
`, framework.HTTPSimpleServerPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
// not set auth header
|
||||
framework.NewRequestExpect(f).Port(vhostHTTPPort).
|
||||
@@ -196,7 +196,7 @@ var _ = ginkgo.Describe("[Feature: HTTP]", func() {
|
||||
customDomains = ["*.example.com"]
|
||||
`, framework.HTTPSimpleServerPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
// not match host
|
||||
framework.NewRequestExpect(f).Port(vhostHTTPPort).
|
||||
@@ -248,7 +248,7 @@ var _ = ginkgo.Describe("[Feature: HTTP]", func() {
|
||||
subdomain = "bar"
|
||||
`, fooPort, barPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
// foo
|
||||
framework.NewRequestExpect(f).Explain("foo subdomain").Port(vhostHTTPPort).
|
||||
@@ -290,7 +290,7 @@ var _ = ginkgo.Describe("[Feature: HTTP]", func() {
|
||||
requestHeaders.set.x-from-where = "frp"
|
||||
`, localPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
framework.NewRequestExpect(f).Port(vhostHTTPPort).
|
||||
RequestModify(func(r *request.Request) {
|
||||
@@ -323,7 +323,7 @@ var _ = ginkgo.Describe("[Feature: HTTP]", func() {
|
||||
responseHeaders.set.x-from-where = "frp"
|
||||
`, localPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
framework.NewRequestExpect(f).Port(vhostHTTPPort).
|
||||
RequestModify(func(r *request.Request) {
|
||||
@@ -357,7 +357,7 @@ var _ = ginkgo.Describe("[Feature: HTTP]", func() {
|
||||
hostHeaderRewrite = "rewrite.example.com"
|
||||
`, localPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
framework.NewRequestExpect(f).Port(vhostHTTPPort).
|
||||
RequestModify(func(r *request.Request) {
|
||||
@@ -406,7 +406,7 @@ var _ = ginkgo.Describe("[Feature: HTTP]", func() {
|
||||
customDomains = ["127.0.0.1"]
|
||||
`, localPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
u := url.URL{Scheme: "ws", Host: "127.0.0.1:" + strconv.Itoa(vhostHTTPPort)}
|
||||
c, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
|
||||
@@ -447,7 +447,7 @@ var _ = ginkgo.Describe("[Feature: HTTP]", func() {
|
||||
customDomains = ["normal.example.com"]
|
||||
`, localPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
framework.NewRequestExpect(f).Port(vhostHTTPPort).
|
||||
RequestModify(func(r *request.Request) {
|
||||
|
||||
@@ -33,7 +33,7 @@ var _ = ginkgo.Describe("[Feature: Server Manager]", func() {
|
||||
udpPortName := port.GenName("UDP", port.WithRangePorts(12000, 13000))
|
||||
clientConf += fmt.Sprintf(`
|
||||
[[proxies]]
|
||||
name = "tcp-allowded-in-range"
|
||||
name = "tcp-allowed-in-range"
|
||||
type = "tcp"
|
||||
localPort = {{ .%s }}
|
||||
remotePort = {{ .%s }}
|
||||
@@ -67,7 +67,7 @@ var _ = ginkgo.Describe("[Feature: Server Manager]", func() {
|
||||
remotePort = 11003
|
||||
`, framework.UDPEchoServerPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
// TCP
|
||||
// Allowed in range
|
||||
@@ -108,7 +108,7 @@ var _ = ginkgo.Describe("[Feature: Server Manager]", func() {
|
||||
localPort = {{ .%s }}
|
||||
`, adminPort, framework.TCPEchoServerPort, framework.UDPEchoServerPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
client := f.APIClientForFrpc(adminPort)
|
||||
|
||||
@@ -150,7 +150,7 @@ var _ = ginkgo.Describe("[Feature: Server Manager]", func() {
|
||||
customDomains = ["example.com"]
|
||||
`, framework.HTTPSimpleServerPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
framework.NewRequestExpect(f).RequestModify(func(r *request.Request) {
|
||||
r.HTTP().HTTPHost("example.com")
|
||||
@@ -178,7 +178,7 @@ var _ = ginkgo.Describe("[Feature: Server Manager]", func() {
|
||||
customDomains = ["example.com"]
|
||||
`, framework.HTTPSimpleServerPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
framework.NewRequestExpect(f).RequestModify(func(r *request.Request) {
|
||||
r.HTTP().HTTPPath("/healthz")
|
||||
|
||||
@@ -79,7 +79,7 @@ var _ = ginkgo.Describe("[Feature: TCPMUX httpconnect]", func() {
|
||||
customDomains = ["normal.example.com"]
|
||||
`, fooPort, barPort, otherPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
// user1
|
||||
framework.NewRequestExpect(f).Explain("user1").
|
||||
@@ -125,7 +125,7 @@ var _ = ginkgo.Describe("[Feature: TCPMUX httpconnect]", func() {
|
||||
httpPassword = "test"
|
||||
`, fooPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
// not set auth header
|
||||
framework.NewRequestExpect(f).Explain("no auth").
|
||||
@@ -209,7 +209,7 @@ var _ = ginkgo.Describe("[Feature: TCPMUX httpconnect]", func() {
|
||||
customDomains = ["normal.example.com"]
|
||||
`, localPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
framework.NewRequestExpect(f).
|
||||
RequestModify(func(r *request.Request) {
|
||||
|
||||
@@ -16,8 +16,11 @@ package basic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/onsi/ginkgo/v2"
|
||||
|
||||
@@ -73,7 +76,7 @@ localPort = {{ .%s }}
|
||||
remotePort = {{ .%s }}
|
||||
`, tokenContent, framework.TCPEchoServerPort, portName)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
framework.NewRequestExpect(f).PortName(portName).Ensure()
|
||||
})
|
||||
@@ -109,7 +112,7 @@ localPort = {{ .%s }}
|
||||
remotePort = {{ .%s }}
|
||||
`, tokenFile, framework.TCPEchoServerPort, portName)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
framework.NewRequestExpect(f).PortName(portName).Ensure()
|
||||
})
|
||||
@@ -150,7 +153,7 @@ localPort = {{ .%s }}
|
||||
remotePort = {{ .%s }}
|
||||
`, clientTokenFile, framework.TCPEchoServerPort, portName)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
framework.NewRequestExpect(f).PortName(portName).Ensure()
|
||||
})
|
||||
@@ -190,7 +193,7 @@ localPort = {{ .%s }}
|
||||
remotePort = {{ .%s }}
|
||||
`, clientTokenFile, framework.TCPEchoServerPort, portName)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
// This should fail due to token mismatch - the client should not be able to connect
|
||||
// We expect the request to fail because the proxy tunnel is not established
|
||||
@@ -198,32 +201,27 @@ remotePort = {{ .%s }}
|
||||
})
|
||||
|
||||
ginkgo.It("should fail with non-existent token file", func() {
|
||||
// This test verifies that server fails to start when tokenSource points to non-existent file
|
||||
// We'll verify this by checking that the configuration loading itself fails
|
||||
|
||||
// Create a config that references a non-existent file
|
||||
tmpDir := f.TempDirectory
|
||||
nonExistentFile := filepath.Join(tmpDir, "non_existent_token")
|
||||
|
||||
serverConf := consts.DefaultServerConfig
|
||||
|
||||
// Server config with non-existent tokenSource file
|
||||
serverConf += fmt.Sprintf(`
|
||||
serverPort := f.AllocPort()
|
||||
serverConf := fmt.Sprintf(`
|
||||
bindAddr = "0.0.0.0"
|
||||
bindPort = %d
|
||||
auth.tokenSource.type = "file"
|
||||
auth.tokenSource.file.path = "%s"
|
||||
`, nonExistentFile)
|
||||
`, serverPort, nonExistentFile)
|
||||
|
||||
// The test expectation is that this will fail during the RunProcesses call
|
||||
// because the server cannot load the configuration due to missing token file
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// Expected: server should fail to start due to missing file
|
||||
ginkgo.By(fmt.Sprintf("Server correctly failed to start: %v", r))
|
||||
}
|
||||
}()
|
||||
serverConfigPath := f.GenerateConfigFile(serverConf)
|
||||
|
||||
// This should cause a panic or error during server startup
|
||||
f.RunProcesses([]string{serverConf}, []string{})
|
||||
_, _, _ = f.RunFrps("-c", serverConfigPath)
|
||||
|
||||
// Server should have failed to start, so the port should not be listening.
|
||||
conn, err := net.DialTimeout("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(serverPort)), 1*time.Second)
|
||||
if err == nil {
|
||||
conn.Close()
|
||||
}
|
||||
framework.ExpectTrue(err != nil, "server should not be listening on port %d", serverPort)
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ var _ = ginkgo.Describe("[Feature: XTCP]", func() {
|
||||
fallbackTimeoutMs = 200
|
||||
`, framework.TCPEchoServerPort, bindPortName)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
framework.NewRequestExpect(f).
|
||||
RequestModify(func(r *request.Request) {
|
||||
r.Timeout(time.Second)
|
||||
|
||||
@@ -36,7 +36,7 @@ var _ = ginkgo.Describe("[Feature: Bandwidth Limit]", func() {
|
||||
transport.bandwidthLimit = "10KB"
|
||||
`, localPort, remotePort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
content := strings.Repeat("a", 50*1024) // 5KB
|
||||
start := time.Now()
|
||||
@@ -92,7 +92,7 @@ var _ = ginkgo.Describe("[Feature: Bandwidth Limit]", func() {
|
||||
remotePort = %d
|
||||
`, localPort, remotePort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
content := strings.Repeat("a", 50*1024) // 5KB
|
||||
start := time.Now()
|
||||
|
||||
@@ -50,12 +50,10 @@ var _ = ginkgo.Describe("[Feature: Group]", func() {
|
||||
return true
|
||||
})
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
wait.Add(1)
|
||||
go func() {
|
||||
defer wait.Done()
|
||||
for range 10 {
|
||||
wait.Go(func() {
|
||||
expectFn()
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
wait.Wait()
|
||||
@@ -94,11 +92,11 @@ var _ = ginkgo.Describe("[Feature: Group]", func() {
|
||||
loadBalancer.groupKey = "123"
|
||||
`, fooPort, remotePort, barPort, remotePort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
fooCount := 0
|
||||
barCount := 0
|
||||
for i := 0; i < 10; i++ {
|
||||
for i := range 10 {
|
||||
framework.NewRequestExpect(f).Explain("times " + strconv.Itoa(i)).Port(remotePort).Ensure(func(resp *request.Response) bool {
|
||||
switch string(resp.Content) {
|
||||
case "foo":
|
||||
@@ -159,11 +157,11 @@ var _ = ginkgo.Describe("[Feature: Group]", func() {
|
||||
loadBalancer.groupKey = "123"
|
||||
`, fooPort, barPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
fooCount := 0
|
||||
barCount := 0
|
||||
for i := 0; i < 10; i++ {
|
||||
for i := range 10 {
|
||||
framework.NewRequestExpect(f).
|
||||
Explain("times " + strconv.Itoa(i)).
|
||||
Port(vhostHTTPSPort).
|
||||
@@ -188,6 +186,68 @@ var _ = ginkgo.Describe("[Feature: Group]", func() {
|
||||
|
||||
framework.ExpectTrue(fooCount > 1 && barCount > 1, "fooCount: %d, barCount: %d", fooCount, barCount)
|
||||
})
|
||||
|
||||
ginkgo.It("TCPMux httpconnect", func() {
|
||||
vhostPort := f.AllocPort()
|
||||
serverConf := consts.DefaultServerConfig + fmt.Sprintf(`
|
||||
tcpmuxHTTPConnectPort = %d
|
||||
`, vhostPort)
|
||||
clientConf := consts.DefaultClientConfig
|
||||
|
||||
fooPort := f.AllocPort()
|
||||
fooServer := streamserver.New(streamserver.TCP, streamserver.WithBindPort(fooPort), streamserver.WithRespContent([]byte("foo")))
|
||||
f.RunServer("", fooServer)
|
||||
|
||||
barPort := f.AllocPort()
|
||||
barServer := streamserver.New(streamserver.TCP, streamserver.WithBindPort(barPort), streamserver.WithRespContent([]byte("bar")))
|
||||
f.RunServer("", barServer)
|
||||
|
||||
clientConf += fmt.Sprintf(`
|
||||
[[proxies]]
|
||||
name = "foo"
|
||||
type = "tcpmux"
|
||||
multiplexer = "httpconnect"
|
||||
localPort = %d
|
||||
customDomains = ["tcpmux-group.example.com"]
|
||||
loadBalancer.group = "test"
|
||||
loadBalancer.groupKey = "123"
|
||||
|
||||
[[proxies]]
|
||||
name = "bar"
|
||||
type = "tcpmux"
|
||||
multiplexer = "httpconnect"
|
||||
localPort = %d
|
||||
customDomains = ["tcpmux-group.example.com"]
|
||||
loadBalancer.group = "test"
|
||||
loadBalancer.groupKey = "123"
|
||||
`, fooPort, barPort)
|
||||
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
proxyURL := fmt.Sprintf("http://127.0.0.1:%d", vhostPort)
|
||||
fooCount := 0
|
||||
barCount := 0
|
||||
for i := range 10 {
|
||||
framework.NewRequestExpect(f).
|
||||
Explain("times " + strconv.Itoa(i)).
|
||||
RequestModify(func(r *request.Request) {
|
||||
r.Addr("tcpmux-group.example.com").Proxy(proxyURL)
|
||||
}).
|
||||
Ensure(func(resp *request.Response) bool {
|
||||
switch string(resp.Content) {
|
||||
case "foo":
|
||||
fooCount++
|
||||
case "bar":
|
||||
barCount++
|
||||
default:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
framework.ExpectTrue(fooCount > 1 && barCount > 1, "fooCount: %d, barCount: %d", fooCount, barCount)
|
||||
})
|
||||
})
|
||||
|
||||
ginkgo.Describe("Health Check", func() {
|
||||
@@ -226,11 +286,11 @@ var _ = ginkgo.Describe("[Feature: Group]", func() {
|
||||
healthCheck.intervalSeconds = 1
|
||||
`, fooPort, remotePort, barPort, remotePort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
// check foo and bar is ok
|
||||
results := []string{}
|
||||
for i := 0; i < 10; i++ {
|
||||
for range 10 {
|
||||
framework.NewRequestExpect(f).Port(remotePort).Ensure(validateFooBarResponse, func(resp *request.Response) bool {
|
||||
results = append(results, string(resp.Content))
|
||||
return true
|
||||
@@ -241,7 +301,7 @@ var _ = ginkgo.Describe("[Feature: Group]", func() {
|
||||
// close bar server, check foo is ok
|
||||
barServer.Close()
|
||||
time.Sleep(2 * time.Second)
|
||||
for i := 0; i < 10; i++ {
|
||||
for range 10 {
|
||||
framework.NewRequestExpect(f).Port(remotePort).ExpectResp([]byte("foo")).Ensure()
|
||||
}
|
||||
|
||||
@@ -249,7 +309,7 @@ var _ = ginkgo.Describe("[Feature: Group]", func() {
|
||||
f.RunServer("", barServer)
|
||||
time.Sleep(2 * time.Second)
|
||||
results = []string{}
|
||||
for i := 0; i < 10; i++ {
|
||||
for range 10 {
|
||||
framework.NewRequestExpect(f).Port(remotePort).Ensure(validateFooBarResponse, func(resp *request.Response) bool {
|
||||
results = append(results, string(resp.Content))
|
||||
return true
|
||||
@@ -297,7 +357,7 @@ var _ = ginkgo.Describe("[Feature: Group]", func() {
|
||||
healthCheck.path = "/healthz"
|
||||
`, fooPort, barPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
// send first HTTP request
|
||||
var contents []string
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user