mirror of
https://github.com/fatedier/frp.git
synced 2026-03-08 02:49:10 +08:00
Compare commits
4 Commits
2f70a2c905
...
443b9bca66
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
443b9bca66 | ||
|
|
c70ceff370 | ||
|
|
bb3d0e7140 | ||
|
|
cf396563f8 |
@@ -16,6 +16,7 @@ package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"reflect"
|
||||
@@ -122,6 +123,33 @@ func (pxy *BaseProxy) Close() {
|
||||
}
|
||||
}
|
||||
|
||||
// wrapWorkConn applies rate limiting, encryption, and compression
|
||||
// to a work connection based on the proxy's transport configuration.
|
||||
// The returned recycle function should be called when the stream is no longer in use
|
||||
// to return compression resources to the pool. It is safe to not call recycle,
|
||||
// in which case resources will be garbage collected normally.
|
||||
func (pxy *BaseProxy) wrapWorkConn(conn net.Conn, encKey []byte) (io.ReadWriteCloser, func(), error) {
|
||||
var rwc io.ReadWriteCloser = conn
|
||||
if pxy.limiter != nil {
|
||||
rwc = libio.WrapReadWriteCloser(limit.NewReader(conn, pxy.limiter), limit.NewWriter(conn, pxy.limiter), func() error {
|
||||
return conn.Close()
|
||||
})
|
||||
}
|
||||
if pxy.baseCfg.Transport.UseEncryption {
|
||||
var err error
|
||||
rwc, err = libio.WithEncryption(rwc, encKey)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, nil, fmt.Errorf("create encryption stream error: %w", err)
|
||||
}
|
||||
}
|
||||
var recycleFn func()
|
||||
if pxy.baseCfg.Transport.UseCompression {
|
||||
rwc, recycleFn = libio.WithCompressionFromPool(rwc)
|
||||
}
|
||||
return rwc, recycleFn, nil
|
||||
}
|
||||
|
||||
func (pxy *BaseProxy) SetInWorkConnCallback(cb func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool) {
|
||||
pxy.inWorkConnCallback = cb
|
||||
}
|
||||
@@ -139,30 +167,14 @@ func (pxy *BaseProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) {
|
||||
func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWorkConn, encKey []byte) {
|
||||
xl := pxy.xl
|
||||
baseCfg := pxy.baseCfg
|
||||
var (
|
||||
remote io.ReadWriteCloser
|
||||
err error
|
||||
)
|
||||
remote = workConn
|
||||
if pxy.limiter != nil {
|
||||
remote = libio.WrapReadWriteCloser(limit.NewReader(workConn, pxy.limiter), limit.NewWriter(workConn, pxy.limiter), func() error {
|
||||
return workConn.Close()
|
||||
})
|
||||
}
|
||||
|
||||
xl.Tracef("handle tcp work connection, useEncryption: %t, useCompression: %t",
|
||||
baseCfg.Transport.UseEncryption, baseCfg.Transport.UseCompression)
|
||||
if baseCfg.Transport.UseEncryption {
|
||||
remote, err = libio.WithEncryption(remote, encKey)
|
||||
if err != nil {
|
||||
workConn.Close()
|
||||
xl.Errorf("create encryption stream error: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
var compressionResourceRecycleFn func()
|
||||
if baseCfg.Transport.UseCompression {
|
||||
remote, compressionResourceRecycleFn = libio.WithCompressionFromPool(remote)
|
||||
|
||||
remote, recycleFn, err := pxy.wrapWorkConn(workConn, encKey)
|
||||
if err != nil {
|
||||
xl.Errorf("wrap work connection: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// check if we need to send proxy protocol info
|
||||
@@ -178,7 +190,6 @@ func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWor
|
||||
}
|
||||
|
||||
if baseCfg.Transport.ProxyProtocolVersion != "" && m.SrcAddr != "" && m.SrcPort != 0 {
|
||||
// Use the common proxy protocol builder function
|
||||
header := netpkg.BuildProxyProtocolHeaderStruct(connInfo.SrcAddr, connInfo.DstAddr, baseCfg.Transport.ProxyProtocolVersion)
|
||||
connInfo.ProxyProtocolHeader = header
|
||||
}
|
||||
@@ -187,12 +198,18 @@ func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWor
|
||||
|
||||
if pxy.proxyPlugin != nil {
|
||||
// if plugin is set, let plugin handle connection first
|
||||
// Don't recycle compression resources here because plugins may
|
||||
// retain the connection after Handle returns.
|
||||
xl.Debugf("handle by plugin: %s", pxy.proxyPlugin.Name())
|
||||
pxy.proxyPlugin.Handle(pxy.ctx, &connInfo)
|
||||
xl.Debugf("handle by plugin finished")
|
||||
return
|
||||
}
|
||||
|
||||
if recycleFn != nil {
|
||||
defer recycleFn()
|
||||
}
|
||||
|
||||
localConn, err := libnet.Dial(
|
||||
net.JoinHostPort(baseCfg.LocalIP, strconv.Itoa(baseCfg.LocalPort)),
|
||||
libnet.WithTimeout(10*time.Second),
|
||||
@@ -220,7 +237,4 @@ func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWor
|
||||
if len(errs) > 0 {
|
||||
xl.Tracef("join connections errors: %v", errs)
|
||||
}
|
||||
if compressionResourceRecycleFn != nil {
|
||||
compressionResourceRecycleFn()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"reflect"
|
||||
"strconv"
|
||||
@@ -25,12 +24,10 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/golib/errors"
|
||||
libio "github.com/fatedier/golib/io"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
"github.com/fatedier/frp/pkg/proto/udp"
|
||||
"github.com/fatedier/frp/pkg/util/limit"
|
||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||
)
|
||||
|
||||
@@ -83,27 +80,13 @@ func (pxy *SUDPProxy) InWorkConn(conn net.Conn, _ *msg.StartWorkConn) {
|
||||
xl := pxy.xl
|
||||
xl.Infof("incoming a new work connection for sudp proxy, %s", conn.RemoteAddr().String())
|
||||
|
||||
var rwc io.ReadWriteCloser = conn
|
||||
var err error
|
||||
if pxy.limiter != nil {
|
||||
rwc = libio.WrapReadWriteCloser(limit.NewReader(conn, pxy.limiter), limit.NewWriter(conn, pxy.limiter), func() error {
|
||||
return conn.Close()
|
||||
})
|
||||
remote, _, err := pxy.wrapWorkConn(conn, pxy.encryptionKey)
|
||||
if err != nil {
|
||||
xl.Errorf("wrap work connection: %v", err)
|
||||
return
|
||||
}
|
||||
if pxy.cfg.Transport.UseEncryption {
|
||||
rwc, err = libio.WithEncryption(rwc, pxy.encryptionKey)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
xl.Errorf("create encryption stream error: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
if pxy.cfg.Transport.UseCompression {
|
||||
rwc = libio.WithCompression(rwc)
|
||||
}
|
||||
conn = netpkg.WrapReadWriteCloserToConn(rwc, conn)
|
||||
|
||||
workConn := conn
|
||||
workConn := netpkg.WrapReadWriteCloserToConn(remote, conn)
|
||||
readCh := make(chan *msg.UDPPacket, 1024)
|
||||
sendCh := make(chan msg.Message, 1024)
|
||||
isClose := false
|
||||
|
||||
@@ -17,19 +17,16 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/golib/errors"
|
||||
libio "github.com/fatedier/golib/io"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
"github.com/fatedier/frp/pkg/proto/udp"
|
||||
"github.com/fatedier/frp/pkg/util/limit"
|
||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||
)
|
||||
|
||||
@@ -94,28 +91,14 @@ func (pxy *UDPProxy) InWorkConn(conn net.Conn, _ *msg.StartWorkConn) {
|
||||
// close resources related with old workConn
|
||||
pxy.Close()
|
||||
|
||||
var rwc io.ReadWriteCloser = conn
|
||||
var err error
|
||||
if pxy.limiter != nil {
|
||||
rwc = libio.WrapReadWriteCloser(limit.NewReader(conn, pxy.limiter), limit.NewWriter(conn, pxy.limiter), func() error {
|
||||
return conn.Close()
|
||||
})
|
||||
remote, _, err := pxy.wrapWorkConn(conn, pxy.encryptionKey)
|
||||
if err != nil {
|
||||
xl.Errorf("wrap work connection: %v", err)
|
||||
return
|
||||
}
|
||||
if pxy.cfg.Transport.UseEncryption {
|
||||
rwc, err = libio.WithEncryption(rwc, pxy.encryptionKey)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
xl.Errorf("create encryption stream error: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
if pxy.cfg.Transport.UseCompression {
|
||||
rwc = libio.WithCompression(rwc)
|
||||
}
|
||||
conn = netpkg.WrapReadWriteCloserToConn(rwc, conn)
|
||||
|
||||
pxy.mu.Lock()
|
||||
pxy.workConn = conn
|
||||
pxy.workConn = netpkg.WrapReadWriteCloserToConn(remote, conn)
|
||||
pxy.readCh = make(chan *msg.UDPPacket, 1024)
|
||||
pxy.sendCh = make(chan msg.Message, 1024)
|
||||
pxy.closed = false
|
||||
|
||||
@@ -42,10 +42,10 @@ func (sv *STCPVisitor) Run() (err error) {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go sv.worker()
|
||||
go sv.acceptLoop(sv.l, "stcp local", sv.handleConn)
|
||||
}
|
||||
|
||||
go sv.internalConnWorker()
|
||||
go sv.acceptLoop(sv.internalLn, "stcp internal", sv.handleConn)
|
||||
|
||||
if sv.plugin != nil {
|
||||
sv.plugin.Start()
|
||||
@@ -57,30 +57,6 @@ func (sv *STCPVisitor) Close() {
|
||||
sv.BaseVisitor.Close()
|
||||
}
|
||||
|
||||
func (sv *STCPVisitor) worker() {
|
||||
xl := xlog.FromContextSafe(sv.ctx)
|
||||
for {
|
||||
conn, err := sv.l.Accept()
|
||||
if err != nil {
|
||||
xl.Warnf("stcp local listener closed")
|
||||
return
|
||||
}
|
||||
go sv.handleConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (sv *STCPVisitor) internalConnWorker() {
|
||||
xl := xlog.FromContextSafe(sv.ctx)
|
||||
for {
|
||||
conn, err := sv.internalLn.Accept()
|
||||
if err != nil {
|
||||
xl.Warnf("stcp internal listener closed")
|
||||
return
|
||||
}
|
||||
go sv.handleConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (sv *STCPVisitor) handleConn(userConn net.Conn) {
|
||||
xl := xlog.FromContextSafe(sv.ctx)
|
||||
var tunnelErr error
|
||||
|
||||
@@ -119,6 +119,18 @@ func (v *BaseVisitor) AcceptConn(conn net.Conn) error {
|
||||
return v.internalLn.PutConn(conn)
|
||||
}
|
||||
|
||||
func (v *BaseVisitor) acceptLoop(l net.Listener, name string, handleConn func(net.Conn)) {
|
||||
xl := xlog.FromContextSafe(v.ctx)
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
xl.Warnf("%s listener closed", name)
|
||||
return
|
||||
}
|
||||
go handleConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (v *BaseVisitor) Close() {
|
||||
if v.l != nil {
|
||||
v.l.Close()
|
||||
|
||||
@@ -65,10 +65,10 @@ func (sv *XTCPVisitor) Run() (err error) {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go sv.worker()
|
||||
go sv.acceptLoop(sv.l, "xtcp local", sv.handleConn)
|
||||
}
|
||||
|
||||
go sv.internalConnWorker()
|
||||
go sv.acceptLoop(sv.internalLn, "xtcp internal", sv.handleConn)
|
||||
go sv.processTunnelStartEvents()
|
||||
if sv.cfg.KeepTunnelOpen {
|
||||
sv.retryLimiter = rate.NewLimiter(rate.Every(time.Hour/time.Duration(sv.cfg.MaxRetriesAnHour)), sv.cfg.MaxRetriesAnHour)
|
||||
@@ -93,30 +93,6 @@ func (sv *XTCPVisitor) Close() {
|
||||
}
|
||||
}
|
||||
|
||||
func (sv *XTCPVisitor) worker() {
|
||||
xl := xlog.FromContextSafe(sv.ctx)
|
||||
for {
|
||||
conn, err := sv.l.Accept()
|
||||
if err != nil {
|
||||
xl.Warnf("xtcp local listener closed")
|
||||
return
|
||||
}
|
||||
go sv.handleConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (sv *XTCPVisitor) internalConnWorker() {
|
||||
xl := xlog.FromContextSafe(sv.ctx)
|
||||
for {
|
||||
conn, err := sv.internalLn.Accept()
|
||||
if err != nil {
|
||||
xl.Warnf("xtcp internal listener closed")
|
||||
return
|
||||
}
|
||||
go sv.handleConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (sv *XTCPVisitor) processTunnelStartEvents() {
|
||||
for {
|
||||
select {
|
||||
|
||||
@@ -203,6 +203,25 @@ func (m *serverMetrics) GetServer() *ServerStats {
|
||||
return s
|
||||
}
|
||||
|
||||
func toProxyStats(name string, proxyStats *ProxyStatistics) *ProxyStats {
|
||||
ps := &ProxyStats{
|
||||
Name: name,
|
||||
Type: proxyStats.ProxyType,
|
||||
User: proxyStats.User,
|
||||
ClientID: proxyStats.ClientID,
|
||||
TodayTrafficIn: proxyStats.TrafficIn.TodayCount(),
|
||||
TodayTrafficOut: proxyStats.TrafficOut.TodayCount(),
|
||||
CurConns: int64(proxyStats.CurConns.Count()),
|
||||
}
|
||||
if !proxyStats.LastStartTime.IsZero() {
|
||||
ps.LastStartTime = proxyStats.LastStartTime.Format("01-02 15:04:05")
|
||||
}
|
||||
if !proxyStats.LastCloseTime.IsZero() {
|
||||
ps.LastCloseTime = proxyStats.LastCloseTime.Format("01-02 15:04:05")
|
||||
}
|
||||
return ps
|
||||
}
|
||||
|
||||
func (m *serverMetrics) GetProxiesByType(proxyType string) []*ProxyStats {
|
||||
res := make([]*ProxyStats, 0)
|
||||
m.mu.Lock()
|
||||
@@ -212,23 +231,7 @@ func (m *serverMetrics) GetProxiesByType(proxyType string) []*ProxyStats {
|
||||
if proxyStats.ProxyType != proxyType {
|
||||
continue
|
||||
}
|
||||
|
||||
ps := &ProxyStats{
|
||||
Name: name,
|
||||
Type: proxyStats.ProxyType,
|
||||
User: proxyStats.User,
|
||||
ClientID: proxyStats.ClientID,
|
||||
TodayTrafficIn: proxyStats.TrafficIn.TodayCount(),
|
||||
TodayTrafficOut: proxyStats.TrafficOut.TodayCount(),
|
||||
CurConns: int64(proxyStats.CurConns.Count()),
|
||||
}
|
||||
if !proxyStats.LastStartTime.IsZero() {
|
||||
ps.LastStartTime = proxyStats.LastStartTime.Format("01-02 15:04:05")
|
||||
}
|
||||
if !proxyStats.LastCloseTime.IsZero() {
|
||||
ps.LastCloseTime = proxyStats.LastCloseTime.Format("01-02 15:04:05")
|
||||
}
|
||||
res = append(res, ps)
|
||||
res = append(res, toProxyStats(name, proxyStats))
|
||||
}
|
||||
return res
|
||||
}
|
||||
@@ -241,26 +244,10 @@ func (m *serverMetrics) GetProxiesByTypeAndName(proxyType string, proxyName stri
|
||||
if proxyStats.ProxyType != proxyType {
|
||||
continue
|
||||
}
|
||||
|
||||
if name != proxyName {
|
||||
continue
|
||||
}
|
||||
|
||||
res = &ProxyStats{
|
||||
Name: name,
|
||||
Type: proxyStats.ProxyType,
|
||||
User: proxyStats.User,
|
||||
ClientID: proxyStats.ClientID,
|
||||
TodayTrafficIn: proxyStats.TrafficIn.TodayCount(),
|
||||
TodayTrafficOut: proxyStats.TrafficOut.TodayCount(),
|
||||
CurConns: int64(proxyStats.CurConns.Count()),
|
||||
}
|
||||
if !proxyStats.LastStartTime.IsZero() {
|
||||
res.LastStartTime = proxyStats.LastStartTime.Format("01-02 15:04:05")
|
||||
}
|
||||
if !proxyStats.LastCloseTime.IsZero() {
|
||||
res.LastCloseTime = proxyStats.LastCloseTime.Format("01-02 15:04:05")
|
||||
}
|
||||
res = toProxyStats(name, proxyStats)
|
||||
break
|
||||
}
|
||||
return
|
||||
@@ -272,21 +259,7 @@ func (m *serverMetrics) GetProxyByName(proxyName string) (res *ProxyStats) {
|
||||
|
||||
proxyStats, ok := m.info.ProxyStatistics[proxyName]
|
||||
if ok {
|
||||
res = &ProxyStats{
|
||||
Name: proxyName,
|
||||
Type: proxyStats.ProxyType,
|
||||
User: proxyStats.User,
|
||||
ClientID: proxyStats.ClientID,
|
||||
TodayTrafficIn: proxyStats.TrafficIn.TodayCount(),
|
||||
TodayTrafficOut: proxyStats.TrafficOut.TodayCount(),
|
||||
CurConns: int64(proxyStats.CurConns.Count()),
|
||||
}
|
||||
if !proxyStats.LastStartTime.IsZero() {
|
||||
res.LastStartTime = proxyStats.LastStartTime.Format("01-02 15:04:05")
|
||||
}
|
||||
if !proxyStats.LastCloseTime.IsZero() {
|
||||
res.LastCloseTime = proxyStats.LastCloseTime.Format("01-02 15:04:05")
|
||||
}
|
||||
res = toProxyStats(proxyName, proxyStats)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -152,7 +152,9 @@ func (c *Controller) GenSid() string {
|
||||
|
||||
func (c *Controller) HandleVisitor(m *msg.NatHoleVisitor, transporter transport.MessageTransporter, visitorUser string) {
|
||||
if m.PreCheck {
|
||||
c.mu.RLock()
|
||||
cfg, ok := c.clientCfgs[m.ProxyName]
|
||||
c.mu.RUnlock()
|
||||
if !ok {
|
||||
_ = transporter.Send(c.GenNatHoleResponse(m.TransactionID, nil, fmt.Sprintf("xtcp server for [%s] doesn't exist", m.ProxyName)))
|
||||
return
|
||||
|
||||
@@ -168,11 +168,15 @@ func ListenUDP(bindAddr string, bindPort int) (l *UDPListener, err error) {
|
||||
return l, err
|
||||
}
|
||||
readConn, err := net.ListenUDP("udp", udpAddr)
|
||||
if err != nil {
|
||||
return l, err
|
||||
}
|
||||
|
||||
l = &UDPListener{
|
||||
addr: udpAddr,
|
||||
acceptCh: make(chan net.Conn),
|
||||
writeCh: make(chan *UDPPacket, 1000),
|
||||
readConn: readConn,
|
||||
fakeConns: make(map[string]*FakeUDPConn),
|
||||
}
|
||||
|
||||
|
||||
@@ -266,31 +266,13 @@ func (rp *HTTPReverseProxy) connectHandler(rw http.ResponseWriter, req *http.Req
|
||||
go libio.Join(remote, client)
|
||||
}
|
||||
|
||||
func parseBasicAuth(auth string) (username, password string, ok bool) {
|
||||
const prefix = "Basic "
|
||||
// Case insensitive prefix match. See Issue 22736.
|
||||
if len(auth) < len(prefix) || !strings.EqualFold(auth[:len(prefix)], prefix) {
|
||||
return
|
||||
}
|
||||
c, err := base64.StdEncoding.DecodeString(auth[len(prefix):])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
cs := string(c)
|
||||
s := strings.IndexByte(cs, ':')
|
||||
if s < 0 {
|
||||
return
|
||||
}
|
||||
return cs[:s], cs[s+1:], true
|
||||
}
|
||||
|
||||
func (rp *HTTPReverseProxy) injectRequestInfoToCtx(req *http.Request) *http.Request {
|
||||
user := ""
|
||||
// If url host isn't empty, it's a proxy request. Get http user from Proxy-Authorization header.
|
||||
if req.URL.Host != "" {
|
||||
proxyAuth := req.Header.Get("Proxy-Authorization")
|
||||
if proxyAuth != "" {
|
||||
user, _, _ = parseBasicAuth(proxyAuth)
|
||||
user, _, _ = httppkg.ParseBasicAuth(proxyAuth)
|
||||
}
|
||||
}
|
||||
if user == "" {
|
||||
|
||||
@@ -95,20 +95,33 @@ func (cm *ControlManager) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type Control struct {
|
||||
// SessionContext encapsulates the input parameters for creating a new Control.
|
||||
type SessionContext struct {
|
||||
// all resource managers and controllers
|
||||
rc *controller.ResourceController
|
||||
|
||||
RC *controller.ResourceController
|
||||
// proxy manager
|
||||
pxyManager *proxy.Manager
|
||||
|
||||
PxyManager *proxy.Manager
|
||||
// plugin manager
|
||||
pluginManager *plugin.Manager
|
||||
|
||||
PluginManager *plugin.Manager
|
||||
// verifies authentication based on selected method
|
||||
authVerifier auth.Verifier
|
||||
AuthVerifier auth.Verifier
|
||||
// key used for connection encryption
|
||||
encryptionKey []byte
|
||||
EncryptionKey []byte
|
||||
// control connection
|
||||
Conn net.Conn
|
||||
// indicates whether the connection is encrypted
|
||||
ConnEncrypted bool
|
||||
// login message
|
||||
LoginMsg *msg.Login
|
||||
// server configuration
|
||||
ServerCfg *v1.ServerConfig
|
||||
// client registry
|
||||
ClientRegistry *registry.ClientRegistry
|
||||
}
|
||||
|
||||
type Control struct {
|
||||
// session context
|
||||
sessionCtx *SessionContext
|
||||
|
||||
// other components can use this to communicate with client
|
||||
msgTransporter transport.MessageTransporter
|
||||
@@ -117,12 +130,6 @@ type Control struct {
|
||||
// It provides a channel for sending messages, and you can register handlers to process messages based on their respective types.
|
||||
msgDispatcher *msg.Dispatcher
|
||||
|
||||
// login message
|
||||
loginMsg *msg.Login
|
||||
|
||||
// control connection
|
||||
conn net.Conn
|
||||
|
||||
// work connections
|
||||
workConnCh chan net.Conn
|
||||
|
||||
@@ -145,61 +152,37 @@ type Control struct {
|
||||
|
||||
mu sync.RWMutex
|
||||
|
||||
// Server configuration information
|
||||
serverCfg *v1.ServerConfig
|
||||
|
||||
clientRegistry *registry.ClientRegistry
|
||||
|
||||
xl *xlog.Logger
|
||||
ctx context.Context
|
||||
doneCh chan struct{}
|
||||
}
|
||||
|
||||
// TODO(fatedier): Referencing the implementation of frpc, encapsulate the input parameters as SessionContext.
|
||||
func NewControl(
|
||||
ctx context.Context,
|
||||
rc *controller.ResourceController,
|
||||
pxyManager *proxy.Manager,
|
||||
pluginManager *plugin.Manager,
|
||||
authVerifier auth.Verifier,
|
||||
encryptionKey []byte,
|
||||
ctlConn net.Conn,
|
||||
ctlConnEncrypted bool,
|
||||
loginMsg *msg.Login,
|
||||
serverCfg *v1.ServerConfig,
|
||||
) (*Control, error) {
|
||||
poolCount := loginMsg.PoolCount
|
||||
if poolCount > int(serverCfg.Transport.MaxPoolCount) {
|
||||
poolCount = int(serverCfg.Transport.MaxPoolCount)
|
||||
func NewControl(ctx context.Context, sessionCtx *SessionContext) (*Control, error) {
|
||||
poolCount := sessionCtx.LoginMsg.PoolCount
|
||||
if poolCount > int(sessionCtx.ServerCfg.Transport.MaxPoolCount) {
|
||||
poolCount = int(sessionCtx.ServerCfg.Transport.MaxPoolCount)
|
||||
}
|
||||
ctl := &Control{
|
||||
rc: rc,
|
||||
pxyManager: pxyManager,
|
||||
pluginManager: pluginManager,
|
||||
authVerifier: authVerifier,
|
||||
encryptionKey: encryptionKey,
|
||||
conn: ctlConn,
|
||||
loginMsg: loginMsg,
|
||||
workConnCh: make(chan net.Conn, poolCount+10),
|
||||
proxies: make(map[string]proxy.Proxy),
|
||||
poolCount: poolCount,
|
||||
portsUsedNum: 0,
|
||||
runID: loginMsg.RunID,
|
||||
serverCfg: serverCfg,
|
||||
xl: xlog.FromContextSafe(ctx),
|
||||
ctx: ctx,
|
||||
doneCh: make(chan struct{}),
|
||||
sessionCtx: sessionCtx,
|
||||
workConnCh: make(chan net.Conn, poolCount+10),
|
||||
proxies: make(map[string]proxy.Proxy),
|
||||
poolCount: poolCount,
|
||||
portsUsedNum: 0,
|
||||
runID: sessionCtx.LoginMsg.RunID,
|
||||
xl: xlog.FromContextSafe(ctx),
|
||||
ctx: ctx,
|
||||
doneCh: make(chan struct{}),
|
||||
}
|
||||
ctl.lastPing.Store(time.Now())
|
||||
|
||||
if ctlConnEncrypted {
|
||||
cryptoRW, err := netpkg.NewCryptoReadWriter(ctl.conn, ctl.encryptionKey)
|
||||
if sessionCtx.ConnEncrypted {
|
||||
cryptoRW, err := netpkg.NewCryptoReadWriter(sessionCtx.Conn, sessionCtx.EncryptionKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctl.msgDispatcher = msg.NewDispatcher(cryptoRW)
|
||||
} else {
|
||||
ctl.msgDispatcher = msg.NewDispatcher(ctl.conn)
|
||||
ctl.msgDispatcher = msg.NewDispatcher(sessionCtx.Conn)
|
||||
}
|
||||
ctl.registerMsgHandlers()
|
||||
ctl.msgTransporter = transport.NewMessageTransporter(ctl.msgDispatcher)
|
||||
@@ -213,7 +196,7 @@ func (ctl *Control) Start() {
|
||||
RunID: ctl.runID,
|
||||
Error: "",
|
||||
}
|
||||
_ = msg.WriteMsg(ctl.conn, loginRespMsg)
|
||||
_ = msg.WriteMsg(ctl.sessionCtx.Conn, loginRespMsg)
|
||||
|
||||
go func() {
|
||||
for i := 0; i < ctl.poolCount; i++ {
|
||||
@@ -225,7 +208,7 @@ func (ctl *Control) Start() {
|
||||
}
|
||||
|
||||
func (ctl *Control) Close() error {
|
||||
ctl.conn.Close()
|
||||
ctl.sessionCtx.Conn.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -233,7 +216,7 @@ func (ctl *Control) Replaced(newCtl *Control) {
|
||||
xl := ctl.xl
|
||||
xl.Infof("replaced by client [%s]", newCtl.runID)
|
||||
ctl.runID = ""
|
||||
ctl.conn.Close()
|
||||
ctl.sessionCtx.Conn.Close()
|
||||
}
|
||||
|
||||
func (ctl *Control) RegisterWorkConn(conn net.Conn) error {
|
||||
@@ -291,7 +274,7 @@ func (ctl *Control) GetWorkConn() (workConn net.Conn, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
case <-time.After(time.Duration(ctl.serverCfg.UserConnTimeout) * time.Second):
|
||||
case <-time.After(time.Duration(ctl.sessionCtx.ServerCfg.UserConnTimeout) * time.Second):
|
||||
err = fmt.Errorf("timeout trying to get work connection")
|
||||
xl.Warnf("%v", err)
|
||||
return
|
||||
@@ -304,15 +287,15 @@ func (ctl *Control) GetWorkConn() (workConn net.Conn, err error) {
|
||||
}
|
||||
|
||||
func (ctl *Control) heartbeatWorker() {
|
||||
if ctl.serverCfg.Transport.HeartbeatTimeout <= 0 {
|
||||
if ctl.sessionCtx.ServerCfg.Transport.HeartbeatTimeout <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
xl := ctl.xl
|
||||
go wait.Until(func() {
|
||||
if time.Since(ctl.lastPing.Load().(time.Time)) > time.Duration(ctl.serverCfg.Transport.HeartbeatTimeout)*time.Second {
|
||||
if time.Since(ctl.lastPing.Load().(time.Time)) > time.Duration(ctl.sessionCtx.ServerCfg.Transport.HeartbeatTimeout)*time.Second {
|
||||
xl.Warnf("heartbeat timeout")
|
||||
ctl.conn.Close()
|
||||
ctl.sessionCtx.Conn.Close()
|
||||
return
|
||||
}
|
||||
}, time.Second, ctl.doneCh)
|
||||
@@ -330,7 +313,7 @@ func (ctl *Control) worker() {
|
||||
go ctl.msgDispatcher.Run()
|
||||
|
||||
<-ctl.msgDispatcher.Done()
|
||||
ctl.conn.Close()
|
||||
ctl.sessionCtx.Conn.Close()
|
||||
|
||||
ctl.mu.Lock()
|
||||
defer ctl.mu.Unlock()
|
||||
@@ -342,26 +325,26 @@ func (ctl *Control) worker() {
|
||||
|
||||
for _, pxy := range ctl.proxies {
|
||||
pxy.Close()
|
||||
ctl.pxyManager.Del(pxy.GetName())
|
||||
ctl.sessionCtx.PxyManager.Del(pxy.GetName())
|
||||
metrics.Server.CloseProxy(pxy.GetName(), pxy.GetConfigurer().GetBaseConfig().Type)
|
||||
|
||||
notifyContent := &plugin.CloseProxyContent{
|
||||
User: plugin.UserInfo{
|
||||
User: ctl.loginMsg.User,
|
||||
Metas: ctl.loginMsg.Metas,
|
||||
RunID: ctl.loginMsg.RunID,
|
||||
User: ctl.sessionCtx.LoginMsg.User,
|
||||
Metas: ctl.sessionCtx.LoginMsg.Metas,
|
||||
RunID: ctl.sessionCtx.LoginMsg.RunID,
|
||||
},
|
||||
CloseProxy: msg.CloseProxy{
|
||||
ProxyName: pxy.GetName(),
|
||||
},
|
||||
}
|
||||
go func() {
|
||||
_ = ctl.pluginManager.CloseProxy(notifyContent)
|
||||
_ = ctl.sessionCtx.PluginManager.CloseProxy(notifyContent)
|
||||
}()
|
||||
}
|
||||
|
||||
metrics.Server.CloseClient()
|
||||
ctl.clientRegistry.MarkOfflineByRunID(ctl.runID)
|
||||
ctl.sessionCtx.ClientRegistry.MarkOfflineByRunID(ctl.runID)
|
||||
xl.Infof("client exit success")
|
||||
close(ctl.doneCh)
|
||||
}
|
||||
@@ -381,14 +364,14 @@ func (ctl *Control) handleNewProxy(m msg.Message) {
|
||||
|
||||
content := &plugin.NewProxyContent{
|
||||
User: plugin.UserInfo{
|
||||
User: ctl.loginMsg.User,
|
||||
Metas: ctl.loginMsg.Metas,
|
||||
RunID: ctl.loginMsg.RunID,
|
||||
User: ctl.sessionCtx.LoginMsg.User,
|
||||
Metas: ctl.sessionCtx.LoginMsg.Metas,
|
||||
RunID: ctl.sessionCtx.LoginMsg.RunID,
|
||||
},
|
||||
NewProxy: *inMsg,
|
||||
}
|
||||
var remoteAddr string
|
||||
retContent, err := ctl.pluginManager.NewProxy(content)
|
||||
retContent, err := ctl.sessionCtx.PluginManager.NewProxy(content)
|
||||
if err == nil {
|
||||
inMsg = &retContent.NewProxy
|
||||
remoteAddr, err = ctl.RegisterProxy(inMsg)
|
||||
@@ -401,15 +384,15 @@ func (ctl *Control) handleNewProxy(m msg.Message) {
|
||||
if err != nil {
|
||||
xl.Warnf("new proxy [%s] type [%s] error: %v", inMsg.ProxyName, inMsg.ProxyType, err)
|
||||
resp.Error = util.GenerateResponseErrorString(fmt.Sprintf("new proxy [%s] error", inMsg.ProxyName),
|
||||
err, lo.FromPtr(ctl.serverCfg.DetailedErrorsToClient))
|
||||
err, lo.FromPtr(ctl.sessionCtx.ServerCfg.DetailedErrorsToClient))
|
||||
} else {
|
||||
resp.RemoteAddr = remoteAddr
|
||||
xl.Infof("new proxy [%s] type [%s] success", inMsg.ProxyName, inMsg.ProxyType)
|
||||
clientID := ctl.loginMsg.ClientID
|
||||
clientID := ctl.sessionCtx.LoginMsg.ClientID
|
||||
if clientID == "" {
|
||||
clientID = ctl.loginMsg.RunID
|
||||
clientID = ctl.sessionCtx.LoginMsg.RunID
|
||||
}
|
||||
metrics.Server.NewProxy(inMsg.ProxyName, inMsg.ProxyType, ctl.loginMsg.User, clientID)
|
||||
metrics.Server.NewProxy(inMsg.ProxyName, inMsg.ProxyType, ctl.sessionCtx.LoginMsg.User, clientID)
|
||||
}
|
||||
_ = ctl.msgDispatcher.Send(resp)
|
||||
}
|
||||
@@ -420,21 +403,21 @@ func (ctl *Control) handlePing(m msg.Message) {
|
||||
|
||||
content := &plugin.PingContent{
|
||||
User: plugin.UserInfo{
|
||||
User: ctl.loginMsg.User,
|
||||
Metas: ctl.loginMsg.Metas,
|
||||
RunID: ctl.loginMsg.RunID,
|
||||
User: ctl.sessionCtx.LoginMsg.User,
|
||||
Metas: ctl.sessionCtx.LoginMsg.Metas,
|
||||
RunID: ctl.sessionCtx.LoginMsg.RunID,
|
||||
},
|
||||
Ping: *inMsg,
|
||||
}
|
||||
retContent, err := ctl.pluginManager.Ping(content)
|
||||
retContent, err := ctl.sessionCtx.PluginManager.Ping(content)
|
||||
if err == nil {
|
||||
inMsg = &retContent.Ping
|
||||
err = ctl.authVerifier.VerifyPing(inMsg)
|
||||
err = ctl.sessionCtx.AuthVerifier.VerifyPing(inMsg)
|
||||
}
|
||||
if err != nil {
|
||||
xl.Warnf("received invalid ping: %v", err)
|
||||
_ = ctl.msgDispatcher.Send(&msg.Pong{
|
||||
Error: util.GenerateResponseErrorString("invalid ping", err, lo.FromPtr(ctl.serverCfg.DetailedErrorsToClient)),
|
||||
Error: util.GenerateResponseErrorString("invalid ping", err, lo.FromPtr(ctl.sessionCtx.ServerCfg.DetailedErrorsToClient)),
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -445,17 +428,17 @@ func (ctl *Control) handlePing(m msg.Message) {
|
||||
|
||||
func (ctl *Control) handleNatHoleVisitor(m msg.Message) {
|
||||
inMsg := m.(*msg.NatHoleVisitor)
|
||||
ctl.rc.NatHoleController.HandleVisitor(inMsg, ctl.msgTransporter, ctl.loginMsg.User)
|
||||
ctl.sessionCtx.RC.NatHoleController.HandleVisitor(inMsg, ctl.msgTransporter, ctl.sessionCtx.LoginMsg.User)
|
||||
}
|
||||
|
||||
func (ctl *Control) handleNatHoleClient(m msg.Message) {
|
||||
inMsg := m.(*msg.NatHoleClient)
|
||||
ctl.rc.NatHoleController.HandleClient(inMsg, ctl.msgTransporter)
|
||||
ctl.sessionCtx.RC.NatHoleController.HandleClient(inMsg, ctl.msgTransporter)
|
||||
}
|
||||
|
||||
func (ctl *Control) handleNatHoleReport(m msg.Message) {
|
||||
inMsg := m.(*msg.NatHoleReport)
|
||||
ctl.rc.NatHoleController.HandleReport(inMsg)
|
||||
ctl.sessionCtx.RC.NatHoleController.HandleReport(inMsg)
|
||||
}
|
||||
|
||||
func (ctl *Control) handleCloseProxy(m msg.Message) {
|
||||
@@ -468,15 +451,15 @@ func (ctl *Control) handleCloseProxy(m msg.Message) {
|
||||
func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err error) {
|
||||
var pxyConf v1.ProxyConfigurer
|
||||
// Load configures from NewProxy message and validate.
|
||||
pxyConf, err = config.NewProxyConfigurerFromMsg(pxyMsg, ctl.serverCfg)
|
||||
pxyConf, err = config.NewProxyConfigurerFromMsg(pxyMsg, ctl.sessionCtx.ServerCfg)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// User info
|
||||
userInfo := plugin.UserInfo{
|
||||
User: ctl.loginMsg.User,
|
||||
Metas: ctl.loginMsg.Metas,
|
||||
User: ctl.sessionCtx.LoginMsg.User,
|
||||
Metas: ctl.sessionCtx.LoginMsg.Metas,
|
||||
RunID: ctl.runID,
|
||||
}
|
||||
|
||||
@@ -484,22 +467,22 @@ func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err
|
||||
// In fact, it creates different proxies based on the proxy type. We just call run() here.
|
||||
pxy, err := proxy.NewProxy(ctl.ctx, &proxy.Options{
|
||||
UserInfo: userInfo,
|
||||
LoginMsg: ctl.loginMsg,
|
||||
LoginMsg: ctl.sessionCtx.LoginMsg,
|
||||
PoolCount: ctl.poolCount,
|
||||
ResourceController: ctl.rc,
|
||||
ResourceController: ctl.sessionCtx.RC,
|
||||
GetWorkConnFn: ctl.GetWorkConn,
|
||||
Configurer: pxyConf,
|
||||
ServerCfg: ctl.serverCfg,
|
||||
EncryptionKey: ctl.encryptionKey,
|
||||
ServerCfg: ctl.sessionCtx.ServerCfg,
|
||||
EncryptionKey: ctl.sessionCtx.EncryptionKey,
|
||||
})
|
||||
if err != nil {
|
||||
return remoteAddr, err
|
||||
}
|
||||
|
||||
// Check ports used number in each client
|
||||
if ctl.serverCfg.MaxPortsPerClient > 0 {
|
||||
if ctl.sessionCtx.ServerCfg.MaxPortsPerClient > 0 {
|
||||
ctl.mu.Lock()
|
||||
if ctl.portsUsedNum+pxy.GetUsedPortsNum() > int(ctl.serverCfg.MaxPortsPerClient) {
|
||||
if ctl.portsUsedNum+pxy.GetUsedPortsNum() > int(ctl.sessionCtx.ServerCfg.MaxPortsPerClient) {
|
||||
ctl.mu.Unlock()
|
||||
err = fmt.Errorf("exceed the max_ports_per_client")
|
||||
return
|
||||
@@ -516,7 +499,7 @@ func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err
|
||||
}()
|
||||
}
|
||||
|
||||
if ctl.pxyManager.Exist(pxyMsg.ProxyName) {
|
||||
if ctl.sessionCtx.PxyManager.Exist(pxyMsg.ProxyName) {
|
||||
err = fmt.Errorf("proxy [%s] already exists", pxyMsg.ProxyName)
|
||||
return
|
||||
}
|
||||
@@ -531,7 +514,7 @@ func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err
|
||||
}
|
||||
}()
|
||||
|
||||
err = ctl.pxyManager.Add(pxyMsg.ProxyName, pxy)
|
||||
err = ctl.sessionCtx.PxyManager.Add(pxyMsg.ProxyName, pxy)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -550,11 +533,11 @@ func (ctl *Control) CloseProxy(closeMsg *msg.CloseProxy) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
if ctl.serverCfg.MaxPortsPerClient > 0 {
|
||||
if ctl.sessionCtx.ServerCfg.MaxPortsPerClient > 0 {
|
||||
ctl.portsUsedNum -= pxy.GetUsedPortsNum()
|
||||
}
|
||||
pxy.Close()
|
||||
ctl.pxyManager.Del(pxy.GetName())
|
||||
ctl.sessionCtx.PxyManager.Del(pxy.GetName())
|
||||
delete(ctl.proxies, closeMsg.ProxyName)
|
||||
ctl.mu.Unlock()
|
||||
|
||||
@@ -562,16 +545,16 @@ func (ctl *Control) CloseProxy(closeMsg *msg.CloseProxy) (err error) {
|
||||
|
||||
notifyContent := &plugin.CloseProxyContent{
|
||||
User: plugin.UserInfo{
|
||||
User: ctl.loginMsg.User,
|
||||
Metas: ctl.loginMsg.Metas,
|
||||
RunID: ctl.loginMsg.RunID,
|
||||
User: ctl.sessionCtx.LoginMsg.User,
|
||||
Metas: ctl.sessionCtx.LoginMsg.Metas,
|
||||
RunID: ctl.sessionCtx.LoginMsg.RunID,
|
||||
},
|
||||
CloseProxy: msg.CloseProxy{
|
||||
ProxyName: pxy.GetName(),
|
||||
},
|
||||
}
|
||||
go func() {
|
||||
_ = ctl.pluginManager.CloseProxy(notifyContent)
|
||||
_ = ctl.sessionCtx.PluginManager.CloseProxy(notifyContent)
|
||||
}()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -75,15 +75,7 @@ func (pxy *HTTPProxy) Run() (remoteAddr string, err error) {
|
||||
}
|
||||
}()
|
||||
|
||||
domains := make([]string, 0, len(pxy.cfg.CustomDomains)+1)
|
||||
for _, d := range pxy.cfg.CustomDomains {
|
||||
if d != "" {
|
||||
domains = append(domains, d)
|
||||
}
|
||||
}
|
||||
if pxy.cfg.SubDomain != "" {
|
||||
domains = append(domains, pxy.cfg.SubDomain+"."+pxy.serverCfg.SubDomainHost)
|
||||
}
|
||||
domains := pxy.buildDomains(pxy.cfg.CustomDomains, pxy.cfg.SubDomain)
|
||||
|
||||
addrs := make([]string, 0)
|
||||
for _, domain := range domains {
|
||||
|
||||
@@ -53,15 +53,7 @@ func (pxy *HTTPSProxy) Run() (remoteAddr string, err error) {
|
||||
pxy.Close()
|
||||
}
|
||||
}()
|
||||
domains := make([]string, 0, len(pxy.cfg.CustomDomains)+1)
|
||||
for _, d := range pxy.cfg.CustomDomains {
|
||||
if d != "" {
|
||||
domains = append(domains, d)
|
||||
}
|
||||
}
|
||||
if pxy.cfg.SubDomain != "" {
|
||||
domains = append(domains, pxy.cfg.SubDomain+"."+pxy.serverCfg.SubDomainHost)
|
||||
}
|
||||
domains := pxy.buildDomains(pxy.cfg.CustomDomains, pxy.cfg.SubDomain)
|
||||
|
||||
addrs := make([]string, 0)
|
||||
for _, domain := range domains {
|
||||
|
||||
@@ -150,7 +150,7 @@ func (pxy *BaseProxy) GetWorkConnFromPool(src, dst net.Addr) (workConn net.Conn,
|
||||
dstAddr, dstPortStr, _ = net.SplitHostPort(dst.String())
|
||||
dstPort, _ = strconv.ParseUint(dstPortStr, 10, 16)
|
||||
}
|
||||
err := msg.WriteMsg(workConn, &msg.StartWorkConn{
|
||||
err = msg.WriteMsg(workConn, &msg.StartWorkConn{
|
||||
ProxyName: pxy.GetName(),
|
||||
SrcAddr: srcAddr,
|
||||
SrcPort: uint16(srcPort),
|
||||
@@ -161,6 +161,7 @@ func (pxy *BaseProxy) GetWorkConnFromPool(src, dst net.Addr) (workConn net.Conn,
|
||||
if err != nil {
|
||||
xl.Warnf("failed to send message to work connection from pool: %v, times: %d", err, i)
|
||||
workConn.Close()
|
||||
workConn = nil
|
||||
} else {
|
||||
break
|
||||
}
|
||||
@@ -173,6 +174,36 @@ func (pxy *BaseProxy) GetWorkConnFromPool(src, dst net.Addr) (workConn net.Conn,
|
||||
return
|
||||
}
|
||||
|
||||
// startVisitorListener sets up a VisitorManager listener for visitor-based proxies (STCP, SUDP).
|
||||
func (pxy *BaseProxy) startVisitorListener(secretKey string, allowUsers []string, proxyType string) error {
|
||||
// if allowUsers is empty, only allow same user from proxy
|
||||
if len(allowUsers) == 0 {
|
||||
allowUsers = []string{pxy.GetUserInfo().User}
|
||||
}
|
||||
listener, err := pxy.rc.VisitorManager.Listen(pxy.GetName(), secretKey, allowUsers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pxy.listeners = append(pxy.listeners, listener)
|
||||
pxy.xl.Infof("%s proxy custom listen success", proxyType)
|
||||
pxy.startCommonTCPListenersHandler()
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildDomains constructs a list of domains from custom domains and subdomain configuration.
|
||||
func (pxy *BaseProxy) buildDomains(customDomains []string, subDomain string) []string {
|
||||
domains := make([]string, 0, len(customDomains)+1)
|
||||
for _, d := range customDomains {
|
||||
if d != "" {
|
||||
domains = append(domains, d)
|
||||
}
|
||||
}
|
||||
if subDomain != "" {
|
||||
domains = append(domains, subDomain+"."+pxy.serverCfg.SubDomainHost)
|
||||
}
|
||||
return domains
|
||||
}
|
||||
|
||||
// startCommonTCPListenersHandler start a goroutine handler for each listener.
|
||||
func (pxy *BaseProxy) startCommonTCPListenersHandler() {
|
||||
xl := xlog.FromContextSafe(pxy.ctx)
|
||||
|
||||
@@ -41,21 +41,7 @@ func NewSTCPProxy(baseProxy *BaseProxy) Proxy {
|
||||
}
|
||||
|
||||
func (pxy *STCPProxy) Run() (remoteAddr string, err error) {
|
||||
xl := pxy.xl
|
||||
allowUsers := pxy.cfg.AllowUsers
|
||||
// if allowUsers is empty, only allow same user from proxy
|
||||
if len(allowUsers) == 0 {
|
||||
allowUsers = []string{pxy.GetUserInfo().User}
|
||||
}
|
||||
listener, errRet := pxy.rc.VisitorManager.Listen(pxy.GetName(), pxy.cfg.Secretkey, allowUsers)
|
||||
if errRet != nil {
|
||||
err = errRet
|
||||
return
|
||||
}
|
||||
pxy.listeners = append(pxy.listeners, listener)
|
||||
xl.Infof("stcp proxy custom listen success")
|
||||
|
||||
pxy.startCommonTCPListenersHandler()
|
||||
err = pxy.startVisitorListener(pxy.cfg.Secretkey, pxy.cfg.AllowUsers, "stcp")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -41,21 +41,7 @@ func NewSUDPProxy(baseProxy *BaseProxy) Proxy {
|
||||
}
|
||||
|
||||
func (pxy *SUDPProxy) Run() (remoteAddr string, err error) {
|
||||
xl := pxy.xl
|
||||
allowUsers := pxy.cfg.AllowUsers
|
||||
// if allowUsers is empty, only allow same user from proxy
|
||||
if len(allowUsers) == 0 {
|
||||
allowUsers = []string{pxy.GetUserInfo().User}
|
||||
}
|
||||
listener, errRet := pxy.rc.VisitorManager.Listen(pxy.GetName(), pxy.cfg.Secretkey, allowUsers)
|
||||
if errRet != nil {
|
||||
err = errRet
|
||||
return
|
||||
}
|
||||
pxy.listeners = append(pxy.listeners, listener)
|
||||
xl.Infof("sudp proxy custom listen success")
|
||||
|
||||
pxy.startCommonTCPListenersHandler()
|
||||
err = pxy.startVisitorListener(pxy.cfg.Secretkey, pxy.cfg.AllowUsers, "sudp")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -72,15 +72,7 @@ func (pxy *TCPMuxProxy) httpConnectListen(
|
||||
}
|
||||
|
||||
func (pxy *TCPMuxProxy) httpConnectRun() (remoteAddr string, err error) {
|
||||
domains := make([]string, 0, len(pxy.cfg.CustomDomains)+1)
|
||||
for _, d := range pxy.cfg.CustomDomains {
|
||||
if d != "" {
|
||||
domains = append(domains, d)
|
||||
}
|
||||
}
|
||||
if pxy.cfg.SubDomain != "" {
|
||||
domains = append(domains, pxy.cfg.SubDomain+"."+pxy.serverCfg.SubDomainHost)
|
||||
}
|
||||
domains := pxy.buildDomains(pxy.cfg.CustomDomains, pxy.cfg.SubDomain)
|
||||
|
||||
addrs := make([]string, 0)
|
||||
for _, domain := range domains {
|
||||
|
||||
@@ -604,8 +604,18 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, inter
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO(fatedier): use SessionContext
|
||||
ctl, err := NewControl(ctx, svr.rc, svr.pxyManager, svr.pluginManager, authVerifier, svr.auth.EncryptionKey(), ctlConn, !internal, loginMsg, svr.cfg)
|
||||
ctl, err := NewControl(ctx, &SessionContext{
|
||||
RC: svr.rc,
|
||||
PxyManager: svr.pxyManager,
|
||||
PluginManager: svr.pluginManager,
|
||||
AuthVerifier: authVerifier,
|
||||
EncryptionKey: svr.auth.EncryptionKey(),
|
||||
Conn: ctlConn,
|
||||
ConnEncrypted: !internal,
|
||||
LoginMsg: loginMsg,
|
||||
ServerCfg: svr.cfg,
|
||||
ClientRegistry: svr.clientRegistry,
|
||||
})
|
||||
if err != nil {
|
||||
xl.Warnf("create new controller error: %v", err)
|
||||
// don't return detailed errors to client
|
||||
@@ -626,7 +636,6 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, inter
|
||||
ctl.Close()
|
||||
return fmt.Errorf("client_id [%s] for user [%s] is already online", loginMsg.ClientID, loginMsg.User)
|
||||
}
|
||||
ctl.clientRegistry = svr.clientRegistry
|
||||
|
||||
ctl.Start()
|
||||
|
||||
@@ -652,9 +661,9 @@ func (svr *Service) RegisterWorkConn(workConn net.Conn, newMsg *msg.NewWorkConn)
|
||||
// server plugin hook
|
||||
content := &plugin.NewWorkConnContent{
|
||||
User: plugin.UserInfo{
|
||||
User: ctl.loginMsg.User,
|
||||
Metas: ctl.loginMsg.Metas,
|
||||
RunID: ctl.loginMsg.RunID,
|
||||
User: ctl.sessionCtx.LoginMsg.User,
|
||||
Metas: ctl.sessionCtx.LoginMsg.Metas,
|
||||
RunID: ctl.sessionCtx.LoginMsg.RunID,
|
||||
},
|
||||
NewWorkConn: *newMsg,
|
||||
}
|
||||
@@ -662,7 +671,7 @@ func (svr *Service) RegisterWorkConn(workConn net.Conn, newMsg *msg.NewWorkConn)
|
||||
if err == nil {
|
||||
newMsg = &retContent.NewWorkConn
|
||||
// Check auth.
|
||||
err = ctl.authVerifier.VerifyNewWorkConn(newMsg)
|
||||
err = ctl.sessionCtx.AuthVerifier.VerifyNewWorkConn(newMsg)
|
||||
}
|
||||
if err != nil {
|
||||
xl.Warnf("invalid NewWorkConn with run id [%s]", newMsg.RunID)
|
||||
@@ -683,7 +692,7 @@ func (svr *Service) RegisterVisitorConn(visitorConn net.Conn, newMsg *msg.NewVis
|
||||
if !exist {
|
||||
return fmt.Errorf("no client control found for run id [%s]", newMsg.RunID)
|
||||
}
|
||||
visitorUser = ctl.loginMsg.User
|
||||
visitorUser = ctl.sessionCtx.LoginMsg.User
|
||||
}
|
||||
return svr.rc.VisitorManager.NewConn(newMsg.ProxyName, visitorConn, newMsg.Timestamp, newMsg.SignKey,
|
||||
newMsg.UseEncryption, newMsg.UseCompression, visitorUser)
|
||||
|
||||
Reference in New Issue
Block a user