diff --git a/client/proxy/proxy.go b/client/proxy/proxy.go index ca7d905b..84ff49a1 100644 --- a/client/proxy/proxy.go +++ b/client/proxy/proxy.go @@ -16,6 +16,7 @@ package proxy import ( "context" + "fmt" "io" "net" "reflect" @@ -122,6 +123,33 @@ func (pxy *BaseProxy) Close() { } } +// wrapWorkConn applies rate limiting, encryption, and compression +// to a work connection based on the proxy's transport configuration. +// The returned recycle function should be called when the stream is no longer in use +// to return compression resources to the pool. It is safe to not call recycle, +// in which case resources will be garbage collected normally. +func (pxy *BaseProxy) wrapWorkConn(conn net.Conn, encKey []byte) (io.ReadWriteCloser, func(), error) { + var rwc io.ReadWriteCloser = conn + if pxy.limiter != nil { + rwc = libio.WrapReadWriteCloser(limit.NewReader(conn, pxy.limiter), limit.NewWriter(conn, pxy.limiter), func() error { + return conn.Close() + }) + } + if pxy.baseCfg.Transport.UseEncryption { + var err error + rwc, err = libio.WithEncryption(rwc, encKey) + if err != nil { + conn.Close() + return nil, nil, fmt.Errorf("create encryption stream error: %w", err) + } + } + var recycleFn func() + if pxy.baseCfg.Transport.UseCompression { + rwc, recycleFn = libio.WithCompressionFromPool(rwc) + } + return rwc, recycleFn, nil +} + func (pxy *BaseProxy) SetInWorkConnCallback(cb func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool) { pxy.inWorkConnCallback = cb } @@ -139,30 +167,14 @@ func (pxy *BaseProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) { func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWorkConn, encKey []byte) { xl := pxy.xl baseCfg := pxy.baseCfg - var ( - remote io.ReadWriteCloser - err error - ) - remote = workConn - if pxy.limiter != nil { - remote = libio.WrapReadWriteCloser(limit.NewReader(workConn, pxy.limiter), limit.NewWriter(workConn, pxy.limiter), func() error { - return workConn.Close() - }) - } xl.Tracef("handle tcp work connection, useEncryption: %t, useCompression: %t", baseCfg.Transport.UseEncryption, baseCfg.Transport.UseCompression) - if baseCfg.Transport.UseEncryption { - remote, err = libio.WithEncryption(remote, encKey) - if err != nil { - workConn.Close() - xl.Errorf("create encryption stream error: %v", err) - return - } - } - var compressionResourceRecycleFn func() - if baseCfg.Transport.UseCompression { - remote, compressionResourceRecycleFn = libio.WithCompressionFromPool(remote) + + remote, recycleFn, err := pxy.wrapWorkConn(workConn, encKey) + if err != nil { + xl.Errorf("wrap work connection: %v", err) + return } // check if we need to send proxy protocol info @@ -178,7 +190,6 @@ func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWor } if baseCfg.Transport.ProxyProtocolVersion != "" && m.SrcAddr != "" && m.SrcPort != 0 { - // Use the common proxy protocol builder function header := netpkg.BuildProxyProtocolHeaderStruct(connInfo.SrcAddr, connInfo.DstAddr, baseCfg.Transport.ProxyProtocolVersion) connInfo.ProxyProtocolHeader = header } @@ -187,12 +198,18 @@ func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWor if pxy.proxyPlugin != nil { // if plugin is set, let plugin handle connection first + // Don't recycle compression resources here because plugins may + // retain the connection after Handle returns. xl.Debugf("handle by plugin: %s", pxy.proxyPlugin.Name()) pxy.proxyPlugin.Handle(pxy.ctx, &connInfo) xl.Debugf("handle by plugin finished") return } + if recycleFn != nil { + defer recycleFn() + } + localConn, err := libnet.Dial( net.JoinHostPort(baseCfg.LocalIP, strconv.Itoa(baseCfg.LocalPort)), libnet.WithTimeout(10*time.Second), @@ -220,7 +237,4 @@ func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWor if len(errs) > 0 { xl.Tracef("join connections errors: %v", errs) } - if compressionResourceRecycleFn != nil { - compressionResourceRecycleFn() - } } diff --git a/client/proxy/sudp.go b/client/proxy/sudp.go index 3a7af19c..da941446 100644 --- a/client/proxy/sudp.go +++ b/client/proxy/sudp.go @@ -17,7 +17,6 @@ package proxy import ( - "io" "net" "reflect" "strconv" @@ -25,12 +24,10 @@ import ( "time" "github.com/fatedier/golib/errors" - libio "github.com/fatedier/golib/io" v1 "github.com/fatedier/frp/pkg/config/v1" "github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/proto/udp" - "github.com/fatedier/frp/pkg/util/limit" netpkg "github.com/fatedier/frp/pkg/util/net" ) @@ -83,27 +80,13 @@ func (pxy *SUDPProxy) InWorkConn(conn net.Conn, _ *msg.StartWorkConn) { xl := pxy.xl xl.Infof("incoming a new work connection for sudp proxy, %s", conn.RemoteAddr().String()) - var rwc io.ReadWriteCloser = conn - var err error - if pxy.limiter != nil { - rwc = libio.WrapReadWriteCloser(limit.NewReader(conn, pxy.limiter), limit.NewWriter(conn, pxy.limiter), func() error { - return conn.Close() - }) + remote, _, err := pxy.wrapWorkConn(conn, pxy.encryptionKey) + if err != nil { + xl.Errorf("wrap work connection: %v", err) + return } - if pxy.cfg.Transport.UseEncryption { - rwc, err = libio.WithEncryption(rwc, pxy.encryptionKey) - if err != nil { - conn.Close() - xl.Errorf("create encryption stream error: %v", err) - return - } - } - if pxy.cfg.Transport.UseCompression { - rwc = libio.WithCompression(rwc) - } - conn = netpkg.WrapReadWriteCloserToConn(rwc, conn) - workConn := conn + workConn := netpkg.WrapReadWriteCloserToConn(remote, conn) readCh := make(chan *msg.UDPPacket, 1024) sendCh := make(chan msg.Message, 1024) isClose := false diff --git a/client/proxy/udp.go b/client/proxy/udp.go index 68426dc6..570da476 100644 --- a/client/proxy/udp.go +++ b/client/proxy/udp.go @@ -17,19 +17,16 @@ package proxy import ( - "io" "net" "reflect" "strconv" "time" "github.com/fatedier/golib/errors" - libio "github.com/fatedier/golib/io" v1 "github.com/fatedier/frp/pkg/config/v1" "github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/proto/udp" - "github.com/fatedier/frp/pkg/util/limit" netpkg "github.com/fatedier/frp/pkg/util/net" ) @@ -94,28 +91,14 @@ func (pxy *UDPProxy) InWorkConn(conn net.Conn, _ *msg.StartWorkConn) { // close resources related with old workConn pxy.Close() - var rwc io.ReadWriteCloser = conn - var err error - if pxy.limiter != nil { - rwc = libio.WrapReadWriteCloser(limit.NewReader(conn, pxy.limiter), limit.NewWriter(conn, pxy.limiter), func() error { - return conn.Close() - }) + remote, _, err := pxy.wrapWorkConn(conn, pxy.encryptionKey) + if err != nil { + xl.Errorf("wrap work connection: %v", err) + return } - if pxy.cfg.Transport.UseEncryption { - rwc, err = libio.WithEncryption(rwc, pxy.encryptionKey) - if err != nil { - conn.Close() - xl.Errorf("create encryption stream error: %v", err) - return - } - } - if pxy.cfg.Transport.UseCompression { - rwc = libio.WithCompression(rwc) - } - conn = netpkg.WrapReadWriteCloserToConn(rwc, conn) pxy.mu.Lock() - pxy.workConn = conn + pxy.workConn = netpkg.WrapReadWriteCloserToConn(remote, conn) pxy.readCh = make(chan *msg.UDPPacket, 1024) pxy.sendCh = make(chan msg.Message, 1024) pxy.closed = false