diff --git a/server/group/base.go b/server/group/base.go new file mode 100644 index 00000000..5684d8ef --- /dev/null +++ b/server/group/base.go @@ -0,0 +1,77 @@ +package group + +import ( + "net" + "sync" + + gerr "github.com/fatedier/golib/errors" +) + +// baseGroup contains the shared plumbing for listener-based groups +// (TCP, HTTPS, TCPMux). Each concrete group embeds this and provides +// its own Listen method with protocol-specific validation. +type baseGroup struct { + group string + groupKey string + + acceptCh chan net.Conn + realLn net.Listener + lns []*Listener + mu sync.Mutex + cleanupFn func() +} + +// initBase resets the baseGroup for a fresh listen cycle. +// Must be called under mu when len(lns) == 0. +func (bg *baseGroup) initBase(group, groupKey string, realLn net.Listener, cleanupFn func()) { + bg.group = group + bg.groupKey = groupKey + bg.realLn = realLn + bg.acceptCh = make(chan net.Conn) + bg.cleanupFn = cleanupFn +} + +// worker reads from the real listener and fans out to acceptCh. +// The parameters are captured at creation time so that the worker is +// bound to a specific listen cycle and cannot observe a later initBase. +func (bg *baseGroup) worker(realLn net.Listener, acceptCh chan<- net.Conn) { + for { + c, err := realLn.Accept() + if err != nil { + return + } + err = gerr.PanicToError(func() { + acceptCh <- c + }) + if err != nil { + c.Close() + return + } + } +} + +// newListener creates a new Listener wired to this baseGroup. +// Must be called under mu. +func (bg *baseGroup) newListener(addr net.Addr) *Listener { + ln := newListener(bg.acceptCh, addr, bg.closeListener) + bg.lns = append(bg.lns, ln) + return ln +} + +// closeListener removes ln from the list. When the last listener is removed, +// it closes acceptCh, closes the real listener, and calls cleanupFn. +func (bg *baseGroup) closeListener(ln *Listener) { + bg.mu.Lock() + defer bg.mu.Unlock() + for i, l := range bg.lns { + if l == ln { + bg.lns = append(bg.lns[:i], bg.lns[i+1:]...) + break + } + } + if len(bg.lns) == 0 { + close(bg.acceptCh) + bg.realLn.Close() + bg.cleanupFn() + } +} diff --git a/server/group/base_test.go b/server/group/base_test.go new file mode 100644 index 00000000..1b470841 --- /dev/null +++ b/server/group/base_test.go @@ -0,0 +1,169 @@ +package group + +import ( + "net" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// fakeLn is a controllable net.Listener for tests. +type fakeLn struct { + connCh chan net.Conn + closed chan struct{} + once sync.Once +} + +func newFakeLn() *fakeLn { + return &fakeLn{ + connCh: make(chan net.Conn, 8), + closed: make(chan struct{}), + } +} + +func (f *fakeLn) Accept() (net.Conn, error) { + select { + case c := <-f.connCh: + return c, nil + case <-f.closed: + return nil, net.ErrClosed + } +} + +func (f *fakeLn) Close() error { + f.once.Do(func() { close(f.closed) }) + return nil +} + +func (f *fakeLn) Addr() net.Addr { return fakeAddr("127.0.0.1:9999") } + +func (f *fakeLn) inject(c net.Conn) { + select { + case f.connCh <- c: + case <-f.closed: + } +} + +func TestBaseGroup_WorkerFanOut(t *testing.T) { + fl := newFakeLn() + var bg baseGroup + bg.initBase("g", "key", fl, func() {}) + + go bg.worker(fl, bg.acceptCh) + + c1, c2 := net.Pipe() + defer c2.Close() + fl.inject(c1) + + select { + case got := <-bg.acceptCh: + assert.Equal(t, c1, got) + got.Close() + case <-time.After(time.Second): + t.Fatal("timed out waiting for connection on acceptCh") + } + + fl.Close() +} + +func TestBaseGroup_WorkerStopsOnListenerClose(t *testing.T) { + fl := newFakeLn() + var bg baseGroup + bg.initBase("g", "key", fl, func() {}) + + done := make(chan struct{}) + go func() { + bg.worker(fl, bg.acceptCh) + close(done) + }() + + fl.Close() + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("worker did not stop after listener close") + } +} + +func TestBaseGroup_WorkerClosesConnOnClosedChannel(t *testing.T) { + fl := newFakeLn() + var bg baseGroup + bg.initBase("g", "key", fl, func() {}) + + // Close acceptCh before worker sends. + close(bg.acceptCh) + + done := make(chan struct{}) + go func() { + bg.worker(fl, bg.acceptCh) + close(done) + }() + + c1, c2 := net.Pipe() + defer c2.Close() + fl.inject(c1) + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("worker did not stop after panic recovery") + } + + // c1 should have been closed by worker's panic recovery path. + buf := make([]byte, 1) + _, err := c1.Read(buf) + assert.Error(t, err, "connection should be closed by worker") +} + +func TestBaseGroup_CloseLastListenerTriggersCleanup(t *testing.T) { + fl := newFakeLn() + var bg baseGroup + cleanupCalled := 0 + bg.initBase("g", "key", fl, func() { cleanupCalled++ }) + + bg.mu.Lock() + ln1 := bg.newListener(fl.Addr()) + ln2 := bg.newListener(fl.Addr()) + bg.mu.Unlock() + + go bg.worker(fl, bg.acceptCh) + + ln1.Close() + assert.Equal(t, 0, cleanupCalled, "cleanup should not run while listeners remain") + + ln2.Close() + assert.Equal(t, 1, cleanupCalled, "cleanup should run after last listener closed") +} + +func TestBaseGroup_CloseOneOfTwoListeners(t *testing.T) { + fl := newFakeLn() + var bg baseGroup + cleanupCalled := 0 + bg.initBase("g", "key", fl, func() { cleanupCalled++ }) + + bg.mu.Lock() + ln1 := bg.newListener(fl.Addr()) + ln2 := bg.newListener(fl.Addr()) + bg.mu.Unlock() + + go bg.worker(fl, bg.acceptCh) + + ln1.Close() + assert.Equal(t, 0, cleanupCalled) + + // ln2 should still receive connections. + c1, c2 := net.Pipe() + defer c2.Close() + fl.inject(c1) + + got, err := ln2.Accept() + require.NoError(t, err) + assert.Equal(t, c1, got) + got.Close() + + ln2.Close() + assert.Equal(t, 1, cleanupCalled) +} diff --git a/server/group/group.go b/server/group/group.go index ab38cf45..1fbedf5c 100644 --- a/server/group/group.go +++ b/server/group/group.go @@ -24,4 +24,6 @@ var ( ErrListenerClosed = errors.New("group listener closed") ErrGroupDifferentPort = errors.New("group should have same remote port") ErrProxyRepeated = errors.New("group proxy repeated") + + errGroupStale = errors.New("stale group reference") ) diff --git a/server/group/http.go b/server/group/http.go index 26af595e..dd905581 100644 --- a/server/group/http.go +++ b/server/group/http.go @@ -9,53 +9,42 @@ import ( "github.com/fatedier/frp/pkg/util/vhost" ) +// HTTPGroupController manages HTTP groups that use round-robin +// callback routing (fundamentally different from listener-based groups). type HTTPGroupController struct { - // groups indexed by group name - groups map[string]*HTTPGroup - - // register createConn for each group to vhostRouter. - // createConn will get a connection from one proxy of the group + groupRegistry[*HTTPGroup] vhostRouter *vhost.Routers - - mu sync.Mutex } func NewHTTPGroupController(vhostRouter *vhost.Routers) *HTTPGroupController { return &HTTPGroupController{ - groups: make(map[string]*HTTPGroup), - vhostRouter: vhostRouter, + groupRegistry: newGroupRegistry[*HTTPGroup](), + vhostRouter: vhostRouter, } } func (ctl *HTTPGroupController) Register( proxyName, group, groupKey string, routeConfig vhost.RouteConfig, -) (err error) { - indexKey := group - ctl.mu.Lock() - g, ok := ctl.groups[indexKey] - if !ok { - g = NewHTTPGroup(ctl) - ctl.groups[indexKey] = g +) error { + for { + g := ctl.getOrCreate(group, func() *HTTPGroup { + return NewHTTPGroup(ctl) + }) + err := g.Register(proxyName, group, groupKey, routeConfig) + if err == errGroupStale { + continue + } + return err } - ctl.mu.Unlock() - - return g.Register(proxyName, group, groupKey, routeConfig) } func (ctl *HTTPGroupController) UnRegister(proxyName, group string, _ vhost.RouteConfig) { - indexKey := group - ctl.mu.Lock() - defer ctl.mu.Unlock() - g, ok := ctl.groups[indexKey] + g, ok := ctl.get(group) if !ok { return } - - isEmpty := g.UnRegister(proxyName) - if isEmpty { - delete(ctl.groups, indexKey) - } + g.UnRegister(proxyName) } type HTTPGroup struct { @@ -87,6 +76,9 @@ func (g *HTTPGroup) Register( ) (err error) { g.mu.Lock() defer g.mu.Unlock() + if !g.ctl.isCurrent(group, func(cur *HTTPGroup) bool { return cur == g }) { + return errGroupStale + } if len(g.createFuncs) == 0 { // the first proxy in this group tmp := routeConfig // copy object @@ -123,7 +115,7 @@ func (g *HTTPGroup) Register( return nil } -func (g *HTTPGroup) UnRegister(proxyName string) (isEmpty bool) { +func (g *HTTPGroup) UnRegister(proxyName string) { g.mu.Lock() defer g.mu.Unlock() delete(g.createFuncs, proxyName) @@ -135,10 +127,11 @@ func (g *HTTPGroup) UnRegister(proxyName string) (isEmpty bool) { } if len(g.createFuncs) == 0 { - isEmpty = true g.ctl.vhostRouter.Del(g.domain, g.location, g.routeByHTTPUser) + g.ctl.removeIf(g.group, func(cur *HTTPGroup) bool { + return cur == g + }) } - return } func (g *HTTPGroup) createConn(remoteAddr string) (net.Conn, error) { @@ -151,7 +144,7 @@ func (g *HTTPGroup) createConn(remoteAddr string) (net.Conn, error) { location := g.location routeByHTTPUser := g.routeByHTTPUser if len(g.pxyNames) > 0 { - name := g.pxyNames[int(newIndex)%len(g.pxyNames)] + name := g.pxyNames[newIndex%uint64(len(g.pxyNames))] f = g.createFuncs[name] } g.mu.RUnlock() @@ -174,7 +167,7 @@ func (g *HTTPGroup) chooseEndpoint() (string, error) { location := g.location routeByHTTPUser := g.routeByHTTPUser if len(g.pxyNames) > 0 { - name = g.pxyNames[int(newIndex)%len(g.pxyNames)] + name = g.pxyNames[newIndex%uint64(len(g.pxyNames))] } g.mu.RUnlock() diff --git a/server/group/https.go b/server/group/https.go index 4089b0cb..1ab97578 100644 --- a/server/group/https.go +++ b/server/group/https.go @@ -17,25 +17,19 @@ package group import ( "context" "net" - "sync" - - gerr "github.com/fatedier/golib/errors" "github.com/fatedier/frp/pkg/util/vhost" ) type HTTPSGroupController struct { - groups map[string]*HTTPSGroup - + groupRegistry[*HTTPSGroup] httpsMuxer *vhost.HTTPSMuxer - - mu sync.Mutex } func NewHTTPSGroupController(httpsMuxer *vhost.HTTPSMuxer) *HTTPSGroupController { return &HTTPSGroupController{ - groups: make(map[string]*HTTPSGroup), - httpsMuxer: httpsMuxer, + groupRegistry: newGroupRegistry[*HTTPSGroup](), + httpsMuxer: httpsMuxer, } } @@ -44,41 +38,28 @@ func (ctl *HTTPSGroupController) Listen( group, groupKey string, routeConfig vhost.RouteConfig, ) (l net.Listener, err error) { - indexKey := group - ctl.mu.Lock() - g, ok := ctl.groups[indexKey] - if !ok { - g = NewHTTPSGroup(ctl) - ctl.groups[indexKey] = g + for { + g := ctl.getOrCreate(group, func() *HTTPSGroup { + return NewHTTPSGroup(ctl) + }) + l, err = g.Listen(ctx, group, groupKey, routeConfig) + if err == errGroupStale { + continue + } + return } - ctl.mu.Unlock() - - return g.Listen(ctx, group, groupKey, routeConfig) -} - -func (ctl *HTTPSGroupController) RemoveGroup(group string) { - ctl.mu.Lock() - defer ctl.mu.Unlock() - delete(ctl.groups, group) } type HTTPSGroup struct { - group string - groupKey string - domain string + baseGroup - acceptCh chan net.Conn - httpsLn *vhost.Listener - lns []*HTTPSGroupListener - ctl *HTTPSGroupController - mu sync.Mutex + domain string + ctl *HTTPSGroupController } func NewHTTPSGroup(ctl *HTTPSGroupController) *HTTPSGroup { return &HTTPSGroup{ - lns: make([]*HTTPSGroupListener, 0), - ctl: ctl, - acceptCh: make(chan net.Conn), + ctl: ctl, } } @@ -86,23 +67,27 @@ func (g *HTTPSGroup) Listen( ctx context.Context, group, groupKey string, routeConfig vhost.RouteConfig, -) (ln *HTTPSGroupListener, err error) { +) (ln *Listener, err error) { g.mu.Lock() defer g.mu.Unlock() + if !g.ctl.isCurrent(group, func(cur *HTTPSGroup) bool { return cur == g }) { + return nil, errGroupStale + } if len(g.lns) == 0 { // the first listener, listen on the real address httpsLn, errRet := g.ctl.httpsMuxer.Listen(ctx, &routeConfig) if errRet != nil { return nil, errRet } - ln = newHTTPSGroupListener(group, g, httpsLn.Addr()) - g.group = group - g.groupKey = groupKey g.domain = routeConfig.Domain - g.httpsLn = httpsLn - g.lns = append(g.lns, ln) - go g.worker() + g.initBase(group, groupKey, httpsLn, func() { + g.ctl.removeIf(g.group, func(cur *HTTPSGroup) bool { + return cur == g + }) + }) + ln = g.newListener(httpsLn.Addr()) + go g.worker(httpsLn, g.acceptCh) } else { // route config in the same group must be equal if g.group != group || g.domain != routeConfig.Domain { @@ -111,87 +96,7 @@ func (g *HTTPSGroup) Listen( if g.groupKey != groupKey { return nil, ErrGroupAuthFailed } - ln = newHTTPSGroupListener(group, g, g.lns[0].Addr()) - g.lns = append(g.lns, ln) + ln = g.newListener(g.lns[0].Addr()) } return } - -func (g *HTTPSGroup) worker() { - for { - c, err := g.httpsLn.Accept() - if err != nil { - return - } - err = gerr.PanicToError(func() { - g.acceptCh <- c - }) - if err != nil { - return - } - } -} - -func (g *HTTPSGroup) Accept() <-chan net.Conn { - return g.acceptCh -} - -func (g *HTTPSGroup) CloseListener(ln *HTTPSGroupListener) { - g.mu.Lock() - defer g.mu.Unlock() - for i, tmpLn := range g.lns { - if tmpLn == ln { - g.lns = append(g.lns[:i], g.lns[i+1:]...) - break - } - } - if len(g.lns) == 0 { - close(g.acceptCh) - if g.httpsLn != nil { - g.httpsLn.Close() - } - g.ctl.RemoveGroup(g.group) - } -} - -type HTTPSGroupListener struct { - groupName string - group *HTTPSGroup - - addr net.Addr - closeCh chan struct{} -} - -func newHTTPSGroupListener(name string, group *HTTPSGroup, addr net.Addr) *HTTPSGroupListener { - return &HTTPSGroupListener{ - groupName: name, - group: group, - addr: addr, - closeCh: make(chan struct{}), - } -} - -func (ln *HTTPSGroupListener) Accept() (c net.Conn, err error) { - var ok bool - select { - case <-ln.closeCh: - return nil, ErrListenerClosed - case c, ok = <-ln.group.Accept(): - if !ok { - return nil, ErrListenerClosed - } - return c, nil - } -} - -func (ln *HTTPSGroupListener) Addr() net.Addr { - return ln.addr -} - -func (ln *HTTPSGroupListener) Close() (err error) { - close(ln.closeCh) - - // remove self from HTTPSGroup - ln.group.CloseListener(ln) - return -} diff --git a/server/group/listener.go b/server/group/listener.go new file mode 100644 index 00000000..33c5c0df --- /dev/null +++ b/server/group/listener.go @@ -0,0 +1,49 @@ +package group + +import ( + "net" + "sync" +) + +// Listener is a per-proxy virtual listener that receives connections +// from a shared group. It implements net.Listener. +type Listener struct { + acceptCh <-chan net.Conn + addr net.Addr + closeCh chan struct{} + onClose func(*Listener) + once sync.Once +} + +func newListener(acceptCh <-chan net.Conn, addr net.Addr, onClose func(*Listener)) *Listener { + return &Listener{ + acceptCh: acceptCh, + addr: addr, + closeCh: make(chan struct{}), + onClose: onClose, + } +} + +func (ln *Listener) Accept() (net.Conn, error) { + select { + case <-ln.closeCh: + return nil, ErrListenerClosed + case c, ok := <-ln.acceptCh: + if !ok { + return nil, ErrListenerClosed + } + return c, nil + } +} + +func (ln *Listener) Addr() net.Addr { + return ln.addr +} + +func (ln *Listener) Close() error { + ln.once.Do(func() { + close(ln.closeCh) + ln.onClose(ln) + }) + return nil +} diff --git a/server/group/listener_test.go b/server/group/listener_test.go new file mode 100644 index 00000000..4e3e30e6 --- /dev/null +++ b/server/group/listener_test.go @@ -0,0 +1,68 @@ +package group + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestListener_Accept(t *testing.T) { + acceptCh := make(chan net.Conn, 1) + ln := newListener(acceptCh, fakeAddr("127.0.0.1:1234"), func(*Listener) {}) + + c1, c2 := net.Pipe() + defer c1.Close() + defer c2.Close() + + acceptCh <- c1 + got, err := ln.Accept() + require.NoError(t, err) + assert.Equal(t, c1, got) +} + +func TestListener_AcceptAfterChannelClose(t *testing.T) { + acceptCh := make(chan net.Conn) + ln := newListener(acceptCh, fakeAddr("127.0.0.1:1234"), func(*Listener) {}) + + close(acceptCh) + _, err := ln.Accept() + assert.ErrorIs(t, err, ErrListenerClosed) +} + +func TestListener_AcceptAfterListenerClose(t *testing.T) { + acceptCh := make(chan net.Conn) // open, not closed + ln := newListener(acceptCh, fakeAddr("127.0.0.1:1234"), func(*Listener) {}) + + ln.Close() + _, err := ln.Accept() + assert.ErrorIs(t, err, ErrListenerClosed) +} + +func TestListener_DoubleClose(t *testing.T) { + closeCalls := 0 + ln := newListener( + make(chan net.Conn), + fakeAddr("127.0.0.1:1234"), + func(*Listener) { closeCalls++ }, + ) + + assert.NotPanics(t, func() { + ln.Close() + ln.Close() + }) + assert.Equal(t, 1, closeCalls, "onClose should be called exactly once") +} + +func TestListener_Addr(t *testing.T) { + addr := fakeAddr("10.0.0.1:5555") + ln := newListener(make(chan net.Conn), addr, func(*Listener) {}) + assert.Equal(t, addr, ln.Addr()) +} + +// fakeAddr implements net.Addr for testing. +type fakeAddr string + +func (a fakeAddr) Network() string { return "tcp" } +func (a fakeAddr) String() string { return string(a) } diff --git a/server/group/registry.go b/server/group/registry.go new file mode 100644 index 00000000..4064c535 --- /dev/null +++ b/server/group/registry.go @@ -0,0 +1,59 @@ +package group + +import ( + "sync" +) + +// groupRegistry is a concurrent map of named groups with +// automatic creation on first access. +type groupRegistry[G any] struct { + groups map[string]G + mu sync.Mutex +} + +func newGroupRegistry[G any]() groupRegistry[G] { + return groupRegistry[G]{ + groups: make(map[string]G), + } +} + +func (r *groupRegistry[G]) getOrCreate(key string, newFn func() G) G { + r.mu.Lock() + defer r.mu.Unlock() + g, ok := r.groups[key] + if !ok { + g = newFn() + r.groups[key] = g + } + return g +} + +func (r *groupRegistry[G]) get(key string) (G, bool) { + r.mu.Lock() + defer r.mu.Unlock() + g, ok := r.groups[key] + return g, ok +} + +// isCurrent returns true if key exists in the registry and matchFn +// returns true for the stored value. +func (r *groupRegistry[G]) isCurrent(key string, matchFn func(G) bool) bool { + r.mu.Lock() + defer r.mu.Unlock() + g, ok := r.groups[key] + return ok && matchFn(g) +} + +// removeIf atomically looks up the group for key, calls fn on it, +// and removes the entry if fn returns true. +func (r *groupRegistry[G]) removeIf(key string, fn func(G) bool) { + r.mu.Lock() + defer r.mu.Unlock() + g, ok := r.groups[key] + if !ok { + return + } + if fn(g) { + delete(r.groups, key) + } +} diff --git a/server/group/registry_test.go b/server/group/registry_test.go new file mode 100644 index 00000000..106d3998 --- /dev/null +++ b/server/group/registry_test.go @@ -0,0 +1,102 @@ +package group + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetOrCreate_New(t *testing.T) { + r := newGroupRegistry[*int]() + called := 0 + v := 42 + got := r.getOrCreate("k", func() *int { called++; return &v }) + assert.Equal(t, 1, called) + assert.Equal(t, &v, got) +} + +func TestGetOrCreate_Existing(t *testing.T) { + r := newGroupRegistry[*int]() + v := 42 + r.getOrCreate("k", func() *int { return &v }) + + called := 0 + got := r.getOrCreate("k", func() *int { called++; return nil }) + assert.Equal(t, 0, called) + assert.Equal(t, &v, got) +} + +func TestGet_ExistingAndMissing(t *testing.T) { + r := newGroupRegistry[*int]() + v := 1 + r.getOrCreate("k", func() *int { return &v }) + + got, ok := r.get("k") + assert.True(t, ok) + assert.Equal(t, &v, got) + + _, ok = r.get("missing") + assert.False(t, ok) +} + +func TestIsCurrent(t *testing.T) { + r := newGroupRegistry[*int]() + v1 := 1 + v2 := 2 + r.getOrCreate("k", func() *int { return &v1 }) + + assert.True(t, r.isCurrent("k", func(g *int) bool { return g == &v1 })) + assert.False(t, r.isCurrent("k", func(g *int) bool { return g == &v2 })) + assert.False(t, r.isCurrent("missing", func(g *int) bool { return true })) +} + +func TestRemoveIf(t *testing.T) { + t.Run("removes when fn returns true", func(t *testing.T) { + r := newGroupRegistry[*int]() + v := 1 + r.getOrCreate("k", func() *int { return &v }) + r.removeIf("k", func(g *int) bool { return g == &v }) + _, ok := r.get("k") + assert.False(t, ok) + }) + + t.Run("keeps when fn returns false", func(t *testing.T) { + r := newGroupRegistry[*int]() + v := 1 + r.getOrCreate("k", func() *int { return &v }) + r.removeIf("k", func(g *int) bool { return false }) + _, ok := r.get("k") + assert.True(t, ok) + }) + + t.Run("noop on missing key", func(t *testing.T) { + r := newGroupRegistry[*int]() + r.removeIf("missing", func(g *int) bool { return true }) // should not panic + }) +} + +func TestConcurrentGetOrCreateAndRemoveIf(t *testing.T) { + r := newGroupRegistry[*int]() + const n = 100 + var wg sync.WaitGroup + wg.Add(n * 2) + for i := range n { + v := i + go func() { + defer wg.Done() + r.getOrCreate("k", func() *int { return &v }) + }() + go func() { + defer wg.Done() + r.removeIf("k", func(*int) bool { return true }) + }() + } + wg.Wait() + + // After all goroutines finish, accessing the key must not panic. + require.NotPanics(t, func() { + _, _ = r.get("k") + }) +} diff --git a/server/group/tcp.go b/server/group/tcp.go index f52d6407..d6bfbcff 100644 --- a/server/group/tcp.go +++ b/server/group/tcp.go @@ -17,83 +17,67 @@ package group import ( "net" "strconv" - "sync" - - gerr "github.com/fatedier/golib/errors" "github.com/fatedier/frp/server/ports" ) -// TCPGroupCtl manage all TCPGroups +// TCPGroupCtl manages all TCPGroups. type TCPGroupCtl struct { - groups map[string]*TCPGroup - - // portManager is used to manage port + groupRegistry[*TCPGroup] portManager *ports.Manager - mu sync.Mutex } -// NewTCPGroupCtl return a new TcpGroupCtl +// NewTCPGroupCtl returns a new TCPGroupCtl. func NewTCPGroupCtl(portManager *ports.Manager) *TCPGroupCtl { return &TCPGroupCtl{ - groups: make(map[string]*TCPGroup), - portManager: portManager, + groupRegistry: newGroupRegistry[*TCPGroup](), + portManager: portManager, } } -// Listen is the wrapper for TCPGroup's Listen -// If there are no group, we will create one here +// Listen is the wrapper for TCPGroup's Listen. +// If there is no group, one will be created. func (tgc *TCPGroupCtl) Listen(proxyName string, group string, groupKey string, addr string, port int, ) (l net.Listener, realPort int, err error) { - tgc.mu.Lock() - tcpGroup, ok := tgc.groups[group] - if !ok { - tcpGroup = NewTCPGroup(tgc) - tgc.groups[group] = tcpGroup + for { + tcpGroup := tgc.getOrCreate(group, func() *TCPGroup { + return NewTCPGroup(tgc) + }) + l, realPort, err = tcpGroup.Listen(proxyName, group, groupKey, addr, port) + if err == errGroupStale { + continue + } + return } - tgc.mu.Unlock() - - return tcpGroup.Listen(proxyName, group, groupKey, addr, port) } -// RemoveGroup remove TCPGroup from controller -func (tgc *TCPGroupCtl) RemoveGroup(group string) { - tgc.mu.Lock() - defer tgc.mu.Unlock() - delete(tgc.groups, group) -} - -// TCPGroup route connections to different proxies +// TCPGroup routes connections to different proxies. type TCPGroup struct { - group string - groupKey string + baseGroup + addr string port int realPort int - - acceptCh chan net.Conn - tcpLn net.Listener - lns []*TCPGroupListener ctl *TCPGroupCtl - mu sync.Mutex } -// NewTCPGroup return a new TCPGroup +// NewTCPGroup returns a new TCPGroup. func NewTCPGroup(ctl *TCPGroupCtl) *TCPGroup { return &TCPGroup{ - lns: make([]*TCPGroupListener, 0), - ctl: ctl, - acceptCh: make(chan net.Conn), + ctl: ctl, } } -// Listen will return a new TCPGroupListener -// if TCPGroup already has a listener, just add a new TCPGroupListener to the queues -// otherwise, listen on the real address -func (tg *TCPGroup) Listen(proxyName string, group string, groupKey string, addr string, port int) (ln *TCPGroupListener, realPort int, err error) { +// Listen will return a new Listener. +// If TCPGroup already has a listener, just add a new Listener to the queues, +// otherwise listen on the real address. +func (tg *TCPGroup) Listen(proxyName string, group string, groupKey string, addr string, port int) (ln *Listener, realPort int, err error) { tg.mu.Lock() defer tg.mu.Unlock() + if !tg.ctl.isCurrent(group, func(cur *TCPGroup) bool { return cur == tg }) { + return nil, 0, errGroupStale + } if len(tg.lns) == 0 { // the first listener, listen on the real address realPort, err = tg.ctl.portManager.Acquire(proxyName, port) @@ -106,19 +90,18 @@ func (tg *TCPGroup) Listen(proxyName string, group string, groupKey string, addr err = errRet return } - ln = newTCPGroupListener(group, tg, tcpLn.Addr()) - tg.group = group - tg.groupKey = groupKey tg.addr = addr tg.port = port tg.realPort = realPort - tg.tcpLn = tcpLn - tg.lns = append(tg.lns, ln) - if tg.acceptCh == nil { - tg.acceptCh = make(chan net.Conn) - } - go tg.worker() + tg.initBase(group, groupKey, tcpLn, func() { + tg.ctl.portManager.Release(tg.realPort) + tg.ctl.removeIf(tg.group, func(cur *TCPGroup) bool { + return cur == tg + }) + }) + ln = tg.newListener(tcpLn.Addr()) + go tg.worker(tcpLn, tg.acceptCh) } else { // address and port in the same group must be equal if tg.group != group || tg.addr != addr { @@ -133,92 +116,8 @@ func (tg *TCPGroup) Listen(proxyName string, group string, groupKey string, addr err = ErrGroupAuthFailed return } - ln = newTCPGroupListener(group, tg, tg.lns[0].Addr()) + ln = tg.newListener(tg.lns[0].Addr()) realPort = tg.realPort - tg.lns = append(tg.lns, ln) } return } - -// worker is called when the real tcp listener has been created -func (tg *TCPGroup) worker() { - for { - c, err := tg.tcpLn.Accept() - if err != nil { - return - } - err = gerr.PanicToError(func() { - tg.acceptCh <- c - }) - if err != nil { - return - } - } -} - -func (tg *TCPGroup) Accept() <-chan net.Conn { - return tg.acceptCh -} - -// CloseListener remove the TCPGroupListener from the TCPGroup -func (tg *TCPGroup) CloseListener(ln *TCPGroupListener) { - tg.mu.Lock() - defer tg.mu.Unlock() - for i, tmpLn := range tg.lns { - if tmpLn == ln { - tg.lns = append(tg.lns[:i], tg.lns[i+1:]...) - break - } - } - if len(tg.lns) == 0 { - close(tg.acceptCh) - tg.tcpLn.Close() - tg.ctl.portManager.Release(tg.realPort) - tg.ctl.RemoveGroup(tg.group) - } -} - -// TCPGroupListener -type TCPGroupListener struct { - groupName string - group *TCPGroup - - addr net.Addr - closeCh chan struct{} -} - -func newTCPGroupListener(name string, group *TCPGroup, addr net.Addr) *TCPGroupListener { - return &TCPGroupListener{ - groupName: name, - group: group, - addr: addr, - closeCh: make(chan struct{}), - } -} - -// Accept will accept connections from TCPGroup -func (ln *TCPGroupListener) Accept() (c net.Conn, err error) { - var ok bool - select { - case <-ln.closeCh: - return nil, ErrListenerClosed - case c, ok = <-ln.group.Accept(): - if !ok { - return nil, ErrListenerClosed - } - return c, nil - } -} - -func (ln *TCPGroupListener) Addr() net.Addr { - return ln.addr -} - -// Close close the listener -func (ln *TCPGroupListener) Close() (err error) { - close(ln.closeCh) - - // remove self from TcpGroup - ln.group.CloseListener(ln) - return -} diff --git a/server/group/tcpmux.go b/server/group/tcpmux.go index 1712bc74..e17a152a 100644 --- a/server/group/tcpmux.go +++ b/server/group/tcpmux.go @@ -18,118 +18,100 @@ import ( "context" "fmt" "net" - "sync" - - gerr "github.com/fatedier/golib/errors" v1 "github.com/fatedier/frp/pkg/config/v1" "github.com/fatedier/frp/pkg/util/tcpmux" "github.com/fatedier/frp/pkg/util/vhost" ) -// TCPMuxGroupCtl manage all TCPMuxGroups +// TCPMuxGroupCtl manages all TCPMuxGroups. type TCPMuxGroupCtl struct { - groups map[string]*TCPMuxGroup - - // portManager is used to manage port + groupRegistry[*TCPMuxGroup] tcpMuxHTTPConnectMuxer *tcpmux.HTTPConnectTCPMuxer - mu sync.Mutex } -// NewTCPMuxGroupCtl return a new TCPMuxGroupCtl +// NewTCPMuxGroupCtl returns a new TCPMuxGroupCtl. func NewTCPMuxGroupCtl(tcpMuxHTTPConnectMuxer *tcpmux.HTTPConnectTCPMuxer) *TCPMuxGroupCtl { return &TCPMuxGroupCtl{ - groups: make(map[string]*TCPMuxGroup), + groupRegistry: newGroupRegistry[*TCPMuxGroup](), tcpMuxHTTPConnectMuxer: tcpMuxHTTPConnectMuxer, } } -// Listen is the wrapper for TCPMuxGroup's Listen -// If there are no group, we will create one here +// Listen is the wrapper for TCPMuxGroup's Listen. +// If there is no group, one will be created. func (tmgc *TCPMuxGroupCtl) Listen( ctx context.Context, multiplexer, group, groupKey string, routeConfig vhost.RouteConfig, ) (l net.Listener, err error) { - tmgc.mu.Lock() - tcpMuxGroup, ok := tmgc.groups[group] - if !ok { - tcpMuxGroup = NewTCPMuxGroup(tmgc) - tmgc.groups[group] = tcpMuxGroup - } - tmgc.mu.Unlock() + for { + tcpMuxGroup := tmgc.getOrCreate(group, func() *TCPMuxGroup { + return NewTCPMuxGroup(tmgc) + }) - switch v1.TCPMultiplexerType(multiplexer) { - case v1.TCPMultiplexerHTTPConnect: - return tcpMuxGroup.HTTPConnectListen(ctx, group, groupKey, routeConfig) - default: - err = fmt.Errorf("unknown multiplexer [%s]", multiplexer) - return + switch v1.TCPMultiplexerType(multiplexer) { + case v1.TCPMultiplexerHTTPConnect: + l, err = tcpMuxGroup.HTTPConnectListen(ctx, group, groupKey, routeConfig) + if err == errGroupStale { + continue + } + return + default: + return nil, fmt.Errorf("unknown multiplexer [%s]", multiplexer) + } } } -// RemoveGroup remove TCPMuxGroup from controller -func (tmgc *TCPMuxGroupCtl) RemoveGroup(group string) { - tmgc.mu.Lock() - defer tmgc.mu.Unlock() - delete(tmgc.groups, group) -} - -// TCPMuxGroup route connections to different proxies +// TCPMuxGroup routes connections to different proxies. type TCPMuxGroup struct { - group string - groupKey string + baseGroup + domain string routeByHTTPUser string username string password string - - acceptCh chan net.Conn - tcpMuxLn net.Listener - lns []*TCPMuxGroupListener - ctl *TCPMuxGroupCtl - mu sync.Mutex + ctl *TCPMuxGroupCtl } -// NewTCPMuxGroup return a new TCPMuxGroup +// NewTCPMuxGroup returns a new TCPMuxGroup. func NewTCPMuxGroup(ctl *TCPMuxGroupCtl) *TCPMuxGroup { return &TCPMuxGroup{ - lns: make([]*TCPMuxGroupListener, 0), - ctl: ctl, - acceptCh: make(chan net.Conn), + ctl: ctl, } } -// Listen will return a new TCPMuxGroupListener -// if TCPMuxGroup already has a listener, just add a new TCPMuxGroupListener to the queues -// otherwise, listen on the real address +// HTTPConnectListen will return a new Listener. +// If TCPMuxGroup already has a listener, just add a new Listener to the queues, +// otherwise listen on the real address. func (tmg *TCPMuxGroup) HTTPConnectListen( ctx context.Context, group, groupKey string, routeConfig vhost.RouteConfig, -) (ln *TCPMuxGroupListener, err error) { +) (ln *Listener, err error) { tmg.mu.Lock() defer tmg.mu.Unlock() + if !tmg.ctl.isCurrent(group, func(cur *TCPMuxGroup) bool { return cur == tmg }) { + return nil, errGroupStale + } if len(tmg.lns) == 0 { // the first listener, listen on the real address tcpMuxLn, errRet := tmg.ctl.tcpMuxHTTPConnectMuxer.Listen(ctx, &routeConfig) if errRet != nil { return nil, errRet } - ln = newTCPMuxGroupListener(group, tmg, tcpMuxLn.Addr()) - tmg.group = group - tmg.groupKey = groupKey tmg.domain = routeConfig.Domain tmg.routeByHTTPUser = routeConfig.RouteByHTTPUser tmg.username = routeConfig.Username tmg.password = routeConfig.Password - tmg.tcpMuxLn = tcpMuxLn - tmg.lns = append(tmg.lns, ln) - if tmg.acceptCh == nil { - tmg.acceptCh = make(chan net.Conn) - } - go tmg.worker() + tmg.initBase(group, groupKey, tcpMuxLn, func() { + tmg.ctl.removeIf(tmg.group, func(cur *TCPMuxGroup) bool { + return cur == tmg + }) + }) + ln = tmg.newListener(tcpMuxLn.Addr()) + go tmg.worker(tcpMuxLn, tmg.acceptCh) } else { // route config in the same group must be equal if tmg.group != group || tmg.domain != routeConfig.Domain || @@ -141,90 +123,7 @@ func (tmg *TCPMuxGroup) HTTPConnectListen( if tmg.groupKey != groupKey { return nil, ErrGroupAuthFailed } - ln = newTCPMuxGroupListener(group, tmg, tmg.lns[0].Addr()) - tmg.lns = append(tmg.lns, ln) + ln = tmg.newListener(tmg.lns[0].Addr()) } return } - -// worker is called when the real TCP listener has been created -func (tmg *TCPMuxGroup) worker() { - for { - c, err := tmg.tcpMuxLn.Accept() - if err != nil { - return - } - err = gerr.PanicToError(func() { - tmg.acceptCh <- c - }) - if err != nil { - return - } - } -} - -func (tmg *TCPMuxGroup) Accept() <-chan net.Conn { - return tmg.acceptCh -} - -// CloseListener remove the TCPMuxGroupListener from the TCPMuxGroup -func (tmg *TCPMuxGroup) CloseListener(ln *TCPMuxGroupListener) { - tmg.mu.Lock() - defer tmg.mu.Unlock() - for i, tmpLn := range tmg.lns { - if tmpLn == ln { - tmg.lns = append(tmg.lns[:i], tmg.lns[i+1:]...) - break - } - } - if len(tmg.lns) == 0 { - close(tmg.acceptCh) - tmg.tcpMuxLn.Close() - tmg.ctl.RemoveGroup(tmg.group) - } -} - -// TCPMuxGroupListener -type TCPMuxGroupListener struct { - groupName string - group *TCPMuxGroup - - addr net.Addr - closeCh chan struct{} -} - -func newTCPMuxGroupListener(name string, group *TCPMuxGroup, addr net.Addr) *TCPMuxGroupListener { - return &TCPMuxGroupListener{ - groupName: name, - group: group, - addr: addr, - closeCh: make(chan struct{}), - } -} - -// Accept will accept connections from TCPMuxGroup -func (ln *TCPMuxGroupListener) Accept() (c net.Conn, err error) { - var ok bool - select { - case <-ln.closeCh: - return nil, ErrListenerClosed - case c, ok = <-ln.group.Accept(): - if !ok { - return nil, ErrListenerClosed - } - return c, nil - } -} - -func (ln *TCPMuxGroupListener) Addr() net.Addr { - return ln.addr -} - -// Close close the listener -func (ln *TCPMuxGroupListener) Close() (err error) { - close(ln.closeCh) - - // remove self from TcpMuxGroup - ln.group.CloseListener(ln) - return -} diff --git a/test/e2e/v1/features/group.go b/test/e2e/v1/features/group.go index 850a932c..2a7891e7 100644 --- a/test/e2e/v1/features/group.go +++ b/test/e2e/v1/features/group.go @@ -186,6 +186,68 @@ var _ = ginkgo.Describe("[Feature: Group]", func() { framework.ExpectTrue(fooCount > 1 && barCount > 1, "fooCount: %d, barCount: %d", fooCount, barCount) }) + + ginkgo.It("TCPMux httpconnect", func() { + vhostPort := f.AllocPort() + serverConf := consts.DefaultServerConfig + fmt.Sprintf(` + tcpmuxHTTPConnectPort = %d + `, vhostPort) + clientConf := consts.DefaultClientConfig + + fooPort := f.AllocPort() + fooServer := streamserver.New(streamserver.TCP, streamserver.WithBindPort(fooPort), streamserver.WithRespContent([]byte("foo"))) + f.RunServer("", fooServer) + + barPort := f.AllocPort() + barServer := streamserver.New(streamserver.TCP, streamserver.WithBindPort(barPort), streamserver.WithRespContent([]byte("bar"))) + f.RunServer("", barServer) + + clientConf += fmt.Sprintf(` + [[proxies]] + name = "foo" + type = "tcpmux" + multiplexer = "httpconnect" + localPort = %d + customDomains = ["tcpmux-group.example.com"] + loadBalancer.group = "test" + loadBalancer.groupKey = "123" + + [[proxies]] + name = "bar" + type = "tcpmux" + multiplexer = "httpconnect" + localPort = %d + customDomains = ["tcpmux-group.example.com"] + loadBalancer.group = "test" + loadBalancer.groupKey = "123" + `, fooPort, barPort) + + f.RunProcesses([]string{serverConf}, []string{clientConf}) + + proxyURL := fmt.Sprintf("http://127.0.0.1:%d", vhostPort) + fooCount := 0 + barCount := 0 + for i := range 10 { + framework.NewRequestExpect(f). + Explain("times " + strconv.Itoa(i)). + RequestModify(func(r *request.Request) { + r.Addr("tcpmux-group.example.com").Proxy(proxyURL) + }). + Ensure(func(resp *request.Response) bool { + switch string(resp.Content) { + case "foo": + fooCount++ + case "bar": + barCount++ + default: + return false + } + return true + }) + } + + framework.ExpectTrue(fooCount > 1 && barCount > 1, "fooCount: %d, barCount: %d", fooCount, barCount) + }) }) ginkgo.Describe("Health Check", func() {