From e9464919d1a3ab2e4ce0ed9d47ba172dfc0dfb3f Mon Sep 17 00:00:00 2001 From: fatedier Date: Mon, 27 Apr 2026 00:17:00 +0800 Subject: [PATCH] protocol: add v2 wire protocol with binary framing and capability negotiation (#5294) --- .github/workflows/golangci-lint.yml | 2 +- .golangci.yml | 2 +- client/connector.go | 35 +++ client/control.go | 25 +- client/control_session.go | 172 ++++++++++++ client/control_session_test.go | 245 ++++++++++++++++++ client/service.go | 98 ++----- client/visitor/visitor.go | 8 +- client/visitor/visitor_manager.go | 7 +- conf/frpc_full_example.toml | 4 + pkg/config/v1/client.go | 4 + pkg/config/v1/client_test.go | 1 + pkg/config/v1/validation/client.go | 3 + pkg/config/v1/validation/validation.go | 4 + pkg/msg/conn_test.go | 56 ++++ pkg/msg/handler.go | 88 ++++++- pkg/msg/msg.go | 36 +-- pkg/msg/msg_test.go | 55 ++++ pkg/msg/wire_v2.go | 192 ++++++++++++++ pkg/msg/wire_v2_test.go | 121 +++++++++ pkg/proto/wire/wire.go | 222 ++++++++++++++++ pkg/proto/wire/wire_test.go | 120 +++++++++ pkg/util/net/conn.go | 10 +- server/control.go | 34 +-- server/group/http.go | 6 +- server/http/controller.go | 1 + server/http/controller_test.go | 24 ++ server/http/model/types.go | 1 + server/proxy/proxy.go | 32 ++- server/proxy/proxy_test.go | 53 ++++ server/registry/registry.go | 4 +- server/registry/registry_test.go | 37 +++ server/service.go | 184 ++++++++++--- test/e2e/mock/server/oidcserver/oidcserver.go | 9 +- test/e2e/v1/basic/wire.go | 163 ++++++++++++ web/frps/src/components/ClientCard.vue | 3 + web/frps/src/types/client.ts | 1 + web/frps/src/utils/client.ts | 8 + web/frps/src/views/ClientDetail.vue | 7 + web/package-lock.json | 7 +- 40 files changed, 1861 insertions(+), 223 deletions(-) create mode 100644 client/control_session.go create mode 100644 client/control_session_test.go create mode 100644 pkg/msg/conn_test.go create mode 100644 pkg/msg/msg_test.go create mode 100644 pkg/msg/wire_v2.go create mode 100644 pkg/msg/wire_v2_test.go create mode 100644 pkg/proto/wire/wire.go create mode 100644 pkg/proto/wire/wire_test.go create mode 100644 server/proxy/proxy_test.go create mode 100644 server/registry/registry_test.go create mode 100644 test/e2e/v1/basic/wire.go diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 35cb5f65..3217c234 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -32,4 +32,4 @@ jobs: uses: golangci/golangci-lint-action@v9 with: # Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version - version: v2.10 + version: v2.11 diff --git a/.golangci.yml b/.golangci.yml index 89222745..886b55fb 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -34,7 +34,7 @@ linters: disabled-checks: - exitAfterDefer gosec: - excludes: ["G115", "G117", "G204", "G401", "G402", "G404", "G501", "G703", "G704", "G705"] + excludes: ["G115", "G117", "G118", "G204", "G401", "G402", "G404", "G501", "G703", "G704", "G705"] severity: low confidence: low govet: diff --git a/client/connector.go b/client/connector.go index 51536750..34b24f0d 100644 --- a/client/connector.go +++ b/client/connector.go @@ -29,6 +29,8 @@ import ( "github.com/samber/lo" v1 "github.com/fatedier/frp/pkg/config/v1" + "github.com/fatedier/frp/pkg/msg" + "github.com/fatedier/frp/pkg/proto/wire" "github.com/fatedier/frp/pkg/transport" netpkg "github.com/fatedier/frp/pkg/util/net" "github.com/fatedier/frp/pkg/util/xlog" @@ -41,6 +43,39 @@ type Connector interface { Close() error } +type MessageConnector interface { + Connect() (*msg.Conn, error) + Close() error +} + +type messageConnector struct { + connector Connector + wireProtocol string +} + +func newMessageConnector(connector Connector, wireProtocol string) *messageConnector { + return &messageConnector{ + connector: connector, + wireProtocol: wireProtocol, + } +} + +func (c *messageConnector) Connect() (*msg.Conn, error) { + conn, err := c.connector.Connect() + if err != nil { + return nil, err + } + if err = wire.WriteMagicIfV2(conn, c.wireProtocol); err != nil { + conn.Close() + return nil, err + } + return msg.NewConn(conn, msg.NewReadWriter(conn, c.wireProtocol)), nil +} + +func (c *messageConnector) Close() error { + return c.connector.Close() +} + // defaultConnectorImpl is the default implementation of Connector for normal frpc. type defaultConnectorImpl struct { ctx context.Context diff --git a/client/control.go b/client/control.go index 020ac94f..6e3002bc 100644 --- a/client/control.go +++ b/client/control.go @@ -27,7 +27,6 @@ import ( "github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/naming" "github.com/fatedier/frp/pkg/transport" - netpkg "github.com/fatedier/frp/pkg/util/net" "github.com/fatedier/frp/pkg/util/wait" "github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/frp/pkg/vnet" @@ -41,13 +40,11 @@ type SessionContext struct { // It should be attached to the login message when reconnecting. RunID string // Underlying control connection. Once conn is closed, the msgDispatcher and the entire Control will exit. - Conn net.Conn - // Indicates whether the connection is encrypted. - ConnEncrypted bool + Conn *msg.Conn // Auth runtime used for login, heartbeats, and encryption. Auth *auth.ClientAuth - // Connector is used to create new connections, which could be real TCP connections or virtual streams. - Connector Connector + // Connector is used to create message connections to frps. + Connector MessageConnector // Virtual net controller VnetController *vnet.Controller } @@ -91,15 +88,7 @@ func NewControl(ctx context.Context, sessionCtx *SessionContext) (*Control, erro } ctl.lastPong.Store(time.Now()) - if sessionCtx.ConnEncrypted { - cryptoRW, err := netpkg.NewCryptoReadWriter(sessionCtx.Conn, sessionCtx.Auth.EncryptionKey()) - if err != nil { - return nil, err - } - ctl.msgDispatcher = msg.NewDispatcher(cryptoRW) - } else { - ctl.msgDispatcher = msg.NewDispatcher(sessionCtx.Conn) - } + ctl.msgDispatcher = msg.NewDispatcher(sessionCtx.Conn) ctl.registerMsgHandlers() ctl.msgTransporter = transport.NewMessageTransporter(ctl.msgDispatcher) @@ -139,14 +128,14 @@ func (ctl *Control) handleReqWorkConn(_ msg.Message) { workConn.Close() return } - if err = msg.WriteMsg(workConn, m); err != nil { + if err = workConn.WriteMsg(m); err != nil { xl.Warnf("work connection write to server error: %v", err) workConn.Close() return } var startMsg msg.StartWorkConn - if err = msg.ReadMsgInto(workConn, &startMsg); err != nil { + if err = workConn.ReadMsgInto(&startMsg); err != nil { xl.Tracef("work connection closed before response StartWorkConn message: %v", err) workConn.Close() return @@ -227,7 +216,7 @@ func (ctl *Control) Done() <-chan struct{} { } // connectServer return a new connection to frps -func (ctl *Control) connectServer() (net.Conn, error) { +func (ctl *Control) connectServer() (*msg.Conn, error) { return ctl.sessionCtx.Connector.Connect() } diff --git a/client/control_session.go b/client/control_session.go new file mode 100644 index 00000000..4438acfe --- /dev/null +++ b/client/control_session.go @@ -0,0 +1,172 @@ +// Copyright 2026 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package client + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "os" + "runtime" + "time" + + "github.com/samber/lo" + + "github.com/fatedier/frp/pkg/auth" + v1 "github.com/fatedier/frp/pkg/config/v1" + "github.com/fatedier/frp/pkg/msg" + "github.com/fatedier/frp/pkg/proto/wire" + netpkg "github.com/fatedier/frp/pkg/util/net" + "github.com/fatedier/frp/pkg/util/version" + "github.com/fatedier/frp/pkg/vnet" +) + +type controlSessionDialer struct { + ctx context.Context + + common *v1.ClientCommonConfig + auth *auth.ClientAuth + clientSpec *msg.ClientSpec + vnetController *vnet.Controller + + connectorCreator func(context.Context, *v1.ClientCommonConfig) Connector +} + +func (d *controlSessionDialer) Dial(previousRunID string) (*SessionContext, error) { + connector := d.connectorCreator(d.ctx, d.common) + if err := connector.Open(); err != nil { + return nil, err + } + + success := false + defer func() { + if !success { + _ = connector.Close() + } + }() + + conn, err := connector.Connect() + if err != nil { + return nil, err + } + defer func() { + if !success { + _ = conn.Close() + } + }() + + loginMsg, err := d.buildLoginMsg(previousRunID) + if err != nil { + return nil, err + } + + loginRespMsg, err := d.exchangeLogin(conn, loginMsg) + if err != nil { + return nil, err + } + if loginRespMsg.Error != "" { + return nil, errors.New(loginRespMsg.Error) + } + + var controlRW io.ReadWriter = conn + if d.clientSpec == nil || d.clientSpec.Type != "ssh-tunnel" { + controlRW, err = netpkg.NewCryptoReadWriter(conn, d.auth.EncryptionKey()) + if err != nil { + return nil, fmt.Errorf("create control crypto read writer: %w", err) + } + } + + success = true + return &SessionContext{ + Common: d.common, + RunID: loginRespMsg.RunID, + Conn: msg.NewConn(conn, msg.NewReadWriter(controlRW, d.common.Transport.WireProtocol)), + Auth: d.auth, + Connector: newMessageConnector(connector, d.common.Transport.WireProtocol), + VnetController: d.vnetController, + }, nil +} + +func (d *controlSessionDialer) buildLoginMsg(previousRunID string) (*msg.Login, error) { + hostname, _ := os.Hostname() + loginMsg := &msg.Login{ + Arch: runtime.GOARCH, + Os: runtime.GOOS, + Hostname: hostname, + PoolCount: d.common.Transport.PoolCount, + User: d.common.User, + ClientID: d.common.ClientID, + Version: version.Full(), + Timestamp: time.Now().Unix(), + RunID: previousRunID, + Metas: d.common.Metadatas, + } + if d.clientSpec != nil { + loginMsg.ClientSpec = *d.clientSpec + } + + if err := d.auth.Setter.SetLogin(loginMsg); err != nil { + return nil, err + } + return loginMsg, nil +} + +func (d *controlSessionDialer) exchangeLogin(conn net.Conn, loginMsg *msg.Login) (*msg.LoginResp, error) { + rw := msg.NewV1ReadWriter(conn) + var wireConn *wire.Conn + + if d.common.Transport.WireProtocol == wire.ProtocolV2 { + if err := wire.WriteMagic(conn); err != nil { + return nil, err + } + + wireConn = wire.NewConn(conn) + rw = msg.NewV2ReadWriterWithConn(wireConn) + hello := wire.DefaultClientHello(wire.BootstrapInfo{ + Transport: d.common.Transport.Protocol, + TLS: lo.FromPtr(d.common.Transport.TLS.Enable) || d.common.Transport.Protocol == "wss" || d.common.Transport.Protocol == "quic", + TCPMux: lo.FromPtr(d.common.Transport.TCPMux), + }) + if err := wireConn.WriteJSONFrame(wire.FrameTypeClientHello, hello); err != nil { + return nil, err + } + } + if err := rw.WriteMsg(loginMsg); err != nil { + return nil, err + } + + _ = conn.SetReadDeadline(time.Now().Add(10 * time.Second)) + defer func() { + _ = conn.SetReadDeadline(time.Time{}) + }() + + if wireConn != nil { + var serverHello wire.ServerHello + if err := wireConn.ReadJSONFrame(wire.FrameTypeServerHello, &serverHello); err != nil { + return nil, err + } + if serverHello.Error != "" { + return nil, errors.New(serverHello.Error) + } + } + + var loginRespMsg msg.LoginResp + if err := rw.ReadMsgInto(&loginRespMsg); err != nil { + return nil, err + } + return &loginRespMsg, nil +} diff --git a/client/control_session_test.go b/client/control_session_test.go new file mode 100644 index 00000000..9e59c9cd --- /dev/null +++ b/client/control_session_test.go @@ -0,0 +1,245 @@ +// Copyright 2026 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package client + +import ( + "context" + "fmt" + "io" + "net" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/fatedier/frp/pkg/auth" + v1 "github.com/fatedier/frp/pkg/config/v1" + "github.com/fatedier/frp/pkg/msg" + "github.com/fatedier/frp/pkg/proto/wire" +) + +type testConnector struct { + conn net.Conn + closed atomic.Bool +} + +func (c *testConnector) Open() error { + return nil +} + +func (c *testConnector) Connect() (net.Conn, error) { + return c.conn, nil +} + +func (c *testConnector) Close() error { + c.closed.Store(true) + return nil +} + +type trackingConn struct { + net.Conn + closed atomic.Bool +} + +func (c *trackingConn) Close() error { + c.closed.Store(true) + return c.Conn.Close() +} + +func newTestControlSessionDialer(t *testing.T, protocol string, connector Connector, clientSpec *msg.ClientSpec) *controlSessionDialer { + t.Helper() + + authRuntime, err := auth.BuildClientAuth(&v1.AuthClientConfig{ + Method: v1.AuthMethodToken, + Token: "token", + }) + require.NoError(t, err) + + return &controlSessionDialer{ + ctx: context.Background(), + common: &v1.ClientCommonConfig{ + User: "test-user", + Transport: v1.ClientTransportConfig{ + Protocol: "tcp", + WireProtocol: protocol, + }, + }, + auth: authRuntime, + clientSpec: clientSpec, + connectorCreator: func(context.Context, *v1.ClientCommonConfig) Connector { + return connector + }, + } +} + +func TestControlSessionDialerDialV1(t *testing.T) { + clientRaw, serverRaw := net.Pipe() + defer serverRaw.Close() + + connector := &testConnector{conn: &trackingConn{Conn: clientRaw}} + serverErrCh := make(chan error, 1) + go func() { + rw := msg.NewV1ReadWriter(serverRaw) + var loginMsg msg.Login + if err := rw.ReadMsgInto(&loginMsg); err != nil { + serverErrCh <- err + return + } + if loginMsg.RunID != "previous-run-id" { + serverErrCh <- fmt.Errorf("unexpected previous run id: %s", loginMsg.RunID) + return + } + if loginMsg.User != "test-user" { + serverErrCh <- fmt.Errorf("unexpected user: %s", loginMsg.User) + return + } + serverErrCh <- rw.WriteMsg(&msg.LoginResp{RunID: "run-v1"}) + }() + + dialer := newTestControlSessionDialer(t, wire.ProtocolV1, connector, nil) + sessionCtx, err := dialer.Dial("previous-run-id") + require.NoError(t, err) + defer sessionCtx.Conn.Close() + defer sessionCtx.Connector.Close() + + require.Equal(t, "run-v1", sessionCtx.RunID) + require.NotNil(t, sessionCtx.Conn) + require.NotNil(t, sessionCtx.Connector) + require.False(t, connector.closed.Load()) + require.NoError(t, <-serverErrCh) +} + +func TestControlSessionDialerDialV2(t *testing.T) { + clientRaw, serverRaw := net.Pipe() + defer serverRaw.Close() + + connector := &testConnector{conn: &trackingConn{Conn: clientRaw}} + serverErrCh := make(chan error, 1) + go func() { + magic := make([]byte, len(wire.MagicV2)) + if _, err := io.ReadFull(serverRaw, magic); err != nil { + serverErrCh <- err + return + } + if string(magic) != wire.MagicV2 { + serverErrCh <- fmt.Errorf("unexpected magic: %q", string(magic)) + return + } + + wireConn := wire.NewConn(serverRaw) + var hello wire.ClientHello + if err := wireConn.ReadJSONFrame(wire.FrameTypeClientHello, &hello); err != nil { + serverErrCh <- err + return + } + if err := wire.ValidateClientHello(hello); err != nil { + serverErrCh <- err + return + } + + rw := msg.NewV2ReadWriterWithConn(wireConn) + var loginMsg msg.Login + if err := rw.ReadMsgInto(&loginMsg); err != nil { + serverErrCh <- err + return + } + if loginMsg.User != "test-user" { + serverErrCh <- fmt.Errorf("unexpected user: %s", loginMsg.User) + return + } + if err := wireConn.WriteJSONFrame(wire.FrameTypeServerHello, wire.DefaultServerHello()); err != nil { + serverErrCh <- err + return + } + serverErrCh <- rw.WriteMsg(&msg.LoginResp{RunID: "run-v2"}) + }() + + dialer := newTestControlSessionDialer(t, wire.ProtocolV2, connector, nil) + sessionCtx, err := dialer.Dial("") + require.NoError(t, err) + defer sessionCtx.Conn.Close() + defer sessionCtx.Connector.Close() + + require.Equal(t, "run-v2", sessionCtx.RunID) + require.NotNil(t, sessionCtx.Conn) + require.NotNil(t, sessionCtx.Connector) + require.False(t, connector.closed.Load()) + require.NoError(t, <-serverErrCh) +} + +func TestControlSessionDialerDialLoginErrorClosesResources(t *testing.T) { + clientRaw, serverRaw := net.Pipe() + defer serverRaw.Close() + + clientConn := &trackingConn{Conn: clientRaw} + connector := &testConnector{conn: clientConn} + serverErrCh := make(chan error, 1) + go func() { + rw := msg.NewV1ReadWriter(serverRaw) + var loginMsg msg.Login + if err := rw.ReadMsgInto(&loginMsg); err != nil { + serverErrCh <- err + return + } + serverErrCh <- rw.WriteMsg(&msg.LoginResp{Error: "login denied"}) + }() + + dialer := newTestControlSessionDialer(t, wire.ProtocolV1, connector, nil) + sessionCtx, err := dialer.Dial("") + require.Nil(t, sessionCtx) + require.ErrorContains(t, err, "login denied") + require.True(t, clientConn.closed.Load()) + require.True(t, connector.closed.Load()) + require.NoError(t, <-serverErrCh) +} + +func TestControlSessionDialerDialSSHTunnelSkipsControlEncryption(t *testing.T) { + clientRaw, serverRaw := net.Pipe() + defer serverRaw.Close() + + connector := &testConnector{conn: &trackingConn{Conn: clientRaw}} + serverErrCh := make(chan error, 1) + go func() { + rw := msg.NewV1ReadWriter(serverRaw) + var loginMsg msg.Login + if err := rw.ReadMsgInto(&loginMsg); err != nil { + serverErrCh <- err + return + } + if err := rw.WriteMsg(&msg.LoginResp{RunID: "run-ssh-tunnel"}); err != nil { + serverErrCh <- err + return + } + + _ = serverRaw.SetReadDeadline(time.Now().Add(time.Second)) + var ping msg.Ping + if err := rw.ReadMsgInto(&ping); err != nil { + serverErrCh <- err + return + } + serverErrCh <- nil + }() + + dialer := newTestControlSessionDialer(t, wire.ProtocolV1, connector, &msg.ClientSpec{Type: "ssh-tunnel"}) + sessionCtx, err := dialer.Dial("") + require.NoError(t, err) + defer sessionCtx.Conn.Close() + defer sessionCtx.Connector.Close() + + require.Equal(t, "run-ssh-tunnel", sessionCtx.RunID) + require.NoError(t, sessionCtx.Conn.WriteMsg(&msg.Ping{})) + require.NoError(t, <-serverErrCh) +} diff --git a/client/service.go b/client/service.go index 5c51fd67..e6e42e46 100644 --- a/client/service.go +++ b/client/service.go @@ -21,7 +21,6 @@ import ( "net" "net/http" "os" - "runtime" "sync" "time" @@ -38,7 +37,6 @@ import ( httppkg "github.com/fatedier/frp/pkg/util/http" "github.com/fatedier/frp/pkg/util/log" netpkg "github.com/fatedier/frp/pkg/util/net" - "github.com/fatedier/frp/pkg/util/version" "github.com/fatedier/frp/pkg/util/wait" "github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/frp/pkg/vnet" @@ -303,80 +301,20 @@ func (svr *Service) keepControllerWorking() { ), true, svr.ctx.Done()) } -// login creates a connection to frps and registers it self as a client -// conn: control connection -// session: if it's not nil, using tcp mux -func (svr *Service) login() (conn net.Conn, connector Connector, err error) { - xl := xlog.FromContextSafe(svr.ctx) - connector = svr.connectorCreator(svr.ctx, svr.common) - if err = connector.Open(); err != nil { - return nil, nil, err - } - - defer func() { - if err != nil { - connector.Close() - } - }() - - conn, err = connector.Connect() - if err != nil { - return - } - - hostname, _ := os.Hostname() - - loginMsg := &msg.Login{ - Arch: runtime.GOARCH, - Os: runtime.GOOS, - Hostname: hostname, - PoolCount: svr.common.Transport.PoolCount, - User: svr.common.User, - ClientID: svr.common.ClientID, - Version: version.Full(), - Timestamp: time.Now().Unix(), - RunID: svr.runID, - Metas: svr.common.Metadatas, - } - if svr.clientSpec != nil { - loginMsg.ClientSpec = *svr.clientSpec - } - - // Add auth - if err = svr.auth.Setter.SetLogin(loginMsg); err != nil { - return - } - - if err = msg.WriteMsg(conn, loginMsg); err != nil { - return - } - - var loginRespMsg msg.LoginResp - _ = conn.SetReadDeadline(time.Now().Add(10 * time.Second)) - if err = msg.ReadMsgInto(conn, &loginRespMsg); err != nil { - return - } - _ = conn.SetReadDeadline(time.Time{}) - - if loginRespMsg.Error != "" { - err = fmt.Errorf("%s", loginRespMsg.Error) - xl.Errorf("%s", loginRespMsg.Error) - return - } - - svr.runID = loginRespMsg.RunID - xl.AddPrefix(xlog.LogPrefix{Name: "runID", Value: svr.runID}) - - xl.Infof("login to server success, get run id [%s]", loginRespMsg.RunID) - return -} - func (svr *Service) loopLoginUntilSuccess(maxInterval time.Duration, firstLoginExit bool) { xl := xlog.FromContextSafe(svr.ctx) loginFunc := func() (bool, error) { xl.Infof("try to connect to server...") - conn, connector, err := svr.login() + dialer := &controlSessionDialer{ + ctx: svr.ctx, + common: svr.common, + auth: svr.auth, + clientSpec: svr.clientSpec, + vnetController: svr.vnetController, + connectorCreator: svr.connectorCreator, + } + sessionCtx, err := dialer.Dial(svr.runID) if err != nil { xl.Warnf("connect to server error: %v", err) if firstLoginExit { @@ -385,25 +323,19 @@ func (svr *Service) loopLoginUntilSuccess(maxInterval time.Duration, firstLoginE return false, err } + svr.runID = sessionCtx.RunID + xl.AddPrefix(xlog.LogPrefix{Name: "runID", Value: svr.runID}) + xl.Infof("login to server success, get run id [%s]", svr.runID) + svr.cfgMu.RLock() proxyCfgs := svr.proxyCfgs visitorCfgs := svr.visitorCfgs svr.cfgMu.RUnlock() - connEncrypted := svr.clientSpec == nil || svr.clientSpec.Type != "ssh-tunnel" - - sessionCtx := &SessionContext{ - Common: svr.common, - RunID: svr.runID, - Conn: conn, - ConnEncrypted: connEncrypted, - Auth: svr.auth, - Connector: connector, - VnetController: svr.vnetController, - } ctl, err := NewControl(svr.ctx, sessionCtx) if err != nil { - conn.Close() + sessionCtx.Conn.Close() + sessionCtx.Connector.Close() xl.Errorf("new control error: %v", err) return false, err } diff --git a/client/visitor/visitor.go b/client/visitor/visitor.go index 14a7aa37..dff0bb94 100644 --- a/client/visitor/visitor.go +++ b/client/visitor/visitor.go @@ -38,7 +38,7 @@ import ( // Helper wraps some functions for visitor to use. type Helper interface { // ConnectServer directly connects to the frp server. - ConnectServer() (net.Conn, error) + ConnectServer() (*msg.Conn, error) // TransferConn transfers the connection to another visitor. TransferConn(string, net.Conn) error // MsgTransporter returns the message transporter that is used to send and receive messages @@ -167,15 +167,15 @@ func (v *BaseVisitor) dialRawVisitorConn(cfg *v1.VisitorBaseConfig) (net.Conn, e UseEncryption: cfg.Transport.UseEncryption, UseCompression: cfg.Transport.UseCompression, } - err = msg.WriteMsg(visitorConn, newVisitorConnMsg) + err = visitorConn.WriteMsg(newVisitorConnMsg) if err != nil { visitorConn.Close() return nil, fmt.Errorf("send newVisitorConnMsg to server error: %v", err) } - var newVisitorConnRespMsg msg.NewVisitorConnResp _ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second)) - err = msg.ReadMsgInto(visitorConn, &newVisitorConnRespMsg) + var newVisitorConnRespMsg msg.NewVisitorConnResp + err = visitorConn.ReadMsgInto(&newVisitorConnRespMsg) if err != nil { visitorConn.Close() return nil, fmt.Errorf("read newVisitorConnRespMsg error: %v", err) diff --git a/client/visitor/visitor_manager.go b/client/visitor/visitor_manager.go index b60f7047..1ca194bd 100644 --- a/client/visitor/visitor_manager.go +++ b/client/visitor/visitor_manager.go @@ -25,6 +25,7 @@ import ( "github.com/samber/lo" v1 "github.com/fatedier/frp/pkg/config/v1" + "github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/transport" "github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/frp/pkg/vnet" @@ -49,7 +50,7 @@ func NewManager( ctx context.Context, runID string, clientCfg *v1.ClientCommonConfig, - connectServer func() (net.Conn, error), + connectServer func() (*msg.Conn, error), msgTransporter transport.MessageTransporter, vnetController *vnet.Controller, ) *Manager { @@ -199,14 +200,14 @@ func (vm *Manager) GetVisitorCfg(name string) (v1.VisitorConfigurer, bool) { } type visitorHelperImpl struct { - connectServerFn func() (net.Conn, error) + connectServerFn func() (*msg.Conn, error) msgTransporter transport.MessageTransporter vnetController *vnet.Controller transferConnFn func(name string, conn net.Conn) error runID string } -func (v *visitorHelperImpl) ConnectServer() (net.Conn, error) { +func (v *visitorHelperImpl) ConnectServer() (*msg.Conn, error) { return v.connectServerFn() } diff --git a/conf/frpc_full_example.toml b/conf/frpc_full_example.toml index 50a00cfd..b141b0ed 100644 --- a/conf/frpc_full_example.toml +++ b/conf/frpc_full_example.toml @@ -103,6 +103,10 @@ transport.poolCount = 5 # supports tcp, kcp, quic, websocket and wss now, default is tcp transport.protocol = "tcp" +# FRP wire protocol used inside the selected transport. +# supports v1 and v2, default is v1. v2 requires frps support and must be enabled explicitly. +# transport.wireProtocol = "v1" + # set client binding ip when connect server, default is empty. # only when protocol = tcp or websocket, the value will be used. transport.connectServerLocalIP = "0.0.0.0" diff --git a/pkg/config/v1/client.go b/pkg/config/v1/client.go index 783eee8e..05d02c94 100644 --- a/pkg/config/v1/client.go +++ b/pkg/config/v1/client.go @@ -104,6 +104,9 @@ type ClientTransportConfig struct { // Valid values are "tcp", "kcp", "quic", "websocket" and "wss". By default, this value // is "tcp". Protocol string `json:"protocol,omitempty"` + // WireProtocol specifies the frpc/frps internal wire protocol version. + // Valid values are "v1" and "v2". By default, this value is "v1". + WireProtocol string `json:"wireProtocol,omitempty"` // The maximum amount of time a dial to server will wait for a connect to complete. DialServerTimeout int64 `json:"dialServerTimeout,omitempty"` // DialServerKeepAlive specifies the interval between keep-alive probes for an active network connection between frpc and frps. @@ -143,6 +146,7 @@ type ClientTransportConfig struct { func (c *ClientTransportConfig) Complete() { c.Protocol = util.EmptyOr(c.Protocol, "tcp") + c.WireProtocol = util.EmptyOr(c.WireProtocol, "v1") c.DialServerTimeout = util.EmptyOr(c.DialServerTimeout, 10) c.DialServerKeepAlive = util.EmptyOr(c.DialServerKeepAlive, 7200) c.ProxyURL = util.EmptyOr(c.ProxyURL, os.Getenv("http_proxy")) diff --git a/pkg/config/v1/client_test.go b/pkg/config/v1/client_test.go index 5473a5f6..67214a5d 100644 --- a/pkg/config/v1/client_test.go +++ b/pkg/config/v1/client_test.go @@ -29,6 +29,7 @@ func TestClientConfigComplete(t *testing.T) { require.EqualValues("token", c.Auth.Method) require.Equal(true, lo.FromPtr(c.Transport.TCPMux)) + require.Equal("v1", c.Transport.WireProtocol) require.Equal(true, lo.FromPtr(c.LoginFailExit)) require.Equal(true, lo.FromPtr(c.Transport.TLS.Enable)) require.Equal(true, lo.FromPtr(c.Transport.TLS.DisableCustomTLSFirstByte)) diff --git a/pkg/config/v1/validation/client.go b/pkg/config/v1/validation/client.go index c90d525d..5c8433d5 100644 --- a/pkg/config/v1/validation/client.go +++ b/pkg/config/v1/validation/client.go @@ -146,6 +146,9 @@ func validateTransportConfig(c *v1.ClientTransportConfig) (Warning, error) { if !slices.Contains(SupportedTransportProtocols, c.Protocol) { errs = AppendError(errs, fmt.Errorf("invalid transport.protocol, optional values are %v", SupportedTransportProtocols)) } + if !slices.Contains(SupportedWireProtocols, c.WireProtocol) { + errs = AppendError(errs, fmt.Errorf("invalid transport.wireProtocol, optional values are %v", SupportedWireProtocols)) + } return warnings, errs } diff --git a/pkg/config/v1/validation/validation.go b/pkg/config/v1/validation/validation.go index 4ca6b67f..417cec0f 100644 --- a/pkg/config/v1/validation/validation.go +++ b/pkg/config/v1/validation/validation.go @@ -29,6 +29,10 @@ var ( "websocket", "wss", } + SupportedWireProtocols = []string{ + "v1", + "v2", + } SupportedAuthMethods = []v1.AuthMethod{ "token", diff --git a/pkg/msg/conn_test.go b/pkg/msg/conn_test.go new file mode 100644 index 00000000..f3498f03 --- /dev/null +++ b/pkg/msg/conn_test.go @@ -0,0 +1,56 @@ +// Copyright 2026 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package msg + +import ( + "net" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/fatedier/frp/pkg/proto/wire" +) + +func TestConnReadWriteMsg(t *testing.T) { + tests := []struct { + name string + protocol string + }{ + {name: "v1", protocol: wire.ProtocolV1}, + {name: "v2", protocol: wire.ProtocolV2}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client, server := net.Pipe() + defer client.Close() + defer server.Close() + + clientConn := NewConn(client, NewReadWriter(client, tt.protocol)) + serverConn := NewConn(server, NewReadWriter(server, tt.protocol)) + + in := &Ping{PrivilegeKey: "key", Timestamp: 123} + errCh := make(chan error, 1) + go func() { + errCh <- clientConn.WriteMsg(in) + }() + + out, err := serverConn.ReadMsg() + require.NoError(t, err) + require.Equal(t, in, out) + require.NoError(t, <-errCh) + }) + } +} diff --git a/pkg/msg/handler.go b/pkg/msg/handler.go index 243e599a..b073e59b 100644 --- a/pkg/msg/handler.go +++ b/pkg/msg/handler.go @@ -15,10 +15,90 @@ package msg import ( + "context" "io" + "net" "reflect" + + "github.com/fatedier/frp/pkg/proto/wire" ) +type ReadWriter interface { + ReadMsg() (Message, error) + ReadMsgInto(Message) error + WriteMsg(Message) error +} + +type Conn struct { + net.Conn + rw ReadWriter +} + +func NewConn(conn net.Conn, rw ReadWriter) *Conn { + return &Conn{ + Conn: conn, + rw: rw, + } +} + +func (c *Conn) ReadMsg() (Message, error) { + return c.rw.ReadMsg() +} + +func (c *Conn) ReadMsgInto(m Message) error { + return c.rw.ReadMsgInto(m) +} + +func (c *Conn) WriteMsg(m Message) error { + return c.rw.WriteMsg(m) +} + +func (c *Conn) Context() context.Context { + if getter, ok := c.Conn.(interface{ Context() context.Context }); ok { + return getter.Context() + } + return context.Background() +} + +func (c *Conn) WithContext(ctx context.Context) { + if setter, ok := c.Conn.(interface{ WithContext(context.Context) }); ok { + setter.WithContext(ctx) + } +} + +type V1ReadWriter struct { + rw io.ReadWriter +} + +func NewV1ReadWriter(rw io.ReadWriter) ReadWriter { + return &V1ReadWriter{rw: rw} +} + +// NewReadWriter wraps rw with the message codec for the selected wire protocol. +// An empty protocol keeps the historical v1 behavior for tests and older call sites. +func NewReadWriter(rw io.ReadWriter, wireProtocol string) ReadWriter { + switch wireProtocol { + case wire.ProtocolV2: + return NewV2ReadWriter(rw) + case "", wire.ProtocolV1: + return NewV1ReadWriter(rw) + default: + return NewV1ReadWriter(rw) + } +} + +func (rw *V1ReadWriter) ReadMsg() (Message, error) { + return ReadMsg(rw.rw) +} + +func (rw *V1ReadWriter) ReadMsgInto(m Message) error { + return ReadMsgInto(rw.rw, m) +} + +func (rw *V1ReadWriter) WriteMsg(m Message) error { + return WriteMsg(rw.rw, m) +} + func AsyncHandler(f func(Message)) func(Message) { return func(m Message) { go f(m) @@ -27,7 +107,7 @@ func AsyncHandler(f func(Message)) func(Message) { // Dispatcher is used to send messages to net.Conn or register handlers for messages read from net.Conn. type Dispatcher struct { - rw io.ReadWriter + rw ReadWriter sendCh chan Message doneCh chan struct{} @@ -35,7 +115,7 @@ type Dispatcher struct { defaultHandler func(Message) } -func NewDispatcher(rw io.ReadWriter) *Dispatcher { +func NewDispatcher(rw ReadWriter) *Dispatcher { return &Dispatcher{ rw: rw, sendCh: make(chan Message, 100), @@ -56,14 +136,14 @@ func (d *Dispatcher) sendLoop() { case <-d.doneCh: return case m := <-d.sendCh: - _ = WriteMsg(d.rw, m) + _ = d.rw.WriteMsg(m) } } } func (d *Dispatcher) readLoop() { for { - m, err := ReadMsg(d.rw) + m, err := d.rw.ReadMsg() if err != nil { close(d.doneCh) return diff --git a/pkg/msg/msg.go b/pkg/msg/msg.go index e8bcbc35..52948bd5 100644 --- a/pkg/msg/msg.go +++ b/pkg/msg/msg.go @@ -20,24 +20,24 @@ import ( ) const ( - TypeLogin = 'o' - TypeLoginResp = '1' - TypeNewProxy = 'p' - TypeNewProxyResp = '2' - TypeCloseProxy = 'c' - TypeNewWorkConn = 'w' - TypeReqWorkConn = 'r' - TypeStartWorkConn = 's' - TypeNewVisitorConn = 'v' - TypeNewVisitorConnResp = '3' - TypePing = 'h' - TypePong = '4' - TypeUDPPacket = 'u' - TypeNatHoleVisitor = 'i' - TypeNatHoleClient = 'n' - TypeNatHoleResp = 'm' - TypeNatHoleSid = '5' - TypeNatHoleReport = '6' + TypeLogin byte = 'o' + TypeLoginResp byte = '1' + TypeNewProxy byte = 'p' + TypeNewProxyResp byte = '2' + TypeCloseProxy byte = 'c' + TypeNewWorkConn byte = 'w' + TypeReqWorkConn byte = 'r' + TypeStartWorkConn byte = 's' + TypeNewVisitorConn byte = 'v' + TypeNewVisitorConnResp byte = '3' + TypePing byte = 'h' + TypePong byte = '4' + TypeUDPPacket byte = 'u' + TypeNatHoleVisitor byte = 'i' + TypeNatHoleClient byte = 'n' + TypeNatHoleResp byte = 'm' + TypeNatHoleSid byte = '5' + TypeNatHoleReport byte = '6' ) var msgTypeMap = map[byte]any{ diff --git a/pkg/msg/msg_test.go b/pkg/msg/msg_test.go new file mode 100644 index 00000000..06272faf --- /dev/null +++ b/pkg/msg/msg_test.go @@ -0,0 +1,55 @@ +// Copyright 2026 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package msg + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestV1MessageTypeIDsAreStable(t *testing.T) { + require.Equal(t, byte('o'), TypeLogin) + require.Equal(t, byte('1'), TypeLoginResp) + require.Equal(t, byte('p'), TypeNewProxy) + require.Equal(t, byte('2'), TypeNewProxyResp) + require.Equal(t, byte('c'), TypeCloseProxy) + require.Equal(t, byte('w'), TypeNewWorkConn) + require.Equal(t, byte('r'), TypeReqWorkConn) + require.Equal(t, byte('s'), TypeStartWorkConn) + require.Equal(t, byte('v'), TypeNewVisitorConn) + require.Equal(t, byte('3'), TypeNewVisitorConnResp) + require.Equal(t, byte('h'), TypePing) + require.Equal(t, byte('4'), TypePong) + require.Equal(t, byte('u'), TypeUDPPacket) + require.Equal(t, byte('i'), TypeNatHoleVisitor) + require.Equal(t, byte('n'), TypeNatHoleClient) + require.Equal(t, byte('m'), TypeNatHoleResp) + require.Equal(t, byte('5'), TypeNatHoleSid) + require.Equal(t, byte('6'), TypeNatHoleReport) +} + +func TestMessageTypeMapIsCompleteAndUnique(t *testing.T) { + require.Len(t, msgTypeMap, 18) + + msgTypes := make(map[reflect.Type]struct{}, len(msgTypeMap)) + + for _, m := range msgTypeMap { + msgType := reflect.TypeOf(m) + require.NotContains(t, msgTypes, msgType) + msgTypes[msgType] = struct{}{} + } +} diff --git a/pkg/msg/wire_v2.go b/pkg/msg/wire_v2.go new file mode 100644 index 00000000..8d2cd88d --- /dev/null +++ b/pkg/msg/wire_v2.go @@ -0,0 +1,192 @@ +// Copyright 2026 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package msg + +import ( + "encoding/binary" + "encoding/json" + "fmt" + "io" + "reflect" + + "github.com/fatedier/frp/pkg/proto/wire" +) + +const ( + V2TypeLogin uint16 = 1 + V2TypeLoginResp uint16 = 2 + V2TypeNewProxy uint16 = 3 + V2TypeNewProxyResp uint16 = 4 + V2TypeCloseProxy uint16 = 5 + V2TypeNewWorkConn uint16 = 6 + V2TypeReqWorkConn uint16 = 7 + V2TypeStartWorkConn uint16 = 8 + V2TypeNewVisitorConn uint16 = 9 + V2TypeNewVisitorConnResp uint16 = 10 + V2TypePing uint16 = 11 + V2TypePong uint16 = 12 + V2TypeUDPPacket uint16 = 13 + V2TypeNatHoleVisitor uint16 = 14 + V2TypeNatHoleClient uint16 = 15 + V2TypeNatHoleResp uint16 = 16 + V2TypeNatHoleSid uint16 = 17 + V2TypeNatHoleReport uint16 = 18 +) + +var v2MsgTypeMap = map[uint16]any{ + V2TypeLogin: Login{}, + V2TypeLoginResp: LoginResp{}, + V2TypeNewProxy: NewProxy{}, + V2TypeNewProxyResp: NewProxyResp{}, + V2TypeCloseProxy: CloseProxy{}, + V2TypeNewWorkConn: NewWorkConn{}, + V2TypeReqWorkConn: ReqWorkConn{}, + V2TypeStartWorkConn: StartWorkConn{}, + V2TypeNewVisitorConn: NewVisitorConn{}, + V2TypeNewVisitorConnResp: NewVisitorConnResp{}, + V2TypePing: Ping{}, + V2TypePong: Pong{}, + V2TypeUDPPacket: UDPPacket{}, + V2TypeNatHoleVisitor: NatHoleVisitor{}, + V2TypeNatHoleClient: NatHoleClient{}, + V2TypeNatHoleResp: NatHoleResp{}, + V2TypeNatHoleSid: NatHoleSid{}, + V2TypeNatHoleReport: NatHoleReport{}, +} + +var v2MsgReflectTypeMap, v2MsgTypeIDMap = buildV2MsgTypeMaps() + +func buildV2MsgTypeMaps() (map[uint16]reflect.Type, map[reflect.Type]uint16) { + reflectTypeMap := make(map[uint16]reflect.Type, len(v2MsgTypeMap)) + typeIDMap := make(map[reflect.Type]uint16, len(v2MsgTypeMap)) + for typeID, m := range v2MsgTypeMap { + t := reflect.TypeOf(m) + reflectTypeMap[typeID] = t + typeIDMap[t] = typeID + } + return reflectTypeMap, typeIDMap +} + +type V2ReadWriter struct { + conn *wire.Conn +} + +func NewV2ReadWriter(rw io.ReadWriter) *V2ReadWriter { + return NewV2ReadWriterWithConn(wire.NewConn(rw)) +} + +func NewV2ReadWriterWithConn(conn *wire.Conn) *V2ReadWriter { + return &V2ReadWriter{conn: conn} +} + +func (rw *V2ReadWriter) WireConn() *wire.Conn { + return rw.conn +} + +func (rw *V2ReadWriter) ReadMsg() (Message, error) { + f, err := rw.conn.ReadFrame() + if err != nil { + return nil, err + } + return DecodeV2MessageFrame(f) +} + +func (rw *V2ReadWriter) ReadMsgInto(m Message) error { + f, err := rw.conn.ReadFrame() + if err != nil { + return err + } + return DecodeV2MessageFrameInto(f, m) +} + +func (rw *V2ReadWriter) WriteMsg(m Message) error { + f, err := EncodeV2MessageFrame(m) + if err != nil { + return err + } + return rw.conn.WriteFrame(f) +} + +func DecodeV2MessageFrame(f *wire.Frame) (Message, error) { + if f.Type != wire.FrameTypeMessage { + return nil, fmt.Errorf("unexpected frame type %d, want %d", f.Type, wire.FrameTypeMessage) + } + if len(f.Payload) < 2 { + return nil, fmt.Errorf("message frame payload too short") + } + typeID := binary.BigEndian.Uint16(f.Payload[:2]) + t, ok := v2MsgReflectTypeMap[typeID] + if !ok { + return nil, fmt.Errorf("unknown v2 message type %d", typeID) + } + m := reflect.New(t).Interface() + if err := json.Unmarshal(f.Payload[2:], m); err != nil { + return nil, err + } + return m, nil +} + +func DecodeV2MessageFrameInto(f *wire.Frame, out Message) error { + if f.Type != wire.FrameTypeMessage { + return fmt.Errorf("unexpected frame type %d, want %d", f.Type, wire.FrameTypeMessage) + } + if len(f.Payload) < 2 { + return fmt.Errorf("message frame payload too short") + } + + typeID := binary.BigEndian.Uint16(f.Payload[:2]) + outType := reflect.TypeOf(out) + if outType == nil || outType.Kind() != reflect.Pointer { + return fmt.Errorf("message target must be a pointer") + } + elemType := outType.Elem() + expectedTypeID, ok := v2MsgTypeIDMap[elemType] + if !ok { + return fmt.Errorf("unknown v2 message type %s", elemType.String()) + } + if typeID != expectedTypeID { + actualType, ok := v2MsgReflectTypeMap[typeID] + if !ok { + return fmt.Errorf("unknown v2 message type %d", typeID) + } + return fmt.Errorf("unexpected message type %s, want %s", actualType.String(), elemType.String()) + } + return json.Unmarshal(f.Payload[2:], out) +} + +func EncodeV2MessageFrame(m Message) (*wire.Frame, error) { + t := reflect.TypeOf(m) + if t == nil { + return nil, fmt.Errorf("nil message") + } + if t.Kind() == reflect.Pointer { + t = t.Elem() + } + typeID, ok := v2MsgTypeIDMap[t] + if !ok { + return nil, fmt.Errorf("unknown v2 message type %s", t.String()) + } + content, err := json.Marshal(m) + if err != nil { + return nil, err + } + payload := make([]byte, 2+len(content)) + binary.BigEndian.PutUint16(payload[:2], typeID) + copy(payload[2:], content) + return &wire.Frame{ + Type: wire.FrameTypeMessage, + Payload: payload, + }, nil +} diff --git a/pkg/msg/wire_v2_test.go b/pkg/msg/wire_v2_test.go new file mode 100644 index 00000000..5879fc37 --- /dev/null +++ b/pkg/msg/wire_v2_test.go @@ -0,0 +1,121 @@ +// Copyright 2026 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package msg + +import ( + "bytes" + "encoding/binary" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/fatedier/frp/pkg/proto/wire" +) + +func TestV2ReadWriterRoundTrip(t *testing.T) { + var buf bytes.Buffer + rw := NewV2ReadWriter(&buf) + + in := &Login{ + Version: "test-version", + RunID: "run-id", + User: "user", + } + require.NoError(t, rw.WriteMsg(in)) + + out, err := rw.ReadMsg() + require.NoError(t, err) + require.Equal(t, in, out) +} + +func TestNewReadWriter(t *testing.T) { + require.IsType(t, &V1ReadWriter{}, NewReadWriter(&bytes.Buffer{}, "")) + require.IsType(t, &V1ReadWriter{}, NewReadWriter(&bytes.Buffer{}, wire.ProtocolV1)) + require.IsType(t, &V2ReadWriter{}, NewReadWriter(&bytes.Buffer{}, wire.ProtocolV2)) +} + +func TestV2MessageTypeIDsAreStable(t *testing.T) { + require.Equal(t, uint16(1), V2TypeLogin) + require.Equal(t, uint16(2), V2TypeLoginResp) + require.Equal(t, uint16(3), V2TypeNewProxy) + require.Equal(t, uint16(4), V2TypeNewProxyResp) + require.Equal(t, uint16(5), V2TypeCloseProxy) + require.Equal(t, uint16(6), V2TypeNewWorkConn) + require.Equal(t, uint16(7), V2TypeReqWorkConn) + require.Equal(t, uint16(8), V2TypeStartWorkConn) + require.Equal(t, uint16(9), V2TypeNewVisitorConn) + require.Equal(t, uint16(10), V2TypeNewVisitorConnResp) + require.Equal(t, uint16(11), V2TypePing) + require.Equal(t, uint16(12), V2TypePong) + require.Equal(t, uint16(13), V2TypeUDPPacket) + require.Equal(t, uint16(14), V2TypeNatHoleVisitor) + require.Equal(t, uint16(15), V2TypeNatHoleClient) + require.Equal(t, uint16(16), V2TypeNatHoleResp) + require.Equal(t, uint16(17), V2TypeNatHoleSid) + require.Equal(t, uint16(18), V2TypeNatHoleReport) +} + +func TestV2MessageFrameEncoding(t *testing.T) { + frame, err := EncodeV2MessageFrame(&ReqWorkConn{}) + require.NoError(t, err) + require.Equal(t, wire.FrameTypeMessage, frame.Type) + require.Len(t, frame.Payload, 4) + require.Equal(t, V2TypeReqWorkConn, binary.BigEndian.Uint16(frame.Payload[:2])) + + out, err := DecodeV2MessageFrame(frame) + require.NoError(t, err) + require.IsType(t, &ReqWorkConn{}, out) +} + +func TestDecodeV2MessageFrameInto(t *testing.T) { + in := &StartWorkConn{ProxyName: "tcp", SrcAddr: "127.0.0.1", SrcPort: 1234} + frame, err := EncodeV2MessageFrame(in) + require.NoError(t, err) + + var out StartWorkConn + require.NoError(t, DecodeV2MessageFrameInto(frame, &out)) + require.Equal(t, *in, out) +} + +func TestDecodeV2MessageFrameRejectsInvalidFrame(t *testing.T) { + _, err := DecodeV2MessageFrame(&wire.Frame{Type: wire.FrameTypeClientHello}) + require.ErrorContains(t, err, "unexpected frame type") + + _, err = DecodeV2MessageFrame(&wire.Frame{Type: wire.FrameTypeMessage, Payload: []byte{0}}) + require.ErrorContains(t, err, "payload too short") + + payload := make([]byte, 4) + binary.BigEndian.PutUint16(payload[:2], 65535) + copy(payload[2:], []byte("{}")) + _, err = DecodeV2MessageFrame(&wire.Frame{Type: wire.FrameTypeMessage, Payload: payload}) + require.ErrorContains(t, err, "unknown v2 message type") +} + +func TestDecodeV2MessageFrameIntoRejectsWrongTarget(t *testing.T) { + frame, err := EncodeV2MessageFrame(&ReqWorkConn{}) + require.NoError(t, err) + + var out StartWorkConn + err = DecodeV2MessageFrameInto(frame, &out) + require.ErrorContains(t, err, "unexpected message type") + + err = DecodeV2MessageFrameInto(frame, StartWorkConn{}) + require.ErrorContains(t, err, "must be a pointer") +} + +func TestEncodeV2MessageFrameRejectsUnknownMessage(t *testing.T) { + _, err := EncodeV2MessageFrame(struct{}{}) + require.ErrorContains(t, err, "unknown v2 message type") +} diff --git a/pkg/proto/wire/wire.go b/pkg/proto/wire/wire.go new file mode 100644 index 00000000..5c38292d --- /dev/null +++ b/pkg/proto/wire/wire.go @@ -0,0 +1,222 @@ +// Copyright 2026 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wire + +import ( + "encoding/binary" + "encoding/json" + "fmt" + "io" + "net" + "slices" + + libnet "github.com/fatedier/golib/net" +) + +const ( + ProtocolV1 = "v1" + ProtocolV2 = "v2" + + WireVersionV2 = 2 + + FrameTypeClientHello uint16 = 1 + FrameTypeServerHello uint16 = 2 + FrameTypeMessage uint16 = 16 + + MessageCodecJSON = "json" + DefaultMaxFramePayloadSize = 64 * 1024 + + MagicV2 = "FRP\x00\x02\r\n" +) + +type Frame struct { + Type uint16 + Flags uint16 + Payload []byte +} + +type Conn struct { + rw io.ReadWriter + maxFramePayloadSize uint32 +} + +func NewConn(rw io.ReadWriter) *Conn { + return &Conn{ + rw: rw, + maxFramePayloadSize: DefaultMaxFramePayloadSize, + } +} + +func (c *Conn) ReadFrame() (*Frame, error) { + header := make([]byte, 8) + if _, err := io.ReadFull(c.rw, header); err != nil { + return nil, err + } + + frameType := binary.BigEndian.Uint16(header[0:2]) + flags := binary.BigEndian.Uint16(header[2:4]) + length := binary.BigEndian.Uint32(header[4:8]) + if flags != 0 { + return nil, fmt.Errorf("unsupported frame flags: %d", flags) + } + if length > c.maxFramePayloadSize { + return nil, fmt.Errorf("frame payload length %d exceeds limit %d", length, c.maxFramePayloadSize) + } + + payload := make([]byte, length) + if _, err := io.ReadFull(c.rw, payload); err != nil { + return nil, err + } + return &Frame{ + Type: frameType, + Flags: flags, + Payload: payload, + }, nil +} + +func (c *Conn) WriteFrame(f *Frame) error { + if f.Flags != 0 { + return fmt.Errorf("unsupported frame flags: %d", f.Flags) + } + if len(f.Payload) > int(c.maxFramePayloadSize) { + return fmt.Errorf("frame payload length %d exceeds limit %d", len(f.Payload), c.maxFramePayloadSize) + } + + header := make([]byte, 8) + binary.BigEndian.PutUint16(header[0:2], f.Type) + binary.BigEndian.PutUint16(header[2:4], f.Flags) + binary.BigEndian.PutUint32(header[4:8], uint32(len(f.Payload))) + if _, err := c.rw.Write(header); err != nil { + return err + } + _, err := c.rw.Write(f.Payload) + return err +} + +func (c *Conn) ReadJSONFrame(frameType uint16, out any) error { + f, err := c.ReadFrame() + if err != nil { + return err + } + if f.Type != frameType { + return fmt.Errorf("unexpected frame type %d, want %d", f.Type, frameType) + } + return c.UnmarshalFrame(f, out) +} + +func (c *Conn) UnmarshalFrame(f *Frame, out any) error { + return json.Unmarshal(f.Payload, out) +} + +func (c *Conn) WriteJSONFrame(frameType uint16, in any) error { + payload, err := json.Marshal(in) + if err != nil { + return err + } + return c.WriteFrame(&Frame{ + Type: frameType, + Payload: payload, + }) +} + +func WriteMagic(w io.Writer) error { + _, err := io.WriteString(w, MagicV2) + return err +} + +func WriteMagicIfV2(w io.Writer, wireProtocol string) error { + if wireProtocol != ProtocolV2 { + return nil + } + return WriteMagic(w) +} + +func CheckMagic(conn net.Conn) (out net.Conn, isV2 bool, err error) { + sharedConn, r := libnet.NewSharedConnSize(conn, len(MagicV2)) + buf := make([]byte, len(MagicV2)) + if _, err = io.ReadFull(r, buf); err != nil { + return nil, false, err + } + for i := range MagicV2 { + if buf[i] != MagicV2[i] { + return sharedConn, false, nil + } + } + return conn, true, nil +} + +type BootstrapInfo struct { + Transport string `json:"transport,omitempty"` + TLS bool `json:"tls,omitempty"` + TCPMux bool `json:"tcpMux,omitempty"` +} + +type ClientHello struct { + Bootstrap BootstrapInfo `json:"bootstrap,omitempty"` + Capabilities ClientCapabilities `json:"capabilities,omitempty"` +} + +type ClientCapabilities struct { + Message MessageCapabilities `json:"message,omitempty"` +} + +type MessageCapabilities struct { + Codecs []string `json:"codecs,omitempty"` +} + +type ServerHello struct { + Selected ServerSelection `json:"selected,omitempty"` + Error string `json:"error,omitempty"` +} + +type ServerSelection struct { + Message MessageSelection `json:"message,omitempty"` +} + +type MessageSelection struct { + Codec string `json:"codec,omitempty"` +} + +func DefaultClientHello(bootstrap BootstrapInfo) ClientHello { + return ClientHello{ + Bootstrap: bootstrap, + Capabilities: ClientCapabilities{ + Message: MessageCapabilities{ + Codecs: []string{MessageCodecJSON}, + }, + }, + } +} + +func DefaultServerHello() ServerHello { + return ServerHello{ + Selected: ServerSelection{ + Message: MessageSelection{ + Codec: MessageCodecJSON, + }, + }, + } +} + +func Supports(list []string, value string) bool { + return slices.Contains(list, value) +} + +func ValidateClientHello(h ClientHello) error { + if !Supports(h.Capabilities.Message.Codecs, MessageCodecJSON) { + return fmt.Errorf("unsupported message codec") + } + return nil +} diff --git a/pkg/proto/wire/wire_test.go b/pkg/proto/wire/wire_test.go new file mode 100644 index 00000000..19adb0be --- /dev/null +++ b/pkg/proto/wire/wire_test.go @@ -0,0 +1,120 @@ +// Copyright 2026 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wire + +import ( + "bytes" + "encoding/binary" + "io" + "net" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestFrameRoundTrip(t *testing.T) { + var buf bytes.Buffer + conn := NewConn(&buf) + + in := DefaultClientHello(BootstrapInfo{ + Transport: "tcp", + TLS: true, + TCPMux: true, + }) + require.NoError(t, conn.WriteJSONFrame(FrameTypeClientHello, in)) + + var out ClientHello + require.NoError(t, conn.ReadJSONFrame(FrameTypeClientHello, &out)) + require.Equal(t, in, out) +} + +func TestReadFrameRejectsUnsupportedFlags(t *testing.T) { + var buf bytes.Buffer + header := make([]byte, 8) + binary.BigEndian.PutUint16(header[0:2], FrameTypeMessage) + binary.BigEndian.PutUint16(header[2:4], 1) + binary.BigEndian.PutUint32(header[4:8], 0) + buf.Write(header) + + _, err := NewConn(&buf).ReadFrame() + require.ErrorContains(t, err, "unsupported frame flags") +} + +func TestReadFrameRejectsOversizedPayload(t *testing.T) { + var buf bytes.Buffer + header := make([]byte, 8) + binary.BigEndian.PutUint16(header[0:2], FrameTypeMessage) + binary.BigEndian.PutUint32(header[4:8], DefaultMaxFramePayloadSize+1) + buf.Write(header) + + _, err := NewConn(&buf).ReadFrame() + require.ErrorContains(t, err, "exceeds limit") +} + +func TestCheckMagicV2ConsumesMagic(t *testing.T) { + client, server := net.Pipe() + defer server.Close() + + want := []byte("payload") + go func() { + defer client.Close() + _, _ = client.Write(append([]byte(MagicV2), want...)) + }() + + out, isV2, err := CheckMagic(server) + require.NoError(t, err) + require.True(t, isV2) + + got := make([]byte, len(want)) + _, err = io.ReadFull(out, got) + require.NoError(t, err) + require.Equal(t, want, got) +} + +func TestWriteMagicIfV2(t *testing.T) { + var buf bytes.Buffer + require.NoError(t, WriteMagicIfV2(&buf, ProtocolV1)) + require.Empty(t, buf.Bytes()) + + require.NoError(t, WriteMagicIfV2(&buf, ProtocolV2)) + require.Equal(t, []byte(MagicV2), buf.Bytes()) +} + +func TestCheckMagicV1PreservesReadBytes(t *testing.T) { + client, server := net.Pipe() + defer server.Close() + + want := []byte("legacy payload") + go func() { + defer client.Close() + _, _ = client.Write(want) + }() + + out, isV2, err := CheckMagic(server) + require.NoError(t, err) + require.False(t, isV2) + + got, err := io.ReadAll(out) + require.NoError(t, err) + require.Equal(t, want, got) +} + +func TestValidateClientHello(t *testing.T) { + require.NoError(t, ValidateClientHello(DefaultClientHello(BootstrapInfo{}))) + + hello := DefaultClientHello(BootstrapInfo{}) + hello.Capabilities.Message.Codecs = []string{"unknown"} + require.ErrorContains(t, ValidateClientHello(hello), "unsupported message codec") +} diff --git a/pkg/util/net/conn.go b/pkg/util/net/conn.go index 914a7bb5..aa83b409 100644 --- a/pkg/util/net/conn.go +++ b/pkg/util/net/conn.go @@ -133,7 +133,7 @@ type CloseNotifyConn struct { net.Conn // 1 means closed - closeFlag int32 + closeFlag atomic.Int32 closeFn func(error) } @@ -147,7 +147,7 @@ func WrapCloseNotifyConn(c net.Conn, closeFn func(error)) *CloseNotifyConn { } func (cc *CloseNotifyConn) Close() (err error) { - pflag := atomic.SwapInt32(&cc.closeFlag, 1) + pflag := cc.closeFlag.Swap(1) if pflag == 0 { err = cc.Conn.Close() if cc.closeFn != nil { @@ -159,7 +159,7 @@ func (cc *CloseNotifyConn) Close() (err error) { // CloseWithError closes the connection and passes the error to the close callback. func (cc *CloseNotifyConn) CloseWithError(err error) error { - pflag := atomic.SwapInt32(&cc.closeFlag, 1) + pflag := cc.closeFlag.Swap(1) if pflag == 0 { closeErr := cc.Conn.Close() if cc.closeFn != nil { @@ -173,7 +173,7 @@ func (cc *CloseNotifyConn) CloseWithError(err error) error { type StatsConn struct { net.Conn - closed int64 // 1 means closed + closed atomic.Int64 // 1 means closed totalRead int64 totalWrite int64 statsFunc func(totalRead, totalWrite int64) @@ -199,7 +199,7 @@ func (statsConn *StatsConn) Write(p []byte) (n int, err error) { } func (statsConn *StatsConn) Close() (err error) { - old := atomic.SwapInt64(&statsConn.closed, 1) + old := statsConn.closed.Swap(1) if old != 1 { err = statsConn.Conn.Close() if statsConn.statsFunc != nil { diff --git a/server/control.go b/server/control.go index 314669ff..d7b41c43 100644 --- a/server/control.go +++ b/server/control.go @@ -17,7 +17,6 @@ package server import ( "context" "fmt" - "net" "runtime/debug" "sync" "sync/atomic" @@ -32,9 +31,7 @@ import ( "github.com/fatedier/frp/pkg/msg" plugin "github.com/fatedier/frp/pkg/plugin/server" "github.com/fatedier/frp/pkg/transport" - netpkg "github.com/fatedier/frp/pkg/util/net" "github.com/fatedier/frp/pkg/util/util" - "github.com/fatedier/frp/pkg/util/version" "github.com/fatedier/frp/pkg/util/wait" "github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/frp/server/controller" @@ -108,9 +105,7 @@ type SessionContext struct { // key used for connection encryption EncryptionKey []byte // control connection - Conn net.Conn - // indicates whether the connection is encrypted - ConnEncrypted bool + Conn *msg.Conn // login message LoginMsg *msg.Login // server configuration @@ -131,7 +126,7 @@ type Control struct { msgDispatcher *msg.Dispatcher // work connections - workConnCh chan net.Conn + workConnCh chan *proxy.WorkConn // proxies in one client proxies map[string]proxy.Proxy @@ -161,7 +156,7 @@ func NewControl(ctx context.Context, sessionCtx *SessionContext) (*Control, erro poolCount := min(sessionCtx.LoginMsg.PoolCount, int(sessionCtx.ServerCfg.Transport.MaxPoolCount)) ctl := &Control{ sessionCtx: sessionCtx, - workConnCh: make(chan net.Conn, poolCount+10), + workConnCh: make(chan *proxy.WorkConn, poolCount+10), proxies: make(map[string]proxy.Proxy), poolCount: poolCount, portsUsedNum: 0, @@ -172,29 +167,14 @@ func NewControl(ctx context.Context, sessionCtx *SessionContext) (*Control, erro } ctl.lastPing.Store(time.Now()) - if sessionCtx.ConnEncrypted { - cryptoRW, err := netpkg.NewCryptoReadWriter(sessionCtx.Conn, sessionCtx.EncryptionKey) - if err != nil { - return nil, err - } - ctl.msgDispatcher = msg.NewDispatcher(cryptoRW) - } else { - ctl.msgDispatcher = msg.NewDispatcher(sessionCtx.Conn) - } + ctl.msgDispatcher = msg.NewDispatcher(sessionCtx.Conn) ctl.registerMsgHandlers() ctl.msgTransporter = transport.NewMessageTransporter(ctl.msgDispatcher) return ctl, nil } -// Start send a login success message to client and start working. +// Start starts the control session workers after login succeeds. func (ctl *Control) Start() { - loginRespMsg := &msg.LoginResp{ - Version: version.Full(), - RunID: ctl.runID, - Error: "", - } - _ = msg.WriteMsg(ctl.sessionCtx.Conn, loginRespMsg) - go func() { for i := 0; i < ctl.poolCount; i++ { // ignore error here, that means that this control is closed @@ -216,7 +196,7 @@ func (ctl *Control) Replaced(newCtl *Control) { ctl.sessionCtx.Conn.Close() } -func (ctl *Control) RegisterWorkConn(conn net.Conn) error { +func (ctl *Control) RegisterWorkConn(conn *proxy.WorkConn) error { xl := ctl.xl defer func() { if err := recover(); err != nil { @@ -239,7 +219,7 @@ func (ctl *Control) RegisterWorkConn(conn net.Conn) error { // If no workConn available in the pool, send message to frpc to get one or more // and wait until it is available. // return an error if wait timeout -func (ctl *Control) GetWorkConn() (workConn net.Conn, err error) { +func (ctl *Control) GetWorkConn() (workConn *proxy.WorkConn, err error) { xl := ctl.xl defer func() { if err := recover(); err != nil { diff --git a/server/group/http.go b/server/group/http.go index dd905581..21bbfd29 100644 --- a/server/group/http.go +++ b/server/group/http.go @@ -57,7 +57,7 @@ type HTTPGroup struct { // CreateConnFuncs indexed by proxy name createFuncs map[string]vhost.CreateConnFunc pxyNames []string - index uint64 + index atomic.Uint64 ctl *HTTPGroupController mu sync.RWMutex } @@ -136,7 +136,7 @@ func (g *HTTPGroup) UnRegister(proxyName string) { func (g *HTTPGroup) createConn(remoteAddr string) (net.Conn, error) { var f vhost.CreateConnFunc - newIndex := atomic.AddUint64(&g.index, 1) + newIndex := g.index.Add(1) g.mu.RLock() group := g.group @@ -158,7 +158,7 @@ func (g *HTTPGroup) createConn(remoteAddr string) (net.Conn, error) { } func (g *HTTPGroup) chooseEndpoint() (string, error) { - newIndex := atomic.AddUint64(&g.index, 1) + newIndex := g.index.Add(1) name := "" g.mu.RLock() diff --git a/server/http/controller.go b/server/http/controller.go index a1842788..65f6c00f 100644 --- a/server/http/controller.go +++ b/server/http/controller.go @@ -287,6 +287,7 @@ func buildClientInfoResp(info registry.ClientInfo) model.ClientInfoResp { ClientID: info.ClientID(), RunID: info.RunID, Version: info.Version, + WireProtocol: info.WireProtocol, Hostname: info.Hostname, ClientIP: info.IP, FirstConnectedAt: toUnix(info.FirstConnectedAt), diff --git a/server/http/controller_test.go b/server/http/controller_test.go index 3ef50776..44509371 100644 --- a/server/http/controller_test.go +++ b/server/http/controller_test.go @@ -17,8 +17,11 @@ package http import ( "encoding/json" "testing" + "time" v1 "github.com/fatedier/frp/pkg/config/v1" + "github.com/fatedier/frp/pkg/proto/wire" + "github.com/fatedier/frp/server/registry" ) func TestGetConfFromConfigurerKeepsPluginFields(t *testing.T) { @@ -69,3 +72,24 @@ func TestGetConfFromConfigurerKeepsPluginFields(t *testing.T) { t.Fatalf("plugin httpPassword mismatch, want %q got %#v", "password", got) } } + +func TestBuildClientInfoRespIncludesWireProtocol(t *testing.T) { + info := registry.ClientInfo{ + Key: "user.client", + User: "user", + RawClientID: "client", + RunID: "run-id", + Version: "1.0.0", + WireProtocol: wire.ProtocolV2, + Hostname: "host", + IP: "127.0.0.1", + FirstConnectedAt: time.Unix(1, 0), + LastConnectedAt: time.Unix(2, 0), + Online: true, + } + + resp := buildClientInfoResp(info) + if resp.WireProtocol != wire.ProtocolV2 { + t.Fatalf("wire protocol mismatch, want %q got %q", wire.ProtocolV2, resp.WireProtocol) + } +} diff --git a/server/http/model/types.go b/server/http/model/types.go index 92467e4b..010c5079 100644 --- a/server/http/model/types.go +++ b/server/http/model/types.go @@ -46,6 +46,7 @@ type ClientInfoResp struct { ClientID string `json:"clientID"` RunID string `json:"runID"` Version string `json:"version,omitempty"` + WireProtocol string `json:"wireProtocol,omitempty"` Hostname string `json:"hostname"` ClientIP string `json:"clientIP,omitempty"` FirstConnectedAt int64 `json:"firstConnectedAt"` diff --git a/server/proxy/proxy.go b/server/proxy/proxy.go index 5b7898eb..7175a606 100644 --- a/server/proxy/proxy.go +++ b/server/proxy/proxy.go @@ -44,7 +44,26 @@ func RegisterProxyFactory(proxyConfType reflect.Type, factory func(*BaseProxy) P proxyFactoryRegistry[proxyConfType] = factory } -type GetWorkConnFn func() (net.Conn, error) +type WorkConn struct { + conn *msg.Conn +} + +func NewWorkConn(conn *msg.Conn) *WorkConn { + return &WorkConn{conn: conn} +} + +func (c *WorkConn) Start(m *msg.StartWorkConn) (net.Conn, error) { + if err := c.conn.WriteMsg(m); err != nil { + return nil, err + } + return c.conn, nil +} + +func (c *WorkConn) Close() error { + return c.conn.Close() +} + +type GetWorkConnFn func() (*WorkConn, error) type Proxy interface { Context() context.Context @@ -125,13 +144,13 @@ func (pxy *BaseProxy) GetWorkConnFromPool(src, dst net.Addr) (workConn net.Conn, xl := xlog.FromContextSafe(pxy.ctx) // try all connections from the pool for i := 0; i < pxy.poolCount+1; i++ { - if workConn, err = pxy.getWorkConnFn(); err != nil { + var pxyWorkConn *WorkConn + if pxyWorkConn, err = pxy.getWorkConnFn(); err != nil { xl.Warnf("failed to get work connection: %v", err) return } - xl.Debugf("get a new work connection: [%s]", workConn.RemoteAddr().String()) + xl.Debugf("get a new work connection: [%s]", pxyWorkConn.conn.RemoteAddr().String()) xl.Spawn().AppendPrefix(pxy.GetName()) - workConn = netpkg.NewContextConn(pxy.ctx, workConn) var ( srcAddr string @@ -150,7 +169,7 @@ func (pxy *BaseProxy) GetWorkConnFromPool(src, dst net.Addr) (workConn net.Conn, dstAddr, dstPortStr, _ = net.SplitHostPort(dst.String()) dstPort, _ = strconv.ParseUint(dstPortStr, 10, 16) } - err = msg.WriteMsg(workConn, &msg.StartWorkConn{ + workConn, err = pxyWorkConn.Start(&msg.StartWorkConn{ ProxyName: pxy.GetName(), SrcAddr: srcAddr, SrcPort: uint16(srcPort), @@ -160,9 +179,10 @@ func (pxy *BaseProxy) GetWorkConnFromPool(src, dst net.Addr) (workConn net.Conn, }) if err != nil { xl.Warnf("failed to send message to work connection from pool: %v, times: %d", err, i) - workConn.Close() + pxyWorkConn.Close() workConn = nil } else { + workConn = netpkg.NewContextConn(pxy.ctx, workConn) break } } diff --git a/server/proxy/proxy_test.go b/server/proxy/proxy_test.go new file mode 100644 index 00000000..38f9dd42 --- /dev/null +++ b/server/proxy/proxy_test.go @@ -0,0 +1,53 @@ +// Copyright 2026 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "net" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/fatedier/frp/pkg/msg" +) + +func TestWorkConnStartWritesStartWorkConn(t *testing.T) { + client, server := net.Pipe() + defer client.Close() + defer server.Close() + + serverMsgConn := msg.NewConn(server, msg.NewV2ReadWriter(server)) + clientMsgConn := msg.NewConn(client, msg.NewV2ReadWriter(client)) + workConn := NewWorkConn(serverMsgConn) + + in := &msg.StartWorkConn{ProxyName: "tcp", SrcAddr: "127.0.0.1", SrcPort: 1234} + type startResult struct { + conn net.Conn + err error + } + resultCh := make(chan startResult, 1) + go func() { + conn, err := workConn.Start(in) + resultCh <- startResult{conn: conn, err: err} + }() + + out, err := clientMsgConn.ReadMsg() + require.NoError(t, err) + require.Equal(t, in, out) + + result := <-resultCh + require.NoError(t, result.err) + require.Same(t, serverMsgConn, result.conn) +} diff --git a/server/registry/registry.go b/server/registry/registry.go index 01c44947..c521d1c2 100644 --- a/server/registry/registry.go +++ b/server/registry/registry.go @@ -29,6 +29,7 @@ type ClientInfo struct { Hostname string IP string Version string + WireProtocol string FirstConnectedAt time.Time LastConnectedAt time.Time DisconnectedAt time.Time @@ -51,7 +52,7 @@ func NewClientRegistry() *ClientRegistry { } // Register stores/updates metadata for a client and returns the registry key plus whether it conflicts with an online client. -func (cr *ClientRegistry) Register(user, rawClientID, runID, hostname, version, remoteAddr string) (key string, conflict bool) { +func (cr *ClientRegistry) Register(user, rawClientID, runID, hostname, version, remoteAddr, wireProtocol string) (key string, conflict bool) { if runID == "" { return "", false } @@ -88,6 +89,7 @@ func (cr *ClientRegistry) Register(user, rawClientID, runID, hostname, version, info.Hostname = hostname info.IP = remoteAddr info.Version = version + info.WireProtocol = wireProtocol if info.FirstConnectedAt.IsZero() { info.FirstConnectedAt = now } diff --git a/server/registry/registry_test.go b/server/registry/registry_test.go new file mode 100644 index 00000000..cf428964 --- /dev/null +++ b/server/registry/registry_test.go @@ -0,0 +1,37 @@ +// Copyright 2026 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package registry + +import ( + "testing" + + "github.com/fatedier/frp/pkg/proto/wire" +) + +func TestClientRegistryRegisterStoresWireProtocol(t *testing.T) { + registry := NewClientRegistry() + key, conflict := registry.Register("user", "client-id", "run-id", "host", "1.0.0", "127.0.0.1", wire.ProtocolV2) + if conflict { + t.Fatal("unexpected client conflict") + } + + info, ok := registry.GetByKey(key) + if !ok { + t.Fatalf("client %q not found", key) + } + if info.WireProtocol != wire.ProtocolV2 { + t.Fatalf("wire protocol mismatch, want %q got %q", wire.ProtocolV2, info.WireProtocol) + } +} diff --git a/server/service.go b/server/service.go index 28ccb451..5077c21f 100644 --- a/server/service.go +++ b/server/service.go @@ -19,6 +19,7 @@ import ( "context" "crypto/tls" "fmt" + "io" "net" "net/http" "os" @@ -37,6 +38,7 @@ import ( "github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/nathole" plugin "github.com/fatedier/frp/pkg/plugin/server" + "github.com/fatedier/frp/pkg/proto/wire" "github.com/fatedier/frp/pkg/ssh" "github.com/fatedier/frp/pkg/transport" httppkg "github.com/fatedier/frp/pkg/util/http" @@ -432,20 +434,15 @@ func (svr *Service) Close() error { func (svr *Service) handleConnection(ctx context.Context, conn net.Conn, internal bool) { xl := xlog.FromContextSafe(ctx) - var ( - rawMsg msg.Message - err error - ) - - _ = conn.SetReadDeadline(time.Now().Add(connReadTimeout)) - if rawMsg, err = msg.ReadMsg(conn); err != nil { - log.Tracef("failed to read message: %v", err) + acceptedConn, err := svr.acceptConnection(ctx, conn) + if err != nil { + log.Tracef("failed to accept frp connection: %v", err) conn.Close() return } - _ = conn.SetReadDeadline(time.Time{}) + conn = acceptedConn.conn - switch m := rawMsg.(type) { + switch m := acceptedConn.firstMsg.(type) { case *msg.Login: // server plugin hook content := &plugin.LoginContent{ @@ -453,35 +450,66 @@ func (svr *Service) handleConnection(ctx context.Context, conn net.Conn, interna ClientAddress: conn.RemoteAddr().String(), } retContent, err := svr.pluginManager.Login(content) + var ctl *Control if err == nil { m = &retContent.Login - err = svr.RegisterControl(conn, m, internal) + controlConn := acceptedConn.conn + if !internal { + var controlRW io.ReadWriter + controlRW, err = netpkg.NewCryptoReadWriter(conn, svr.auth.EncryptionKey()) + if err == nil { + controlConn = acceptedConn.messageConnFor(controlRW) + } + } + if err == nil { + ctl, err = svr.RegisterControl(controlConn, m, internal, acceptedConn.wireProtocol) + } } - // If login failed, send error message there. - // Otherwise send success message in control's work goroutine. if err != nil { xl.Warnf("register control error: %v", err) - _ = msg.WriteMsg(conn, &msg.LoginResp{ + _ = acceptedConn.conn.WriteMsg(&msg.LoginResp{ Version: version.Full(), Error: util.GenerateResponseErrorString("register control error", err, lo.FromPtr(svr.cfg.DetailedErrorsToClient)), }) conn.Close() + return } + if err = acceptedConn.conn.WriteMsg(&msg.LoginResp{ + Version: version.Full(), + RunID: ctl.runID, + Error: "", + }); err != nil { + xl.Warnf("write login response error: %v", err) + svr.ctlManager.Del(m.RunID, ctl) + svr.clientRegistry.MarkOfflineByRunID(m.RunID) + conn.Close() + return + } + ctl.Start() + metrics.Server.NewClient() + go func() { + // block until control closed + ctl.WaitClosed() + svr.ctlManager.Del(m.RunID, ctl) + }() case *msg.NewWorkConn: - if err := svr.RegisterWorkConn(conn, m); err != nil { + if err := svr.RegisterWorkConn(acceptedConn.conn, m); err != nil { + _ = acceptedConn.conn.WriteMsg(&msg.StartWorkConn{ + Error: util.GenerateResponseErrorString("invalid NewWorkConn", err, lo.FromPtr(svr.cfg.DetailedErrorsToClient)), + }) conn.Close() } case *msg.NewVisitorConn: if err = svr.RegisterVisitorConn(conn, m); err != nil { xl.Warnf("register visitor conn error: %v", err) - _ = msg.WriteMsg(conn, &msg.NewVisitorConnResp{ + _ = acceptedConn.conn.WriteMsg(&msg.NewVisitorConnResp{ ProxyName: m.ProxyName, Error: util.GenerateResponseErrorString("register visitor conn error", err, lo.FromPtr(svr.cfg.DetailedErrorsToClient)), }) conn.Close() } else { - _ = msg.WriteMsg(conn, &msg.NewVisitorConnResp{ + _ = acceptedConn.conn.WriteMsg(&msg.NewVisitorConnResp{ ProxyName: m.ProxyName, Error: "", }) @@ -492,6 +520,87 @@ func (svr *Service) handleConnection(ctx context.Context, conn net.Conn, interna } } +type acceptedConnection struct { + conn *msg.Conn + wireProtocol string + firstMsg msg.Message +} + +func (svr *Service) acceptConnection(ctx context.Context, conn net.Conn) (*acceptedConnection, error) { + _ = conn.SetReadDeadline(time.Now().Add(connReadTimeout)) + checkedConn, isV2, err := wire.CheckMagic(conn) + if err != nil { + return nil, fmt.Errorf("read wire protocol magic: %w", err) + } + + wireProtocol := wire.ProtocolV1 + if isV2 { + wireProtocol = wire.ProtocolV2 + } + + conn = netpkg.NewContextConn(ctx, checkedConn) + acceptedConn := &acceptedConnection{wireProtocol: wireProtocol} + if isV2 { + wireConn := wire.NewConn(conn) + rw := msg.NewV2ReadWriterWithConn(wireConn) + acceptedConn.conn = msg.NewConn(conn, rw) + acceptedConn.firstMsg, err = acceptedConn.readFirstV2Msg(wireConn) + } else { + rw := msg.NewV1ReadWriter(conn) + acceptedConn.conn = msg.NewConn(conn, rw) + acceptedConn.firstMsg, err = acceptedConn.conn.ReadMsg() + } + if err != nil { + return nil, err + } + _ = conn.SetReadDeadline(time.Time{}) + return acceptedConn, nil +} + +func (ac *acceptedConnection) messageConnFor(rw io.ReadWriter) *msg.Conn { + return msg.NewConn(ac.conn, msg.NewReadWriter(rw, ac.wireProtocol)) +} + +func (ac *acceptedConnection) readFirstV2Msg(wireConn *wire.Conn) (msg.Message, error) { + frame, err := wireConn.ReadFrame() + if err != nil { + return nil, fmt.Errorf("read v2 frame: %w", err) + } + if frame.Type == wire.FrameTypeClientHello { + if err := ac.handleClientHello(wireConn, frame); err != nil { + return nil, err + } + frame, err = wireConn.ReadFrame() + if err != nil { + return nil, fmt.Errorf("read first v2 message frame: %w", err) + } + } + + m, err := msg.DecodeV2MessageFrame(frame) + if err != nil { + return nil, fmt.Errorf("decode v2 message: %w", err) + } + return m, nil +} + +func (ac *acceptedConnection) handleClientHello(wireConn *wire.Conn, frame *wire.Frame) error { + var hello wire.ClientHello + if err := wireConn.UnmarshalFrame(frame, &hello); err != nil { + return fmt.Errorf("decode ClientHello: %w", err) + } + + serverHello := wire.DefaultServerHello() + if err := wire.ValidateClientHello(hello); err != nil { + serverHello.Error = err.Error() + _ = wireConn.WriteJSONFrame(wire.FrameTypeServerHello, serverHello) + return err + } + if err := wireConn.WriteJSONFrame(wire.FrameTypeServerHello, serverHello); err != nil { + return fmt.Errorf("write ServerHello: %w", err) + } + return nil +} + // HandleListener accepts connections from client and call handleConnection to handle them. // If internal is true, it means that this listener is used for internal communication like ssh tunnel gateway. // TODO(fatedier): Pass some parameters of listener/connection through context to avoid passing too many parameters. @@ -577,14 +686,19 @@ func (svr *Service) HandleQUICListener(l *quic.Listener) { } } -func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, internal bool) error { +func (svr *Service) RegisterControl( + ctlConn *msg.Conn, + loginMsg *msg.Login, + internal bool, + wireProtocol string, +) (*Control, error) { // If client's RunID is empty, it's a new client, we just create a new controller. // Otherwise, we check if there is one controller has the same run id. If so, we release previous controller and start new one. var err error if loginMsg.RunID == "" { loginMsg.RunID, err = util.RandID() if err != nil { - return err + return nil, err } } @@ -601,7 +715,7 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, inter authVerifier = auth.AlwaysPassVerifier } if err := authVerifier.VerifyLogin(loginMsg); err != nil { - return err + return nil, err } ctl, err := NewControl(ctx, &SessionContext{ @@ -611,7 +725,6 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, inter AuthVerifier: authVerifier, EncryptionKey: svr.auth.EncryptionKey(), Conn: ctlConn, - ConnEncrypted: !internal, LoginMsg: loginMsg, ServerCfg: svr.cfg, ClientRegistry: svr.clientRegistry, @@ -619,7 +732,7 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, inter if err != nil { xl.Warnf("create new controller error: %v", err) // don't return detailed errors to client - return fmt.Errorf("unexpected error when creating new controller") + return nil, fmt.Errorf("unexpected error when creating new controller") } if oldCtl := svr.ctlManager.Add(loginMsg.RunID, ctl); oldCtl != nil { @@ -630,34 +743,24 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, inter if host, _, err := net.SplitHostPort(remoteAddr); err == nil { remoteAddr = host } - _, conflict := svr.clientRegistry.Register(loginMsg.User, loginMsg.ClientID, loginMsg.RunID, loginMsg.Hostname, loginMsg.Version, remoteAddr) + _, conflict := svr.clientRegistry.Register(loginMsg.User, loginMsg.ClientID, loginMsg.RunID, loginMsg.Hostname, loginMsg.Version, remoteAddr, wireProtocol) if conflict { svr.ctlManager.Del(loginMsg.RunID, ctl) - ctl.Close() - return fmt.Errorf("client_id [%s] for user [%s] is already online", loginMsg.ClientID, loginMsg.User) + return nil, fmt.Errorf("client_id [%s] for user [%s] is already online", loginMsg.ClientID, loginMsg.User) } - ctl.Start() - - // for statistics - metrics.Server.NewClient() - - go func() { - // block until control closed - ctl.WaitClosed() - svr.ctlManager.Del(loginMsg.RunID, ctl) - }() - return nil + return ctl, nil } // RegisterWorkConn register a new work connection to control and proxies need it. -func (svr *Service) RegisterWorkConn(workConn net.Conn, newMsg *msg.NewWorkConn) error { +func (svr *Service) RegisterWorkConn(workConn *msg.Conn, newMsg *msg.NewWorkConn) error { xl := netpkg.NewLogFromConn(workConn) ctl, exist := svr.ctlManager.GetByID(newMsg.RunID) if !exist { xl.Warnf("no client control found for run id [%s]", newMsg.RunID) return fmt.Errorf("no client control found for run id [%s]", newMsg.RunID) } + // server plugin hook content := &plugin.NewWorkConnContent{ User: plugin.UserInfo{ @@ -675,12 +778,9 @@ func (svr *Service) RegisterWorkConn(workConn net.Conn, newMsg *msg.NewWorkConn) } if err != nil { xl.Warnf("invalid NewWorkConn with run id [%s]", newMsg.RunID) - _ = msg.WriteMsg(workConn, &msg.StartWorkConn{ - Error: util.GenerateResponseErrorString("invalid NewWorkConn", err, lo.FromPtr(svr.cfg.DetailedErrorsToClient)), - }) - return fmt.Errorf("invalid NewWorkConn with run id [%s]", newMsg.RunID) + return err } - return ctl.RegisterWorkConn(workConn) + return ctl.RegisterWorkConn(proxy.NewWorkConn(workConn)) } func (svr *Service) RegisterVisitorConn(visitorConn net.Conn, newMsg *msg.NewVisitorConn) error { diff --git a/test/e2e/mock/server/oidcserver/oidcserver.go b/test/e2e/mock/server/oidcserver/oidcserver.go index d7aa1329..22236dc0 100644 --- a/test/e2e/mock/server/oidcserver/oidcserver.go +++ b/test/e2e/mock/server/oidcserver/oidcserver.go @@ -53,6 +53,8 @@ type Server struct { tokenRequestCount atomic.Int64 } +const maxTokenRequestBodySize = 1 << 20 + type Option func(*Server) func WithBindPort(port int) Option { @@ -178,6 +180,7 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) { http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } + r.Body = http.MaxBytesReader(w, r.Body, maxTokenRequestBodySize) if err := r.ParseForm(); err != nil { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) @@ -187,7 +190,7 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) { return } - if r.FormValue("grant_type") != "client_credentials" { + if r.Form.Get("grant_type") != "client_credentials" { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) _ = json.NewEncoder(w).Encode(map[string]any{ @@ -199,8 +202,8 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) { // Accept credentials from Basic Auth or form body. clientID, clientSecret, ok := r.BasicAuth() if !ok { - clientID = r.FormValue("client_id") - clientSecret = r.FormValue("client_secret") + clientID = r.Form.Get("client_id") + clientSecret = r.Form.Get("client_secret") } if clientID != s.clientID || clientSecret != s.clientSecret { w.Header().Set("Content-Type", "application/json") diff --git a/test/e2e/v1/basic/wire.go b/test/e2e/v1/basic/wire.go new file mode 100644 index 00000000..bbae9199 --- /dev/null +++ b/test/e2e/v1/basic/wire.go @@ -0,0 +1,163 @@ +package basic + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/onsi/ginkgo/v2" + + "github.com/fatedier/frp/test/e2e/framework" + "github.com/fatedier/frp/test/e2e/framework/consts" + "github.com/fatedier/frp/test/e2e/pkg/port" +) + +var _ = ginkgo.Describe("[Feature: WireProtocol]", func() { + f := framework.NewDefaultFramework() + + ginkgo.It("v1 tcp and udp proxies", func() { + serverConf := consts.DefaultServerConfig + tcpPortName := port.GenName("WireV1TCP") + udpPortName := port.GenName("WireV1UDP") + clientConf := consts.DefaultClientConfig + fmt.Sprintf(` + transport.wireProtocol = "v1" + + [[proxies]] + name = "tcp" + type = "tcp" + localPort = {{ .%s }} + remotePort = {{ .%s }} + + [[proxies]] + name = "udp" + type = "udp" + localPort = {{ .%s }} + remotePort = {{ .%s }} + `, framework.TCPEchoServerPort, tcpPortName, framework.UDPEchoServerPort, udpPortName) + + f.RunProcesses(serverConf, []string{clientConf}) + + framework.NewRequestExpect(f).PortName(tcpPortName).Ensure() + framework.NewRequestExpect(f).Protocol("udp").PortName(udpPortName).Ensure() + }) + + ginkgo.It("v2 tcp and udp proxies", func() { + serverConf := consts.DefaultServerConfig + tcpPortName := port.GenName("WireV2TCP") + udpPortName := port.GenName("WireV2UDP") + clientConf := consts.DefaultClientConfig + fmt.Sprintf(` + transport.wireProtocol = "v2" + + [[proxies]] + name = "tcp" + type = "tcp" + localPort = {{ .%s }} + remotePort = {{ .%s }} + + [[proxies]] + name = "udp" + type = "udp" + localPort = {{ .%s }} + remotePort = {{ .%s }} + `, framework.TCPEchoServerPort, tcpPortName, framework.UDPEchoServerPort, udpPortName) + + f.RunProcesses(serverConf, []string{clientConf}) + + framework.NewRequestExpect(f).PortName(tcpPortName).Ensure() + framework.NewRequestExpect(f).Protocol("udp").PortName(udpPortName).Ensure() + }) + + ginkgo.It("v2 stcp visitor", func() { + serverConf := consts.DefaultServerConfig + bindPortName := port.GenName("WireV2STCP") + clientServerConf := consts.DefaultClientConfig + fmt.Sprintf(` + user = "user1" + transport.wireProtocol = "v2" + + [[proxies]] + name = "stcp" + type = "stcp" + secretKey = "abc" + localPort = {{ .%s }} + `, framework.TCPEchoServerPort) + clientVisitorConf := consts.DefaultClientConfig + fmt.Sprintf(` + user = "user1" + transport.wireProtocol = "v2" + + [[visitors]] + name = "stcp-visitor" + type = "stcp" + serverName = "stcp" + secretKey = "abc" + bindPort = {{ .%s }} + `, bindPortName) + + f.RunProcesses(serverConf, []string{clientServerConf, clientVisitorConf}) + + framework.NewRequestExpect(f).PortName(bindPortName).Ensure() + }) + + ginkgo.It("reports client wire protocol", func() { + webPort := f.AllocPort() + serverConf := consts.DefaultServerConfig + fmt.Sprintf(` + webServer.port = %d + `, webPort) + + v1PortName := port.GenName("WireReportV1") + v1ClientConf := consts.DefaultClientConfig + fmt.Sprintf(` + clientID = "wire-v1" + transport.wireProtocol = "v1" + + [[proxies]] + name = "v1" + type = "tcp" + localPort = {{ .%s }} + remotePort = {{ .%s }} + `, framework.TCPEchoServerPort, v1PortName) + + v2PortName := port.GenName("WireReportV2") + v2ClientConf := consts.DefaultClientConfig + fmt.Sprintf(` + clientID = "wire-v2" + transport.wireProtocol = "v2" + + [[proxies]] + name = "v2" + type = "tcp" + localPort = {{ .%s }} + remotePort = {{ .%s }} + `, framework.TCPEchoServerPort, v2PortName) + + f.RunProcesses(serverConf, []string{v1ClientConf, v2ClientConf}) + + framework.NewRequestExpect(f).PortName(v1PortName).Ensure() + framework.NewRequestExpect(f).PortName(v2PortName).Ensure() + expectClientWireProtocol(webPort, "wire-v1", "v1") + expectClientWireProtocol(webPort, "wire-v2", "v2") + }) +}) + +type wireClientInfo struct { + ClientID string `json:"clientID"` + WireProtocol string `json:"wireProtocol"` +} + +func expectClientWireProtocol(webPort int, clientID string, wireProtocol string) { + resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d/api/clients", webPort)) + framework.ExpectNoError(err) + defer resp.Body.Close() + framework.ExpectEqual(resp.StatusCode, 200) + + content, err := io.ReadAll(resp.Body) + framework.ExpectNoError(err) + + var clients []wireClientInfo + framework.ExpectNoError(json.Unmarshal(content, &clients)) + for _, client := range clients { + if client.ClientID == clientID { + framework.ExpectEqual(client.WireProtocol, wireProtocol) + return + } + } + framework.Failf("client %q not found in /api/clients response: %s", clientID, string(content)) +} diff --git a/web/frps/src/components/ClientCard.vue b/web/frps/src/components/ClientCard.vue index 531a210f..4ccd7441 100644 --- a/web/frps/src/components/ClientCard.vue +++ b/web/frps/src/components/ClientCard.vue @@ -16,6 +16,9 @@ v{{ client.version }} + + {{ client.wireProtocolLabel }} +
diff --git a/web/frps/src/types/client.ts b/web/frps/src/types/client.ts index 31d846a7..85413ff0 100644 --- a/web/frps/src/types/client.ts +++ b/web/frps/src/types/client.ts @@ -4,6 +4,7 @@ export interface ClientInfoData { clientID: string runID: string version?: string + wireProtocol?: string hostname: string clientIP?: string metas?: Record diff --git a/web/frps/src/utils/client.ts b/web/frps/src/utils/client.ts index d7f95a3f..8d26eeb8 100644 --- a/web/frps/src/utils/client.ts +++ b/web/frps/src/utils/client.ts @@ -7,6 +7,7 @@ export class Client { clientID: string runID: string version: string + wireProtocol: string hostname: string ip: string metas: Map @@ -21,6 +22,7 @@ export class Client { this.clientID = data.clientID this.runID = data.runID this.version = data.version || '' + this.wireProtocol = data.wireProtocol || '' this.hostname = data.hostname this.ip = data.clientIP || '' this.metas = new Map() @@ -48,6 +50,11 @@ export class Client { return this.runID.substring(0, 8) } + get wireProtocolLabel(): string { + if (!this.wireProtocol) return '' + return `Protocol ${this.wireProtocol}` + } + get firstConnectedAgo(): string { return formatDistanceToNow(this.firstConnectedAt) } @@ -80,6 +87,7 @@ export class Client { this.user.toLowerCase().includes(search) || this.clientID.toLowerCase().includes(search) || this.runID.toLowerCase().includes(search) || + this.wireProtocol.toLowerCase().includes(search) || this.hostname.toLowerCase().includes(search) ) } diff --git a/web/frps/src/views/ClientDetail.vue b/web/frps/src/views/ClientDetail.vue index 1468a243..42c9a25e 100644 --- a/web/frps/src/views/ClientDetail.vue +++ b/web/frps/src/views/ClientDetail.vue @@ -27,6 +27,9 @@ v{{ client.version }} + + {{ client.wireProtocolLabel }} +
{{ @@ -58,6 +61,10 @@ Run ID {{ client.runID }}
+
+ Protocol + {{ client.wireProtocol }} +
First Connected {{ client.firstConnectedAgo }} diff --git a/web/package-lock.json b/web/package-lock.json index bd274637..d9a66bc8 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -5753,9 +5753,9 @@ } }, "node_modules/postcss": { - "version": "8.5.8", - "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.8.tgz", - "integrity": "sha512-OW/rX8O/jXnm82Ey1k44pObPtdblfiuWnrd8X7GJ7emImCOstunGbXUpp7HdBrFQX6rJzn3sPT397Wp5aCwCHg==", + "version": "8.5.12", + "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.12.tgz", + "integrity": "sha512-W62t/Se6rA0Az3DfCL0AqJwXuKwBeYg6nOaIgzP+xZ7N5BFCI7DYi1qs6ygUYT6rvfi6t9k65UMLJC+PHZpDAA==", "funding": [ { "type": "opencollective", @@ -5770,7 +5770,6 @@ "url": "https://github.com/sponsors/ai" } ], - "license": "MIT", "dependencies": { "nanoid": "^3.3.11", "picocolors": "^1.1.1",