mirror of
https://github.com/fatedier/frp.git
synced 2026-03-18 07:49:16 +08:00
server/group: refactor with shared abstractions and fix concurrency issues (#5222)
* server/group: refactor group package with shared abstractions and fix concurrency issues Extract common patterns into reusable components: - groupRegistry[G]: generic concurrent map for group lifecycle management - baseGroup: shared plumbing for listener-based groups (TCP, HTTPS, TCPMux) - Listener: unified virtual listener replacing 3 identical implementations Fix concurrency issues: - Stale-pointer race: isCurrent check + errGroupStale + controller retry loops - Worker generation safety: pass realLn and acceptCh as params instead of reading mutable fields - Connection leak: close conn on worker panic recovery path - ABBA deadlock in HTTP UnRegister: consistent lock ordering (group.mu -> registry.mu) - Round-robin overflow in HTTPGroup: use unsigned modulo Add unit tests (17 tests) for registry, listener, and baseGroup. Add TCPMux group load balancing e2e test. * server/group: replace tautological assertion with require.NotPanics * server/group: remove blank line between doc comment and type declaration
This commit is contained in:
77
server/group/base.go
Normal file
77
server/group/base.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package group
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
gerr "github.com/fatedier/golib/errors"
|
||||
)
|
||||
|
||||
// baseGroup contains the shared plumbing for listener-based groups
|
||||
// (TCP, HTTPS, TCPMux). Each concrete group embeds this and provides
|
||||
// its own Listen method with protocol-specific validation.
|
||||
type baseGroup struct {
|
||||
group string
|
||||
groupKey string
|
||||
|
||||
acceptCh chan net.Conn
|
||||
realLn net.Listener
|
||||
lns []*Listener
|
||||
mu sync.Mutex
|
||||
cleanupFn func()
|
||||
}
|
||||
|
||||
// initBase resets the baseGroup for a fresh listen cycle.
|
||||
// Must be called under mu when len(lns) == 0.
|
||||
func (bg *baseGroup) initBase(group, groupKey string, realLn net.Listener, cleanupFn func()) {
|
||||
bg.group = group
|
||||
bg.groupKey = groupKey
|
||||
bg.realLn = realLn
|
||||
bg.acceptCh = make(chan net.Conn)
|
||||
bg.cleanupFn = cleanupFn
|
||||
}
|
||||
|
||||
// worker reads from the real listener and fans out to acceptCh.
|
||||
// The parameters are captured at creation time so that the worker is
|
||||
// bound to a specific listen cycle and cannot observe a later initBase.
|
||||
func (bg *baseGroup) worker(realLn net.Listener, acceptCh chan<- net.Conn) {
|
||||
for {
|
||||
c, err := realLn.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = gerr.PanicToError(func() {
|
||||
acceptCh <- c
|
||||
})
|
||||
if err != nil {
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// newListener creates a new Listener wired to this baseGroup.
|
||||
// Must be called under mu.
|
||||
func (bg *baseGroup) newListener(addr net.Addr) *Listener {
|
||||
ln := newListener(bg.acceptCh, addr, bg.closeListener)
|
||||
bg.lns = append(bg.lns, ln)
|
||||
return ln
|
||||
}
|
||||
|
||||
// closeListener removes ln from the list. When the last listener is removed,
|
||||
// it closes acceptCh, closes the real listener, and calls cleanupFn.
|
||||
func (bg *baseGroup) closeListener(ln *Listener) {
|
||||
bg.mu.Lock()
|
||||
defer bg.mu.Unlock()
|
||||
for i, l := range bg.lns {
|
||||
if l == ln {
|
||||
bg.lns = append(bg.lns[:i], bg.lns[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(bg.lns) == 0 {
|
||||
close(bg.acceptCh)
|
||||
bg.realLn.Close()
|
||||
bg.cleanupFn()
|
||||
}
|
||||
}
|
||||
169
server/group/base_test.go
Normal file
169
server/group/base_test.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package group
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// fakeLn is a controllable net.Listener for tests.
|
||||
type fakeLn struct {
|
||||
connCh chan net.Conn
|
||||
closed chan struct{}
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func newFakeLn() *fakeLn {
|
||||
return &fakeLn{
|
||||
connCh: make(chan net.Conn, 8),
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeLn) Accept() (net.Conn, error) {
|
||||
select {
|
||||
case c := <-f.connCh:
|
||||
return c, nil
|
||||
case <-f.closed:
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeLn) Close() error {
|
||||
f.once.Do(func() { close(f.closed) })
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeLn) Addr() net.Addr { return fakeAddr("127.0.0.1:9999") }
|
||||
|
||||
func (f *fakeLn) inject(c net.Conn) {
|
||||
select {
|
||||
case f.connCh <- c:
|
||||
case <-f.closed:
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseGroup_WorkerFanOut(t *testing.T) {
|
||||
fl := newFakeLn()
|
||||
var bg baseGroup
|
||||
bg.initBase("g", "key", fl, func() {})
|
||||
|
||||
go bg.worker(fl, bg.acceptCh)
|
||||
|
||||
c1, c2 := net.Pipe()
|
||||
defer c2.Close()
|
||||
fl.inject(c1)
|
||||
|
||||
select {
|
||||
case got := <-bg.acceptCh:
|
||||
assert.Equal(t, c1, got)
|
||||
got.Close()
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for connection on acceptCh")
|
||||
}
|
||||
|
||||
fl.Close()
|
||||
}
|
||||
|
||||
func TestBaseGroup_WorkerStopsOnListenerClose(t *testing.T) {
|
||||
fl := newFakeLn()
|
||||
var bg baseGroup
|
||||
bg.initBase("g", "key", fl, func() {})
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
bg.worker(fl, bg.acceptCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
fl.Close()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("worker did not stop after listener close")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseGroup_WorkerClosesConnOnClosedChannel(t *testing.T) {
|
||||
fl := newFakeLn()
|
||||
var bg baseGroup
|
||||
bg.initBase("g", "key", fl, func() {})
|
||||
|
||||
// Close acceptCh before worker sends.
|
||||
close(bg.acceptCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
bg.worker(fl, bg.acceptCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
c1, c2 := net.Pipe()
|
||||
defer c2.Close()
|
||||
fl.inject(c1)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("worker did not stop after panic recovery")
|
||||
}
|
||||
|
||||
// c1 should have been closed by worker's panic recovery path.
|
||||
buf := make([]byte, 1)
|
||||
_, err := c1.Read(buf)
|
||||
assert.Error(t, err, "connection should be closed by worker")
|
||||
}
|
||||
|
||||
func TestBaseGroup_CloseLastListenerTriggersCleanup(t *testing.T) {
|
||||
fl := newFakeLn()
|
||||
var bg baseGroup
|
||||
cleanupCalled := 0
|
||||
bg.initBase("g", "key", fl, func() { cleanupCalled++ })
|
||||
|
||||
bg.mu.Lock()
|
||||
ln1 := bg.newListener(fl.Addr())
|
||||
ln2 := bg.newListener(fl.Addr())
|
||||
bg.mu.Unlock()
|
||||
|
||||
go bg.worker(fl, bg.acceptCh)
|
||||
|
||||
ln1.Close()
|
||||
assert.Equal(t, 0, cleanupCalled, "cleanup should not run while listeners remain")
|
||||
|
||||
ln2.Close()
|
||||
assert.Equal(t, 1, cleanupCalled, "cleanup should run after last listener closed")
|
||||
}
|
||||
|
||||
func TestBaseGroup_CloseOneOfTwoListeners(t *testing.T) {
|
||||
fl := newFakeLn()
|
||||
var bg baseGroup
|
||||
cleanupCalled := 0
|
||||
bg.initBase("g", "key", fl, func() { cleanupCalled++ })
|
||||
|
||||
bg.mu.Lock()
|
||||
ln1 := bg.newListener(fl.Addr())
|
||||
ln2 := bg.newListener(fl.Addr())
|
||||
bg.mu.Unlock()
|
||||
|
||||
go bg.worker(fl, bg.acceptCh)
|
||||
|
||||
ln1.Close()
|
||||
assert.Equal(t, 0, cleanupCalled)
|
||||
|
||||
// ln2 should still receive connections.
|
||||
c1, c2 := net.Pipe()
|
||||
defer c2.Close()
|
||||
fl.inject(c1)
|
||||
|
||||
got, err := ln2.Accept()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, c1, got)
|
||||
got.Close()
|
||||
|
||||
ln2.Close()
|
||||
assert.Equal(t, 1, cleanupCalled)
|
||||
}
|
||||
@@ -24,4 +24,6 @@ var (
|
||||
ErrListenerClosed = errors.New("group listener closed")
|
||||
ErrGroupDifferentPort = errors.New("group should have same remote port")
|
||||
ErrProxyRepeated = errors.New("group proxy repeated")
|
||||
|
||||
errGroupStale = errors.New("stale group reference")
|
||||
)
|
||||
|
||||
@@ -9,53 +9,42 @@ import (
|
||||
"github.com/fatedier/frp/pkg/util/vhost"
|
||||
)
|
||||
|
||||
// HTTPGroupController manages HTTP groups that use round-robin
|
||||
// callback routing (fundamentally different from listener-based groups).
|
||||
type HTTPGroupController struct {
|
||||
// groups indexed by group name
|
||||
groups map[string]*HTTPGroup
|
||||
|
||||
// register createConn for each group to vhostRouter.
|
||||
// createConn will get a connection from one proxy of the group
|
||||
groupRegistry[*HTTPGroup]
|
||||
vhostRouter *vhost.Routers
|
||||
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewHTTPGroupController(vhostRouter *vhost.Routers) *HTTPGroupController {
|
||||
return &HTTPGroupController{
|
||||
groups: make(map[string]*HTTPGroup),
|
||||
vhostRouter: vhostRouter,
|
||||
groupRegistry: newGroupRegistry[*HTTPGroup](),
|
||||
vhostRouter: vhostRouter,
|
||||
}
|
||||
}
|
||||
|
||||
func (ctl *HTTPGroupController) Register(
|
||||
proxyName, group, groupKey string,
|
||||
routeConfig vhost.RouteConfig,
|
||||
) (err error) {
|
||||
indexKey := group
|
||||
ctl.mu.Lock()
|
||||
g, ok := ctl.groups[indexKey]
|
||||
if !ok {
|
||||
g = NewHTTPGroup(ctl)
|
||||
ctl.groups[indexKey] = g
|
||||
) error {
|
||||
for {
|
||||
g := ctl.getOrCreate(group, func() *HTTPGroup {
|
||||
return NewHTTPGroup(ctl)
|
||||
})
|
||||
err := g.Register(proxyName, group, groupKey, routeConfig)
|
||||
if err == errGroupStale {
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
ctl.mu.Unlock()
|
||||
|
||||
return g.Register(proxyName, group, groupKey, routeConfig)
|
||||
}
|
||||
|
||||
func (ctl *HTTPGroupController) UnRegister(proxyName, group string, _ vhost.RouteConfig) {
|
||||
indexKey := group
|
||||
ctl.mu.Lock()
|
||||
defer ctl.mu.Unlock()
|
||||
g, ok := ctl.groups[indexKey]
|
||||
g, ok := ctl.get(group)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
isEmpty := g.UnRegister(proxyName)
|
||||
if isEmpty {
|
||||
delete(ctl.groups, indexKey)
|
||||
}
|
||||
g.UnRegister(proxyName)
|
||||
}
|
||||
|
||||
type HTTPGroup struct {
|
||||
@@ -87,6 +76,9 @@ func (g *HTTPGroup) Register(
|
||||
) (err error) {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
if !g.ctl.isCurrent(group, func(cur *HTTPGroup) bool { return cur == g }) {
|
||||
return errGroupStale
|
||||
}
|
||||
if len(g.createFuncs) == 0 {
|
||||
// the first proxy in this group
|
||||
tmp := routeConfig // copy object
|
||||
@@ -123,7 +115,7 @@ func (g *HTTPGroup) Register(
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *HTTPGroup) UnRegister(proxyName string) (isEmpty bool) {
|
||||
func (g *HTTPGroup) UnRegister(proxyName string) {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
delete(g.createFuncs, proxyName)
|
||||
@@ -135,10 +127,11 @@ func (g *HTTPGroup) UnRegister(proxyName string) (isEmpty bool) {
|
||||
}
|
||||
|
||||
if len(g.createFuncs) == 0 {
|
||||
isEmpty = true
|
||||
g.ctl.vhostRouter.Del(g.domain, g.location, g.routeByHTTPUser)
|
||||
g.ctl.removeIf(g.group, func(cur *HTTPGroup) bool {
|
||||
return cur == g
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (g *HTTPGroup) createConn(remoteAddr string) (net.Conn, error) {
|
||||
@@ -151,7 +144,7 @@ func (g *HTTPGroup) createConn(remoteAddr string) (net.Conn, error) {
|
||||
location := g.location
|
||||
routeByHTTPUser := g.routeByHTTPUser
|
||||
if len(g.pxyNames) > 0 {
|
||||
name := g.pxyNames[int(newIndex)%len(g.pxyNames)]
|
||||
name := g.pxyNames[newIndex%uint64(len(g.pxyNames))]
|
||||
f = g.createFuncs[name]
|
||||
}
|
||||
g.mu.RUnlock()
|
||||
@@ -174,7 +167,7 @@ func (g *HTTPGroup) chooseEndpoint() (string, error) {
|
||||
location := g.location
|
||||
routeByHTTPUser := g.routeByHTTPUser
|
||||
if len(g.pxyNames) > 0 {
|
||||
name = g.pxyNames[int(newIndex)%len(g.pxyNames)]
|
||||
name = g.pxyNames[newIndex%uint64(len(g.pxyNames))]
|
||||
}
|
||||
g.mu.RUnlock()
|
||||
|
||||
|
||||
@@ -17,25 +17,19 @@ package group
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
gerr "github.com/fatedier/golib/errors"
|
||||
|
||||
"github.com/fatedier/frp/pkg/util/vhost"
|
||||
)
|
||||
|
||||
type HTTPSGroupController struct {
|
||||
groups map[string]*HTTPSGroup
|
||||
|
||||
groupRegistry[*HTTPSGroup]
|
||||
httpsMuxer *vhost.HTTPSMuxer
|
||||
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewHTTPSGroupController(httpsMuxer *vhost.HTTPSMuxer) *HTTPSGroupController {
|
||||
return &HTTPSGroupController{
|
||||
groups: make(map[string]*HTTPSGroup),
|
||||
httpsMuxer: httpsMuxer,
|
||||
groupRegistry: newGroupRegistry[*HTTPSGroup](),
|
||||
httpsMuxer: httpsMuxer,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,41 +38,28 @@ func (ctl *HTTPSGroupController) Listen(
|
||||
group, groupKey string,
|
||||
routeConfig vhost.RouteConfig,
|
||||
) (l net.Listener, err error) {
|
||||
indexKey := group
|
||||
ctl.mu.Lock()
|
||||
g, ok := ctl.groups[indexKey]
|
||||
if !ok {
|
||||
g = NewHTTPSGroup(ctl)
|
||||
ctl.groups[indexKey] = g
|
||||
for {
|
||||
g := ctl.getOrCreate(group, func() *HTTPSGroup {
|
||||
return NewHTTPSGroup(ctl)
|
||||
})
|
||||
l, err = g.Listen(ctx, group, groupKey, routeConfig)
|
||||
if err == errGroupStale {
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
ctl.mu.Unlock()
|
||||
|
||||
return g.Listen(ctx, group, groupKey, routeConfig)
|
||||
}
|
||||
|
||||
func (ctl *HTTPSGroupController) RemoveGroup(group string) {
|
||||
ctl.mu.Lock()
|
||||
defer ctl.mu.Unlock()
|
||||
delete(ctl.groups, group)
|
||||
}
|
||||
|
||||
type HTTPSGroup struct {
|
||||
group string
|
||||
groupKey string
|
||||
domain string
|
||||
baseGroup
|
||||
|
||||
acceptCh chan net.Conn
|
||||
httpsLn *vhost.Listener
|
||||
lns []*HTTPSGroupListener
|
||||
ctl *HTTPSGroupController
|
||||
mu sync.Mutex
|
||||
domain string
|
||||
ctl *HTTPSGroupController
|
||||
}
|
||||
|
||||
func NewHTTPSGroup(ctl *HTTPSGroupController) *HTTPSGroup {
|
||||
return &HTTPSGroup{
|
||||
lns: make([]*HTTPSGroupListener, 0),
|
||||
ctl: ctl,
|
||||
acceptCh: make(chan net.Conn),
|
||||
ctl: ctl,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -86,23 +67,27 @@ func (g *HTTPSGroup) Listen(
|
||||
ctx context.Context,
|
||||
group, groupKey string,
|
||||
routeConfig vhost.RouteConfig,
|
||||
) (ln *HTTPSGroupListener, err error) {
|
||||
) (ln *Listener, err error) {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
if !g.ctl.isCurrent(group, func(cur *HTTPSGroup) bool { return cur == g }) {
|
||||
return nil, errGroupStale
|
||||
}
|
||||
if len(g.lns) == 0 {
|
||||
// the first listener, listen on the real address
|
||||
httpsLn, errRet := g.ctl.httpsMuxer.Listen(ctx, &routeConfig)
|
||||
if errRet != nil {
|
||||
return nil, errRet
|
||||
}
|
||||
ln = newHTTPSGroupListener(group, g, httpsLn.Addr())
|
||||
|
||||
g.group = group
|
||||
g.groupKey = groupKey
|
||||
g.domain = routeConfig.Domain
|
||||
g.httpsLn = httpsLn
|
||||
g.lns = append(g.lns, ln)
|
||||
go g.worker()
|
||||
g.initBase(group, groupKey, httpsLn, func() {
|
||||
g.ctl.removeIf(g.group, func(cur *HTTPSGroup) bool {
|
||||
return cur == g
|
||||
})
|
||||
})
|
||||
ln = g.newListener(httpsLn.Addr())
|
||||
go g.worker(httpsLn, g.acceptCh)
|
||||
} else {
|
||||
// route config in the same group must be equal
|
||||
if g.group != group || g.domain != routeConfig.Domain {
|
||||
@@ -111,87 +96,7 @@ func (g *HTTPSGroup) Listen(
|
||||
if g.groupKey != groupKey {
|
||||
return nil, ErrGroupAuthFailed
|
||||
}
|
||||
ln = newHTTPSGroupListener(group, g, g.lns[0].Addr())
|
||||
g.lns = append(g.lns, ln)
|
||||
ln = g.newListener(g.lns[0].Addr())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (g *HTTPSGroup) worker() {
|
||||
for {
|
||||
c, err := g.httpsLn.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = gerr.PanicToError(func() {
|
||||
g.acceptCh <- c
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *HTTPSGroup) Accept() <-chan net.Conn {
|
||||
return g.acceptCh
|
||||
}
|
||||
|
||||
func (g *HTTPSGroup) CloseListener(ln *HTTPSGroupListener) {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
for i, tmpLn := range g.lns {
|
||||
if tmpLn == ln {
|
||||
g.lns = append(g.lns[:i], g.lns[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(g.lns) == 0 {
|
||||
close(g.acceptCh)
|
||||
if g.httpsLn != nil {
|
||||
g.httpsLn.Close()
|
||||
}
|
||||
g.ctl.RemoveGroup(g.group)
|
||||
}
|
||||
}
|
||||
|
||||
type HTTPSGroupListener struct {
|
||||
groupName string
|
||||
group *HTTPSGroup
|
||||
|
||||
addr net.Addr
|
||||
closeCh chan struct{}
|
||||
}
|
||||
|
||||
func newHTTPSGroupListener(name string, group *HTTPSGroup, addr net.Addr) *HTTPSGroupListener {
|
||||
return &HTTPSGroupListener{
|
||||
groupName: name,
|
||||
group: group,
|
||||
addr: addr,
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (ln *HTTPSGroupListener) Accept() (c net.Conn, err error) {
|
||||
var ok bool
|
||||
select {
|
||||
case <-ln.closeCh:
|
||||
return nil, ErrListenerClosed
|
||||
case c, ok = <-ln.group.Accept():
|
||||
if !ok {
|
||||
return nil, ErrListenerClosed
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (ln *HTTPSGroupListener) Addr() net.Addr {
|
||||
return ln.addr
|
||||
}
|
||||
|
||||
func (ln *HTTPSGroupListener) Close() (err error) {
|
||||
close(ln.closeCh)
|
||||
|
||||
// remove self from HTTPSGroup
|
||||
ln.group.CloseListener(ln)
|
||||
return
|
||||
}
|
||||
|
||||
49
server/group/listener.go
Normal file
49
server/group/listener.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package group
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Listener is a per-proxy virtual listener that receives connections
|
||||
// from a shared group. It implements net.Listener.
|
||||
type Listener struct {
|
||||
acceptCh <-chan net.Conn
|
||||
addr net.Addr
|
||||
closeCh chan struct{}
|
||||
onClose func(*Listener)
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func newListener(acceptCh <-chan net.Conn, addr net.Addr, onClose func(*Listener)) *Listener {
|
||||
return &Listener{
|
||||
acceptCh: acceptCh,
|
||||
addr: addr,
|
||||
closeCh: make(chan struct{}),
|
||||
onClose: onClose,
|
||||
}
|
||||
}
|
||||
|
||||
func (ln *Listener) Accept() (net.Conn, error) {
|
||||
select {
|
||||
case <-ln.closeCh:
|
||||
return nil, ErrListenerClosed
|
||||
case c, ok := <-ln.acceptCh:
|
||||
if !ok {
|
||||
return nil, ErrListenerClosed
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (ln *Listener) Addr() net.Addr {
|
||||
return ln.addr
|
||||
}
|
||||
|
||||
func (ln *Listener) Close() error {
|
||||
ln.once.Do(func() {
|
||||
close(ln.closeCh)
|
||||
ln.onClose(ln)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
68
server/group/listener_test.go
Normal file
68
server/group/listener_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package group
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestListener_Accept(t *testing.T) {
|
||||
acceptCh := make(chan net.Conn, 1)
|
||||
ln := newListener(acceptCh, fakeAddr("127.0.0.1:1234"), func(*Listener) {})
|
||||
|
||||
c1, c2 := net.Pipe()
|
||||
defer c1.Close()
|
||||
defer c2.Close()
|
||||
|
||||
acceptCh <- c1
|
||||
got, err := ln.Accept()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, c1, got)
|
||||
}
|
||||
|
||||
func TestListener_AcceptAfterChannelClose(t *testing.T) {
|
||||
acceptCh := make(chan net.Conn)
|
||||
ln := newListener(acceptCh, fakeAddr("127.0.0.1:1234"), func(*Listener) {})
|
||||
|
||||
close(acceptCh)
|
||||
_, err := ln.Accept()
|
||||
assert.ErrorIs(t, err, ErrListenerClosed)
|
||||
}
|
||||
|
||||
func TestListener_AcceptAfterListenerClose(t *testing.T) {
|
||||
acceptCh := make(chan net.Conn) // open, not closed
|
||||
ln := newListener(acceptCh, fakeAddr("127.0.0.1:1234"), func(*Listener) {})
|
||||
|
||||
ln.Close()
|
||||
_, err := ln.Accept()
|
||||
assert.ErrorIs(t, err, ErrListenerClosed)
|
||||
}
|
||||
|
||||
func TestListener_DoubleClose(t *testing.T) {
|
||||
closeCalls := 0
|
||||
ln := newListener(
|
||||
make(chan net.Conn),
|
||||
fakeAddr("127.0.0.1:1234"),
|
||||
func(*Listener) { closeCalls++ },
|
||||
)
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
ln.Close()
|
||||
ln.Close()
|
||||
})
|
||||
assert.Equal(t, 1, closeCalls, "onClose should be called exactly once")
|
||||
}
|
||||
|
||||
func TestListener_Addr(t *testing.T) {
|
||||
addr := fakeAddr("10.0.0.1:5555")
|
||||
ln := newListener(make(chan net.Conn), addr, func(*Listener) {})
|
||||
assert.Equal(t, addr, ln.Addr())
|
||||
}
|
||||
|
||||
// fakeAddr implements net.Addr for testing.
|
||||
type fakeAddr string
|
||||
|
||||
func (a fakeAddr) Network() string { return "tcp" }
|
||||
func (a fakeAddr) String() string { return string(a) }
|
||||
59
server/group/registry.go
Normal file
59
server/group/registry.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package group
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// groupRegistry is a concurrent map of named groups with
|
||||
// automatic creation on first access.
|
||||
type groupRegistry[G any] struct {
|
||||
groups map[string]G
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newGroupRegistry[G any]() groupRegistry[G] {
|
||||
return groupRegistry[G]{
|
||||
groups: make(map[string]G),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *groupRegistry[G]) getOrCreate(key string, newFn func() G) G {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
g, ok := r.groups[key]
|
||||
if !ok {
|
||||
g = newFn()
|
||||
r.groups[key] = g
|
||||
}
|
||||
return g
|
||||
}
|
||||
|
||||
func (r *groupRegistry[G]) get(key string) (G, bool) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
g, ok := r.groups[key]
|
||||
return g, ok
|
||||
}
|
||||
|
||||
// isCurrent returns true if key exists in the registry and matchFn
|
||||
// returns true for the stored value.
|
||||
func (r *groupRegistry[G]) isCurrent(key string, matchFn func(G) bool) bool {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
g, ok := r.groups[key]
|
||||
return ok && matchFn(g)
|
||||
}
|
||||
|
||||
// removeIf atomically looks up the group for key, calls fn on it,
|
||||
// and removes the entry if fn returns true.
|
||||
func (r *groupRegistry[G]) removeIf(key string, fn func(G) bool) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
g, ok := r.groups[key]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if fn(g) {
|
||||
delete(r.groups, key)
|
||||
}
|
||||
}
|
||||
102
server/group/registry_test.go
Normal file
102
server/group/registry_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package group
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetOrCreate_New(t *testing.T) {
|
||||
r := newGroupRegistry[*int]()
|
||||
called := 0
|
||||
v := 42
|
||||
got := r.getOrCreate("k", func() *int { called++; return &v })
|
||||
assert.Equal(t, 1, called)
|
||||
assert.Equal(t, &v, got)
|
||||
}
|
||||
|
||||
func TestGetOrCreate_Existing(t *testing.T) {
|
||||
r := newGroupRegistry[*int]()
|
||||
v := 42
|
||||
r.getOrCreate("k", func() *int { return &v })
|
||||
|
||||
called := 0
|
||||
got := r.getOrCreate("k", func() *int { called++; return nil })
|
||||
assert.Equal(t, 0, called)
|
||||
assert.Equal(t, &v, got)
|
||||
}
|
||||
|
||||
func TestGet_ExistingAndMissing(t *testing.T) {
|
||||
r := newGroupRegistry[*int]()
|
||||
v := 1
|
||||
r.getOrCreate("k", func() *int { return &v })
|
||||
|
||||
got, ok := r.get("k")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, &v, got)
|
||||
|
||||
_, ok = r.get("missing")
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestIsCurrent(t *testing.T) {
|
||||
r := newGroupRegistry[*int]()
|
||||
v1 := 1
|
||||
v2 := 2
|
||||
r.getOrCreate("k", func() *int { return &v1 })
|
||||
|
||||
assert.True(t, r.isCurrent("k", func(g *int) bool { return g == &v1 }))
|
||||
assert.False(t, r.isCurrent("k", func(g *int) bool { return g == &v2 }))
|
||||
assert.False(t, r.isCurrent("missing", func(g *int) bool { return true }))
|
||||
}
|
||||
|
||||
func TestRemoveIf(t *testing.T) {
|
||||
t.Run("removes when fn returns true", func(t *testing.T) {
|
||||
r := newGroupRegistry[*int]()
|
||||
v := 1
|
||||
r.getOrCreate("k", func() *int { return &v })
|
||||
r.removeIf("k", func(g *int) bool { return g == &v })
|
||||
_, ok := r.get("k")
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("keeps when fn returns false", func(t *testing.T) {
|
||||
r := newGroupRegistry[*int]()
|
||||
v := 1
|
||||
r.getOrCreate("k", func() *int { return &v })
|
||||
r.removeIf("k", func(g *int) bool { return false })
|
||||
_, ok := r.get("k")
|
||||
assert.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("noop on missing key", func(t *testing.T) {
|
||||
r := newGroupRegistry[*int]()
|
||||
r.removeIf("missing", func(g *int) bool { return true }) // should not panic
|
||||
})
|
||||
}
|
||||
|
||||
func TestConcurrentGetOrCreateAndRemoveIf(t *testing.T) {
|
||||
r := newGroupRegistry[*int]()
|
||||
const n = 100
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(n * 2)
|
||||
for i := range n {
|
||||
v := i
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
r.getOrCreate("k", func() *int { return &v })
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
r.removeIf("k", func(*int) bool { return true })
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// After all goroutines finish, accessing the key must not panic.
|
||||
require.NotPanics(t, func() {
|
||||
_, _ = r.get("k")
|
||||
})
|
||||
}
|
||||
@@ -17,83 +17,67 @@ package group
|
||||
import (
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
gerr "github.com/fatedier/golib/errors"
|
||||
|
||||
"github.com/fatedier/frp/server/ports"
|
||||
)
|
||||
|
||||
// TCPGroupCtl manage all TCPGroups
|
||||
// TCPGroupCtl manages all TCPGroups.
|
||||
type TCPGroupCtl struct {
|
||||
groups map[string]*TCPGroup
|
||||
|
||||
// portManager is used to manage port
|
||||
groupRegistry[*TCPGroup]
|
||||
portManager *ports.Manager
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewTCPGroupCtl return a new TcpGroupCtl
|
||||
// NewTCPGroupCtl returns a new TCPGroupCtl.
|
||||
func NewTCPGroupCtl(portManager *ports.Manager) *TCPGroupCtl {
|
||||
return &TCPGroupCtl{
|
||||
groups: make(map[string]*TCPGroup),
|
||||
portManager: portManager,
|
||||
groupRegistry: newGroupRegistry[*TCPGroup](),
|
||||
portManager: portManager,
|
||||
}
|
||||
}
|
||||
|
||||
// Listen is the wrapper for TCPGroup's Listen
|
||||
// If there are no group, we will create one here
|
||||
// Listen is the wrapper for TCPGroup's Listen.
|
||||
// If there is no group, one will be created.
|
||||
func (tgc *TCPGroupCtl) Listen(proxyName string, group string, groupKey string,
|
||||
addr string, port int,
|
||||
) (l net.Listener, realPort int, err error) {
|
||||
tgc.mu.Lock()
|
||||
tcpGroup, ok := tgc.groups[group]
|
||||
if !ok {
|
||||
tcpGroup = NewTCPGroup(tgc)
|
||||
tgc.groups[group] = tcpGroup
|
||||
for {
|
||||
tcpGroup := tgc.getOrCreate(group, func() *TCPGroup {
|
||||
return NewTCPGroup(tgc)
|
||||
})
|
||||
l, realPort, err = tcpGroup.Listen(proxyName, group, groupKey, addr, port)
|
||||
if err == errGroupStale {
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
tgc.mu.Unlock()
|
||||
|
||||
return tcpGroup.Listen(proxyName, group, groupKey, addr, port)
|
||||
}
|
||||
|
||||
// RemoveGroup remove TCPGroup from controller
|
||||
func (tgc *TCPGroupCtl) RemoveGroup(group string) {
|
||||
tgc.mu.Lock()
|
||||
defer tgc.mu.Unlock()
|
||||
delete(tgc.groups, group)
|
||||
}
|
||||
|
||||
// TCPGroup route connections to different proxies
|
||||
// TCPGroup routes connections to different proxies.
|
||||
type TCPGroup struct {
|
||||
group string
|
||||
groupKey string
|
||||
baseGroup
|
||||
|
||||
addr string
|
||||
port int
|
||||
realPort int
|
||||
|
||||
acceptCh chan net.Conn
|
||||
tcpLn net.Listener
|
||||
lns []*TCPGroupListener
|
||||
ctl *TCPGroupCtl
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewTCPGroup return a new TCPGroup
|
||||
// NewTCPGroup returns a new TCPGroup.
|
||||
func NewTCPGroup(ctl *TCPGroupCtl) *TCPGroup {
|
||||
return &TCPGroup{
|
||||
lns: make([]*TCPGroupListener, 0),
|
||||
ctl: ctl,
|
||||
acceptCh: make(chan net.Conn),
|
||||
ctl: ctl,
|
||||
}
|
||||
}
|
||||
|
||||
// Listen will return a new TCPGroupListener
|
||||
// if TCPGroup already has a listener, just add a new TCPGroupListener to the queues
|
||||
// otherwise, listen on the real address
|
||||
func (tg *TCPGroup) Listen(proxyName string, group string, groupKey string, addr string, port int) (ln *TCPGroupListener, realPort int, err error) {
|
||||
// Listen will return a new Listener.
|
||||
// If TCPGroup already has a listener, just add a new Listener to the queues,
|
||||
// otherwise listen on the real address.
|
||||
func (tg *TCPGroup) Listen(proxyName string, group string, groupKey string, addr string, port int) (ln *Listener, realPort int, err error) {
|
||||
tg.mu.Lock()
|
||||
defer tg.mu.Unlock()
|
||||
if !tg.ctl.isCurrent(group, func(cur *TCPGroup) bool { return cur == tg }) {
|
||||
return nil, 0, errGroupStale
|
||||
}
|
||||
if len(tg.lns) == 0 {
|
||||
// the first listener, listen on the real address
|
||||
realPort, err = tg.ctl.portManager.Acquire(proxyName, port)
|
||||
@@ -106,19 +90,18 @@ func (tg *TCPGroup) Listen(proxyName string, group string, groupKey string, addr
|
||||
err = errRet
|
||||
return
|
||||
}
|
||||
ln = newTCPGroupListener(group, tg, tcpLn.Addr())
|
||||
|
||||
tg.group = group
|
||||
tg.groupKey = groupKey
|
||||
tg.addr = addr
|
||||
tg.port = port
|
||||
tg.realPort = realPort
|
||||
tg.tcpLn = tcpLn
|
||||
tg.lns = append(tg.lns, ln)
|
||||
if tg.acceptCh == nil {
|
||||
tg.acceptCh = make(chan net.Conn)
|
||||
}
|
||||
go tg.worker()
|
||||
tg.initBase(group, groupKey, tcpLn, func() {
|
||||
tg.ctl.portManager.Release(tg.realPort)
|
||||
tg.ctl.removeIf(tg.group, func(cur *TCPGroup) bool {
|
||||
return cur == tg
|
||||
})
|
||||
})
|
||||
ln = tg.newListener(tcpLn.Addr())
|
||||
go tg.worker(tcpLn, tg.acceptCh)
|
||||
} else {
|
||||
// address and port in the same group must be equal
|
||||
if tg.group != group || tg.addr != addr {
|
||||
@@ -133,92 +116,8 @@ func (tg *TCPGroup) Listen(proxyName string, group string, groupKey string, addr
|
||||
err = ErrGroupAuthFailed
|
||||
return
|
||||
}
|
||||
ln = newTCPGroupListener(group, tg, tg.lns[0].Addr())
|
||||
ln = tg.newListener(tg.lns[0].Addr())
|
||||
realPort = tg.realPort
|
||||
tg.lns = append(tg.lns, ln)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// worker is called when the real tcp listener has been created
|
||||
func (tg *TCPGroup) worker() {
|
||||
for {
|
||||
c, err := tg.tcpLn.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = gerr.PanicToError(func() {
|
||||
tg.acceptCh <- c
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (tg *TCPGroup) Accept() <-chan net.Conn {
|
||||
return tg.acceptCh
|
||||
}
|
||||
|
||||
// CloseListener remove the TCPGroupListener from the TCPGroup
|
||||
func (tg *TCPGroup) CloseListener(ln *TCPGroupListener) {
|
||||
tg.mu.Lock()
|
||||
defer tg.mu.Unlock()
|
||||
for i, tmpLn := range tg.lns {
|
||||
if tmpLn == ln {
|
||||
tg.lns = append(tg.lns[:i], tg.lns[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(tg.lns) == 0 {
|
||||
close(tg.acceptCh)
|
||||
tg.tcpLn.Close()
|
||||
tg.ctl.portManager.Release(tg.realPort)
|
||||
tg.ctl.RemoveGroup(tg.group)
|
||||
}
|
||||
}
|
||||
|
||||
// TCPGroupListener
|
||||
type TCPGroupListener struct {
|
||||
groupName string
|
||||
group *TCPGroup
|
||||
|
||||
addr net.Addr
|
||||
closeCh chan struct{}
|
||||
}
|
||||
|
||||
func newTCPGroupListener(name string, group *TCPGroup, addr net.Addr) *TCPGroupListener {
|
||||
return &TCPGroupListener{
|
||||
groupName: name,
|
||||
group: group,
|
||||
addr: addr,
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Accept will accept connections from TCPGroup
|
||||
func (ln *TCPGroupListener) Accept() (c net.Conn, err error) {
|
||||
var ok bool
|
||||
select {
|
||||
case <-ln.closeCh:
|
||||
return nil, ErrListenerClosed
|
||||
case c, ok = <-ln.group.Accept():
|
||||
if !ok {
|
||||
return nil, ErrListenerClosed
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (ln *TCPGroupListener) Addr() net.Addr {
|
||||
return ln.addr
|
||||
}
|
||||
|
||||
// Close close the listener
|
||||
func (ln *TCPGroupListener) Close() (err error) {
|
||||
close(ln.closeCh)
|
||||
|
||||
// remove self from TcpGroup
|
||||
ln.group.CloseListener(ln)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -18,118 +18,100 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
gerr "github.com/fatedier/golib/errors"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/util/tcpmux"
|
||||
"github.com/fatedier/frp/pkg/util/vhost"
|
||||
)
|
||||
|
||||
// TCPMuxGroupCtl manage all TCPMuxGroups
|
||||
// TCPMuxGroupCtl manages all TCPMuxGroups.
|
||||
type TCPMuxGroupCtl struct {
|
||||
groups map[string]*TCPMuxGroup
|
||||
|
||||
// portManager is used to manage port
|
||||
groupRegistry[*TCPMuxGroup]
|
||||
tcpMuxHTTPConnectMuxer *tcpmux.HTTPConnectTCPMuxer
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewTCPMuxGroupCtl return a new TCPMuxGroupCtl
|
||||
// NewTCPMuxGroupCtl returns a new TCPMuxGroupCtl.
|
||||
func NewTCPMuxGroupCtl(tcpMuxHTTPConnectMuxer *tcpmux.HTTPConnectTCPMuxer) *TCPMuxGroupCtl {
|
||||
return &TCPMuxGroupCtl{
|
||||
groups: make(map[string]*TCPMuxGroup),
|
||||
groupRegistry: newGroupRegistry[*TCPMuxGroup](),
|
||||
tcpMuxHTTPConnectMuxer: tcpMuxHTTPConnectMuxer,
|
||||
}
|
||||
}
|
||||
|
||||
// Listen is the wrapper for TCPMuxGroup's Listen
|
||||
// If there are no group, we will create one here
|
||||
// Listen is the wrapper for TCPMuxGroup's Listen.
|
||||
// If there is no group, one will be created.
|
||||
func (tmgc *TCPMuxGroupCtl) Listen(
|
||||
ctx context.Context,
|
||||
multiplexer, group, groupKey string,
|
||||
routeConfig vhost.RouteConfig,
|
||||
) (l net.Listener, err error) {
|
||||
tmgc.mu.Lock()
|
||||
tcpMuxGroup, ok := tmgc.groups[group]
|
||||
if !ok {
|
||||
tcpMuxGroup = NewTCPMuxGroup(tmgc)
|
||||
tmgc.groups[group] = tcpMuxGroup
|
||||
}
|
||||
tmgc.mu.Unlock()
|
||||
for {
|
||||
tcpMuxGroup := tmgc.getOrCreate(group, func() *TCPMuxGroup {
|
||||
return NewTCPMuxGroup(tmgc)
|
||||
})
|
||||
|
||||
switch v1.TCPMultiplexerType(multiplexer) {
|
||||
case v1.TCPMultiplexerHTTPConnect:
|
||||
return tcpMuxGroup.HTTPConnectListen(ctx, group, groupKey, routeConfig)
|
||||
default:
|
||||
err = fmt.Errorf("unknown multiplexer [%s]", multiplexer)
|
||||
return
|
||||
switch v1.TCPMultiplexerType(multiplexer) {
|
||||
case v1.TCPMultiplexerHTTPConnect:
|
||||
l, err = tcpMuxGroup.HTTPConnectListen(ctx, group, groupKey, routeConfig)
|
||||
if err == errGroupStale {
|
||||
continue
|
||||
}
|
||||
return
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown multiplexer [%s]", multiplexer)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveGroup remove TCPMuxGroup from controller
|
||||
func (tmgc *TCPMuxGroupCtl) RemoveGroup(group string) {
|
||||
tmgc.mu.Lock()
|
||||
defer tmgc.mu.Unlock()
|
||||
delete(tmgc.groups, group)
|
||||
}
|
||||
|
||||
// TCPMuxGroup route connections to different proxies
|
||||
// TCPMuxGroup routes connections to different proxies.
|
||||
type TCPMuxGroup struct {
|
||||
group string
|
||||
groupKey string
|
||||
baseGroup
|
||||
|
||||
domain string
|
||||
routeByHTTPUser string
|
||||
username string
|
||||
password string
|
||||
|
||||
acceptCh chan net.Conn
|
||||
tcpMuxLn net.Listener
|
||||
lns []*TCPMuxGroupListener
|
||||
ctl *TCPMuxGroupCtl
|
||||
mu sync.Mutex
|
||||
ctl *TCPMuxGroupCtl
|
||||
}
|
||||
|
||||
// NewTCPMuxGroup return a new TCPMuxGroup
|
||||
// NewTCPMuxGroup returns a new TCPMuxGroup.
|
||||
func NewTCPMuxGroup(ctl *TCPMuxGroupCtl) *TCPMuxGroup {
|
||||
return &TCPMuxGroup{
|
||||
lns: make([]*TCPMuxGroupListener, 0),
|
||||
ctl: ctl,
|
||||
acceptCh: make(chan net.Conn),
|
||||
ctl: ctl,
|
||||
}
|
||||
}
|
||||
|
||||
// Listen will return a new TCPMuxGroupListener
|
||||
// if TCPMuxGroup already has a listener, just add a new TCPMuxGroupListener to the queues
|
||||
// otherwise, listen on the real address
|
||||
// HTTPConnectListen will return a new Listener.
|
||||
// If TCPMuxGroup already has a listener, just add a new Listener to the queues,
|
||||
// otherwise listen on the real address.
|
||||
func (tmg *TCPMuxGroup) HTTPConnectListen(
|
||||
ctx context.Context,
|
||||
group, groupKey string,
|
||||
routeConfig vhost.RouteConfig,
|
||||
) (ln *TCPMuxGroupListener, err error) {
|
||||
) (ln *Listener, err error) {
|
||||
tmg.mu.Lock()
|
||||
defer tmg.mu.Unlock()
|
||||
if !tmg.ctl.isCurrent(group, func(cur *TCPMuxGroup) bool { return cur == tmg }) {
|
||||
return nil, errGroupStale
|
||||
}
|
||||
if len(tmg.lns) == 0 {
|
||||
// the first listener, listen on the real address
|
||||
tcpMuxLn, errRet := tmg.ctl.tcpMuxHTTPConnectMuxer.Listen(ctx, &routeConfig)
|
||||
if errRet != nil {
|
||||
return nil, errRet
|
||||
}
|
||||
ln = newTCPMuxGroupListener(group, tmg, tcpMuxLn.Addr())
|
||||
|
||||
tmg.group = group
|
||||
tmg.groupKey = groupKey
|
||||
tmg.domain = routeConfig.Domain
|
||||
tmg.routeByHTTPUser = routeConfig.RouteByHTTPUser
|
||||
tmg.username = routeConfig.Username
|
||||
tmg.password = routeConfig.Password
|
||||
tmg.tcpMuxLn = tcpMuxLn
|
||||
tmg.lns = append(tmg.lns, ln)
|
||||
if tmg.acceptCh == nil {
|
||||
tmg.acceptCh = make(chan net.Conn)
|
||||
}
|
||||
go tmg.worker()
|
||||
tmg.initBase(group, groupKey, tcpMuxLn, func() {
|
||||
tmg.ctl.removeIf(tmg.group, func(cur *TCPMuxGroup) bool {
|
||||
return cur == tmg
|
||||
})
|
||||
})
|
||||
ln = tmg.newListener(tcpMuxLn.Addr())
|
||||
go tmg.worker(tcpMuxLn, tmg.acceptCh)
|
||||
} else {
|
||||
// route config in the same group must be equal
|
||||
if tmg.group != group || tmg.domain != routeConfig.Domain ||
|
||||
@@ -141,90 +123,7 @@ func (tmg *TCPMuxGroup) HTTPConnectListen(
|
||||
if tmg.groupKey != groupKey {
|
||||
return nil, ErrGroupAuthFailed
|
||||
}
|
||||
ln = newTCPMuxGroupListener(group, tmg, tmg.lns[0].Addr())
|
||||
tmg.lns = append(tmg.lns, ln)
|
||||
ln = tmg.newListener(tmg.lns[0].Addr())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// worker is called when the real TCP listener has been created
|
||||
func (tmg *TCPMuxGroup) worker() {
|
||||
for {
|
||||
c, err := tmg.tcpMuxLn.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = gerr.PanicToError(func() {
|
||||
tmg.acceptCh <- c
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (tmg *TCPMuxGroup) Accept() <-chan net.Conn {
|
||||
return tmg.acceptCh
|
||||
}
|
||||
|
||||
// CloseListener remove the TCPMuxGroupListener from the TCPMuxGroup
|
||||
func (tmg *TCPMuxGroup) CloseListener(ln *TCPMuxGroupListener) {
|
||||
tmg.mu.Lock()
|
||||
defer tmg.mu.Unlock()
|
||||
for i, tmpLn := range tmg.lns {
|
||||
if tmpLn == ln {
|
||||
tmg.lns = append(tmg.lns[:i], tmg.lns[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(tmg.lns) == 0 {
|
||||
close(tmg.acceptCh)
|
||||
tmg.tcpMuxLn.Close()
|
||||
tmg.ctl.RemoveGroup(tmg.group)
|
||||
}
|
||||
}
|
||||
|
||||
// TCPMuxGroupListener
|
||||
type TCPMuxGroupListener struct {
|
||||
groupName string
|
||||
group *TCPMuxGroup
|
||||
|
||||
addr net.Addr
|
||||
closeCh chan struct{}
|
||||
}
|
||||
|
||||
func newTCPMuxGroupListener(name string, group *TCPMuxGroup, addr net.Addr) *TCPMuxGroupListener {
|
||||
return &TCPMuxGroupListener{
|
||||
groupName: name,
|
||||
group: group,
|
||||
addr: addr,
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Accept will accept connections from TCPMuxGroup
|
||||
func (ln *TCPMuxGroupListener) Accept() (c net.Conn, err error) {
|
||||
var ok bool
|
||||
select {
|
||||
case <-ln.closeCh:
|
||||
return nil, ErrListenerClosed
|
||||
case c, ok = <-ln.group.Accept():
|
||||
if !ok {
|
||||
return nil, ErrListenerClosed
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (ln *TCPMuxGroupListener) Addr() net.Addr {
|
||||
return ln.addr
|
||||
}
|
||||
|
||||
// Close close the listener
|
||||
func (ln *TCPMuxGroupListener) Close() (err error) {
|
||||
close(ln.closeCh)
|
||||
|
||||
// remove self from TcpMuxGroup
|
||||
ln.group.CloseListener(ln)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -186,6 +186,68 @@ var _ = ginkgo.Describe("[Feature: Group]", func() {
|
||||
|
||||
framework.ExpectTrue(fooCount > 1 && barCount > 1, "fooCount: %d, barCount: %d", fooCount, barCount)
|
||||
})
|
||||
|
||||
ginkgo.It("TCPMux httpconnect", func() {
|
||||
vhostPort := f.AllocPort()
|
||||
serverConf := consts.DefaultServerConfig + fmt.Sprintf(`
|
||||
tcpmuxHTTPConnectPort = %d
|
||||
`, vhostPort)
|
||||
clientConf := consts.DefaultClientConfig
|
||||
|
||||
fooPort := f.AllocPort()
|
||||
fooServer := streamserver.New(streamserver.TCP, streamserver.WithBindPort(fooPort), streamserver.WithRespContent([]byte("foo")))
|
||||
f.RunServer("", fooServer)
|
||||
|
||||
barPort := f.AllocPort()
|
||||
barServer := streamserver.New(streamserver.TCP, streamserver.WithBindPort(barPort), streamserver.WithRespContent([]byte("bar")))
|
||||
f.RunServer("", barServer)
|
||||
|
||||
clientConf += fmt.Sprintf(`
|
||||
[[proxies]]
|
||||
name = "foo"
|
||||
type = "tcpmux"
|
||||
multiplexer = "httpconnect"
|
||||
localPort = %d
|
||||
customDomains = ["tcpmux-group.example.com"]
|
||||
loadBalancer.group = "test"
|
||||
loadBalancer.groupKey = "123"
|
||||
|
||||
[[proxies]]
|
||||
name = "bar"
|
||||
type = "tcpmux"
|
||||
multiplexer = "httpconnect"
|
||||
localPort = %d
|
||||
customDomains = ["tcpmux-group.example.com"]
|
||||
loadBalancer.group = "test"
|
||||
loadBalancer.groupKey = "123"
|
||||
`, fooPort, barPort)
|
||||
|
||||
f.RunProcesses([]string{serverConf}, []string{clientConf})
|
||||
|
||||
proxyURL := fmt.Sprintf("http://127.0.0.1:%d", vhostPort)
|
||||
fooCount := 0
|
||||
barCount := 0
|
||||
for i := range 10 {
|
||||
framework.NewRequestExpect(f).
|
||||
Explain("times " + strconv.Itoa(i)).
|
||||
RequestModify(func(r *request.Request) {
|
||||
r.Addr("tcpmux-group.example.com").Proxy(proxyURL)
|
||||
}).
|
||||
Ensure(func(resp *request.Response) bool {
|
||||
switch string(resp.Content) {
|
||||
case "foo":
|
||||
fooCount++
|
||||
case "bar":
|
||||
barCount++
|
||||
default:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
framework.ExpectTrue(fooCount > 1 && barCount > 1, "fooCount: %d, barCount: %d", fooCount, barCount)
|
||||
})
|
||||
})
|
||||
|
||||
ginkgo.Describe("Health Check", func() {
|
||||
|
||||
Reference in New Issue
Block a user