Compare commits

...

28 Commits

Author SHA1 Message Date
Shani Pathak
94a631fe9c auth/oidc: cache OIDC access token and refresh before expiry (#5175)
* auth/oidc: cache OIDC access token and refresh before expiry

- Use Config.TokenSource(ctx) once at init to create a persistent
  oauth2.TokenSource that caches the token and only refreshes on expiry
- Wrap with oauth2.ReuseTokenSourceWithExpiry for configurable early refresh
- Add tokenRefreshAdvanceDuration config option (default: 300s)
- Add unit test verifying token caching with mock HTTP server

* address review comments

* auth/oidc: fallback to per-request token fetch when expires_in is missing

When an OIDC provider omits the expires_in field, oauth2.ReuseTokenSource
treats the cached token as valid forever and never refreshes it. This causes
server-side OIDC verification to fail once the JWT's exp claim passes.

Add a nonCachingTokenSource fallback: after fetching the initial token, if
its Expiry is the zero value, swap the caching TokenSource for one that
fetches a fresh token on every request, preserving the old behavior for
providers that don't return expires_in.

* auth/oidc: fix gosec lint and add test for zero-expiry fallback

Suppress G101 false positive on test-only dummy token responses.
Add test to verify per-request token fetch when expires_in is missing.
Update caching test to account for eager initial token fetch.

* fix lint
2026-03-12 00:24:46 +08:00
fatedier
6b1be922e1 add AGENTS.md and CLAUDE.md, remove them from .gitignore (#5232) 2026-03-12 00:21:31 +08:00
fatedier
4f584f81d0 test/e2e: replace sleeps with event-driven waits in chaos/group/store tests (#5231)
* test/e2e: replace sleeps with event-driven waits in chaos/group/store tests

Replace 21 time.Sleep calls with deterministic waiting using
WaitForOutput, WaitForTCPReady, and a new WaitForTCPUnreachable helper.
Add CountOutput method for snapshot-based incremental log matching.

* test/e2e: validate interval and cap dial/sleep to remaining deadline in WaitForTCPUnreachable
2026-03-12 00:11:09 +08:00
fatedier
9669e1ca0c test/e2e: replace RunProcesses client sleep with log-based proxy readiness detection (#5226)
* test/e2e: replace RunProcesses client sleep with log-based proxy readiness detection

Replace the fixed 1500ms sleep in RunProcesses with event-driven proxy
registration detection by monitoring frpc log output for "start proxy
success" messages.

Key changes:
- Add thread-safe SafeBuffer to replace bytes.Buffer in Process, enabling
  concurrent read/write of process output during execution
- Add Process.WaitForOutput() to poll process output for pattern matches
  with timeout and early exit on process termination
- Add waitForClientProxyReady() that uses config.LoadClientConfig() to
  extract proxy names, then waits for each proxy's success log
- For visitor-only clients (no deterministic readiness signal), fall back
  to the original sleep with elapsed time deducted

* test/e2e: use shared deadline for proxy readiness and fix doc comment

- Use a single deadline in waitForClientProxyReady so total wait across
  all proxies does not exceed the given timeout
- Fix WaitForOutput doc comment to accurately describe single pattern
  with count semantics
2026-03-09 22:28:23 +08:00
fatedier
48e8901466 test/e2e: optimize RunFrps/RunFrpc with process exit detection (#5225)
* test/e2e: optimize RunFrps/RunFrpc with process exit detection

Refactor Process to track subprocess lifecycle via a done channel,
replacing direct cmd.Wait() in Stop() to avoid double-Wait races.
RunFrps/RunFrpc now use select on the done channel instead of fixed
sleeps, allowing short-lived processes (verify, startup failures) to
return immediately while preserving existing timeout behavior for
long-running daemons.

* test/e2e: guard Process against double-Start and Stop-before-Start

Add started flag to prevent double-Start panics and allow Stop to
return immediately when the process was never started. Use sync.Once
for closing the done channel as defense-in-depth against double close.
2026-03-09 10:28:47 +08:00
fatedier
bcd2424c24 test/e2e: optimize e2e test time by replacing sleeps with TCP readiness checks (#5223)
Replace the fixed 500ms sleep after each frps startup in RunProcesses
with a TCP dial-based readiness check that polls the server bind port.
This reduces the e2e suite wall time from ~97s to ~43s.

Also simplify the RunProcesses API to accept a single server template
string instead of a slice, matching how every call site uses it.
2026-03-08 23:41:33 +08:00
fatedier
c7ac12ea0f server/group: refactor with shared abstractions and fix concurrency issues (#5222)
* server/group: refactor group package with shared abstractions and fix concurrency issues

Extract common patterns into reusable components:
- groupRegistry[G]: generic concurrent map for group lifecycle management
- baseGroup: shared plumbing for listener-based groups (TCP, HTTPS, TCPMux)
- Listener: unified virtual listener replacing 3 identical implementations

Fix concurrency issues:
- Stale-pointer race: isCurrent check + errGroupStale + controller retry loops
- Worker generation safety: pass realLn and acceptCh as params instead of reading mutable fields
- Connection leak: close conn on worker panic recovery path
- ABBA deadlock in HTTP UnRegister: consistent lock ordering (group.mu -> registry.mu)
- Round-robin overflow in HTTPGroup: use unsigned modulo

Add unit tests (17 tests) for registry, listener, and baseGroup.
Add TCPMux group load balancing e2e test.

* server/group: replace tautological assertion with require.NotPanics

* server/group: remove blank line between doc comment and type declaration
2026-03-08 18:57:21 +08:00
fatedier
eeb0dacfc1 pkg/metrics/mem: remove redundant map write-backs and optimize proxy lookup (#5221)
Remove 4 redundant pointer map write-backs in OpenConnection,
CloseConnection, AddTrafficIn, and AddTrafficOut since the map stores
pointers and mutations are already visible without reassignment.

Optimize GetProxiesByTypeAndName from O(n) full map scan to O(1) direct
map lookup by proxy name.
2026-03-08 10:40:39 +08:00
Oleksandr Redko
535eb3db35 refactor: use maps.Clone and slices.Concat (#5220) 2026-03-08 10:38:16 +08:00
fatedier
605f3bdece client/visitor: deduplicate visitor connection handshake and wrapping (#5219)
Extract two shared helpers to eliminate duplicated code across STCP,
SUDP, and XTCP visitors:

- dialRawVisitorConn: handles ConnectServer + NewVisitorConn handshake
  (auth, sign key, 10s read deadline, error check)
- wrapVisitorConn: handles encryption + pooled compression wrapping,
  returning a recycleFn for pool resource cleanup

SUDP is upgraded from WithCompression to WithCompressionFromPool,
aligning with the pooled compression used by STCP and XTCP.
2026-03-08 01:03:40 +08:00
fatedier
764a626b6e server/control: deduplicate close-proxy logic and UserInfo construction (#5218)
Extract closeProxy() helper to eliminate duplicated 4-step cleanup
sequence (Close, PxyManager.Del, metrics, plugin notify) between
worker() and CloseProxy().

Extract loginUserInfo() helper to eliminate 4 repeated plugin.UserInfo
constructions using LoginMsg fields.

Optimize worker() to snapshot and clear the proxies map under lock,
then perform cleanup outside the lock to reduce lock hold time.
2026-03-08 00:02:14 +08:00
Oleksandr Redko
c2454e7114 refactor: fix modernize lint issues (#5215) 2026-03-07 23:10:19 +08:00
fatedier
017d71717f server: introduce SessionContext to encapsulate NewControl parameters (#5217)
Replace 10 positional parameters in NewControl() with a single
SessionContext struct, matching the client-side pattern. This also
eliminates the post-construction mutation of clientRegistry and
removes two TODO comments.
2026-03-07 20:17:00 +08:00
Oleksandr Redko
bd200b1a3b fix: typos in comments, tests, functions (#5216) 2026-03-07 18:43:04 +08:00
fatedier
c70ceff370 fix: three high-severity bugs across nathole, proxy, and udp modules (#5214)
- pkg/nathole: add RLock when reading clientCfgs map in PreCheck path
  to prevent concurrent map read/write crash
- server/proxy: fix error variable shadowing in GetWorkConnFromPool
  that could return a closed connection with nil error
- pkg/util/net: check ListenUDP error before spawning goroutines
  and assign readConn to struct field so Close() works correctly
2026-03-07 13:36:02 +08:00
fatedier
bb3d0e7140 deduplicate common logic across proxy, visitor, and metrics modules (#5213)
- Replace duplicate parseBasicAuth with existing httppkg.ParseBasicAuth
- Extract buildDomains helper in BaseProxy for HTTP/HTTPS/TCPMux proxies
- Extract toProxyStats helper to deduplicate ProxyStats construction
- Extract startVisitorListener helper in BaseProxy for STCP/SUDP proxies
- Extract acceptLoop helper in BaseVisitor for STCP/XTCP visitors
2026-03-07 12:00:27 +08:00
fatedier
cf396563f8 client/proxy: unify work conn wrapping across all proxy types (#5212)
* client/proxy: extract wrapWorkConn to deduplicate UDP/SUDP connection wrapping

Move the repeated rate-limiting, encryption, and compression wrapping
logic from UDPProxy and SUDPProxy into a shared BaseProxy.wrapWorkConn
method, reducing ~18 lines of duplication in each proxy type.

* client/proxy: unify work conn wrapping with pooled compression for all proxy types

Refactor wrapWorkConn to accept encKey parameter and return
(io.ReadWriteCloser, recycleFn, error), enabling HandleTCPWorkConnection
to reuse the same limiter/encryption/compression pipeline.

Switch all proxy types from WithCompression to WithCompressionFromPool.
TCP non-plugin path calls recycleFn via defer after Join; plugin and
UDP/SUDP paths skip recycle (objects are GC'd safely, per golib contract).
2026-03-07 01:33:37 +08:00
fatedier
0b4f83cd04 pkg/config: use modern Go stdlib for sorting and string operations (#5210)
- slices.SortedFunc + maps.Values + cmp.Compare instead of manual
  map-to-slice collection + sort.Slice (source/aggregator.go)
- strings.CutSuffix instead of HasSuffix+TrimSuffix, and deduplicate
  error handling in BandwidthQuantity.UnmarshalString (types/types.go)
2026-03-06 23:13:29 +08:00
fatedier
e9f7a1a9f2 pkg: use modern Go stdlib functions to simplify code (#5209)
- strings.CutPrefix instead of HasPrefix+TrimPrefix (naming, legacy)
- slices.Contains instead of manual loop (plugin/server)
- min/max builtins instead of manual comparisons (nathole)
2026-03-06 22:14:46 +08:00
fatedier
d644593342 server/proxy: simplify HTTPS and TCPMux proxy domain registration (#5208)
Consolidate the separate custom-domain loop and subdomain block into a
single unified loop, matching the pattern already applied to HTTPProxy
in PR #5207. No behavioral change.
2026-03-06 21:31:29 +08:00
fatedier
427c4ca3ae server/proxy: simplify HTTP proxy domain registration by removing duplicate loop (#5207)
The Run() method had two nearly identical loop blocks for registering
custom domains and subdomain, with the same group/non-group registration
logic copy-pasted (~30 lines of duplication).

Consolidate by collecting all domains into a single slice first, then
iterating once with the shared registration logic. Also fixes a minor
inconsistency where the custom domain block used routeConfig.Domain in
CanonicalAddr but the subdomain block used tmpRouteConfig.Domain.
2026-03-06 21:17:30 +08:00
fatedier
f2d1f3739a pkg/util/xlog: fix AddPrefix not updating existing prefix due to range value copy (#5206)
In AddPrefix, the loop `for _, p := range l.prefixes` creates a copy
of each element. Assignments to p.Value and p.Priority only modify
the local copy, not the original slice element, causing updates to
existing prefixes to be silently lost.

This affects client/service.go where AddPrefix is called with
Name:"runID" on reconnection — the old runID value would persist
in log output instead of being updated to the new one.

Fix by using index-based access `l.prefixes[i]` to modify the
original slice element, and add break since prefix names are unique.
2026-03-06 20:44:40 +08:00
fatedier
c23894f156 fix: validate CA cert parsing and add missing ReadHeaderTimeout (#5205)
- pkg/transport/tls.go: check AppendCertsFromPEM return value and
  return clear error when CA file contains no valid PEM certificates
- pkg/plugin/client/http2http.go: set ReadHeaderTimeout to 60s to
  match other plugins and prevent slow header attacks
- pkg/plugin/client/http2https.go: same ReadHeaderTimeout fix
2026-03-06 17:59:41 +08:00
fatedier
cb459b02b6 fix: WebsocketListener nil panic and OIDC auth data race (#5204)
- pkg/util/net/websocket.go: store ln parameter in struct to prevent
  nil pointer panic when Addr() is called
- pkg/auth/oidc.go: replace unsynchronized []string with map + RWMutex
  for subjectsFromLogin to fix data race across concurrent connections
2026-03-06 16:51:52 +08:00
fatedier
8f633fe363 fix: return buffers to pool on error paths to reduce GC pressure (#5203)
- pkg/nathole/nathole.go: add pool.PutBuf(buf) on ReadFromUDP error
  and DecodeMessageInto error paths in waitDetectMessage
- pkg/proto/udp/udp.go: add defer pool.PutBuf(buf) in writerFn to
  ensure buffer is returned when the goroutine exits
2026-03-06 15:55:22 +08:00
fatedier
c62a1da161 fix: close connections on error paths to prevent resource leaks (#5202)
Fix connection leaks in multiple error paths across client and server:
- server/proxy/http: close tmpConn when WithEncryption fails
- client/proxy: close localConn when ProxyProtocol WriteTo fails
- client/visitor/sudp: close visitorConn on all error paths in getNewVisitorConn
- client/visitor/xtcp: close tunnelConn when WithEncryption fails
- client/visitor/xtcp: close lConn when NewKCPConnFromUDP fails
- pkg/plugin/client/unix_domain_socket: close localConn and connInfo.Conn when WriteTo fails, close connInfo.Conn when DialUnix fails
- pkg/plugin/client/tls2raw: close tlsConn when Handshake or Dial fails
2026-03-06 15:18:38 +08:00
fatedier
f22f7d539c server/group: fix port leak and incorrect Listen port in TCPGroup (#5200)
Fix two bugs in TCPGroup.Listen():
- Release acquired port when net.Listen fails to prevent port leak
- Use realPort instead of port for net.Listen to ensure consistency
  between port manager records and actual listening port
2026-03-06 02:25:47 +08:00
fatedier
462c987f6d pkg/msg: change UDPPacket.Content from string to []byte to avoid redundant base64 encode/decode (#5198) 2026-03-06 01:38:24 +08:00
112 changed files with 1905 additions and 1456 deletions

3
.gitignore vendored
View File

@@ -29,6 +29,5 @@ client.key
*.swp *.swp
# AI # AI
CLAUDE.md .claude/
AGENTS.md
.sisyphus/ .sisyphus/

View File

@@ -18,6 +18,7 @@ linters:
- lll - lll
- makezero - makezero
- misspell - misspell
- modernize
- prealloc - prealloc
- predeclared - predeclared
- revive - revive
@@ -47,6 +48,9 @@ linters:
ignore-rules: ignore-rules:
- cancelled - cancelled
- marshalled - marshalled
modernize:
disable:
- omitzero
unparam: unparam:
check-exported: false check-exported: false
exclusions: exclusions:

34
AGENTS.md Normal file
View File

@@ -0,0 +1,34 @@
# AGENTS.md
## Development Commands
### Build
- `make build` - Build both frps and frpc binaries
- `make frps` - Build server binary only
- `make frpc` - Build client binary only
- `make all` - Build everything with formatting
### Testing
- `make test` - Run unit tests
- `make e2e` - Run end-to-end tests
- `make e2e-trace` - Run e2e tests with trace logging
- `make alltest` - Run all tests including vet, unit tests, and e2e
### Code Quality
- `make fmt` - Run go fmt
- `make fmt-more` - Run gofumpt for more strict formatting
- `make gci` - Run gci import organizer
- `make vet` - Run go vet
- `golangci-lint run` - Run comprehensive linting (configured in .golangci.yml)
### Assets
- `make web` - Build web dashboards (frps and frpc)
### Cleanup
- `make clean` - Remove built binaries and temporary files
## Testing
- E2E tests using Ginkgo/Gomega framework
- Mock servers in `/test/e2e/mock/`
- Run: `make e2e` or `make alltest`

1
CLAUDE.md Symbolic link
View File

@@ -0,0 +1 @@
AGENTS.md

View File

@@ -16,6 +16,7 @@ package proxy
import ( import (
"context" "context"
"fmt"
"io" "io"
"net" "net"
"reflect" "reflect"
@@ -122,6 +123,33 @@ func (pxy *BaseProxy) Close() {
} }
} }
// wrapWorkConn applies rate limiting, encryption, and compression
// to a work connection based on the proxy's transport configuration.
// The returned recycle function should be called when the stream is no longer in use
// to return compression resources to the pool. It is safe to not call recycle,
// in which case resources will be garbage collected normally.
func (pxy *BaseProxy) wrapWorkConn(conn net.Conn, encKey []byte) (io.ReadWriteCloser, func(), error) {
var rwc io.ReadWriteCloser = conn
if pxy.limiter != nil {
rwc = libio.WrapReadWriteCloser(limit.NewReader(conn, pxy.limiter), limit.NewWriter(conn, pxy.limiter), func() error {
return conn.Close()
})
}
if pxy.baseCfg.Transport.UseEncryption {
var err error
rwc, err = libio.WithEncryption(rwc, encKey)
if err != nil {
conn.Close()
return nil, nil, fmt.Errorf("create encryption stream error: %w", err)
}
}
var recycleFn func()
if pxy.baseCfg.Transport.UseCompression {
rwc, recycleFn = libio.WithCompressionFromPool(rwc)
}
return rwc, recycleFn, nil
}
func (pxy *BaseProxy) SetInWorkConnCallback(cb func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool) { func (pxy *BaseProxy) SetInWorkConnCallback(cb func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool) {
pxy.inWorkConnCallback = cb pxy.inWorkConnCallback = cb
} }
@@ -139,30 +167,14 @@ func (pxy *BaseProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) {
func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWorkConn, encKey []byte) { func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWorkConn, encKey []byte) {
xl := pxy.xl xl := pxy.xl
baseCfg := pxy.baseCfg baseCfg := pxy.baseCfg
var (
remote io.ReadWriteCloser
err error
)
remote = workConn
if pxy.limiter != nil {
remote = libio.WrapReadWriteCloser(limit.NewReader(workConn, pxy.limiter), limit.NewWriter(workConn, pxy.limiter), func() error {
return workConn.Close()
})
}
xl.Tracef("handle tcp work connection, useEncryption: %t, useCompression: %t", xl.Tracef("handle tcp work connection, useEncryption: %t, useCompression: %t",
baseCfg.Transport.UseEncryption, baseCfg.Transport.UseCompression) baseCfg.Transport.UseEncryption, baseCfg.Transport.UseCompression)
if baseCfg.Transport.UseEncryption {
remote, err = libio.WithEncryption(remote, encKey) remote, recycleFn, err := pxy.wrapWorkConn(workConn, encKey)
if err != nil { if err != nil {
workConn.Close() xl.Errorf("wrap work connection: %v", err)
xl.Errorf("create encryption stream error: %v", err) return
return
}
}
var compressionResourceRecycleFn func()
if baseCfg.Transport.UseCompression {
remote, compressionResourceRecycleFn = libio.WithCompressionFromPool(remote)
} }
// check if we need to send proxy protocol info // check if we need to send proxy protocol info
@@ -178,7 +190,6 @@ func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWor
} }
if baseCfg.Transport.ProxyProtocolVersion != "" && m.SrcAddr != "" && m.SrcPort != 0 { if baseCfg.Transport.ProxyProtocolVersion != "" && m.SrcAddr != "" && m.SrcPort != 0 {
// Use the common proxy protocol builder function
header := netpkg.BuildProxyProtocolHeaderStruct(connInfo.SrcAddr, connInfo.DstAddr, baseCfg.Transport.ProxyProtocolVersion) header := netpkg.BuildProxyProtocolHeaderStruct(connInfo.SrcAddr, connInfo.DstAddr, baseCfg.Transport.ProxyProtocolVersion)
connInfo.ProxyProtocolHeader = header connInfo.ProxyProtocolHeader = header
} }
@@ -187,12 +198,18 @@ func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWor
if pxy.proxyPlugin != nil { if pxy.proxyPlugin != nil {
// if plugin is set, let plugin handle connection first // if plugin is set, let plugin handle connection first
// Don't recycle compression resources here because plugins may
// retain the connection after Handle returns.
xl.Debugf("handle by plugin: %s", pxy.proxyPlugin.Name()) xl.Debugf("handle by plugin: %s", pxy.proxyPlugin.Name())
pxy.proxyPlugin.Handle(pxy.ctx, &connInfo) pxy.proxyPlugin.Handle(pxy.ctx, &connInfo)
xl.Debugf("handle by plugin finished") xl.Debugf("handle by plugin finished")
return return
} }
if recycleFn != nil {
defer recycleFn()
}
localConn, err := libnet.Dial( localConn, err := libnet.Dial(
net.JoinHostPort(baseCfg.LocalIP, strconv.Itoa(baseCfg.LocalPort)), net.JoinHostPort(baseCfg.LocalIP, strconv.Itoa(baseCfg.LocalPort)),
libnet.WithTimeout(10*time.Second), libnet.WithTimeout(10*time.Second),
@@ -209,6 +226,7 @@ func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWor
if connInfo.ProxyProtocolHeader != nil { if connInfo.ProxyProtocolHeader != nil {
if _, err := connInfo.ProxyProtocolHeader.WriteTo(localConn); err != nil { if _, err := connInfo.ProxyProtocolHeader.WriteTo(localConn); err != nil {
workConn.Close() workConn.Close()
localConn.Close()
xl.Errorf("write proxy protocol header to local conn error: %v", err) xl.Errorf("write proxy protocol header to local conn error: %v", err)
return return
} }
@@ -219,7 +237,4 @@ func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWor
if len(errs) > 0 { if len(errs) > 0 {
xl.Tracef("join connections errors: %v", errs) xl.Tracef("join connections errors: %v", errs)
} }
if compressionResourceRecycleFn != nil {
compressionResourceRecycleFn()
}
} }

View File

@@ -17,7 +17,6 @@
package proxy package proxy
import ( import (
"io"
"net" "net"
"reflect" "reflect"
"strconv" "strconv"
@@ -25,17 +24,15 @@ import (
"time" "time"
"github.com/fatedier/golib/errors" "github.com/fatedier/golib/errors"
libio "github.com/fatedier/golib/io"
v1 "github.com/fatedier/frp/pkg/config/v1" v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/proto/udp" "github.com/fatedier/frp/pkg/proto/udp"
"github.com/fatedier/frp/pkg/util/limit"
netpkg "github.com/fatedier/frp/pkg/util/net" netpkg "github.com/fatedier/frp/pkg/util/net"
) )
func init() { func init() {
RegisterProxyFactory(reflect.TypeOf(&v1.SUDPProxyConfig{}), NewSUDPProxy) RegisterProxyFactory(reflect.TypeFor[*v1.SUDPProxyConfig](), NewSUDPProxy)
} }
type SUDPProxy struct { type SUDPProxy struct {
@@ -83,27 +80,13 @@ func (pxy *SUDPProxy) InWorkConn(conn net.Conn, _ *msg.StartWorkConn) {
xl := pxy.xl xl := pxy.xl
xl.Infof("incoming a new work connection for sudp proxy, %s", conn.RemoteAddr().String()) xl.Infof("incoming a new work connection for sudp proxy, %s", conn.RemoteAddr().String())
var rwc io.ReadWriteCloser = conn remote, _, err := pxy.wrapWorkConn(conn, pxy.encryptionKey)
var err error if err != nil {
if pxy.limiter != nil { xl.Errorf("wrap work connection: %v", err)
rwc = libio.WrapReadWriteCloser(limit.NewReader(conn, pxy.limiter), limit.NewWriter(conn, pxy.limiter), func() error { return
return conn.Close()
})
} }
if pxy.cfg.Transport.UseEncryption {
rwc, err = libio.WithEncryption(rwc, pxy.encryptionKey)
if err != nil {
conn.Close()
xl.Errorf("create encryption stream error: %v", err)
return
}
}
if pxy.cfg.Transport.UseCompression {
rwc = libio.WithCompression(rwc)
}
conn = netpkg.WrapReadWriteCloserToConn(rwc, conn)
workConn := conn workConn := netpkg.WrapReadWriteCloserToConn(remote, conn)
readCh := make(chan *msg.UDPPacket, 1024) readCh := make(chan *msg.UDPPacket, 1024)
sendCh := make(chan msg.Message, 1024) sendCh := make(chan msg.Message, 1024)
isClose := false isClose := false

View File

@@ -17,24 +17,21 @@
package proxy package proxy
import ( import (
"io"
"net" "net"
"reflect" "reflect"
"strconv" "strconv"
"time" "time"
"github.com/fatedier/golib/errors" "github.com/fatedier/golib/errors"
libio "github.com/fatedier/golib/io"
v1 "github.com/fatedier/frp/pkg/config/v1" v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/proto/udp" "github.com/fatedier/frp/pkg/proto/udp"
"github.com/fatedier/frp/pkg/util/limit"
netpkg "github.com/fatedier/frp/pkg/util/net" netpkg "github.com/fatedier/frp/pkg/util/net"
) )
func init() { func init() {
RegisterProxyFactory(reflect.TypeOf(&v1.UDPProxyConfig{}), NewUDPProxy) RegisterProxyFactory(reflect.TypeFor[*v1.UDPProxyConfig](), NewUDPProxy)
} }
type UDPProxy struct { type UDPProxy struct {
@@ -94,28 +91,14 @@ func (pxy *UDPProxy) InWorkConn(conn net.Conn, _ *msg.StartWorkConn) {
// close resources related with old workConn // close resources related with old workConn
pxy.Close() pxy.Close()
var rwc io.ReadWriteCloser = conn remote, _, err := pxy.wrapWorkConn(conn, pxy.encryptionKey)
var err error if err != nil {
if pxy.limiter != nil { xl.Errorf("wrap work connection: %v", err)
rwc = libio.WrapReadWriteCloser(limit.NewReader(conn, pxy.limiter), limit.NewWriter(conn, pxy.limiter), func() error { return
return conn.Close()
})
} }
if pxy.cfg.Transport.UseEncryption {
rwc, err = libio.WithEncryption(rwc, pxy.encryptionKey)
if err != nil {
conn.Close()
xl.Errorf("create encryption stream error: %v", err)
return
}
}
if pxy.cfg.Transport.UseCompression {
rwc = libio.WithCompression(rwc)
}
conn = netpkg.WrapReadWriteCloserToConn(rwc, conn)
pxy.mu.Lock() pxy.mu.Lock()
pxy.workConn = conn pxy.workConn = netpkg.WrapReadWriteCloserToConn(remote, conn)
pxy.readCh = make(chan *msg.UDPPacket, 1024) pxy.readCh = make(chan *msg.UDPPacket, 1024)
pxy.sendCh = make(chan msg.Message, 1024) pxy.sendCh = make(chan msg.Message, 1024)
pxy.closed = false pxy.closed = false
@@ -129,7 +112,7 @@ func (pxy *UDPProxy) InWorkConn(conn net.Conn, _ *msg.StartWorkConn) {
return return
} }
if errRet := errors.PanicToError(func() { if errRet := errors.PanicToError(func() {
xl.Tracef("get udp package from workConn: %s", udpMsg.Content) xl.Tracef("get udp package from workConn, len: %d", len(udpMsg.Content))
readCh <- &udpMsg readCh <- &udpMsg
}); errRet != nil { }); errRet != nil {
xl.Infof("reader goroutine for udp work connection closed: %v", errRet) xl.Infof("reader goroutine for udp work connection closed: %v", errRet)
@@ -145,7 +128,7 @@ func (pxy *UDPProxy) InWorkConn(conn net.Conn, _ *msg.StartWorkConn) {
for rawMsg := range sendCh { for rawMsg := range sendCh {
switch m := rawMsg.(type) { switch m := rawMsg.(type) {
case *msg.UDPPacket: case *msg.UDPPacket:
xl.Tracef("send udp package to workConn: %s", m.Content) xl.Tracef("send udp package to workConn, len: %d", len(m.Content))
case *msg.Ping: case *msg.Ping:
xl.Tracef("send ping message to udp workConn") xl.Tracef("send ping message to udp workConn")
} }

View File

@@ -34,7 +34,7 @@ import (
) )
func init() { func init() {
RegisterProxyFactory(reflect.TypeOf(&v1.XTCPProxyConfig{}), NewXTCPProxy) RegisterProxyFactory(reflect.TypeFor[*v1.XTCPProxyConfig](), NewXTCPProxy)
} }
type XTCPProxy struct { type XTCPProxy struct {

View File

@@ -15,18 +15,12 @@
package visitor package visitor
import ( import (
"fmt"
"io"
"net" "net"
"strconv" "strconv"
"time"
libio "github.com/fatedier/golib/io" libio "github.com/fatedier/golib/io"
v1 "github.com/fatedier/frp/pkg/config/v1" v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/naming"
"github.com/fatedier/frp/pkg/util/util"
"github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/frp/pkg/util/xlog"
) )
@@ -42,10 +36,10 @@ func (sv *STCPVisitor) Run() (err error) {
if err != nil { if err != nil {
return return
} }
go sv.worker() go sv.acceptLoop(sv.l, "stcp local", sv.handleConn)
} }
go sv.internalConnWorker() go sv.acceptLoop(sv.internalLn, "stcp internal", sv.handleConn)
if sv.plugin != nil { if sv.plugin != nil {
sv.plugin.Start() sv.plugin.Start()
@@ -57,35 +51,10 @@ func (sv *STCPVisitor) Close() {
sv.BaseVisitor.Close() sv.BaseVisitor.Close()
} }
func (sv *STCPVisitor) worker() {
xl := xlog.FromContextSafe(sv.ctx)
for {
conn, err := sv.l.Accept()
if err != nil {
xl.Warnf("stcp local listener closed")
return
}
go sv.handleConn(conn)
}
}
func (sv *STCPVisitor) internalConnWorker() {
xl := xlog.FromContextSafe(sv.ctx)
for {
conn, err := sv.internalLn.Accept()
if err != nil {
xl.Warnf("stcp internal listener closed")
return
}
go sv.handleConn(conn)
}
}
func (sv *STCPVisitor) handleConn(userConn net.Conn) { func (sv *STCPVisitor) handleConn(userConn net.Conn) {
xl := xlog.FromContextSafe(sv.ctx) xl := xlog.FromContextSafe(sv.ctx)
var tunnelErr error var tunnelErr error
defer func() { defer func() {
// If there was an error and connection supports CloseWithError, use it
if tunnelErr != nil { if tunnelErr != nil {
if eConn, ok := userConn.(interface{ CloseWithError(error) error }); ok { if eConn, ok := userConn.(interface{ CloseWithError(error) error }); ok {
_ = eConn.CloseWithError(tunnelErr) _ = eConn.CloseWithError(tunnelErr)
@@ -96,62 +65,21 @@ func (sv *STCPVisitor) handleConn(userConn net.Conn) {
}() }()
xl.Debugf("get a new stcp user connection") xl.Debugf("get a new stcp user connection")
visitorConn, err := sv.helper.ConnectServer() visitorConn, err := sv.dialRawVisitorConn(sv.cfg.GetBaseConfig())
if err != nil { if err != nil {
xl.Warnf("dialRawVisitorConn error: %v", err)
tunnelErr = err tunnelErr = err
return return
} }
defer visitorConn.Close() defer visitorConn.Close()
now := time.Now().Unix() remote, recycleFn, err := wrapVisitorConn(visitorConn, sv.cfg.GetBaseConfig())
targetProxyName := naming.BuildTargetServerProxyName(sv.clientCfg.User, sv.cfg.ServerUser, sv.cfg.ServerName)
newVisitorConnMsg := &msg.NewVisitorConn{
RunID: sv.helper.RunID(),
ProxyName: targetProxyName,
SignKey: util.GetAuthKey(sv.cfg.SecretKey, now),
Timestamp: now,
UseEncryption: sv.cfg.Transport.UseEncryption,
UseCompression: sv.cfg.Transport.UseCompression,
}
err = msg.WriteMsg(visitorConn, newVisitorConnMsg)
if err != nil { if err != nil {
xl.Warnf("send newVisitorConnMsg to server error: %v", err) xl.Warnf("wrapVisitorConn error: %v", err)
tunnelErr = err tunnelErr = err
return return
} }
defer recycleFn()
var newVisitorConnRespMsg msg.NewVisitorConnResp
_ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second))
err = msg.ReadMsgInto(visitorConn, &newVisitorConnRespMsg)
if err != nil {
xl.Warnf("get newVisitorConnRespMsg error: %v", err)
tunnelErr = err
return
}
_ = visitorConn.SetReadDeadline(time.Time{})
if newVisitorConnRespMsg.Error != "" {
xl.Warnf("start new visitor connection error: %s", newVisitorConnRespMsg.Error)
tunnelErr = fmt.Errorf("%s", newVisitorConnRespMsg.Error)
return
}
var remote io.ReadWriteCloser
remote = visitorConn
if sv.cfg.Transport.UseEncryption {
remote, err = libio.WithEncryption(remote, []byte(sv.cfg.SecretKey))
if err != nil {
xl.Errorf("create encryption stream error: %v", err)
tunnelErr = err
return
}
}
if sv.cfg.Transport.UseCompression {
var recycleFn func()
remote, recycleFn = libio.WithCompressionFromPool(remote)
defer recycleFn()
}
libio.Join(userConn, remote) libio.Join(userConn, remote)
} }

View File

@@ -16,21 +16,17 @@ package visitor
import ( import (
"fmt" "fmt"
"io"
"net" "net"
"strconv" "strconv"
"sync" "sync"
"time" "time"
"github.com/fatedier/golib/errors" "github.com/fatedier/golib/errors"
libio "github.com/fatedier/golib/io"
v1 "github.com/fatedier/frp/pkg/config/v1" v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/naming"
"github.com/fatedier/frp/pkg/proto/udp" "github.com/fatedier/frp/pkg/proto/udp"
netpkg "github.com/fatedier/frp/pkg/util/net" netpkg "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/util"
"github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/frp/pkg/util/xlog"
) )
@@ -76,6 +72,7 @@ func (sv *SUDPVisitor) dispatcher() {
var ( var (
visitorConn net.Conn visitorConn net.Conn
recycleFn func()
err error err error
firstPacket *msg.UDPPacket firstPacket *msg.UDPPacket
@@ -93,14 +90,17 @@ func (sv *SUDPVisitor) dispatcher() {
return return
} }
visitorConn, err = sv.getNewVisitorConn() visitorConn, recycleFn, err = sv.getNewVisitorConn()
if err != nil { if err != nil {
xl.Warnf("newVisitorConn to frps error: %v, try to reconnect", err) xl.Warnf("newVisitorConn to frps error: %v, try to reconnect", err)
continue continue
} }
// visitorConn always be closed when worker done. // visitorConn always be closed when worker done.
sv.worker(visitorConn, firstPacket) func() {
defer recycleFn()
sv.worker(visitorConn, firstPacket)
}()
select { select {
case <-sv.checkCloseCh: case <-sv.checkCloseCh:
@@ -147,7 +147,7 @@ func (sv *SUDPVisitor) worker(workConn net.Conn, firstPacket *msg.UDPPacket) {
case *msg.UDPPacket: case *msg.UDPPacket:
if errRet := errors.PanicToError(func() { if errRet := errors.PanicToError(func() {
sv.readCh <- m sv.readCh <- m
xl.Tracef("frpc visitor get udp packet from workConn: %s", m.Content) xl.Tracef("frpc visitor get udp packet from workConn, len: %d", len(m.Content))
}); errRet != nil { }); errRet != nil {
xl.Infof("reader goroutine for udp work connection closed") xl.Infof("reader goroutine for udp work connection closed")
return return
@@ -169,7 +169,7 @@ func (sv *SUDPVisitor) worker(workConn net.Conn, firstPacket *msg.UDPPacket) {
xl.Warnf("sender goroutine for udp work connection closed: %v", errRet) xl.Warnf("sender goroutine for udp work connection closed: %v", errRet)
return return
} }
xl.Tracef("send udp package to workConn: %s", firstPacket.Content) xl.Tracef("send udp package to workConn, len: %d", len(firstPacket.Content))
} }
for { for {
@@ -184,7 +184,7 @@ func (sv *SUDPVisitor) worker(workConn net.Conn, firstPacket *msg.UDPPacket) {
xl.Warnf("sender goroutine for udp work connection closed: %v", errRet) xl.Warnf("sender goroutine for udp work connection closed: %v", errRet)
return return
} }
xl.Tracef("send udp package to workConn: %s", udpMsg.Content) xl.Tracef("send udp package to workConn, len: %d", len(udpMsg.Content))
case <-closeCh: case <-closeCh:
return return
} }
@@ -198,53 +198,17 @@ func (sv *SUDPVisitor) worker(workConn net.Conn, firstPacket *msg.UDPPacket) {
xl.Infof("sudp worker is closed") xl.Infof("sudp worker is closed")
} }
func (sv *SUDPVisitor) getNewVisitorConn() (net.Conn, error) { func (sv *SUDPVisitor) getNewVisitorConn() (net.Conn, func(), error) {
xl := xlog.FromContextSafe(sv.ctx) rawConn, err := sv.dialRawVisitorConn(sv.cfg.GetBaseConfig())
visitorConn, err := sv.helper.ConnectServer()
if err != nil { if err != nil {
return nil, fmt.Errorf("frpc connect frps error: %v", err) return nil, func() {}, err
} }
rwc, recycleFn, err := wrapVisitorConn(rawConn, sv.cfg.GetBaseConfig())
now := time.Now().Unix()
targetProxyName := naming.BuildTargetServerProxyName(sv.clientCfg.User, sv.cfg.ServerUser, sv.cfg.ServerName)
newVisitorConnMsg := &msg.NewVisitorConn{
RunID: sv.helper.RunID(),
ProxyName: targetProxyName,
SignKey: util.GetAuthKey(sv.cfg.SecretKey, now),
Timestamp: now,
UseEncryption: sv.cfg.Transport.UseEncryption,
UseCompression: sv.cfg.Transport.UseCompression,
}
err = msg.WriteMsg(visitorConn, newVisitorConnMsg)
if err != nil { if err != nil {
return nil, fmt.Errorf("frpc send newVisitorConnMsg to frps error: %v", err) rawConn.Close()
return nil, func() {}, err
} }
return netpkg.WrapReadWriteCloserToConn(rwc, rawConn), recycleFn, nil
var newVisitorConnRespMsg msg.NewVisitorConnResp
_ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second))
err = msg.ReadMsgInto(visitorConn, &newVisitorConnRespMsg)
if err != nil {
return nil, fmt.Errorf("frpc read newVisitorConnRespMsg error: %v", err)
}
_ = visitorConn.SetReadDeadline(time.Time{})
if newVisitorConnRespMsg.Error != "" {
return nil, fmt.Errorf("start new visitor connection error: %s", newVisitorConnRespMsg.Error)
}
var remote io.ReadWriteCloser
remote = visitorConn
if sv.cfg.Transport.UseEncryption {
remote, err = libio.WithEncryption(remote, []byte(sv.cfg.SecretKey))
if err != nil {
xl.Errorf("create encryption stream error: %v", err)
return nil, err
}
}
if sv.cfg.Transport.UseCompression {
remote = libio.WithCompression(remote)
}
return netpkg.WrapReadWriteCloserToConn(remote, visitorConn), nil
} }
func (sv *SUDPVisitor) Close() { func (sv *SUDPVisitor) Close() {

View File

@@ -16,13 +16,21 @@ package visitor
import ( import (
"context" "context"
"fmt"
"io"
"net" "net"
"sync" "sync"
"time"
libio "github.com/fatedier/golib/io"
v1 "github.com/fatedier/frp/pkg/config/v1" v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/naming"
plugin "github.com/fatedier/frp/pkg/plugin/visitor" plugin "github.com/fatedier/frp/pkg/plugin/visitor"
"github.com/fatedier/frp/pkg/transport" "github.com/fatedier/frp/pkg/transport"
netpkg "github.com/fatedier/frp/pkg/util/net" netpkg "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/util"
"github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/frp/pkg/util/xlog"
"github.com/fatedier/frp/pkg/vnet" "github.com/fatedier/frp/pkg/vnet"
) )
@@ -119,6 +127,18 @@ func (v *BaseVisitor) AcceptConn(conn net.Conn) error {
return v.internalLn.PutConn(conn) return v.internalLn.PutConn(conn)
} }
func (v *BaseVisitor) acceptLoop(l net.Listener, name string, handleConn func(net.Conn)) {
xl := xlog.FromContextSafe(v.ctx)
for {
conn, err := l.Accept()
if err != nil {
xl.Warnf("%s listener closed", name)
return
}
go handleConn(conn)
}
}
func (v *BaseVisitor) Close() { func (v *BaseVisitor) Close() {
if v.l != nil { if v.l != nil {
v.l.Close() v.l.Close()
@@ -130,3 +150,57 @@ func (v *BaseVisitor) Close() {
v.plugin.Close() v.plugin.Close()
} }
} }
func (v *BaseVisitor) dialRawVisitorConn(cfg *v1.VisitorBaseConfig) (net.Conn, error) {
visitorConn, err := v.helper.ConnectServer()
if err != nil {
return nil, fmt.Errorf("connect to server error: %v", err)
}
now := time.Now().Unix()
targetProxyName := naming.BuildTargetServerProxyName(v.clientCfg.User, cfg.ServerUser, cfg.ServerName)
newVisitorConnMsg := &msg.NewVisitorConn{
RunID: v.helper.RunID(),
ProxyName: targetProxyName,
SignKey: util.GetAuthKey(cfg.SecretKey, now),
Timestamp: now,
UseEncryption: cfg.Transport.UseEncryption,
UseCompression: cfg.Transport.UseCompression,
}
err = msg.WriteMsg(visitorConn, newVisitorConnMsg)
if err != nil {
visitorConn.Close()
return nil, fmt.Errorf("send newVisitorConnMsg to server error: %v", err)
}
var newVisitorConnRespMsg msg.NewVisitorConnResp
_ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second))
err = msg.ReadMsgInto(visitorConn, &newVisitorConnRespMsg)
if err != nil {
visitorConn.Close()
return nil, fmt.Errorf("read newVisitorConnRespMsg error: %v", err)
}
_ = visitorConn.SetReadDeadline(time.Time{})
if newVisitorConnRespMsg.Error != "" {
visitorConn.Close()
return nil, fmt.Errorf("start new visitor connection error: %s", newVisitorConnRespMsg.Error)
}
return visitorConn, nil
}
func wrapVisitorConn(conn io.ReadWriteCloser, cfg *v1.VisitorBaseConfig) (io.ReadWriteCloser, func(), error) {
rwc := conn
if cfg.Transport.UseEncryption {
var err error
rwc, err = libio.WithEncryption(rwc, []byte(cfg.SecretKey))
if err != nil {
return nil, func() {}, fmt.Errorf("create encryption stream error: %v", err)
}
}
recycleFn := func() {}
if cfg.Transport.UseCompression {
rwc, recycleFn = libio.WithCompressionFromPool(rwc)
}
return rwc, recycleFn, nil
}

View File

@@ -65,10 +65,10 @@ func (sv *XTCPVisitor) Run() (err error) {
if err != nil { if err != nil {
return return
} }
go sv.worker() go sv.acceptLoop(sv.l, "xtcp local", sv.handleConn)
} }
go sv.internalConnWorker() go sv.acceptLoop(sv.internalLn, "xtcp internal", sv.handleConn)
go sv.processTunnelStartEvents() go sv.processTunnelStartEvents()
if sv.cfg.KeepTunnelOpen { if sv.cfg.KeepTunnelOpen {
sv.retryLimiter = rate.NewLimiter(rate.Every(time.Hour/time.Duration(sv.cfg.MaxRetriesAnHour)), sv.cfg.MaxRetriesAnHour) sv.retryLimiter = rate.NewLimiter(rate.Every(time.Hour/time.Duration(sv.cfg.MaxRetriesAnHour)), sv.cfg.MaxRetriesAnHour)
@@ -93,30 +93,6 @@ func (sv *XTCPVisitor) Close() {
} }
} }
func (sv *XTCPVisitor) worker() {
xl := xlog.FromContextSafe(sv.ctx)
for {
conn, err := sv.l.Accept()
if err != nil {
xl.Warnf("xtcp local listener closed")
return
}
go sv.handleConn(conn)
}
}
func (sv *XTCPVisitor) internalConnWorker() {
xl := xlog.FromContextSafe(sv.ctx)
for {
conn, err := sv.internalLn.Accept()
if err != nil {
xl.Warnf("xtcp internal listener closed")
return
}
go sv.handleConn(conn)
}
}
func (sv *XTCPVisitor) processTunnelStartEvents() { func (sv *XTCPVisitor) processTunnelStartEvents() {
for { for {
select { select {
@@ -206,20 +182,14 @@ func (sv *XTCPVisitor) handleConn(userConn net.Conn) {
return return
} }
var muxConnRWCloser io.ReadWriteCloser = tunnelConn muxConnRWCloser, recycleFn, err := wrapVisitorConn(tunnelConn, sv.cfg.GetBaseConfig())
if sv.cfg.Transport.UseEncryption { if err != nil {
muxConnRWCloser, err = libio.WithEncryption(muxConnRWCloser, []byte(sv.cfg.SecretKey)) xl.Errorf("%v", err)
if err != nil { tunnelConn.Close()
xl.Errorf("create encryption stream error: %v", err) tunnelErr = err
tunnelErr = err return
return
}
}
if sv.cfg.Transport.UseCompression {
var recycleFn func()
muxConnRWCloser, recycleFn = libio.WithCompressionFromPool(muxConnRWCloser)
defer recycleFn()
} }
defer recycleFn()
_, _, errs := libio.Join(userConn, muxConnRWCloser) _, _, errs := libio.Join(userConn, muxConnRWCloser)
xl.Debugf("join connections closed") xl.Debugf("join connections closed")
@@ -373,6 +343,7 @@ func (ks *KCPTunnelSession) Init(listenConn *net.UDPConn, raddr *net.UDPAddr) er
} }
remote, err := netpkg.NewKCPConnFromUDP(lConn, true, raddr.String()) remote, err := netpkg.NewKCPConnFromUDP(lConn, true, raddr.String())
if err != nil { if err != nil {
lConn.Close()
return fmt.Errorf("create kcp connection from udp connection error: %v", err) return fmt.Errorf("create kcp connection from udp connection error: %v", err)
} }

View File

@@ -47,7 +47,7 @@ var natholeDiscoveryCmd = &cobra.Command{
Use: "discover", Use: "discover",
Short: "Discover nathole information from stun server", Short: "Discover nathole information from stun server",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
// ignore error here, because we can use command line pameters // ignore error here, because we can use command line parameters
cfg, _, _, _, err := config.LoadClientConfig(cfgFile, strictConfigMode) cfg, _, _, _, err := config.LoadClientConfig(cfgFile, strictConfigMode)
if err != nil { if err != nil {
cfg = &v1.ClientCommonConfig{} cfg = &v1.ClientCommonConfig{}

View File

@@ -23,6 +23,7 @@ import (
"net/url" "net/url"
"os" "os"
"slices" "slices"
"sync"
"github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc"
"golang.org/x/oauth2" "golang.org/x/oauth2"
@@ -74,11 +75,23 @@ func createOIDCHTTPClient(trustedCAFile string, insecureSkipVerify bool, proxyUR
return &http.Client{Transport: transport}, nil return &http.Client{Transport: transport}, nil
} }
// nonCachingTokenSource wraps a clientcredentials.Config to fetch a fresh
// token on every call. This is used as a fallback when the OIDC provider
// does not return expires_in, which would cause a caching TokenSource to
// hold onto a stale token forever.
type nonCachingTokenSource struct {
cfg *clientcredentials.Config
ctx context.Context
}
func (s *nonCachingTokenSource) Token() (*oauth2.Token, error) {
return s.cfg.Token(s.ctx)
}
type OidcAuthProvider struct { type OidcAuthProvider struct {
additionalAuthScopes []v1.AuthScope additionalAuthScopes []v1.AuthScope
tokenGenerator *clientcredentials.Config tokenSource oauth2.TokenSource
httpClient *http.Client
} }
func NewOidcAuthSetter(additionalAuthScopes []v1.AuthScope, cfg v1.AuthOIDCClientConfig) (*OidcAuthProvider, error) { func NewOidcAuthSetter(additionalAuthScopes []v1.AuthScope, cfg v1.AuthOIDCClientConfig) (*OidcAuthProvider, error) {
@@ -99,30 +112,44 @@ func NewOidcAuthSetter(additionalAuthScopes []v1.AuthScope, cfg v1.AuthOIDCClien
EndpointParams: eps, EndpointParams: eps,
} }
// Create custom HTTP client if needed // Build the context that TokenSource will use for all future HTTP requests.
var httpClient *http.Client // context.Background() is appropriate here because the token source is
// long-lived and outlives any single request.
ctx := context.Background()
if cfg.TrustedCaFile != "" || cfg.InsecureSkipVerify || cfg.ProxyURL != "" { if cfg.TrustedCaFile != "" || cfg.InsecureSkipVerify || cfg.ProxyURL != "" {
var err error httpClient, err := createOIDCHTTPClient(cfg.TrustedCaFile, cfg.InsecureSkipVerify, cfg.ProxyURL)
httpClient, err = createOIDCHTTPClient(cfg.TrustedCaFile, cfg.InsecureSkipVerify, cfg.ProxyURL)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create OIDC HTTP client: %w", err) return nil, fmt.Errorf("failed to create OIDC HTTP client: %w", err)
} }
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
}
// Create a persistent TokenSource that caches the token and refreshes
// it before expiry. This avoids making a new HTTP request to the OIDC
// provider on every heartbeat/ping.
tokenSource := tokenGenerator.TokenSource(ctx)
// Fetch the initial token to check if the provider returns an expiry.
// If Expiry is the zero value (provider omitted expires_in), the cached
// TokenSource would treat the token as valid forever and never refresh it,
// even after the JWT's exp claim passes. In that case, fall back to
// fetching a fresh token on every request.
initialToken, err := tokenSource.Token()
if err != nil {
return nil, fmt.Errorf("failed to obtain initial OIDC token: %w", err)
}
if initialToken.Expiry.IsZero() {
tokenSource = &nonCachingTokenSource{cfg: tokenGenerator, ctx: ctx}
} }
return &OidcAuthProvider{ return &OidcAuthProvider{
additionalAuthScopes: additionalAuthScopes, additionalAuthScopes: additionalAuthScopes,
tokenGenerator: tokenGenerator, tokenSource: tokenSource,
httpClient: httpClient,
}, nil }, nil
} }
func (auth *OidcAuthProvider) generateAccessToken() (accessToken string, err error) { func (auth *OidcAuthProvider) generateAccessToken() (accessToken string, err error) {
ctx := context.Background() tokenObj, err := auth.tokenSource.Token()
if auth.httpClient != nil {
ctx = context.WithValue(ctx, oauth2.HTTPClient, auth.httpClient)
}
tokenObj, err := auth.tokenGenerator.Token(ctx)
if err != nil { if err != nil {
return "", fmt.Errorf("couldn't generate OIDC token for login: %v", err) return "", fmt.Errorf("couldn't generate OIDC token for login: %v", err)
} }
@@ -205,7 +232,8 @@ type OidcAuthConsumer struct {
additionalAuthScopes []v1.AuthScope additionalAuthScopes []v1.AuthScope
verifier TokenVerifier verifier TokenVerifier
subjectsFromLogin []string mu sync.RWMutex
subjectsFromLogin map[string]struct{}
} }
func NewTokenVerifier(cfg v1.AuthOIDCServerConfig) TokenVerifier { func NewTokenVerifier(cfg v1.AuthOIDCServerConfig) TokenVerifier {
@@ -226,7 +254,7 @@ func NewOidcAuthVerifier(additionalAuthScopes []v1.AuthScope, verifier TokenVeri
return &OidcAuthConsumer{ return &OidcAuthConsumer{
additionalAuthScopes: additionalAuthScopes, additionalAuthScopes: additionalAuthScopes,
verifier: verifier, verifier: verifier,
subjectsFromLogin: []string{}, subjectsFromLogin: make(map[string]struct{}),
} }
} }
@@ -235,9 +263,9 @@ func (auth *OidcAuthConsumer) VerifyLogin(loginMsg *msg.Login) (err error) {
if err != nil { if err != nil {
return fmt.Errorf("invalid OIDC token in login: %v", err) return fmt.Errorf("invalid OIDC token in login: %v", err)
} }
if !slices.Contains(auth.subjectsFromLogin, token.Subject) { auth.mu.Lock()
auth.subjectsFromLogin = append(auth.subjectsFromLogin, token.Subject) auth.subjectsFromLogin[token.Subject] = struct{}{}
} auth.mu.Unlock()
return nil return nil
} }
@@ -246,11 +274,13 @@ func (auth *OidcAuthConsumer) verifyPostLoginToken(privilegeKey string) (err err
if err != nil { if err != nil {
return fmt.Errorf("invalid OIDC token in ping: %v", err) return fmt.Errorf("invalid OIDC token in ping: %v", err)
} }
if !slices.Contains(auth.subjectsFromLogin, token.Subject) { auth.mu.RLock()
_, ok := auth.subjectsFromLogin[token.Subject]
auth.mu.RUnlock()
if !ok {
return fmt.Errorf("received different OIDC subject in login and ping. "+ return fmt.Errorf("received different OIDC subject in login and ping. "+
"original subjects: %s, "+
"new subject: %s", "new subject: %s",
auth.subjectsFromLogin, token.Subject) token.Subject)
} }
return nil return nil
} }

View File

@@ -2,6 +2,10 @@ package auth_test
import ( import (
"context" "context"
"encoding/json"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing" "testing"
"time" "time"
@@ -62,3 +66,90 @@ func TestPingAfterLoginWithDifferentSubjectFails(t *testing.T) {
r.Error(err) r.Error(err)
r.Contains(err.Error(), "received different OIDC subject in login and ping") r.Contains(err.Error(), "received different OIDC subject in login and ping")
} }
func TestOidcAuthProviderFallsBackWhenNoExpiry(t *testing.T) {
r := require.New(t)
var requestCount atomic.Int32
tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
requestCount.Add(1)
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{ //nolint:gosec // test-only dummy token response
"access_token": "fresh-test-token",
"token_type": "Bearer",
})
}))
defer tokenServer.Close()
provider, err := auth.NewOidcAuthSetter(
[]v1.AuthScope{v1.AuthScopeHeartBeats},
v1.AuthOIDCClientConfig{
ClientID: "test-client",
ClientSecret: "test-secret",
TokenEndpointURL: tokenServer.URL,
},
)
r.NoError(err)
// Constructor fetches the initial token (1 request).
// Each subsequent call should also fetch a fresh token since there is no expiry.
loginMsg := &msg.Login{}
err = provider.SetLogin(loginMsg)
r.NoError(err)
r.Equal("fresh-test-token", loginMsg.PrivilegeKey)
for range 3 {
pingMsg := &msg.Ping{}
err = provider.SetPing(pingMsg)
r.NoError(err)
r.Equal("fresh-test-token", pingMsg.PrivilegeKey)
}
// 1 initial (constructor) + 1 login + 3 pings = 5 requests
r.Equal(int32(5), requestCount.Load(), "each call should fetch a fresh token when expires_in is missing")
}
func TestOidcAuthProviderCachesToken(t *testing.T) {
r := require.New(t)
var requestCount atomic.Int32
tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
requestCount.Add(1)
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{ //nolint:gosec // test-only dummy token response
"access_token": "cached-test-token",
"token_type": "Bearer",
"expires_in": 3600,
})
}))
defer tokenServer.Close()
provider, err := auth.NewOidcAuthSetter(
[]v1.AuthScope{v1.AuthScopeHeartBeats},
v1.AuthOIDCClientConfig{
ClientID: "test-client",
ClientSecret: "test-secret",
TokenEndpointURL: tokenServer.URL,
},
)
r.NoError(err)
// Constructor eagerly fetches the initial token (1 request).
r.Equal(int32(1), requestCount.Load())
// SetLogin should reuse the cached token
loginMsg := &msg.Login{}
err = provider.SetLogin(loginMsg)
r.NoError(err)
r.Equal("cached-test-token", loginMsg.PrivilegeKey)
r.Equal(int32(1), requestCount.Load())
// Subsequent calls should also reuse the cached token
for range 5 {
pingMsg := &msg.Ping{}
err = provider.SetPing(pingMsg)
r.NoError(err)
r.Equal("cached-test-token", pingMsg.PrivilegeKey)
}
r.Equal(int32(1), requestCount.Load(), "token endpoint should only be called once; cached token should be reused")
}

View File

@@ -171,15 +171,14 @@ func Convert_ServerCommonConf_To_v1(conf *ServerCommonConf) *v1.ServerConfig {
func transformHeadersFromPluginParams(params map[string]string) v1.HeaderOperations { func transformHeadersFromPluginParams(params map[string]string) v1.HeaderOperations {
out := v1.HeaderOperations{} out := v1.HeaderOperations{}
for k, v := range params { for k, v := range params {
if !strings.HasPrefix(k, "plugin_header_") { k, ok := strings.CutPrefix(k, "plugin_header_")
if !ok || k == "" {
continue continue
} }
if k = strings.TrimPrefix(k, "plugin_header_"); k != "" { if out.Set == nil {
if out.Set == nil { out.Set = make(map[string]string)
out.Set = make(map[string]string)
}
out.Set[k] = v
} }
out.Set[k] = v
} }
return out return out
} }

View File

@@ -39,14 +39,14 @@ const (
// Proxy // Proxy
var ( var (
proxyConfTypeMap = map[ProxyType]reflect.Type{ proxyConfTypeMap = map[ProxyType]reflect.Type{
ProxyTypeTCP: reflect.TypeOf(TCPProxyConf{}), ProxyTypeTCP: reflect.TypeFor[TCPProxyConf](),
ProxyTypeUDP: reflect.TypeOf(UDPProxyConf{}), ProxyTypeUDP: reflect.TypeFor[UDPProxyConf](),
ProxyTypeTCPMUX: reflect.TypeOf(TCPMuxProxyConf{}), ProxyTypeTCPMUX: reflect.TypeFor[TCPMuxProxyConf](),
ProxyTypeHTTP: reflect.TypeOf(HTTPProxyConf{}), ProxyTypeHTTP: reflect.TypeFor[HTTPProxyConf](),
ProxyTypeHTTPS: reflect.TypeOf(HTTPSProxyConf{}), ProxyTypeHTTPS: reflect.TypeFor[HTTPSProxyConf](),
ProxyTypeSTCP: reflect.TypeOf(STCPProxyConf{}), ProxyTypeSTCP: reflect.TypeFor[STCPProxyConf](),
ProxyTypeXTCP: reflect.TypeOf(XTCPProxyConf{}), ProxyTypeXTCP: reflect.TypeFor[XTCPProxyConf](),
ProxyTypeSUDP: reflect.TypeOf(SUDPProxyConf{}), ProxyTypeSUDP: reflect.TypeFor[SUDPProxyConf](),
} }
) )

View File

@@ -22,8 +22,8 @@ func GetMapWithoutPrefix(set map[string]string, prefix string) map[string]string
m := make(map[string]string) m := make(map[string]string)
for key, value := range set { for key, value := range set {
if strings.HasPrefix(key, prefix) { if trimmed, ok := strings.CutPrefix(key, prefix); ok {
m[strings.TrimPrefix(key, prefix)] = value m[trimmed] = value
} }
} }

View File

@@ -32,9 +32,9 @@ const (
// Visitor // Visitor
var ( var (
visitorConfTypeMap = map[VisitorType]reflect.Type{ visitorConfTypeMap = map[VisitorType]reflect.Type{
VisitorTypeSTCP: reflect.TypeOf(STCPVisitorConf{}), VisitorTypeSTCP: reflect.TypeFor[STCPVisitorConf](),
VisitorTypeXTCP: reflect.TypeOf(XTCPVisitorConf{}), VisitorTypeXTCP: reflect.TypeFor[XTCPVisitorConf](),
VisitorTypeSUDP: reflect.TypeOf(SUDPVisitorConf{}), VisitorTypeSUDP: reflect.TypeFor[SUDPVisitorConf](),
} }
) )

View File

@@ -15,9 +15,11 @@
package source package source
import ( import (
"cmp"
"errors" "errors"
"fmt" "fmt"
"sort" "maps"
"slices"
"sync" "sync"
v1 "github.com/fatedier/frp/pkg/config/v1" v1 "github.com/fatedier/frp/pkg/config/v1"
@@ -97,21 +99,11 @@ func (a *Aggregator) mapsToSortedSlices(
proxyMap map[string]v1.ProxyConfigurer, proxyMap map[string]v1.ProxyConfigurer,
visitorMap map[string]v1.VisitorConfigurer, visitorMap map[string]v1.VisitorConfigurer,
) ([]v1.ProxyConfigurer, []v1.VisitorConfigurer) { ) ([]v1.ProxyConfigurer, []v1.VisitorConfigurer) {
proxies := make([]v1.ProxyConfigurer, 0, len(proxyMap)) proxies := slices.SortedFunc(maps.Values(proxyMap), func(x, y v1.ProxyConfigurer) int {
for _, p := range proxyMap { return cmp.Compare(x.GetBaseConfig().Name, y.GetBaseConfig().Name)
proxies = append(proxies, p)
}
sort.Slice(proxies, func(i, j int) bool {
return proxies[i].GetBaseConfig().Name < proxies[j].GetBaseConfig().Name
}) })
visitors := slices.SortedFunc(maps.Values(visitorMap), func(x, y v1.VisitorConfigurer) int {
visitors := make([]v1.VisitorConfigurer, 0, len(visitorMap)) return cmp.Compare(x.GetBaseConfig().Name, y.GetBaseConfig().Name)
for _, v := range visitorMap {
visitors = append(visitors, v)
}
sort.Slice(visitors, func(i, j int) bool {
return visitors[i].GetBaseConfig().Name < visitors[j].GetBaseConfig().Name
}) })
return proxies, visitors return proxies, visitors
} }

View File

@@ -196,6 +196,27 @@ func TestAggregator_VisitorMerge(t *testing.T) {
require.Len(visitors, 2) require.Len(visitors, 2)
} }
func TestAggregator_Load_ReturnsSortedByName(t *testing.T) {
require := require.New(t)
agg := newTestAggregator(t, nil)
err := agg.ConfigSource().ReplaceAll(
[]v1.ProxyConfigurer{mockProxy("charlie"), mockProxy("alice"), mockProxy("bob")},
[]v1.VisitorConfigurer{mockVisitor("zulu"), mockVisitor("alpha")},
)
require.NoError(err)
proxies, visitors, err := agg.Load()
require.NoError(err)
require.Len(proxies, 3)
require.Equal("alice", proxies[0].GetBaseConfig().Name)
require.Equal("bob", proxies[1].GetBaseConfig().Name)
require.Equal("charlie", proxies[2].GetBaseConfig().Name)
require.Len(visitors, 2)
require.Equal("alpha", visitors[0].GetBaseConfig().Name)
require.Equal("zulu", visitors[1].GetBaseConfig().Name)
}
func TestAggregator_Load_ReturnsDefensiveCopies(t *testing.T) { func TestAggregator_Load_ReturnsDefensiveCopies(t *testing.T) {
require := require.New(t) require := require.New(t)

View File

@@ -38,7 +38,7 @@ func parseNumberRangePair(firstRangeStr, secondRangeStr string) ([]NumberPair, e
return nil, fmt.Errorf("first and second range numbers are not in pairs") return nil, fmt.Errorf("first and second range numbers are not in pairs")
} }
pairs := make([]NumberPair, 0, len(firstRangeNumbers)) pairs := make([]NumberPair, 0, len(firstRangeNumbers))
for i := 0; i < len(firstRangeNumbers); i++ { for i := range firstRangeNumbers {
pairs = append(pairs, NumberPair{ pairs = append(pairs, NumberPair{
First: firstRangeNumbers[i], First: firstRangeNumbers[i],
Second: secondRangeNumbers[i], Second: secondRangeNumbers[i],

View File

@@ -70,24 +70,18 @@ func (q *BandwidthQuantity) UnmarshalString(s string) error {
f float64 f float64
err error err error
) )
switch { if fstr, ok := strings.CutSuffix(s, "MB"); ok {
case strings.HasSuffix(s, "MB"):
base = MB base = MB
fstr := strings.TrimSuffix(s, "MB")
f, err = strconv.ParseFloat(fstr, 64) f, err = strconv.ParseFloat(fstr, 64)
if err != nil { } else if fstr, ok := strings.CutSuffix(s, "KB"); ok {
return err
}
case strings.HasSuffix(s, "KB"):
base = KB base = KB
fstr := strings.TrimSuffix(s, "KB")
f, err = strconv.ParseFloat(fstr, 64) f, err = strconv.ParseFloat(fstr, 64)
if err != nil { } else {
return err
}
default:
return errors.New("unit not support") return errors.New("unit not support")
} }
if err != nil {
return err
}
q.s = s q.s = s
q.i = int64(f * float64(base)) q.i = int64(f * float64(base))
@@ -143,8 +137,8 @@ func (p PortsRangeSlice) String() string {
func NewPortsRangeSliceFromString(str string) ([]PortsRange, error) { func NewPortsRangeSliceFromString(str string) ([]PortsRange, error) {
str = strings.TrimSpace(str) str = strings.TrimSpace(str)
out := []PortsRange{} out := []PortsRange{}
numRanges := strings.Split(str, ",") numRanges := strings.SplitSeq(str, ",")
for _, numRangeStr := range numRanges { for numRangeStr := range numRanges {
// 1000-2000 or 2001 // 1000-2000 or 2001
numArray := strings.Split(numRangeStr, "-") numArray := strings.Split(numRangeStr, "-")
// length: only 1 or 2 is correct // length: only 1 or 2 is correct

View File

@@ -39,6 +39,31 @@ func TestBandwidthQuantity(t *testing.T) {
require.Equal(`{"b":"1KB","int":5}`, string(buf)) require.Equal(`{"b":"1KB","int":5}`, string(buf))
} }
func TestBandwidthQuantity_MB(t *testing.T) {
require := require.New(t)
var w Wrap
err := json.Unmarshal([]byte(`{"b":"2MB","int":1}`), &w)
require.NoError(err)
require.EqualValues(2*MB, w.B.Bytes())
buf, err := json.Marshal(&w)
require.NoError(err)
require.Equal(`{"b":"2MB","int":1}`, string(buf))
}
func TestBandwidthQuantity_InvalidUnit(t *testing.T) {
var w Wrap
err := json.Unmarshal([]byte(`{"b":"1GB","int":1}`), &w)
require.Error(t, err)
}
func TestBandwidthQuantity_InvalidNumber(t *testing.T) {
var w Wrap
err := json.Unmarshal([]byte(`{"b":"abcKB","int":1}`), &w)
require.Error(t, err)
}
func TestPortsRangeSlice2String(t *testing.T) { func TestPortsRangeSlice2String(t *testing.T) {
require := require.New(t) require := require.New(t)

View File

@@ -239,14 +239,14 @@ const (
) )
var proxyConfigTypeMap = map[ProxyType]reflect.Type{ var proxyConfigTypeMap = map[ProxyType]reflect.Type{
ProxyTypeTCP: reflect.TypeOf(TCPProxyConfig{}), ProxyTypeTCP: reflect.TypeFor[TCPProxyConfig](),
ProxyTypeUDP: reflect.TypeOf(UDPProxyConfig{}), ProxyTypeUDP: reflect.TypeFor[UDPProxyConfig](),
ProxyTypeHTTP: reflect.TypeOf(HTTPProxyConfig{}), ProxyTypeHTTP: reflect.TypeFor[HTTPProxyConfig](),
ProxyTypeHTTPS: reflect.TypeOf(HTTPSProxyConfig{}), ProxyTypeHTTPS: reflect.TypeFor[HTTPSProxyConfig](),
ProxyTypeTCPMUX: reflect.TypeOf(TCPMuxProxyConfig{}), ProxyTypeTCPMUX: reflect.TypeFor[TCPMuxProxyConfig](),
ProxyTypeSTCP: reflect.TypeOf(STCPProxyConfig{}), ProxyTypeSTCP: reflect.TypeFor[STCPProxyConfig](),
ProxyTypeXTCP: reflect.TypeOf(XTCPProxyConfig{}), ProxyTypeXTCP: reflect.TypeFor[XTCPProxyConfig](),
ProxyTypeSUDP: reflect.TypeOf(SUDPProxyConfig{}), ProxyTypeSUDP: reflect.TypeFor[SUDPProxyConfig](),
} }
func NewProxyConfigurerByType(proxyType ProxyType) ProxyConfigurer { func NewProxyConfigurerByType(proxyType ProxyType) ProxyConfigurer {

View File

@@ -37,16 +37,16 @@ const (
) )
var clientPluginOptionsTypeMap = map[string]reflect.Type{ var clientPluginOptionsTypeMap = map[string]reflect.Type{
PluginHTTP2HTTPS: reflect.TypeOf(HTTP2HTTPSPluginOptions{}), PluginHTTP2HTTPS: reflect.TypeFor[HTTP2HTTPSPluginOptions](),
PluginHTTPProxy: reflect.TypeOf(HTTPProxyPluginOptions{}), PluginHTTPProxy: reflect.TypeFor[HTTPProxyPluginOptions](),
PluginHTTPS2HTTP: reflect.TypeOf(HTTPS2HTTPPluginOptions{}), PluginHTTPS2HTTP: reflect.TypeFor[HTTPS2HTTPPluginOptions](),
PluginHTTPS2HTTPS: reflect.TypeOf(HTTPS2HTTPSPluginOptions{}), PluginHTTPS2HTTPS: reflect.TypeFor[HTTPS2HTTPSPluginOptions](),
PluginHTTP2HTTP: reflect.TypeOf(HTTP2HTTPPluginOptions{}), PluginHTTP2HTTP: reflect.TypeFor[HTTP2HTTPPluginOptions](),
PluginSocks5: reflect.TypeOf(Socks5PluginOptions{}), PluginSocks5: reflect.TypeFor[Socks5PluginOptions](),
PluginStaticFile: reflect.TypeOf(StaticFilePluginOptions{}), PluginStaticFile: reflect.TypeFor[StaticFilePluginOptions](),
PluginUnixDomainSocket: reflect.TypeOf(UnixDomainSocketPluginOptions{}), PluginUnixDomainSocket: reflect.TypeFor[UnixDomainSocketPluginOptions](),
PluginTLS2Raw: reflect.TypeOf(TLS2RawPluginOptions{}), PluginTLS2Raw: reflect.TypeFor[TLS2RawPluginOptions](),
PluginVirtualNet: reflect.TypeOf(VirtualNetPluginOptions{}), PluginVirtualNet: reflect.TypeFor[VirtualNetPluginOptions](),
} }
type ClientPluginOptions interface { type ClientPluginOptions interface {

View File

@@ -79,9 +79,9 @@ const (
) )
var visitorConfigTypeMap = map[VisitorType]reflect.Type{ var visitorConfigTypeMap = map[VisitorType]reflect.Type{
VisitorTypeSTCP: reflect.TypeOf(STCPVisitorConfig{}), VisitorTypeSTCP: reflect.TypeFor[STCPVisitorConfig](),
VisitorTypeXTCP: reflect.TypeOf(XTCPVisitorConfig{}), VisitorTypeXTCP: reflect.TypeFor[XTCPVisitorConfig](),
VisitorTypeSUDP: reflect.TypeOf(SUDPVisitorConfig{}), VisitorTypeSUDP: reflect.TypeFor[SUDPVisitorConfig](),
} }
type TypedVisitorConfig struct { type TypedVisitorConfig struct {

View File

@@ -25,7 +25,7 @@ const (
) )
var visitorPluginOptionsTypeMap = map[string]reflect.Type{ var visitorPluginOptionsTypeMap = map[string]reflect.Type{
VisitorPluginVirtualNet: reflect.TypeOf(VirtualNetVisitorPluginOptions{}), VisitorPluginVirtualNet: reflect.TypeFor[VirtualNetVisitorPluginOptions](),
} }
type VisitorPluginOptions interface { type VisitorPluginOptions interface {

View File

@@ -143,7 +143,6 @@ func (m *serverMetrics) OpenConnection(name string, _ string) {
proxyStats, ok := m.info.ProxyStatistics[name] proxyStats, ok := m.info.ProxyStatistics[name]
if ok { if ok {
proxyStats.CurConns.Inc(1) proxyStats.CurConns.Inc(1)
m.info.ProxyStatistics[name] = proxyStats
} }
} }
@@ -155,7 +154,6 @@ func (m *serverMetrics) CloseConnection(name string, _ string) {
proxyStats, ok := m.info.ProxyStatistics[name] proxyStats, ok := m.info.ProxyStatistics[name]
if ok { if ok {
proxyStats.CurConns.Dec(1) proxyStats.CurConns.Dec(1)
m.info.ProxyStatistics[name] = proxyStats
} }
} }
@@ -168,7 +166,6 @@ func (m *serverMetrics) AddTrafficIn(name string, _ string, trafficBytes int64)
proxyStats, ok := m.info.ProxyStatistics[name] proxyStats, ok := m.info.ProxyStatistics[name]
if ok { if ok {
proxyStats.TrafficIn.Inc(trafficBytes) proxyStats.TrafficIn.Inc(trafficBytes)
m.info.ProxyStatistics[name] = proxyStats
} }
} }
@@ -181,7 +178,6 @@ func (m *serverMetrics) AddTrafficOut(name string, _ string, trafficBytes int64)
proxyStats, ok := m.info.ProxyStatistics[name] proxyStats, ok := m.info.ProxyStatistics[name]
if ok { if ok {
proxyStats.TrafficOut.Inc(trafficBytes) proxyStats.TrafficOut.Inc(trafficBytes)
m.info.ProxyStatistics[name] = proxyStats
} }
} }
@@ -203,6 +199,25 @@ func (m *serverMetrics) GetServer() *ServerStats {
return s return s
} }
func toProxyStats(name string, proxyStats *ProxyStatistics) *ProxyStats {
ps := &ProxyStats{
Name: name,
Type: proxyStats.ProxyType,
User: proxyStats.User,
ClientID: proxyStats.ClientID,
TodayTrafficIn: proxyStats.TrafficIn.TodayCount(),
TodayTrafficOut: proxyStats.TrafficOut.TodayCount(),
CurConns: int64(proxyStats.CurConns.Count()),
}
if !proxyStats.LastStartTime.IsZero() {
ps.LastStartTime = proxyStats.LastStartTime.Format("01-02 15:04:05")
}
if !proxyStats.LastCloseTime.IsZero() {
ps.LastCloseTime = proxyStats.LastCloseTime.Format("01-02 15:04:05")
}
return ps
}
func (m *serverMetrics) GetProxiesByType(proxyType string) []*ProxyStats { func (m *serverMetrics) GetProxiesByType(proxyType string) []*ProxyStats {
res := make([]*ProxyStats, 0) res := make([]*ProxyStats, 0)
m.mu.Lock() m.mu.Lock()
@@ -212,23 +227,7 @@ func (m *serverMetrics) GetProxiesByType(proxyType string) []*ProxyStats {
if proxyStats.ProxyType != proxyType { if proxyStats.ProxyType != proxyType {
continue continue
} }
res = append(res, toProxyStats(name, proxyStats))
ps := &ProxyStats{
Name: name,
Type: proxyStats.ProxyType,
User: proxyStats.User,
ClientID: proxyStats.ClientID,
TodayTrafficIn: proxyStats.TrafficIn.TodayCount(),
TodayTrafficOut: proxyStats.TrafficOut.TodayCount(),
CurConns: int64(proxyStats.CurConns.Count()),
}
if !proxyStats.LastStartTime.IsZero() {
ps.LastStartTime = proxyStats.LastStartTime.Format("01-02 15:04:05")
}
if !proxyStats.LastCloseTime.IsZero() {
ps.LastCloseTime = proxyStats.LastCloseTime.Format("01-02 15:04:05")
}
res = append(res, ps)
} }
return res return res
} }
@@ -237,31 +236,9 @@ func (m *serverMetrics) GetProxiesByTypeAndName(proxyType string, proxyName stri
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
for name, proxyStats := range m.info.ProxyStatistics { proxyStats, ok := m.info.ProxyStatistics[proxyName]
if proxyStats.ProxyType != proxyType { if ok && proxyStats.ProxyType == proxyType {
continue res = toProxyStats(proxyName, proxyStats)
}
if name != proxyName {
continue
}
res = &ProxyStats{
Name: name,
Type: proxyStats.ProxyType,
User: proxyStats.User,
ClientID: proxyStats.ClientID,
TodayTrafficIn: proxyStats.TrafficIn.TodayCount(),
TodayTrafficOut: proxyStats.TrafficOut.TodayCount(),
CurConns: int64(proxyStats.CurConns.Count()),
}
if !proxyStats.LastStartTime.IsZero() {
res.LastStartTime = proxyStats.LastStartTime.Format("01-02 15:04:05")
}
if !proxyStats.LastCloseTime.IsZero() {
res.LastCloseTime = proxyStats.LastCloseTime.Format("01-02 15:04:05")
}
break
} }
return return
} }
@@ -272,21 +249,7 @@ func (m *serverMetrics) GetProxyByName(proxyName string) (res *ProxyStats) {
proxyStats, ok := m.info.ProxyStatistics[proxyName] proxyStats, ok := m.info.ProxyStatistics[proxyName]
if ok { if ok {
res = &ProxyStats{ res = toProxyStats(proxyName, proxyStats)
Name: proxyName,
Type: proxyStats.ProxyType,
User: proxyStats.User,
ClientID: proxyStats.ClientID,
TodayTrafficIn: proxyStats.TrafficIn.TodayCount(),
TodayTrafficOut: proxyStats.TrafficOut.TodayCount(),
CurConns: int64(proxyStats.CurConns.Count()),
}
if !proxyStats.LastStartTime.IsZero() {
res.LastStartTime = proxyStats.LastStartTime.Format("01-02 15:04:05")
}
if !proxyStats.LastCloseTime.IsZero() {
res.LastCloseTime = proxyStats.LastCloseTime.Format("01-02 15:04:05")
}
} }
return return
} }

View File

@@ -61,7 +61,7 @@ var msgTypeMap = map[byte]any{
TypeNatHoleReport: NatHoleReport{}, TypeNatHoleReport: NatHoleReport{},
} }
var TypeNameNatHoleResp = reflect.TypeOf(&NatHoleResp{}).Elem().Name() var TypeNameNatHoleResp = reflect.TypeFor[NatHoleResp]().Name()
type ClientSpec struct { type ClientSpec struct {
// Due to the support of VirtualClient, frps needs to know the client type in order to // Due to the support of VirtualClient, frps needs to know the client type in order to
@@ -184,7 +184,7 @@ type Pong struct {
} }
type UDPPacket struct { type UDPPacket struct {
Content string `json:"c,omitempty"` Content []byte `json:"c,omitempty"`
LocalAddr *net.UDPAddr `json:"l,omitempty"` LocalAddr *net.UDPAddr `json:"l,omitempty"`
RemoteAddr *net.UDPAddr `json:"r,omitempty"` RemoteAddr *net.UDPAddr `json:"r,omitempty"`
} }

View File

@@ -16,9 +16,8 @@ func StripUserPrefix(user, name string) string {
if user == "" { if user == "" {
return name return name
} }
prefix := user + "." if trimmed, ok := strings.CutPrefix(name, user+"."); ok {
if strings.HasPrefix(name, prefix) { return trimmed
return strings.TrimPrefix(name, prefix)
} }
return name return name
} }

View File

@@ -151,7 +151,7 @@ func getBehaviorScoresByMode(mode int, defaultScore int) []*BehaviorScore {
func getBehaviorScoresByMode2(mode int, senderScore, receiverScore int) []*BehaviorScore { func getBehaviorScoresByMode2(mode int, senderScore, receiverScore int) []*BehaviorScore {
behaviors := getBehaviorByMode(mode) behaviors := getBehaviorByMode(mode)
scores := make([]*BehaviorScore, 0, len(behaviors)) scores := make([]*BehaviorScore, 0, len(behaviors))
for i := 0; i < len(behaviors); i++ { for i := range behaviors {
score := receiverScore score := receiverScore
if behaviors[i].A.Role == DetectRoleSender { if behaviors[i].A.Role == DetectRoleSender {
score = senderScore score = senderScore

View File

@@ -70,12 +70,8 @@ func ClassifyNATFeature(addresses []string, localIPs []string) (*NatFeature, err
continue continue
} }
if portNum > portMax { portMax = max(portMax, portNum)
portMax = portNum portMin = min(portMin, portNum)
}
if portNum < portMin {
portMin = portNum
}
if baseIP != ip { if baseIP != ip {
ipChanged = true ipChanged = true
} }

View File

@@ -152,7 +152,9 @@ func (c *Controller) GenSid() string {
func (c *Controller) HandleVisitor(m *msg.NatHoleVisitor, transporter transport.MessageTransporter, visitorUser string) { func (c *Controller) HandleVisitor(m *msg.NatHoleVisitor, transporter transport.MessageTransporter, visitorUser string) {
if m.PreCheck { if m.PreCheck {
c.mu.RLock()
cfg, ok := c.clientCfgs[m.ProxyName] cfg, ok := c.clientCfgs[m.ProxyName]
c.mu.RUnlock()
if !ok { if !ok {
_ = transporter.Send(c.GenNatHoleResponse(m.TransactionID, nil, fmt.Sprintf("xtcp server for [%s] doesn't exist", m.ProxyName))) _ = transporter.Send(c.GenNatHoleResponse(m.TransactionID, nil, fmt.Sprintf("xtcp server for [%s] doesn't exist", m.ProxyName)))
return return

View File

@@ -298,11 +298,13 @@ func waitDetectMessage(
n, raddr, err := conn.ReadFromUDP(buf) n, raddr, err := conn.ReadFromUDP(buf)
_ = conn.SetReadDeadline(time.Time{}) _ = conn.SetReadDeadline(time.Time{})
if err != nil { if err != nil {
pool.PutBuf(buf)
return nil, err return nil, err
} }
xl.Debugf("get udp message local %s, from %s", conn.LocalAddr(), raddr) xl.Debugf("get udp message local %s, from %s", conn.LocalAddr(), raddr)
var m msg.NatHoleSid var m msg.NatHoleSid
if err := DecodeMessageInto(buf[:n], key, &m); err != nil { if err := DecodeMessageInto(buf[:n], key, &m); err != nil {
pool.PutBuf(buf)
xl.Warnf("decode sid message error: %v", err) xl.Warnf("decode sid message error: %v", err)
continue continue
} }
@@ -408,7 +410,7 @@ func sendSidMessageToRandomPorts(
xl := xlog.FromContextSafe(ctx) xl := xlog.FromContextSafe(ctx)
used := sets.New[int]() used := sets.New[int]()
getUnusedPort := func() int { getUnusedPort := func() int {
for i := 0; i < 10; i++ { for range 10 {
port := rand.IntN(65535-1024) + 1024 port := rand.IntN(65535-1024) + 1024
if !used.Has(port) { if !used.Has(port) {
used.Insert(port) used.Insert(port)
@@ -418,7 +420,7 @@ func sendSidMessageToRandomPorts(
return 0 return 0
} }
for i := 0; i < count; i++ { for range count {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return return

View File

@@ -21,6 +21,7 @@ import (
stdlog "log" stdlog "log"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"time"
"github.com/fatedier/golib/pool" "github.com/fatedier/golib/pool"
@@ -68,7 +69,7 @@ func NewHTTP2HTTPPlugin(_ PluginContext, options v1.ClientPluginOptions) (Plugin
p.s = &http.Server{ p.s = &http.Server{
Handler: rp, Handler: rp,
ReadHeaderTimeout: 0, ReadHeaderTimeout: 60 * time.Second,
} }
go func() { go func() {

View File

@@ -22,6 +22,7 @@ import (
stdlog "log" stdlog "log"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"time"
"github.com/fatedier/golib/pool" "github.com/fatedier/golib/pool"
@@ -77,7 +78,7 @@ func NewHTTP2HTTPSPlugin(_ PluginContext, options v1.ClientPluginOptions) (Plugi
p.s = &http.Server{ p.s = &http.Server{
Handler: rp, Handler: rp,
ReadHeaderTimeout: 0, ReadHeaderTimeout: 60 * time.Second,
} }
go func() { go func() {

View File

@@ -62,11 +62,13 @@ func (p *TLS2RawPlugin) Handle(ctx context.Context, connInfo *ConnectionInfo) {
if err := tlsConn.Handshake(); err != nil { if err := tlsConn.Handshake(); err != nil {
xl.Warnf("tls handshake error: %v", err) xl.Warnf("tls handshake error: %v", err)
tlsConn.Close()
return return
} }
rawConn, err := net.Dial("tcp", p.opts.LocalAddr) rawConn, err := net.Dial("tcp", p.opts.LocalAddr)
if err != nil { if err != nil {
xl.Warnf("dial to local addr error: %v", err) xl.Warnf("dial to local addr error: %v", err)
tlsConn.Close()
return return
} }

View File

@@ -54,10 +54,13 @@ func (uds *UnixDomainSocketPlugin) Handle(ctx context.Context, connInfo *Connect
localConn, err := net.DialUnix("unix", nil, uds.UnixAddr) localConn, err := net.DialUnix("unix", nil, uds.UnixAddr)
if err != nil { if err != nil {
xl.Warnf("dial to uds %s error: %v", uds.UnixAddr, err) xl.Warnf("dial to uds %s error: %v", uds.UnixAddr, err)
connInfo.Conn.Close()
return return
} }
if connInfo.ProxyProtocolHeader != nil { if connInfo.ProxyProtocolHeader != nil {
if _, err := connInfo.ProxyProtocolHeader.WriteTo(localConn); err != nil { if _, err := connInfo.ProxyProtocolHeader.WriteTo(localConn); err != nil {
localConn.Close()
connInfo.Conn.Close()
return return
} }
} }

View File

@@ -24,6 +24,7 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"reflect" "reflect"
"slices"
"strings" "strings"
v1 "github.com/fatedier/frp/pkg/config/v1" v1 "github.com/fatedier/frp/pkg/config/v1"
@@ -64,12 +65,7 @@ func (p *httpPlugin) Name() string {
} }
func (p *httpPlugin) IsSupport(op string) bool { func (p *httpPlugin) IsSupport(op string) bool {
for _, v := range p.options.Ops { return slices.Contains(p.options.Ops, op)
if v == op {
return true
}
}
return false
} }
func (p *httpPlugin) Handle(ctx context.Context, op string, content any) (*Response, any, error) { func (p *httpPlugin) Handle(ctx context.Context, op string, content any) (*Response, any, error) {

View File

@@ -153,10 +153,7 @@ func (p *VirtualNetPlugin) run() {
// Exponential backoff: 60s, 120s, 240s, 300s (capped) // Exponential backoff: 60s, 120s, 240s, 300s (capped)
baseDelay := 60 * time.Second baseDelay := 60 * time.Second
reconnectDelay = baseDelay * time.Duration(1<<uint(p.consecutiveErrors-1)) reconnectDelay = min(baseDelay*time.Duration(1<<uint(p.consecutiveErrors-1)), 300*time.Second)
if reconnectDelay > 300*time.Second {
reconnectDelay = 300 * time.Second
}
} else { } else {
// Reset consecutive errors on successful connection // Reset consecutive errors on successful connection
if p.consecutiveErrors > 0 { if p.consecutiveErrors > 0 {

View File

@@ -16,6 +16,7 @@ package featuregate
import ( import (
"fmt" "fmt"
"maps"
"sort" "sort"
"strings" "strings"
"sync" "sync"
@@ -92,10 +93,7 @@ type featureGate struct {
// NewFeatureGate creates a new feature gate with the default features // NewFeatureGate creates a new feature gate with the default features
func NewFeatureGate() MutableFeatureGate { func NewFeatureGate() MutableFeatureGate {
known := map[Feature]FeatureSpec{} known := maps.Clone(defaultFeatures)
for k, v := range defaultFeatures {
known[k] = v
}
f := &featureGate{} f := &featureGate{}
f.known.Store(known) f.known.Store(known)
@@ -109,14 +107,8 @@ func (f *featureGate) SetFromMap(m map[string]bool) error {
defer f.lock.Unlock() defer f.lock.Unlock()
// Copy existing state // Copy existing state
known := map[Feature]FeatureSpec{} known := maps.Clone(f.known.Load().(map[Feature]FeatureSpec))
for k, v := range f.known.Load().(map[Feature]FeatureSpec) { enabled := maps.Clone(f.enabled.Load().(map[Feature]bool))
known[k] = v
}
enabled := map[Feature]bool{}
for k, v := range f.enabled.Load().(map[Feature]bool) {
enabled[k] = v
}
// Apply the new settings // Apply the new settings
for k, v := range m { for k, v := range m {
@@ -147,10 +139,7 @@ func (f *featureGate) Add(features map[Feature]FeatureSpec) error {
} }
// Copy existing state // Copy existing state
known := map[Feature]FeatureSpec{} known := maps.Clone(f.known.Load().(map[Feature]FeatureSpec))
for k, v := range f.known.Load().(map[Feature]FeatureSpec) {
known[k] = v
}
// Add new features // Add new features
for name, spec := range features { for name, spec := range features {

View File

@@ -15,7 +15,6 @@
package udp package udp
import ( import (
"encoding/base64"
"net" "net"
"sync" "sync"
"time" "time"
@@ -28,16 +27,17 @@ import (
) )
func NewUDPPacket(buf []byte, laddr, raddr *net.UDPAddr) *msg.UDPPacket { func NewUDPPacket(buf []byte, laddr, raddr *net.UDPAddr) *msg.UDPPacket {
content := make([]byte, len(buf))
copy(content, buf)
return &msg.UDPPacket{ return &msg.UDPPacket{
Content: base64.StdEncoding.EncodeToString(buf), Content: content,
LocalAddr: laddr, LocalAddr: laddr,
RemoteAddr: raddr, RemoteAddr: raddr,
} }
} }
func GetContent(m *msg.UDPPacket) (buf []byte, err error) { func GetContent(m *msg.UDPPacket) (buf []byte, err error) {
buf, err = base64.StdEncoding.DecodeString(m.Content) return m.Content, nil
return
} }
func ForwardUserConn(udpConn *net.UDPConn, readCh <-chan *msg.UDPPacket, sendCh chan<- *msg.UDPPacket, bufSize int) { func ForwardUserConn(udpConn *net.UDPConn, readCh <-chan *msg.UDPPacket, sendCh chan<- *msg.UDPPacket, bufSize int) {
@@ -60,7 +60,7 @@ func ForwardUserConn(udpConn *net.UDPConn, readCh <-chan *msg.UDPPacket, sendCh
if err != nil { if err != nil {
return return
} }
// buf[:n] will be encoded to string, so the bytes can be reused // NewUDPPacket copies buf[:n], so the read buffer can be reused
udpMsg := NewUDPPacket(buf[:n], nil, remoteAddr) udpMsg := NewUDPPacket(buf[:n], nil, remoteAddr)
select { select {
@@ -85,6 +85,7 @@ func Forwarder(dstAddr *net.UDPAddr, readCh <-chan *msg.UDPPacket, sendCh chan<-
}() }()
buf := pool.GetBuf(bufSize) buf := pool.GetBuf(bufSize)
defer pool.PutBuf(buf)
for { for {
_ = udpConn.SetReadDeadline(time.Now().Add(30 * time.Second)) _ = udpConn.SetReadDeadline(time.Now().Add(30 * time.Second))
n, _, err := udpConn.ReadFromUDP(buf) n, _, err := udpConn.ReadFromUDP(buf)

View File

@@ -20,6 +20,7 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
"fmt"
"math/big" "math/big"
"os" "os"
"time" "time"
@@ -85,7 +86,9 @@ func newCertPool(caPath string) (*x509.CertPool, error) {
return nil, err return nil, err
} }
pool.AppendCertsFromPEM(caCrt) if !pool.AppendCertsFromPEM(caCrt) {
return nil, fmt.Errorf("failed to parse CA certificate from file %q: no valid PEM certificates found", caPath)
}
return pool, nil return pool, nil
} }

View File

@@ -89,11 +89,11 @@ func ParseBasicAuth(auth string) (username, password string, ok bool) {
return return
} }
cs := string(c) cs := string(c)
s := strings.IndexByte(cs, ':') before, after, found := strings.Cut(cs, ":")
if s < 0 { if !found {
return return
} }
return cs[:s], cs[s+1:], true return before, after, true
} }
func BasicAuth(username, passwd string) string { func BasicAuth(username, passwd string) string {

View File

@@ -86,11 +86,7 @@ func (c *FakeUDPConn) Read(b []byte) (n int, err error) {
c.lastActive = time.Now() c.lastActive = time.Now()
c.mu.Unlock() c.mu.Unlock()
if len(b) < len(content) { n = min(len(b), len(content))
n = len(b)
} else {
n = len(content)
}
copy(b, content) copy(b, content)
return n, nil return n, nil
} }
@@ -168,11 +164,15 @@ func ListenUDP(bindAddr string, bindPort int) (l *UDPListener, err error) {
return l, err return l, err
} }
readConn, err := net.ListenUDP("udp", udpAddr) readConn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
return l, err
}
l = &UDPListener{ l = &UDPListener{
addr: udpAddr, addr: udpAddr,
acceptCh: make(chan net.Conn), acceptCh: make(chan net.Conn),
writeCh: make(chan *UDPPacket, 1000), writeCh: make(chan *UDPPacket, 1000),
readConn: readConn,
fakeConns: make(map[string]*FakeUDPConn), fakeConns: make(map[string]*FakeUDPConn),
} }

View File

@@ -26,6 +26,7 @@ type WebsocketListener struct {
// ln: tcp listener for websocket connections // ln: tcp listener for websocket connections
func NewWebsocketListener(ln net.Listener) (wl *WebsocketListener) { func NewWebsocketListener(ln net.Listener) (wl *WebsocketListener) {
wl = &WebsocketListener{ wl = &WebsocketListener{
ln: ln,
acceptCh: make(chan net.Conn), acceptCh: make(chan net.Conn),
} }

View File

@@ -68,8 +68,8 @@ func ParseRangeNumbers(rangeStr string) (numbers []int64, err error) {
rangeStr = strings.TrimSpace(rangeStr) rangeStr = strings.TrimSpace(rangeStr)
numbers = make([]int64, 0) numbers = make([]int64, 0)
// e.g. 1000-2000,2001,2002,3000-4000 // e.g. 1000-2000,2001,2002,3000-4000
numRanges := strings.Split(rangeStr, ",") numRanges := strings.SplitSeq(rangeStr, ",")
for _, numRangeStr := range numRanges { for numRangeStr := range numRanges {
// 1000-2000 or 2001 // 1000-2000 or 2001
numArray := strings.Split(numRangeStr, "-") numArray := strings.Split(numRangeStr, "-")
// length: only 1 or 2 is correct // length: only 1 or 2 is correct

View File

@@ -266,31 +266,13 @@ func (rp *HTTPReverseProxy) connectHandler(rw http.ResponseWriter, req *http.Req
go libio.Join(remote, client) go libio.Join(remote, client)
} }
func parseBasicAuth(auth string) (username, password string, ok bool) {
const prefix = "Basic "
// Case insensitive prefix match. See Issue 22736.
if len(auth) < len(prefix) || !strings.EqualFold(auth[:len(prefix)], prefix) {
return
}
c, err := base64.StdEncoding.DecodeString(auth[len(prefix):])
if err != nil {
return
}
cs := string(c)
s := strings.IndexByte(cs, ':')
if s < 0 {
return
}
return cs[:s], cs[s+1:], true
}
func (rp *HTTPReverseProxy) injectRequestInfoToCtx(req *http.Request) *http.Request { func (rp *HTTPReverseProxy) injectRequestInfoToCtx(req *http.Request) *http.Request {
user := "" user := ""
// If url host isn't empty, it's a proxy request. Get http user from Proxy-Authorization header. // If url host isn't empty, it's a proxy request. Get http user from Proxy-Authorization header.
if req.URL.Host != "" { if req.URL.Host != "" {
proxyAuth := req.Header.Get("Proxy-Authorization") proxyAuth := req.Header.Get("Proxy-Authorization")
if proxyAuth != "" { if proxyAuth != "" {
user, _, _ = parseBasicAuth(proxyAuth) user, _, _ = httppkg.ParseBasicAuth(proxyAuth)
} }
} }
if user == "" { if user == "" {

View File

@@ -63,11 +63,12 @@ func (l *Logger) AddPrefix(prefix LogPrefix) *Logger {
if prefix.Priority <= 0 { if prefix.Priority <= 0 {
prefix.Priority = 10 prefix.Priority = 10
} }
for _, p := range l.prefixes { for i, p := range l.prefixes {
if p.Name == prefix.Name { if p.Name == prefix.Name {
found = true found = true
p.Value = prefix.Value l.prefixes[i].Value = prefix.Value
p.Priority = prefix.Priority l.prefixes[i].Priority = prefix.Priority
break
} }
} }
if !found { if !found {

View File

@@ -95,20 +95,33 @@ func (cm *ControlManager) Close() error {
return nil return nil
} }
type Control struct { // SessionContext encapsulates the input parameters for creating a new Control.
type SessionContext struct {
// all resource managers and controllers // all resource managers and controllers
rc *controller.ResourceController RC *controller.ResourceController
// proxy manager // proxy manager
pxyManager *proxy.Manager PxyManager *proxy.Manager
// plugin manager // plugin manager
pluginManager *plugin.Manager PluginManager *plugin.Manager
// verifies authentication based on selected method // verifies authentication based on selected method
authVerifier auth.Verifier AuthVerifier auth.Verifier
// key used for connection encryption // key used for connection encryption
encryptionKey []byte EncryptionKey []byte
// control connection
Conn net.Conn
// indicates whether the connection is encrypted
ConnEncrypted bool
// login message
LoginMsg *msg.Login
// server configuration
ServerCfg *v1.ServerConfig
// client registry
ClientRegistry *registry.ClientRegistry
}
type Control struct {
// session context
sessionCtx *SessionContext
// other components can use this to communicate with client // other components can use this to communicate with client
msgTransporter transport.MessageTransporter msgTransporter transport.MessageTransporter
@@ -117,12 +130,6 @@ type Control struct {
// It provides a channel for sending messages, and you can register handlers to process messages based on their respective types. // It provides a channel for sending messages, and you can register handlers to process messages based on their respective types.
msgDispatcher *msg.Dispatcher msgDispatcher *msg.Dispatcher
// login message
loginMsg *msg.Login
// control connection
conn net.Conn
// work connections // work connections
workConnCh chan net.Conn workConnCh chan net.Conn
@@ -145,61 +152,34 @@ type Control struct {
mu sync.RWMutex mu sync.RWMutex
// Server configuration information
serverCfg *v1.ServerConfig
clientRegistry *registry.ClientRegistry
xl *xlog.Logger xl *xlog.Logger
ctx context.Context ctx context.Context
doneCh chan struct{} doneCh chan struct{}
} }
// TODO(fatedier): Referencing the implementation of frpc, encapsulate the input parameters as SessionContext. func NewControl(ctx context.Context, sessionCtx *SessionContext) (*Control, error) {
func NewControl( poolCount := min(sessionCtx.LoginMsg.PoolCount, int(sessionCtx.ServerCfg.Transport.MaxPoolCount))
ctx context.Context,
rc *controller.ResourceController,
pxyManager *proxy.Manager,
pluginManager *plugin.Manager,
authVerifier auth.Verifier,
encryptionKey []byte,
ctlConn net.Conn,
ctlConnEncrypted bool,
loginMsg *msg.Login,
serverCfg *v1.ServerConfig,
) (*Control, error) {
poolCount := loginMsg.PoolCount
if poolCount > int(serverCfg.Transport.MaxPoolCount) {
poolCount = int(serverCfg.Transport.MaxPoolCount)
}
ctl := &Control{ ctl := &Control{
rc: rc, sessionCtx: sessionCtx,
pxyManager: pxyManager, workConnCh: make(chan net.Conn, poolCount+10),
pluginManager: pluginManager, proxies: make(map[string]proxy.Proxy),
authVerifier: authVerifier, poolCount: poolCount,
encryptionKey: encryptionKey, portsUsedNum: 0,
conn: ctlConn, runID: sessionCtx.LoginMsg.RunID,
loginMsg: loginMsg, xl: xlog.FromContextSafe(ctx),
workConnCh: make(chan net.Conn, poolCount+10), ctx: ctx,
proxies: make(map[string]proxy.Proxy), doneCh: make(chan struct{}),
poolCount: poolCount,
portsUsedNum: 0,
runID: loginMsg.RunID,
serverCfg: serverCfg,
xl: xlog.FromContextSafe(ctx),
ctx: ctx,
doneCh: make(chan struct{}),
} }
ctl.lastPing.Store(time.Now()) ctl.lastPing.Store(time.Now())
if ctlConnEncrypted { if sessionCtx.ConnEncrypted {
cryptoRW, err := netpkg.NewCryptoReadWriter(ctl.conn, ctl.encryptionKey) cryptoRW, err := netpkg.NewCryptoReadWriter(sessionCtx.Conn, sessionCtx.EncryptionKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ctl.msgDispatcher = msg.NewDispatcher(cryptoRW) ctl.msgDispatcher = msg.NewDispatcher(cryptoRW)
} else { } else {
ctl.msgDispatcher = msg.NewDispatcher(ctl.conn) ctl.msgDispatcher = msg.NewDispatcher(sessionCtx.Conn)
} }
ctl.registerMsgHandlers() ctl.registerMsgHandlers()
ctl.msgTransporter = transport.NewMessageTransporter(ctl.msgDispatcher) ctl.msgTransporter = transport.NewMessageTransporter(ctl.msgDispatcher)
@@ -213,7 +193,7 @@ func (ctl *Control) Start() {
RunID: ctl.runID, RunID: ctl.runID,
Error: "", Error: "",
} }
_ = msg.WriteMsg(ctl.conn, loginRespMsg) _ = msg.WriteMsg(ctl.sessionCtx.Conn, loginRespMsg)
go func() { go func() {
for i := 0; i < ctl.poolCount; i++ { for i := 0; i < ctl.poolCount; i++ {
@@ -225,7 +205,7 @@ func (ctl *Control) Start() {
} }
func (ctl *Control) Close() error { func (ctl *Control) Close() error {
ctl.conn.Close() ctl.sessionCtx.Conn.Close()
return nil return nil
} }
@@ -233,7 +213,7 @@ func (ctl *Control) Replaced(newCtl *Control) {
xl := ctl.xl xl := ctl.xl
xl.Infof("replaced by client [%s]", newCtl.runID) xl.Infof("replaced by client [%s]", newCtl.runID)
ctl.runID = "" ctl.runID = ""
ctl.conn.Close() ctl.sessionCtx.Conn.Close()
} }
func (ctl *Control) RegisterWorkConn(conn net.Conn) error { func (ctl *Control) RegisterWorkConn(conn net.Conn) error {
@@ -291,7 +271,7 @@ func (ctl *Control) GetWorkConn() (workConn net.Conn, err error) {
return return
} }
case <-time.After(time.Duration(ctl.serverCfg.UserConnTimeout) * time.Second): case <-time.After(time.Duration(ctl.sessionCtx.ServerCfg.UserConnTimeout) * time.Second):
err = fmt.Errorf("timeout trying to get work connection") err = fmt.Errorf("timeout trying to get work connection")
xl.Warnf("%v", err) xl.Warnf("%v", err)
return return
@@ -304,15 +284,15 @@ func (ctl *Control) GetWorkConn() (workConn net.Conn, err error) {
} }
func (ctl *Control) heartbeatWorker() { func (ctl *Control) heartbeatWorker() {
if ctl.serverCfg.Transport.HeartbeatTimeout <= 0 { if ctl.sessionCtx.ServerCfg.Transport.HeartbeatTimeout <= 0 {
return return
} }
xl := ctl.xl xl := ctl.xl
go wait.Until(func() { go wait.Until(func() {
if time.Since(ctl.lastPing.Load().(time.Time)) > time.Duration(ctl.serverCfg.Transport.HeartbeatTimeout)*time.Second { if time.Since(ctl.lastPing.Load().(time.Time)) > time.Duration(ctl.sessionCtx.ServerCfg.Transport.HeartbeatTimeout)*time.Second {
xl.Warnf("heartbeat timeout") xl.Warnf("heartbeat timeout")
ctl.conn.Close() ctl.sessionCtx.Conn.Close()
return return
} }
}, time.Second, ctl.doneCh) }, time.Second, ctl.doneCh)
@@ -323,6 +303,30 @@ func (ctl *Control) WaitClosed() {
<-ctl.doneCh <-ctl.doneCh
} }
func (ctl *Control) loginUserInfo() plugin.UserInfo {
return plugin.UserInfo{
User: ctl.sessionCtx.LoginMsg.User,
Metas: ctl.sessionCtx.LoginMsg.Metas,
RunID: ctl.sessionCtx.LoginMsg.RunID,
}
}
func (ctl *Control) closeProxy(pxy proxy.Proxy) {
pxy.Close()
ctl.sessionCtx.PxyManager.Del(pxy.GetName())
metrics.Server.CloseProxy(pxy.GetName(), pxy.GetConfigurer().GetBaseConfig().Type)
notifyContent := &plugin.CloseProxyContent{
User: ctl.loginUserInfo(),
CloseProxy: msg.CloseProxy{
ProxyName: pxy.GetName(),
},
}
go func() {
_ = ctl.sessionCtx.PluginManager.CloseProxy(notifyContent)
}()
}
func (ctl *Control) worker() { func (ctl *Control) worker() {
xl := ctl.xl xl := ctl.xl
@@ -330,38 +334,23 @@ func (ctl *Control) worker() {
go ctl.msgDispatcher.Run() go ctl.msgDispatcher.Run()
<-ctl.msgDispatcher.Done() <-ctl.msgDispatcher.Done()
ctl.conn.Close() ctl.sessionCtx.Conn.Close()
ctl.mu.Lock() ctl.mu.Lock()
defer ctl.mu.Unlock()
close(ctl.workConnCh) close(ctl.workConnCh)
for workConn := range ctl.workConnCh { for workConn := range ctl.workConnCh {
workConn.Close() workConn.Close()
} }
proxies := ctl.proxies
ctl.proxies = make(map[string]proxy.Proxy)
ctl.mu.Unlock()
for _, pxy := range ctl.proxies { for _, pxy := range proxies {
pxy.Close() ctl.closeProxy(pxy)
ctl.pxyManager.Del(pxy.GetName())
metrics.Server.CloseProxy(pxy.GetName(), pxy.GetConfigurer().GetBaseConfig().Type)
notifyContent := &plugin.CloseProxyContent{
User: plugin.UserInfo{
User: ctl.loginMsg.User,
Metas: ctl.loginMsg.Metas,
RunID: ctl.loginMsg.RunID,
},
CloseProxy: msg.CloseProxy{
ProxyName: pxy.GetName(),
},
}
go func() {
_ = ctl.pluginManager.CloseProxy(notifyContent)
}()
} }
metrics.Server.CloseClient() metrics.Server.CloseClient()
ctl.clientRegistry.MarkOfflineByRunID(ctl.runID) ctl.sessionCtx.ClientRegistry.MarkOfflineByRunID(ctl.runID)
xl.Infof("client exit success") xl.Infof("client exit success")
close(ctl.doneCh) close(ctl.doneCh)
} }
@@ -380,15 +369,11 @@ func (ctl *Control) handleNewProxy(m msg.Message) {
inMsg := m.(*msg.NewProxy) inMsg := m.(*msg.NewProxy)
content := &plugin.NewProxyContent{ content := &plugin.NewProxyContent{
User: plugin.UserInfo{ User: ctl.loginUserInfo(),
User: ctl.loginMsg.User,
Metas: ctl.loginMsg.Metas,
RunID: ctl.loginMsg.RunID,
},
NewProxy: *inMsg, NewProxy: *inMsg,
} }
var remoteAddr string var remoteAddr string
retContent, err := ctl.pluginManager.NewProxy(content) retContent, err := ctl.sessionCtx.PluginManager.NewProxy(content)
if err == nil { if err == nil {
inMsg = &retContent.NewProxy inMsg = &retContent.NewProxy
remoteAddr, err = ctl.RegisterProxy(inMsg) remoteAddr, err = ctl.RegisterProxy(inMsg)
@@ -401,15 +386,15 @@ func (ctl *Control) handleNewProxy(m msg.Message) {
if err != nil { if err != nil {
xl.Warnf("new proxy [%s] type [%s] error: %v", inMsg.ProxyName, inMsg.ProxyType, err) xl.Warnf("new proxy [%s] type [%s] error: %v", inMsg.ProxyName, inMsg.ProxyType, err)
resp.Error = util.GenerateResponseErrorString(fmt.Sprintf("new proxy [%s] error", inMsg.ProxyName), resp.Error = util.GenerateResponseErrorString(fmt.Sprintf("new proxy [%s] error", inMsg.ProxyName),
err, lo.FromPtr(ctl.serverCfg.DetailedErrorsToClient)) err, lo.FromPtr(ctl.sessionCtx.ServerCfg.DetailedErrorsToClient))
} else { } else {
resp.RemoteAddr = remoteAddr resp.RemoteAddr = remoteAddr
xl.Infof("new proxy [%s] type [%s] success", inMsg.ProxyName, inMsg.ProxyType) xl.Infof("new proxy [%s] type [%s] success", inMsg.ProxyName, inMsg.ProxyType)
clientID := ctl.loginMsg.ClientID clientID := ctl.sessionCtx.LoginMsg.ClientID
if clientID == "" { if clientID == "" {
clientID = ctl.loginMsg.RunID clientID = ctl.sessionCtx.LoginMsg.RunID
} }
metrics.Server.NewProxy(inMsg.ProxyName, inMsg.ProxyType, ctl.loginMsg.User, clientID) metrics.Server.NewProxy(inMsg.ProxyName, inMsg.ProxyType, ctl.sessionCtx.LoginMsg.User, clientID)
} }
_ = ctl.msgDispatcher.Send(resp) _ = ctl.msgDispatcher.Send(resp)
} }
@@ -419,22 +404,18 @@ func (ctl *Control) handlePing(m msg.Message) {
inMsg := m.(*msg.Ping) inMsg := m.(*msg.Ping)
content := &plugin.PingContent{ content := &plugin.PingContent{
User: plugin.UserInfo{ User: ctl.loginUserInfo(),
User: ctl.loginMsg.User,
Metas: ctl.loginMsg.Metas,
RunID: ctl.loginMsg.RunID,
},
Ping: *inMsg, Ping: *inMsg,
} }
retContent, err := ctl.pluginManager.Ping(content) retContent, err := ctl.sessionCtx.PluginManager.Ping(content)
if err == nil { if err == nil {
inMsg = &retContent.Ping inMsg = &retContent.Ping
err = ctl.authVerifier.VerifyPing(inMsg) err = ctl.sessionCtx.AuthVerifier.VerifyPing(inMsg)
} }
if err != nil { if err != nil {
xl.Warnf("received invalid ping: %v", err) xl.Warnf("received invalid ping: %v", err)
_ = ctl.msgDispatcher.Send(&msg.Pong{ _ = ctl.msgDispatcher.Send(&msg.Pong{
Error: util.GenerateResponseErrorString("invalid ping", err, lo.FromPtr(ctl.serverCfg.DetailedErrorsToClient)), Error: util.GenerateResponseErrorString("invalid ping", err, lo.FromPtr(ctl.sessionCtx.ServerCfg.DetailedErrorsToClient)),
}) })
return return
} }
@@ -445,17 +426,17 @@ func (ctl *Control) handlePing(m msg.Message) {
func (ctl *Control) handleNatHoleVisitor(m msg.Message) { func (ctl *Control) handleNatHoleVisitor(m msg.Message) {
inMsg := m.(*msg.NatHoleVisitor) inMsg := m.(*msg.NatHoleVisitor)
ctl.rc.NatHoleController.HandleVisitor(inMsg, ctl.msgTransporter, ctl.loginMsg.User) ctl.sessionCtx.RC.NatHoleController.HandleVisitor(inMsg, ctl.msgTransporter, ctl.sessionCtx.LoginMsg.User)
} }
func (ctl *Control) handleNatHoleClient(m msg.Message) { func (ctl *Control) handleNatHoleClient(m msg.Message) {
inMsg := m.(*msg.NatHoleClient) inMsg := m.(*msg.NatHoleClient)
ctl.rc.NatHoleController.HandleClient(inMsg, ctl.msgTransporter) ctl.sessionCtx.RC.NatHoleController.HandleClient(inMsg, ctl.msgTransporter)
} }
func (ctl *Control) handleNatHoleReport(m msg.Message) { func (ctl *Control) handleNatHoleReport(m msg.Message) {
inMsg := m.(*msg.NatHoleReport) inMsg := m.(*msg.NatHoleReport)
ctl.rc.NatHoleController.HandleReport(inMsg) ctl.sessionCtx.RC.NatHoleController.HandleReport(inMsg)
} }
func (ctl *Control) handleCloseProxy(m msg.Message) { func (ctl *Control) handleCloseProxy(m msg.Message) {
@@ -468,15 +449,15 @@ func (ctl *Control) handleCloseProxy(m msg.Message) {
func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err error) { func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err error) {
var pxyConf v1.ProxyConfigurer var pxyConf v1.ProxyConfigurer
// Load configures from NewProxy message and validate. // Load configures from NewProxy message and validate.
pxyConf, err = config.NewProxyConfigurerFromMsg(pxyMsg, ctl.serverCfg) pxyConf, err = config.NewProxyConfigurerFromMsg(pxyMsg, ctl.sessionCtx.ServerCfg)
if err != nil { if err != nil {
return return
} }
// User info // User info
userInfo := plugin.UserInfo{ userInfo := plugin.UserInfo{
User: ctl.loginMsg.User, User: ctl.sessionCtx.LoginMsg.User,
Metas: ctl.loginMsg.Metas, Metas: ctl.sessionCtx.LoginMsg.Metas,
RunID: ctl.runID, RunID: ctl.runID,
} }
@@ -484,22 +465,22 @@ func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err
// In fact, it creates different proxies based on the proxy type. We just call run() here. // In fact, it creates different proxies based on the proxy type. We just call run() here.
pxy, err := proxy.NewProxy(ctl.ctx, &proxy.Options{ pxy, err := proxy.NewProxy(ctl.ctx, &proxy.Options{
UserInfo: userInfo, UserInfo: userInfo,
LoginMsg: ctl.loginMsg, LoginMsg: ctl.sessionCtx.LoginMsg,
PoolCount: ctl.poolCount, PoolCount: ctl.poolCount,
ResourceController: ctl.rc, ResourceController: ctl.sessionCtx.RC,
GetWorkConnFn: ctl.GetWorkConn, GetWorkConnFn: ctl.GetWorkConn,
Configurer: pxyConf, Configurer: pxyConf,
ServerCfg: ctl.serverCfg, ServerCfg: ctl.sessionCtx.ServerCfg,
EncryptionKey: ctl.encryptionKey, EncryptionKey: ctl.sessionCtx.EncryptionKey,
}) })
if err != nil { if err != nil {
return remoteAddr, err return remoteAddr, err
} }
// Check ports used number in each client // Check ports used number in each client
if ctl.serverCfg.MaxPortsPerClient > 0 { if ctl.sessionCtx.ServerCfg.MaxPortsPerClient > 0 {
ctl.mu.Lock() ctl.mu.Lock()
if ctl.portsUsedNum+pxy.GetUsedPortsNum() > int(ctl.serverCfg.MaxPortsPerClient) { if ctl.portsUsedNum+pxy.GetUsedPortsNum() > int(ctl.sessionCtx.ServerCfg.MaxPortsPerClient) {
ctl.mu.Unlock() ctl.mu.Unlock()
err = fmt.Errorf("exceed the max_ports_per_client") err = fmt.Errorf("exceed the max_ports_per_client")
return return
@@ -516,7 +497,7 @@ func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err
}() }()
} }
if ctl.pxyManager.Exist(pxyMsg.ProxyName) { if ctl.sessionCtx.PxyManager.Exist(pxyMsg.ProxyName) {
err = fmt.Errorf("proxy [%s] already exists", pxyMsg.ProxyName) err = fmt.Errorf("proxy [%s] already exists", pxyMsg.ProxyName)
return return
} }
@@ -531,7 +512,7 @@ func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err
} }
}() }()
err = ctl.pxyManager.Add(pxyMsg.ProxyName, pxy) err = ctl.sessionCtx.PxyManager.Add(pxyMsg.ProxyName, pxy)
if err != nil { if err != nil {
return return
} }
@@ -550,28 +531,12 @@ func (ctl *Control) CloseProxy(closeMsg *msg.CloseProxy) (err error) {
return return
} }
if ctl.serverCfg.MaxPortsPerClient > 0 { if ctl.sessionCtx.ServerCfg.MaxPortsPerClient > 0 {
ctl.portsUsedNum -= pxy.GetUsedPortsNum() ctl.portsUsedNum -= pxy.GetUsedPortsNum()
} }
pxy.Close()
ctl.pxyManager.Del(pxy.GetName())
delete(ctl.proxies, closeMsg.ProxyName) delete(ctl.proxies, closeMsg.ProxyName)
ctl.mu.Unlock() ctl.mu.Unlock()
metrics.Server.CloseProxy(pxy.GetName(), pxy.GetConfigurer().GetBaseConfig().Type) ctl.closeProxy(pxy)
notifyContent := &plugin.CloseProxyContent{
User: plugin.UserInfo{
User: ctl.loginMsg.User,
Metas: ctl.loginMsg.Metas,
RunID: ctl.loginMsg.RunID,
},
CloseProxy: msg.CloseProxy{
ProxyName: pxy.GetName(),
},
}
go func() {
_ = ctl.pluginManager.CloseProxy(notifyContent)
}()
return return
} }

77
server/group/base.go Normal file
View File

@@ -0,0 +1,77 @@
package group
import (
"net"
"sync"
gerr "github.com/fatedier/golib/errors"
)
// baseGroup contains the shared plumbing for listener-based groups
// (TCP, HTTPS, TCPMux). Each concrete group embeds this and provides
// its own Listen method with protocol-specific validation.
type baseGroup struct {
group string
groupKey string
acceptCh chan net.Conn
realLn net.Listener
lns []*Listener
mu sync.Mutex
cleanupFn func()
}
// initBase resets the baseGroup for a fresh listen cycle.
// Must be called under mu when len(lns) == 0.
func (bg *baseGroup) initBase(group, groupKey string, realLn net.Listener, cleanupFn func()) {
bg.group = group
bg.groupKey = groupKey
bg.realLn = realLn
bg.acceptCh = make(chan net.Conn)
bg.cleanupFn = cleanupFn
}
// worker reads from the real listener and fans out to acceptCh.
// The parameters are captured at creation time so that the worker is
// bound to a specific listen cycle and cannot observe a later initBase.
func (bg *baseGroup) worker(realLn net.Listener, acceptCh chan<- net.Conn) {
for {
c, err := realLn.Accept()
if err != nil {
return
}
err = gerr.PanicToError(func() {
acceptCh <- c
})
if err != nil {
c.Close()
return
}
}
}
// newListener creates a new Listener wired to this baseGroup.
// Must be called under mu.
func (bg *baseGroup) newListener(addr net.Addr) *Listener {
ln := newListener(bg.acceptCh, addr, bg.closeListener)
bg.lns = append(bg.lns, ln)
return ln
}
// closeListener removes ln from the list. When the last listener is removed,
// it closes acceptCh, closes the real listener, and calls cleanupFn.
func (bg *baseGroup) closeListener(ln *Listener) {
bg.mu.Lock()
defer bg.mu.Unlock()
for i, l := range bg.lns {
if l == ln {
bg.lns = append(bg.lns[:i], bg.lns[i+1:]...)
break
}
}
if len(bg.lns) == 0 {
close(bg.acceptCh)
bg.realLn.Close()
bg.cleanupFn()
}
}

169
server/group/base_test.go Normal file
View File

@@ -0,0 +1,169 @@
package group
import (
"net"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// fakeLn is a controllable net.Listener for tests.
type fakeLn struct {
connCh chan net.Conn
closed chan struct{}
once sync.Once
}
func newFakeLn() *fakeLn {
return &fakeLn{
connCh: make(chan net.Conn, 8),
closed: make(chan struct{}),
}
}
func (f *fakeLn) Accept() (net.Conn, error) {
select {
case c := <-f.connCh:
return c, nil
case <-f.closed:
return nil, net.ErrClosed
}
}
func (f *fakeLn) Close() error {
f.once.Do(func() { close(f.closed) })
return nil
}
func (f *fakeLn) Addr() net.Addr { return fakeAddr("127.0.0.1:9999") }
func (f *fakeLn) inject(c net.Conn) {
select {
case f.connCh <- c:
case <-f.closed:
}
}
func TestBaseGroup_WorkerFanOut(t *testing.T) {
fl := newFakeLn()
var bg baseGroup
bg.initBase("g", "key", fl, func() {})
go bg.worker(fl, bg.acceptCh)
c1, c2 := net.Pipe()
defer c2.Close()
fl.inject(c1)
select {
case got := <-bg.acceptCh:
assert.Equal(t, c1, got)
got.Close()
case <-time.After(time.Second):
t.Fatal("timed out waiting for connection on acceptCh")
}
fl.Close()
}
func TestBaseGroup_WorkerStopsOnListenerClose(t *testing.T) {
fl := newFakeLn()
var bg baseGroup
bg.initBase("g", "key", fl, func() {})
done := make(chan struct{})
go func() {
bg.worker(fl, bg.acceptCh)
close(done)
}()
fl.Close()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("worker did not stop after listener close")
}
}
func TestBaseGroup_WorkerClosesConnOnClosedChannel(t *testing.T) {
fl := newFakeLn()
var bg baseGroup
bg.initBase("g", "key", fl, func() {})
// Close acceptCh before worker sends.
close(bg.acceptCh)
done := make(chan struct{})
go func() {
bg.worker(fl, bg.acceptCh)
close(done)
}()
c1, c2 := net.Pipe()
defer c2.Close()
fl.inject(c1)
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("worker did not stop after panic recovery")
}
// c1 should have been closed by worker's panic recovery path.
buf := make([]byte, 1)
_, err := c1.Read(buf)
assert.Error(t, err, "connection should be closed by worker")
}
func TestBaseGroup_CloseLastListenerTriggersCleanup(t *testing.T) {
fl := newFakeLn()
var bg baseGroup
cleanupCalled := 0
bg.initBase("g", "key", fl, func() { cleanupCalled++ })
bg.mu.Lock()
ln1 := bg.newListener(fl.Addr())
ln2 := bg.newListener(fl.Addr())
bg.mu.Unlock()
go bg.worker(fl, bg.acceptCh)
ln1.Close()
assert.Equal(t, 0, cleanupCalled, "cleanup should not run while listeners remain")
ln2.Close()
assert.Equal(t, 1, cleanupCalled, "cleanup should run after last listener closed")
}
func TestBaseGroup_CloseOneOfTwoListeners(t *testing.T) {
fl := newFakeLn()
var bg baseGroup
cleanupCalled := 0
bg.initBase("g", "key", fl, func() { cleanupCalled++ })
bg.mu.Lock()
ln1 := bg.newListener(fl.Addr())
ln2 := bg.newListener(fl.Addr())
bg.mu.Unlock()
go bg.worker(fl, bg.acceptCh)
ln1.Close()
assert.Equal(t, 0, cleanupCalled)
// ln2 should still receive connections.
c1, c2 := net.Pipe()
defer c2.Close()
fl.inject(c1)
got, err := ln2.Accept()
require.NoError(t, err)
assert.Equal(t, c1, got)
got.Close()
ln2.Close()
assert.Equal(t, 1, cleanupCalled)
}

View File

@@ -24,4 +24,6 @@ var (
ErrListenerClosed = errors.New("group listener closed") ErrListenerClosed = errors.New("group listener closed")
ErrGroupDifferentPort = errors.New("group should have same remote port") ErrGroupDifferentPort = errors.New("group should have same remote port")
ErrProxyRepeated = errors.New("group proxy repeated") ErrProxyRepeated = errors.New("group proxy repeated")
errGroupStale = errors.New("stale group reference")
) )

View File

@@ -9,53 +9,42 @@ import (
"github.com/fatedier/frp/pkg/util/vhost" "github.com/fatedier/frp/pkg/util/vhost"
) )
// HTTPGroupController manages HTTP groups that use round-robin
// callback routing (fundamentally different from listener-based groups).
type HTTPGroupController struct { type HTTPGroupController struct {
// groups indexed by group name groupRegistry[*HTTPGroup]
groups map[string]*HTTPGroup
// register createConn for each group to vhostRouter.
// createConn will get a connection from one proxy of the group
vhostRouter *vhost.Routers vhostRouter *vhost.Routers
mu sync.Mutex
} }
func NewHTTPGroupController(vhostRouter *vhost.Routers) *HTTPGroupController { func NewHTTPGroupController(vhostRouter *vhost.Routers) *HTTPGroupController {
return &HTTPGroupController{ return &HTTPGroupController{
groups: make(map[string]*HTTPGroup), groupRegistry: newGroupRegistry[*HTTPGroup](),
vhostRouter: vhostRouter, vhostRouter: vhostRouter,
} }
} }
func (ctl *HTTPGroupController) Register( func (ctl *HTTPGroupController) Register(
proxyName, group, groupKey string, proxyName, group, groupKey string,
routeConfig vhost.RouteConfig, routeConfig vhost.RouteConfig,
) (err error) { ) error {
indexKey := group for {
ctl.mu.Lock() g := ctl.getOrCreate(group, func() *HTTPGroup {
g, ok := ctl.groups[indexKey] return NewHTTPGroup(ctl)
if !ok { })
g = NewHTTPGroup(ctl) err := g.Register(proxyName, group, groupKey, routeConfig)
ctl.groups[indexKey] = g if err == errGroupStale {
continue
}
return err
} }
ctl.mu.Unlock()
return g.Register(proxyName, group, groupKey, routeConfig)
} }
func (ctl *HTTPGroupController) UnRegister(proxyName, group string, _ vhost.RouteConfig) { func (ctl *HTTPGroupController) UnRegister(proxyName, group string, _ vhost.RouteConfig) {
indexKey := group g, ok := ctl.get(group)
ctl.mu.Lock()
defer ctl.mu.Unlock()
g, ok := ctl.groups[indexKey]
if !ok { if !ok {
return return
} }
g.UnRegister(proxyName)
isEmpty := g.UnRegister(proxyName)
if isEmpty {
delete(ctl.groups, indexKey)
}
} }
type HTTPGroup struct { type HTTPGroup struct {
@@ -87,6 +76,9 @@ func (g *HTTPGroup) Register(
) (err error) { ) (err error) {
g.mu.Lock() g.mu.Lock()
defer g.mu.Unlock() defer g.mu.Unlock()
if !g.ctl.isCurrent(group, func(cur *HTTPGroup) bool { return cur == g }) {
return errGroupStale
}
if len(g.createFuncs) == 0 { if len(g.createFuncs) == 0 {
// the first proxy in this group // the first proxy in this group
tmp := routeConfig // copy object tmp := routeConfig // copy object
@@ -123,7 +115,7 @@ func (g *HTTPGroup) Register(
return nil return nil
} }
func (g *HTTPGroup) UnRegister(proxyName string) (isEmpty bool) { func (g *HTTPGroup) UnRegister(proxyName string) {
g.mu.Lock() g.mu.Lock()
defer g.mu.Unlock() defer g.mu.Unlock()
delete(g.createFuncs, proxyName) delete(g.createFuncs, proxyName)
@@ -135,10 +127,11 @@ func (g *HTTPGroup) UnRegister(proxyName string) (isEmpty bool) {
} }
if len(g.createFuncs) == 0 { if len(g.createFuncs) == 0 {
isEmpty = true
g.ctl.vhostRouter.Del(g.domain, g.location, g.routeByHTTPUser) g.ctl.vhostRouter.Del(g.domain, g.location, g.routeByHTTPUser)
g.ctl.removeIf(g.group, func(cur *HTTPGroup) bool {
return cur == g
})
} }
return
} }
func (g *HTTPGroup) createConn(remoteAddr string) (net.Conn, error) { func (g *HTTPGroup) createConn(remoteAddr string) (net.Conn, error) {
@@ -151,7 +144,7 @@ func (g *HTTPGroup) createConn(remoteAddr string) (net.Conn, error) {
location := g.location location := g.location
routeByHTTPUser := g.routeByHTTPUser routeByHTTPUser := g.routeByHTTPUser
if len(g.pxyNames) > 0 { if len(g.pxyNames) > 0 {
name := g.pxyNames[int(newIndex)%len(g.pxyNames)] name := g.pxyNames[newIndex%uint64(len(g.pxyNames))]
f = g.createFuncs[name] f = g.createFuncs[name]
} }
g.mu.RUnlock() g.mu.RUnlock()
@@ -174,7 +167,7 @@ func (g *HTTPGroup) chooseEndpoint() (string, error) {
location := g.location location := g.location
routeByHTTPUser := g.routeByHTTPUser routeByHTTPUser := g.routeByHTTPUser
if len(g.pxyNames) > 0 { if len(g.pxyNames) > 0 {
name = g.pxyNames[int(newIndex)%len(g.pxyNames)] name = g.pxyNames[newIndex%uint64(len(g.pxyNames))]
} }
g.mu.RUnlock() g.mu.RUnlock()

View File

@@ -17,25 +17,19 @@ package group
import ( import (
"context" "context"
"net" "net"
"sync"
gerr "github.com/fatedier/golib/errors"
"github.com/fatedier/frp/pkg/util/vhost" "github.com/fatedier/frp/pkg/util/vhost"
) )
type HTTPSGroupController struct { type HTTPSGroupController struct {
groups map[string]*HTTPSGroup groupRegistry[*HTTPSGroup]
httpsMuxer *vhost.HTTPSMuxer httpsMuxer *vhost.HTTPSMuxer
mu sync.Mutex
} }
func NewHTTPSGroupController(httpsMuxer *vhost.HTTPSMuxer) *HTTPSGroupController { func NewHTTPSGroupController(httpsMuxer *vhost.HTTPSMuxer) *HTTPSGroupController {
return &HTTPSGroupController{ return &HTTPSGroupController{
groups: make(map[string]*HTTPSGroup), groupRegistry: newGroupRegistry[*HTTPSGroup](),
httpsMuxer: httpsMuxer, httpsMuxer: httpsMuxer,
} }
} }
@@ -44,41 +38,28 @@ func (ctl *HTTPSGroupController) Listen(
group, groupKey string, group, groupKey string,
routeConfig vhost.RouteConfig, routeConfig vhost.RouteConfig,
) (l net.Listener, err error) { ) (l net.Listener, err error) {
indexKey := group for {
ctl.mu.Lock() g := ctl.getOrCreate(group, func() *HTTPSGroup {
g, ok := ctl.groups[indexKey] return NewHTTPSGroup(ctl)
if !ok { })
g = NewHTTPSGroup(ctl) l, err = g.Listen(ctx, group, groupKey, routeConfig)
ctl.groups[indexKey] = g if err == errGroupStale {
continue
}
return
} }
ctl.mu.Unlock()
return g.Listen(ctx, group, groupKey, routeConfig)
}
func (ctl *HTTPSGroupController) RemoveGroup(group string) {
ctl.mu.Lock()
defer ctl.mu.Unlock()
delete(ctl.groups, group)
} }
type HTTPSGroup struct { type HTTPSGroup struct {
group string baseGroup
groupKey string
domain string
acceptCh chan net.Conn domain string
httpsLn *vhost.Listener ctl *HTTPSGroupController
lns []*HTTPSGroupListener
ctl *HTTPSGroupController
mu sync.Mutex
} }
func NewHTTPSGroup(ctl *HTTPSGroupController) *HTTPSGroup { func NewHTTPSGroup(ctl *HTTPSGroupController) *HTTPSGroup {
return &HTTPSGroup{ return &HTTPSGroup{
lns: make([]*HTTPSGroupListener, 0), ctl: ctl,
ctl: ctl,
acceptCh: make(chan net.Conn),
} }
} }
@@ -86,23 +67,27 @@ func (g *HTTPSGroup) Listen(
ctx context.Context, ctx context.Context,
group, groupKey string, group, groupKey string,
routeConfig vhost.RouteConfig, routeConfig vhost.RouteConfig,
) (ln *HTTPSGroupListener, err error) { ) (ln *Listener, err error) {
g.mu.Lock() g.mu.Lock()
defer g.mu.Unlock() defer g.mu.Unlock()
if !g.ctl.isCurrent(group, func(cur *HTTPSGroup) bool { return cur == g }) {
return nil, errGroupStale
}
if len(g.lns) == 0 { if len(g.lns) == 0 {
// the first listener, listen on the real address // the first listener, listen on the real address
httpsLn, errRet := g.ctl.httpsMuxer.Listen(ctx, &routeConfig) httpsLn, errRet := g.ctl.httpsMuxer.Listen(ctx, &routeConfig)
if errRet != nil { if errRet != nil {
return nil, errRet return nil, errRet
} }
ln = newHTTPSGroupListener(group, g, httpsLn.Addr())
g.group = group
g.groupKey = groupKey
g.domain = routeConfig.Domain g.domain = routeConfig.Domain
g.httpsLn = httpsLn g.initBase(group, groupKey, httpsLn, func() {
g.lns = append(g.lns, ln) g.ctl.removeIf(g.group, func(cur *HTTPSGroup) bool {
go g.worker() return cur == g
})
})
ln = g.newListener(httpsLn.Addr())
go g.worker(httpsLn, g.acceptCh)
} else { } else {
// route config in the same group must be equal // route config in the same group must be equal
if g.group != group || g.domain != routeConfig.Domain { if g.group != group || g.domain != routeConfig.Domain {
@@ -111,87 +96,7 @@ func (g *HTTPSGroup) Listen(
if g.groupKey != groupKey { if g.groupKey != groupKey {
return nil, ErrGroupAuthFailed return nil, ErrGroupAuthFailed
} }
ln = newHTTPSGroupListener(group, g, g.lns[0].Addr()) ln = g.newListener(g.lns[0].Addr())
g.lns = append(g.lns, ln)
} }
return return
} }
func (g *HTTPSGroup) worker() {
for {
c, err := g.httpsLn.Accept()
if err != nil {
return
}
err = gerr.PanicToError(func() {
g.acceptCh <- c
})
if err != nil {
return
}
}
}
func (g *HTTPSGroup) Accept() <-chan net.Conn {
return g.acceptCh
}
func (g *HTTPSGroup) CloseListener(ln *HTTPSGroupListener) {
g.mu.Lock()
defer g.mu.Unlock()
for i, tmpLn := range g.lns {
if tmpLn == ln {
g.lns = append(g.lns[:i], g.lns[i+1:]...)
break
}
}
if len(g.lns) == 0 {
close(g.acceptCh)
if g.httpsLn != nil {
g.httpsLn.Close()
}
g.ctl.RemoveGroup(g.group)
}
}
type HTTPSGroupListener struct {
groupName string
group *HTTPSGroup
addr net.Addr
closeCh chan struct{}
}
func newHTTPSGroupListener(name string, group *HTTPSGroup, addr net.Addr) *HTTPSGroupListener {
return &HTTPSGroupListener{
groupName: name,
group: group,
addr: addr,
closeCh: make(chan struct{}),
}
}
func (ln *HTTPSGroupListener) Accept() (c net.Conn, err error) {
var ok bool
select {
case <-ln.closeCh:
return nil, ErrListenerClosed
case c, ok = <-ln.group.Accept():
if !ok {
return nil, ErrListenerClosed
}
return c, nil
}
}
func (ln *HTTPSGroupListener) Addr() net.Addr {
return ln.addr
}
func (ln *HTTPSGroupListener) Close() (err error) {
close(ln.closeCh)
// remove self from HTTPSGroup
ln.group.CloseListener(ln)
return
}

49
server/group/listener.go Normal file
View File

@@ -0,0 +1,49 @@
package group
import (
"net"
"sync"
)
// Listener is a per-proxy virtual listener that receives connections
// from a shared group. It implements net.Listener.
type Listener struct {
acceptCh <-chan net.Conn
addr net.Addr
closeCh chan struct{}
onClose func(*Listener)
once sync.Once
}
func newListener(acceptCh <-chan net.Conn, addr net.Addr, onClose func(*Listener)) *Listener {
return &Listener{
acceptCh: acceptCh,
addr: addr,
closeCh: make(chan struct{}),
onClose: onClose,
}
}
func (ln *Listener) Accept() (net.Conn, error) {
select {
case <-ln.closeCh:
return nil, ErrListenerClosed
case c, ok := <-ln.acceptCh:
if !ok {
return nil, ErrListenerClosed
}
return c, nil
}
}
func (ln *Listener) Addr() net.Addr {
return ln.addr
}
func (ln *Listener) Close() error {
ln.once.Do(func() {
close(ln.closeCh)
ln.onClose(ln)
})
return nil
}

View File

@@ -0,0 +1,68 @@
package group
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestListener_Accept(t *testing.T) {
acceptCh := make(chan net.Conn, 1)
ln := newListener(acceptCh, fakeAddr("127.0.0.1:1234"), func(*Listener) {})
c1, c2 := net.Pipe()
defer c1.Close()
defer c2.Close()
acceptCh <- c1
got, err := ln.Accept()
require.NoError(t, err)
assert.Equal(t, c1, got)
}
func TestListener_AcceptAfterChannelClose(t *testing.T) {
acceptCh := make(chan net.Conn)
ln := newListener(acceptCh, fakeAddr("127.0.0.1:1234"), func(*Listener) {})
close(acceptCh)
_, err := ln.Accept()
assert.ErrorIs(t, err, ErrListenerClosed)
}
func TestListener_AcceptAfterListenerClose(t *testing.T) {
acceptCh := make(chan net.Conn) // open, not closed
ln := newListener(acceptCh, fakeAddr("127.0.0.1:1234"), func(*Listener) {})
ln.Close()
_, err := ln.Accept()
assert.ErrorIs(t, err, ErrListenerClosed)
}
func TestListener_DoubleClose(t *testing.T) {
closeCalls := 0
ln := newListener(
make(chan net.Conn),
fakeAddr("127.0.0.1:1234"),
func(*Listener) { closeCalls++ },
)
assert.NotPanics(t, func() {
ln.Close()
ln.Close()
})
assert.Equal(t, 1, closeCalls, "onClose should be called exactly once")
}
func TestListener_Addr(t *testing.T) {
addr := fakeAddr("10.0.0.1:5555")
ln := newListener(make(chan net.Conn), addr, func(*Listener) {})
assert.Equal(t, addr, ln.Addr())
}
// fakeAddr implements net.Addr for testing.
type fakeAddr string
func (a fakeAddr) Network() string { return "tcp" }
func (a fakeAddr) String() string { return string(a) }

59
server/group/registry.go Normal file
View File

@@ -0,0 +1,59 @@
package group
import (
"sync"
)
// groupRegistry is a concurrent map of named groups with
// automatic creation on first access.
type groupRegistry[G any] struct {
groups map[string]G
mu sync.Mutex
}
func newGroupRegistry[G any]() groupRegistry[G] {
return groupRegistry[G]{
groups: make(map[string]G),
}
}
func (r *groupRegistry[G]) getOrCreate(key string, newFn func() G) G {
r.mu.Lock()
defer r.mu.Unlock()
g, ok := r.groups[key]
if !ok {
g = newFn()
r.groups[key] = g
}
return g
}
func (r *groupRegistry[G]) get(key string) (G, bool) {
r.mu.Lock()
defer r.mu.Unlock()
g, ok := r.groups[key]
return g, ok
}
// isCurrent returns true if key exists in the registry and matchFn
// returns true for the stored value.
func (r *groupRegistry[G]) isCurrent(key string, matchFn func(G) bool) bool {
r.mu.Lock()
defer r.mu.Unlock()
g, ok := r.groups[key]
return ok && matchFn(g)
}
// removeIf atomically looks up the group for key, calls fn on it,
// and removes the entry if fn returns true.
func (r *groupRegistry[G]) removeIf(key string, fn func(G) bool) {
r.mu.Lock()
defer r.mu.Unlock()
g, ok := r.groups[key]
if !ok {
return
}
if fn(g) {
delete(r.groups, key)
}
}

View File

@@ -0,0 +1,102 @@
package group
import (
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetOrCreate_New(t *testing.T) {
r := newGroupRegistry[*int]()
called := 0
v := 42
got := r.getOrCreate("k", func() *int { called++; return &v })
assert.Equal(t, 1, called)
assert.Equal(t, &v, got)
}
func TestGetOrCreate_Existing(t *testing.T) {
r := newGroupRegistry[*int]()
v := 42
r.getOrCreate("k", func() *int { return &v })
called := 0
got := r.getOrCreate("k", func() *int { called++; return nil })
assert.Equal(t, 0, called)
assert.Equal(t, &v, got)
}
func TestGet_ExistingAndMissing(t *testing.T) {
r := newGroupRegistry[*int]()
v := 1
r.getOrCreate("k", func() *int { return &v })
got, ok := r.get("k")
assert.True(t, ok)
assert.Equal(t, &v, got)
_, ok = r.get("missing")
assert.False(t, ok)
}
func TestIsCurrent(t *testing.T) {
r := newGroupRegistry[*int]()
v1 := 1
v2 := 2
r.getOrCreate("k", func() *int { return &v1 })
assert.True(t, r.isCurrent("k", func(g *int) bool { return g == &v1 }))
assert.False(t, r.isCurrent("k", func(g *int) bool { return g == &v2 }))
assert.False(t, r.isCurrent("missing", func(g *int) bool { return true }))
}
func TestRemoveIf(t *testing.T) {
t.Run("removes when fn returns true", func(t *testing.T) {
r := newGroupRegistry[*int]()
v := 1
r.getOrCreate("k", func() *int { return &v })
r.removeIf("k", func(g *int) bool { return g == &v })
_, ok := r.get("k")
assert.False(t, ok)
})
t.Run("keeps when fn returns false", func(t *testing.T) {
r := newGroupRegistry[*int]()
v := 1
r.getOrCreate("k", func() *int { return &v })
r.removeIf("k", func(g *int) bool { return false })
_, ok := r.get("k")
assert.True(t, ok)
})
t.Run("noop on missing key", func(t *testing.T) {
r := newGroupRegistry[*int]()
r.removeIf("missing", func(g *int) bool { return true }) // should not panic
})
}
func TestConcurrentGetOrCreateAndRemoveIf(t *testing.T) {
r := newGroupRegistry[*int]()
const n = 100
var wg sync.WaitGroup
wg.Add(n * 2)
for i := range n {
v := i
go func() {
defer wg.Done()
r.getOrCreate("k", func() *int { return &v })
}()
go func() {
defer wg.Done()
r.removeIf("k", func(*int) bool { return true })
}()
}
wg.Wait()
// After all goroutines finish, accessing the key must not panic.
require.NotPanics(t, func() {
_, _ = r.get("k")
})
}

View File

@@ -17,107 +17,91 @@ package group
import ( import (
"net" "net"
"strconv" "strconv"
"sync"
gerr "github.com/fatedier/golib/errors"
"github.com/fatedier/frp/server/ports" "github.com/fatedier/frp/server/ports"
) )
// TCPGroupCtl manage all TCPGroups // TCPGroupCtl manages all TCPGroups.
type TCPGroupCtl struct { type TCPGroupCtl struct {
groups map[string]*TCPGroup groupRegistry[*TCPGroup]
// portManager is used to manage port
portManager *ports.Manager portManager *ports.Manager
mu sync.Mutex
} }
// NewTCPGroupCtl return a new TcpGroupCtl // NewTCPGroupCtl returns a new TCPGroupCtl.
func NewTCPGroupCtl(portManager *ports.Manager) *TCPGroupCtl { func NewTCPGroupCtl(portManager *ports.Manager) *TCPGroupCtl {
return &TCPGroupCtl{ return &TCPGroupCtl{
groups: make(map[string]*TCPGroup), groupRegistry: newGroupRegistry[*TCPGroup](),
portManager: portManager, portManager: portManager,
} }
} }
// Listen is the wrapper for TCPGroup's Listen // Listen is the wrapper for TCPGroup's Listen.
// If there are no group, we will create one here // If there is no group, one will be created.
func (tgc *TCPGroupCtl) Listen(proxyName string, group string, groupKey string, func (tgc *TCPGroupCtl) Listen(proxyName string, group string, groupKey string,
addr string, port int, addr string, port int,
) (l net.Listener, realPort int, err error) { ) (l net.Listener, realPort int, err error) {
tgc.mu.Lock() for {
tcpGroup, ok := tgc.groups[group] tcpGroup := tgc.getOrCreate(group, func() *TCPGroup {
if !ok { return NewTCPGroup(tgc)
tcpGroup = NewTCPGroup(tgc) })
tgc.groups[group] = tcpGroup l, realPort, err = tcpGroup.Listen(proxyName, group, groupKey, addr, port)
if err == errGroupStale {
continue
}
return
} }
tgc.mu.Unlock()
return tcpGroup.Listen(proxyName, group, groupKey, addr, port)
} }
// RemoveGroup remove TCPGroup from controller // TCPGroup routes connections to different proxies.
func (tgc *TCPGroupCtl) RemoveGroup(group string) {
tgc.mu.Lock()
defer tgc.mu.Unlock()
delete(tgc.groups, group)
}
// TCPGroup route connections to different proxies
type TCPGroup struct { type TCPGroup struct {
group string baseGroup
groupKey string
addr string addr string
port int port int
realPort int realPort int
acceptCh chan net.Conn
tcpLn net.Listener
lns []*TCPGroupListener
ctl *TCPGroupCtl ctl *TCPGroupCtl
mu sync.Mutex
} }
// NewTCPGroup return a new TCPGroup // NewTCPGroup returns a new TCPGroup.
func NewTCPGroup(ctl *TCPGroupCtl) *TCPGroup { func NewTCPGroup(ctl *TCPGroupCtl) *TCPGroup {
return &TCPGroup{ return &TCPGroup{
lns: make([]*TCPGroupListener, 0), ctl: ctl,
ctl: ctl,
acceptCh: make(chan net.Conn),
} }
} }
// Listen will return a new TCPGroupListener // Listen will return a new Listener.
// if TCPGroup already has a listener, just add a new TCPGroupListener to the queues // If TCPGroup already has a listener, just add a new Listener to the queues,
// otherwise, listen on the real address // otherwise listen on the real address.
func (tg *TCPGroup) Listen(proxyName string, group string, groupKey string, addr string, port int) (ln *TCPGroupListener, realPort int, err error) { func (tg *TCPGroup) Listen(proxyName string, group string, groupKey string, addr string, port int) (ln *Listener, realPort int, err error) {
tg.mu.Lock() tg.mu.Lock()
defer tg.mu.Unlock() defer tg.mu.Unlock()
if !tg.ctl.isCurrent(group, func(cur *TCPGroup) bool { return cur == tg }) {
return nil, 0, errGroupStale
}
if len(tg.lns) == 0 { if len(tg.lns) == 0 {
// the first listener, listen on the real address // the first listener, listen on the real address
realPort, err = tg.ctl.portManager.Acquire(proxyName, port) realPort, err = tg.ctl.portManager.Acquire(proxyName, port)
if err != nil { if err != nil {
return return
} }
tcpLn, errRet := net.Listen("tcp", net.JoinHostPort(addr, strconv.Itoa(port))) tcpLn, errRet := net.Listen("tcp", net.JoinHostPort(addr, strconv.Itoa(realPort)))
if errRet != nil { if errRet != nil {
tg.ctl.portManager.Release(realPort)
err = errRet err = errRet
return return
} }
ln = newTCPGroupListener(group, tg, tcpLn.Addr())
tg.group = group
tg.groupKey = groupKey
tg.addr = addr tg.addr = addr
tg.port = port tg.port = port
tg.realPort = realPort tg.realPort = realPort
tg.tcpLn = tcpLn tg.initBase(group, groupKey, tcpLn, func() {
tg.lns = append(tg.lns, ln) tg.ctl.portManager.Release(tg.realPort)
if tg.acceptCh == nil { tg.ctl.removeIf(tg.group, func(cur *TCPGroup) bool {
tg.acceptCh = make(chan net.Conn) return cur == tg
} })
go tg.worker() })
ln = tg.newListener(tcpLn.Addr())
go tg.worker(tcpLn, tg.acceptCh)
} else { } else {
// address and port in the same group must be equal // address and port in the same group must be equal
if tg.group != group || tg.addr != addr { if tg.group != group || tg.addr != addr {
@@ -132,92 +116,8 @@ func (tg *TCPGroup) Listen(proxyName string, group string, groupKey string, addr
err = ErrGroupAuthFailed err = ErrGroupAuthFailed
return return
} }
ln = newTCPGroupListener(group, tg, tg.lns[0].Addr()) ln = tg.newListener(tg.lns[0].Addr())
realPort = tg.realPort realPort = tg.realPort
tg.lns = append(tg.lns, ln)
} }
return return
} }
// worker is called when the real tcp listener has been created
func (tg *TCPGroup) worker() {
for {
c, err := tg.tcpLn.Accept()
if err != nil {
return
}
err = gerr.PanicToError(func() {
tg.acceptCh <- c
})
if err != nil {
return
}
}
}
func (tg *TCPGroup) Accept() <-chan net.Conn {
return tg.acceptCh
}
// CloseListener remove the TCPGroupListener from the TCPGroup
func (tg *TCPGroup) CloseListener(ln *TCPGroupListener) {
tg.mu.Lock()
defer tg.mu.Unlock()
for i, tmpLn := range tg.lns {
if tmpLn == ln {
tg.lns = append(tg.lns[:i], tg.lns[i+1:]...)
break
}
}
if len(tg.lns) == 0 {
close(tg.acceptCh)
tg.tcpLn.Close()
tg.ctl.portManager.Release(tg.realPort)
tg.ctl.RemoveGroup(tg.group)
}
}
// TCPGroupListener
type TCPGroupListener struct {
groupName string
group *TCPGroup
addr net.Addr
closeCh chan struct{}
}
func newTCPGroupListener(name string, group *TCPGroup, addr net.Addr) *TCPGroupListener {
return &TCPGroupListener{
groupName: name,
group: group,
addr: addr,
closeCh: make(chan struct{}),
}
}
// Accept will accept connections from TCPGroup
func (ln *TCPGroupListener) Accept() (c net.Conn, err error) {
var ok bool
select {
case <-ln.closeCh:
return nil, ErrListenerClosed
case c, ok = <-ln.group.Accept():
if !ok {
return nil, ErrListenerClosed
}
return c, nil
}
}
func (ln *TCPGroupListener) Addr() net.Addr {
return ln.addr
}
// Close close the listener
func (ln *TCPGroupListener) Close() (err error) {
close(ln.closeCh)
// remove self from TcpGroup
ln.group.CloseListener(ln)
return
}

View File

@@ -18,118 +18,100 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"sync"
gerr "github.com/fatedier/golib/errors"
v1 "github.com/fatedier/frp/pkg/config/v1" v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/util/tcpmux" "github.com/fatedier/frp/pkg/util/tcpmux"
"github.com/fatedier/frp/pkg/util/vhost" "github.com/fatedier/frp/pkg/util/vhost"
) )
// TCPMuxGroupCtl manage all TCPMuxGroups // TCPMuxGroupCtl manages all TCPMuxGroups.
type TCPMuxGroupCtl struct { type TCPMuxGroupCtl struct {
groups map[string]*TCPMuxGroup groupRegistry[*TCPMuxGroup]
// portManager is used to manage port
tcpMuxHTTPConnectMuxer *tcpmux.HTTPConnectTCPMuxer tcpMuxHTTPConnectMuxer *tcpmux.HTTPConnectTCPMuxer
mu sync.Mutex
} }
// NewTCPMuxGroupCtl return a new TCPMuxGroupCtl // NewTCPMuxGroupCtl returns a new TCPMuxGroupCtl.
func NewTCPMuxGroupCtl(tcpMuxHTTPConnectMuxer *tcpmux.HTTPConnectTCPMuxer) *TCPMuxGroupCtl { func NewTCPMuxGroupCtl(tcpMuxHTTPConnectMuxer *tcpmux.HTTPConnectTCPMuxer) *TCPMuxGroupCtl {
return &TCPMuxGroupCtl{ return &TCPMuxGroupCtl{
groups: make(map[string]*TCPMuxGroup), groupRegistry: newGroupRegistry[*TCPMuxGroup](),
tcpMuxHTTPConnectMuxer: tcpMuxHTTPConnectMuxer, tcpMuxHTTPConnectMuxer: tcpMuxHTTPConnectMuxer,
} }
} }
// Listen is the wrapper for TCPMuxGroup's Listen // Listen is the wrapper for TCPMuxGroup's Listen.
// If there are no group, we will create one here // If there is no group, one will be created.
func (tmgc *TCPMuxGroupCtl) Listen( func (tmgc *TCPMuxGroupCtl) Listen(
ctx context.Context, ctx context.Context,
multiplexer, group, groupKey string, multiplexer, group, groupKey string,
routeConfig vhost.RouteConfig, routeConfig vhost.RouteConfig,
) (l net.Listener, err error) { ) (l net.Listener, err error) {
tmgc.mu.Lock() for {
tcpMuxGroup, ok := tmgc.groups[group] tcpMuxGroup := tmgc.getOrCreate(group, func() *TCPMuxGroup {
if !ok { return NewTCPMuxGroup(tmgc)
tcpMuxGroup = NewTCPMuxGroup(tmgc) })
tmgc.groups[group] = tcpMuxGroup
}
tmgc.mu.Unlock()
switch v1.TCPMultiplexerType(multiplexer) { switch v1.TCPMultiplexerType(multiplexer) {
case v1.TCPMultiplexerHTTPConnect: case v1.TCPMultiplexerHTTPConnect:
return tcpMuxGroup.HTTPConnectListen(ctx, group, groupKey, routeConfig) l, err = tcpMuxGroup.HTTPConnectListen(ctx, group, groupKey, routeConfig)
default: if err == errGroupStale {
err = fmt.Errorf("unknown multiplexer [%s]", multiplexer) continue
return }
return
default:
return nil, fmt.Errorf("unknown multiplexer [%s]", multiplexer)
}
} }
} }
// RemoveGroup remove TCPMuxGroup from controller // TCPMuxGroup routes connections to different proxies.
func (tmgc *TCPMuxGroupCtl) RemoveGroup(group string) {
tmgc.mu.Lock()
defer tmgc.mu.Unlock()
delete(tmgc.groups, group)
}
// TCPMuxGroup route connections to different proxies
type TCPMuxGroup struct { type TCPMuxGroup struct {
group string baseGroup
groupKey string
domain string domain string
routeByHTTPUser string routeByHTTPUser string
username string username string
password string password string
ctl *TCPMuxGroupCtl
acceptCh chan net.Conn
tcpMuxLn net.Listener
lns []*TCPMuxGroupListener
ctl *TCPMuxGroupCtl
mu sync.Mutex
} }
// NewTCPMuxGroup return a new TCPMuxGroup // NewTCPMuxGroup returns a new TCPMuxGroup.
func NewTCPMuxGroup(ctl *TCPMuxGroupCtl) *TCPMuxGroup { func NewTCPMuxGroup(ctl *TCPMuxGroupCtl) *TCPMuxGroup {
return &TCPMuxGroup{ return &TCPMuxGroup{
lns: make([]*TCPMuxGroupListener, 0), ctl: ctl,
ctl: ctl,
acceptCh: make(chan net.Conn),
} }
} }
// Listen will return a new TCPMuxGroupListener // HTTPConnectListen will return a new Listener.
// if TCPMuxGroup already has a listener, just add a new TCPMuxGroupListener to the queues // If TCPMuxGroup already has a listener, just add a new Listener to the queues,
// otherwise, listen on the real address // otherwise listen on the real address.
func (tmg *TCPMuxGroup) HTTPConnectListen( func (tmg *TCPMuxGroup) HTTPConnectListen(
ctx context.Context, ctx context.Context,
group, groupKey string, group, groupKey string,
routeConfig vhost.RouteConfig, routeConfig vhost.RouteConfig,
) (ln *TCPMuxGroupListener, err error) { ) (ln *Listener, err error) {
tmg.mu.Lock() tmg.mu.Lock()
defer tmg.mu.Unlock() defer tmg.mu.Unlock()
if !tmg.ctl.isCurrent(group, func(cur *TCPMuxGroup) bool { return cur == tmg }) {
return nil, errGroupStale
}
if len(tmg.lns) == 0 { if len(tmg.lns) == 0 {
// the first listener, listen on the real address // the first listener, listen on the real address
tcpMuxLn, errRet := tmg.ctl.tcpMuxHTTPConnectMuxer.Listen(ctx, &routeConfig) tcpMuxLn, errRet := tmg.ctl.tcpMuxHTTPConnectMuxer.Listen(ctx, &routeConfig)
if errRet != nil { if errRet != nil {
return nil, errRet return nil, errRet
} }
ln = newTCPMuxGroupListener(group, tmg, tcpMuxLn.Addr())
tmg.group = group
tmg.groupKey = groupKey
tmg.domain = routeConfig.Domain tmg.domain = routeConfig.Domain
tmg.routeByHTTPUser = routeConfig.RouteByHTTPUser tmg.routeByHTTPUser = routeConfig.RouteByHTTPUser
tmg.username = routeConfig.Username tmg.username = routeConfig.Username
tmg.password = routeConfig.Password tmg.password = routeConfig.Password
tmg.tcpMuxLn = tcpMuxLn tmg.initBase(group, groupKey, tcpMuxLn, func() {
tmg.lns = append(tmg.lns, ln) tmg.ctl.removeIf(tmg.group, func(cur *TCPMuxGroup) bool {
if tmg.acceptCh == nil { return cur == tmg
tmg.acceptCh = make(chan net.Conn) })
} })
go tmg.worker() ln = tmg.newListener(tcpMuxLn.Addr())
go tmg.worker(tcpMuxLn, tmg.acceptCh)
} else { } else {
// route config in the same group must be equal // route config in the same group must be equal
if tmg.group != group || tmg.domain != routeConfig.Domain || if tmg.group != group || tmg.domain != routeConfig.Domain ||
@@ -141,90 +123,7 @@ func (tmg *TCPMuxGroup) HTTPConnectListen(
if tmg.groupKey != groupKey { if tmg.groupKey != groupKey {
return nil, ErrGroupAuthFailed return nil, ErrGroupAuthFailed
} }
ln = newTCPMuxGroupListener(group, tmg, tmg.lns[0].Addr()) ln = tmg.newListener(tmg.lns[0].Addr())
tmg.lns = append(tmg.lns, ln)
} }
return return
} }
// worker is called when the real TCP listener has been created
func (tmg *TCPMuxGroup) worker() {
for {
c, err := tmg.tcpMuxLn.Accept()
if err != nil {
return
}
err = gerr.PanicToError(func() {
tmg.acceptCh <- c
})
if err != nil {
return
}
}
}
func (tmg *TCPMuxGroup) Accept() <-chan net.Conn {
return tmg.acceptCh
}
// CloseListener remove the TCPMuxGroupListener from the TCPMuxGroup
func (tmg *TCPMuxGroup) CloseListener(ln *TCPMuxGroupListener) {
tmg.mu.Lock()
defer tmg.mu.Unlock()
for i, tmpLn := range tmg.lns {
if tmpLn == ln {
tmg.lns = append(tmg.lns[:i], tmg.lns[i+1:]...)
break
}
}
if len(tmg.lns) == 0 {
close(tmg.acceptCh)
tmg.tcpMuxLn.Close()
tmg.ctl.RemoveGroup(tmg.group)
}
}
// TCPMuxGroupListener
type TCPMuxGroupListener struct {
groupName string
group *TCPMuxGroup
addr net.Addr
closeCh chan struct{}
}
func newTCPMuxGroupListener(name string, group *TCPMuxGroup, addr net.Addr) *TCPMuxGroupListener {
return &TCPMuxGroupListener{
groupName: name,
group: group,
addr: addr,
closeCh: make(chan struct{}),
}
}
// Accept will accept connections from TCPMuxGroup
func (ln *TCPMuxGroupListener) Accept() (c net.Conn, err error) {
var ok bool
select {
case <-ln.closeCh:
return nil, ErrListenerClosed
case c, ok = <-ln.group.Accept():
if !ok {
return nil, ErrListenerClosed
}
return c, nil
}
}
func (ln *TCPMuxGroupListener) Addr() net.Addr {
return ln.addr
}
// Close close the listener
func (ln *TCPMuxGroupListener) Close() (err error) {
close(ln.closeCh)
// remove self from TcpMuxGroup
ln.group.CloseListener(ln)
return
}

View File

@@ -31,7 +31,7 @@ import (
) )
func init() { func init() {
RegisterProxyFactory(reflect.TypeOf(&v1.HTTPProxyConfig{}), NewHTTPProxy) RegisterProxyFactory(reflect.TypeFor[*v1.HTTPProxyConfig](), NewHTTPProxy)
} }
type HTTPProxy struct { type HTTPProxy struct {
@@ -75,16 +75,13 @@ func (pxy *HTTPProxy) Run() (remoteAddr string, err error) {
} }
}() }()
addrs := make([]string, 0) domains := pxy.buildDomains(pxy.cfg.CustomDomains, pxy.cfg.SubDomain)
for _, domain := range pxy.cfg.CustomDomains {
if domain == "" {
continue
}
addrs := make([]string, 0)
for _, domain := range domains {
routeConfig.Domain = domain routeConfig.Domain = domain
for _, location := range locations { for _, location := range locations {
routeConfig.Location = location routeConfig.Location = location
tmpRouteConfig := routeConfig tmpRouteConfig := routeConfig
// handle group // handle group
@@ -93,12 +90,10 @@ func (pxy *HTTPProxy) Run() (remoteAddr string, err error) {
if err != nil { if err != nil {
return return
} }
pxy.closeFuncs = append(pxy.closeFuncs, func() { pxy.closeFuncs = append(pxy.closeFuncs, func() {
pxy.rc.HTTPGroupCtl.UnRegister(pxy.name, pxy.cfg.LoadBalancer.Group, tmpRouteConfig) pxy.rc.HTTPGroupCtl.UnRegister(pxy.name, pxy.cfg.LoadBalancer.Group, tmpRouteConfig)
}) })
} else { } else {
// no group
err = pxy.rc.HTTPReverseProxy.Register(routeConfig) err = pxy.rc.HTTPReverseProxy.Register(routeConfig)
if err != nil { if err != nil {
return return
@@ -112,39 +107,6 @@ func (pxy *HTTPProxy) Run() (remoteAddr string, err error) {
routeConfig.Domain, routeConfig.Location, pxy.cfg.LoadBalancer.Group, pxy.cfg.RouteByHTTPUser) routeConfig.Domain, routeConfig.Location, pxy.cfg.LoadBalancer.Group, pxy.cfg.RouteByHTTPUser)
} }
} }
if pxy.cfg.SubDomain != "" {
routeConfig.Domain = pxy.cfg.SubDomain + "." + pxy.serverCfg.SubDomainHost
for _, location := range locations {
routeConfig.Location = location
tmpRouteConfig := routeConfig
// handle group
if pxy.cfg.LoadBalancer.Group != "" {
err = pxy.rc.HTTPGroupCtl.Register(pxy.name, pxy.cfg.LoadBalancer.Group, pxy.cfg.LoadBalancer.GroupKey, routeConfig)
if err != nil {
return
}
pxy.closeFuncs = append(pxy.closeFuncs, func() {
pxy.rc.HTTPGroupCtl.UnRegister(pxy.name, pxy.cfg.LoadBalancer.Group, tmpRouteConfig)
})
} else {
err = pxy.rc.HTTPReverseProxy.Register(routeConfig)
if err != nil {
return
}
pxy.closeFuncs = append(pxy.closeFuncs, func() {
pxy.rc.HTTPReverseProxy.UnRegister(tmpRouteConfig)
})
}
addrs = append(addrs, util.CanonicalAddr(tmpRouteConfig.Domain, pxy.serverCfg.VhostHTTPPort))
xl.Infof("http proxy listen for host [%s] location [%s] group [%s], routeByHTTPUser [%s]",
routeConfig.Domain, routeConfig.Location, pxy.cfg.LoadBalancer.Group, pxy.cfg.RouteByHTTPUser)
}
}
remoteAddr = strings.Join(addrs, ",") remoteAddr = strings.Join(addrs, ",")
return return
} }
@@ -168,6 +130,7 @@ func (pxy *HTTPProxy) GetRealConn(remoteAddr string) (workConn net.Conn, err err
rwc, err = libio.WithEncryption(rwc, pxy.encryptionKey) rwc, err = libio.WithEncryption(rwc, pxy.encryptionKey)
if err != nil { if err != nil {
xl.Errorf("create encryption stream error: %v", err) xl.Errorf("create encryption stream error: %v", err)
tmpConn.Close()
return return
} }
} }

View File

@@ -25,7 +25,7 @@ import (
) )
func init() { func init() {
RegisterProxyFactory(reflect.TypeOf(&v1.HTTPSProxyConfig{}), NewHTTPSProxy) RegisterProxyFactory(reflect.TypeFor[*v1.HTTPSProxyConfig](), NewHTTPSProxy)
} }
type HTTPSProxy struct { type HTTPSProxy struct {
@@ -53,23 +53,10 @@ func (pxy *HTTPSProxy) Run() (remoteAddr string, err error) {
pxy.Close() pxy.Close()
} }
}() }()
domains := pxy.buildDomains(pxy.cfg.CustomDomains, pxy.cfg.SubDomain)
addrs := make([]string, 0) addrs := make([]string, 0)
for _, domain := range pxy.cfg.CustomDomains { for _, domain := range domains {
if domain == "" {
continue
}
l, err := pxy.listenForDomain(routeConfig, domain)
if err != nil {
return "", err
}
pxy.listeners = append(pxy.listeners, l)
addrs = append(addrs, util.CanonicalAddr(domain, pxy.serverCfg.VhostHTTPSPort))
xl.Infof("https proxy listen for host [%s] group [%s]", domain, pxy.cfg.LoadBalancer.Group)
}
if pxy.cfg.SubDomain != "" {
domain := pxy.cfg.SubDomain + "." + pxy.serverCfg.SubDomainHost
l, err := pxy.listenForDomain(routeConfig, domain) l, err := pxy.listenForDomain(routeConfig, domain)
if err != nil { if err != nil {
return "", err return "", err

View File

@@ -150,7 +150,7 @@ func (pxy *BaseProxy) GetWorkConnFromPool(src, dst net.Addr) (workConn net.Conn,
dstAddr, dstPortStr, _ = net.SplitHostPort(dst.String()) dstAddr, dstPortStr, _ = net.SplitHostPort(dst.String())
dstPort, _ = strconv.ParseUint(dstPortStr, 10, 16) dstPort, _ = strconv.ParseUint(dstPortStr, 10, 16)
} }
err := msg.WriteMsg(workConn, &msg.StartWorkConn{ err = msg.WriteMsg(workConn, &msg.StartWorkConn{
ProxyName: pxy.GetName(), ProxyName: pxy.GetName(),
SrcAddr: srcAddr, SrcAddr: srcAddr,
SrcPort: uint16(srcPort), SrcPort: uint16(srcPort),
@@ -161,6 +161,7 @@ func (pxy *BaseProxy) GetWorkConnFromPool(src, dst net.Addr) (workConn net.Conn,
if err != nil { if err != nil {
xl.Warnf("failed to send message to work connection from pool: %v, times: %d", err, i) xl.Warnf("failed to send message to work connection from pool: %v, times: %d", err, i)
workConn.Close() workConn.Close()
workConn = nil
} else { } else {
break break
} }
@@ -173,6 +174,36 @@ func (pxy *BaseProxy) GetWorkConnFromPool(src, dst net.Addr) (workConn net.Conn,
return return
} }
// startVisitorListener sets up a VisitorManager listener for visitor-based proxies (STCP, SUDP).
func (pxy *BaseProxy) startVisitorListener(secretKey string, allowUsers []string, proxyType string) error {
// if allowUsers is empty, only allow same user from proxy
if len(allowUsers) == 0 {
allowUsers = []string{pxy.GetUserInfo().User}
}
listener, err := pxy.rc.VisitorManager.Listen(pxy.GetName(), secretKey, allowUsers)
if err != nil {
return err
}
pxy.listeners = append(pxy.listeners, listener)
pxy.xl.Infof("%s proxy custom listen success", proxyType)
pxy.startCommonTCPListenersHandler()
return nil
}
// buildDomains constructs a list of domains from custom domains and subdomain configuration.
func (pxy *BaseProxy) buildDomains(customDomains []string, subDomain string) []string {
domains := make([]string, 0, len(customDomains)+1)
for _, d := range customDomains {
if d != "" {
domains = append(domains, d)
}
}
if subDomain != "" {
domains = append(domains, subDomain+"."+pxy.serverCfg.SubDomainHost)
}
return domains
}
// startCommonTCPListenersHandler start a goroutine handler for each listener. // startCommonTCPListenersHandler start a goroutine handler for each listener.
func (pxy *BaseProxy) startCommonTCPListenersHandler() { func (pxy *BaseProxy) startCommonTCPListenersHandler() {
xl := xlog.FromContextSafe(pxy.ctx) xl := xlog.FromContextSafe(pxy.ctx)

View File

@@ -21,7 +21,7 @@ import (
) )
func init() { func init() {
RegisterProxyFactory(reflect.TypeOf(&v1.STCPProxyConfig{}), NewSTCPProxy) RegisterProxyFactory(reflect.TypeFor[*v1.STCPProxyConfig](), NewSTCPProxy)
} }
type STCPProxy struct { type STCPProxy struct {
@@ -41,21 +41,7 @@ func NewSTCPProxy(baseProxy *BaseProxy) Proxy {
} }
func (pxy *STCPProxy) Run() (remoteAddr string, err error) { func (pxy *STCPProxy) Run() (remoteAddr string, err error) {
xl := pxy.xl err = pxy.startVisitorListener(pxy.cfg.Secretkey, pxy.cfg.AllowUsers, "stcp")
allowUsers := pxy.cfg.AllowUsers
// if allowUsers is empty, only allow same user from proxy
if len(allowUsers) == 0 {
allowUsers = []string{pxy.GetUserInfo().User}
}
listener, errRet := pxy.rc.VisitorManager.Listen(pxy.GetName(), pxy.cfg.Secretkey, allowUsers)
if errRet != nil {
err = errRet
return
}
pxy.listeners = append(pxy.listeners, listener)
xl.Infof("stcp proxy custom listen success")
pxy.startCommonTCPListenersHandler()
return return
} }

View File

@@ -21,7 +21,7 @@ import (
) )
func init() { func init() {
RegisterProxyFactory(reflect.TypeOf(&v1.SUDPProxyConfig{}), NewSUDPProxy) RegisterProxyFactory(reflect.TypeFor[*v1.SUDPProxyConfig](), NewSUDPProxy)
} }
type SUDPProxy struct { type SUDPProxy struct {
@@ -41,21 +41,7 @@ func NewSUDPProxy(baseProxy *BaseProxy) Proxy {
} }
func (pxy *SUDPProxy) Run() (remoteAddr string, err error) { func (pxy *SUDPProxy) Run() (remoteAddr string, err error) {
xl := pxy.xl err = pxy.startVisitorListener(pxy.cfg.Secretkey, pxy.cfg.AllowUsers, "sudp")
allowUsers := pxy.cfg.AllowUsers
// if allowUsers is empty, only allow same user from proxy
if len(allowUsers) == 0 {
allowUsers = []string{pxy.GetUserInfo().User}
}
listener, errRet := pxy.rc.VisitorManager.Listen(pxy.GetName(), pxy.cfg.Secretkey, allowUsers)
if errRet != nil {
err = errRet
return
}
pxy.listeners = append(pxy.listeners, listener)
xl.Infof("sudp proxy custom listen success")
pxy.startCommonTCPListenersHandler()
return return
} }

View File

@@ -24,7 +24,7 @@ import (
) )
func init() { func init() {
RegisterProxyFactory(reflect.TypeOf(&v1.TCPProxyConfig{}), NewTCPProxy) RegisterProxyFactory(reflect.TypeFor[*v1.TCPProxyConfig](), NewTCPProxy)
} }
type TCPProxy struct { type TCPProxy struct {

View File

@@ -26,7 +26,7 @@ import (
) )
func init() { func init() {
RegisterProxyFactory(reflect.TypeOf(&v1.TCPMuxProxyConfig{}), NewTCPMuxProxy) RegisterProxyFactory(reflect.TypeFor[*v1.TCPMuxProxyConfig](), NewTCPMuxProxy)
} }
type TCPMuxProxy struct { type TCPMuxProxy struct {
@@ -72,26 +72,16 @@ func (pxy *TCPMuxProxy) httpConnectListen(
} }
func (pxy *TCPMuxProxy) httpConnectRun() (remoteAddr string, err error) { func (pxy *TCPMuxProxy) httpConnectRun() (remoteAddr string, err error) {
domains := pxy.buildDomains(pxy.cfg.CustomDomains, pxy.cfg.SubDomain)
addrs := make([]string, 0) addrs := make([]string, 0)
for _, domain := range pxy.cfg.CustomDomains { for _, domain := range domains {
if domain == "" {
continue
}
addrs, err = pxy.httpConnectListen(domain, pxy.cfg.RouteByHTTPUser, pxy.cfg.HTTPUser, pxy.cfg.HTTPPassword, addrs) addrs, err = pxy.httpConnectListen(domain, pxy.cfg.RouteByHTTPUser, pxy.cfg.HTTPUser, pxy.cfg.HTTPPassword, addrs)
if err != nil { if err != nil {
return "", err return "", err
} }
} }
if pxy.cfg.SubDomain != "" {
addrs, err = pxy.httpConnectListen(pxy.cfg.SubDomain+"."+pxy.serverCfg.SubDomainHost,
pxy.cfg.RouteByHTTPUser, pxy.cfg.HTTPUser, pxy.cfg.HTTPPassword, addrs)
if err != nil {
return "", err
}
}
pxy.startCommonTCPListenersHandler() pxy.startCommonTCPListenersHandler()
remoteAddr = strings.Join(addrs, ",") remoteAddr = strings.Join(addrs, ",")
return remoteAddr, err return remoteAddr, err

View File

@@ -35,7 +35,7 @@ import (
) )
func init() { func init() {
RegisterProxyFactory(reflect.TypeOf(&v1.UDPProxyConfig{}), NewUDPProxy) RegisterProxyFactory(reflect.TypeFor[*v1.UDPProxyConfig](), NewUDPProxy)
} }
type UDPProxy struct { type UDPProxy struct {
@@ -136,7 +136,7 @@ func (pxy *UDPProxy) Run() (remoteAddr string, err error) {
continue continue
case *msg.UDPPacket: case *msg.UDPPacket:
if errRet := errors.PanicToError(func() { if errRet := errors.PanicToError(func() {
xl.Tracef("get udp message from workConn: %s", m.Content) xl.Tracef("get udp message from workConn, len: %d", len(m.Content))
pxy.readCh <- m pxy.readCh <- m
metrics.Server.AddTrafficOut( metrics.Server.AddTrafficOut(
pxy.GetName(), pxy.GetName(),
@@ -167,7 +167,7 @@ func (pxy *UDPProxy) Run() (remoteAddr string, err error) {
conn.Close() conn.Close()
return return
} }
xl.Tracef("send message to udp workConn: %s", udpMsg.Content) xl.Tracef("send message to udp workConn, len: %d", len(udpMsg.Content))
metrics.Server.AddTrafficIn( metrics.Server.AddTrafficIn(
pxy.GetName(), pxy.GetName(),
pxy.GetConfigurer().GetBaseConfig().Type, pxy.GetConfigurer().GetBaseConfig().Type,

View File

@@ -24,7 +24,7 @@ import (
) )
func init() { func init() {
RegisterProxyFactory(reflect.TypeOf(&v1.XTCPProxyConfig{}), NewXTCPProxy) RegisterProxyFactory(reflect.TypeFor[*v1.XTCPProxyConfig](), NewXTCPProxy)
} }
type XTCPProxy struct { type XTCPProxy struct {

View File

@@ -193,7 +193,7 @@ func NewService(cfg *v1.ServerConfig) (*Service, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("create vhost tcpMuxer error, %v", err) return nil, fmt.Errorf("create vhost tcpMuxer error, %v", err)
} }
log.Infof("tcpmux httpconnect multiplexer listen on %s, passthough: %v", address, cfg.TCPMuxPassthrough) log.Infof("tcpmux httpconnect multiplexer listen on %s, passthrough: %v", address, cfg.TCPMuxPassthrough)
} }
// Init all plugins // Init all plugins
@@ -604,8 +604,18 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, inter
return err return err
} }
// TODO(fatedier): use SessionContext ctl, err := NewControl(ctx, &SessionContext{
ctl, err := NewControl(ctx, svr.rc, svr.pxyManager, svr.pluginManager, authVerifier, svr.auth.EncryptionKey(), ctlConn, !internal, loginMsg, svr.cfg) RC: svr.rc,
PxyManager: svr.pxyManager,
PluginManager: svr.pluginManager,
AuthVerifier: authVerifier,
EncryptionKey: svr.auth.EncryptionKey(),
Conn: ctlConn,
ConnEncrypted: !internal,
LoginMsg: loginMsg,
ServerCfg: svr.cfg,
ClientRegistry: svr.clientRegistry,
})
if err != nil { if err != nil {
xl.Warnf("create new controller error: %v", err) xl.Warnf("create new controller error: %v", err)
// don't return detailed errors to client // don't return detailed errors to client
@@ -626,7 +636,6 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, inter
ctl.Close() ctl.Close()
return fmt.Errorf("client_id [%s] for user [%s] is already online", loginMsg.ClientID, loginMsg.User) return fmt.Errorf("client_id [%s] for user [%s] is already online", loginMsg.ClientID, loginMsg.User)
} }
ctl.clientRegistry = svr.clientRegistry
ctl.Start() ctl.Start()
@@ -652,9 +661,9 @@ func (svr *Service) RegisterWorkConn(workConn net.Conn, newMsg *msg.NewWorkConn)
// server plugin hook // server plugin hook
content := &plugin.NewWorkConnContent{ content := &plugin.NewWorkConnContent{
User: plugin.UserInfo{ User: plugin.UserInfo{
User: ctl.loginMsg.User, User: ctl.sessionCtx.LoginMsg.User,
Metas: ctl.loginMsg.Metas, Metas: ctl.sessionCtx.LoginMsg.Metas,
RunID: ctl.loginMsg.RunID, RunID: ctl.sessionCtx.LoginMsg.RunID,
}, },
NewWorkConn: *newMsg, NewWorkConn: *newMsg,
} }
@@ -662,7 +671,7 @@ func (svr *Service) RegisterWorkConn(workConn net.Conn, newMsg *msg.NewWorkConn)
if err == nil { if err == nil {
newMsg = &retContent.NewWorkConn newMsg = &retContent.NewWorkConn
// Check auth. // Check auth.
err = ctl.authVerifier.VerifyNewWorkConn(newMsg) err = ctl.sessionCtx.AuthVerifier.VerifyNewWorkConn(newMsg)
} }
if err != nil { if err != nil {
xl.Warnf("invalid NewWorkConn with run id [%s]", newMsg.RunID) xl.Warnf("invalid NewWorkConn with run id [%s]", newMsg.RunID)
@@ -683,7 +692,7 @@ func (svr *Service) RegisterVisitorConn(visitorConn net.Conn, newMsg *msg.NewVis
if !exist { if !exist {
return fmt.Errorf("no client control found for run id [%s]", newMsg.RunID) return fmt.Errorf("no client control found for run id [%s]", newMsg.RunID)
} }
visitorUser = ctl.loginMsg.User visitorUser = ctl.sessionCtx.LoginMsg.User
} }
return svr.rc.VisitorManager.NewConn(newMsg.ProxyName, visitorConn, newMsg.Timestamp, newMsg.SignKey, return svr.rc.VisitorManager.NewConn(newMsg.ProxyName, visitorConn, newMsg.Timestamp, newMsg.SignKey,
newMsg.UseEncryption, newMsg.UseCompression, visitorUser) newMsg.UseEncryption, newMsg.UseCompression, visitorUser)

View File

@@ -26,7 +26,7 @@ var _ = ginkgo.Describe("[Feature: Example]", func() {
remotePort = %d remotePort = %d
`, framework.TCPEchoServerPort, remotePort) `, framework.TCPEchoServerPort, remotePort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
framework.NewRequestExpect(f).Port(remotePort).Ensure() framework.NewRequestExpect(f).Port(remotePort).Ensure()
}) })

View File

@@ -2,70 +2,85 @@ package framework
import ( import (
"fmt" "fmt"
"maps"
"net"
"os" "os"
"path/filepath" "path/filepath"
"strconv"
"time" "time"
"github.com/fatedier/frp/pkg/config"
flog "github.com/fatedier/frp/pkg/util/log" flog "github.com/fatedier/frp/pkg/util/log"
"github.com/fatedier/frp/test/e2e/framework/consts"
"github.com/fatedier/frp/test/e2e/pkg/process" "github.com/fatedier/frp/test/e2e/pkg/process"
) )
// RunProcesses run multiple processes from templates. // RunProcesses starts one frps and zero or more frpc processes from templates.
// The first template should always be frps. func (f *Framework) RunProcesses(serverTemplate string, clientTemplates []string) (*process.Process, []*process.Process) {
func (f *Framework) RunProcesses(serverTemplates []string, clientTemplates []string) ([]*process.Process, []*process.Process) { templates := append([]string{serverTemplate}, clientTemplates...)
templates := make([]string, 0, len(serverTemplates)+len(clientTemplates))
templates = append(templates, serverTemplates...)
templates = append(templates, clientTemplates...)
outs, ports, err := f.RenderTemplates(templates) outs, ports, err := f.RenderTemplates(templates)
ExpectNoError(err) ExpectNoError(err)
ExpectTrue(len(templates) > 0)
for name, port := range ports { maps.Copy(f.usedPorts, ports)
f.usedPorts[name] = port
// Start frps.
serverPath := filepath.Join(f.TempDirectory, "frp-e2e-server-0")
err = os.WriteFile(serverPath, []byte(outs[0]), 0o600)
ExpectNoError(err)
if TestContext.Debug {
flog.Debugf("[%s] %s", serverPath, outs[0])
} }
currentServerProcesses := make([]*process.Process, 0, len(serverTemplates)) serverProcess := process.NewWithEnvs(TestContext.FRPServerPath, []string{"-c", serverPath}, f.osEnvs)
for i := range serverTemplates { f.serverConfPaths = append(f.serverConfPaths, serverPath)
path := filepath.Join(f.TempDirectory, fmt.Sprintf("frp-e2e-server-%d", i)) f.serverProcesses = append(f.serverProcesses, serverProcess)
err = os.WriteFile(path, []byte(outs[i]), 0o600) err = serverProcess.Start()
ExpectNoError(err) ExpectNoError(err)
if TestContext.Debug { if port, ok := ports[consts.PortServerName]; ok {
flog.Debugf("[%s] %s", path, outs[i]) ExpectNoError(WaitForTCPReady(net.JoinHostPort("127.0.0.1", strconv.Itoa(port)), 5*time.Second))
} } else {
time.Sleep(2 * time.Second)
p := process.NewWithEnvs(TestContext.FRPServerPath, []string{"-c", path}, f.osEnvs)
f.serverConfPaths = append(f.serverConfPaths, path)
f.serverProcesses = append(f.serverProcesses, p)
currentServerProcesses = append(currentServerProcesses, p)
err = p.Start()
ExpectNoError(err)
time.Sleep(500 * time.Millisecond)
} }
time.Sleep(2 * time.Second)
currentClientProcesses := make([]*process.Process, 0, len(clientTemplates)) // Start frpc(s).
clientProcesses := make([]*process.Process, 0, len(clientTemplates))
for i := range clientTemplates { for i := range clientTemplates {
index := i + len(serverTemplates)
path := filepath.Join(f.TempDirectory, fmt.Sprintf("frp-e2e-client-%d", i)) path := filepath.Join(f.TempDirectory, fmt.Sprintf("frp-e2e-client-%d", i))
err = os.WriteFile(path, []byte(outs[index]), 0o600) err = os.WriteFile(path, []byte(outs[1+i]), 0o600)
ExpectNoError(err) ExpectNoError(err)
if TestContext.Debug { if TestContext.Debug {
flog.Debugf("[%s] %s", path, outs[index]) flog.Debugf("[%s] %s", path, outs[1+i])
} }
p := process.NewWithEnvs(TestContext.FRPClientPath, []string{"-c", path}, f.osEnvs) p := process.NewWithEnvs(TestContext.FRPClientPath, []string{"-c", path}, f.osEnvs)
f.clientConfPaths = append(f.clientConfPaths, path) f.clientConfPaths = append(f.clientConfPaths, path)
f.clientProcesses = append(f.clientProcesses, p) f.clientProcesses = append(f.clientProcesses, p)
currentClientProcesses = append(currentClientProcesses, p) clientProcesses = append(clientProcesses, p)
err = p.Start() err = p.Start()
ExpectNoError(err) ExpectNoError(err)
time.Sleep(500 * time.Millisecond)
} }
time.Sleep(3 * time.Second) // Wait for each client's proxies to register with frps.
// If any client has no proxies (e.g. visitor-only), fall back to sleep
// for the remaining time since visitors have no deterministic readiness signal.
allConfirmed := len(clientProcesses) > 0
start := time.Now()
for i, p := range clientProcesses {
configPath := f.clientConfPaths[len(f.clientConfPaths)-len(clientProcesses)+i]
if !waitForClientProxyReady(configPath, p, 5*time.Second) {
allConfirmed = false
}
}
if len(clientProcesses) > 0 && !allConfirmed {
remaining := 1500*time.Millisecond - time.Since(start)
if remaining > 0 {
time.Sleep(remaining)
}
}
return currentServerProcesses, currentClientProcesses return serverProcess, clientProcesses
} }
func (f *Framework) RunFrps(args ...string) (*process.Process, string, error) { func (f *Framework) RunFrps(args ...string) (*process.Process, string, error) {
@@ -73,11 +88,13 @@ func (f *Framework) RunFrps(args ...string) (*process.Process, string, error) {
f.serverProcesses = append(f.serverProcesses, p) f.serverProcesses = append(f.serverProcesses, p)
err := p.Start() err := p.Start()
if err != nil { if err != nil {
return p, p.StdOutput(), err return p, p.Output(), err
} }
// Give frps extra time to finish binding ports before proceeding. select {
time.Sleep(4 * time.Second) case <-p.Done():
return p, p.StdOutput(), nil case <-time.After(2 * time.Second):
}
return p, p.Output(), nil
} }
func (f *Framework) RunFrpc(args ...string) (*process.Process, string, error) { func (f *Framework) RunFrpc(args ...string) (*process.Process, string, error) {
@@ -85,10 +102,13 @@ func (f *Framework) RunFrpc(args ...string) (*process.Process, string, error) {
f.clientProcesses = append(f.clientProcesses, p) f.clientProcesses = append(f.clientProcesses, p)
err := p.Start() err := p.Start()
if err != nil { if err != nil {
return p, p.StdOutput(), err return p, p.Output(), err
} }
time.Sleep(2 * time.Second) select {
return p, p.StdOutput(), nil case <-p.Done():
case <-time.After(1500 * time.Millisecond):
}
return p, p.Output(), nil
} }
func (f *Framework) GenerateConfigFile(content string) string { func (f *Framework) GenerateConfigFile(content string) string {
@@ -98,3 +118,74 @@ func (f *Framework) GenerateConfigFile(content string) string {
ExpectNoError(err) ExpectNoError(err)
return path return path
} }
// waitForClientProxyReady parses the client config to extract proxy names,
// then waits for each proxy's "start proxy success" log in the process output.
// Returns true only if proxies were expected and all registered successfully.
func waitForClientProxyReady(configPath string, p *process.Process, timeout time.Duration) bool {
_, proxyCfgs, _, _, err := config.LoadClientConfig(configPath, false)
if err != nil || len(proxyCfgs) == 0 {
return false
}
// Use a single deadline so the total wait across all proxies does not exceed timeout.
deadline := time.Now().Add(timeout)
for _, cfg := range proxyCfgs {
remaining := time.Until(deadline)
if remaining <= 0 {
return false
}
name := cfg.GetBaseConfig().Name
pattern := fmt.Sprintf("[%s] start proxy success", name)
if err := p.WaitForOutput(pattern, 1, remaining); err != nil {
return false
}
}
return true
}
// WaitForTCPUnreachable polls a TCP address until a connection fails or timeout.
func WaitForTCPUnreachable(addr string, interval, timeout time.Duration) error {
if interval <= 0 {
return fmt.Errorf("invalid interval for TCP unreachable on %s: interval must be positive", addr)
}
if timeout <= 0 {
return fmt.Errorf("invalid timeout for TCP unreachable on %s: timeout must be positive", addr)
}
deadline := time.Now().Add(timeout)
for {
remaining := time.Until(deadline)
if remaining <= 0 {
return fmt.Errorf("timeout waiting for TCP unreachable on %s", addr)
}
dialTimeout := min(interval, remaining)
conn, err := net.DialTimeout("tcp", addr, dialTimeout)
if err != nil {
return nil
}
conn.Close()
time.Sleep(min(interval, time.Until(deadline)))
}
}
// WaitForTCPReady polls a TCP address until a connection succeeds or timeout.
func WaitForTCPReady(addr string, timeout time.Duration) error {
if timeout <= 0 {
return fmt.Errorf("invalid timeout for TCP readiness on %s: timeout must be positive", addr)
}
deadline := time.Now().Add(timeout)
var lastErr error
for time.Now().Before(deadline) {
conn, err := net.DialTimeout("tcp", addr, 100*time.Millisecond)
if err == nil {
conn.Close()
return nil
}
lastErr = err
time.Sleep(50 * time.Millisecond)
}
if lastErr == nil {
return fmt.Errorf("timeout waiting for TCP readiness on %s before any dial attempt", addr)
}
return fmt.Errorf("timeout waiting for TCP readiness on %s: %w", addr, lastErr)
}

View File

@@ -26,7 +26,8 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
proxyType := t proxyType := t
ginkgo.It(fmt.Sprintf("Expose a %s echo server", strings.ToUpper(proxyType)), func() { ginkgo.It(fmt.Sprintf("Expose a %s echo server", strings.ToUpper(proxyType)), func() {
serverConf := consts.LegacyDefaultServerConfig serverConf := consts.LegacyDefaultServerConfig
clientConf := consts.LegacyDefaultClientConfig var clientConf strings.Builder
clientConf.WriteString(consts.LegacyDefaultClientConfig)
localPortName := "" localPortName := ""
protocol := "tcp" protocol := "tcp"
@@ -78,10 +79,10 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
// build all client config // build all client config
for _, test := range tests { for _, test := range tests {
clientConf += getProxyConf(test.proxyName, test.portName, test.extraConfig) + "\n" clientConf.WriteString(getProxyConf(test.proxyName, test.portName, test.extraConfig) + "\n")
} }
// run frps and frpc // run frps and frpc
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf.String()})
for _, test := range tests { for _, test := range tests {
framework.NewRequestExpect(f). framework.NewRequestExpect(f).
@@ -102,7 +103,8 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
vhost_http_port = %d vhost_http_port = %d
`, vhostHTTPPort) `, vhostHTTPPort)
clientConf := consts.LegacyDefaultClientConfig var clientConf strings.Builder
clientConf.WriteString(consts.LegacyDefaultClientConfig)
getProxyConf := func(proxyName string, customDomains string, extra string) string { getProxyConf := func(proxyName string, customDomains string, extra string) string {
return fmt.Sprintf(` return fmt.Sprintf(`
@@ -147,13 +149,13 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
if tests[i].customDomains == "" { if tests[i].customDomains == "" {
tests[i].customDomains = test.proxyName + ".example.com" tests[i].customDomains = test.proxyName + ".example.com"
} }
clientConf += getProxyConf(test.proxyName, tests[i].customDomains, test.extraConfig) + "\n" clientConf.WriteString(getProxyConf(test.proxyName, tests[i].customDomains, test.extraConfig) + "\n")
} }
// run frps and frpc // run frps and frpc
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf.String()})
for _, test := range tests { for _, test := range tests {
for _, domain := range strings.Split(test.customDomains, ",") { for domain := range strings.SplitSeq(test.customDomains, ",") {
domain = strings.TrimSpace(domain) domain = strings.TrimSpace(domain)
framework.NewRequestExpect(f). framework.NewRequestExpect(f).
Explain(test.proxyName + "-" + domain). Explain(test.proxyName + "-" + domain).
@@ -185,7 +187,8 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
`, vhostHTTPSPort) `, vhostHTTPSPort)
localPort := f.AllocPort() localPort := f.AllocPort()
clientConf := consts.LegacyDefaultClientConfig var clientConf strings.Builder
clientConf.WriteString(consts.LegacyDefaultClientConfig)
getProxyConf := func(proxyName string, customDomains string, extra string) string { getProxyConf := func(proxyName string, customDomains string, extra string) string {
return fmt.Sprintf(` return fmt.Sprintf(`
[%s] [%s]
@@ -229,10 +232,10 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
if tests[i].customDomains == "" { if tests[i].customDomains == "" {
tests[i].customDomains = test.proxyName + ".example.com" tests[i].customDomains = test.proxyName + ".example.com"
} }
clientConf += getProxyConf(test.proxyName, tests[i].customDomains, test.extraConfig) + "\n" clientConf.WriteString(getProxyConf(test.proxyName, tests[i].customDomains, test.extraConfig) + "\n")
} }
// run frps and frpc // run frps and frpc
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf.String()})
tlsConfig, err := transport.NewServerTLSConfig("", "", "") tlsConfig, err := transport.NewServerTLSConfig("", "", "")
framework.ExpectNoError(err) framework.ExpectNoError(err)
@@ -244,7 +247,7 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
f.RunServer("", localServer) f.RunServer("", localServer)
for _, test := range tests { for _, test := range tests {
for _, domain := range strings.Split(test.customDomains, ",") { for domain := range strings.SplitSeq(test.customDomains, ",") {
domain = strings.TrimSpace(domain) domain = strings.TrimSpace(domain)
framework.NewRequestExpect(f). framework.NewRequestExpect(f).
Explain(test.proxyName + "-" + domain). Explain(test.proxyName + "-" + domain).
@@ -282,9 +285,12 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
proxyType := t proxyType := t
ginkgo.It(fmt.Sprintf("Expose echo server with %s", strings.ToUpper(proxyType)), func() { ginkgo.It(fmt.Sprintf("Expose echo server with %s", strings.ToUpper(proxyType)), func() {
serverConf := consts.LegacyDefaultServerConfig serverConf := consts.LegacyDefaultServerConfig
clientServerConf := consts.LegacyDefaultClientConfig + "\nuser = user1" var clientServerConf strings.Builder
clientVisitorConf := consts.LegacyDefaultClientConfig + "\nuser = user1" clientServerConf.WriteString(consts.LegacyDefaultClientConfig + "\nuser = user1")
clientUser2VisitorConf := consts.LegacyDefaultClientConfig + "\nuser = user2" var clientVisitorConf strings.Builder
clientVisitorConf.WriteString(consts.LegacyDefaultClientConfig + "\nuser = user1")
var clientUser2VisitorConf strings.Builder
clientUser2VisitorConf.WriteString(consts.LegacyDefaultClientConfig + "\nuser = user2")
localPortName := "" localPortName := ""
protocol := "tcp" protocol := "tcp"
@@ -400,20 +406,20 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
// build all client config // build all client config
for _, test := range tests { for _, test := range tests {
clientServerConf += getProxyServerConf(test.proxyName, test.commonExtraConfig+"\n"+test.proxyExtraConfig) + "\n" clientServerConf.WriteString(getProxyServerConf(test.proxyName, test.commonExtraConfig+"\n"+test.proxyExtraConfig) + "\n")
} }
for _, test := range tests { for _, test := range tests {
config := getProxyVisitorConf( config := getProxyVisitorConf(
test.proxyName, test.bindPortName, test.visitorSK, test.commonExtraConfig+"\n"+test.visitorExtraConfig, test.proxyName, test.bindPortName, test.visitorSK, test.commonExtraConfig+"\n"+test.visitorExtraConfig,
) + "\n" ) + "\n"
if test.deployUser2Client { if test.deployUser2Client {
clientUser2VisitorConf += config clientUser2VisitorConf.WriteString(config)
} else { } else {
clientVisitorConf += config clientVisitorConf.WriteString(config)
} }
} }
// run frps and frpc // run frps and frpc
f.RunProcesses([]string{serverConf}, []string{clientServerConf, clientVisitorConf, clientUser2VisitorConf}) f.RunProcesses(serverConf, []string{clientServerConf.String(), clientVisitorConf.String(), clientUser2VisitorConf.String()})
for _, test := range tests { for _, test := range tests {
timeout := time.Second timeout := time.Second
@@ -440,7 +446,8 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
ginkgo.Describe("TCPMUX", func() { ginkgo.Describe("TCPMUX", func() {
ginkgo.It("Type tcpmux", func() { ginkgo.It("Type tcpmux", func() {
serverConf := consts.LegacyDefaultServerConfig serverConf := consts.LegacyDefaultServerConfig
clientConf := consts.LegacyDefaultClientConfig var clientConf strings.Builder
clientConf.WriteString(consts.LegacyDefaultClientConfig)
tcpmuxHTTPConnectPortName := port.GenName("TCPMUX") tcpmuxHTTPConnectPortName := port.GenName("TCPMUX")
serverConf += fmt.Sprintf(` serverConf += fmt.Sprintf(`
@@ -483,14 +490,14 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
// build all client config // build all client config
for _, test := range tests { for _, test := range tests {
clientConf += getProxyConf(test.proxyName, test.extraConfig) + "\n" clientConf.WriteString(getProxyConf(test.proxyName, test.extraConfig) + "\n")
localServer := streamserver.New(streamserver.TCP, streamserver.WithBindPort(f.AllocPort()), streamserver.WithRespContent([]byte(test.proxyName))) localServer := streamserver.New(streamserver.TCP, streamserver.WithBindPort(f.AllocPort()), streamserver.WithRespContent([]byte(test.proxyName)))
f.RunServer(port.GenName(test.proxyName), localServer) f.RunServer(port.GenName(test.proxyName), localServer)
} }
// run frps and frpc // run frps and frpc
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf.String()})
// Request without HTTP connect should get error // Request without HTTP connect should get error
framework.NewRequestExpect(f). framework.NewRequestExpect(f).

View File

@@ -48,7 +48,7 @@ var _ = ginkgo.Describe("[Feature: ClientManage]", func() {
framework.TCPEchoServerPort, p2Port, framework.TCPEchoServerPort, p2Port,
framework.TCPEchoServerPort, p3Port) framework.TCPEchoServerPort, p3Port)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
framework.NewRequestExpect(f).Port(p1Port).Ensure() framework.NewRequestExpect(f).Port(p1Port).Ensure()
framework.NewRequestExpect(f).Port(p2Port).Ensure() framework.NewRequestExpect(f).Port(p2Port).Ensure()
@@ -90,7 +90,7 @@ var _ = ginkgo.Describe("[Feature: ClientManage]", func() {
admin_pwd = admin admin_pwd = admin
`, dashboardPort) `, dashboardPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
framework.NewRequestExpect(f).RequestModify(func(r *request.Request) { framework.NewRequestExpect(f).RequestModify(func(r *request.Request) {
r.HTTP().HTTPPath("/healthz") r.HTTP().HTTPPath("/healthz")
@@ -116,7 +116,7 @@ var _ = ginkgo.Describe("[Feature: ClientManage]", func() {
remote_port = %d remote_port = %d
`, adminPort, framework.TCPEchoServerPort, testPort) `, adminPort, framework.TCPEchoServerPort, testPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
framework.NewRequestExpect(f).Port(testPort).Ensure() framework.NewRequestExpect(f).Port(testPort).Ensure()

View File

@@ -76,7 +76,7 @@ func runClientServerTest(f *framework.Framework, configures *generalTestConfigur
clientConfs = append(clientConfs, client2Conf) clientConfs = append(clientConfs, client2Conf)
} }
f.RunProcesses([]string{serverConf}, clientConfs) f.RunProcesses(serverConf, clientConfs)
if configures.testDelay > 0 { if configures.testDelay > 0 {
time.Sleep(configures.testDelay) time.Sleep(configures.testDelay)

View File

@@ -33,7 +33,7 @@ var _ = ginkgo.Describe("[Feature: Config]", func() {
`, "`", "`", framework.TCPEchoServerPort, portName) `, "`", "`", framework.TCPEchoServerPort, portName)
f.SetEnvs([]string{"FRP_TOKEN=123"}) f.SetEnvs([]string{"FRP_TOKEN=123"})
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
framework.NewRequestExpect(f).PortName(portName).Ensure() framework.NewRequestExpect(f).PortName(portName).Ensure()
}) })

View File

@@ -56,7 +56,7 @@ var _ = ginkgo.Describe("[Feature: HTTP]", func() {
locations = /bar locations = /bar
`, fooPort, barPort) `, fooPort, barPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
tests := []struct { tests := []struct {
path string path string
@@ -111,7 +111,7 @@ var _ = ginkgo.Describe("[Feature: HTTP]", func() {
custom_domains = normal.example.com custom_domains = normal.example.com
`, fooPort, barPort, otherPort) `, fooPort, barPort, otherPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
// user1 // user1
framework.NewRequestExpect(f).Explain("user1").Port(vhostHTTPPort). framework.NewRequestExpect(f).Explain("user1").Port(vhostHTTPPort).
@@ -152,7 +152,7 @@ var _ = ginkgo.Describe("[Feature: HTTP]", func() {
http_pwd = test http_pwd = test
`, framework.HTTPSimpleServerPort) `, framework.HTTPSimpleServerPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
// not set auth header // not set auth header
framework.NewRequestExpect(f).Port(vhostHTTPPort). framework.NewRequestExpect(f).Port(vhostHTTPPort).
@@ -188,7 +188,7 @@ var _ = ginkgo.Describe("[Feature: HTTP]", func() {
custom_domains = *.example.com custom_domains = *.example.com
`, framework.HTTPSimpleServerPort) `, framework.HTTPSimpleServerPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
// not match host // not match host
framework.NewRequestExpect(f).Port(vhostHTTPPort). framework.NewRequestExpect(f).Port(vhostHTTPPort).
@@ -238,7 +238,7 @@ var _ = ginkgo.Describe("[Feature: HTTP]", func() {
subdomain = bar subdomain = bar
`, fooPort, barPort) `, fooPort, barPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
// foo // foo
framework.NewRequestExpect(f).Explain("foo subdomain").Port(vhostHTTPPort). framework.NewRequestExpect(f).Explain("foo subdomain").Port(vhostHTTPPort).
@@ -279,7 +279,7 @@ var _ = ginkgo.Describe("[Feature: HTTP]", func() {
header_X-From-Where = frp header_X-From-Where = frp
`, localPort) `, localPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
// not set auth header // not set auth header
framework.NewRequestExpect(f).Port(vhostHTTPPort). framework.NewRequestExpect(f).Port(vhostHTTPPort).
@@ -312,7 +312,7 @@ var _ = ginkgo.Describe("[Feature: HTTP]", func() {
host_header_rewrite = rewrite.example.com host_header_rewrite = rewrite.example.com
`, localPort) `, localPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
framework.NewRequestExpect(f).Port(vhostHTTPPort). framework.NewRequestExpect(f).Port(vhostHTTPPort).
RequestModify(func(r *request.Request) { RequestModify(func(r *request.Request) {
@@ -360,7 +360,7 @@ var _ = ginkgo.Describe("[Feature: HTTP]", func() {
custom_domains = 127.0.0.1 custom_domains = 127.0.0.1
`, localPort) `, localPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
u := url.URL{Scheme: "ws", Host: "127.0.0.1:" + strconv.Itoa(vhostHTTPPort)} u := url.URL{Scheme: "ws", Host: "127.0.0.1:" + strconv.Itoa(vhostHTTPPort)}
c, _, err := websocket.DefaultDialer.Dial(u.String(), nil) c, _, err := websocket.DefaultDialer.Dial(u.String(), nil)

View File

@@ -28,7 +28,7 @@ var _ = ginkgo.Describe("[Feature: Server Manager]", func() {
tcpPortName := port.GenName("TCP", port.WithRangePorts(10000, 11000)) tcpPortName := port.GenName("TCP", port.WithRangePorts(10000, 11000))
udpPortName := port.GenName("UDP", port.WithRangePorts(12000, 13000)) udpPortName := port.GenName("UDP", port.WithRangePorts(12000, 13000))
clientConf += fmt.Sprintf(` clientConf += fmt.Sprintf(`
[tcp-allowded-in-range] [tcp-allowed-in-range]
type = tcp type = tcp
local_port = {{ .%s }} local_port = {{ .%s }}
remote_port = {{ .%s }} remote_port = {{ .%s }}
@@ -58,7 +58,7 @@ var _ = ginkgo.Describe("[Feature: Server Manager]", func() {
remote_port = 11003 remote_port = 11003
`, framework.UDPEchoServerPort) `, framework.UDPEchoServerPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
// TCP // TCP
// Allowed in range // Allowed in range
@@ -97,7 +97,7 @@ var _ = ginkgo.Describe("[Feature: Server Manager]", func() {
local_port = {{ .%s }} local_port = {{ .%s }}
`, adminPort, framework.TCPEchoServerPort, framework.UDPEchoServerPort) `, adminPort, framework.TCPEchoServerPort, framework.UDPEchoServerPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
client := f.APIClientForFrpc(adminPort) client := f.APIClientForFrpc(adminPort)
@@ -138,7 +138,7 @@ var _ = ginkgo.Describe("[Feature: Server Manager]", func() {
custom_domains = example.com custom_domains = example.com
`, framework.HTTPSimpleServerPort) `, framework.HTTPSimpleServerPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
framework.NewRequestExpect(f).RequestModify(func(r *request.Request) { framework.NewRequestExpect(f).RequestModify(func(r *request.Request) {
r.HTTP().HTTPHost("example.com") r.HTTP().HTTPHost("example.com")
@@ -165,7 +165,7 @@ var _ = ginkgo.Describe("[Feature: Server Manager]", func() {
custom_domains = example.com custom_domains = example.com
`, framework.HTTPSimpleServerPort) `, framework.HTTPSimpleServerPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
framework.NewRequestExpect(f).RequestModify(func(r *request.Request) { framework.NewRequestExpect(f).RequestModify(func(r *request.Request) {
r.HTTP().HTTPPath("/healthz") r.HTTP().HTTPPath("/healthz")

View File

@@ -76,7 +76,7 @@ var _ = ginkgo.Describe("[Feature: TCPMUX httpconnect]", func() {
custom_domains = normal.example.com custom_domains = normal.example.com
`, fooPort, barPort, otherPort) `, fooPort, barPort, otherPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
// user1 // user1
framework.NewRequestExpect(f).Explain("user1"). framework.NewRequestExpect(f).Explain("user1").
@@ -121,7 +121,7 @@ var _ = ginkgo.Describe("[Feature: TCPMUX httpconnect]", func() {
http_pwd = test http_pwd = test
`, fooPort) `, fooPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
// not set auth header // not set auth header
framework.NewRequestExpect(f).Explain("no auth"). framework.NewRequestExpect(f).Explain("no auth").
@@ -204,7 +204,7 @@ var _ = ginkgo.Describe("[Feature: TCPMUX httpconnect]", func() {
custom_domains = normal.example.com custom_domains = normal.example.com
`, localPort) `, localPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
framework.NewRequestExpect(f). framework.NewRequestExpect(f).
RequestModify(func(r *request.Request) { RequestModify(func(r *request.Request) {

View File

@@ -41,7 +41,7 @@ var _ = ginkgo.Describe("[Feature: XTCP]", func() {
fallback_timeout_ms = 200 fallback_timeout_ms = 200
`, framework.TCPEchoServerPort, bindPortName) `, framework.TCPEchoServerPort, bindPortName)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
framework.NewRequestExpect(f). framework.NewRequestExpect(f).
RequestModify(func(r *request.Request) { RequestModify(func(r *request.Request) {
r.Timeout(time.Second) r.Timeout(time.Second)

View File

@@ -35,7 +35,7 @@ var _ = ginkgo.Describe("[Feature: Bandwidth Limit]", func() {
bandwidth_limit = 10KB bandwidth_limit = 10KB
`, localPort, remotePort) `, localPort, remotePort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
content := strings.Repeat("a", 50*1024) // 5KB content := strings.Repeat("a", 50*1024) // 5KB
start := time.Now() start := time.Now()
@@ -89,7 +89,7 @@ var _ = ginkgo.Describe("[Feature: Bandwidth Limit]", func() {
remote_port = %d remote_port = %d
`, localPort, remotePort) `, localPort, remotePort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
content := strings.Repeat("a", 50*1024) // 5KB content := strings.Repeat("a", 50*1024) // 5KB
start := time.Now() start := time.Now()

View File

@@ -48,12 +48,10 @@ var _ = ginkgo.Describe("[Feature: Group]", func() {
return true return true
}) })
} }
for i := 0; i < 10; i++ { for range 10 {
wait.Add(1) wait.Go(func() {
go func() {
defer wait.Done()
expectFn() expectFn()
}() })
} }
wait.Wait() wait.Wait()
@@ -90,11 +88,11 @@ var _ = ginkgo.Describe("[Feature: Group]", func() {
group_key = 123 group_key = 123
`, fooPort, remotePort, barPort, remotePort) `, fooPort, remotePort, barPort, remotePort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
fooCount := 0 fooCount := 0
barCount := 0 barCount := 0
for i := 0; i < 10; i++ { for i := range 10 {
framework.NewRequestExpect(f).Explain("times " + strconv.Itoa(i)).Port(remotePort).Ensure(func(resp *request.Response) bool { framework.NewRequestExpect(f).Explain("times " + strconv.Itoa(i)).Port(remotePort).Ensure(func(resp *request.Response) bool {
switch string(resp.Content) { switch string(resp.Content) {
case "foo": case "foo":
@@ -146,11 +144,11 @@ var _ = ginkgo.Describe("[Feature: Group]", func() {
health_check_interval_s = 1 health_check_interval_s = 1
`, fooPort, remotePort, barPort, remotePort) `, fooPort, remotePort, barPort, remotePort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
// check foo and bar is ok // check foo and bar is ok
results := []string{} results := []string{}
for i := 0; i < 10; i++ { for range 10 {
framework.NewRequestExpect(f).Port(remotePort).Ensure(validateFooBarResponse, func(resp *request.Response) bool { framework.NewRequestExpect(f).Port(remotePort).Ensure(validateFooBarResponse, func(resp *request.Response) bool {
results = append(results, string(resp.Content)) results = append(results, string(resp.Content))
return true return true
@@ -161,7 +159,7 @@ var _ = ginkgo.Describe("[Feature: Group]", func() {
// close bar server, check foo is ok // close bar server, check foo is ok
barServer.Close() barServer.Close()
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
for i := 0; i < 10; i++ { for range 10 {
framework.NewRequestExpect(f).Port(remotePort).ExpectResp([]byte("foo")).Ensure() framework.NewRequestExpect(f).Port(remotePort).ExpectResp([]byte("foo")).Ensure()
} }
@@ -169,7 +167,7 @@ var _ = ginkgo.Describe("[Feature: Group]", func() {
f.RunServer("", barServer) f.RunServer("", barServer)
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
results = []string{} results = []string{}
for i := 0; i < 10; i++ { for range 10 {
framework.NewRequestExpect(f).Port(remotePort).Ensure(validateFooBarResponse, func(resp *request.Response) bool { framework.NewRequestExpect(f).Port(remotePort).Ensure(validateFooBarResponse, func(resp *request.Response) bool {
results = append(results, string(resp.Content)) results = append(results, string(resp.Content))
return true return true
@@ -215,7 +213,7 @@ var _ = ginkgo.Describe("[Feature: Group]", func() {
health_check_url = /healthz health_check_url = /healthz
`, fooPort, barPort) `, fooPort, barPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
// send first HTTP request // send first HTTP request
var contents []string var contents []string

View File

@@ -38,7 +38,7 @@ var _ = ginkgo.Describe("[Feature: Heartbeat]", func() {
`, serverPort, f.PortByName(framework.TCPEchoServerPort), remotePort) `, serverPort, f.PortByName(framework.TCPEchoServerPort), remotePort)
// run frps and frpc // run frps and frpc
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
framework.NewRequestExpect(f).Protocol("tcp").Port(remotePort).Ensure() framework.NewRequestExpect(f).Protocol("tcp").Port(remotePort).Ensure()

View File

@@ -33,7 +33,7 @@ var _ = ginkgo.Describe("[Feature: Monitor]", func() {
remote_port = %d remote_port = %d
`, framework.TCPEchoServerPort, remotePort) `, framework.TCPEchoServerPort, remotePort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
framework.NewRequestExpect(f).Port(remotePort).Ensure() framework.NewRequestExpect(f).Port(remotePort).Ensure()
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)

View File

@@ -44,7 +44,7 @@ var _ = ginkgo.Describe("[Feature: Real IP]", func() {
custom_domains = normal.example.com custom_domains = normal.example.com
`, localPort) `, localPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
framework.NewRequestExpect(f).Port(vhostHTTPPort). framework.NewRequestExpect(f).Port(vhostHTTPPort).
RequestModify(func(r *request.Request) { RequestModify(func(r *request.Request) {
@@ -90,7 +90,7 @@ var _ = ginkgo.Describe("[Feature: Real IP]", func() {
proxy_protocol_version = v2 proxy_protocol_version = v2
`, localPort, remotePort) `, localPort, remotePort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
framework.NewRequestExpect(f).Port(remotePort).Ensure(func(resp *request.Response) bool { framework.NewRequestExpect(f).Port(remotePort).Ensure(func(resp *request.Response) bool {
log.Tracef("proxy protocol get SourceAddr: %s", string(resp.Content)) log.Tracef("proxy protocol get SourceAddr: %s", string(resp.Content))
@@ -136,7 +136,7 @@ var _ = ginkgo.Describe("[Feature: Real IP]", func() {
proxy_protocol_version = v2 proxy_protocol_version = v2
`, localPort) `, localPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
framework.NewRequestExpect(f).Port(vhostHTTPPort).RequestModify(func(r *request.Request) { framework.NewRequestExpect(f).Port(vhostHTTPPort).RequestModify(func(r *request.Request) {
r.HTTP().HTTPHost("normal.example.com") r.HTTP().HTTPHost("normal.example.com")

View File

@@ -4,6 +4,7 @@ import (
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"strconv" "strconv"
"strings"
"github.com/onsi/ginkgo/v2" "github.com/onsi/ginkgo/v2"
@@ -22,7 +23,8 @@ var _ = ginkgo.Describe("[Feature: Client-Plugins]", func() {
ginkgo.Describe("UnixDomainSocket", func() { ginkgo.Describe("UnixDomainSocket", func() {
ginkgo.It("Expose a unix domain socket echo server", func() { ginkgo.It("Expose a unix domain socket echo server", func() {
serverConf := consts.LegacyDefaultServerConfig serverConf := consts.LegacyDefaultServerConfig
clientConf := consts.LegacyDefaultClientConfig var clientConf strings.Builder
clientConf.WriteString(consts.LegacyDefaultClientConfig)
getProxyConf := func(proxyName string, portName string, extra string) string { getProxyConf := func(proxyName string, portName string, extra string) string {
return fmt.Sprintf(` return fmt.Sprintf(`
@@ -65,10 +67,10 @@ var _ = ginkgo.Describe("[Feature: Client-Plugins]", func() {
// build all client config // build all client config
for _, test := range tests { for _, test := range tests {
clientConf += getProxyConf(test.proxyName, test.portName, test.extraConfig) + "\n" clientConf.WriteString(getProxyConf(test.proxyName, test.portName, test.extraConfig) + "\n")
} }
// run frps and frpc // run frps and frpc
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf.String()})
for _, test := range tests { for _, test := range tests {
framework.NewRequestExpect(f).Port(f.PortByName(test.portName)).Ensure() framework.NewRequestExpect(f).Port(f.PortByName(test.portName)).Ensure()
@@ -90,7 +92,7 @@ var _ = ginkgo.Describe("[Feature: Client-Plugins]", func() {
plugin_http_passwd = 123 plugin_http_passwd = 123
`, remotePort) `, remotePort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
// http proxy, no auth info // http proxy, no auth info
framework.NewRequestExpect(f).PortName(framework.HTTPSimpleServerPort).RequestModify(func(r *request.Request) { framework.NewRequestExpect(f).PortName(framework.HTTPSimpleServerPort).RequestModify(func(r *request.Request) {
@@ -122,7 +124,7 @@ var _ = ginkgo.Describe("[Feature: Client-Plugins]", func() {
plugin_passwd = 123 plugin_passwd = 123
`, remotePort) `, remotePort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
// http proxy, no auth info // http proxy, no auth info
framework.NewRequestExpect(f).PortName(framework.TCPEchoServerPort).RequestModify(func(r *request.Request) { framework.NewRequestExpect(f).PortName(framework.TCPEchoServerPort).RequestModify(func(r *request.Request) {
@@ -166,7 +168,7 @@ var _ = ginkgo.Describe("[Feature: Client-Plugins]", func() {
plugin_http_passwd = 123 plugin_http_passwd = 123
`, remotePort, f.TempDirectory, f.TempDirectory, f.TempDirectory) `, remotePort, f.TempDirectory, f.TempDirectory, f.TempDirectory)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
// from tcp proxy // from tcp proxy
framework.NewRequestExpect(f).Request( framework.NewRequestExpect(f).Request(
@@ -200,7 +202,7 @@ var _ = ginkgo.Describe("[Feature: Client-Plugins]", func() {
plugin_local_addr = 127.0.0.1:%d plugin_local_addr = 127.0.0.1:%d
`, localPort) `, localPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
tlsConfig, err := transport.NewServerTLSConfig("", "", "") tlsConfig, err := transport.NewServerTLSConfig("", "", "")
framework.ExpectNoError(err) framework.ExpectNoError(err)
@@ -244,7 +246,7 @@ var _ = ginkgo.Describe("[Feature: Client-Plugins]", func() {
plugin_key_path = %s plugin_key_path = %s
`, localPort, crtPath, keyPath) `, localPort, crtPath, keyPath)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
localServer := httpserver.New( localServer := httpserver.New(
httpserver.WithBindPort(localPort), httpserver.WithBindPort(localPort),
@@ -288,7 +290,7 @@ var _ = ginkgo.Describe("[Feature: Client-Plugins]", func() {
plugin_key_path = %s plugin_key_path = %s
`, localPort, crtPath, keyPath) `, localPort, crtPath, keyPath)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
tlsConfig, err := transport.NewServerTLSConfig("", "", "") tlsConfig, err := transport.NewServerTLSConfig("", "", "")
framework.ExpectNoError(err) framework.ExpectNoError(err)

View File

@@ -71,7 +71,7 @@ var _ = ginkgo.Describe("[Feature: Server-Plugins]", func() {
remote_port = %d remote_port = %d
`, framework.TCPEchoServerPort, remotePort2) `, framework.TCPEchoServerPort, remotePort2)
f.RunProcesses([]string{serverConf}, []string{clientConf, invalidTokenClientConf}) f.RunProcesses(serverConf, []string{clientConf, invalidTokenClientConf})
framework.NewRequestExpect(f).Port(remotePort).Ensure() framework.NewRequestExpect(f).Port(remotePort).Ensure()
framework.NewRequestExpect(f).Port(remotePort2).ExpectError(true).Ensure() framework.NewRequestExpect(f).Port(remotePort2).ExpectError(true).Ensure()
@@ -119,7 +119,7 @@ var _ = ginkgo.Describe("[Feature: Server-Plugins]", func() {
remote_port = %d remote_port = %d
`, framework.TCPEchoServerPort, remotePort) `, framework.TCPEchoServerPort, remotePort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
framework.NewRequestExpect(f).Port(remotePort).Ensure() framework.NewRequestExpect(f).Port(remotePort).Ensure()
}) })
@@ -153,7 +153,7 @@ var _ = ginkgo.Describe("[Feature: Server-Plugins]", func() {
remote_port = 0 remote_port = 0
`, framework.TCPEchoServerPort) `, framework.TCPEchoServerPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
framework.NewRequestExpect(f).Port(remotePort).Ensure() framework.NewRequestExpect(f).Port(remotePort).Ensure()
}) })
@@ -195,7 +195,7 @@ var _ = ginkgo.Describe("[Feature: Server-Plugins]", func() {
remote_port = %d remote_port = %d
`, framework.TCPEchoServerPort, remotePort) `, framework.TCPEchoServerPort, remotePort)
_, clients := f.RunProcesses([]string{serverConf}, []string{clientConf}) _, clients := f.RunProcesses(serverConf, []string{clientConf})
framework.NewRequestExpect(f).Port(remotePort).Ensure() framework.NewRequestExpect(f).Port(remotePort).Ensure()
@@ -250,7 +250,7 @@ var _ = ginkgo.Describe("[Feature: Server-Plugins]", func() {
remote_port = %d remote_port = %d
`, framework.TCPEchoServerPort, remotePort) `, framework.TCPEchoServerPort, remotePort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
framework.NewRequestExpect(f).Port(remotePort).Ensure() framework.NewRequestExpect(f).Port(remotePort).Ensure()
@@ -297,7 +297,7 @@ var _ = ginkgo.Describe("[Feature: Server-Plugins]", func() {
remote_port = %d remote_port = %d
`, framework.TCPEchoServerPort, remotePort) `, framework.TCPEchoServerPort, remotePort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
framework.NewRequestExpect(f).Port(remotePort).Ensure() framework.NewRequestExpect(f).Port(remotePort).Ensure()
@@ -342,7 +342,7 @@ var _ = ginkgo.Describe("[Feature: Server-Plugins]", func() {
remote_port = %d remote_port = %d
`, framework.TCPEchoServerPort, remotePort) `, framework.TCPEchoServerPort, remotePort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
framework.NewRequestExpect(f).Port(remotePort).Ensure() framework.NewRequestExpect(f).Port(remotePort).Ensure()
@@ -389,7 +389,7 @@ var _ = ginkgo.Describe("[Feature: Server-Plugins]", func() {
remote_port = %d remote_port = %d
`, framework.TCPEchoServerPort, remotePort) `, framework.TCPEchoServerPort, remotePort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
framework.NewRequestExpect(f).Port(remotePort).Ensure() framework.NewRequestExpect(f).Port(remotePort).Ensure()

View File

@@ -52,7 +52,7 @@ func (pa *Allocator) GetByName(portName string) int {
pa.mu.Lock() pa.mu.Lock()
defer pa.mu.Unlock() defer pa.mu.Unlock()
for i := 0; i < 20; i++ { for range 20 {
port := pa.getByRange(builder.rangePortFrom, builder.rangePortTo) port := pa.getByRange(builder.rangePortFrom, builder.rangePortTo)
if port == 0 { if port == 0 {
return 0 return 0

View File

@@ -3,15 +3,44 @@ package process
import ( import (
"bytes" "bytes"
"context" "context"
"errors"
"fmt"
"os/exec" "os/exec"
"strings"
"sync"
"time"
) )
// SafeBuffer is a thread-safe wrapper around bytes.Buffer.
// It is safe to call Write and String concurrently.
type SafeBuffer struct {
mu sync.Mutex
buf bytes.Buffer
}
func (b *SafeBuffer) Write(p []byte) (int, error) {
b.mu.Lock()
defer b.mu.Unlock()
return b.buf.Write(p)
}
func (b *SafeBuffer) String() string {
b.mu.Lock()
defer b.mu.Unlock()
return b.buf.String()
}
type Process struct { type Process struct {
cmd *exec.Cmd cmd *exec.Cmd
cancel context.CancelFunc cancel context.CancelFunc
errorOutput *bytes.Buffer errorOutput *SafeBuffer
stdOutput *bytes.Buffer stdOutput *SafeBuffer
done chan struct{}
closeOne sync.Once
waitErr error
started bool
beforeStopHandler func() beforeStopHandler func()
stopped bool stopped bool
} }
@@ -27,20 +56,45 @@ func NewWithEnvs(path string, params []string, envs []string) *Process {
p := &Process{ p := &Process{
cmd: cmd, cmd: cmd,
cancel: cancel, cancel: cancel,
done: make(chan struct{}),
} }
p.errorOutput = bytes.NewBufferString("") p.errorOutput = &SafeBuffer{}
p.stdOutput = bytes.NewBufferString("") p.stdOutput = &SafeBuffer{}
cmd.Stderr = p.errorOutput cmd.Stderr = p.errorOutput
cmd.Stdout = p.stdOutput cmd.Stdout = p.stdOutput
return p return p
} }
func (p *Process) Start() error { func (p *Process) Start() error {
return p.cmd.Start() if p.started {
return errors.New("process already started")
}
p.started = true
err := p.cmd.Start()
if err != nil {
p.waitErr = err
p.closeDone()
return err
}
go func() {
p.waitErr = p.cmd.Wait()
p.closeDone()
}()
return nil
}
func (p *Process) closeDone() {
p.closeOne.Do(func() { close(p.done) })
}
// Done returns a channel that is closed when the process exits.
func (p *Process) Done() <-chan struct{} {
return p.done
} }
func (p *Process) Stop() error { func (p *Process) Stop() error {
if p.stopped { if p.stopped || !p.started {
return nil return nil
} }
defer func() { defer func() {
@@ -50,7 +104,8 @@ func (p *Process) Stop() error {
p.beforeStopHandler() p.beforeStopHandler()
} }
p.cancel() p.cancel()
return p.cmd.Wait() <-p.done
return p.waitErr
} }
func (p *Process) ErrorOutput() string { func (p *Process) ErrorOutput() string {
@@ -61,6 +116,38 @@ func (p *Process) StdOutput() string {
return p.stdOutput.String() return p.stdOutput.String()
} }
func (p *Process) Output() string {
return p.stdOutput.String() + p.errorOutput.String()
}
// CountOutput returns how many times pattern appears in the current accumulated output.
func (p *Process) CountOutput(pattern string) int {
return strings.Count(p.Output(), pattern)
}
func (p *Process) SetBeforeStopHandler(fn func()) { func (p *Process) SetBeforeStopHandler(fn func()) {
p.beforeStopHandler = fn p.beforeStopHandler = fn
} }
// WaitForOutput polls the combined process output until the pattern is found
// count time(s) or the timeout is reached. It also returns early if the process exits.
func (p *Process) WaitForOutput(pattern string, count int, timeout time.Duration) error {
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
output := p.Output()
if strings.Count(output, pattern) >= count {
return nil
}
select {
case <-p.Done():
// Process exited, check one last time.
output = p.Output()
if strings.Count(output, pattern) >= count {
return nil
}
return fmt.Errorf("process exited before %d occurrence(s) of %q found", count, pattern)
case <-time.After(25 * time.Millisecond):
}
}
return fmt.Errorf("timeout waiting for %d occurrence(s) of %q", count, pattern)
}

View File

@@ -75,11 +75,11 @@ func (c *TunnelClient) serveListener() {
if err != nil { if err != nil {
return return
} }
go c.hanldeConn(conn) go c.handleConn(conn)
} }
} }
func (c *TunnelClient) hanldeConn(conn net.Conn) { func (c *TunnelClient) handleConn(conn net.Conn) {
defer conn.Close() defer conn.Close()
local, err := net.Dial("tcp", c.localAddr) local, err := net.Dial("tcp", c.localAddr)
if err != nil { if err != nil {

View File

@@ -35,7 +35,7 @@ var _ = ginkgo.Describe("[Feature: Annotations]", func() {
"frp.e2e.test/bar" = "value2" "frp.e2e.test/bar" = "value2"
`, framework.TCPEchoServerPort, p1Port) `, framework.TCPEchoServerPort, p1Port)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
framework.NewRequestExpect(f).Port(p1Port).Ensure() framework.NewRequestExpect(f).Port(p1Port).Ensure()

View File

@@ -26,7 +26,8 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
proxyType := t proxyType := t
ginkgo.It(fmt.Sprintf("Expose a %s echo server", strings.ToUpper(proxyType)), func() { ginkgo.It(fmt.Sprintf("Expose a %s echo server", strings.ToUpper(proxyType)), func() {
serverConf := consts.DefaultServerConfig serverConf := consts.DefaultServerConfig
clientConf := consts.DefaultClientConfig var clientConf strings.Builder
clientConf.WriteString(consts.DefaultClientConfig)
localPortName := "" localPortName := ""
protocol := "tcp" protocol := "tcp"
@@ -79,10 +80,10 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
// build all client config // build all client config
for _, test := range tests { for _, test := range tests {
clientConf += getProxyConf(test.proxyName, test.portName, test.extraConfig) + "\n" clientConf.WriteString(getProxyConf(test.proxyName, test.portName, test.extraConfig) + "\n")
} }
// run frps and frpc // run frps and frpc
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf.String()})
for _, test := range tests { for _, test := range tests {
framework.NewRequestExpect(f). framework.NewRequestExpect(f).
@@ -103,7 +104,8 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
vhostHTTPPort = %d vhostHTTPPort = %d
`, vhostHTTPPort) `, vhostHTTPPort)
clientConf := consts.DefaultClientConfig var clientConf strings.Builder
clientConf.WriteString(consts.DefaultClientConfig)
getProxyConf := func(proxyName string, customDomains string, extra string) string { getProxyConf := func(proxyName string, customDomains string, extra string) string {
return fmt.Sprintf(` return fmt.Sprintf(`
@@ -149,13 +151,13 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
if tests[i].customDomains == "" { if tests[i].customDomains == "" {
tests[i].customDomains = fmt.Sprintf(`["%s"]`, test.proxyName+".example.com") tests[i].customDomains = fmt.Sprintf(`["%s"]`, test.proxyName+".example.com")
} }
clientConf += getProxyConf(test.proxyName, tests[i].customDomains, test.extraConfig) + "\n" clientConf.WriteString(getProxyConf(test.proxyName, tests[i].customDomains, test.extraConfig) + "\n")
} }
// run frps and frpc // run frps and frpc
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf.String()})
for _, test := range tests { for _, test := range tests {
for _, domain := range strings.Split(test.customDomains, ",") { for domain := range strings.SplitSeq(test.customDomains, ",") {
domain = strings.TrimSpace(domain) domain = strings.TrimSpace(domain)
domain = strings.TrimLeft(domain, "[\"") domain = strings.TrimLeft(domain, "[\"")
domain = strings.TrimRight(domain, "]\"") domain = strings.TrimRight(domain, "]\"")
@@ -189,7 +191,8 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
`, vhostHTTPSPort) `, vhostHTTPSPort)
localPort := f.AllocPort() localPort := f.AllocPort()
clientConf := consts.DefaultClientConfig var clientConf strings.Builder
clientConf.WriteString(consts.DefaultClientConfig)
getProxyConf := func(proxyName string, customDomains string, extra string) string { getProxyConf := func(proxyName string, customDomains string, extra string) string {
return fmt.Sprintf(` return fmt.Sprintf(`
[[proxies]] [[proxies]]
@@ -234,10 +237,10 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
if tests[i].customDomains == "" { if tests[i].customDomains == "" {
tests[i].customDomains = fmt.Sprintf(`["%s"]`, test.proxyName+".example.com") tests[i].customDomains = fmt.Sprintf(`["%s"]`, test.proxyName+".example.com")
} }
clientConf += getProxyConf(test.proxyName, tests[i].customDomains, test.extraConfig) + "\n" clientConf.WriteString(getProxyConf(test.proxyName, tests[i].customDomains, test.extraConfig) + "\n")
} }
// run frps and frpc // run frps and frpc
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf.String()})
tlsConfig, err := transport.NewServerTLSConfig("", "", "") tlsConfig, err := transport.NewServerTLSConfig("", "", "")
framework.ExpectNoError(err) framework.ExpectNoError(err)
@@ -249,7 +252,7 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
f.RunServer("", localServer) f.RunServer("", localServer)
for _, test := range tests { for _, test := range tests {
for _, domain := range strings.Split(test.customDomains, ",") { for domain := range strings.SplitSeq(test.customDomains, ",") {
domain = strings.TrimSpace(domain) domain = strings.TrimSpace(domain)
domain = strings.TrimLeft(domain, "[\"") domain = strings.TrimLeft(domain, "[\"")
domain = strings.TrimRight(domain, "]\"") domain = strings.TrimRight(domain, "]\"")
@@ -289,9 +292,12 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
proxyType := t proxyType := t
ginkgo.It(fmt.Sprintf("Expose echo server with %s", strings.ToUpper(proxyType)), func() { ginkgo.It(fmt.Sprintf("Expose echo server with %s", strings.ToUpper(proxyType)), func() {
serverConf := consts.DefaultServerConfig serverConf := consts.DefaultServerConfig
clientServerConf := consts.DefaultClientConfig + "\nuser = \"user1\"" var clientServerConf strings.Builder
clientVisitorConf := consts.DefaultClientConfig + "\nuser = \"user1\"" clientServerConf.WriteString(consts.DefaultClientConfig + "\nuser = \"user1\"")
clientUser2VisitorConf := consts.DefaultClientConfig + "\nuser = \"user2\"" var clientVisitorConf strings.Builder
clientVisitorConf.WriteString(consts.DefaultClientConfig + "\nuser = \"user1\"")
var clientUser2VisitorConf strings.Builder
clientUser2VisitorConf.WriteString(consts.DefaultClientConfig + "\nuser = \"user2\"")
localPortName := "" localPortName := ""
protocol := "tcp" protocol := "tcp"
@@ -407,20 +413,20 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
// build all client config // build all client config
for _, test := range tests { for _, test := range tests {
clientServerConf += getProxyServerConf(test.proxyName, test.commonExtraConfig+"\n"+test.proxyExtraConfig) + "\n" clientServerConf.WriteString(getProxyServerConf(test.proxyName, test.commonExtraConfig+"\n"+test.proxyExtraConfig) + "\n")
} }
for _, test := range tests { for _, test := range tests {
config := getProxyVisitorConf( config := getProxyVisitorConf(
test.proxyName, test.bindPortName, test.visitorSK, test.commonExtraConfig+"\n"+test.visitorExtraConfig, test.proxyName, test.bindPortName, test.visitorSK, test.commonExtraConfig+"\n"+test.visitorExtraConfig,
) + "\n" ) + "\n"
if test.deployUser2Client { if test.deployUser2Client {
clientUser2VisitorConf += config clientUser2VisitorConf.WriteString(config)
} else { } else {
clientVisitorConf += config clientVisitorConf.WriteString(config)
} }
} }
// run frps and frpc // run frps and frpc
f.RunProcesses([]string{serverConf}, []string{clientServerConf, clientVisitorConf, clientUser2VisitorConf}) f.RunProcesses(serverConf, []string{clientServerConf.String(), clientVisitorConf.String(), clientUser2VisitorConf.String()})
for _, test := range tests { for _, test := range tests {
timeout := time.Second timeout := time.Second
@@ -447,7 +453,8 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
ginkgo.Describe("TCPMUX", func() { ginkgo.Describe("TCPMUX", func() {
ginkgo.It("Type tcpmux", func() { ginkgo.It("Type tcpmux", func() {
serverConf := consts.DefaultServerConfig serverConf := consts.DefaultServerConfig
clientConf := consts.DefaultClientConfig var clientConf strings.Builder
clientConf.WriteString(consts.DefaultClientConfig)
tcpmuxHTTPConnectPortName := port.GenName("TCPMUX") tcpmuxHTTPConnectPortName := port.GenName("TCPMUX")
serverConf += fmt.Sprintf(` serverConf += fmt.Sprintf(`
@@ -491,14 +498,14 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
// build all client config // build all client config
for _, test := range tests { for _, test := range tests {
clientConf += getProxyConf(test.proxyName, test.extraConfig) + "\n" clientConf.WriteString(getProxyConf(test.proxyName, test.extraConfig) + "\n")
localServer := streamserver.New(streamserver.TCP, streamserver.WithBindPort(f.AllocPort()), streamserver.WithRespContent([]byte(test.proxyName))) localServer := streamserver.New(streamserver.TCP, streamserver.WithBindPort(f.AllocPort()), streamserver.WithRespContent([]byte(test.proxyName)))
f.RunServer(port.GenName(test.proxyName), localServer) f.RunServer(port.GenName(test.proxyName), localServer)
} }
// run frps and frpc // run frps and frpc
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf.String()})
// Request without HTTP connect should get error // Request without HTTP connect should get error
framework.NewRequestExpect(f). framework.NewRequestExpect(f).

View File

@@ -51,7 +51,7 @@ var _ = ginkgo.Describe("[Feature: ClientManage]", func() {
framework.TCPEchoServerPort, p2Port, framework.TCPEchoServerPort, p2Port,
framework.TCPEchoServerPort, p3Port) framework.TCPEchoServerPort, p3Port)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
framework.NewRequestExpect(f).Port(p1Port).Ensure() framework.NewRequestExpect(f).Port(p1Port).Ensure()
framework.NewRequestExpect(f).Port(p2Port).Ensure() framework.NewRequestExpect(f).Port(p2Port).Ensure()
@@ -93,7 +93,7 @@ var _ = ginkgo.Describe("[Feature: ClientManage]", func() {
webServer.password = "admin" webServer.password = "admin"
`, dashboardPort) `, dashboardPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
framework.NewRequestExpect(f).RequestModify(func(r *request.Request) { framework.NewRequestExpect(f).RequestModify(func(r *request.Request) {
r.HTTP().HTTPPath("/healthz") r.HTTP().HTTPPath("/healthz")
@@ -120,7 +120,7 @@ var _ = ginkgo.Describe("[Feature: ClientManage]", func() {
remotePort = %d remotePort = %d
`, adminPort, framework.TCPEchoServerPort, testPort) `, adminPort, framework.TCPEchoServerPort, testPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
framework.NewRequestExpect(f).Port(testPort).Ensure() framework.NewRequestExpect(f).Port(testPort).Ensure()

View File

@@ -78,7 +78,7 @@ func runClientServerTest(f *framework.Framework, configures *generalTestConfigur
clientConfs = append(clientConfs, client2Conf) clientConfs = append(clientConfs, client2Conf)
} }
f.RunProcesses([]string{serverConf}, clientConfs) f.RunProcesses(serverConf, clientConfs)
if configures.testDelay > 0 { if configures.testDelay > 0 {
time.Sleep(configures.testDelay) time.Sleep(configures.testDelay)

View File

@@ -35,7 +35,7 @@ var _ = ginkgo.Describe("[Feature: Config]", func() {
`, "`", "`", framework.TCPEchoServerPort, portName) `, "`", "`", framework.TCPEchoServerPort, portName)
f.SetEnvs([]string{"FRP_TOKEN=123"}) f.SetEnvs([]string{"FRP_TOKEN=123"})
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
framework.NewRequestExpect(f).PortName(portName).Ensure() framework.NewRequestExpect(f).PortName(portName).Ensure()
}) })
@@ -69,7 +69,7 @@ var _ = ginkgo.Describe("[Feature: Config]", func() {
escapeTemplate("{{- end }}"), escapeTemplate("{{- end }}"),
) )
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
client := f.APIClientForFrpc(adminPort) client := f.APIClientForFrpc(adminPort)
checkProxyFn := func(name string, localPort, remotePort int) { checkProxyFn := func(name string, localPort, remotePort int) {
@@ -149,7 +149,7 @@ proxies:
remotePort: %d remotePort: %d
`, port.GenName("Server"), framework.TCPEchoServerPort, remotePort) `, port.GenName("Server"), framework.TCPEchoServerPort, remotePort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
framework.NewRequestExpect(f).Port(remotePort).Ensure() framework.NewRequestExpect(f).Port(remotePort).Ensure()
}) })
@@ -161,7 +161,7 @@ proxies:
"proxies": [{"name": "tcp", "type": "tcp", "localPort": {{ .%s }}, "remotePort": %d}]}`, "proxies": [{"name": "tcp", "type": "tcp", "localPort": {{ .%s }}, "remotePort": %d}]}`,
port.GenName("Server"), framework.TCPEchoServerPort, remotePort) port.GenName("Server"), framework.TCPEchoServerPort, remotePort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
framework.NewRequestExpect(f).Port(remotePort).Ensure() framework.NewRequestExpect(f).Port(remotePort).Ensure()
}) })
}) })

View File

@@ -59,7 +59,7 @@ var _ = ginkgo.Describe("[Feature: HTTP]", func() {
locations = ["/bar"] locations = ["/bar"]
`, fooPort, barPort) `, fooPort, barPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
tests := []struct { tests := []struct {
path string path string
@@ -117,7 +117,7 @@ var _ = ginkgo.Describe("[Feature: HTTP]", func() {
customDomains = ["normal.example.com"] customDomains = ["normal.example.com"]
`, fooPort, barPort, otherPort) `, fooPort, barPort, otherPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
// user1 // user1
framework.NewRequestExpect(f).Explain("user1").Port(vhostHTTPPort). framework.NewRequestExpect(f).Explain("user1").Port(vhostHTTPPort).
@@ -159,7 +159,7 @@ var _ = ginkgo.Describe("[Feature: HTTP]", func() {
httpPassword = "test" httpPassword = "test"
`, framework.HTTPSimpleServerPort) `, framework.HTTPSimpleServerPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
// not set auth header // not set auth header
framework.NewRequestExpect(f).Port(vhostHTTPPort). framework.NewRequestExpect(f).Port(vhostHTTPPort).
@@ -196,7 +196,7 @@ var _ = ginkgo.Describe("[Feature: HTTP]", func() {
customDomains = ["*.example.com"] customDomains = ["*.example.com"]
`, framework.HTTPSimpleServerPort) `, framework.HTTPSimpleServerPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
// not match host // not match host
framework.NewRequestExpect(f).Port(vhostHTTPPort). framework.NewRequestExpect(f).Port(vhostHTTPPort).
@@ -248,7 +248,7 @@ var _ = ginkgo.Describe("[Feature: HTTP]", func() {
subdomain = "bar" subdomain = "bar"
`, fooPort, barPort) `, fooPort, barPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
// foo // foo
framework.NewRequestExpect(f).Explain("foo subdomain").Port(vhostHTTPPort). framework.NewRequestExpect(f).Explain("foo subdomain").Port(vhostHTTPPort).
@@ -290,7 +290,7 @@ var _ = ginkgo.Describe("[Feature: HTTP]", func() {
requestHeaders.set.x-from-where = "frp" requestHeaders.set.x-from-where = "frp"
`, localPort) `, localPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
framework.NewRequestExpect(f).Port(vhostHTTPPort). framework.NewRequestExpect(f).Port(vhostHTTPPort).
RequestModify(func(r *request.Request) { RequestModify(func(r *request.Request) {
@@ -323,7 +323,7 @@ var _ = ginkgo.Describe("[Feature: HTTP]", func() {
responseHeaders.set.x-from-where = "frp" responseHeaders.set.x-from-where = "frp"
`, localPort) `, localPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
framework.NewRequestExpect(f).Port(vhostHTTPPort). framework.NewRequestExpect(f).Port(vhostHTTPPort).
RequestModify(func(r *request.Request) { RequestModify(func(r *request.Request) {
@@ -357,7 +357,7 @@ var _ = ginkgo.Describe("[Feature: HTTP]", func() {
hostHeaderRewrite = "rewrite.example.com" hostHeaderRewrite = "rewrite.example.com"
`, localPort) `, localPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
framework.NewRequestExpect(f).Port(vhostHTTPPort). framework.NewRequestExpect(f).Port(vhostHTTPPort).
RequestModify(func(r *request.Request) { RequestModify(func(r *request.Request) {
@@ -406,7 +406,7 @@ var _ = ginkgo.Describe("[Feature: HTTP]", func() {
customDomains = ["127.0.0.1"] customDomains = ["127.0.0.1"]
`, localPort) `, localPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
u := url.URL{Scheme: "ws", Host: "127.0.0.1:" + strconv.Itoa(vhostHTTPPort)} u := url.URL{Scheme: "ws", Host: "127.0.0.1:" + strconv.Itoa(vhostHTTPPort)}
c, _, err := websocket.DefaultDialer.Dial(u.String(), nil) c, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
@@ -447,7 +447,7 @@ var _ = ginkgo.Describe("[Feature: HTTP]", func() {
customDomains = ["normal.example.com"] customDomains = ["normal.example.com"]
`, localPort) `, localPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
framework.NewRequestExpect(f).Port(vhostHTTPPort). framework.NewRequestExpect(f).Port(vhostHTTPPort).
RequestModify(func(r *request.Request) { RequestModify(func(r *request.Request) {

View File

@@ -33,7 +33,7 @@ var _ = ginkgo.Describe("[Feature: Server Manager]", func() {
udpPortName := port.GenName("UDP", port.WithRangePorts(12000, 13000)) udpPortName := port.GenName("UDP", port.WithRangePorts(12000, 13000))
clientConf += fmt.Sprintf(` clientConf += fmt.Sprintf(`
[[proxies]] [[proxies]]
name = "tcp-allowded-in-range" name = "tcp-allowed-in-range"
type = "tcp" type = "tcp"
localPort = {{ .%s }} localPort = {{ .%s }}
remotePort = {{ .%s }} remotePort = {{ .%s }}
@@ -67,7 +67,7 @@ var _ = ginkgo.Describe("[Feature: Server Manager]", func() {
remotePort = 11003 remotePort = 11003
`, framework.UDPEchoServerPort) `, framework.UDPEchoServerPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
// TCP // TCP
// Allowed in range // Allowed in range
@@ -108,7 +108,7 @@ var _ = ginkgo.Describe("[Feature: Server Manager]", func() {
localPort = {{ .%s }} localPort = {{ .%s }}
`, adminPort, framework.TCPEchoServerPort, framework.UDPEchoServerPort) `, adminPort, framework.TCPEchoServerPort, framework.UDPEchoServerPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
client := f.APIClientForFrpc(adminPort) client := f.APIClientForFrpc(adminPort)
@@ -150,7 +150,7 @@ var _ = ginkgo.Describe("[Feature: Server Manager]", func() {
customDomains = ["example.com"] customDomains = ["example.com"]
`, framework.HTTPSimpleServerPort) `, framework.HTTPSimpleServerPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
framework.NewRequestExpect(f).RequestModify(func(r *request.Request) { framework.NewRequestExpect(f).RequestModify(func(r *request.Request) {
r.HTTP().HTTPHost("example.com") r.HTTP().HTTPHost("example.com")
@@ -178,7 +178,7 @@ var _ = ginkgo.Describe("[Feature: Server Manager]", func() {
customDomains = ["example.com"] customDomains = ["example.com"]
`, framework.HTTPSimpleServerPort) `, framework.HTTPSimpleServerPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
framework.NewRequestExpect(f).RequestModify(func(r *request.Request) { framework.NewRequestExpect(f).RequestModify(func(r *request.Request) {
r.HTTP().HTTPPath("/healthz") r.HTTP().HTTPPath("/healthz")

View File

@@ -79,7 +79,7 @@ var _ = ginkgo.Describe("[Feature: TCPMUX httpconnect]", func() {
customDomains = ["normal.example.com"] customDomains = ["normal.example.com"]
`, fooPort, barPort, otherPort) `, fooPort, barPort, otherPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
// user1 // user1
framework.NewRequestExpect(f).Explain("user1"). framework.NewRequestExpect(f).Explain("user1").
@@ -125,7 +125,7 @@ var _ = ginkgo.Describe("[Feature: TCPMUX httpconnect]", func() {
httpPassword = "test" httpPassword = "test"
`, fooPort) `, fooPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
// not set auth header // not set auth header
framework.NewRequestExpect(f).Explain("no auth"). framework.NewRequestExpect(f).Explain("no auth").
@@ -209,7 +209,7 @@ var _ = ginkgo.Describe("[Feature: TCPMUX httpconnect]", func() {
customDomains = ["normal.example.com"] customDomains = ["normal.example.com"]
`, localPort) `, localPort)
f.RunProcesses([]string{serverConf}, []string{clientConf}) f.RunProcesses(serverConf, []string{clientConf})
framework.NewRequestExpect(f). framework.NewRequestExpect(f).
RequestModify(func(r *request.Request) { RequestModify(func(r *request.Request) {

Some files were not shown because too many files have changed in this diff Show More