diff --git a/server/proxy/http.go b/server/proxy/http.go index 2c4f1fd4..761fa805 100644 --- a/server/proxy/http.go +++ b/server/proxy/http.go @@ -181,18 +181,26 @@ func (pxy *HTTPProxy) GetRealConn(remoteAddr string) (workConn net.Conn, err err }) } + name := pxy.GetName() + proxyType := pxy.GetConfigurer().GetBaseConfig().Type + rwc = wrapCountingReadWriteCloser(rwc, func(bytes int64) { + metrics.Server.AddTrafficOut(name, proxyType, bytes) + }, func(bytes int64) { + metrics.Server.AddTrafficIn(name, proxyType, bytes) + }) + workConn = netpkg.WrapReadWriteCloserToConn(rwc, tmpConn) - workConn = netpkg.WrapStatsConn(workConn, pxy.updateStatsAfterClosedConn) - metrics.Server.OpenConnection(pxy.GetName(), pxy.GetConfigurer().GetBaseConfig().Type) + workConn = netpkg.WrapCloseNotifyConn(workConn, func(error) { + pxy.updateStatsAfterClosedConn() + }) + metrics.Server.OpenConnection(name, proxyType) return } -func (pxy *HTTPProxy) updateStatsAfterClosedConn(totalRead, totalWrite int64) { +func (pxy *HTTPProxy) updateStatsAfterClosedConn() { name := pxy.GetName() proxyType := pxy.GetConfigurer().GetBaseConfig().Type metrics.Server.CloseConnection(name, proxyType) - metrics.Server.AddTrafficIn(name, proxyType, totalWrite) - metrics.Server.AddTrafficOut(name, proxyType, totalRead) } func (pxy *HTTPProxy) Close() { diff --git a/server/proxy/proxy.go b/server/proxy/proxy.go index 564eca28..f3a30cc5 100644 --- a/server/proxy/proxy.go +++ b/server/proxy/proxy.go @@ -263,11 +263,18 @@ func (pxy *BaseProxy) handleUserTCPConnection(userConn net.Conn) { name := pxy.GetName() proxyType := cfg.Type + local = wrapCountingReadWriteCloser(local, nil, func(bytes int64) { + metrics.Server.AddTrafficIn(name, proxyType, bytes) + }) + userConn = netpkg.WrapReadWriteCloserToConn( + wrapCountingReadWriteCloser(userConn, nil, func(bytes int64) { + metrics.Server.AddTrafficOut(name, proxyType, bytes) + }), + userConn, + ) metrics.Server.OpenConnection(name, proxyType) - inCount, outCount, _ := libio.Join(local, userConn) + _, _, _ = libio.Join(local, userConn) metrics.Server.CloseConnection(name, proxyType) - metrics.Server.AddTrafficIn(name, proxyType, inCount) - metrics.Server.AddTrafficOut(name, proxyType, outCount) xl.Debugf("join connections closed") } diff --git a/server/proxy/traffic_counter.go b/server/proxy/traffic_counter.go new file mode 100644 index 00000000..c3297a74 --- /dev/null +++ b/server/proxy/traffic_counter.go @@ -0,0 +1,36 @@ +package proxy + +import "io" + +type countingReadWriteCloser struct { + io.ReadWriteCloser + onRead func(int64) + onWrite func(int64) +} + +func wrapCountingReadWriteCloser(rwc io.ReadWriteCloser, onRead, onWrite func(int64)) io.ReadWriteCloser { + if onRead == nil && onWrite == nil { + return rwc + } + return &countingReadWriteCloser{ + ReadWriteCloser: rwc, + onRead: onRead, + onWrite: onWrite, + } +} + +func (c *countingReadWriteCloser) Read(p []byte) (n int, err error) { + n, err = c.ReadWriteCloser.Read(p) + if n > 0 && c.onRead != nil { + c.onRead(int64(n)) + } + return +} + +func (c *countingReadWriteCloser) Write(p []byte) (n int, err error) { + n, err = c.ReadWriteCloser.Write(p) + if n > 0 && c.onWrite != nil { + c.onWrite(int64(n)) + } + return +}