Compare commits

...

5 Commits

Author SHA1 Message Date
fatedier
9eafcc8a95 server/group: remove blank line between doc comment and type declaration 2026-03-08 18:49:14 +08:00
fatedier
2de56556d9 server/group: replace tautological assertion with require.NotPanics 2026-03-08 18:19:12 +08:00
fatedier
0125ca9437 server/group: refactor group package with shared abstractions and fix concurrency issues
Extract common patterns into reusable components:
- groupRegistry[G]: generic concurrent map for group lifecycle management
- baseGroup: shared plumbing for listener-based groups (TCP, HTTPS, TCPMux)
- Listener: unified virtual listener replacing 3 identical implementations

Fix concurrency issues:
- Stale-pointer race: isCurrent check + errGroupStale + controller retry loops
- Worker generation safety: pass realLn and acceptCh as params instead of reading mutable fields
- Connection leak: close conn on worker panic recovery path
- ABBA deadlock in HTTP UnRegister: consistent lock ordering (group.mu -> registry.mu)
- Round-robin overflow in HTTPGroup: use unsigned modulo

Add unit tests (17 tests) for registry, listener, and baseGroup.
Add TCPMux group load balancing e2e test.
2026-03-08 14:16:48 +08:00
fatedier
eeb0dacfc1 pkg/metrics/mem: remove redundant map write-backs and optimize proxy lookup (#5221)
Remove 4 redundant pointer map write-backs in OpenConnection,
CloseConnection, AddTrafficIn, and AddTrafficOut since the map stores
pointers and mutations are already visible without reassignment.

Optimize GetProxiesByTypeAndName from O(n) full map scan to O(1) direct
map lookup by proxy name.
2026-03-08 10:40:39 +08:00
Oleksandr Redko
535eb3db35 refactor: use maps.Clone and slices.Concat (#5220) 2026-03-08 10:38:16 +08:00
15 changed files with 729 additions and 460 deletions

View File

@@ -143,7 +143,6 @@ func (m *serverMetrics) OpenConnection(name string, _ string) {
proxyStats, ok := m.info.ProxyStatistics[name]
if ok {
proxyStats.CurConns.Inc(1)
m.info.ProxyStatistics[name] = proxyStats
}
}
@@ -155,7 +154,6 @@ func (m *serverMetrics) CloseConnection(name string, _ string) {
proxyStats, ok := m.info.ProxyStatistics[name]
if ok {
proxyStats.CurConns.Dec(1)
m.info.ProxyStatistics[name] = proxyStats
}
}
@@ -168,7 +166,6 @@ func (m *serverMetrics) AddTrafficIn(name string, _ string, trafficBytes int64)
proxyStats, ok := m.info.ProxyStatistics[name]
if ok {
proxyStats.TrafficIn.Inc(trafficBytes)
m.info.ProxyStatistics[name] = proxyStats
}
}
@@ -181,7 +178,6 @@ func (m *serverMetrics) AddTrafficOut(name string, _ string, trafficBytes int64)
proxyStats, ok := m.info.ProxyStatistics[name]
if ok {
proxyStats.TrafficOut.Inc(trafficBytes)
m.info.ProxyStatistics[name] = proxyStats
}
}
@@ -240,15 +236,9 @@ func (m *serverMetrics) GetProxiesByTypeAndName(proxyType string, proxyName stri
m.mu.Lock()
defer m.mu.Unlock()
for name, proxyStats := range m.info.ProxyStatistics {
if proxyStats.ProxyType != proxyType {
continue
}
if name != proxyName {
continue
}
res = toProxyStats(name, proxyStats)
break
proxyStats, ok := m.info.ProxyStatistics[proxyName]
if ok && proxyStats.ProxyType == proxyType {
res = toProxyStats(proxyName, proxyStats)
}
return
}

View File

@@ -93,8 +93,7 @@ type featureGate struct {
// NewFeatureGate creates a new feature gate with the default features
func NewFeatureGate() MutableFeatureGate {
known := map[Feature]FeatureSpec{}
maps.Copy(known, defaultFeatures)
known := maps.Clone(defaultFeatures)
f := &featureGate{}
f.known.Store(known)
@@ -108,10 +107,8 @@ func (f *featureGate) SetFromMap(m map[string]bool) error {
defer f.lock.Unlock()
// Copy existing state
known := map[Feature]FeatureSpec{}
maps.Copy(known, f.known.Load().(map[Feature]FeatureSpec))
enabled := map[Feature]bool{}
maps.Copy(enabled, f.enabled.Load().(map[Feature]bool))
known := maps.Clone(f.known.Load().(map[Feature]FeatureSpec))
enabled := maps.Clone(f.enabled.Load().(map[Feature]bool))
// Apply the new settings
for k, v := range m {
@@ -142,8 +139,7 @@ func (f *featureGate) Add(features map[Feature]FeatureSpec) error {
}
// Copy existing state
known := map[Feature]FeatureSpec{}
maps.Copy(known, f.known.Load().(map[Feature]FeatureSpec))
known := maps.Clone(f.known.Load().(map[Feature]FeatureSpec))
// Add new features
for name, spec := range features {

77
server/group/base.go Normal file
View File

@@ -0,0 +1,77 @@
package group
import (
"net"
"sync"
gerr "github.com/fatedier/golib/errors"
)
// baseGroup contains the shared plumbing for listener-based groups
// (TCP, HTTPS, TCPMux). Each concrete group embeds this and provides
// its own Listen method with protocol-specific validation.
type baseGroup struct {
group string
groupKey string
acceptCh chan net.Conn
realLn net.Listener
lns []*Listener
mu sync.Mutex
cleanupFn func()
}
// initBase resets the baseGroup for a fresh listen cycle.
// Must be called under mu when len(lns) == 0.
func (bg *baseGroup) initBase(group, groupKey string, realLn net.Listener, cleanupFn func()) {
bg.group = group
bg.groupKey = groupKey
bg.realLn = realLn
bg.acceptCh = make(chan net.Conn)
bg.cleanupFn = cleanupFn
}
// worker reads from the real listener and fans out to acceptCh.
// The parameters are captured at creation time so that the worker is
// bound to a specific listen cycle and cannot observe a later initBase.
func (bg *baseGroup) worker(realLn net.Listener, acceptCh chan<- net.Conn) {
for {
c, err := realLn.Accept()
if err != nil {
return
}
err = gerr.PanicToError(func() {
acceptCh <- c
})
if err != nil {
c.Close()
return
}
}
}
// newListener creates a new Listener wired to this baseGroup.
// Must be called under mu.
func (bg *baseGroup) newListener(addr net.Addr) *Listener {
ln := newListener(bg.acceptCh, addr, bg.closeListener)
bg.lns = append(bg.lns, ln)
return ln
}
// closeListener removes ln from the list. When the last listener is removed,
// it closes acceptCh, closes the real listener, and calls cleanupFn.
func (bg *baseGroup) closeListener(ln *Listener) {
bg.mu.Lock()
defer bg.mu.Unlock()
for i, l := range bg.lns {
if l == ln {
bg.lns = append(bg.lns[:i], bg.lns[i+1:]...)
break
}
}
if len(bg.lns) == 0 {
close(bg.acceptCh)
bg.realLn.Close()
bg.cleanupFn()
}
}

169
server/group/base_test.go Normal file
View File

@@ -0,0 +1,169 @@
package group
import (
"net"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// fakeLn is a controllable net.Listener for tests.
type fakeLn struct {
connCh chan net.Conn
closed chan struct{}
once sync.Once
}
func newFakeLn() *fakeLn {
return &fakeLn{
connCh: make(chan net.Conn, 8),
closed: make(chan struct{}),
}
}
func (f *fakeLn) Accept() (net.Conn, error) {
select {
case c := <-f.connCh:
return c, nil
case <-f.closed:
return nil, net.ErrClosed
}
}
func (f *fakeLn) Close() error {
f.once.Do(func() { close(f.closed) })
return nil
}
func (f *fakeLn) Addr() net.Addr { return fakeAddr("127.0.0.1:9999") }
func (f *fakeLn) inject(c net.Conn) {
select {
case f.connCh <- c:
case <-f.closed:
}
}
func TestBaseGroup_WorkerFanOut(t *testing.T) {
fl := newFakeLn()
var bg baseGroup
bg.initBase("g", "key", fl, func() {})
go bg.worker(fl, bg.acceptCh)
c1, c2 := net.Pipe()
defer c2.Close()
fl.inject(c1)
select {
case got := <-bg.acceptCh:
assert.Equal(t, c1, got)
got.Close()
case <-time.After(time.Second):
t.Fatal("timed out waiting for connection on acceptCh")
}
fl.Close()
}
func TestBaseGroup_WorkerStopsOnListenerClose(t *testing.T) {
fl := newFakeLn()
var bg baseGroup
bg.initBase("g", "key", fl, func() {})
done := make(chan struct{})
go func() {
bg.worker(fl, bg.acceptCh)
close(done)
}()
fl.Close()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("worker did not stop after listener close")
}
}
func TestBaseGroup_WorkerClosesConnOnClosedChannel(t *testing.T) {
fl := newFakeLn()
var bg baseGroup
bg.initBase("g", "key", fl, func() {})
// Close acceptCh before worker sends.
close(bg.acceptCh)
done := make(chan struct{})
go func() {
bg.worker(fl, bg.acceptCh)
close(done)
}()
c1, c2 := net.Pipe()
defer c2.Close()
fl.inject(c1)
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("worker did not stop after panic recovery")
}
// c1 should have been closed by worker's panic recovery path.
buf := make([]byte, 1)
_, err := c1.Read(buf)
assert.Error(t, err, "connection should be closed by worker")
}
func TestBaseGroup_CloseLastListenerTriggersCleanup(t *testing.T) {
fl := newFakeLn()
var bg baseGroup
cleanupCalled := 0
bg.initBase("g", "key", fl, func() { cleanupCalled++ })
bg.mu.Lock()
ln1 := bg.newListener(fl.Addr())
ln2 := bg.newListener(fl.Addr())
bg.mu.Unlock()
go bg.worker(fl, bg.acceptCh)
ln1.Close()
assert.Equal(t, 0, cleanupCalled, "cleanup should not run while listeners remain")
ln2.Close()
assert.Equal(t, 1, cleanupCalled, "cleanup should run after last listener closed")
}
func TestBaseGroup_CloseOneOfTwoListeners(t *testing.T) {
fl := newFakeLn()
var bg baseGroup
cleanupCalled := 0
bg.initBase("g", "key", fl, func() { cleanupCalled++ })
bg.mu.Lock()
ln1 := bg.newListener(fl.Addr())
ln2 := bg.newListener(fl.Addr())
bg.mu.Unlock()
go bg.worker(fl, bg.acceptCh)
ln1.Close()
assert.Equal(t, 0, cleanupCalled)
// ln2 should still receive connections.
c1, c2 := net.Pipe()
defer c2.Close()
fl.inject(c1)
got, err := ln2.Accept()
require.NoError(t, err)
assert.Equal(t, c1, got)
got.Close()
ln2.Close()
assert.Equal(t, 1, cleanupCalled)
}

View File

@@ -24,4 +24,6 @@ var (
ErrListenerClosed = errors.New("group listener closed")
ErrGroupDifferentPort = errors.New("group should have same remote port")
ErrProxyRepeated = errors.New("group proxy repeated")
errGroupStale = errors.New("stale group reference")
)

View File

@@ -9,53 +9,42 @@ import (
"github.com/fatedier/frp/pkg/util/vhost"
)
// HTTPGroupController manages HTTP groups that use round-robin
// callback routing (fundamentally different from listener-based groups).
type HTTPGroupController struct {
// groups indexed by group name
groups map[string]*HTTPGroup
// register createConn for each group to vhostRouter.
// createConn will get a connection from one proxy of the group
groupRegistry[*HTTPGroup]
vhostRouter *vhost.Routers
mu sync.Mutex
}
func NewHTTPGroupController(vhostRouter *vhost.Routers) *HTTPGroupController {
return &HTTPGroupController{
groups: make(map[string]*HTTPGroup),
vhostRouter: vhostRouter,
groupRegistry: newGroupRegistry[*HTTPGroup](),
vhostRouter: vhostRouter,
}
}
func (ctl *HTTPGroupController) Register(
proxyName, group, groupKey string,
routeConfig vhost.RouteConfig,
) (err error) {
indexKey := group
ctl.mu.Lock()
g, ok := ctl.groups[indexKey]
if !ok {
g = NewHTTPGroup(ctl)
ctl.groups[indexKey] = g
) error {
for {
g := ctl.getOrCreate(group, func() *HTTPGroup {
return NewHTTPGroup(ctl)
})
err := g.Register(proxyName, group, groupKey, routeConfig)
if err == errGroupStale {
continue
}
return err
}
ctl.mu.Unlock()
return g.Register(proxyName, group, groupKey, routeConfig)
}
func (ctl *HTTPGroupController) UnRegister(proxyName, group string, _ vhost.RouteConfig) {
indexKey := group
ctl.mu.Lock()
defer ctl.mu.Unlock()
g, ok := ctl.groups[indexKey]
g, ok := ctl.get(group)
if !ok {
return
}
isEmpty := g.UnRegister(proxyName)
if isEmpty {
delete(ctl.groups, indexKey)
}
g.UnRegister(proxyName)
}
type HTTPGroup struct {
@@ -87,6 +76,9 @@ func (g *HTTPGroup) Register(
) (err error) {
g.mu.Lock()
defer g.mu.Unlock()
if !g.ctl.isCurrent(group, func(cur *HTTPGroup) bool { return cur == g }) {
return errGroupStale
}
if len(g.createFuncs) == 0 {
// the first proxy in this group
tmp := routeConfig // copy object
@@ -123,7 +115,7 @@ func (g *HTTPGroup) Register(
return nil
}
func (g *HTTPGroup) UnRegister(proxyName string) (isEmpty bool) {
func (g *HTTPGroup) UnRegister(proxyName string) {
g.mu.Lock()
defer g.mu.Unlock()
delete(g.createFuncs, proxyName)
@@ -135,10 +127,11 @@ func (g *HTTPGroup) UnRegister(proxyName string) (isEmpty bool) {
}
if len(g.createFuncs) == 0 {
isEmpty = true
g.ctl.vhostRouter.Del(g.domain, g.location, g.routeByHTTPUser)
g.ctl.removeIf(g.group, func(cur *HTTPGroup) bool {
return cur == g
})
}
return
}
func (g *HTTPGroup) createConn(remoteAddr string) (net.Conn, error) {
@@ -151,7 +144,7 @@ func (g *HTTPGroup) createConn(remoteAddr string) (net.Conn, error) {
location := g.location
routeByHTTPUser := g.routeByHTTPUser
if len(g.pxyNames) > 0 {
name := g.pxyNames[int(newIndex)%len(g.pxyNames)]
name := g.pxyNames[newIndex%uint64(len(g.pxyNames))]
f = g.createFuncs[name]
}
g.mu.RUnlock()
@@ -174,7 +167,7 @@ func (g *HTTPGroup) chooseEndpoint() (string, error) {
location := g.location
routeByHTTPUser := g.routeByHTTPUser
if len(g.pxyNames) > 0 {
name = g.pxyNames[int(newIndex)%len(g.pxyNames)]
name = g.pxyNames[newIndex%uint64(len(g.pxyNames))]
}
g.mu.RUnlock()

View File

@@ -17,25 +17,19 @@ package group
import (
"context"
"net"
"sync"
gerr "github.com/fatedier/golib/errors"
"github.com/fatedier/frp/pkg/util/vhost"
)
type HTTPSGroupController struct {
groups map[string]*HTTPSGroup
groupRegistry[*HTTPSGroup]
httpsMuxer *vhost.HTTPSMuxer
mu sync.Mutex
}
func NewHTTPSGroupController(httpsMuxer *vhost.HTTPSMuxer) *HTTPSGroupController {
return &HTTPSGroupController{
groups: make(map[string]*HTTPSGroup),
httpsMuxer: httpsMuxer,
groupRegistry: newGroupRegistry[*HTTPSGroup](),
httpsMuxer: httpsMuxer,
}
}
@@ -44,41 +38,28 @@ func (ctl *HTTPSGroupController) Listen(
group, groupKey string,
routeConfig vhost.RouteConfig,
) (l net.Listener, err error) {
indexKey := group
ctl.mu.Lock()
g, ok := ctl.groups[indexKey]
if !ok {
g = NewHTTPSGroup(ctl)
ctl.groups[indexKey] = g
for {
g := ctl.getOrCreate(group, func() *HTTPSGroup {
return NewHTTPSGroup(ctl)
})
l, err = g.Listen(ctx, group, groupKey, routeConfig)
if err == errGroupStale {
continue
}
return
}
ctl.mu.Unlock()
return g.Listen(ctx, group, groupKey, routeConfig)
}
func (ctl *HTTPSGroupController) RemoveGroup(group string) {
ctl.mu.Lock()
defer ctl.mu.Unlock()
delete(ctl.groups, group)
}
type HTTPSGroup struct {
group string
groupKey string
domain string
baseGroup
acceptCh chan net.Conn
httpsLn *vhost.Listener
lns []*HTTPSGroupListener
ctl *HTTPSGroupController
mu sync.Mutex
domain string
ctl *HTTPSGroupController
}
func NewHTTPSGroup(ctl *HTTPSGroupController) *HTTPSGroup {
return &HTTPSGroup{
lns: make([]*HTTPSGroupListener, 0),
ctl: ctl,
acceptCh: make(chan net.Conn),
ctl: ctl,
}
}
@@ -86,23 +67,27 @@ func (g *HTTPSGroup) Listen(
ctx context.Context,
group, groupKey string,
routeConfig vhost.RouteConfig,
) (ln *HTTPSGroupListener, err error) {
) (ln *Listener, err error) {
g.mu.Lock()
defer g.mu.Unlock()
if !g.ctl.isCurrent(group, func(cur *HTTPSGroup) bool { return cur == g }) {
return nil, errGroupStale
}
if len(g.lns) == 0 {
// the first listener, listen on the real address
httpsLn, errRet := g.ctl.httpsMuxer.Listen(ctx, &routeConfig)
if errRet != nil {
return nil, errRet
}
ln = newHTTPSGroupListener(group, g, httpsLn.Addr())
g.group = group
g.groupKey = groupKey
g.domain = routeConfig.Domain
g.httpsLn = httpsLn
g.lns = append(g.lns, ln)
go g.worker()
g.initBase(group, groupKey, httpsLn, func() {
g.ctl.removeIf(g.group, func(cur *HTTPSGroup) bool {
return cur == g
})
})
ln = g.newListener(httpsLn.Addr())
go g.worker(httpsLn, g.acceptCh)
} else {
// route config in the same group must be equal
if g.group != group || g.domain != routeConfig.Domain {
@@ -111,87 +96,7 @@ func (g *HTTPSGroup) Listen(
if g.groupKey != groupKey {
return nil, ErrGroupAuthFailed
}
ln = newHTTPSGroupListener(group, g, g.lns[0].Addr())
g.lns = append(g.lns, ln)
ln = g.newListener(g.lns[0].Addr())
}
return
}
func (g *HTTPSGroup) worker() {
for {
c, err := g.httpsLn.Accept()
if err != nil {
return
}
err = gerr.PanicToError(func() {
g.acceptCh <- c
})
if err != nil {
return
}
}
}
func (g *HTTPSGroup) Accept() <-chan net.Conn {
return g.acceptCh
}
func (g *HTTPSGroup) CloseListener(ln *HTTPSGroupListener) {
g.mu.Lock()
defer g.mu.Unlock()
for i, tmpLn := range g.lns {
if tmpLn == ln {
g.lns = append(g.lns[:i], g.lns[i+1:]...)
break
}
}
if len(g.lns) == 0 {
close(g.acceptCh)
if g.httpsLn != nil {
g.httpsLn.Close()
}
g.ctl.RemoveGroup(g.group)
}
}
type HTTPSGroupListener struct {
groupName string
group *HTTPSGroup
addr net.Addr
closeCh chan struct{}
}
func newHTTPSGroupListener(name string, group *HTTPSGroup, addr net.Addr) *HTTPSGroupListener {
return &HTTPSGroupListener{
groupName: name,
group: group,
addr: addr,
closeCh: make(chan struct{}),
}
}
func (ln *HTTPSGroupListener) Accept() (c net.Conn, err error) {
var ok bool
select {
case <-ln.closeCh:
return nil, ErrListenerClosed
case c, ok = <-ln.group.Accept():
if !ok {
return nil, ErrListenerClosed
}
return c, nil
}
}
func (ln *HTTPSGroupListener) Addr() net.Addr {
return ln.addr
}
func (ln *HTTPSGroupListener) Close() (err error) {
close(ln.closeCh)
// remove self from HTTPSGroup
ln.group.CloseListener(ln)
return
}

49
server/group/listener.go Normal file
View File

@@ -0,0 +1,49 @@
package group
import (
"net"
"sync"
)
// Listener is a per-proxy virtual listener that receives connections
// from a shared group. It implements net.Listener.
type Listener struct {
acceptCh <-chan net.Conn
addr net.Addr
closeCh chan struct{}
onClose func(*Listener)
once sync.Once
}
func newListener(acceptCh <-chan net.Conn, addr net.Addr, onClose func(*Listener)) *Listener {
return &Listener{
acceptCh: acceptCh,
addr: addr,
closeCh: make(chan struct{}),
onClose: onClose,
}
}
func (ln *Listener) Accept() (net.Conn, error) {
select {
case <-ln.closeCh:
return nil, ErrListenerClosed
case c, ok := <-ln.acceptCh:
if !ok {
return nil, ErrListenerClosed
}
return c, nil
}
}
func (ln *Listener) Addr() net.Addr {
return ln.addr
}
func (ln *Listener) Close() error {
ln.once.Do(func() {
close(ln.closeCh)
ln.onClose(ln)
})
return nil
}

View File

@@ -0,0 +1,68 @@
package group
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestListener_Accept(t *testing.T) {
acceptCh := make(chan net.Conn, 1)
ln := newListener(acceptCh, fakeAddr("127.0.0.1:1234"), func(*Listener) {})
c1, c2 := net.Pipe()
defer c1.Close()
defer c2.Close()
acceptCh <- c1
got, err := ln.Accept()
require.NoError(t, err)
assert.Equal(t, c1, got)
}
func TestListener_AcceptAfterChannelClose(t *testing.T) {
acceptCh := make(chan net.Conn)
ln := newListener(acceptCh, fakeAddr("127.0.0.1:1234"), func(*Listener) {})
close(acceptCh)
_, err := ln.Accept()
assert.ErrorIs(t, err, ErrListenerClosed)
}
func TestListener_AcceptAfterListenerClose(t *testing.T) {
acceptCh := make(chan net.Conn) // open, not closed
ln := newListener(acceptCh, fakeAddr("127.0.0.1:1234"), func(*Listener) {})
ln.Close()
_, err := ln.Accept()
assert.ErrorIs(t, err, ErrListenerClosed)
}
func TestListener_DoubleClose(t *testing.T) {
closeCalls := 0
ln := newListener(
make(chan net.Conn),
fakeAddr("127.0.0.1:1234"),
func(*Listener) { closeCalls++ },
)
assert.NotPanics(t, func() {
ln.Close()
ln.Close()
})
assert.Equal(t, 1, closeCalls, "onClose should be called exactly once")
}
func TestListener_Addr(t *testing.T) {
addr := fakeAddr("10.0.0.1:5555")
ln := newListener(make(chan net.Conn), addr, func(*Listener) {})
assert.Equal(t, addr, ln.Addr())
}
// fakeAddr implements net.Addr for testing.
type fakeAddr string
func (a fakeAddr) Network() string { return "tcp" }
func (a fakeAddr) String() string { return string(a) }

59
server/group/registry.go Normal file
View File

@@ -0,0 +1,59 @@
package group
import (
"sync"
)
// groupRegistry is a concurrent map of named groups with
// automatic creation on first access.
type groupRegistry[G any] struct {
groups map[string]G
mu sync.Mutex
}
func newGroupRegistry[G any]() groupRegistry[G] {
return groupRegistry[G]{
groups: make(map[string]G),
}
}
func (r *groupRegistry[G]) getOrCreate(key string, newFn func() G) G {
r.mu.Lock()
defer r.mu.Unlock()
g, ok := r.groups[key]
if !ok {
g = newFn()
r.groups[key] = g
}
return g
}
func (r *groupRegistry[G]) get(key string) (G, bool) {
r.mu.Lock()
defer r.mu.Unlock()
g, ok := r.groups[key]
return g, ok
}
// isCurrent returns true if key exists in the registry and matchFn
// returns true for the stored value.
func (r *groupRegistry[G]) isCurrent(key string, matchFn func(G) bool) bool {
r.mu.Lock()
defer r.mu.Unlock()
g, ok := r.groups[key]
return ok && matchFn(g)
}
// removeIf atomically looks up the group for key, calls fn on it,
// and removes the entry if fn returns true.
func (r *groupRegistry[G]) removeIf(key string, fn func(G) bool) {
r.mu.Lock()
defer r.mu.Unlock()
g, ok := r.groups[key]
if !ok {
return
}
if fn(g) {
delete(r.groups, key)
}
}

View File

@@ -0,0 +1,102 @@
package group
import (
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetOrCreate_New(t *testing.T) {
r := newGroupRegistry[*int]()
called := 0
v := 42
got := r.getOrCreate("k", func() *int { called++; return &v })
assert.Equal(t, 1, called)
assert.Equal(t, &v, got)
}
func TestGetOrCreate_Existing(t *testing.T) {
r := newGroupRegistry[*int]()
v := 42
r.getOrCreate("k", func() *int { return &v })
called := 0
got := r.getOrCreate("k", func() *int { called++; return nil })
assert.Equal(t, 0, called)
assert.Equal(t, &v, got)
}
func TestGet_ExistingAndMissing(t *testing.T) {
r := newGroupRegistry[*int]()
v := 1
r.getOrCreate("k", func() *int { return &v })
got, ok := r.get("k")
assert.True(t, ok)
assert.Equal(t, &v, got)
_, ok = r.get("missing")
assert.False(t, ok)
}
func TestIsCurrent(t *testing.T) {
r := newGroupRegistry[*int]()
v1 := 1
v2 := 2
r.getOrCreate("k", func() *int { return &v1 })
assert.True(t, r.isCurrent("k", func(g *int) bool { return g == &v1 }))
assert.False(t, r.isCurrent("k", func(g *int) bool { return g == &v2 }))
assert.False(t, r.isCurrent("missing", func(g *int) bool { return true }))
}
func TestRemoveIf(t *testing.T) {
t.Run("removes when fn returns true", func(t *testing.T) {
r := newGroupRegistry[*int]()
v := 1
r.getOrCreate("k", func() *int { return &v })
r.removeIf("k", func(g *int) bool { return g == &v })
_, ok := r.get("k")
assert.False(t, ok)
})
t.Run("keeps when fn returns false", func(t *testing.T) {
r := newGroupRegistry[*int]()
v := 1
r.getOrCreate("k", func() *int { return &v })
r.removeIf("k", func(g *int) bool { return false })
_, ok := r.get("k")
assert.True(t, ok)
})
t.Run("noop on missing key", func(t *testing.T) {
r := newGroupRegistry[*int]()
r.removeIf("missing", func(g *int) bool { return true }) // should not panic
})
}
func TestConcurrentGetOrCreateAndRemoveIf(t *testing.T) {
r := newGroupRegistry[*int]()
const n = 100
var wg sync.WaitGroup
wg.Add(n * 2)
for i := range n {
v := i
go func() {
defer wg.Done()
r.getOrCreate("k", func() *int { return &v })
}()
go func() {
defer wg.Done()
r.removeIf("k", func(*int) bool { return true })
}()
}
wg.Wait()
// After all goroutines finish, accessing the key must not panic.
require.NotPanics(t, func() {
_, _ = r.get("k")
})
}

View File

@@ -17,83 +17,67 @@ package group
import (
"net"
"strconv"
"sync"
gerr "github.com/fatedier/golib/errors"
"github.com/fatedier/frp/server/ports"
)
// TCPGroupCtl manage all TCPGroups
// TCPGroupCtl manages all TCPGroups.
type TCPGroupCtl struct {
groups map[string]*TCPGroup
// portManager is used to manage port
groupRegistry[*TCPGroup]
portManager *ports.Manager
mu sync.Mutex
}
// NewTCPGroupCtl return a new TcpGroupCtl
// NewTCPGroupCtl returns a new TCPGroupCtl.
func NewTCPGroupCtl(portManager *ports.Manager) *TCPGroupCtl {
return &TCPGroupCtl{
groups: make(map[string]*TCPGroup),
portManager: portManager,
groupRegistry: newGroupRegistry[*TCPGroup](),
portManager: portManager,
}
}
// Listen is the wrapper for TCPGroup's Listen
// If there are no group, we will create one here
// Listen is the wrapper for TCPGroup's Listen.
// If there is no group, one will be created.
func (tgc *TCPGroupCtl) Listen(proxyName string, group string, groupKey string,
addr string, port int,
) (l net.Listener, realPort int, err error) {
tgc.mu.Lock()
tcpGroup, ok := tgc.groups[group]
if !ok {
tcpGroup = NewTCPGroup(tgc)
tgc.groups[group] = tcpGroup
for {
tcpGroup := tgc.getOrCreate(group, func() *TCPGroup {
return NewTCPGroup(tgc)
})
l, realPort, err = tcpGroup.Listen(proxyName, group, groupKey, addr, port)
if err == errGroupStale {
continue
}
return
}
tgc.mu.Unlock()
return tcpGroup.Listen(proxyName, group, groupKey, addr, port)
}
// RemoveGroup remove TCPGroup from controller
func (tgc *TCPGroupCtl) RemoveGroup(group string) {
tgc.mu.Lock()
defer tgc.mu.Unlock()
delete(tgc.groups, group)
}
// TCPGroup route connections to different proxies
// TCPGroup routes connections to different proxies.
type TCPGroup struct {
group string
groupKey string
baseGroup
addr string
port int
realPort int
acceptCh chan net.Conn
tcpLn net.Listener
lns []*TCPGroupListener
ctl *TCPGroupCtl
mu sync.Mutex
}
// NewTCPGroup return a new TCPGroup
// NewTCPGroup returns a new TCPGroup.
func NewTCPGroup(ctl *TCPGroupCtl) *TCPGroup {
return &TCPGroup{
lns: make([]*TCPGroupListener, 0),
ctl: ctl,
acceptCh: make(chan net.Conn),
ctl: ctl,
}
}
// Listen will return a new TCPGroupListener
// if TCPGroup already has a listener, just add a new TCPGroupListener to the queues
// otherwise, listen on the real address
func (tg *TCPGroup) Listen(proxyName string, group string, groupKey string, addr string, port int) (ln *TCPGroupListener, realPort int, err error) {
// Listen will return a new Listener.
// If TCPGroup already has a listener, just add a new Listener to the queues,
// otherwise listen on the real address.
func (tg *TCPGroup) Listen(proxyName string, group string, groupKey string, addr string, port int) (ln *Listener, realPort int, err error) {
tg.mu.Lock()
defer tg.mu.Unlock()
if !tg.ctl.isCurrent(group, func(cur *TCPGroup) bool { return cur == tg }) {
return nil, 0, errGroupStale
}
if len(tg.lns) == 0 {
// the first listener, listen on the real address
realPort, err = tg.ctl.portManager.Acquire(proxyName, port)
@@ -106,19 +90,18 @@ func (tg *TCPGroup) Listen(proxyName string, group string, groupKey string, addr
err = errRet
return
}
ln = newTCPGroupListener(group, tg, tcpLn.Addr())
tg.group = group
tg.groupKey = groupKey
tg.addr = addr
tg.port = port
tg.realPort = realPort
tg.tcpLn = tcpLn
tg.lns = append(tg.lns, ln)
if tg.acceptCh == nil {
tg.acceptCh = make(chan net.Conn)
}
go tg.worker()
tg.initBase(group, groupKey, tcpLn, func() {
tg.ctl.portManager.Release(tg.realPort)
tg.ctl.removeIf(tg.group, func(cur *TCPGroup) bool {
return cur == tg
})
})
ln = tg.newListener(tcpLn.Addr())
go tg.worker(tcpLn, tg.acceptCh)
} else {
// address and port in the same group must be equal
if tg.group != group || tg.addr != addr {
@@ -133,92 +116,8 @@ func (tg *TCPGroup) Listen(proxyName string, group string, groupKey string, addr
err = ErrGroupAuthFailed
return
}
ln = newTCPGroupListener(group, tg, tg.lns[0].Addr())
ln = tg.newListener(tg.lns[0].Addr())
realPort = tg.realPort
tg.lns = append(tg.lns, ln)
}
return
}
// worker is called when the real tcp listener has been created
func (tg *TCPGroup) worker() {
for {
c, err := tg.tcpLn.Accept()
if err != nil {
return
}
err = gerr.PanicToError(func() {
tg.acceptCh <- c
})
if err != nil {
return
}
}
}
func (tg *TCPGroup) Accept() <-chan net.Conn {
return tg.acceptCh
}
// CloseListener remove the TCPGroupListener from the TCPGroup
func (tg *TCPGroup) CloseListener(ln *TCPGroupListener) {
tg.mu.Lock()
defer tg.mu.Unlock()
for i, tmpLn := range tg.lns {
if tmpLn == ln {
tg.lns = append(tg.lns[:i], tg.lns[i+1:]...)
break
}
}
if len(tg.lns) == 0 {
close(tg.acceptCh)
tg.tcpLn.Close()
tg.ctl.portManager.Release(tg.realPort)
tg.ctl.RemoveGroup(tg.group)
}
}
// TCPGroupListener
type TCPGroupListener struct {
groupName string
group *TCPGroup
addr net.Addr
closeCh chan struct{}
}
func newTCPGroupListener(name string, group *TCPGroup, addr net.Addr) *TCPGroupListener {
return &TCPGroupListener{
groupName: name,
group: group,
addr: addr,
closeCh: make(chan struct{}),
}
}
// Accept will accept connections from TCPGroup
func (ln *TCPGroupListener) Accept() (c net.Conn, err error) {
var ok bool
select {
case <-ln.closeCh:
return nil, ErrListenerClosed
case c, ok = <-ln.group.Accept():
if !ok {
return nil, ErrListenerClosed
}
return c, nil
}
}
func (ln *TCPGroupListener) Addr() net.Addr {
return ln.addr
}
// Close close the listener
func (ln *TCPGroupListener) Close() (err error) {
close(ln.closeCh)
// remove self from TcpGroup
ln.group.CloseListener(ln)
return
}

View File

@@ -18,118 +18,100 @@ import (
"context"
"fmt"
"net"
"sync"
gerr "github.com/fatedier/golib/errors"
v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/util/tcpmux"
"github.com/fatedier/frp/pkg/util/vhost"
)
// TCPMuxGroupCtl manage all TCPMuxGroups
// TCPMuxGroupCtl manages all TCPMuxGroups.
type TCPMuxGroupCtl struct {
groups map[string]*TCPMuxGroup
// portManager is used to manage port
groupRegistry[*TCPMuxGroup]
tcpMuxHTTPConnectMuxer *tcpmux.HTTPConnectTCPMuxer
mu sync.Mutex
}
// NewTCPMuxGroupCtl return a new TCPMuxGroupCtl
// NewTCPMuxGroupCtl returns a new TCPMuxGroupCtl.
func NewTCPMuxGroupCtl(tcpMuxHTTPConnectMuxer *tcpmux.HTTPConnectTCPMuxer) *TCPMuxGroupCtl {
return &TCPMuxGroupCtl{
groups: make(map[string]*TCPMuxGroup),
groupRegistry: newGroupRegistry[*TCPMuxGroup](),
tcpMuxHTTPConnectMuxer: tcpMuxHTTPConnectMuxer,
}
}
// Listen is the wrapper for TCPMuxGroup's Listen
// If there are no group, we will create one here
// Listen is the wrapper for TCPMuxGroup's Listen.
// If there is no group, one will be created.
func (tmgc *TCPMuxGroupCtl) Listen(
ctx context.Context,
multiplexer, group, groupKey string,
routeConfig vhost.RouteConfig,
) (l net.Listener, err error) {
tmgc.mu.Lock()
tcpMuxGroup, ok := tmgc.groups[group]
if !ok {
tcpMuxGroup = NewTCPMuxGroup(tmgc)
tmgc.groups[group] = tcpMuxGroup
}
tmgc.mu.Unlock()
for {
tcpMuxGroup := tmgc.getOrCreate(group, func() *TCPMuxGroup {
return NewTCPMuxGroup(tmgc)
})
switch v1.TCPMultiplexerType(multiplexer) {
case v1.TCPMultiplexerHTTPConnect:
return tcpMuxGroup.HTTPConnectListen(ctx, group, groupKey, routeConfig)
default:
err = fmt.Errorf("unknown multiplexer [%s]", multiplexer)
return
switch v1.TCPMultiplexerType(multiplexer) {
case v1.TCPMultiplexerHTTPConnect:
l, err = tcpMuxGroup.HTTPConnectListen(ctx, group, groupKey, routeConfig)
if err == errGroupStale {
continue
}
return
default:
return nil, fmt.Errorf("unknown multiplexer [%s]", multiplexer)
}
}
}
// RemoveGroup remove TCPMuxGroup from controller
func (tmgc *TCPMuxGroupCtl) RemoveGroup(group string) {
tmgc.mu.Lock()
defer tmgc.mu.Unlock()
delete(tmgc.groups, group)
}
// TCPMuxGroup route connections to different proxies
// TCPMuxGroup routes connections to different proxies.
type TCPMuxGroup struct {
group string
groupKey string
baseGroup
domain string
routeByHTTPUser string
username string
password string
acceptCh chan net.Conn
tcpMuxLn net.Listener
lns []*TCPMuxGroupListener
ctl *TCPMuxGroupCtl
mu sync.Mutex
ctl *TCPMuxGroupCtl
}
// NewTCPMuxGroup return a new TCPMuxGroup
// NewTCPMuxGroup returns a new TCPMuxGroup.
func NewTCPMuxGroup(ctl *TCPMuxGroupCtl) *TCPMuxGroup {
return &TCPMuxGroup{
lns: make([]*TCPMuxGroupListener, 0),
ctl: ctl,
acceptCh: make(chan net.Conn),
ctl: ctl,
}
}
// Listen will return a new TCPMuxGroupListener
// if TCPMuxGroup already has a listener, just add a new TCPMuxGroupListener to the queues
// otherwise, listen on the real address
// HTTPConnectListen will return a new Listener.
// If TCPMuxGroup already has a listener, just add a new Listener to the queues,
// otherwise listen on the real address.
func (tmg *TCPMuxGroup) HTTPConnectListen(
ctx context.Context,
group, groupKey string,
routeConfig vhost.RouteConfig,
) (ln *TCPMuxGroupListener, err error) {
) (ln *Listener, err error) {
tmg.mu.Lock()
defer tmg.mu.Unlock()
if !tmg.ctl.isCurrent(group, func(cur *TCPMuxGroup) bool { return cur == tmg }) {
return nil, errGroupStale
}
if len(tmg.lns) == 0 {
// the first listener, listen on the real address
tcpMuxLn, errRet := tmg.ctl.tcpMuxHTTPConnectMuxer.Listen(ctx, &routeConfig)
if errRet != nil {
return nil, errRet
}
ln = newTCPMuxGroupListener(group, tmg, tcpMuxLn.Addr())
tmg.group = group
tmg.groupKey = groupKey
tmg.domain = routeConfig.Domain
tmg.routeByHTTPUser = routeConfig.RouteByHTTPUser
tmg.username = routeConfig.Username
tmg.password = routeConfig.Password
tmg.tcpMuxLn = tcpMuxLn
tmg.lns = append(tmg.lns, ln)
if tmg.acceptCh == nil {
tmg.acceptCh = make(chan net.Conn)
}
go tmg.worker()
tmg.initBase(group, groupKey, tcpMuxLn, func() {
tmg.ctl.removeIf(tmg.group, func(cur *TCPMuxGroup) bool {
return cur == tmg
})
})
ln = tmg.newListener(tcpMuxLn.Addr())
go tmg.worker(tcpMuxLn, tmg.acceptCh)
} else {
// route config in the same group must be equal
if tmg.group != group || tmg.domain != routeConfig.Domain ||
@@ -141,90 +123,7 @@ func (tmg *TCPMuxGroup) HTTPConnectListen(
if tmg.groupKey != groupKey {
return nil, ErrGroupAuthFailed
}
ln = newTCPMuxGroupListener(group, tmg, tmg.lns[0].Addr())
tmg.lns = append(tmg.lns, ln)
ln = tmg.newListener(tmg.lns[0].Addr())
}
return
}
// worker is called when the real TCP listener has been created
func (tmg *TCPMuxGroup) worker() {
for {
c, err := tmg.tcpMuxLn.Accept()
if err != nil {
return
}
err = gerr.PanicToError(func() {
tmg.acceptCh <- c
})
if err != nil {
return
}
}
}
func (tmg *TCPMuxGroup) Accept() <-chan net.Conn {
return tmg.acceptCh
}
// CloseListener remove the TCPMuxGroupListener from the TCPMuxGroup
func (tmg *TCPMuxGroup) CloseListener(ln *TCPMuxGroupListener) {
tmg.mu.Lock()
defer tmg.mu.Unlock()
for i, tmpLn := range tmg.lns {
if tmpLn == ln {
tmg.lns = append(tmg.lns[:i], tmg.lns[i+1:]...)
break
}
}
if len(tmg.lns) == 0 {
close(tmg.acceptCh)
tmg.tcpMuxLn.Close()
tmg.ctl.RemoveGroup(tmg.group)
}
}
// TCPMuxGroupListener
type TCPMuxGroupListener struct {
groupName string
group *TCPMuxGroup
addr net.Addr
closeCh chan struct{}
}
func newTCPMuxGroupListener(name string, group *TCPMuxGroup, addr net.Addr) *TCPMuxGroupListener {
return &TCPMuxGroupListener{
groupName: name,
group: group,
addr: addr,
closeCh: make(chan struct{}),
}
}
// Accept will accept connections from TCPMuxGroup
func (ln *TCPMuxGroupListener) Accept() (c net.Conn, err error) {
var ok bool
select {
case <-ln.closeCh:
return nil, ErrListenerClosed
case c, ok = <-ln.group.Accept():
if !ok {
return nil, ErrListenerClosed
}
return c, nil
}
}
func (ln *TCPMuxGroupListener) Addr() net.Addr {
return ln.addr
}
// Close close the listener
func (ln *TCPMuxGroupListener) Close() (err error) {
close(ln.closeCh)
// remove self from TcpMuxGroup
ln.group.CloseListener(ln)
return
}

View File

@@ -5,6 +5,7 @@ import (
"maps"
"os"
"path/filepath"
"slices"
"time"
flog "github.com/fatedier/frp/pkg/util/log"
@@ -14,9 +15,7 @@ import (
// RunProcesses run multiple processes from templates.
// The first template should always be frps.
func (f *Framework) RunProcesses(serverTemplates []string, clientTemplates []string) ([]*process.Process, []*process.Process) {
templates := make([]string, 0, len(serverTemplates)+len(clientTemplates))
templates = append(templates, serverTemplates...)
templates = append(templates, clientTemplates...)
templates := slices.Concat(serverTemplates, clientTemplates)
outs, ports, err := f.RenderTemplates(templates)
ExpectNoError(err)
ExpectTrue(len(templates) > 0)

View File

@@ -186,6 +186,68 @@ var _ = ginkgo.Describe("[Feature: Group]", func() {
framework.ExpectTrue(fooCount > 1 && barCount > 1, "fooCount: %d, barCount: %d", fooCount, barCount)
})
ginkgo.It("TCPMux httpconnect", func() {
vhostPort := f.AllocPort()
serverConf := consts.DefaultServerConfig + fmt.Sprintf(`
tcpmuxHTTPConnectPort = %d
`, vhostPort)
clientConf := consts.DefaultClientConfig
fooPort := f.AllocPort()
fooServer := streamserver.New(streamserver.TCP, streamserver.WithBindPort(fooPort), streamserver.WithRespContent([]byte("foo")))
f.RunServer("", fooServer)
barPort := f.AllocPort()
barServer := streamserver.New(streamserver.TCP, streamserver.WithBindPort(barPort), streamserver.WithRespContent([]byte("bar")))
f.RunServer("", barServer)
clientConf += fmt.Sprintf(`
[[proxies]]
name = "foo"
type = "tcpmux"
multiplexer = "httpconnect"
localPort = %d
customDomains = ["tcpmux-group.example.com"]
loadBalancer.group = "test"
loadBalancer.groupKey = "123"
[[proxies]]
name = "bar"
type = "tcpmux"
multiplexer = "httpconnect"
localPort = %d
customDomains = ["tcpmux-group.example.com"]
loadBalancer.group = "test"
loadBalancer.groupKey = "123"
`, fooPort, barPort)
f.RunProcesses([]string{serverConf}, []string{clientConf})
proxyURL := fmt.Sprintf("http://127.0.0.1:%d", vhostPort)
fooCount := 0
barCount := 0
for i := range 10 {
framework.NewRequestExpect(f).
Explain("times " + strconv.Itoa(i)).
RequestModify(func(r *request.Request) {
r.Addr("tcpmux-group.example.com").Proxy(proxyURL)
}).
Ensure(func(resp *request.Response) bool {
switch string(resp.Content) {
case "foo":
fooCount++
case "bar":
barCount++
default:
return false
}
return true
})
}
framework.ExpectTrue(fooCount > 1 && barCount > 1, "fooCount: %d, barCount: %d", fooCount, barCount)
})
})
ginkgo.Describe("Health Check", func() {