diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml
index 35cb5f65..3217c234 100644
--- a/.github/workflows/golangci-lint.yml
+++ b/.github/workflows/golangci-lint.yml
@@ -32,4 +32,4 @@ jobs:
uses: golangci/golangci-lint-action@v9
with:
# Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
- version: v2.10
+ version: v2.11
diff --git a/.golangci.yml b/.golangci.yml
index 89222745..886b55fb 100644
--- a/.golangci.yml
+++ b/.golangci.yml
@@ -34,7 +34,7 @@ linters:
disabled-checks:
- exitAfterDefer
gosec:
- excludes: ["G115", "G117", "G204", "G401", "G402", "G404", "G501", "G703", "G704", "G705"]
+ excludes: ["G115", "G117", "G118", "G204", "G401", "G402", "G404", "G501", "G703", "G704", "G705"]
severity: low
confidence: low
govet:
diff --git a/client/connector.go b/client/connector.go
index 51536750..34b24f0d 100644
--- a/client/connector.go
+++ b/client/connector.go
@@ -29,6 +29,8 @@ import (
"github.com/samber/lo"
v1 "github.com/fatedier/frp/pkg/config/v1"
+ "github.com/fatedier/frp/pkg/msg"
+ "github.com/fatedier/frp/pkg/proto/wire"
"github.com/fatedier/frp/pkg/transport"
netpkg "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/xlog"
@@ -41,6 +43,39 @@ type Connector interface {
Close() error
}
+type MessageConnector interface {
+ Connect() (*msg.Conn, error)
+ Close() error
+}
+
+type messageConnector struct {
+ connector Connector
+ wireProtocol string
+}
+
+func newMessageConnector(connector Connector, wireProtocol string) *messageConnector {
+ return &messageConnector{
+ connector: connector,
+ wireProtocol: wireProtocol,
+ }
+}
+
+func (c *messageConnector) Connect() (*msg.Conn, error) {
+ conn, err := c.connector.Connect()
+ if err != nil {
+ return nil, err
+ }
+ if err = wire.WriteMagicIfV2(conn, c.wireProtocol); err != nil {
+ conn.Close()
+ return nil, err
+ }
+ return msg.NewConn(conn, msg.NewReadWriter(conn, c.wireProtocol)), nil
+}
+
+func (c *messageConnector) Close() error {
+ return c.connector.Close()
+}
+
// defaultConnectorImpl is the default implementation of Connector for normal frpc.
type defaultConnectorImpl struct {
ctx context.Context
diff --git a/client/control.go b/client/control.go
index 020ac94f..6e3002bc 100644
--- a/client/control.go
+++ b/client/control.go
@@ -27,7 +27,6 @@ import (
"github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/naming"
"github.com/fatedier/frp/pkg/transport"
- netpkg "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/wait"
"github.com/fatedier/frp/pkg/util/xlog"
"github.com/fatedier/frp/pkg/vnet"
@@ -41,13 +40,11 @@ type SessionContext struct {
// It should be attached to the login message when reconnecting.
RunID string
// Underlying control connection. Once conn is closed, the msgDispatcher and the entire Control will exit.
- Conn net.Conn
- // Indicates whether the connection is encrypted.
- ConnEncrypted bool
+ Conn *msg.Conn
// Auth runtime used for login, heartbeats, and encryption.
Auth *auth.ClientAuth
- // Connector is used to create new connections, which could be real TCP connections or virtual streams.
- Connector Connector
+ // Connector is used to create message connections to frps.
+ Connector MessageConnector
// Virtual net controller
VnetController *vnet.Controller
}
@@ -91,15 +88,7 @@ func NewControl(ctx context.Context, sessionCtx *SessionContext) (*Control, erro
}
ctl.lastPong.Store(time.Now())
- if sessionCtx.ConnEncrypted {
- cryptoRW, err := netpkg.NewCryptoReadWriter(sessionCtx.Conn, sessionCtx.Auth.EncryptionKey())
- if err != nil {
- return nil, err
- }
- ctl.msgDispatcher = msg.NewDispatcher(cryptoRW)
- } else {
- ctl.msgDispatcher = msg.NewDispatcher(sessionCtx.Conn)
- }
+ ctl.msgDispatcher = msg.NewDispatcher(sessionCtx.Conn)
ctl.registerMsgHandlers()
ctl.msgTransporter = transport.NewMessageTransporter(ctl.msgDispatcher)
@@ -139,14 +128,14 @@ func (ctl *Control) handleReqWorkConn(_ msg.Message) {
workConn.Close()
return
}
- if err = msg.WriteMsg(workConn, m); err != nil {
+ if err = workConn.WriteMsg(m); err != nil {
xl.Warnf("work connection write to server error: %v", err)
workConn.Close()
return
}
var startMsg msg.StartWorkConn
- if err = msg.ReadMsgInto(workConn, &startMsg); err != nil {
+ if err = workConn.ReadMsgInto(&startMsg); err != nil {
xl.Tracef("work connection closed before response StartWorkConn message: %v", err)
workConn.Close()
return
@@ -227,7 +216,7 @@ func (ctl *Control) Done() <-chan struct{} {
}
// connectServer return a new connection to frps
-func (ctl *Control) connectServer() (net.Conn, error) {
+func (ctl *Control) connectServer() (*msg.Conn, error) {
return ctl.sessionCtx.Connector.Connect()
}
diff --git a/client/control_session.go b/client/control_session.go
new file mode 100644
index 00000000..4438acfe
--- /dev/null
+++ b/client/control_session.go
@@ -0,0 +1,172 @@
+// Copyright 2026 The frp Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package client
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "os"
+ "runtime"
+ "time"
+
+ "github.com/samber/lo"
+
+ "github.com/fatedier/frp/pkg/auth"
+ v1 "github.com/fatedier/frp/pkg/config/v1"
+ "github.com/fatedier/frp/pkg/msg"
+ "github.com/fatedier/frp/pkg/proto/wire"
+ netpkg "github.com/fatedier/frp/pkg/util/net"
+ "github.com/fatedier/frp/pkg/util/version"
+ "github.com/fatedier/frp/pkg/vnet"
+)
+
+type controlSessionDialer struct {
+ ctx context.Context
+
+ common *v1.ClientCommonConfig
+ auth *auth.ClientAuth
+ clientSpec *msg.ClientSpec
+ vnetController *vnet.Controller
+
+ connectorCreator func(context.Context, *v1.ClientCommonConfig) Connector
+}
+
+func (d *controlSessionDialer) Dial(previousRunID string) (*SessionContext, error) {
+ connector := d.connectorCreator(d.ctx, d.common)
+ if err := connector.Open(); err != nil {
+ return nil, err
+ }
+
+ success := false
+ defer func() {
+ if !success {
+ _ = connector.Close()
+ }
+ }()
+
+ conn, err := connector.Connect()
+ if err != nil {
+ return nil, err
+ }
+ defer func() {
+ if !success {
+ _ = conn.Close()
+ }
+ }()
+
+ loginMsg, err := d.buildLoginMsg(previousRunID)
+ if err != nil {
+ return nil, err
+ }
+
+ loginRespMsg, err := d.exchangeLogin(conn, loginMsg)
+ if err != nil {
+ return nil, err
+ }
+ if loginRespMsg.Error != "" {
+ return nil, errors.New(loginRespMsg.Error)
+ }
+
+ var controlRW io.ReadWriter = conn
+ if d.clientSpec == nil || d.clientSpec.Type != "ssh-tunnel" {
+ controlRW, err = netpkg.NewCryptoReadWriter(conn, d.auth.EncryptionKey())
+ if err != nil {
+ return nil, fmt.Errorf("create control crypto read writer: %w", err)
+ }
+ }
+
+ success = true
+ return &SessionContext{
+ Common: d.common,
+ RunID: loginRespMsg.RunID,
+ Conn: msg.NewConn(conn, msg.NewReadWriter(controlRW, d.common.Transport.WireProtocol)),
+ Auth: d.auth,
+ Connector: newMessageConnector(connector, d.common.Transport.WireProtocol),
+ VnetController: d.vnetController,
+ }, nil
+}
+
+func (d *controlSessionDialer) buildLoginMsg(previousRunID string) (*msg.Login, error) {
+ hostname, _ := os.Hostname()
+ loginMsg := &msg.Login{
+ Arch: runtime.GOARCH,
+ Os: runtime.GOOS,
+ Hostname: hostname,
+ PoolCount: d.common.Transport.PoolCount,
+ User: d.common.User,
+ ClientID: d.common.ClientID,
+ Version: version.Full(),
+ Timestamp: time.Now().Unix(),
+ RunID: previousRunID,
+ Metas: d.common.Metadatas,
+ }
+ if d.clientSpec != nil {
+ loginMsg.ClientSpec = *d.clientSpec
+ }
+
+ if err := d.auth.Setter.SetLogin(loginMsg); err != nil {
+ return nil, err
+ }
+ return loginMsg, nil
+}
+
+func (d *controlSessionDialer) exchangeLogin(conn net.Conn, loginMsg *msg.Login) (*msg.LoginResp, error) {
+ rw := msg.NewV1ReadWriter(conn)
+ var wireConn *wire.Conn
+
+ if d.common.Transport.WireProtocol == wire.ProtocolV2 {
+ if err := wire.WriteMagic(conn); err != nil {
+ return nil, err
+ }
+
+ wireConn = wire.NewConn(conn)
+ rw = msg.NewV2ReadWriterWithConn(wireConn)
+ hello := wire.DefaultClientHello(wire.BootstrapInfo{
+ Transport: d.common.Transport.Protocol,
+ TLS: lo.FromPtr(d.common.Transport.TLS.Enable) || d.common.Transport.Protocol == "wss" || d.common.Transport.Protocol == "quic",
+ TCPMux: lo.FromPtr(d.common.Transport.TCPMux),
+ })
+ if err := wireConn.WriteJSONFrame(wire.FrameTypeClientHello, hello); err != nil {
+ return nil, err
+ }
+ }
+ if err := rw.WriteMsg(loginMsg); err != nil {
+ return nil, err
+ }
+
+ _ = conn.SetReadDeadline(time.Now().Add(10 * time.Second))
+ defer func() {
+ _ = conn.SetReadDeadline(time.Time{})
+ }()
+
+ if wireConn != nil {
+ var serverHello wire.ServerHello
+ if err := wireConn.ReadJSONFrame(wire.FrameTypeServerHello, &serverHello); err != nil {
+ return nil, err
+ }
+ if serverHello.Error != "" {
+ return nil, errors.New(serverHello.Error)
+ }
+ }
+
+ var loginRespMsg msg.LoginResp
+ if err := rw.ReadMsgInto(&loginRespMsg); err != nil {
+ return nil, err
+ }
+ return &loginRespMsg, nil
+}
diff --git a/client/control_session_test.go b/client/control_session_test.go
new file mode 100644
index 00000000..9e59c9cd
--- /dev/null
+++ b/client/control_session_test.go
@@ -0,0 +1,245 @@
+// Copyright 2026 The frp Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package client
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "net"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+
+ "github.com/fatedier/frp/pkg/auth"
+ v1 "github.com/fatedier/frp/pkg/config/v1"
+ "github.com/fatedier/frp/pkg/msg"
+ "github.com/fatedier/frp/pkg/proto/wire"
+)
+
+type testConnector struct {
+ conn net.Conn
+ closed atomic.Bool
+}
+
+func (c *testConnector) Open() error {
+ return nil
+}
+
+func (c *testConnector) Connect() (net.Conn, error) {
+ return c.conn, nil
+}
+
+func (c *testConnector) Close() error {
+ c.closed.Store(true)
+ return nil
+}
+
+type trackingConn struct {
+ net.Conn
+ closed atomic.Bool
+}
+
+func (c *trackingConn) Close() error {
+ c.closed.Store(true)
+ return c.Conn.Close()
+}
+
+func newTestControlSessionDialer(t *testing.T, protocol string, connector Connector, clientSpec *msg.ClientSpec) *controlSessionDialer {
+ t.Helper()
+
+ authRuntime, err := auth.BuildClientAuth(&v1.AuthClientConfig{
+ Method: v1.AuthMethodToken,
+ Token: "token",
+ })
+ require.NoError(t, err)
+
+ return &controlSessionDialer{
+ ctx: context.Background(),
+ common: &v1.ClientCommonConfig{
+ User: "test-user",
+ Transport: v1.ClientTransportConfig{
+ Protocol: "tcp",
+ WireProtocol: protocol,
+ },
+ },
+ auth: authRuntime,
+ clientSpec: clientSpec,
+ connectorCreator: func(context.Context, *v1.ClientCommonConfig) Connector {
+ return connector
+ },
+ }
+}
+
+func TestControlSessionDialerDialV1(t *testing.T) {
+ clientRaw, serverRaw := net.Pipe()
+ defer serverRaw.Close()
+
+ connector := &testConnector{conn: &trackingConn{Conn: clientRaw}}
+ serverErrCh := make(chan error, 1)
+ go func() {
+ rw := msg.NewV1ReadWriter(serverRaw)
+ var loginMsg msg.Login
+ if err := rw.ReadMsgInto(&loginMsg); err != nil {
+ serverErrCh <- err
+ return
+ }
+ if loginMsg.RunID != "previous-run-id" {
+ serverErrCh <- fmt.Errorf("unexpected previous run id: %s", loginMsg.RunID)
+ return
+ }
+ if loginMsg.User != "test-user" {
+ serverErrCh <- fmt.Errorf("unexpected user: %s", loginMsg.User)
+ return
+ }
+ serverErrCh <- rw.WriteMsg(&msg.LoginResp{RunID: "run-v1"})
+ }()
+
+ dialer := newTestControlSessionDialer(t, wire.ProtocolV1, connector, nil)
+ sessionCtx, err := dialer.Dial("previous-run-id")
+ require.NoError(t, err)
+ defer sessionCtx.Conn.Close()
+ defer sessionCtx.Connector.Close()
+
+ require.Equal(t, "run-v1", sessionCtx.RunID)
+ require.NotNil(t, sessionCtx.Conn)
+ require.NotNil(t, sessionCtx.Connector)
+ require.False(t, connector.closed.Load())
+ require.NoError(t, <-serverErrCh)
+}
+
+func TestControlSessionDialerDialV2(t *testing.T) {
+ clientRaw, serverRaw := net.Pipe()
+ defer serverRaw.Close()
+
+ connector := &testConnector{conn: &trackingConn{Conn: clientRaw}}
+ serverErrCh := make(chan error, 1)
+ go func() {
+ magic := make([]byte, len(wire.MagicV2))
+ if _, err := io.ReadFull(serverRaw, magic); err != nil {
+ serverErrCh <- err
+ return
+ }
+ if string(magic) != wire.MagicV2 {
+ serverErrCh <- fmt.Errorf("unexpected magic: %q", string(magic))
+ return
+ }
+
+ wireConn := wire.NewConn(serverRaw)
+ var hello wire.ClientHello
+ if err := wireConn.ReadJSONFrame(wire.FrameTypeClientHello, &hello); err != nil {
+ serverErrCh <- err
+ return
+ }
+ if err := wire.ValidateClientHello(hello); err != nil {
+ serverErrCh <- err
+ return
+ }
+
+ rw := msg.NewV2ReadWriterWithConn(wireConn)
+ var loginMsg msg.Login
+ if err := rw.ReadMsgInto(&loginMsg); err != nil {
+ serverErrCh <- err
+ return
+ }
+ if loginMsg.User != "test-user" {
+ serverErrCh <- fmt.Errorf("unexpected user: %s", loginMsg.User)
+ return
+ }
+ if err := wireConn.WriteJSONFrame(wire.FrameTypeServerHello, wire.DefaultServerHello()); err != nil {
+ serverErrCh <- err
+ return
+ }
+ serverErrCh <- rw.WriteMsg(&msg.LoginResp{RunID: "run-v2"})
+ }()
+
+ dialer := newTestControlSessionDialer(t, wire.ProtocolV2, connector, nil)
+ sessionCtx, err := dialer.Dial("")
+ require.NoError(t, err)
+ defer sessionCtx.Conn.Close()
+ defer sessionCtx.Connector.Close()
+
+ require.Equal(t, "run-v2", sessionCtx.RunID)
+ require.NotNil(t, sessionCtx.Conn)
+ require.NotNil(t, sessionCtx.Connector)
+ require.False(t, connector.closed.Load())
+ require.NoError(t, <-serverErrCh)
+}
+
+func TestControlSessionDialerDialLoginErrorClosesResources(t *testing.T) {
+ clientRaw, serverRaw := net.Pipe()
+ defer serverRaw.Close()
+
+ clientConn := &trackingConn{Conn: clientRaw}
+ connector := &testConnector{conn: clientConn}
+ serverErrCh := make(chan error, 1)
+ go func() {
+ rw := msg.NewV1ReadWriter(serverRaw)
+ var loginMsg msg.Login
+ if err := rw.ReadMsgInto(&loginMsg); err != nil {
+ serverErrCh <- err
+ return
+ }
+ serverErrCh <- rw.WriteMsg(&msg.LoginResp{Error: "login denied"})
+ }()
+
+ dialer := newTestControlSessionDialer(t, wire.ProtocolV1, connector, nil)
+ sessionCtx, err := dialer.Dial("")
+ require.Nil(t, sessionCtx)
+ require.ErrorContains(t, err, "login denied")
+ require.True(t, clientConn.closed.Load())
+ require.True(t, connector.closed.Load())
+ require.NoError(t, <-serverErrCh)
+}
+
+func TestControlSessionDialerDialSSHTunnelSkipsControlEncryption(t *testing.T) {
+ clientRaw, serverRaw := net.Pipe()
+ defer serverRaw.Close()
+
+ connector := &testConnector{conn: &trackingConn{Conn: clientRaw}}
+ serverErrCh := make(chan error, 1)
+ go func() {
+ rw := msg.NewV1ReadWriter(serverRaw)
+ var loginMsg msg.Login
+ if err := rw.ReadMsgInto(&loginMsg); err != nil {
+ serverErrCh <- err
+ return
+ }
+ if err := rw.WriteMsg(&msg.LoginResp{RunID: "run-ssh-tunnel"}); err != nil {
+ serverErrCh <- err
+ return
+ }
+
+ _ = serverRaw.SetReadDeadline(time.Now().Add(time.Second))
+ var ping msg.Ping
+ if err := rw.ReadMsgInto(&ping); err != nil {
+ serverErrCh <- err
+ return
+ }
+ serverErrCh <- nil
+ }()
+
+ dialer := newTestControlSessionDialer(t, wire.ProtocolV1, connector, &msg.ClientSpec{Type: "ssh-tunnel"})
+ sessionCtx, err := dialer.Dial("")
+ require.NoError(t, err)
+ defer sessionCtx.Conn.Close()
+ defer sessionCtx.Connector.Close()
+
+ require.Equal(t, "run-ssh-tunnel", sessionCtx.RunID)
+ require.NoError(t, sessionCtx.Conn.WriteMsg(&msg.Ping{}))
+ require.NoError(t, <-serverErrCh)
+}
diff --git a/client/service.go b/client/service.go
index 5c51fd67..e6e42e46 100644
--- a/client/service.go
+++ b/client/service.go
@@ -21,7 +21,6 @@ import (
"net"
"net/http"
"os"
- "runtime"
"sync"
"time"
@@ -38,7 +37,6 @@ import (
httppkg "github.com/fatedier/frp/pkg/util/http"
"github.com/fatedier/frp/pkg/util/log"
netpkg "github.com/fatedier/frp/pkg/util/net"
- "github.com/fatedier/frp/pkg/util/version"
"github.com/fatedier/frp/pkg/util/wait"
"github.com/fatedier/frp/pkg/util/xlog"
"github.com/fatedier/frp/pkg/vnet"
@@ -303,80 +301,20 @@ func (svr *Service) keepControllerWorking() {
), true, svr.ctx.Done())
}
-// login creates a connection to frps and registers it self as a client
-// conn: control connection
-// session: if it's not nil, using tcp mux
-func (svr *Service) login() (conn net.Conn, connector Connector, err error) {
- xl := xlog.FromContextSafe(svr.ctx)
- connector = svr.connectorCreator(svr.ctx, svr.common)
- if err = connector.Open(); err != nil {
- return nil, nil, err
- }
-
- defer func() {
- if err != nil {
- connector.Close()
- }
- }()
-
- conn, err = connector.Connect()
- if err != nil {
- return
- }
-
- hostname, _ := os.Hostname()
-
- loginMsg := &msg.Login{
- Arch: runtime.GOARCH,
- Os: runtime.GOOS,
- Hostname: hostname,
- PoolCount: svr.common.Transport.PoolCount,
- User: svr.common.User,
- ClientID: svr.common.ClientID,
- Version: version.Full(),
- Timestamp: time.Now().Unix(),
- RunID: svr.runID,
- Metas: svr.common.Metadatas,
- }
- if svr.clientSpec != nil {
- loginMsg.ClientSpec = *svr.clientSpec
- }
-
- // Add auth
- if err = svr.auth.Setter.SetLogin(loginMsg); err != nil {
- return
- }
-
- if err = msg.WriteMsg(conn, loginMsg); err != nil {
- return
- }
-
- var loginRespMsg msg.LoginResp
- _ = conn.SetReadDeadline(time.Now().Add(10 * time.Second))
- if err = msg.ReadMsgInto(conn, &loginRespMsg); err != nil {
- return
- }
- _ = conn.SetReadDeadline(time.Time{})
-
- if loginRespMsg.Error != "" {
- err = fmt.Errorf("%s", loginRespMsg.Error)
- xl.Errorf("%s", loginRespMsg.Error)
- return
- }
-
- svr.runID = loginRespMsg.RunID
- xl.AddPrefix(xlog.LogPrefix{Name: "runID", Value: svr.runID})
-
- xl.Infof("login to server success, get run id [%s]", loginRespMsg.RunID)
- return
-}
-
func (svr *Service) loopLoginUntilSuccess(maxInterval time.Duration, firstLoginExit bool) {
xl := xlog.FromContextSafe(svr.ctx)
loginFunc := func() (bool, error) {
xl.Infof("try to connect to server...")
- conn, connector, err := svr.login()
+ dialer := &controlSessionDialer{
+ ctx: svr.ctx,
+ common: svr.common,
+ auth: svr.auth,
+ clientSpec: svr.clientSpec,
+ vnetController: svr.vnetController,
+ connectorCreator: svr.connectorCreator,
+ }
+ sessionCtx, err := dialer.Dial(svr.runID)
if err != nil {
xl.Warnf("connect to server error: %v", err)
if firstLoginExit {
@@ -385,25 +323,19 @@ func (svr *Service) loopLoginUntilSuccess(maxInterval time.Duration, firstLoginE
return false, err
}
+ svr.runID = sessionCtx.RunID
+ xl.AddPrefix(xlog.LogPrefix{Name: "runID", Value: svr.runID})
+ xl.Infof("login to server success, get run id [%s]", svr.runID)
+
svr.cfgMu.RLock()
proxyCfgs := svr.proxyCfgs
visitorCfgs := svr.visitorCfgs
svr.cfgMu.RUnlock()
- connEncrypted := svr.clientSpec == nil || svr.clientSpec.Type != "ssh-tunnel"
-
- sessionCtx := &SessionContext{
- Common: svr.common,
- RunID: svr.runID,
- Conn: conn,
- ConnEncrypted: connEncrypted,
- Auth: svr.auth,
- Connector: connector,
- VnetController: svr.vnetController,
- }
ctl, err := NewControl(svr.ctx, sessionCtx)
if err != nil {
- conn.Close()
+ sessionCtx.Conn.Close()
+ sessionCtx.Connector.Close()
xl.Errorf("new control error: %v", err)
return false, err
}
diff --git a/client/visitor/visitor.go b/client/visitor/visitor.go
index 14a7aa37..dff0bb94 100644
--- a/client/visitor/visitor.go
+++ b/client/visitor/visitor.go
@@ -38,7 +38,7 @@ import (
// Helper wraps some functions for visitor to use.
type Helper interface {
// ConnectServer directly connects to the frp server.
- ConnectServer() (net.Conn, error)
+ ConnectServer() (*msg.Conn, error)
// TransferConn transfers the connection to another visitor.
TransferConn(string, net.Conn) error
// MsgTransporter returns the message transporter that is used to send and receive messages
@@ -167,15 +167,15 @@ func (v *BaseVisitor) dialRawVisitorConn(cfg *v1.VisitorBaseConfig) (net.Conn, e
UseEncryption: cfg.Transport.UseEncryption,
UseCompression: cfg.Transport.UseCompression,
}
- err = msg.WriteMsg(visitorConn, newVisitorConnMsg)
+ err = visitorConn.WriteMsg(newVisitorConnMsg)
if err != nil {
visitorConn.Close()
return nil, fmt.Errorf("send newVisitorConnMsg to server error: %v", err)
}
- var newVisitorConnRespMsg msg.NewVisitorConnResp
_ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second))
- err = msg.ReadMsgInto(visitorConn, &newVisitorConnRespMsg)
+ var newVisitorConnRespMsg msg.NewVisitorConnResp
+ err = visitorConn.ReadMsgInto(&newVisitorConnRespMsg)
if err != nil {
visitorConn.Close()
return nil, fmt.Errorf("read newVisitorConnRespMsg error: %v", err)
diff --git a/client/visitor/visitor_manager.go b/client/visitor/visitor_manager.go
index b60f7047..1ca194bd 100644
--- a/client/visitor/visitor_manager.go
+++ b/client/visitor/visitor_manager.go
@@ -25,6 +25,7 @@ import (
"github.com/samber/lo"
v1 "github.com/fatedier/frp/pkg/config/v1"
+ "github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/transport"
"github.com/fatedier/frp/pkg/util/xlog"
"github.com/fatedier/frp/pkg/vnet"
@@ -49,7 +50,7 @@ func NewManager(
ctx context.Context,
runID string,
clientCfg *v1.ClientCommonConfig,
- connectServer func() (net.Conn, error),
+ connectServer func() (*msg.Conn, error),
msgTransporter transport.MessageTransporter,
vnetController *vnet.Controller,
) *Manager {
@@ -199,14 +200,14 @@ func (vm *Manager) GetVisitorCfg(name string) (v1.VisitorConfigurer, bool) {
}
type visitorHelperImpl struct {
- connectServerFn func() (net.Conn, error)
+ connectServerFn func() (*msg.Conn, error)
msgTransporter transport.MessageTransporter
vnetController *vnet.Controller
transferConnFn func(name string, conn net.Conn) error
runID string
}
-func (v *visitorHelperImpl) ConnectServer() (net.Conn, error) {
+func (v *visitorHelperImpl) ConnectServer() (*msg.Conn, error) {
return v.connectServerFn()
}
diff --git a/conf/frpc_full_example.toml b/conf/frpc_full_example.toml
index 50a00cfd..b141b0ed 100644
--- a/conf/frpc_full_example.toml
+++ b/conf/frpc_full_example.toml
@@ -103,6 +103,10 @@ transport.poolCount = 5
# supports tcp, kcp, quic, websocket and wss now, default is tcp
transport.protocol = "tcp"
+# FRP wire protocol used inside the selected transport.
+# supports v1 and v2, default is v1. v2 requires frps support and must be enabled explicitly.
+# transport.wireProtocol = "v1"
+
# set client binding ip when connect server, default is empty.
# only when protocol = tcp or websocket, the value will be used.
transport.connectServerLocalIP = "0.0.0.0"
diff --git a/pkg/config/v1/client.go b/pkg/config/v1/client.go
index 783eee8e..05d02c94 100644
--- a/pkg/config/v1/client.go
+++ b/pkg/config/v1/client.go
@@ -104,6 +104,9 @@ type ClientTransportConfig struct {
// Valid values are "tcp", "kcp", "quic", "websocket" and "wss". By default, this value
// is "tcp".
Protocol string `json:"protocol,omitempty"`
+ // WireProtocol specifies the frpc/frps internal wire protocol version.
+ // Valid values are "v1" and "v2". By default, this value is "v1".
+ WireProtocol string `json:"wireProtocol,omitempty"`
// The maximum amount of time a dial to server will wait for a connect to complete.
DialServerTimeout int64 `json:"dialServerTimeout,omitempty"`
// DialServerKeepAlive specifies the interval between keep-alive probes for an active network connection between frpc and frps.
@@ -143,6 +146,7 @@ type ClientTransportConfig struct {
func (c *ClientTransportConfig) Complete() {
c.Protocol = util.EmptyOr(c.Protocol, "tcp")
+ c.WireProtocol = util.EmptyOr(c.WireProtocol, "v1")
c.DialServerTimeout = util.EmptyOr(c.DialServerTimeout, 10)
c.DialServerKeepAlive = util.EmptyOr(c.DialServerKeepAlive, 7200)
c.ProxyURL = util.EmptyOr(c.ProxyURL, os.Getenv("http_proxy"))
diff --git a/pkg/config/v1/client_test.go b/pkg/config/v1/client_test.go
index 5473a5f6..67214a5d 100644
--- a/pkg/config/v1/client_test.go
+++ b/pkg/config/v1/client_test.go
@@ -29,6 +29,7 @@ func TestClientConfigComplete(t *testing.T) {
require.EqualValues("token", c.Auth.Method)
require.Equal(true, lo.FromPtr(c.Transport.TCPMux))
+ require.Equal("v1", c.Transport.WireProtocol)
require.Equal(true, lo.FromPtr(c.LoginFailExit))
require.Equal(true, lo.FromPtr(c.Transport.TLS.Enable))
require.Equal(true, lo.FromPtr(c.Transport.TLS.DisableCustomTLSFirstByte))
diff --git a/pkg/config/v1/validation/client.go b/pkg/config/v1/validation/client.go
index c90d525d..5c8433d5 100644
--- a/pkg/config/v1/validation/client.go
+++ b/pkg/config/v1/validation/client.go
@@ -146,6 +146,9 @@ func validateTransportConfig(c *v1.ClientTransportConfig) (Warning, error) {
if !slices.Contains(SupportedTransportProtocols, c.Protocol) {
errs = AppendError(errs, fmt.Errorf("invalid transport.protocol, optional values are %v", SupportedTransportProtocols))
}
+ if !slices.Contains(SupportedWireProtocols, c.WireProtocol) {
+ errs = AppendError(errs, fmt.Errorf("invalid transport.wireProtocol, optional values are %v", SupportedWireProtocols))
+ }
return warnings, errs
}
diff --git a/pkg/config/v1/validation/validation.go b/pkg/config/v1/validation/validation.go
index 4ca6b67f..417cec0f 100644
--- a/pkg/config/v1/validation/validation.go
+++ b/pkg/config/v1/validation/validation.go
@@ -29,6 +29,10 @@ var (
"websocket",
"wss",
}
+ SupportedWireProtocols = []string{
+ "v1",
+ "v2",
+ }
SupportedAuthMethods = []v1.AuthMethod{
"token",
diff --git a/pkg/msg/conn_test.go b/pkg/msg/conn_test.go
new file mode 100644
index 00000000..f3498f03
--- /dev/null
+++ b/pkg/msg/conn_test.go
@@ -0,0 +1,56 @@
+// Copyright 2026 The frp Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package msg
+
+import (
+ "net"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ "github.com/fatedier/frp/pkg/proto/wire"
+)
+
+func TestConnReadWriteMsg(t *testing.T) {
+ tests := []struct {
+ name string
+ protocol string
+ }{
+ {name: "v1", protocol: wire.ProtocolV1},
+ {name: "v2", protocol: wire.ProtocolV2},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ client, server := net.Pipe()
+ defer client.Close()
+ defer server.Close()
+
+ clientConn := NewConn(client, NewReadWriter(client, tt.protocol))
+ serverConn := NewConn(server, NewReadWriter(server, tt.protocol))
+
+ in := &Ping{PrivilegeKey: "key", Timestamp: 123}
+ errCh := make(chan error, 1)
+ go func() {
+ errCh <- clientConn.WriteMsg(in)
+ }()
+
+ out, err := serverConn.ReadMsg()
+ require.NoError(t, err)
+ require.Equal(t, in, out)
+ require.NoError(t, <-errCh)
+ })
+ }
+}
diff --git a/pkg/msg/handler.go b/pkg/msg/handler.go
index 243e599a..b073e59b 100644
--- a/pkg/msg/handler.go
+++ b/pkg/msg/handler.go
@@ -15,10 +15,90 @@
package msg
import (
+ "context"
"io"
+ "net"
"reflect"
+
+ "github.com/fatedier/frp/pkg/proto/wire"
)
+type ReadWriter interface {
+ ReadMsg() (Message, error)
+ ReadMsgInto(Message) error
+ WriteMsg(Message) error
+}
+
+type Conn struct {
+ net.Conn
+ rw ReadWriter
+}
+
+func NewConn(conn net.Conn, rw ReadWriter) *Conn {
+ return &Conn{
+ Conn: conn,
+ rw: rw,
+ }
+}
+
+func (c *Conn) ReadMsg() (Message, error) {
+ return c.rw.ReadMsg()
+}
+
+func (c *Conn) ReadMsgInto(m Message) error {
+ return c.rw.ReadMsgInto(m)
+}
+
+func (c *Conn) WriteMsg(m Message) error {
+ return c.rw.WriteMsg(m)
+}
+
+func (c *Conn) Context() context.Context {
+ if getter, ok := c.Conn.(interface{ Context() context.Context }); ok {
+ return getter.Context()
+ }
+ return context.Background()
+}
+
+func (c *Conn) WithContext(ctx context.Context) {
+ if setter, ok := c.Conn.(interface{ WithContext(context.Context) }); ok {
+ setter.WithContext(ctx)
+ }
+}
+
+type V1ReadWriter struct {
+ rw io.ReadWriter
+}
+
+func NewV1ReadWriter(rw io.ReadWriter) ReadWriter {
+ return &V1ReadWriter{rw: rw}
+}
+
+// NewReadWriter wraps rw with the message codec for the selected wire protocol.
+// An empty protocol keeps the historical v1 behavior for tests and older call sites.
+func NewReadWriter(rw io.ReadWriter, wireProtocol string) ReadWriter {
+ switch wireProtocol {
+ case wire.ProtocolV2:
+ return NewV2ReadWriter(rw)
+ case "", wire.ProtocolV1:
+ return NewV1ReadWriter(rw)
+ default:
+ return NewV1ReadWriter(rw)
+ }
+}
+
+func (rw *V1ReadWriter) ReadMsg() (Message, error) {
+ return ReadMsg(rw.rw)
+}
+
+func (rw *V1ReadWriter) ReadMsgInto(m Message) error {
+ return ReadMsgInto(rw.rw, m)
+}
+
+func (rw *V1ReadWriter) WriteMsg(m Message) error {
+ return WriteMsg(rw.rw, m)
+}
+
func AsyncHandler(f func(Message)) func(Message) {
return func(m Message) {
go f(m)
@@ -27,7 +107,7 @@ func AsyncHandler(f func(Message)) func(Message) {
// Dispatcher is used to send messages to net.Conn or register handlers for messages read from net.Conn.
type Dispatcher struct {
- rw io.ReadWriter
+ rw ReadWriter
sendCh chan Message
doneCh chan struct{}
@@ -35,7 +115,7 @@ type Dispatcher struct {
defaultHandler func(Message)
}
-func NewDispatcher(rw io.ReadWriter) *Dispatcher {
+func NewDispatcher(rw ReadWriter) *Dispatcher {
return &Dispatcher{
rw: rw,
sendCh: make(chan Message, 100),
@@ -56,14 +136,14 @@ func (d *Dispatcher) sendLoop() {
case <-d.doneCh:
return
case m := <-d.sendCh:
- _ = WriteMsg(d.rw, m)
+ _ = d.rw.WriteMsg(m)
}
}
}
func (d *Dispatcher) readLoop() {
for {
- m, err := ReadMsg(d.rw)
+ m, err := d.rw.ReadMsg()
if err != nil {
close(d.doneCh)
return
diff --git a/pkg/msg/msg.go b/pkg/msg/msg.go
index e8bcbc35..52948bd5 100644
--- a/pkg/msg/msg.go
+++ b/pkg/msg/msg.go
@@ -20,24 +20,24 @@ import (
)
const (
- TypeLogin = 'o'
- TypeLoginResp = '1'
- TypeNewProxy = 'p'
- TypeNewProxyResp = '2'
- TypeCloseProxy = 'c'
- TypeNewWorkConn = 'w'
- TypeReqWorkConn = 'r'
- TypeStartWorkConn = 's'
- TypeNewVisitorConn = 'v'
- TypeNewVisitorConnResp = '3'
- TypePing = 'h'
- TypePong = '4'
- TypeUDPPacket = 'u'
- TypeNatHoleVisitor = 'i'
- TypeNatHoleClient = 'n'
- TypeNatHoleResp = 'm'
- TypeNatHoleSid = '5'
- TypeNatHoleReport = '6'
+ TypeLogin byte = 'o'
+ TypeLoginResp byte = '1'
+ TypeNewProxy byte = 'p'
+ TypeNewProxyResp byte = '2'
+ TypeCloseProxy byte = 'c'
+ TypeNewWorkConn byte = 'w'
+ TypeReqWorkConn byte = 'r'
+ TypeStartWorkConn byte = 's'
+ TypeNewVisitorConn byte = 'v'
+ TypeNewVisitorConnResp byte = '3'
+ TypePing byte = 'h'
+ TypePong byte = '4'
+ TypeUDPPacket byte = 'u'
+ TypeNatHoleVisitor byte = 'i'
+ TypeNatHoleClient byte = 'n'
+ TypeNatHoleResp byte = 'm'
+ TypeNatHoleSid byte = '5'
+ TypeNatHoleReport byte = '6'
)
var msgTypeMap = map[byte]any{
diff --git a/pkg/msg/msg_test.go b/pkg/msg/msg_test.go
new file mode 100644
index 00000000..06272faf
--- /dev/null
+++ b/pkg/msg/msg_test.go
@@ -0,0 +1,55 @@
+// Copyright 2026 The frp Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package msg
+
+import (
+ "reflect"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestV1MessageTypeIDsAreStable(t *testing.T) {
+ require.Equal(t, byte('o'), TypeLogin)
+ require.Equal(t, byte('1'), TypeLoginResp)
+ require.Equal(t, byte('p'), TypeNewProxy)
+ require.Equal(t, byte('2'), TypeNewProxyResp)
+ require.Equal(t, byte('c'), TypeCloseProxy)
+ require.Equal(t, byte('w'), TypeNewWorkConn)
+ require.Equal(t, byte('r'), TypeReqWorkConn)
+ require.Equal(t, byte('s'), TypeStartWorkConn)
+ require.Equal(t, byte('v'), TypeNewVisitorConn)
+ require.Equal(t, byte('3'), TypeNewVisitorConnResp)
+ require.Equal(t, byte('h'), TypePing)
+ require.Equal(t, byte('4'), TypePong)
+ require.Equal(t, byte('u'), TypeUDPPacket)
+ require.Equal(t, byte('i'), TypeNatHoleVisitor)
+ require.Equal(t, byte('n'), TypeNatHoleClient)
+ require.Equal(t, byte('m'), TypeNatHoleResp)
+ require.Equal(t, byte('5'), TypeNatHoleSid)
+ require.Equal(t, byte('6'), TypeNatHoleReport)
+}
+
+func TestMessageTypeMapIsCompleteAndUnique(t *testing.T) {
+ require.Len(t, msgTypeMap, 18)
+
+ msgTypes := make(map[reflect.Type]struct{}, len(msgTypeMap))
+
+ for _, m := range msgTypeMap {
+ msgType := reflect.TypeOf(m)
+ require.NotContains(t, msgTypes, msgType)
+ msgTypes[msgType] = struct{}{}
+ }
+}
diff --git a/pkg/msg/wire_v2.go b/pkg/msg/wire_v2.go
new file mode 100644
index 00000000..8d2cd88d
--- /dev/null
+++ b/pkg/msg/wire_v2.go
@@ -0,0 +1,192 @@
+// Copyright 2026 The frp Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package msg
+
+import (
+ "encoding/binary"
+ "encoding/json"
+ "fmt"
+ "io"
+ "reflect"
+
+ "github.com/fatedier/frp/pkg/proto/wire"
+)
+
+const (
+ V2TypeLogin uint16 = 1
+ V2TypeLoginResp uint16 = 2
+ V2TypeNewProxy uint16 = 3
+ V2TypeNewProxyResp uint16 = 4
+ V2TypeCloseProxy uint16 = 5
+ V2TypeNewWorkConn uint16 = 6
+ V2TypeReqWorkConn uint16 = 7
+ V2TypeStartWorkConn uint16 = 8
+ V2TypeNewVisitorConn uint16 = 9
+ V2TypeNewVisitorConnResp uint16 = 10
+ V2TypePing uint16 = 11
+ V2TypePong uint16 = 12
+ V2TypeUDPPacket uint16 = 13
+ V2TypeNatHoleVisitor uint16 = 14
+ V2TypeNatHoleClient uint16 = 15
+ V2TypeNatHoleResp uint16 = 16
+ V2TypeNatHoleSid uint16 = 17
+ V2TypeNatHoleReport uint16 = 18
+)
+
+var v2MsgTypeMap = map[uint16]any{
+ V2TypeLogin: Login{},
+ V2TypeLoginResp: LoginResp{},
+ V2TypeNewProxy: NewProxy{},
+ V2TypeNewProxyResp: NewProxyResp{},
+ V2TypeCloseProxy: CloseProxy{},
+ V2TypeNewWorkConn: NewWorkConn{},
+ V2TypeReqWorkConn: ReqWorkConn{},
+ V2TypeStartWorkConn: StartWorkConn{},
+ V2TypeNewVisitorConn: NewVisitorConn{},
+ V2TypeNewVisitorConnResp: NewVisitorConnResp{},
+ V2TypePing: Ping{},
+ V2TypePong: Pong{},
+ V2TypeUDPPacket: UDPPacket{},
+ V2TypeNatHoleVisitor: NatHoleVisitor{},
+ V2TypeNatHoleClient: NatHoleClient{},
+ V2TypeNatHoleResp: NatHoleResp{},
+ V2TypeNatHoleSid: NatHoleSid{},
+ V2TypeNatHoleReport: NatHoleReport{},
+}
+
+var v2MsgReflectTypeMap, v2MsgTypeIDMap = buildV2MsgTypeMaps()
+
+func buildV2MsgTypeMaps() (map[uint16]reflect.Type, map[reflect.Type]uint16) {
+ reflectTypeMap := make(map[uint16]reflect.Type, len(v2MsgTypeMap))
+ typeIDMap := make(map[reflect.Type]uint16, len(v2MsgTypeMap))
+ for typeID, m := range v2MsgTypeMap {
+ t := reflect.TypeOf(m)
+ reflectTypeMap[typeID] = t
+ typeIDMap[t] = typeID
+ }
+ return reflectTypeMap, typeIDMap
+}
+
+type V2ReadWriter struct {
+ conn *wire.Conn
+}
+
+func NewV2ReadWriter(rw io.ReadWriter) *V2ReadWriter {
+ return NewV2ReadWriterWithConn(wire.NewConn(rw))
+}
+
+func NewV2ReadWriterWithConn(conn *wire.Conn) *V2ReadWriter {
+ return &V2ReadWriter{conn: conn}
+}
+
+func (rw *V2ReadWriter) WireConn() *wire.Conn {
+ return rw.conn
+}
+
+func (rw *V2ReadWriter) ReadMsg() (Message, error) {
+ f, err := rw.conn.ReadFrame()
+ if err != nil {
+ return nil, err
+ }
+ return DecodeV2MessageFrame(f)
+}
+
+func (rw *V2ReadWriter) ReadMsgInto(m Message) error {
+ f, err := rw.conn.ReadFrame()
+ if err != nil {
+ return err
+ }
+ return DecodeV2MessageFrameInto(f, m)
+}
+
+func (rw *V2ReadWriter) WriteMsg(m Message) error {
+ f, err := EncodeV2MessageFrame(m)
+ if err != nil {
+ return err
+ }
+ return rw.conn.WriteFrame(f)
+}
+
+func DecodeV2MessageFrame(f *wire.Frame) (Message, error) {
+ if f.Type != wire.FrameTypeMessage {
+ return nil, fmt.Errorf("unexpected frame type %d, want %d", f.Type, wire.FrameTypeMessage)
+ }
+ if len(f.Payload) < 2 {
+ return nil, fmt.Errorf("message frame payload too short")
+ }
+ typeID := binary.BigEndian.Uint16(f.Payload[:2])
+ t, ok := v2MsgReflectTypeMap[typeID]
+ if !ok {
+ return nil, fmt.Errorf("unknown v2 message type %d", typeID)
+ }
+ m := reflect.New(t).Interface()
+ if err := json.Unmarshal(f.Payload[2:], m); err != nil {
+ return nil, err
+ }
+ return m, nil
+}
+
+func DecodeV2MessageFrameInto(f *wire.Frame, out Message) error {
+ if f.Type != wire.FrameTypeMessage {
+ return fmt.Errorf("unexpected frame type %d, want %d", f.Type, wire.FrameTypeMessage)
+ }
+ if len(f.Payload) < 2 {
+ return fmt.Errorf("message frame payload too short")
+ }
+
+ typeID := binary.BigEndian.Uint16(f.Payload[:2])
+ outType := reflect.TypeOf(out)
+ if outType == nil || outType.Kind() != reflect.Pointer {
+ return fmt.Errorf("message target must be a pointer")
+ }
+ elemType := outType.Elem()
+ expectedTypeID, ok := v2MsgTypeIDMap[elemType]
+ if !ok {
+ return fmt.Errorf("unknown v2 message type %s", elemType.String())
+ }
+ if typeID != expectedTypeID {
+ actualType, ok := v2MsgReflectTypeMap[typeID]
+ if !ok {
+ return fmt.Errorf("unknown v2 message type %d", typeID)
+ }
+ return fmt.Errorf("unexpected message type %s, want %s", actualType.String(), elemType.String())
+ }
+ return json.Unmarshal(f.Payload[2:], out)
+}
+
+func EncodeV2MessageFrame(m Message) (*wire.Frame, error) {
+ t := reflect.TypeOf(m)
+ if t == nil {
+ return nil, fmt.Errorf("nil message")
+ }
+ if t.Kind() == reflect.Pointer {
+ t = t.Elem()
+ }
+ typeID, ok := v2MsgTypeIDMap[t]
+ if !ok {
+ return nil, fmt.Errorf("unknown v2 message type %s", t.String())
+ }
+ content, err := json.Marshal(m)
+ if err != nil {
+ return nil, err
+ }
+ payload := make([]byte, 2+len(content))
+ binary.BigEndian.PutUint16(payload[:2], typeID)
+ copy(payload[2:], content)
+ return &wire.Frame{
+ Type: wire.FrameTypeMessage,
+ Payload: payload,
+ }, nil
+}
diff --git a/pkg/msg/wire_v2_test.go b/pkg/msg/wire_v2_test.go
new file mode 100644
index 00000000..5879fc37
--- /dev/null
+++ b/pkg/msg/wire_v2_test.go
@@ -0,0 +1,121 @@
+// Copyright 2026 The frp Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package msg
+
+import (
+ "bytes"
+ "encoding/binary"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ "github.com/fatedier/frp/pkg/proto/wire"
+)
+
+func TestV2ReadWriterRoundTrip(t *testing.T) {
+ var buf bytes.Buffer
+ rw := NewV2ReadWriter(&buf)
+
+ in := &Login{
+ Version: "test-version",
+ RunID: "run-id",
+ User: "user",
+ }
+ require.NoError(t, rw.WriteMsg(in))
+
+ out, err := rw.ReadMsg()
+ require.NoError(t, err)
+ require.Equal(t, in, out)
+}
+
+func TestNewReadWriter(t *testing.T) {
+ require.IsType(t, &V1ReadWriter{}, NewReadWriter(&bytes.Buffer{}, ""))
+ require.IsType(t, &V1ReadWriter{}, NewReadWriter(&bytes.Buffer{}, wire.ProtocolV1))
+ require.IsType(t, &V2ReadWriter{}, NewReadWriter(&bytes.Buffer{}, wire.ProtocolV2))
+}
+
+func TestV2MessageTypeIDsAreStable(t *testing.T) {
+ require.Equal(t, uint16(1), V2TypeLogin)
+ require.Equal(t, uint16(2), V2TypeLoginResp)
+ require.Equal(t, uint16(3), V2TypeNewProxy)
+ require.Equal(t, uint16(4), V2TypeNewProxyResp)
+ require.Equal(t, uint16(5), V2TypeCloseProxy)
+ require.Equal(t, uint16(6), V2TypeNewWorkConn)
+ require.Equal(t, uint16(7), V2TypeReqWorkConn)
+ require.Equal(t, uint16(8), V2TypeStartWorkConn)
+ require.Equal(t, uint16(9), V2TypeNewVisitorConn)
+ require.Equal(t, uint16(10), V2TypeNewVisitorConnResp)
+ require.Equal(t, uint16(11), V2TypePing)
+ require.Equal(t, uint16(12), V2TypePong)
+ require.Equal(t, uint16(13), V2TypeUDPPacket)
+ require.Equal(t, uint16(14), V2TypeNatHoleVisitor)
+ require.Equal(t, uint16(15), V2TypeNatHoleClient)
+ require.Equal(t, uint16(16), V2TypeNatHoleResp)
+ require.Equal(t, uint16(17), V2TypeNatHoleSid)
+ require.Equal(t, uint16(18), V2TypeNatHoleReport)
+}
+
+func TestV2MessageFrameEncoding(t *testing.T) {
+ frame, err := EncodeV2MessageFrame(&ReqWorkConn{})
+ require.NoError(t, err)
+ require.Equal(t, wire.FrameTypeMessage, frame.Type)
+ require.Len(t, frame.Payload, 4)
+ require.Equal(t, V2TypeReqWorkConn, binary.BigEndian.Uint16(frame.Payload[:2]))
+
+ out, err := DecodeV2MessageFrame(frame)
+ require.NoError(t, err)
+ require.IsType(t, &ReqWorkConn{}, out)
+}
+
+func TestDecodeV2MessageFrameInto(t *testing.T) {
+ in := &StartWorkConn{ProxyName: "tcp", SrcAddr: "127.0.0.1", SrcPort: 1234}
+ frame, err := EncodeV2MessageFrame(in)
+ require.NoError(t, err)
+
+ var out StartWorkConn
+ require.NoError(t, DecodeV2MessageFrameInto(frame, &out))
+ require.Equal(t, *in, out)
+}
+
+func TestDecodeV2MessageFrameRejectsInvalidFrame(t *testing.T) {
+ _, err := DecodeV2MessageFrame(&wire.Frame{Type: wire.FrameTypeClientHello})
+ require.ErrorContains(t, err, "unexpected frame type")
+
+ _, err = DecodeV2MessageFrame(&wire.Frame{Type: wire.FrameTypeMessage, Payload: []byte{0}})
+ require.ErrorContains(t, err, "payload too short")
+
+ payload := make([]byte, 4)
+ binary.BigEndian.PutUint16(payload[:2], 65535)
+ copy(payload[2:], []byte("{}"))
+ _, err = DecodeV2MessageFrame(&wire.Frame{Type: wire.FrameTypeMessage, Payload: payload})
+ require.ErrorContains(t, err, "unknown v2 message type")
+}
+
+func TestDecodeV2MessageFrameIntoRejectsWrongTarget(t *testing.T) {
+ frame, err := EncodeV2MessageFrame(&ReqWorkConn{})
+ require.NoError(t, err)
+
+ var out StartWorkConn
+ err = DecodeV2MessageFrameInto(frame, &out)
+ require.ErrorContains(t, err, "unexpected message type")
+
+ err = DecodeV2MessageFrameInto(frame, StartWorkConn{})
+ require.ErrorContains(t, err, "must be a pointer")
+}
+
+func TestEncodeV2MessageFrameRejectsUnknownMessage(t *testing.T) {
+ _, err := EncodeV2MessageFrame(struct{}{})
+ require.ErrorContains(t, err, "unknown v2 message type")
+}
diff --git a/pkg/proto/wire/wire.go b/pkg/proto/wire/wire.go
new file mode 100644
index 00000000..5c38292d
--- /dev/null
+++ b/pkg/proto/wire/wire.go
@@ -0,0 +1,222 @@
+// Copyright 2026 The frp Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package wire
+
+import (
+ "encoding/binary"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net"
+ "slices"
+
+ libnet "github.com/fatedier/golib/net"
+)
+
+const (
+ ProtocolV1 = "v1"
+ ProtocolV2 = "v2"
+
+ WireVersionV2 = 2
+
+ FrameTypeClientHello uint16 = 1
+ FrameTypeServerHello uint16 = 2
+ FrameTypeMessage uint16 = 16
+
+ MessageCodecJSON = "json"
+ DefaultMaxFramePayloadSize = 64 * 1024
+
+ MagicV2 = "FRP\x00\x02\r\n"
+)
+
+type Frame struct {
+ Type uint16
+ Flags uint16
+ Payload []byte
+}
+
+type Conn struct {
+ rw io.ReadWriter
+ maxFramePayloadSize uint32
+}
+
+func NewConn(rw io.ReadWriter) *Conn {
+ return &Conn{
+ rw: rw,
+ maxFramePayloadSize: DefaultMaxFramePayloadSize,
+ }
+}
+
+func (c *Conn) ReadFrame() (*Frame, error) {
+ header := make([]byte, 8)
+ if _, err := io.ReadFull(c.rw, header); err != nil {
+ return nil, err
+ }
+
+ frameType := binary.BigEndian.Uint16(header[0:2])
+ flags := binary.BigEndian.Uint16(header[2:4])
+ length := binary.BigEndian.Uint32(header[4:8])
+ if flags != 0 {
+ return nil, fmt.Errorf("unsupported frame flags: %d", flags)
+ }
+ if length > c.maxFramePayloadSize {
+ return nil, fmt.Errorf("frame payload length %d exceeds limit %d", length, c.maxFramePayloadSize)
+ }
+
+ payload := make([]byte, length)
+ if _, err := io.ReadFull(c.rw, payload); err != nil {
+ return nil, err
+ }
+ return &Frame{
+ Type: frameType,
+ Flags: flags,
+ Payload: payload,
+ }, nil
+}
+
+func (c *Conn) WriteFrame(f *Frame) error {
+ if f.Flags != 0 {
+ return fmt.Errorf("unsupported frame flags: %d", f.Flags)
+ }
+ if len(f.Payload) > int(c.maxFramePayloadSize) {
+ return fmt.Errorf("frame payload length %d exceeds limit %d", len(f.Payload), c.maxFramePayloadSize)
+ }
+
+ header := make([]byte, 8)
+ binary.BigEndian.PutUint16(header[0:2], f.Type)
+ binary.BigEndian.PutUint16(header[2:4], f.Flags)
+ binary.BigEndian.PutUint32(header[4:8], uint32(len(f.Payload)))
+ if _, err := c.rw.Write(header); err != nil {
+ return err
+ }
+ _, err := c.rw.Write(f.Payload)
+ return err
+}
+
+func (c *Conn) ReadJSONFrame(frameType uint16, out any) error {
+ f, err := c.ReadFrame()
+ if err != nil {
+ return err
+ }
+ if f.Type != frameType {
+ return fmt.Errorf("unexpected frame type %d, want %d", f.Type, frameType)
+ }
+ return c.UnmarshalFrame(f, out)
+}
+
+func (c *Conn) UnmarshalFrame(f *Frame, out any) error {
+ return json.Unmarshal(f.Payload, out)
+}
+
+func (c *Conn) WriteJSONFrame(frameType uint16, in any) error {
+ payload, err := json.Marshal(in)
+ if err != nil {
+ return err
+ }
+ return c.WriteFrame(&Frame{
+ Type: frameType,
+ Payload: payload,
+ })
+}
+
+func WriteMagic(w io.Writer) error {
+ _, err := io.WriteString(w, MagicV2)
+ return err
+}
+
+func WriteMagicIfV2(w io.Writer, wireProtocol string) error {
+ if wireProtocol != ProtocolV2 {
+ return nil
+ }
+ return WriteMagic(w)
+}
+
+func CheckMagic(conn net.Conn) (out net.Conn, isV2 bool, err error) {
+ sharedConn, r := libnet.NewSharedConnSize(conn, len(MagicV2))
+ buf := make([]byte, len(MagicV2))
+ if _, err = io.ReadFull(r, buf); err != nil {
+ return nil, false, err
+ }
+ for i := range MagicV2 {
+ if buf[i] != MagicV2[i] {
+ return sharedConn, false, nil
+ }
+ }
+ return conn, true, nil
+}
+
+type BootstrapInfo struct {
+ Transport string `json:"transport,omitempty"`
+ TLS bool `json:"tls,omitempty"`
+ TCPMux bool `json:"tcpMux,omitempty"`
+}
+
+type ClientHello struct {
+ Bootstrap BootstrapInfo `json:"bootstrap,omitempty"`
+ Capabilities ClientCapabilities `json:"capabilities,omitempty"`
+}
+
+type ClientCapabilities struct {
+ Message MessageCapabilities `json:"message,omitempty"`
+}
+
+type MessageCapabilities struct {
+ Codecs []string `json:"codecs,omitempty"`
+}
+
+type ServerHello struct {
+ Selected ServerSelection `json:"selected,omitempty"`
+ Error string `json:"error,omitempty"`
+}
+
+type ServerSelection struct {
+ Message MessageSelection `json:"message,omitempty"`
+}
+
+type MessageSelection struct {
+ Codec string `json:"codec,omitempty"`
+}
+
+func DefaultClientHello(bootstrap BootstrapInfo) ClientHello {
+ return ClientHello{
+ Bootstrap: bootstrap,
+ Capabilities: ClientCapabilities{
+ Message: MessageCapabilities{
+ Codecs: []string{MessageCodecJSON},
+ },
+ },
+ }
+}
+
+func DefaultServerHello() ServerHello {
+ return ServerHello{
+ Selected: ServerSelection{
+ Message: MessageSelection{
+ Codec: MessageCodecJSON,
+ },
+ },
+ }
+}
+
+func Supports(list []string, value string) bool {
+ return slices.Contains(list, value)
+}
+
+func ValidateClientHello(h ClientHello) error {
+ if !Supports(h.Capabilities.Message.Codecs, MessageCodecJSON) {
+ return fmt.Errorf("unsupported message codec")
+ }
+ return nil
+}
diff --git a/pkg/proto/wire/wire_test.go b/pkg/proto/wire/wire_test.go
new file mode 100644
index 00000000..19adb0be
--- /dev/null
+++ b/pkg/proto/wire/wire_test.go
@@ -0,0 +1,120 @@
+// Copyright 2026 The frp Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package wire
+
+import (
+ "bytes"
+ "encoding/binary"
+ "io"
+ "net"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestFrameRoundTrip(t *testing.T) {
+ var buf bytes.Buffer
+ conn := NewConn(&buf)
+
+ in := DefaultClientHello(BootstrapInfo{
+ Transport: "tcp",
+ TLS: true,
+ TCPMux: true,
+ })
+ require.NoError(t, conn.WriteJSONFrame(FrameTypeClientHello, in))
+
+ var out ClientHello
+ require.NoError(t, conn.ReadJSONFrame(FrameTypeClientHello, &out))
+ require.Equal(t, in, out)
+}
+
+func TestReadFrameRejectsUnsupportedFlags(t *testing.T) {
+ var buf bytes.Buffer
+ header := make([]byte, 8)
+ binary.BigEndian.PutUint16(header[0:2], FrameTypeMessage)
+ binary.BigEndian.PutUint16(header[2:4], 1)
+ binary.BigEndian.PutUint32(header[4:8], 0)
+ buf.Write(header)
+
+ _, err := NewConn(&buf).ReadFrame()
+ require.ErrorContains(t, err, "unsupported frame flags")
+}
+
+func TestReadFrameRejectsOversizedPayload(t *testing.T) {
+ var buf bytes.Buffer
+ header := make([]byte, 8)
+ binary.BigEndian.PutUint16(header[0:2], FrameTypeMessage)
+ binary.BigEndian.PutUint32(header[4:8], DefaultMaxFramePayloadSize+1)
+ buf.Write(header)
+
+ _, err := NewConn(&buf).ReadFrame()
+ require.ErrorContains(t, err, "exceeds limit")
+}
+
+func TestCheckMagicV2ConsumesMagic(t *testing.T) {
+ client, server := net.Pipe()
+ defer server.Close()
+
+ want := []byte("payload")
+ go func() {
+ defer client.Close()
+ _, _ = client.Write(append([]byte(MagicV2), want...))
+ }()
+
+ out, isV2, err := CheckMagic(server)
+ require.NoError(t, err)
+ require.True(t, isV2)
+
+ got := make([]byte, len(want))
+ _, err = io.ReadFull(out, got)
+ require.NoError(t, err)
+ require.Equal(t, want, got)
+}
+
+func TestWriteMagicIfV2(t *testing.T) {
+ var buf bytes.Buffer
+ require.NoError(t, WriteMagicIfV2(&buf, ProtocolV1))
+ require.Empty(t, buf.Bytes())
+
+ require.NoError(t, WriteMagicIfV2(&buf, ProtocolV2))
+ require.Equal(t, []byte(MagicV2), buf.Bytes())
+}
+
+func TestCheckMagicV1PreservesReadBytes(t *testing.T) {
+ client, server := net.Pipe()
+ defer server.Close()
+
+ want := []byte("legacy payload")
+ go func() {
+ defer client.Close()
+ _, _ = client.Write(want)
+ }()
+
+ out, isV2, err := CheckMagic(server)
+ require.NoError(t, err)
+ require.False(t, isV2)
+
+ got, err := io.ReadAll(out)
+ require.NoError(t, err)
+ require.Equal(t, want, got)
+}
+
+func TestValidateClientHello(t *testing.T) {
+ require.NoError(t, ValidateClientHello(DefaultClientHello(BootstrapInfo{})))
+
+ hello := DefaultClientHello(BootstrapInfo{})
+ hello.Capabilities.Message.Codecs = []string{"unknown"}
+ require.ErrorContains(t, ValidateClientHello(hello), "unsupported message codec")
+}
diff --git a/pkg/util/net/conn.go b/pkg/util/net/conn.go
index 914a7bb5..aa83b409 100644
--- a/pkg/util/net/conn.go
+++ b/pkg/util/net/conn.go
@@ -133,7 +133,7 @@ type CloseNotifyConn struct {
net.Conn
// 1 means closed
- closeFlag int32
+ closeFlag atomic.Int32
closeFn func(error)
}
@@ -147,7 +147,7 @@ func WrapCloseNotifyConn(c net.Conn, closeFn func(error)) *CloseNotifyConn {
}
func (cc *CloseNotifyConn) Close() (err error) {
- pflag := atomic.SwapInt32(&cc.closeFlag, 1)
+ pflag := cc.closeFlag.Swap(1)
if pflag == 0 {
err = cc.Conn.Close()
if cc.closeFn != nil {
@@ -159,7 +159,7 @@ func (cc *CloseNotifyConn) Close() (err error) {
// CloseWithError closes the connection and passes the error to the close callback.
func (cc *CloseNotifyConn) CloseWithError(err error) error {
- pflag := atomic.SwapInt32(&cc.closeFlag, 1)
+ pflag := cc.closeFlag.Swap(1)
if pflag == 0 {
closeErr := cc.Conn.Close()
if cc.closeFn != nil {
@@ -173,7 +173,7 @@ func (cc *CloseNotifyConn) CloseWithError(err error) error {
type StatsConn struct {
net.Conn
- closed int64 // 1 means closed
+ closed atomic.Int64 // 1 means closed
totalRead int64
totalWrite int64
statsFunc func(totalRead, totalWrite int64)
@@ -199,7 +199,7 @@ func (statsConn *StatsConn) Write(p []byte) (n int, err error) {
}
func (statsConn *StatsConn) Close() (err error) {
- old := atomic.SwapInt64(&statsConn.closed, 1)
+ old := statsConn.closed.Swap(1)
if old != 1 {
err = statsConn.Conn.Close()
if statsConn.statsFunc != nil {
diff --git a/server/control.go b/server/control.go
index 314669ff..d7b41c43 100644
--- a/server/control.go
+++ b/server/control.go
@@ -17,7 +17,6 @@ package server
import (
"context"
"fmt"
- "net"
"runtime/debug"
"sync"
"sync/atomic"
@@ -32,9 +31,7 @@ import (
"github.com/fatedier/frp/pkg/msg"
plugin "github.com/fatedier/frp/pkg/plugin/server"
"github.com/fatedier/frp/pkg/transport"
- netpkg "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/util"
- "github.com/fatedier/frp/pkg/util/version"
"github.com/fatedier/frp/pkg/util/wait"
"github.com/fatedier/frp/pkg/util/xlog"
"github.com/fatedier/frp/server/controller"
@@ -108,9 +105,7 @@ type SessionContext struct {
// key used for connection encryption
EncryptionKey []byte
// control connection
- Conn net.Conn
- // indicates whether the connection is encrypted
- ConnEncrypted bool
+ Conn *msg.Conn
// login message
LoginMsg *msg.Login
// server configuration
@@ -131,7 +126,7 @@ type Control struct {
msgDispatcher *msg.Dispatcher
// work connections
- workConnCh chan net.Conn
+ workConnCh chan *proxy.WorkConn
// proxies in one client
proxies map[string]proxy.Proxy
@@ -161,7 +156,7 @@ func NewControl(ctx context.Context, sessionCtx *SessionContext) (*Control, erro
poolCount := min(sessionCtx.LoginMsg.PoolCount, int(sessionCtx.ServerCfg.Transport.MaxPoolCount))
ctl := &Control{
sessionCtx: sessionCtx,
- workConnCh: make(chan net.Conn, poolCount+10),
+ workConnCh: make(chan *proxy.WorkConn, poolCount+10),
proxies: make(map[string]proxy.Proxy),
poolCount: poolCount,
portsUsedNum: 0,
@@ -172,29 +167,14 @@ func NewControl(ctx context.Context, sessionCtx *SessionContext) (*Control, erro
}
ctl.lastPing.Store(time.Now())
- if sessionCtx.ConnEncrypted {
- cryptoRW, err := netpkg.NewCryptoReadWriter(sessionCtx.Conn, sessionCtx.EncryptionKey)
- if err != nil {
- return nil, err
- }
- ctl.msgDispatcher = msg.NewDispatcher(cryptoRW)
- } else {
- ctl.msgDispatcher = msg.NewDispatcher(sessionCtx.Conn)
- }
+ ctl.msgDispatcher = msg.NewDispatcher(sessionCtx.Conn)
ctl.registerMsgHandlers()
ctl.msgTransporter = transport.NewMessageTransporter(ctl.msgDispatcher)
return ctl, nil
}
-// Start send a login success message to client and start working.
+// Start starts the control session workers after login succeeds.
func (ctl *Control) Start() {
- loginRespMsg := &msg.LoginResp{
- Version: version.Full(),
- RunID: ctl.runID,
- Error: "",
- }
- _ = msg.WriteMsg(ctl.sessionCtx.Conn, loginRespMsg)
-
go func() {
for i := 0; i < ctl.poolCount; i++ {
// ignore error here, that means that this control is closed
@@ -216,7 +196,7 @@ func (ctl *Control) Replaced(newCtl *Control) {
ctl.sessionCtx.Conn.Close()
}
-func (ctl *Control) RegisterWorkConn(conn net.Conn) error {
+func (ctl *Control) RegisterWorkConn(conn *proxy.WorkConn) error {
xl := ctl.xl
defer func() {
if err := recover(); err != nil {
@@ -239,7 +219,7 @@ func (ctl *Control) RegisterWorkConn(conn net.Conn) error {
// If no workConn available in the pool, send message to frpc to get one or more
// and wait until it is available.
// return an error if wait timeout
-func (ctl *Control) GetWorkConn() (workConn net.Conn, err error) {
+func (ctl *Control) GetWorkConn() (workConn *proxy.WorkConn, err error) {
xl := ctl.xl
defer func() {
if err := recover(); err != nil {
diff --git a/server/group/http.go b/server/group/http.go
index dd905581..21bbfd29 100644
--- a/server/group/http.go
+++ b/server/group/http.go
@@ -57,7 +57,7 @@ type HTTPGroup struct {
// CreateConnFuncs indexed by proxy name
createFuncs map[string]vhost.CreateConnFunc
pxyNames []string
- index uint64
+ index atomic.Uint64
ctl *HTTPGroupController
mu sync.RWMutex
}
@@ -136,7 +136,7 @@ func (g *HTTPGroup) UnRegister(proxyName string) {
func (g *HTTPGroup) createConn(remoteAddr string) (net.Conn, error) {
var f vhost.CreateConnFunc
- newIndex := atomic.AddUint64(&g.index, 1)
+ newIndex := g.index.Add(1)
g.mu.RLock()
group := g.group
@@ -158,7 +158,7 @@ func (g *HTTPGroup) createConn(remoteAddr string) (net.Conn, error) {
}
func (g *HTTPGroup) chooseEndpoint() (string, error) {
- newIndex := atomic.AddUint64(&g.index, 1)
+ newIndex := g.index.Add(1)
name := ""
g.mu.RLock()
diff --git a/server/http/controller.go b/server/http/controller.go
index a1842788..65f6c00f 100644
--- a/server/http/controller.go
+++ b/server/http/controller.go
@@ -287,6 +287,7 @@ func buildClientInfoResp(info registry.ClientInfo) model.ClientInfoResp {
ClientID: info.ClientID(),
RunID: info.RunID,
Version: info.Version,
+ WireProtocol: info.WireProtocol,
Hostname: info.Hostname,
ClientIP: info.IP,
FirstConnectedAt: toUnix(info.FirstConnectedAt),
diff --git a/server/http/controller_test.go b/server/http/controller_test.go
index 3ef50776..44509371 100644
--- a/server/http/controller_test.go
+++ b/server/http/controller_test.go
@@ -17,8 +17,11 @@ package http
import (
"encoding/json"
"testing"
+ "time"
v1 "github.com/fatedier/frp/pkg/config/v1"
+ "github.com/fatedier/frp/pkg/proto/wire"
+ "github.com/fatedier/frp/server/registry"
)
func TestGetConfFromConfigurerKeepsPluginFields(t *testing.T) {
@@ -69,3 +72,24 @@ func TestGetConfFromConfigurerKeepsPluginFields(t *testing.T) {
t.Fatalf("plugin httpPassword mismatch, want %q got %#v", "password", got)
}
}
+
+func TestBuildClientInfoRespIncludesWireProtocol(t *testing.T) {
+ info := registry.ClientInfo{
+ Key: "user.client",
+ User: "user",
+ RawClientID: "client",
+ RunID: "run-id",
+ Version: "1.0.0",
+ WireProtocol: wire.ProtocolV2,
+ Hostname: "host",
+ IP: "127.0.0.1",
+ FirstConnectedAt: time.Unix(1, 0),
+ LastConnectedAt: time.Unix(2, 0),
+ Online: true,
+ }
+
+ resp := buildClientInfoResp(info)
+ if resp.WireProtocol != wire.ProtocolV2 {
+ t.Fatalf("wire protocol mismatch, want %q got %q", wire.ProtocolV2, resp.WireProtocol)
+ }
+}
diff --git a/server/http/model/types.go b/server/http/model/types.go
index 92467e4b..010c5079 100644
--- a/server/http/model/types.go
+++ b/server/http/model/types.go
@@ -46,6 +46,7 @@ type ClientInfoResp struct {
ClientID string `json:"clientID"`
RunID string `json:"runID"`
Version string `json:"version,omitempty"`
+ WireProtocol string `json:"wireProtocol,omitempty"`
Hostname string `json:"hostname"`
ClientIP string `json:"clientIP,omitempty"`
FirstConnectedAt int64 `json:"firstConnectedAt"`
diff --git a/server/proxy/proxy.go b/server/proxy/proxy.go
index 5b7898eb..7175a606 100644
--- a/server/proxy/proxy.go
+++ b/server/proxy/proxy.go
@@ -44,7 +44,26 @@ func RegisterProxyFactory(proxyConfType reflect.Type, factory func(*BaseProxy) P
proxyFactoryRegistry[proxyConfType] = factory
}
-type GetWorkConnFn func() (net.Conn, error)
+type WorkConn struct {
+ conn *msg.Conn
+}
+
+func NewWorkConn(conn *msg.Conn) *WorkConn {
+ return &WorkConn{conn: conn}
+}
+
+func (c *WorkConn) Start(m *msg.StartWorkConn) (net.Conn, error) {
+ if err := c.conn.WriteMsg(m); err != nil {
+ return nil, err
+ }
+ return c.conn, nil
+}
+
+func (c *WorkConn) Close() error {
+ return c.conn.Close()
+}
+
+type GetWorkConnFn func() (*WorkConn, error)
type Proxy interface {
Context() context.Context
@@ -125,13 +144,13 @@ func (pxy *BaseProxy) GetWorkConnFromPool(src, dst net.Addr) (workConn net.Conn,
xl := xlog.FromContextSafe(pxy.ctx)
// try all connections from the pool
for i := 0; i < pxy.poolCount+1; i++ {
- if workConn, err = pxy.getWorkConnFn(); err != nil {
+ var pxyWorkConn *WorkConn
+ if pxyWorkConn, err = pxy.getWorkConnFn(); err != nil {
xl.Warnf("failed to get work connection: %v", err)
return
}
- xl.Debugf("get a new work connection: [%s]", workConn.RemoteAddr().String())
+ xl.Debugf("get a new work connection: [%s]", pxyWorkConn.conn.RemoteAddr().String())
xl.Spawn().AppendPrefix(pxy.GetName())
- workConn = netpkg.NewContextConn(pxy.ctx, workConn)
var (
srcAddr string
@@ -150,7 +169,7 @@ func (pxy *BaseProxy) GetWorkConnFromPool(src, dst net.Addr) (workConn net.Conn,
dstAddr, dstPortStr, _ = net.SplitHostPort(dst.String())
dstPort, _ = strconv.ParseUint(dstPortStr, 10, 16)
}
- err = msg.WriteMsg(workConn, &msg.StartWorkConn{
+ workConn, err = pxyWorkConn.Start(&msg.StartWorkConn{
ProxyName: pxy.GetName(),
SrcAddr: srcAddr,
SrcPort: uint16(srcPort),
@@ -160,9 +179,10 @@ func (pxy *BaseProxy) GetWorkConnFromPool(src, dst net.Addr) (workConn net.Conn,
})
if err != nil {
xl.Warnf("failed to send message to work connection from pool: %v, times: %d", err, i)
- workConn.Close()
+ pxyWorkConn.Close()
workConn = nil
} else {
+ workConn = netpkg.NewContextConn(pxy.ctx, workConn)
break
}
}
diff --git a/server/proxy/proxy_test.go b/server/proxy/proxy_test.go
new file mode 100644
index 00000000..38f9dd42
--- /dev/null
+++ b/server/proxy/proxy_test.go
@@ -0,0 +1,53 @@
+// Copyright 2026 The frp Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proxy
+
+import (
+ "net"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ "github.com/fatedier/frp/pkg/msg"
+)
+
+func TestWorkConnStartWritesStartWorkConn(t *testing.T) {
+ client, server := net.Pipe()
+ defer client.Close()
+ defer server.Close()
+
+ serverMsgConn := msg.NewConn(server, msg.NewV2ReadWriter(server))
+ clientMsgConn := msg.NewConn(client, msg.NewV2ReadWriter(client))
+ workConn := NewWorkConn(serverMsgConn)
+
+ in := &msg.StartWorkConn{ProxyName: "tcp", SrcAddr: "127.0.0.1", SrcPort: 1234}
+ type startResult struct {
+ conn net.Conn
+ err error
+ }
+ resultCh := make(chan startResult, 1)
+ go func() {
+ conn, err := workConn.Start(in)
+ resultCh <- startResult{conn: conn, err: err}
+ }()
+
+ out, err := clientMsgConn.ReadMsg()
+ require.NoError(t, err)
+ require.Equal(t, in, out)
+
+ result := <-resultCh
+ require.NoError(t, result.err)
+ require.Same(t, serverMsgConn, result.conn)
+}
diff --git a/server/registry/registry.go b/server/registry/registry.go
index 01c44947..c521d1c2 100644
--- a/server/registry/registry.go
+++ b/server/registry/registry.go
@@ -29,6 +29,7 @@ type ClientInfo struct {
Hostname string
IP string
Version string
+ WireProtocol string
FirstConnectedAt time.Time
LastConnectedAt time.Time
DisconnectedAt time.Time
@@ -51,7 +52,7 @@ func NewClientRegistry() *ClientRegistry {
}
// Register stores/updates metadata for a client and returns the registry key plus whether it conflicts with an online client.
-func (cr *ClientRegistry) Register(user, rawClientID, runID, hostname, version, remoteAddr string) (key string, conflict bool) {
+func (cr *ClientRegistry) Register(user, rawClientID, runID, hostname, version, remoteAddr, wireProtocol string) (key string, conflict bool) {
if runID == "" {
return "", false
}
@@ -88,6 +89,7 @@ func (cr *ClientRegistry) Register(user, rawClientID, runID, hostname, version,
info.Hostname = hostname
info.IP = remoteAddr
info.Version = version
+ info.WireProtocol = wireProtocol
if info.FirstConnectedAt.IsZero() {
info.FirstConnectedAt = now
}
diff --git a/server/registry/registry_test.go b/server/registry/registry_test.go
new file mode 100644
index 00000000..cf428964
--- /dev/null
+++ b/server/registry/registry_test.go
@@ -0,0 +1,37 @@
+// Copyright 2026 The frp Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package registry
+
+import (
+ "testing"
+
+ "github.com/fatedier/frp/pkg/proto/wire"
+)
+
+func TestClientRegistryRegisterStoresWireProtocol(t *testing.T) {
+ registry := NewClientRegistry()
+ key, conflict := registry.Register("user", "client-id", "run-id", "host", "1.0.0", "127.0.0.1", wire.ProtocolV2)
+ if conflict {
+ t.Fatal("unexpected client conflict")
+ }
+
+ info, ok := registry.GetByKey(key)
+ if !ok {
+ t.Fatalf("client %q not found", key)
+ }
+ if info.WireProtocol != wire.ProtocolV2 {
+ t.Fatalf("wire protocol mismatch, want %q got %q", wire.ProtocolV2, info.WireProtocol)
+ }
+}
diff --git a/server/service.go b/server/service.go
index 28ccb451..5077c21f 100644
--- a/server/service.go
+++ b/server/service.go
@@ -19,6 +19,7 @@ import (
"context"
"crypto/tls"
"fmt"
+ "io"
"net"
"net/http"
"os"
@@ -37,6 +38,7 @@ import (
"github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/nathole"
plugin "github.com/fatedier/frp/pkg/plugin/server"
+ "github.com/fatedier/frp/pkg/proto/wire"
"github.com/fatedier/frp/pkg/ssh"
"github.com/fatedier/frp/pkg/transport"
httppkg "github.com/fatedier/frp/pkg/util/http"
@@ -432,20 +434,15 @@ func (svr *Service) Close() error {
func (svr *Service) handleConnection(ctx context.Context, conn net.Conn, internal bool) {
xl := xlog.FromContextSafe(ctx)
- var (
- rawMsg msg.Message
- err error
- )
-
- _ = conn.SetReadDeadline(time.Now().Add(connReadTimeout))
- if rawMsg, err = msg.ReadMsg(conn); err != nil {
- log.Tracef("failed to read message: %v", err)
+ acceptedConn, err := svr.acceptConnection(ctx, conn)
+ if err != nil {
+ log.Tracef("failed to accept frp connection: %v", err)
conn.Close()
return
}
- _ = conn.SetReadDeadline(time.Time{})
+ conn = acceptedConn.conn
- switch m := rawMsg.(type) {
+ switch m := acceptedConn.firstMsg.(type) {
case *msg.Login:
// server plugin hook
content := &plugin.LoginContent{
@@ -453,35 +450,66 @@ func (svr *Service) handleConnection(ctx context.Context, conn net.Conn, interna
ClientAddress: conn.RemoteAddr().String(),
}
retContent, err := svr.pluginManager.Login(content)
+ var ctl *Control
if err == nil {
m = &retContent.Login
- err = svr.RegisterControl(conn, m, internal)
+ controlConn := acceptedConn.conn
+ if !internal {
+ var controlRW io.ReadWriter
+ controlRW, err = netpkg.NewCryptoReadWriter(conn, svr.auth.EncryptionKey())
+ if err == nil {
+ controlConn = acceptedConn.messageConnFor(controlRW)
+ }
+ }
+ if err == nil {
+ ctl, err = svr.RegisterControl(controlConn, m, internal, acceptedConn.wireProtocol)
+ }
}
- // If login failed, send error message there.
- // Otherwise send success message in control's work goroutine.
if err != nil {
xl.Warnf("register control error: %v", err)
- _ = msg.WriteMsg(conn, &msg.LoginResp{
+ _ = acceptedConn.conn.WriteMsg(&msg.LoginResp{
Version: version.Full(),
Error: util.GenerateResponseErrorString("register control error", err, lo.FromPtr(svr.cfg.DetailedErrorsToClient)),
})
conn.Close()
+ return
}
+ if err = acceptedConn.conn.WriteMsg(&msg.LoginResp{
+ Version: version.Full(),
+ RunID: ctl.runID,
+ Error: "",
+ }); err != nil {
+ xl.Warnf("write login response error: %v", err)
+ svr.ctlManager.Del(m.RunID, ctl)
+ svr.clientRegistry.MarkOfflineByRunID(m.RunID)
+ conn.Close()
+ return
+ }
+ ctl.Start()
+ metrics.Server.NewClient()
+ go func() {
+ // block until control closed
+ ctl.WaitClosed()
+ svr.ctlManager.Del(m.RunID, ctl)
+ }()
case *msg.NewWorkConn:
- if err := svr.RegisterWorkConn(conn, m); err != nil {
+ if err := svr.RegisterWorkConn(acceptedConn.conn, m); err != nil {
+ _ = acceptedConn.conn.WriteMsg(&msg.StartWorkConn{
+ Error: util.GenerateResponseErrorString("invalid NewWorkConn", err, lo.FromPtr(svr.cfg.DetailedErrorsToClient)),
+ })
conn.Close()
}
case *msg.NewVisitorConn:
if err = svr.RegisterVisitorConn(conn, m); err != nil {
xl.Warnf("register visitor conn error: %v", err)
- _ = msg.WriteMsg(conn, &msg.NewVisitorConnResp{
+ _ = acceptedConn.conn.WriteMsg(&msg.NewVisitorConnResp{
ProxyName: m.ProxyName,
Error: util.GenerateResponseErrorString("register visitor conn error", err, lo.FromPtr(svr.cfg.DetailedErrorsToClient)),
})
conn.Close()
} else {
- _ = msg.WriteMsg(conn, &msg.NewVisitorConnResp{
+ _ = acceptedConn.conn.WriteMsg(&msg.NewVisitorConnResp{
ProxyName: m.ProxyName,
Error: "",
})
@@ -492,6 +520,87 @@ func (svr *Service) handleConnection(ctx context.Context, conn net.Conn, interna
}
}
+type acceptedConnection struct {
+ conn *msg.Conn
+ wireProtocol string
+ firstMsg msg.Message
+}
+
+func (svr *Service) acceptConnection(ctx context.Context, conn net.Conn) (*acceptedConnection, error) {
+ _ = conn.SetReadDeadline(time.Now().Add(connReadTimeout))
+ checkedConn, isV2, err := wire.CheckMagic(conn)
+ if err != nil {
+ return nil, fmt.Errorf("read wire protocol magic: %w", err)
+ }
+
+ wireProtocol := wire.ProtocolV1
+ if isV2 {
+ wireProtocol = wire.ProtocolV2
+ }
+
+ conn = netpkg.NewContextConn(ctx, checkedConn)
+ acceptedConn := &acceptedConnection{wireProtocol: wireProtocol}
+ if isV2 {
+ wireConn := wire.NewConn(conn)
+ rw := msg.NewV2ReadWriterWithConn(wireConn)
+ acceptedConn.conn = msg.NewConn(conn, rw)
+ acceptedConn.firstMsg, err = acceptedConn.readFirstV2Msg(wireConn)
+ } else {
+ rw := msg.NewV1ReadWriter(conn)
+ acceptedConn.conn = msg.NewConn(conn, rw)
+ acceptedConn.firstMsg, err = acceptedConn.conn.ReadMsg()
+ }
+ if err != nil {
+ return nil, err
+ }
+ _ = conn.SetReadDeadline(time.Time{})
+ return acceptedConn, nil
+}
+
+func (ac *acceptedConnection) messageConnFor(rw io.ReadWriter) *msg.Conn {
+ return msg.NewConn(ac.conn, msg.NewReadWriter(rw, ac.wireProtocol))
+}
+
+func (ac *acceptedConnection) readFirstV2Msg(wireConn *wire.Conn) (msg.Message, error) {
+ frame, err := wireConn.ReadFrame()
+ if err != nil {
+ return nil, fmt.Errorf("read v2 frame: %w", err)
+ }
+ if frame.Type == wire.FrameTypeClientHello {
+ if err := ac.handleClientHello(wireConn, frame); err != nil {
+ return nil, err
+ }
+ frame, err = wireConn.ReadFrame()
+ if err != nil {
+ return nil, fmt.Errorf("read first v2 message frame: %w", err)
+ }
+ }
+
+ m, err := msg.DecodeV2MessageFrame(frame)
+ if err != nil {
+ return nil, fmt.Errorf("decode v2 message: %w", err)
+ }
+ return m, nil
+}
+
+func (ac *acceptedConnection) handleClientHello(wireConn *wire.Conn, frame *wire.Frame) error {
+ var hello wire.ClientHello
+ if err := wireConn.UnmarshalFrame(frame, &hello); err != nil {
+ return fmt.Errorf("decode ClientHello: %w", err)
+ }
+
+ serverHello := wire.DefaultServerHello()
+ if err := wire.ValidateClientHello(hello); err != nil {
+ serverHello.Error = err.Error()
+ _ = wireConn.WriteJSONFrame(wire.FrameTypeServerHello, serverHello)
+ return err
+ }
+ if err := wireConn.WriteJSONFrame(wire.FrameTypeServerHello, serverHello); err != nil {
+ return fmt.Errorf("write ServerHello: %w", err)
+ }
+ return nil
+}
+
// HandleListener accepts connections from client and call handleConnection to handle them.
// If internal is true, it means that this listener is used for internal communication like ssh tunnel gateway.
// TODO(fatedier): Pass some parameters of listener/connection through context to avoid passing too many parameters.
@@ -577,14 +686,19 @@ func (svr *Service) HandleQUICListener(l *quic.Listener) {
}
}
-func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, internal bool) error {
+func (svr *Service) RegisterControl(
+ ctlConn *msg.Conn,
+ loginMsg *msg.Login,
+ internal bool,
+ wireProtocol string,
+) (*Control, error) {
// If client's RunID is empty, it's a new client, we just create a new controller.
// Otherwise, we check if there is one controller has the same run id. If so, we release previous controller and start new one.
var err error
if loginMsg.RunID == "" {
loginMsg.RunID, err = util.RandID()
if err != nil {
- return err
+ return nil, err
}
}
@@ -601,7 +715,7 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, inter
authVerifier = auth.AlwaysPassVerifier
}
if err := authVerifier.VerifyLogin(loginMsg); err != nil {
- return err
+ return nil, err
}
ctl, err := NewControl(ctx, &SessionContext{
@@ -611,7 +725,6 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, inter
AuthVerifier: authVerifier,
EncryptionKey: svr.auth.EncryptionKey(),
Conn: ctlConn,
- ConnEncrypted: !internal,
LoginMsg: loginMsg,
ServerCfg: svr.cfg,
ClientRegistry: svr.clientRegistry,
@@ -619,7 +732,7 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, inter
if err != nil {
xl.Warnf("create new controller error: %v", err)
// don't return detailed errors to client
- return fmt.Errorf("unexpected error when creating new controller")
+ return nil, fmt.Errorf("unexpected error when creating new controller")
}
if oldCtl := svr.ctlManager.Add(loginMsg.RunID, ctl); oldCtl != nil {
@@ -630,34 +743,24 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, inter
if host, _, err := net.SplitHostPort(remoteAddr); err == nil {
remoteAddr = host
}
- _, conflict := svr.clientRegistry.Register(loginMsg.User, loginMsg.ClientID, loginMsg.RunID, loginMsg.Hostname, loginMsg.Version, remoteAddr)
+ _, conflict := svr.clientRegistry.Register(loginMsg.User, loginMsg.ClientID, loginMsg.RunID, loginMsg.Hostname, loginMsg.Version, remoteAddr, wireProtocol)
if conflict {
svr.ctlManager.Del(loginMsg.RunID, ctl)
- ctl.Close()
- return fmt.Errorf("client_id [%s] for user [%s] is already online", loginMsg.ClientID, loginMsg.User)
+ return nil, fmt.Errorf("client_id [%s] for user [%s] is already online", loginMsg.ClientID, loginMsg.User)
}
- ctl.Start()
-
- // for statistics
- metrics.Server.NewClient()
-
- go func() {
- // block until control closed
- ctl.WaitClosed()
- svr.ctlManager.Del(loginMsg.RunID, ctl)
- }()
- return nil
+ return ctl, nil
}
// RegisterWorkConn register a new work connection to control and proxies need it.
-func (svr *Service) RegisterWorkConn(workConn net.Conn, newMsg *msg.NewWorkConn) error {
+func (svr *Service) RegisterWorkConn(workConn *msg.Conn, newMsg *msg.NewWorkConn) error {
xl := netpkg.NewLogFromConn(workConn)
ctl, exist := svr.ctlManager.GetByID(newMsg.RunID)
if !exist {
xl.Warnf("no client control found for run id [%s]", newMsg.RunID)
return fmt.Errorf("no client control found for run id [%s]", newMsg.RunID)
}
+
// server plugin hook
content := &plugin.NewWorkConnContent{
User: plugin.UserInfo{
@@ -675,12 +778,9 @@ func (svr *Service) RegisterWorkConn(workConn net.Conn, newMsg *msg.NewWorkConn)
}
if err != nil {
xl.Warnf("invalid NewWorkConn with run id [%s]", newMsg.RunID)
- _ = msg.WriteMsg(workConn, &msg.StartWorkConn{
- Error: util.GenerateResponseErrorString("invalid NewWorkConn", err, lo.FromPtr(svr.cfg.DetailedErrorsToClient)),
- })
- return fmt.Errorf("invalid NewWorkConn with run id [%s]", newMsg.RunID)
+ return err
}
- return ctl.RegisterWorkConn(workConn)
+ return ctl.RegisterWorkConn(proxy.NewWorkConn(workConn))
}
func (svr *Service) RegisterVisitorConn(visitorConn net.Conn, newMsg *msg.NewVisitorConn) error {
diff --git a/test/e2e/mock/server/oidcserver/oidcserver.go b/test/e2e/mock/server/oidcserver/oidcserver.go
index d7aa1329..22236dc0 100644
--- a/test/e2e/mock/server/oidcserver/oidcserver.go
+++ b/test/e2e/mock/server/oidcserver/oidcserver.go
@@ -53,6 +53,8 @@ type Server struct {
tokenRequestCount atomic.Int64
}
+const maxTokenRequestBodySize = 1 << 20
+
type Option func(*Server)
func WithBindPort(port int) Option {
@@ -178,6 +180,7 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
+ r.Body = http.MaxBytesReader(w, r.Body, maxTokenRequestBodySize)
if err := r.ParseForm(); err != nil {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
@@ -187,7 +190,7 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
return
}
- if r.FormValue("grant_type") != "client_credentials" {
+ if r.Form.Get("grant_type") != "client_credentials" {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
_ = json.NewEncoder(w).Encode(map[string]any{
@@ -199,8 +202,8 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
// Accept credentials from Basic Auth or form body.
clientID, clientSecret, ok := r.BasicAuth()
if !ok {
- clientID = r.FormValue("client_id")
- clientSecret = r.FormValue("client_secret")
+ clientID = r.Form.Get("client_id")
+ clientSecret = r.Form.Get("client_secret")
}
if clientID != s.clientID || clientSecret != s.clientSecret {
w.Header().Set("Content-Type", "application/json")
diff --git a/test/e2e/v1/basic/wire.go b/test/e2e/v1/basic/wire.go
new file mode 100644
index 00000000..bbae9199
--- /dev/null
+++ b/test/e2e/v1/basic/wire.go
@@ -0,0 +1,163 @@
+package basic
+
+import (
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+
+ "github.com/onsi/ginkgo/v2"
+
+ "github.com/fatedier/frp/test/e2e/framework"
+ "github.com/fatedier/frp/test/e2e/framework/consts"
+ "github.com/fatedier/frp/test/e2e/pkg/port"
+)
+
+var _ = ginkgo.Describe("[Feature: WireProtocol]", func() {
+ f := framework.NewDefaultFramework()
+
+ ginkgo.It("v1 tcp and udp proxies", func() {
+ serverConf := consts.DefaultServerConfig
+ tcpPortName := port.GenName("WireV1TCP")
+ udpPortName := port.GenName("WireV1UDP")
+ clientConf := consts.DefaultClientConfig + fmt.Sprintf(`
+ transport.wireProtocol = "v1"
+
+ [[proxies]]
+ name = "tcp"
+ type = "tcp"
+ localPort = {{ .%s }}
+ remotePort = {{ .%s }}
+
+ [[proxies]]
+ name = "udp"
+ type = "udp"
+ localPort = {{ .%s }}
+ remotePort = {{ .%s }}
+ `, framework.TCPEchoServerPort, tcpPortName, framework.UDPEchoServerPort, udpPortName)
+
+ f.RunProcesses(serverConf, []string{clientConf})
+
+ framework.NewRequestExpect(f).PortName(tcpPortName).Ensure()
+ framework.NewRequestExpect(f).Protocol("udp").PortName(udpPortName).Ensure()
+ })
+
+ ginkgo.It("v2 tcp and udp proxies", func() {
+ serverConf := consts.DefaultServerConfig
+ tcpPortName := port.GenName("WireV2TCP")
+ udpPortName := port.GenName("WireV2UDP")
+ clientConf := consts.DefaultClientConfig + fmt.Sprintf(`
+ transport.wireProtocol = "v2"
+
+ [[proxies]]
+ name = "tcp"
+ type = "tcp"
+ localPort = {{ .%s }}
+ remotePort = {{ .%s }}
+
+ [[proxies]]
+ name = "udp"
+ type = "udp"
+ localPort = {{ .%s }}
+ remotePort = {{ .%s }}
+ `, framework.TCPEchoServerPort, tcpPortName, framework.UDPEchoServerPort, udpPortName)
+
+ f.RunProcesses(serverConf, []string{clientConf})
+
+ framework.NewRequestExpect(f).PortName(tcpPortName).Ensure()
+ framework.NewRequestExpect(f).Protocol("udp").PortName(udpPortName).Ensure()
+ })
+
+ ginkgo.It("v2 stcp visitor", func() {
+ serverConf := consts.DefaultServerConfig
+ bindPortName := port.GenName("WireV2STCP")
+ clientServerConf := consts.DefaultClientConfig + fmt.Sprintf(`
+ user = "user1"
+ transport.wireProtocol = "v2"
+
+ [[proxies]]
+ name = "stcp"
+ type = "stcp"
+ secretKey = "abc"
+ localPort = {{ .%s }}
+ `, framework.TCPEchoServerPort)
+ clientVisitorConf := consts.DefaultClientConfig + fmt.Sprintf(`
+ user = "user1"
+ transport.wireProtocol = "v2"
+
+ [[visitors]]
+ name = "stcp-visitor"
+ type = "stcp"
+ serverName = "stcp"
+ secretKey = "abc"
+ bindPort = {{ .%s }}
+ `, bindPortName)
+
+ f.RunProcesses(serverConf, []string{clientServerConf, clientVisitorConf})
+
+ framework.NewRequestExpect(f).PortName(bindPortName).Ensure()
+ })
+
+ ginkgo.It("reports client wire protocol", func() {
+ webPort := f.AllocPort()
+ serverConf := consts.DefaultServerConfig + fmt.Sprintf(`
+ webServer.port = %d
+ `, webPort)
+
+ v1PortName := port.GenName("WireReportV1")
+ v1ClientConf := consts.DefaultClientConfig + fmt.Sprintf(`
+ clientID = "wire-v1"
+ transport.wireProtocol = "v1"
+
+ [[proxies]]
+ name = "v1"
+ type = "tcp"
+ localPort = {{ .%s }}
+ remotePort = {{ .%s }}
+ `, framework.TCPEchoServerPort, v1PortName)
+
+ v2PortName := port.GenName("WireReportV2")
+ v2ClientConf := consts.DefaultClientConfig + fmt.Sprintf(`
+ clientID = "wire-v2"
+ transport.wireProtocol = "v2"
+
+ [[proxies]]
+ name = "v2"
+ type = "tcp"
+ localPort = {{ .%s }}
+ remotePort = {{ .%s }}
+ `, framework.TCPEchoServerPort, v2PortName)
+
+ f.RunProcesses(serverConf, []string{v1ClientConf, v2ClientConf})
+
+ framework.NewRequestExpect(f).PortName(v1PortName).Ensure()
+ framework.NewRequestExpect(f).PortName(v2PortName).Ensure()
+ expectClientWireProtocol(webPort, "wire-v1", "v1")
+ expectClientWireProtocol(webPort, "wire-v2", "v2")
+ })
+})
+
+type wireClientInfo struct {
+ ClientID string `json:"clientID"`
+ WireProtocol string `json:"wireProtocol"`
+}
+
+func expectClientWireProtocol(webPort int, clientID string, wireProtocol string) {
+ resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d/api/clients", webPort))
+ framework.ExpectNoError(err)
+ defer resp.Body.Close()
+ framework.ExpectEqual(resp.StatusCode, 200)
+
+ content, err := io.ReadAll(resp.Body)
+ framework.ExpectNoError(err)
+
+ var clients []wireClientInfo
+ framework.ExpectNoError(json.Unmarshal(content, &clients))
+ for _, client := range clients {
+ if client.ClientID == clientID {
+ framework.ExpectEqual(client.WireProtocol, wireProtocol)
+ return
+ }
+ }
+ framework.Failf("client %q not found in /api/clients response: %s", clientID, string(content))
+}
diff --git a/web/frps/src/components/ClientCard.vue b/web/frps/src/components/ClientCard.vue
index 531a210f..4ccd7441 100644
--- a/web/frps/src/components/ClientCard.vue
+++ b/web/frps/src/components/ClientCard.vue
@@ -16,6 +16,9 @@