From bb3d0e71401b94433753398272f35fdea27097b1 Mon Sep 17 00:00:00 2001 From: fatedier Date: Sat, 7 Mar 2026 12:00:27 +0800 Subject: [PATCH] deduplicate common logic across proxy, visitor, and metrics modules (#5213) - Replace duplicate parseBasicAuth with existing httppkg.ParseBasicAuth - Extract buildDomains helper in BaseProxy for HTTP/HTTPS/TCPMux proxies - Extract toProxyStats helper to deduplicate ProxyStats construction - Extract startVisitorListener helper in BaseProxy for STCP/SUDP proxies - Extract acceptLoop helper in BaseVisitor for STCP/XTCP visitors --- client/visitor/stcp.go | 28 ++------------- client/visitor/visitor.go | 12 +++++++ client/visitor/xtcp.go | 28 ++------------- pkg/metrics/mem/server.go | 71 ++++++++++++--------------------------- pkg/util/vhost/http.go | 20 +---------- server/proxy/http.go | 10 +----- server/proxy/https.go | 10 +----- server/proxy/proxy.go | 30 +++++++++++++++++ server/proxy/stcp.go | 16 +-------- server/proxy/sudp.go | 16 +-------- server/proxy/tcpmux.go | 10 +----- 11 files changed, 74 insertions(+), 177 deletions(-) diff --git a/client/visitor/stcp.go b/client/visitor/stcp.go index 870e48d9..6529142c 100644 --- a/client/visitor/stcp.go +++ b/client/visitor/stcp.go @@ -42,10 +42,10 @@ func (sv *STCPVisitor) Run() (err error) { if err != nil { return } - go sv.worker() + go sv.acceptLoop(sv.l, "stcp local", sv.handleConn) } - go sv.internalConnWorker() + go sv.acceptLoop(sv.internalLn, "stcp internal", sv.handleConn) if sv.plugin != nil { sv.plugin.Start() @@ -57,30 +57,6 @@ func (sv *STCPVisitor) Close() { sv.BaseVisitor.Close() } -func (sv *STCPVisitor) worker() { - xl := xlog.FromContextSafe(sv.ctx) - for { - conn, err := sv.l.Accept() - if err != nil { - xl.Warnf("stcp local listener closed") - return - } - go sv.handleConn(conn) - } -} - -func (sv *STCPVisitor) internalConnWorker() { - xl := xlog.FromContextSafe(sv.ctx) - for { - conn, err := sv.internalLn.Accept() - if err != nil { - xl.Warnf("stcp internal listener closed") - return - } - go sv.handleConn(conn) - } -} - func (sv *STCPVisitor) handleConn(userConn net.Conn) { xl := xlog.FromContextSafe(sv.ctx) var tunnelErr error diff --git a/client/visitor/visitor.go b/client/visitor/visitor.go index 87e4f29f..51999499 100644 --- a/client/visitor/visitor.go +++ b/client/visitor/visitor.go @@ -119,6 +119,18 @@ func (v *BaseVisitor) AcceptConn(conn net.Conn) error { return v.internalLn.PutConn(conn) } +func (v *BaseVisitor) acceptLoop(l net.Listener, name string, handleConn func(net.Conn)) { + xl := xlog.FromContextSafe(v.ctx) + for { + conn, err := l.Accept() + if err != nil { + xl.Warnf("%s listener closed", name) + return + } + go handleConn(conn) + } +} + func (v *BaseVisitor) Close() { if v.l != nil { v.l.Close() diff --git a/client/visitor/xtcp.go b/client/visitor/xtcp.go index 2273a271..b2f4ef37 100644 --- a/client/visitor/xtcp.go +++ b/client/visitor/xtcp.go @@ -65,10 +65,10 @@ func (sv *XTCPVisitor) Run() (err error) { if err != nil { return } - go sv.worker() + go sv.acceptLoop(sv.l, "xtcp local", sv.handleConn) } - go sv.internalConnWorker() + go sv.acceptLoop(sv.internalLn, "xtcp internal", sv.handleConn) go sv.processTunnelStartEvents() if sv.cfg.KeepTunnelOpen { sv.retryLimiter = rate.NewLimiter(rate.Every(time.Hour/time.Duration(sv.cfg.MaxRetriesAnHour)), sv.cfg.MaxRetriesAnHour) @@ -93,30 +93,6 @@ func (sv *XTCPVisitor) Close() { } } -func (sv *XTCPVisitor) worker() { - xl := xlog.FromContextSafe(sv.ctx) - for { - conn, err := sv.l.Accept() - if err != nil { - xl.Warnf("xtcp local listener closed") - return - } - go sv.handleConn(conn) - } -} - -func (sv *XTCPVisitor) internalConnWorker() { - xl := xlog.FromContextSafe(sv.ctx) - for { - conn, err := sv.internalLn.Accept() - if err != nil { - xl.Warnf("xtcp internal listener closed") - return - } - go sv.handleConn(conn) - } -} - func (sv *XTCPVisitor) processTunnelStartEvents() { for { select { diff --git a/pkg/metrics/mem/server.go b/pkg/metrics/mem/server.go index 677788d3..16a6491e 100644 --- a/pkg/metrics/mem/server.go +++ b/pkg/metrics/mem/server.go @@ -203,6 +203,25 @@ func (m *serverMetrics) GetServer() *ServerStats { return s } +func toProxyStats(name string, proxyStats *ProxyStatistics) *ProxyStats { + ps := &ProxyStats{ + Name: name, + Type: proxyStats.ProxyType, + User: proxyStats.User, + ClientID: proxyStats.ClientID, + TodayTrafficIn: proxyStats.TrafficIn.TodayCount(), + TodayTrafficOut: proxyStats.TrafficOut.TodayCount(), + CurConns: int64(proxyStats.CurConns.Count()), + } + if !proxyStats.LastStartTime.IsZero() { + ps.LastStartTime = proxyStats.LastStartTime.Format("01-02 15:04:05") + } + if !proxyStats.LastCloseTime.IsZero() { + ps.LastCloseTime = proxyStats.LastCloseTime.Format("01-02 15:04:05") + } + return ps +} + func (m *serverMetrics) GetProxiesByType(proxyType string) []*ProxyStats { res := make([]*ProxyStats, 0) m.mu.Lock() @@ -212,23 +231,7 @@ func (m *serverMetrics) GetProxiesByType(proxyType string) []*ProxyStats { if proxyStats.ProxyType != proxyType { continue } - - ps := &ProxyStats{ - Name: name, - Type: proxyStats.ProxyType, - User: proxyStats.User, - ClientID: proxyStats.ClientID, - TodayTrafficIn: proxyStats.TrafficIn.TodayCount(), - TodayTrafficOut: proxyStats.TrafficOut.TodayCount(), - CurConns: int64(proxyStats.CurConns.Count()), - } - if !proxyStats.LastStartTime.IsZero() { - ps.LastStartTime = proxyStats.LastStartTime.Format("01-02 15:04:05") - } - if !proxyStats.LastCloseTime.IsZero() { - ps.LastCloseTime = proxyStats.LastCloseTime.Format("01-02 15:04:05") - } - res = append(res, ps) + res = append(res, toProxyStats(name, proxyStats)) } return res } @@ -241,26 +244,10 @@ func (m *serverMetrics) GetProxiesByTypeAndName(proxyType string, proxyName stri if proxyStats.ProxyType != proxyType { continue } - if name != proxyName { continue } - - res = &ProxyStats{ - Name: name, - Type: proxyStats.ProxyType, - User: proxyStats.User, - ClientID: proxyStats.ClientID, - TodayTrafficIn: proxyStats.TrafficIn.TodayCount(), - TodayTrafficOut: proxyStats.TrafficOut.TodayCount(), - CurConns: int64(proxyStats.CurConns.Count()), - } - if !proxyStats.LastStartTime.IsZero() { - res.LastStartTime = proxyStats.LastStartTime.Format("01-02 15:04:05") - } - if !proxyStats.LastCloseTime.IsZero() { - res.LastCloseTime = proxyStats.LastCloseTime.Format("01-02 15:04:05") - } + res = toProxyStats(name, proxyStats) break } return @@ -272,21 +259,7 @@ func (m *serverMetrics) GetProxyByName(proxyName string) (res *ProxyStats) { proxyStats, ok := m.info.ProxyStatistics[proxyName] if ok { - res = &ProxyStats{ - Name: proxyName, - Type: proxyStats.ProxyType, - User: proxyStats.User, - ClientID: proxyStats.ClientID, - TodayTrafficIn: proxyStats.TrafficIn.TodayCount(), - TodayTrafficOut: proxyStats.TrafficOut.TodayCount(), - CurConns: int64(proxyStats.CurConns.Count()), - } - if !proxyStats.LastStartTime.IsZero() { - res.LastStartTime = proxyStats.LastStartTime.Format("01-02 15:04:05") - } - if !proxyStats.LastCloseTime.IsZero() { - res.LastCloseTime = proxyStats.LastCloseTime.Format("01-02 15:04:05") - } + res = toProxyStats(proxyName, proxyStats) } return } diff --git a/pkg/util/vhost/http.go b/pkg/util/vhost/http.go index 05ec174b..d12e7916 100644 --- a/pkg/util/vhost/http.go +++ b/pkg/util/vhost/http.go @@ -266,31 +266,13 @@ func (rp *HTTPReverseProxy) connectHandler(rw http.ResponseWriter, req *http.Req go libio.Join(remote, client) } -func parseBasicAuth(auth string) (username, password string, ok bool) { - const prefix = "Basic " - // Case insensitive prefix match. See Issue 22736. - if len(auth) < len(prefix) || !strings.EqualFold(auth[:len(prefix)], prefix) { - return - } - c, err := base64.StdEncoding.DecodeString(auth[len(prefix):]) - if err != nil { - return - } - cs := string(c) - s := strings.IndexByte(cs, ':') - if s < 0 { - return - } - return cs[:s], cs[s+1:], true -} - func (rp *HTTPReverseProxy) injectRequestInfoToCtx(req *http.Request) *http.Request { user := "" // If url host isn't empty, it's a proxy request. Get http user from Proxy-Authorization header. if req.URL.Host != "" { proxyAuth := req.Header.Get("Proxy-Authorization") if proxyAuth != "" { - user, _, _ = parseBasicAuth(proxyAuth) + user, _, _ = httppkg.ParseBasicAuth(proxyAuth) } } if user == "" { diff --git a/server/proxy/http.go b/server/proxy/http.go index 31b00410..e5df06c5 100644 --- a/server/proxy/http.go +++ b/server/proxy/http.go @@ -75,15 +75,7 @@ func (pxy *HTTPProxy) Run() (remoteAddr string, err error) { } }() - domains := make([]string, 0, len(pxy.cfg.CustomDomains)+1) - for _, d := range pxy.cfg.CustomDomains { - if d != "" { - domains = append(domains, d) - } - } - if pxy.cfg.SubDomain != "" { - domains = append(domains, pxy.cfg.SubDomain+"."+pxy.serverCfg.SubDomainHost) - } + domains := pxy.buildDomains(pxy.cfg.CustomDomains, pxy.cfg.SubDomain) addrs := make([]string, 0) for _, domain := range domains { diff --git a/server/proxy/https.go b/server/proxy/https.go index b65cac34..ec720242 100644 --- a/server/proxy/https.go +++ b/server/proxy/https.go @@ -53,15 +53,7 @@ func (pxy *HTTPSProxy) Run() (remoteAddr string, err error) { pxy.Close() } }() - domains := make([]string, 0, len(pxy.cfg.CustomDomains)+1) - for _, d := range pxy.cfg.CustomDomains { - if d != "" { - domains = append(domains, d) - } - } - if pxy.cfg.SubDomain != "" { - domains = append(domains, pxy.cfg.SubDomain+"."+pxy.serverCfg.SubDomainHost) - } + domains := pxy.buildDomains(pxy.cfg.CustomDomains, pxy.cfg.SubDomain) addrs := make([]string, 0) for _, domain := range domains { diff --git a/server/proxy/proxy.go b/server/proxy/proxy.go index 564eca28..dcc652a0 100644 --- a/server/proxy/proxy.go +++ b/server/proxy/proxy.go @@ -173,6 +173,36 @@ func (pxy *BaseProxy) GetWorkConnFromPool(src, dst net.Addr) (workConn net.Conn, return } +// startVisitorListener sets up a VisitorManager listener for visitor-based proxies (STCP, SUDP). +func (pxy *BaseProxy) startVisitorListener(secretKey string, allowUsers []string, proxyType string) error { + // if allowUsers is empty, only allow same user from proxy + if len(allowUsers) == 0 { + allowUsers = []string{pxy.GetUserInfo().User} + } + listener, err := pxy.rc.VisitorManager.Listen(pxy.GetName(), secretKey, allowUsers) + if err != nil { + return err + } + pxy.listeners = append(pxy.listeners, listener) + pxy.xl.Infof("%s proxy custom listen success", proxyType) + pxy.startCommonTCPListenersHandler() + return nil +} + +// buildDomains constructs a list of domains from custom domains and subdomain configuration. +func (pxy *BaseProxy) buildDomains(customDomains []string, subDomain string) []string { + domains := make([]string, 0, len(customDomains)+1) + for _, d := range customDomains { + if d != "" { + domains = append(domains, d) + } + } + if subDomain != "" { + domains = append(domains, subDomain+"."+pxy.serverCfg.SubDomainHost) + } + return domains +} + // startCommonTCPListenersHandler start a goroutine handler for each listener. func (pxy *BaseProxy) startCommonTCPListenersHandler() { xl := xlog.FromContextSafe(pxy.ctx) diff --git a/server/proxy/stcp.go b/server/proxy/stcp.go index 06b1b17f..113fd13b 100644 --- a/server/proxy/stcp.go +++ b/server/proxy/stcp.go @@ -41,21 +41,7 @@ func NewSTCPProxy(baseProxy *BaseProxy) Proxy { } func (pxy *STCPProxy) Run() (remoteAddr string, err error) { - xl := pxy.xl - allowUsers := pxy.cfg.AllowUsers - // if allowUsers is empty, only allow same user from proxy - if len(allowUsers) == 0 { - allowUsers = []string{pxy.GetUserInfo().User} - } - listener, errRet := pxy.rc.VisitorManager.Listen(pxy.GetName(), pxy.cfg.Secretkey, allowUsers) - if errRet != nil { - err = errRet - return - } - pxy.listeners = append(pxy.listeners, listener) - xl.Infof("stcp proxy custom listen success") - - pxy.startCommonTCPListenersHandler() + err = pxy.startVisitorListener(pxy.cfg.Secretkey, pxy.cfg.AllowUsers, "stcp") return } diff --git a/server/proxy/sudp.go b/server/proxy/sudp.go index f37fb423..00438882 100644 --- a/server/proxy/sudp.go +++ b/server/proxy/sudp.go @@ -41,21 +41,7 @@ func NewSUDPProxy(baseProxy *BaseProxy) Proxy { } func (pxy *SUDPProxy) Run() (remoteAddr string, err error) { - xl := pxy.xl - allowUsers := pxy.cfg.AllowUsers - // if allowUsers is empty, only allow same user from proxy - if len(allowUsers) == 0 { - allowUsers = []string{pxy.GetUserInfo().User} - } - listener, errRet := pxy.rc.VisitorManager.Listen(pxy.GetName(), pxy.cfg.Secretkey, allowUsers) - if errRet != nil { - err = errRet - return - } - pxy.listeners = append(pxy.listeners, listener) - xl.Infof("sudp proxy custom listen success") - - pxy.startCommonTCPListenersHandler() + err = pxy.startVisitorListener(pxy.cfg.Secretkey, pxy.cfg.AllowUsers, "sudp") return } diff --git a/server/proxy/tcpmux.go b/server/proxy/tcpmux.go index f68ad12d..2a0c8512 100644 --- a/server/proxy/tcpmux.go +++ b/server/proxy/tcpmux.go @@ -72,15 +72,7 @@ func (pxy *TCPMuxProxy) httpConnectListen( } func (pxy *TCPMuxProxy) httpConnectRun() (remoteAddr string, err error) { - domains := make([]string, 0, len(pxy.cfg.CustomDomains)+1) - for _, d := range pxy.cfg.CustomDomains { - if d != "" { - domains = append(domains, d) - } - } - if pxy.cfg.SubDomain != "" { - domains = append(domains, pxy.cfg.SubDomain+"."+pxy.serverCfg.SubDomainHost) - } + domains := pxy.buildDomains(pxy.cfg.CustomDomains, pxy.cfg.SubDomain) addrs := make([]string, 0) for _, domain := range domains {