mirror of
https://github.com/fatedier/frp.git
synced 2026-04-28 12:09:10 +08:00
protocol: add v2 wire protocol with binary framing and capability negotiation (#5294)
This commit is contained in:
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -32,4 +32,4 @@ jobs:
|
|||||||
uses: golangci/golangci-lint-action@v9
|
uses: golangci/golangci-lint-action@v9
|
||||||
with:
|
with:
|
||||||
# Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
|
# Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
|
||||||
version: v2.10
|
version: v2.11
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ linters:
|
|||||||
disabled-checks:
|
disabled-checks:
|
||||||
- exitAfterDefer
|
- exitAfterDefer
|
||||||
gosec:
|
gosec:
|
||||||
excludes: ["G115", "G117", "G204", "G401", "G402", "G404", "G501", "G703", "G704", "G705"]
|
excludes: ["G115", "G117", "G118", "G204", "G401", "G402", "G404", "G501", "G703", "G704", "G705"]
|
||||||
severity: low
|
severity: low
|
||||||
confidence: low
|
confidence: low
|
||||||
govet:
|
govet:
|
||||||
|
|||||||
@@ -29,6 +29,8 @@ import (
|
|||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
|
|
||||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||||
|
"github.com/fatedier/frp/pkg/msg"
|
||||||
|
"github.com/fatedier/frp/pkg/proto/wire"
|
||||||
"github.com/fatedier/frp/pkg/transport"
|
"github.com/fatedier/frp/pkg/transport"
|
||||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||||
"github.com/fatedier/frp/pkg/util/xlog"
|
"github.com/fatedier/frp/pkg/util/xlog"
|
||||||
@@ -41,6 +43,39 @@ type Connector interface {
|
|||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type MessageConnector interface {
|
||||||
|
Connect() (*msg.Conn, error)
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
type messageConnector struct {
|
||||||
|
connector Connector
|
||||||
|
wireProtocol string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMessageConnector(connector Connector, wireProtocol string) *messageConnector {
|
||||||
|
return &messageConnector{
|
||||||
|
connector: connector,
|
||||||
|
wireProtocol: wireProtocol,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *messageConnector) Connect() (*msg.Conn, error) {
|
||||||
|
conn, err := c.connector.Connect()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err = wire.WriteMagicIfV2(conn, c.wireProtocol); err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return msg.NewConn(conn, msg.NewReadWriter(conn, c.wireProtocol)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *messageConnector) Close() error {
|
||||||
|
return c.connector.Close()
|
||||||
|
}
|
||||||
|
|
||||||
// defaultConnectorImpl is the default implementation of Connector for normal frpc.
|
// defaultConnectorImpl is the default implementation of Connector for normal frpc.
|
||||||
type defaultConnectorImpl struct {
|
type defaultConnectorImpl struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ import (
|
|||||||
"github.com/fatedier/frp/pkg/msg"
|
"github.com/fatedier/frp/pkg/msg"
|
||||||
"github.com/fatedier/frp/pkg/naming"
|
"github.com/fatedier/frp/pkg/naming"
|
||||||
"github.com/fatedier/frp/pkg/transport"
|
"github.com/fatedier/frp/pkg/transport"
|
||||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
|
||||||
"github.com/fatedier/frp/pkg/util/wait"
|
"github.com/fatedier/frp/pkg/util/wait"
|
||||||
"github.com/fatedier/frp/pkg/util/xlog"
|
"github.com/fatedier/frp/pkg/util/xlog"
|
||||||
"github.com/fatedier/frp/pkg/vnet"
|
"github.com/fatedier/frp/pkg/vnet"
|
||||||
@@ -41,13 +40,11 @@ type SessionContext struct {
|
|||||||
// It should be attached to the login message when reconnecting.
|
// It should be attached to the login message when reconnecting.
|
||||||
RunID string
|
RunID string
|
||||||
// Underlying control connection. Once conn is closed, the msgDispatcher and the entire Control will exit.
|
// Underlying control connection. Once conn is closed, the msgDispatcher and the entire Control will exit.
|
||||||
Conn net.Conn
|
Conn *msg.Conn
|
||||||
// Indicates whether the connection is encrypted.
|
|
||||||
ConnEncrypted bool
|
|
||||||
// Auth runtime used for login, heartbeats, and encryption.
|
// Auth runtime used for login, heartbeats, and encryption.
|
||||||
Auth *auth.ClientAuth
|
Auth *auth.ClientAuth
|
||||||
// Connector is used to create new connections, which could be real TCP connections or virtual streams.
|
// Connector is used to create message connections to frps.
|
||||||
Connector Connector
|
Connector MessageConnector
|
||||||
// Virtual net controller
|
// Virtual net controller
|
||||||
VnetController *vnet.Controller
|
VnetController *vnet.Controller
|
||||||
}
|
}
|
||||||
@@ -91,15 +88,7 @@ func NewControl(ctx context.Context, sessionCtx *SessionContext) (*Control, erro
|
|||||||
}
|
}
|
||||||
ctl.lastPong.Store(time.Now())
|
ctl.lastPong.Store(time.Now())
|
||||||
|
|
||||||
if sessionCtx.ConnEncrypted {
|
ctl.msgDispatcher = msg.NewDispatcher(sessionCtx.Conn)
|
||||||
cryptoRW, err := netpkg.NewCryptoReadWriter(sessionCtx.Conn, sessionCtx.Auth.EncryptionKey())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
ctl.msgDispatcher = msg.NewDispatcher(cryptoRW)
|
|
||||||
} else {
|
|
||||||
ctl.msgDispatcher = msg.NewDispatcher(sessionCtx.Conn)
|
|
||||||
}
|
|
||||||
ctl.registerMsgHandlers()
|
ctl.registerMsgHandlers()
|
||||||
ctl.msgTransporter = transport.NewMessageTransporter(ctl.msgDispatcher)
|
ctl.msgTransporter = transport.NewMessageTransporter(ctl.msgDispatcher)
|
||||||
|
|
||||||
@@ -139,14 +128,14 @@ func (ctl *Control) handleReqWorkConn(_ msg.Message) {
|
|||||||
workConn.Close()
|
workConn.Close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err = msg.WriteMsg(workConn, m); err != nil {
|
if err = workConn.WriteMsg(m); err != nil {
|
||||||
xl.Warnf("work connection write to server error: %v", err)
|
xl.Warnf("work connection write to server error: %v", err)
|
||||||
workConn.Close()
|
workConn.Close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var startMsg msg.StartWorkConn
|
var startMsg msg.StartWorkConn
|
||||||
if err = msg.ReadMsgInto(workConn, &startMsg); err != nil {
|
if err = workConn.ReadMsgInto(&startMsg); err != nil {
|
||||||
xl.Tracef("work connection closed before response StartWorkConn message: %v", err)
|
xl.Tracef("work connection closed before response StartWorkConn message: %v", err)
|
||||||
workConn.Close()
|
workConn.Close()
|
||||||
return
|
return
|
||||||
@@ -227,7 +216,7 @@ func (ctl *Control) Done() <-chan struct{} {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// connectServer return a new connection to frps
|
// connectServer return a new connection to frps
|
||||||
func (ctl *Control) connectServer() (net.Conn, error) {
|
func (ctl *Control) connectServer() (*msg.Conn, error) {
|
||||||
return ctl.sessionCtx.Connector.Connect()
|
return ctl.sessionCtx.Connector.Connect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
172
client/control_session.go
Normal file
172
client/control_session.go
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
// Copyright 2026 The frp Authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/samber/lo"
|
||||||
|
|
||||||
|
"github.com/fatedier/frp/pkg/auth"
|
||||||
|
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||||
|
"github.com/fatedier/frp/pkg/msg"
|
||||||
|
"github.com/fatedier/frp/pkg/proto/wire"
|
||||||
|
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||||
|
"github.com/fatedier/frp/pkg/util/version"
|
||||||
|
"github.com/fatedier/frp/pkg/vnet"
|
||||||
|
)
|
||||||
|
|
||||||
|
type controlSessionDialer struct {
|
||||||
|
ctx context.Context
|
||||||
|
|
||||||
|
common *v1.ClientCommonConfig
|
||||||
|
auth *auth.ClientAuth
|
||||||
|
clientSpec *msg.ClientSpec
|
||||||
|
vnetController *vnet.Controller
|
||||||
|
|
||||||
|
connectorCreator func(context.Context, *v1.ClientCommonConfig) Connector
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *controlSessionDialer) Dial(previousRunID string) (*SessionContext, error) {
|
||||||
|
connector := d.connectorCreator(d.ctx, d.common)
|
||||||
|
if err := connector.Open(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
success := false
|
||||||
|
defer func() {
|
||||||
|
if !success {
|
||||||
|
_ = connector.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
conn, err := connector.Connect()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if !success {
|
||||||
|
_ = conn.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
loginMsg, err := d.buildLoginMsg(previousRunID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
loginRespMsg, err := d.exchangeLogin(conn, loginMsg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if loginRespMsg.Error != "" {
|
||||||
|
return nil, errors.New(loginRespMsg.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
var controlRW io.ReadWriter = conn
|
||||||
|
if d.clientSpec == nil || d.clientSpec.Type != "ssh-tunnel" {
|
||||||
|
controlRW, err = netpkg.NewCryptoReadWriter(conn, d.auth.EncryptionKey())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create control crypto read writer: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
success = true
|
||||||
|
return &SessionContext{
|
||||||
|
Common: d.common,
|
||||||
|
RunID: loginRespMsg.RunID,
|
||||||
|
Conn: msg.NewConn(conn, msg.NewReadWriter(controlRW, d.common.Transport.WireProtocol)),
|
||||||
|
Auth: d.auth,
|
||||||
|
Connector: newMessageConnector(connector, d.common.Transport.WireProtocol),
|
||||||
|
VnetController: d.vnetController,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *controlSessionDialer) buildLoginMsg(previousRunID string) (*msg.Login, error) {
|
||||||
|
hostname, _ := os.Hostname()
|
||||||
|
loginMsg := &msg.Login{
|
||||||
|
Arch: runtime.GOARCH,
|
||||||
|
Os: runtime.GOOS,
|
||||||
|
Hostname: hostname,
|
||||||
|
PoolCount: d.common.Transport.PoolCount,
|
||||||
|
User: d.common.User,
|
||||||
|
ClientID: d.common.ClientID,
|
||||||
|
Version: version.Full(),
|
||||||
|
Timestamp: time.Now().Unix(),
|
||||||
|
RunID: previousRunID,
|
||||||
|
Metas: d.common.Metadatas,
|
||||||
|
}
|
||||||
|
if d.clientSpec != nil {
|
||||||
|
loginMsg.ClientSpec = *d.clientSpec
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := d.auth.Setter.SetLogin(loginMsg); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return loginMsg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *controlSessionDialer) exchangeLogin(conn net.Conn, loginMsg *msg.Login) (*msg.LoginResp, error) {
|
||||||
|
rw := msg.NewV1ReadWriter(conn)
|
||||||
|
var wireConn *wire.Conn
|
||||||
|
|
||||||
|
if d.common.Transport.WireProtocol == wire.ProtocolV2 {
|
||||||
|
if err := wire.WriteMagic(conn); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
wireConn = wire.NewConn(conn)
|
||||||
|
rw = msg.NewV2ReadWriterWithConn(wireConn)
|
||||||
|
hello := wire.DefaultClientHello(wire.BootstrapInfo{
|
||||||
|
Transport: d.common.Transport.Protocol,
|
||||||
|
TLS: lo.FromPtr(d.common.Transport.TLS.Enable) || d.common.Transport.Protocol == "wss" || d.common.Transport.Protocol == "quic",
|
||||||
|
TCPMux: lo.FromPtr(d.common.Transport.TCPMux),
|
||||||
|
})
|
||||||
|
if err := wireConn.WriteJSONFrame(wire.FrameTypeClientHello, hello); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := rw.WriteMsg(loginMsg); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = conn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
||||||
|
defer func() {
|
||||||
|
_ = conn.SetReadDeadline(time.Time{})
|
||||||
|
}()
|
||||||
|
|
||||||
|
if wireConn != nil {
|
||||||
|
var serverHello wire.ServerHello
|
||||||
|
if err := wireConn.ReadJSONFrame(wire.FrameTypeServerHello, &serverHello); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if serverHello.Error != "" {
|
||||||
|
return nil, errors.New(serverHello.Error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var loginRespMsg msg.LoginResp
|
||||||
|
if err := rw.ReadMsgInto(&loginRespMsg); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &loginRespMsg, nil
|
||||||
|
}
|
||||||
245
client/control_session_test.go
Normal file
245
client/control_session_test.go
Normal file
@@ -0,0 +1,245 @@
|
|||||||
|
// Copyright 2026 The frp Authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/fatedier/frp/pkg/auth"
|
||||||
|
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||||
|
"github.com/fatedier/frp/pkg/msg"
|
||||||
|
"github.com/fatedier/frp/pkg/proto/wire"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testConnector struct {
|
||||||
|
conn net.Conn
|
||||||
|
closed atomic.Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testConnector) Open() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testConnector) Connect() (net.Conn, error) {
|
||||||
|
return c.conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testConnector) Close() error {
|
||||||
|
c.closed.Store(true)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type trackingConn struct {
|
||||||
|
net.Conn
|
||||||
|
closed atomic.Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *trackingConn) Close() error {
|
||||||
|
c.closed.Store(true)
|
||||||
|
return c.Conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestControlSessionDialer(t *testing.T, protocol string, connector Connector, clientSpec *msg.ClientSpec) *controlSessionDialer {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
authRuntime, err := auth.BuildClientAuth(&v1.AuthClientConfig{
|
||||||
|
Method: v1.AuthMethodToken,
|
||||||
|
Token: "token",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return &controlSessionDialer{
|
||||||
|
ctx: context.Background(),
|
||||||
|
common: &v1.ClientCommonConfig{
|
||||||
|
User: "test-user",
|
||||||
|
Transport: v1.ClientTransportConfig{
|
||||||
|
Protocol: "tcp",
|
||||||
|
WireProtocol: protocol,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
auth: authRuntime,
|
||||||
|
clientSpec: clientSpec,
|
||||||
|
connectorCreator: func(context.Context, *v1.ClientCommonConfig) Connector {
|
||||||
|
return connector
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestControlSessionDialerDialV1(t *testing.T) {
|
||||||
|
clientRaw, serverRaw := net.Pipe()
|
||||||
|
defer serverRaw.Close()
|
||||||
|
|
||||||
|
connector := &testConnector{conn: &trackingConn{Conn: clientRaw}}
|
||||||
|
serverErrCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
rw := msg.NewV1ReadWriter(serverRaw)
|
||||||
|
var loginMsg msg.Login
|
||||||
|
if err := rw.ReadMsgInto(&loginMsg); err != nil {
|
||||||
|
serverErrCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if loginMsg.RunID != "previous-run-id" {
|
||||||
|
serverErrCh <- fmt.Errorf("unexpected previous run id: %s", loginMsg.RunID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if loginMsg.User != "test-user" {
|
||||||
|
serverErrCh <- fmt.Errorf("unexpected user: %s", loginMsg.User)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
serverErrCh <- rw.WriteMsg(&msg.LoginResp{RunID: "run-v1"})
|
||||||
|
}()
|
||||||
|
|
||||||
|
dialer := newTestControlSessionDialer(t, wire.ProtocolV1, connector, nil)
|
||||||
|
sessionCtx, err := dialer.Dial("previous-run-id")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer sessionCtx.Conn.Close()
|
||||||
|
defer sessionCtx.Connector.Close()
|
||||||
|
|
||||||
|
require.Equal(t, "run-v1", sessionCtx.RunID)
|
||||||
|
require.NotNil(t, sessionCtx.Conn)
|
||||||
|
require.NotNil(t, sessionCtx.Connector)
|
||||||
|
require.False(t, connector.closed.Load())
|
||||||
|
require.NoError(t, <-serverErrCh)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestControlSessionDialerDialV2(t *testing.T) {
|
||||||
|
clientRaw, serverRaw := net.Pipe()
|
||||||
|
defer serverRaw.Close()
|
||||||
|
|
||||||
|
connector := &testConnector{conn: &trackingConn{Conn: clientRaw}}
|
||||||
|
serverErrCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
magic := make([]byte, len(wire.MagicV2))
|
||||||
|
if _, err := io.ReadFull(serverRaw, magic); err != nil {
|
||||||
|
serverErrCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if string(magic) != wire.MagicV2 {
|
||||||
|
serverErrCh <- fmt.Errorf("unexpected magic: %q", string(magic))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
wireConn := wire.NewConn(serverRaw)
|
||||||
|
var hello wire.ClientHello
|
||||||
|
if err := wireConn.ReadJSONFrame(wire.FrameTypeClientHello, &hello); err != nil {
|
||||||
|
serverErrCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := wire.ValidateClientHello(hello); err != nil {
|
||||||
|
serverErrCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rw := msg.NewV2ReadWriterWithConn(wireConn)
|
||||||
|
var loginMsg msg.Login
|
||||||
|
if err := rw.ReadMsgInto(&loginMsg); err != nil {
|
||||||
|
serverErrCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if loginMsg.User != "test-user" {
|
||||||
|
serverErrCh <- fmt.Errorf("unexpected user: %s", loginMsg.User)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := wireConn.WriteJSONFrame(wire.FrameTypeServerHello, wire.DefaultServerHello()); err != nil {
|
||||||
|
serverErrCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
serverErrCh <- rw.WriteMsg(&msg.LoginResp{RunID: "run-v2"})
|
||||||
|
}()
|
||||||
|
|
||||||
|
dialer := newTestControlSessionDialer(t, wire.ProtocolV2, connector, nil)
|
||||||
|
sessionCtx, err := dialer.Dial("")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer sessionCtx.Conn.Close()
|
||||||
|
defer sessionCtx.Connector.Close()
|
||||||
|
|
||||||
|
require.Equal(t, "run-v2", sessionCtx.RunID)
|
||||||
|
require.NotNil(t, sessionCtx.Conn)
|
||||||
|
require.NotNil(t, sessionCtx.Connector)
|
||||||
|
require.False(t, connector.closed.Load())
|
||||||
|
require.NoError(t, <-serverErrCh)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestControlSessionDialerDialLoginErrorClosesResources(t *testing.T) {
|
||||||
|
clientRaw, serverRaw := net.Pipe()
|
||||||
|
defer serverRaw.Close()
|
||||||
|
|
||||||
|
clientConn := &trackingConn{Conn: clientRaw}
|
||||||
|
connector := &testConnector{conn: clientConn}
|
||||||
|
serverErrCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
rw := msg.NewV1ReadWriter(serverRaw)
|
||||||
|
var loginMsg msg.Login
|
||||||
|
if err := rw.ReadMsgInto(&loginMsg); err != nil {
|
||||||
|
serverErrCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
serverErrCh <- rw.WriteMsg(&msg.LoginResp{Error: "login denied"})
|
||||||
|
}()
|
||||||
|
|
||||||
|
dialer := newTestControlSessionDialer(t, wire.ProtocolV1, connector, nil)
|
||||||
|
sessionCtx, err := dialer.Dial("")
|
||||||
|
require.Nil(t, sessionCtx)
|
||||||
|
require.ErrorContains(t, err, "login denied")
|
||||||
|
require.True(t, clientConn.closed.Load())
|
||||||
|
require.True(t, connector.closed.Load())
|
||||||
|
require.NoError(t, <-serverErrCh)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestControlSessionDialerDialSSHTunnelSkipsControlEncryption(t *testing.T) {
|
||||||
|
clientRaw, serverRaw := net.Pipe()
|
||||||
|
defer serverRaw.Close()
|
||||||
|
|
||||||
|
connector := &testConnector{conn: &trackingConn{Conn: clientRaw}}
|
||||||
|
serverErrCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
rw := msg.NewV1ReadWriter(serverRaw)
|
||||||
|
var loginMsg msg.Login
|
||||||
|
if err := rw.ReadMsgInto(&loginMsg); err != nil {
|
||||||
|
serverErrCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := rw.WriteMsg(&msg.LoginResp{RunID: "run-ssh-tunnel"}); err != nil {
|
||||||
|
serverErrCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = serverRaw.SetReadDeadline(time.Now().Add(time.Second))
|
||||||
|
var ping msg.Ping
|
||||||
|
if err := rw.ReadMsgInto(&ping); err != nil {
|
||||||
|
serverErrCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
serverErrCh <- nil
|
||||||
|
}()
|
||||||
|
|
||||||
|
dialer := newTestControlSessionDialer(t, wire.ProtocolV1, connector, &msg.ClientSpec{Type: "ssh-tunnel"})
|
||||||
|
sessionCtx, err := dialer.Dial("")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer sessionCtx.Conn.Close()
|
||||||
|
defer sessionCtx.Connector.Close()
|
||||||
|
|
||||||
|
require.Equal(t, "run-ssh-tunnel", sessionCtx.RunID)
|
||||||
|
require.NoError(t, sessionCtx.Conn.WriteMsg(&msg.Ping{}))
|
||||||
|
require.NoError(t, <-serverErrCh)
|
||||||
|
}
|
||||||
@@ -21,7 +21,6 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -38,7 +37,6 @@ import (
|
|||||||
httppkg "github.com/fatedier/frp/pkg/util/http"
|
httppkg "github.com/fatedier/frp/pkg/util/http"
|
||||||
"github.com/fatedier/frp/pkg/util/log"
|
"github.com/fatedier/frp/pkg/util/log"
|
||||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||||
"github.com/fatedier/frp/pkg/util/version"
|
|
||||||
"github.com/fatedier/frp/pkg/util/wait"
|
"github.com/fatedier/frp/pkg/util/wait"
|
||||||
"github.com/fatedier/frp/pkg/util/xlog"
|
"github.com/fatedier/frp/pkg/util/xlog"
|
||||||
"github.com/fatedier/frp/pkg/vnet"
|
"github.com/fatedier/frp/pkg/vnet"
|
||||||
@@ -303,80 +301,20 @@ func (svr *Service) keepControllerWorking() {
|
|||||||
), true, svr.ctx.Done())
|
), true, svr.ctx.Done())
|
||||||
}
|
}
|
||||||
|
|
||||||
// login creates a connection to frps and registers it self as a client
|
|
||||||
// conn: control connection
|
|
||||||
// session: if it's not nil, using tcp mux
|
|
||||||
func (svr *Service) login() (conn net.Conn, connector Connector, err error) {
|
|
||||||
xl := xlog.FromContextSafe(svr.ctx)
|
|
||||||
connector = svr.connectorCreator(svr.ctx, svr.common)
|
|
||||||
if err = connector.Open(); err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
if err != nil {
|
|
||||||
connector.Close()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
conn, err = connector.Connect()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
hostname, _ := os.Hostname()
|
|
||||||
|
|
||||||
loginMsg := &msg.Login{
|
|
||||||
Arch: runtime.GOARCH,
|
|
||||||
Os: runtime.GOOS,
|
|
||||||
Hostname: hostname,
|
|
||||||
PoolCount: svr.common.Transport.PoolCount,
|
|
||||||
User: svr.common.User,
|
|
||||||
ClientID: svr.common.ClientID,
|
|
||||||
Version: version.Full(),
|
|
||||||
Timestamp: time.Now().Unix(),
|
|
||||||
RunID: svr.runID,
|
|
||||||
Metas: svr.common.Metadatas,
|
|
||||||
}
|
|
||||||
if svr.clientSpec != nil {
|
|
||||||
loginMsg.ClientSpec = *svr.clientSpec
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add auth
|
|
||||||
if err = svr.auth.Setter.SetLogin(loginMsg); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = msg.WriteMsg(conn, loginMsg); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var loginRespMsg msg.LoginResp
|
|
||||||
_ = conn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
|
||||||
if err = msg.ReadMsgInto(conn, &loginRespMsg); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
_ = conn.SetReadDeadline(time.Time{})
|
|
||||||
|
|
||||||
if loginRespMsg.Error != "" {
|
|
||||||
err = fmt.Errorf("%s", loginRespMsg.Error)
|
|
||||||
xl.Errorf("%s", loginRespMsg.Error)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
svr.runID = loginRespMsg.RunID
|
|
||||||
xl.AddPrefix(xlog.LogPrefix{Name: "runID", Value: svr.runID})
|
|
||||||
|
|
||||||
xl.Infof("login to server success, get run id [%s]", loginRespMsg.RunID)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (svr *Service) loopLoginUntilSuccess(maxInterval time.Duration, firstLoginExit bool) {
|
func (svr *Service) loopLoginUntilSuccess(maxInterval time.Duration, firstLoginExit bool) {
|
||||||
xl := xlog.FromContextSafe(svr.ctx)
|
xl := xlog.FromContextSafe(svr.ctx)
|
||||||
|
|
||||||
loginFunc := func() (bool, error) {
|
loginFunc := func() (bool, error) {
|
||||||
xl.Infof("try to connect to server...")
|
xl.Infof("try to connect to server...")
|
||||||
conn, connector, err := svr.login()
|
dialer := &controlSessionDialer{
|
||||||
|
ctx: svr.ctx,
|
||||||
|
common: svr.common,
|
||||||
|
auth: svr.auth,
|
||||||
|
clientSpec: svr.clientSpec,
|
||||||
|
vnetController: svr.vnetController,
|
||||||
|
connectorCreator: svr.connectorCreator,
|
||||||
|
}
|
||||||
|
sessionCtx, err := dialer.Dial(svr.runID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
xl.Warnf("connect to server error: %v", err)
|
xl.Warnf("connect to server error: %v", err)
|
||||||
if firstLoginExit {
|
if firstLoginExit {
|
||||||
@@ -385,25 +323,19 @@ func (svr *Service) loopLoginUntilSuccess(maxInterval time.Duration, firstLoginE
|
|||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
svr.runID = sessionCtx.RunID
|
||||||
|
xl.AddPrefix(xlog.LogPrefix{Name: "runID", Value: svr.runID})
|
||||||
|
xl.Infof("login to server success, get run id [%s]", svr.runID)
|
||||||
|
|
||||||
svr.cfgMu.RLock()
|
svr.cfgMu.RLock()
|
||||||
proxyCfgs := svr.proxyCfgs
|
proxyCfgs := svr.proxyCfgs
|
||||||
visitorCfgs := svr.visitorCfgs
|
visitorCfgs := svr.visitorCfgs
|
||||||
svr.cfgMu.RUnlock()
|
svr.cfgMu.RUnlock()
|
||||||
|
|
||||||
connEncrypted := svr.clientSpec == nil || svr.clientSpec.Type != "ssh-tunnel"
|
|
||||||
|
|
||||||
sessionCtx := &SessionContext{
|
|
||||||
Common: svr.common,
|
|
||||||
RunID: svr.runID,
|
|
||||||
Conn: conn,
|
|
||||||
ConnEncrypted: connEncrypted,
|
|
||||||
Auth: svr.auth,
|
|
||||||
Connector: connector,
|
|
||||||
VnetController: svr.vnetController,
|
|
||||||
}
|
|
||||||
ctl, err := NewControl(svr.ctx, sessionCtx)
|
ctl, err := NewControl(svr.ctx, sessionCtx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Close()
|
sessionCtx.Conn.Close()
|
||||||
|
sessionCtx.Connector.Close()
|
||||||
xl.Errorf("new control error: %v", err)
|
xl.Errorf("new control error: %v", err)
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ import (
|
|||||||
// Helper wraps some functions for visitor to use.
|
// Helper wraps some functions for visitor to use.
|
||||||
type Helper interface {
|
type Helper interface {
|
||||||
// ConnectServer directly connects to the frp server.
|
// ConnectServer directly connects to the frp server.
|
||||||
ConnectServer() (net.Conn, error)
|
ConnectServer() (*msg.Conn, error)
|
||||||
// TransferConn transfers the connection to another visitor.
|
// TransferConn transfers the connection to another visitor.
|
||||||
TransferConn(string, net.Conn) error
|
TransferConn(string, net.Conn) error
|
||||||
// MsgTransporter returns the message transporter that is used to send and receive messages
|
// MsgTransporter returns the message transporter that is used to send and receive messages
|
||||||
@@ -167,15 +167,15 @@ func (v *BaseVisitor) dialRawVisitorConn(cfg *v1.VisitorBaseConfig) (net.Conn, e
|
|||||||
UseEncryption: cfg.Transport.UseEncryption,
|
UseEncryption: cfg.Transport.UseEncryption,
|
||||||
UseCompression: cfg.Transport.UseCompression,
|
UseCompression: cfg.Transport.UseCompression,
|
||||||
}
|
}
|
||||||
err = msg.WriteMsg(visitorConn, newVisitorConnMsg)
|
err = visitorConn.WriteMsg(newVisitorConnMsg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
visitorConn.Close()
|
visitorConn.Close()
|
||||||
return nil, fmt.Errorf("send newVisitorConnMsg to server error: %v", err)
|
return nil, fmt.Errorf("send newVisitorConnMsg to server error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var newVisitorConnRespMsg msg.NewVisitorConnResp
|
|
||||||
_ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
_ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
||||||
err = msg.ReadMsgInto(visitorConn, &newVisitorConnRespMsg)
|
var newVisitorConnRespMsg msg.NewVisitorConnResp
|
||||||
|
err = visitorConn.ReadMsgInto(&newVisitorConnRespMsg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
visitorConn.Close()
|
visitorConn.Close()
|
||||||
return nil, fmt.Errorf("read newVisitorConnRespMsg error: %v", err)
|
return nil, fmt.Errorf("read newVisitorConnRespMsg error: %v", err)
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ import (
|
|||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
|
|
||||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||||
|
"github.com/fatedier/frp/pkg/msg"
|
||||||
"github.com/fatedier/frp/pkg/transport"
|
"github.com/fatedier/frp/pkg/transport"
|
||||||
"github.com/fatedier/frp/pkg/util/xlog"
|
"github.com/fatedier/frp/pkg/util/xlog"
|
||||||
"github.com/fatedier/frp/pkg/vnet"
|
"github.com/fatedier/frp/pkg/vnet"
|
||||||
@@ -49,7 +50,7 @@ func NewManager(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
runID string,
|
runID string,
|
||||||
clientCfg *v1.ClientCommonConfig,
|
clientCfg *v1.ClientCommonConfig,
|
||||||
connectServer func() (net.Conn, error),
|
connectServer func() (*msg.Conn, error),
|
||||||
msgTransporter transport.MessageTransporter,
|
msgTransporter transport.MessageTransporter,
|
||||||
vnetController *vnet.Controller,
|
vnetController *vnet.Controller,
|
||||||
) *Manager {
|
) *Manager {
|
||||||
@@ -199,14 +200,14 @@ func (vm *Manager) GetVisitorCfg(name string) (v1.VisitorConfigurer, bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type visitorHelperImpl struct {
|
type visitorHelperImpl struct {
|
||||||
connectServerFn func() (net.Conn, error)
|
connectServerFn func() (*msg.Conn, error)
|
||||||
msgTransporter transport.MessageTransporter
|
msgTransporter transport.MessageTransporter
|
||||||
vnetController *vnet.Controller
|
vnetController *vnet.Controller
|
||||||
transferConnFn func(name string, conn net.Conn) error
|
transferConnFn func(name string, conn net.Conn) error
|
||||||
runID string
|
runID string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *visitorHelperImpl) ConnectServer() (net.Conn, error) {
|
func (v *visitorHelperImpl) ConnectServer() (*msg.Conn, error) {
|
||||||
return v.connectServerFn()
|
return v.connectServerFn()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -103,6 +103,10 @@ transport.poolCount = 5
|
|||||||
# supports tcp, kcp, quic, websocket and wss now, default is tcp
|
# supports tcp, kcp, quic, websocket and wss now, default is tcp
|
||||||
transport.protocol = "tcp"
|
transport.protocol = "tcp"
|
||||||
|
|
||||||
|
# FRP wire protocol used inside the selected transport.
|
||||||
|
# supports v1 and v2, default is v1. v2 requires frps support and must be enabled explicitly.
|
||||||
|
# transport.wireProtocol = "v1"
|
||||||
|
|
||||||
# set client binding ip when connect server, default is empty.
|
# set client binding ip when connect server, default is empty.
|
||||||
# only when protocol = tcp or websocket, the value will be used.
|
# only when protocol = tcp or websocket, the value will be used.
|
||||||
transport.connectServerLocalIP = "0.0.0.0"
|
transport.connectServerLocalIP = "0.0.0.0"
|
||||||
|
|||||||
@@ -104,6 +104,9 @@ type ClientTransportConfig struct {
|
|||||||
// Valid values are "tcp", "kcp", "quic", "websocket" and "wss". By default, this value
|
// Valid values are "tcp", "kcp", "quic", "websocket" and "wss". By default, this value
|
||||||
// is "tcp".
|
// is "tcp".
|
||||||
Protocol string `json:"protocol,omitempty"`
|
Protocol string `json:"protocol,omitempty"`
|
||||||
|
// WireProtocol specifies the frpc/frps internal wire protocol version.
|
||||||
|
// Valid values are "v1" and "v2". By default, this value is "v1".
|
||||||
|
WireProtocol string `json:"wireProtocol,omitempty"`
|
||||||
// The maximum amount of time a dial to server will wait for a connect to complete.
|
// The maximum amount of time a dial to server will wait for a connect to complete.
|
||||||
DialServerTimeout int64 `json:"dialServerTimeout,omitempty"`
|
DialServerTimeout int64 `json:"dialServerTimeout,omitempty"`
|
||||||
// DialServerKeepAlive specifies the interval between keep-alive probes for an active network connection between frpc and frps.
|
// DialServerKeepAlive specifies the interval between keep-alive probes for an active network connection between frpc and frps.
|
||||||
@@ -143,6 +146,7 @@ type ClientTransportConfig struct {
|
|||||||
|
|
||||||
func (c *ClientTransportConfig) Complete() {
|
func (c *ClientTransportConfig) Complete() {
|
||||||
c.Protocol = util.EmptyOr(c.Protocol, "tcp")
|
c.Protocol = util.EmptyOr(c.Protocol, "tcp")
|
||||||
|
c.WireProtocol = util.EmptyOr(c.WireProtocol, "v1")
|
||||||
c.DialServerTimeout = util.EmptyOr(c.DialServerTimeout, 10)
|
c.DialServerTimeout = util.EmptyOr(c.DialServerTimeout, 10)
|
||||||
c.DialServerKeepAlive = util.EmptyOr(c.DialServerKeepAlive, 7200)
|
c.DialServerKeepAlive = util.EmptyOr(c.DialServerKeepAlive, 7200)
|
||||||
c.ProxyURL = util.EmptyOr(c.ProxyURL, os.Getenv("http_proxy"))
|
c.ProxyURL = util.EmptyOr(c.ProxyURL, os.Getenv("http_proxy"))
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ func TestClientConfigComplete(t *testing.T) {
|
|||||||
|
|
||||||
require.EqualValues("token", c.Auth.Method)
|
require.EqualValues("token", c.Auth.Method)
|
||||||
require.Equal(true, lo.FromPtr(c.Transport.TCPMux))
|
require.Equal(true, lo.FromPtr(c.Transport.TCPMux))
|
||||||
|
require.Equal("v1", c.Transport.WireProtocol)
|
||||||
require.Equal(true, lo.FromPtr(c.LoginFailExit))
|
require.Equal(true, lo.FromPtr(c.LoginFailExit))
|
||||||
require.Equal(true, lo.FromPtr(c.Transport.TLS.Enable))
|
require.Equal(true, lo.FromPtr(c.Transport.TLS.Enable))
|
||||||
require.Equal(true, lo.FromPtr(c.Transport.TLS.DisableCustomTLSFirstByte))
|
require.Equal(true, lo.FromPtr(c.Transport.TLS.DisableCustomTLSFirstByte))
|
||||||
|
|||||||
@@ -146,6 +146,9 @@ func validateTransportConfig(c *v1.ClientTransportConfig) (Warning, error) {
|
|||||||
if !slices.Contains(SupportedTransportProtocols, c.Protocol) {
|
if !slices.Contains(SupportedTransportProtocols, c.Protocol) {
|
||||||
errs = AppendError(errs, fmt.Errorf("invalid transport.protocol, optional values are %v", SupportedTransportProtocols))
|
errs = AppendError(errs, fmt.Errorf("invalid transport.protocol, optional values are %v", SupportedTransportProtocols))
|
||||||
}
|
}
|
||||||
|
if !slices.Contains(SupportedWireProtocols, c.WireProtocol) {
|
||||||
|
errs = AppendError(errs, fmt.Errorf("invalid transport.wireProtocol, optional values are %v", SupportedWireProtocols))
|
||||||
|
}
|
||||||
return warnings, errs
|
return warnings, errs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -29,6 +29,10 @@ var (
|
|||||||
"websocket",
|
"websocket",
|
||||||
"wss",
|
"wss",
|
||||||
}
|
}
|
||||||
|
SupportedWireProtocols = []string{
|
||||||
|
"v1",
|
||||||
|
"v2",
|
||||||
|
}
|
||||||
|
|
||||||
SupportedAuthMethods = []v1.AuthMethod{
|
SupportedAuthMethods = []v1.AuthMethod{
|
||||||
"token",
|
"token",
|
||||||
|
|||||||
56
pkg/msg/conn_test.go
Normal file
56
pkg/msg/conn_test.go
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
// Copyright 2026 The frp Authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package msg
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/fatedier/frp/pkg/proto/wire"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConnReadWriteMsg(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
protocol string
|
||||||
|
}{
|
||||||
|
{name: "v1", protocol: wire.ProtocolV1},
|
||||||
|
{name: "v2", protocol: wire.ProtocolV2},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
client, server := net.Pipe()
|
||||||
|
defer client.Close()
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
clientConn := NewConn(client, NewReadWriter(client, tt.protocol))
|
||||||
|
serverConn := NewConn(server, NewReadWriter(server, tt.protocol))
|
||||||
|
|
||||||
|
in := &Ping{PrivilegeKey: "key", Timestamp: 123}
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
errCh <- clientConn.WriteMsg(in)
|
||||||
|
}()
|
||||||
|
|
||||||
|
out, err := serverConn.ReadMsg()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, in, out)
|
||||||
|
require.NoError(t, <-errCh)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -15,10 +15,90 @@
|
|||||||
package msg
|
package msg
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
|
||||||
|
"github.com/fatedier/frp/pkg/proto/wire"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type ReadWriter interface {
|
||||||
|
ReadMsg() (Message, error)
|
||||||
|
ReadMsgInto(Message) error
|
||||||
|
WriteMsg(Message) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type Conn struct {
|
||||||
|
net.Conn
|
||||||
|
rw ReadWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConn(conn net.Conn, rw ReadWriter) *Conn {
|
||||||
|
return &Conn{
|
||||||
|
Conn: conn,
|
||||||
|
rw: rw,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) ReadMsg() (Message, error) {
|
||||||
|
return c.rw.ReadMsg()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) ReadMsgInto(m Message) error {
|
||||||
|
return c.rw.ReadMsgInto(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) WriteMsg(m Message) error {
|
||||||
|
return c.rw.WriteMsg(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) Context() context.Context {
|
||||||
|
if getter, ok := c.Conn.(interface{ Context() context.Context }); ok {
|
||||||
|
return getter.Context()
|
||||||
|
}
|
||||||
|
return context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) WithContext(ctx context.Context) {
|
||||||
|
if setter, ok := c.Conn.(interface{ WithContext(context.Context) }); ok {
|
||||||
|
setter.WithContext(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type V1ReadWriter struct {
|
||||||
|
rw io.ReadWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewV1ReadWriter(rw io.ReadWriter) ReadWriter {
|
||||||
|
return &V1ReadWriter{rw: rw}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewReadWriter wraps rw with the message codec for the selected wire protocol.
|
||||||
|
// An empty protocol keeps the historical v1 behavior for tests and older call sites.
|
||||||
|
func NewReadWriter(rw io.ReadWriter, wireProtocol string) ReadWriter {
|
||||||
|
switch wireProtocol {
|
||||||
|
case wire.ProtocolV2:
|
||||||
|
return NewV2ReadWriter(rw)
|
||||||
|
case "", wire.ProtocolV1:
|
||||||
|
return NewV1ReadWriter(rw)
|
||||||
|
default:
|
||||||
|
return NewV1ReadWriter(rw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rw *V1ReadWriter) ReadMsg() (Message, error) {
|
||||||
|
return ReadMsg(rw.rw)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rw *V1ReadWriter) ReadMsgInto(m Message) error {
|
||||||
|
return ReadMsgInto(rw.rw, m)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rw *V1ReadWriter) WriteMsg(m Message) error {
|
||||||
|
return WriteMsg(rw.rw, m)
|
||||||
|
}
|
||||||
|
|
||||||
func AsyncHandler(f func(Message)) func(Message) {
|
func AsyncHandler(f func(Message)) func(Message) {
|
||||||
return func(m Message) {
|
return func(m Message) {
|
||||||
go f(m)
|
go f(m)
|
||||||
@@ -27,7 +107,7 @@ func AsyncHandler(f func(Message)) func(Message) {
|
|||||||
|
|
||||||
// Dispatcher is used to send messages to net.Conn or register handlers for messages read from net.Conn.
|
// Dispatcher is used to send messages to net.Conn or register handlers for messages read from net.Conn.
|
||||||
type Dispatcher struct {
|
type Dispatcher struct {
|
||||||
rw io.ReadWriter
|
rw ReadWriter
|
||||||
|
|
||||||
sendCh chan Message
|
sendCh chan Message
|
||||||
doneCh chan struct{}
|
doneCh chan struct{}
|
||||||
@@ -35,7 +115,7 @@ type Dispatcher struct {
|
|||||||
defaultHandler func(Message)
|
defaultHandler func(Message)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDispatcher(rw io.ReadWriter) *Dispatcher {
|
func NewDispatcher(rw ReadWriter) *Dispatcher {
|
||||||
return &Dispatcher{
|
return &Dispatcher{
|
||||||
rw: rw,
|
rw: rw,
|
||||||
sendCh: make(chan Message, 100),
|
sendCh: make(chan Message, 100),
|
||||||
@@ -56,14 +136,14 @@ func (d *Dispatcher) sendLoop() {
|
|||||||
case <-d.doneCh:
|
case <-d.doneCh:
|
||||||
return
|
return
|
||||||
case m := <-d.sendCh:
|
case m := <-d.sendCh:
|
||||||
_ = WriteMsg(d.rw, m)
|
_ = d.rw.WriteMsg(m)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Dispatcher) readLoop() {
|
func (d *Dispatcher) readLoop() {
|
||||||
for {
|
for {
|
||||||
m, err := ReadMsg(d.rw)
|
m, err := d.rw.ReadMsg()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
close(d.doneCh)
|
close(d.doneCh)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -20,24 +20,24 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
TypeLogin = 'o'
|
TypeLogin byte = 'o'
|
||||||
TypeLoginResp = '1'
|
TypeLoginResp byte = '1'
|
||||||
TypeNewProxy = 'p'
|
TypeNewProxy byte = 'p'
|
||||||
TypeNewProxyResp = '2'
|
TypeNewProxyResp byte = '2'
|
||||||
TypeCloseProxy = 'c'
|
TypeCloseProxy byte = 'c'
|
||||||
TypeNewWorkConn = 'w'
|
TypeNewWorkConn byte = 'w'
|
||||||
TypeReqWorkConn = 'r'
|
TypeReqWorkConn byte = 'r'
|
||||||
TypeStartWorkConn = 's'
|
TypeStartWorkConn byte = 's'
|
||||||
TypeNewVisitorConn = 'v'
|
TypeNewVisitorConn byte = 'v'
|
||||||
TypeNewVisitorConnResp = '3'
|
TypeNewVisitorConnResp byte = '3'
|
||||||
TypePing = 'h'
|
TypePing byte = 'h'
|
||||||
TypePong = '4'
|
TypePong byte = '4'
|
||||||
TypeUDPPacket = 'u'
|
TypeUDPPacket byte = 'u'
|
||||||
TypeNatHoleVisitor = 'i'
|
TypeNatHoleVisitor byte = 'i'
|
||||||
TypeNatHoleClient = 'n'
|
TypeNatHoleClient byte = 'n'
|
||||||
TypeNatHoleResp = 'm'
|
TypeNatHoleResp byte = 'm'
|
||||||
TypeNatHoleSid = '5'
|
TypeNatHoleSid byte = '5'
|
||||||
TypeNatHoleReport = '6'
|
TypeNatHoleReport byte = '6'
|
||||||
)
|
)
|
||||||
|
|
||||||
var msgTypeMap = map[byte]any{
|
var msgTypeMap = map[byte]any{
|
||||||
|
|||||||
55
pkg/msg/msg_test.go
Normal file
55
pkg/msg/msg_test.go
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
// Copyright 2026 The frp Authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package msg
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestV1MessageTypeIDsAreStable(t *testing.T) {
|
||||||
|
require.Equal(t, byte('o'), TypeLogin)
|
||||||
|
require.Equal(t, byte('1'), TypeLoginResp)
|
||||||
|
require.Equal(t, byte('p'), TypeNewProxy)
|
||||||
|
require.Equal(t, byte('2'), TypeNewProxyResp)
|
||||||
|
require.Equal(t, byte('c'), TypeCloseProxy)
|
||||||
|
require.Equal(t, byte('w'), TypeNewWorkConn)
|
||||||
|
require.Equal(t, byte('r'), TypeReqWorkConn)
|
||||||
|
require.Equal(t, byte('s'), TypeStartWorkConn)
|
||||||
|
require.Equal(t, byte('v'), TypeNewVisitorConn)
|
||||||
|
require.Equal(t, byte('3'), TypeNewVisitorConnResp)
|
||||||
|
require.Equal(t, byte('h'), TypePing)
|
||||||
|
require.Equal(t, byte('4'), TypePong)
|
||||||
|
require.Equal(t, byte('u'), TypeUDPPacket)
|
||||||
|
require.Equal(t, byte('i'), TypeNatHoleVisitor)
|
||||||
|
require.Equal(t, byte('n'), TypeNatHoleClient)
|
||||||
|
require.Equal(t, byte('m'), TypeNatHoleResp)
|
||||||
|
require.Equal(t, byte('5'), TypeNatHoleSid)
|
||||||
|
require.Equal(t, byte('6'), TypeNatHoleReport)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMessageTypeMapIsCompleteAndUnique(t *testing.T) {
|
||||||
|
require.Len(t, msgTypeMap, 18)
|
||||||
|
|
||||||
|
msgTypes := make(map[reflect.Type]struct{}, len(msgTypeMap))
|
||||||
|
|
||||||
|
for _, m := range msgTypeMap {
|
||||||
|
msgType := reflect.TypeOf(m)
|
||||||
|
require.NotContains(t, msgTypes, msgType)
|
||||||
|
msgTypes[msgType] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
192
pkg/msg/wire_v2.go
Normal file
192
pkg/msg/wire_v2.go
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
// Copyright 2026 The frp Authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package msg
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"reflect"
|
||||||
|
|
||||||
|
"github.com/fatedier/frp/pkg/proto/wire"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
V2TypeLogin uint16 = 1
|
||||||
|
V2TypeLoginResp uint16 = 2
|
||||||
|
V2TypeNewProxy uint16 = 3
|
||||||
|
V2TypeNewProxyResp uint16 = 4
|
||||||
|
V2TypeCloseProxy uint16 = 5
|
||||||
|
V2TypeNewWorkConn uint16 = 6
|
||||||
|
V2TypeReqWorkConn uint16 = 7
|
||||||
|
V2TypeStartWorkConn uint16 = 8
|
||||||
|
V2TypeNewVisitorConn uint16 = 9
|
||||||
|
V2TypeNewVisitorConnResp uint16 = 10
|
||||||
|
V2TypePing uint16 = 11
|
||||||
|
V2TypePong uint16 = 12
|
||||||
|
V2TypeUDPPacket uint16 = 13
|
||||||
|
V2TypeNatHoleVisitor uint16 = 14
|
||||||
|
V2TypeNatHoleClient uint16 = 15
|
||||||
|
V2TypeNatHoleResp uint16 = 16
|
||||||
|
V2TypeNatHoleSid uint16 = 17
|
||||||
|
V2TypeNatHoleReport uint16 = 18
|
||||||
|
)
|
||||||
|
|
||||||
|
var v2MsgTypeMap = map[uint16]any{
|
||||||
|
V2TypeLogin: Login{},
|
||||||
|
V2TypeLoginResp: LoginResp{},
|
||||||
|
V2TypeNewProxy: NewProxy{},
|
||||||
|
V2TypeNewProxyResp: NewProxyResp{},
|
||||||
|
V2TypeCloseProxy: CloseProxy{},
|
||||||
|
V2TypeNewWorkConn: NewWorkConn{},
|
||||||
|
V2TypeReqWorkConn: ReqWorkConn{},
|
||||||
|
V2TypeStartWorkConn: StartWorkConn{},
|
||||||
|
V2TypeNewVisitorConn: NewVisitorConn{},
|
||||||
|
V2TypeNewVisitorConnResp: NewVisitorConnResp{},
|
||||||
|
V2TypePing: Ping{},
|
||||||
|
V2TypePong: Pong{},
|
||||||
|
V2TypeUDPPacket: UDPPacket{},
|
||||||
|
V2TypeNatHoleVisitor: NatHoleVisitor{},
|
||||||
|
V2TypeNatHoleClient: NatHoleClient{},
|
||||||
|
V2TypeNatHoleResp: NatHoleResp{},
|
||||||
|
V2TypeNatHoleSid: NatHoleSid{},
|
||||||
|
V2TypeNatHoleReport: NatHoleReport{},
|
||||||
|
}
|
||||||
|
|
||||||
|
var v2MsgReflectTypeMap, v2MsgTypeIDMap = buildV2MsgTypeMaps()
|
||||||
|
|
||||||
|
func buildV2MsgTypeMaps() (map[uint16]reflect.Type, map[reflect.Type]uint16) {
|
||||||
|
reflectTypeMap := make(map[uint16]reflect.Type, len(v2MsgTypeMap))
|
||||||
|
typeIDMap := make(map[reflect.Type]uint16, len(v2MsgTypeMap))
|
||||||
|
for typeID, m := range v2MsgTypeMap {
|
||||||
|
t := reflect.TypeOf(m)
|
||||||
|
reflectTypeMap[typeID] = t
|
||||||
|
typeIDMap[t] = typeID
|
||||||
|
}
|
||||||
|
return reflectTypeMap, typeIDMap
|
||||||
|
}
|
||||||
|
|
||||||
|
type V2ReadWriter struct {
|
||||||
|
conn *wire.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewV2ReadWriter(rw io.ReadWriter) *V2ReadWriter {
|
||||||
|
return NewV2ReadWriterWithConn(wire.NewConn(rw))
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewV2ReadWriterWithConn(conn *wire.Conn) *V2ReadWriter {
|
||||||
|
return &V2ReadWriter{conn: conn}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rw *V2ReadWriter) WireConn() *wire.Conn {
|
||||||
|
return rw.conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rw *V2ReadWriter) ReadMsg() (Message, error) {
|
||||||
|
f, err := rw.conn.ReadFrame()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return DecodeV2MessageFrame(f)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rw *V2ReadWriter) ReadMsgInto(m Message) error {
|
||||||
|
f, err := rw.conn.ReadFrame()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return DecodeV2MessageFrameInto(f, m)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rw *V2ReadWriter) WriteMsg(m Message) error {
|
||||||
|
f, err := EncodeV2MessageFrame(m)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return rw.conn.WriteFrame(f)
|
||||||
|
}
|
||||||
|
|
||||||
|
func DecodeV2MessageFrame(f *wire.Frame) (Message, error) {
|
||||||
|
if f.Type != wire.FrameTypeMessage {
|
||||||
|
return nil, fmt.Errorf("unexpected frame type %d, want %d", f.Type, wire.FrameTypeMessage)
|
||||||
|
}
|
||||||
|
if len(f.Payload) < 2 {
|
||||||
|
return nil, fmt.Errorf("message frame payload too short")
|
||||||
|
}
|
||||||
|
typeID := binary.BigEndian.Uint16(f.Payload[:2])
|
||||||
|
t, ok := v2MsgReflectTypeMap[typeID]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unknown v2 message type %d", typeID)
|
||||||
|
}
|
||||||
|
m := reflect.New(t).Interface()
|
||||||
|
if err := json.Unmarshal(f.Payload[2:], m); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func DecodeV2MessageFrameInto(f *wire.Frame, out Message) error {
|
||||||
|
if f.Type != wire.FrameTypeMessage {
|
||||||
|
return fmt.Errorf("unexpected frame type %d, want %d", f.Type, wire.FrameTypeMessage)
|
||||||
|
}
|
||||||
|
if len(f.Payload) < 2 {
|
||||||
|
return fmt.Errorf("message frame payload too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
typeID := binary.BigEndian.Uint16(f.Payload[:2])
|
||||||
|
outType := reflect.TypeOf(out)
|
||||||
|
if outType == nil || outType.Kind() != reflect.Pointer {
|
||||||
|
return fmt.Errorf("message target must be a pointer")
|
||||||
|
}
|
||||||
|
elemType := outType.Elem()
|
||||||
|
expectedTypeID, ok := v2MsgTypeIDMap[elemType]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unknown v2 message type %s", elemType.String())
|
||||||
|
}
|
||||||
|
if typeID != expectedTypeID {
|
||||||
|
actualType, ok := v2MsgReflectTypeMap[typeID]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unknown v2 message type %d", typeID)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("unexpected message type %s, want %s", actualType.String(), elemType.String())
|
||||||
|
}
|
||||||
|
return json.Unmarshal(f.Payload[2:], out)
|
||||||
|
}
|
||||||
|
|
||||||
|
func EncodeV2MessageFrame(m Message) (*wire.Frame, error) {
|
||||||
|
t := reflect.TypeOf(m)
|
||||||
|
if t == nil {
|
||||||
|
return nil, fmt.Errorf("nil message")
|
||||||
|
}
|
||||||
|
if t.Kind() == reflect.Pointer {
|
||||||
|
t = t.Elem()
|
||||||
|
}
|
||||||
|
typeID, ok := v2MsgTypeIDMap[t]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unknown v2 message type %s", t.String())
|
||||||
|
}
|
||||||
|
content, err := json.Marshal(m)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
payload := make([]byte, 2+len(content))
|
||||||
|
binary.BigEndian.PutUint16(payload[:2], typeID)
|
||||||
|
copy(payload[2:], content)
|
||||||
|
return &wire.Frame{
|
||||||
|
Type: wire.FrameTypeMessage,
|
||||||
|
Payload: payload,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
121
pkg/msg/wire_v2_test.go
Normal file
121
pkg/msg/wire_v2_test.go
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
// Copyright 2026 The frp Authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package msg
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/fatedier/frp/pkg/proto/wire"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestV2ReadWriterRoundTrip(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
rw := NewV2ReadWriter(&buf)
|
||||||
|
|
||||||
|
in := &Login{
|
||||||
|
Version: "test-version",
|
||||||
|
RunID: "run-id",
|
||||||
|
User: "user",
|
||||||
|
}
|
||||||
|
require.NoError(t, rw.WriteMsg(in))
|
||||||
|
|
||||||
|
out, err := rw.ReadMsg()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, in, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewReadWriter(t *testing.T) {
|
||||||
|
require.IsType(t, &V1ReadWriter{}, NewReadWriter(&bytes.Buffer{}, ""))
|
||||||
|
require.IsType(t, &V1ReadWriter{}, NewReadWriter(&bytes.Buffer{}, wire.ProtocolV1))
|
||||||
|
require.IsType(t, &V2ReadWriter{}, NewReadWriter(&bytes.Buffer{}, wire.ProtocolV2))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestV2MessageTypeIDsAreStable(t *testing.T) {
|
||||||
|
require.Equal(t, uint16(1), V2TypeLogin)
|
||||||
|
require.Equal(t, uint16(2), V2TypeLoginResp)
|
||||||
|
require.Equal(t, uint16(3), V2TypeNewProxy)
|
||||||
|
require.Equal(t, uint16(4), V2TypeNewProxyResp)
|
||||||
|
require.Equal(t, uint16(5), V2TypeCloseProxy)
|
||||||
|
require.Equal(t, uint16(6), V2TypeNewWorkConn)
|
||||||
|
require.Equal(t, uint16(7), V2TypeReqWorkConn)
|
||||||
|
require.Equal(t, uint16(8), V2TypeStartWorkConn)
|
||||||
|
require.Equal(t, uint16(9), V2TypeNewVisitorConn)
|
||||||
|
require.Equal(t, uint16(10), V2TypeNewVisitorConnResp)
|
||||||
|
require.Equal(t, uint16(11), V2TypePing)
|
||||||
|
require.Equal(t, uint16(12), V2TypePong)
|
||||||
|
require.Equal(t, uint16(13), V2TypeUDPPacket)
|
||||||
|
require.Equal(t, uint16(14), V2TypeNatHoleVisitor)
|
||||||
|
require.Equal(t, uint16(15), V2TypeNatHoleClient)
|
||||||
|
require.Equal(t, uint16(16), V2TypeNatHoleResp)
|
||||||
|
require.Equal(t, uint16(17), V2TypeNatHoleSid)
|
||||||
|
require.Equal(t, uint16(18), V2TypeNatHoleReport)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestV2MessageFrameEncoding(t *testing.T) {
|
||||||
|
frame, err := EncodeV2MessageFrame(&ReqWorkConn{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, wire.FrameTypeMessage, frame.Type)
|
||||||
|
require.Len(t, frame.Payload, 4)
|
||||||
|
require.Equal(t, V2TypeReqWorkConn, binary.BigEndian.Uint16(frame.Payload[:2]))
|
||||||
|
|
||||||
|
out, err := DecodeV2MessageFrame(frame)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.IsType(t, &ReqWorkConn{}, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecodeV2MessageFrameInto(t *testing.T) {
|
||||||
|
in := &StartWorkConn{ProxyName: "tcp", SrcAddr: "127.0.0.1", SrcPort: 1234}
|
||||||
|
frame, err := EncodeV2MessageFrame(in)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var out StartWorkConn
|
||||||
|
require.NoError(t, DecodeV2MessageFrameInto(frame, &out))
|
||||||
|
require.Equal(t, *in, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecodeV2MessageFrameRejectsInvalidFrame(t *testing.T) {
|
||||||
|
_, err := DecodeV2MessageFrame(&wire.Frame{Type: wire.FrameTypeClientHello})
|
||||||
|
require.ErrorContains(t, err, "unexpected frame type")
|
||||||
|
|
||||||
|
_, err = DecodeV2MessageFrame(&wire.Frame{Type: wire.FrameTypeMessage, Payload: []byte{0}})
|
||||||
|
require.ErrorContains(t, err, "payload too short")
|
||||||
|
|
||||||
|
payload := make([]byte, 4)
|
||||||
|
binary.BigEndian.PutUint16(payload[:2], 65535)
|
||||||
|
copy(payload[2:], []byte("{}"))
|
||||||
|
_, err = DecodeV2MessageFrame(&wire.Frame{Type: wire.FrameTypeMessage, Payload: payload})
|
||||||
|
require.ErrorContains(t, err, "unknown v2 message type")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecodeV2MessageFrameIntoRejectsWrongTarget(t *testing.T) {
|
||||||
|
frame, err := EncodeV2MessageFrame(&ReqWorkConn{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var out StartWorkConn
|
||||||
|
err = DecodeV2MessageFrameInto(frame, &out)
|
||||||
|
require.ErrorContains(t, err, "unexpected message type")
|
||||||
|
|
||||||
|
err = DecodeV2MessageFrameInto(frame, StartWorkConn{})
|
||||||
|
require.ErrorContains(t, err, "must be a pointer")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeV2MessageFrameRejectsUnknownMessage(t *testing.T) {
|
||||||
|
_, err := EncodeV2MessageFrame(struct{}{})
|
||||||
|
require.ErrorContains(t, err, "unknown v2 message type")
|
||||||
|
}
|
||||||
222
pkg/proto/wire/wire.go
Normal file
222
pkg/proto/wire/wire.go
Normal file
@@ -0,0 +1,222 @@
|
|||||||
|
// Copyright 2026 The frp Authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package wire
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"slices"
|
||||||
|
|
||||||
|
libnet "github.com/fatedier/golib/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ProtocolV1 = "v1"
|
||||||
|
ProtocolV2 = "v2"
|
||||||
|
|
||||||
|
WireVersionV2 = 2
|
||||||
|
|
||||||
|
FrameTypeClientHello uint16 = 1
|
||||||
|
FrameTypeServerHello uint16 = 2
|
||||||
|
FrameTypeMessage uint16 = 16
|
||||||
|
|
||||||
|
MessageCodecJSON = "json"
|
||||||
|
DefaultMaxFramePayloadSize = 64 * 1024
|
||||||
|
|
||||||
|
MagicV2 = "FRP\x00\x02\r\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Frame struct {
|
||||||
|
Type uint16
|
||||||
|
Flags uint16
|
||||||
|
Payload []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type Conn struct {
|
||||||
|
rw io.ReadWriter
|
||||||
|
maxFramePayloadSize uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConn(rw io.ReadWriter) *Conn {
|
||||||
|
return &Conn{
|
||||||
|
rw: rw,
|
||||||
|
maxFramePayloadSize: DefaultMaxFramePayloadSize,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) ReadFrame() (*Frame, error) {
|
||||||
|
header := make([]byte, 8)
|
||||||
|
if _, err := io.ReadFull(c.rw, header); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
frameType := binary.BigEndian.Uint16(header[0:2])
|
||||||
|
flags := binary.BigEndian.Uint16(header[2:4])
|
||||||
|
length := binary.BigEndian.Uint32(header[4:8])
|
||||||
|
if flags != 0 {
|
||||||
|
return nil, fmt.Errorf("unsupported frame flags: %d", flags)
|
||||||
|
}
|
||||||
|
if length > c.maxFramePayloadSize {
|
||||||
|
return nil, fmt.Errorf("frame payload length %d exceeds limit %d", length, c.maxFramePayloadSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := make([]byte, length)
|
||||||
|
if _, err := io.ReadFull(c.rw, payload); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &Frame{
|
||||||
|
Type: frameType,
|
||||||
|
Flags: flags,
|
||||||
|
Payload: payload,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) WriteFrame(f *Frame) error {
|
||||||
|
if f.Flags != 0 {
|
||||||
|
return fmt.Errorf("unsupported frame flags: %d", f.Flags)
|
||||||
|
}
|
||||||
|
if len(f.Payload) > int(c.maxFramePayloadSize) {
|
||||||
|
return fmt.Errorf("frame payload length %d exceeds limit %d", len(f.Payload), c.maxFramePayloadSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
header := make([]byte, 8)
|
||||||
|
binary.BigEndian.PutUint16(header[0:2], f.Type)
|
||||||
|
binary.BigEndian.PutUint16(header[2:4], f.Flags)
|
||||||
|
binary.BigEndian.PutUint32(header[4:8], uint32(len(f.Payload)))
|
||||||
|
if _, err := c.rw.Write(header); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err := c.rw.Write(f.Payload)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) ReadJSONFrame(frameType uint16, out any) error {
|
||||||
|
f, err := c.ReadFrame()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if f.Type != frameType {
|
||||||
|
return fmt.Errorf("unexpected frame type %d, want %d", f.Type, frameType)
|
||||||
|
}
|
||||||
|
return c.UnmarshalFrame(f, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) UnmarshalFrame(f *Frame, out any) error {
|
||||||
|
return json.Unmarshal(f.Payload, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) WriteJSONFrame(frameType uint16, in any) error {
|
||||||
|
payload, err := json.Marshal(in)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.WriteFrame(&Frame{
|
||||||
|
Type: frameType,
|
||||||
|
Payload: payload,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func WriteMagic(w io.Writer) error {
|
||||||
|
_, err := io.WriteString(w, MagicV2)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func WriteMagicIfV2(w io.Writer, wireProtocol string) error {
|
||||||
|
if wireProtocol != ProtocolV2 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return WriteMagic(w)
|
||||||
|
}
|
||||||
|
|
||||||
|
func CheckMagic(conn net.Conn) (out net.Conn, isV2 bool, err error) {
|
||||||
|
sharedConn, r := libnet.NewSharedConnSize(conn, len(MagicV2))
|
||||||
|
buf := make([]byte, len(MagicV2))
|
||||||
|
if _, err = io.ReadFull(r, buf); err != nil {
|
||||||
|
return nil, false, err
|
||||||
|
}
|
||||||
|
for i := range MagicV2 {
|
||||||
|
if buf[i] != MagicV2[i] {
|
||||||
|
return sharedConn, false, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return conn, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type BootstrapInfo struct {
|
||||||
|
Transport string `json:"transport,omitempty"`
|
||||||
|
TLS bool `json:"tls,omitempty"`
|
||||||
|
TCPMux bool `json:"tcpMux,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClientHello struct {
|
||||||
|
Bootstrap BootstrapInfo `json:"bootstrap,omitempty"`
|
||||||
|
Capabilities ClientCapabilities `json:"capabilities,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClientCapabilities struct {
|
||||||
|
Message MessageCapabilities `json:"message,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type MessageCapabilities struct {
|
||||||
|
Codecs []string `json:"codecs,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServerHello struct {
|
||||||
|
Selected ServerSelection `json:"selected,omitempty"`
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServerSelection struct {
|
||||||
|
Message MessageSelection `json:"message,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type MessageSelection struct {
|
||||||
|
Codec string `json:"codec,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func DefaultClientHello(bootstrap BootstrapInfo) ClientHello {
|
||||||
|
return ClientHello{
|
||||||
|
Bootstrap: bootstrap,
|
||||||
|
Capabilities: ClientCapabilities{
|
||||||
|
Message: MessageCapabilities{
|
||||||
|
Codecs: []string{MessageCodecJSON},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func DefaultServerHello() ServerHello {
|
||||||
|
return ServerHello{
|
||||||
|
Selected: ServerSelection{
|
||||||
|
Message: MessageSelection{
|
||||||
|
Codec: MessageCodecJSON,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Supports(list []string, value string) bool {
|
||||||
|
return slices.Contains(list, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ValidateClientHello(h ClientHello) error {
|
||||||
|
if !Supports(h.Capabilities.Message.Codecs, MessageCodecJSON) {
|
||||||
|
return fmt.Errorf("unsupported message codec")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
120
pkg/proto/wire/wire_test.go
Normal file
120
pkg/proto/wire/wire_test.go
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
// Copyright 2026 The frp Authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package wire
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFrameRoundTrip(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
conn := NewConn(&buf)
|
||||||
|
|
||||||
|
in := DefaultClientHello(BootstrapInfo{
|
||||||
|
Transport: "tcp",
|
||||||
|
TLS: true,
|
||||||
|
TCPMux: true,
|
||||||
|
})
|
||||||
|
require.NoError(t, conn.WriteJSONFrame(FrameTypeClientHello, in))
|
||||||
|
|
||||||
|
var out ClientHello
|
||||||
|
require.NoError(t, conn.ReadJSONFrame(FrameTypeClientHello, &out))
|
||||||
|
require.Equal(t, in, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadFrameRejectsUnsupportedFlags(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
header := make([]byte, 8)
|
||||||
|
binary.BigEndian.PutUint16(header[0:2], FrameTypeMessage)
|
||||||
|
binary.BigEndian.PutUint16(header[2:4], 1)
|
||||||
|
binary.BigEndian.PutUint32(header[4:8], 0)
|
||||||
|
buf.Write(header)
|
||||||
|
|
||||||
|
_, err := NewConn(&buf).ReadFrame()
|
||||||
|
require.ErrorContains(t, err, "unsupported frame flags")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadFrameRejectsOversizedPayload(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
header := make([]byte, 8)
|
||||||
|
binary.BigEndian.PutUint16(header[0:2], FrameTypeMessage)
|
||||||
|
binary.BigEndian.PutUint32(header[4:8], DefaultMaxFramePayloadSize+1)
|
||||||
|
buf.Write(header)
|
||||||
|
|
||||||
|
_, err := NewConn(&buf).ReadFrame()
|
||||||
|
require.ErrorContains(t, err, "exceeds limit")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckMagicV2ConsumesMagic(t *testing.T) {
|
||||||
|
client, server := net.Pipe()
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
want := []byte("payload")
|
||||||
|
go func() {
|
||||||
|
defer client.Close()
|
||||||
|
_, _ = client.Write(append([]byte(MagicV2), want...))
|
||||||
|
}()
|
||||||
|
|
||||||
|
out, isV2, err := CheckMagic(server)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, isV2)
|
||||||
|
|
||||||
|
got := make([]byte, len(want))
|
||||||
|
_, err = io.ReadFull(out, got)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, want, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteMagicIfV2(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
require.NoError(t, WriteMagicIfV2(&buf, ProtocolV1))
|
||||||
|
require.Empty(t, buf.Bytes())
|
||||||
|
|
||||||
|
require.NoError(t, WriteMagicIfV2(&buf, ProtocolV2))
|
||||||
|
require.Equal(t, []byte(MagicV2), buf.Bytes())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckMagicV1PreservesReadBytes(t *testing.T) {
|
||||||
|
client, server := net.Pipe()
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
want := []byte("legacy payload")
|
||||||
|
go func() {
|
||||||
|
defer client.Close()
|
||||||
|
_, _ = client.Write(want)
|
||||||
|
}()
|
||||||
|
|
||||||
|
out, isV2, err := CheckMagic(server)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.False(t, isV2)
|
||||||
|
|
||||||
|
got, err := io.ReadAll(out)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, want, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateClientHello(t *testing.T) {
|
||||||
|
require.NoError(t, ValidateClientHello(DefaultClientHello(BootstrapInfo{})))
|
||||||
|
|
||||||
|
hello := DefaultClientHello(BootstrapInfo{})
|
||||||
|
hello.Capabilities.Message.Codecs = []string{"unknown"}
|
||||||
|
require.ErrorContains(t, ValidateClientHello(hello), "unsupported message codec")
|
||||||
|
}
|
||||||
@@ -133,7 +133,7 @@ type CloseNotifyConn struct {
|
|||||||
net.Conn
|
net.Conn
|
||||||
|
|
||||||
// 1 means closed
|
// 1 means closed
|
||||||
closeFlag int32
|
closeFlag atomic.Int32
|
||||||
|
|
||||||
closeFn func(error)
|
closeFn func(error)
|
||||||
}
|
}
|
||||||
@@ -147,7 +147,7 @@ func WrapCloseNotifyConn(c net.Conn, closeFn func(error)) *CloseNotifyConn {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (cc *CloseNotifyConn) Close() (err error) {
|
func (cc *CloseNotifyConn) Close() (err error) {
|
||||||
pflag := atomic.SwapInt32(&cc.closeFlag, 1)
|
pflag := cc.closeFlag.Swap(1)
|
||||||
if pflag == 0 {
|
if pflag == 0 {
|
||||||
err = cc.Conn.Close()
|
err = cc.Conn.Close()
|
||||||
if cc.closeFn != nil {
|
if cc.closeFn != nil {
|
||||||
@@ -159,7 +159,7 @@ func (cc *CloseNotifyConn) Close() (err error) {
|
|||||||
|
|
||||||
// CloseWithError closes the connection and passes the error to the close callback.
|
// CloseWithError closes the connection and passes the error to the close callback.
|
||||||
func (cc *CloseNotifyConn) CloseWithError(err error) error {
|
func (cc *CloseNotifyConn) CloseWithError(err error) error {
|
||||||
pflag := atomic.SwapInt32(&cc.closeFlag, 1)
|
pflag := cc.closeFlag.Swap(1)
|
||||||
if pflag == 0 {
|
if pflag == 0 {
|
||||||
closeErr := cc.Conn.Close()
|
closeErr := cc.Conn.Close()
|
||||||
if cc.closeFn != nil {
|
if cc.closeFn != nil {
|
||||||
@@ -173,7 +173,7 @@ func (cc *CloseNotifyConn) CloseWithError(err error) error {
|
|||||||
type StatsConn struct {
|
type StatsConn struct {
|
||||||
net.Conn
|
net.Conn
|
||||||
|
|
||||||
closed int64 // 1 means closed
|
closed atomic.Int64 // 1 means closed
|
||||||
totalRead int64
|
totalRead int64
|
||||||
totalWrite int64
|
totalWrite int64
|
||||||
statsFunc func(totalRead, totalWrite int64)
|
statsFunc func(totalRead, totalWrite int64)
|
||||||
@@ -199,7 +199,7 @@ func (statsConn *StatsConn) Write(p []byte) (n int, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (statsConn *StatsConn) Close() (err error) {
|
func (statsConn *StatsConn) Close() (err error) {
|
||||||
old := atomic.SwapInt64(&statsConn.closed, 1)
|
old := statsConn.closed.Swap(1)
|
||||||
if old != 1 {
|
if old != 1 {
|
||||||
err = statsConn.Conn.Close()
|
err = statsConn.Conn.Close()
|
||||||
if statsConn.statsFunc != nil {
|
if statsConn.statsFunc != nil {
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ package server
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@@ -32,9 +31,7 @@ import (
|
|||||||
"github.com/fatedier/frp/pkg/msg"
|
"github.com/fatedier/frp/pkg/msg"
|
||||||
plugin "github.com/fatedier/frp/pkg/plugin/server"
|
plugin "github.com/fatedier/frp/pkg/plugin/server"
|
||||||
"github.com/fatedier/frp/pkg/transport"
|
"github.com/fatedier/frp/pkg/transport"
|
||||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
|
||||||
"github.com/fatedier/frp/pkg/util/util"
|
"github.com/fatedier/frp/pkg/util/util"
|
||||||
"github.com/fatedier/frp/pkg/util/version"
|
|
||||||
"github.com/fatedier/frp/pkg/util/wait"
|
"github.com/fatedier/frp/pkg/util/wait"
|
||||||
"github.com/fatedier/frp/pkg/util/xlog"
|
"github.com/fatedier/frp/pkg/util/xlog"
|
||||||
"github.com/fatedier/frp/server/controller"
|
"github.com/fatedier/frp/server/controller"
|
||||||
@@ -108,9 +105,7 @@ type SessionContext struct {
|
|||||||
// key used for connection encryption
|
// key used for connection encryption
|
||||||
EncryptionKey []byte
|
EncryptionKey []byte
|
||||||
// control connection
|
// control connection
|
||||||
Conn net.Conn
|
Conn *msg.Conn
|
||||||
// indicates whether the connection is encrypted
|
|
||||||
ConnEncrypted bool
|
|
||||||
// login message
|
// login message
|
||||||
LoginMsg *msg.Login
|
LoginMsg *msg.Login
|
||||||
// server configuration
|
// server configuration
|
||||||
@@ -131,7 +126,7 @@ type Control struct {
|
|||||||
msgDispatcher *msg.Dispatcher
|
msgDispatcher *msg.Dispatcher
|
||||||
|
|
||||||
// work connections
|
// work connections
|
||||||
workConnCh chan net.Conn
|
workConnCh chan *proxy.WorkConn
|
||||||
|
|
||||||
// proxies in one client
|
// proxies in one client
|
||||||
proxies map[string]proxy.Proxy
|
proxies map[string]proxy.Proxy
|
||||||
@@ -161,7 +156,7 @@ func NewControl(ctx context.Context, sessionCtx *SessionContext) (*Control, erro
|
|||||||
poolCount := min(sessionCtx.LoginMsg.PoolCount, int(sessionCtx.ServerCfg.Transport.MaxPoolCount))
|
poolCount := min(sessionCtx.LoginMsg.PoolCount, int(sessionCtx.ServerCfg.Transport.MaxPoolCount))
|
||||||
ctl := &Control{
|
ctl := &Control{
|
||||||
sessionCtx: sessionCtx,
|
sessionCtx: sessionCtx,
|
||||||
workConnCh: make(chan net.Conn, poolCount+10),
|
workConnCh: make(chan *proxy.WorkConn, poolCount+10),
|
||||||
proxies: make(map[string]proxy.Proxy),
|
proxies: make(map[string]proxy.Proxy),
|
||||||
poolCount: poolCount,
|
poolCount: poolCount,
|
||||||
portsUsedNum: 0,
|
portsUsedNum: 0,
|
||||||
@@ -172,29 +167,14 @@ func NewControl(ctx context.Context, sessionCtx *SessionContext) (*Control, erro
|
|||||||
}
|
}
|
||||||
ctl.lastPing.Store(time.Now())
|
ctl.lastPing.Store(time.Now())
|
||||||
|
|
||||||
if sessionCtx.ConnEncrypted {
|
ctl.msgDispatcher = msg.NewDispatcher(sessionCtx.Conn)
|
||||||
cryptoRW, err := netpkg.NewCryptoReadWriter(sessionCtx.Conn, sessionCtx.EncryptionKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
ctl.msgDispatcher = msg.NewDispatcher(cryptoRW)
|
|
||||||
} else {
|
|
||||||
ctl.msgDispatcher = msg.NewDispatcher(sessionCtx.Conn)
|
|
||||||
}
|
|
||||||
ctl.registerMsgHandlers()
|
ctl.registerMsgHandlers()
|
||||||
ctl.msgTransporter = transport.NewMessageTransporter(ctl.msgDispatcher)
|
ctl.msgTransporter = transport.NewMessageTransporter(ctl.msgDispatcher)
|
||||||
return ctl, nil
|
return ctl, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start send a login success message to client and start working.
|
// Start starts the control session workers after login succeeds.
|
||||||
func (ctl *Control) Start() {
|
func (ctl *Control) Start() {
|
||||||
loginRespMsg := &msg.LoginResp{
|
|
||||||
Version: version.Full(),
|
|
||||||
RunID: ctl.runID,
|
|
||||||
Error: "",
|
|
||||||
}
|
|
||||||
_ = msg.WriteMsg(ctl.sessionCtx.Conn, loginRespMsg)
|
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for i := 0; i < ctl.poolCount; i++ {
|
for i := 0; i < ctl.poolCount; i++ {
|
||||||
// ignore error here, that means that this control is closed
|
// ignore error here, that means that this control is closed
|
||||||
@@ -216,7 +196,7 @@ func (ctl *Control) Replaced(newCtl *Control) {
|
|||||||
ctl.sessionCtx.Conn.Close()
|
ctl.sessionCtx.Conn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ctl *Control) RegisterWorkConn(conn net.Conn) error {
|
func (ctl *Control) RegisterWorkConn(conn *proxy.WorkConn) error {
|
||||||
xl := ctl.xl
|
xl := ctl.xl
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := recover(); err != nil {
|
if err := recover(); err != nil {
|
||||||
@@ -239,7 +219,7 @@ func (ctl *Control) RegisterWorkConn(conn net.Conn) error {
|
|||||||
// If no workConn available in the pool, send message to frpc to get one or more
|
// If no workConn available in the pool, send message to frpc to get one or more
|
||||||
// and wait until it is available.
|
// and wait until it is available.
|
||||||
// return an error if wait timeout
|
// return an error if wait timeout
|
||||||
func (ctl *Control) GetWorkConn() (workConn net.Conn, err error) {
|
func (ctl *Control) GetWorkConn() (workConn *proxy.WorkConn, err error) {
|
||||||
xl := ctl.xl
|
xl := ctl.xl
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := recover(); err != nil {
|
if err := recover(); err != nil {
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ type HTTPGroup struct {
|
|||||||
// CreateConnFuncs indexed by proxy name
|
// CreateConnFuncs indexed by proxy name
|
||||||
createFuncs map[string]vhost.CreateConnFunc
|
createFuncs map[string]vhost.CreateConnFunc
|
||||||
pxyNames []string
|
pxyNames []string
|
||||||
index uint64
|
index atomic.Uint64
|
||||||
ctl *HTTPGroupController
|
ctl *HTTPGroupController
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
}
|
}
|
||||||
@@ -136,7 +136,7 @@ func (g *HTTPGroup) UnRegister(proxyName string) {
|
|||||||
|
|
||||||
func (g *HTTPGroup) createConn(remoteAddr string) (net.Conn, error) {
|
func (g *HTTPGroup) createConn(remoteAddr string) (net.Conn, error) {
|
||||||
var f vhost.CreateConnFunc
|
var f vhost.CreateConnFunc
|
||||||
newIndex := atomic.AddUint64(&g.index, 1)
|
newIndex := g.index.Add(1)
|
||||||
|
|
||||||
g.mu.RLock()
|
g.mu.RLock()
|
||||||
group := g.group
|
group := g.group
|
||||||
@@ -158,7 +158,7 @@ func (g *HTTPGroup) createConn(remoteAddr string) (net.Conn, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (g *HTTPGroup) chooseEndpoint() (string, error) {
|
func (g *HTTPGroup) chooseEndpoint() (string, error) {
|
||||||
newIndex := atomic.AddUint64(&g.index, 1)
|
newIndex := g.index.Add(1)
|
||||||
name := ""
|
name := ""
|
||||||
|
|
||||||
g.mu.RLock()
|
g.mu.RLock()
|
||||||
|
|||||||
@@ -287,6 +287,7 @@ func buildClientInfoResp(info registry.ClientInfo) model.ClientInfoResp {
|
|||||||
ClientID: info.ClientID(),
|
ClientID: info.ClientID(),
|
||||||
RunID: info.RunID,
|
RunID: info.RunID,
|
||||||
Version: info.Version,
|
Version: info.Version,
|
||||||
|
WireProtocol: info.WireProtocol,
|
||||||
Hostname: info.Hostname,
|
Hostname: info.Hostname,
|
||||||
ClientIP: info.IP,
|
ClientIP: info.IP,
|
||||||
FirstConnectedAt: toUnix(info.FirstConnectedAt),
|
FirstConnectedAt: toUnix(info.FirstConnectedAt),
|
||||||
|
|||||||
@@ -17,8 +17,11 @@ package http
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||||
|
"github.com/fatedier/frp/pkg/proto/wire"
|
||||||
|
"github.com/fatedier/frp/server/registry"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGetConfFromConfigurerKeepsPluginFields(t *testing.T) {
|
func TestGetConfFromConfigurerKeepsPluginFields(t *testing.T) {
|
||||||
@@ -69,3 +72,24 @@ func TestGetConfFromConfigurerKeepsPluginFields(t *testing.T) {
|
|||||||
t.Fatalf("plugin httpPassword mismatch, want %q got %#v", "password", got)
|
t.Fatalf("plugin httpPassword mismatch, want %q got %#v", "password", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBuildClientInfoRespIncludesWireProtocol(t *testing.T) {
|
||||||
|
info := registry.ClientInfo{
|
||||||
|
Key: "user.client",
|
||||||
|
User: "user",
|
||||||
|
RawClientID: "client",
|
||||||
|
RunID: "run-id",
|
||||||
|
Version: "1.0.0",
|
||||||
|
WireProtocol: wire.ProtocolV2,
|
||||||
|
Hostname: "host",
|
||||||
|
IP: "127.0.0.1",
|
||||||
|
FirstConnectedAt: time.Unix(1, 0),
|
||||||
|
LastConnectedAt: time.Unix(2, 0),
|
||||||
|
Online: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := buildClientInfoResp(info)
|
||||||
|
if resp.WireProtocol != wire.ProtocolV2 {
|
||||||
|
t.Fatalf("wire protocol mismatch, want %q got %q", wire.ProtocolV2, resp.WireProtocol)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ type ClientInfoResp struct {
|
|||||||
ClientID string `json:"clientID"`
|
ClientID string `json:"clientID"`
|
||||||
RunID string `json:"runID"`
|
RunID string `json:"runID"`
|
||||||
Version string `json:"version,omitempty"`
|
Version string `json:"version,omitempty"`
|
||||||
|
WireProtocol string `json:"wireProtocol,omitempty"`
|
||||||
Hostname string `json:"hostname"`
|
Hostname string `json:"hostname"`
|
||||||
ClientIP string `json:"clientIP,omitempty"`
|
ClientIP string `json:"clientIP,omitempty"`
|
||||||
FirstConnectedAt int64 `json:"firstConnectedAt"`
|
FirstConnectedAt int64 `json:"firstConnectedAt"`
|
||||||
|
|||||||
@@ -44,7 +44,26 @@ func RegisterProxyFactory(proxyConfType reflect.Type, factory func(*BaseProxy) P
|
|||||||
proxyFactoryRegistry[proxyConfType] = factory
|
proxyFactoryRegistry[proxyConfType] = factory
|
||||||
}
|
}
|
||||||
|
|
||||||
type GetWorkConnFn func() (net.Conn, error)
|
type WorkConn struct {
|
||||||
|
conn *msg.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewWorkConn(conn *msg.Conn) *WorkConn {
|
||||||
|
return &WorkConn{conn: conn}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WorkConn) Start(m *msg.StartWorkConn) (net.Conn, error) {
|
||||||
|
if err := c.conn.WriteMsg(m); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return c.conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WorkConn) Close() error {
|
||||||
|
return c.conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
type GetWorkConnFn func() (*WorkConn, error)
|
||||||
|
|
||||||
type Proxy interface {
|
type Proxy interface {
|
||||||
Context() context.Context
|
Context() context.Context
|
||||||
@@ -125,13 +144,13 @@ func (pxy *BaseProxy) GetWorkConnFromPool(src, dst net.Addr) (workConn net.Conn,
|
|||||||
xl := xlog.FromContextSafe(pxy.ctx)
|
xl := xlog.FromContextSafe(pxy.ctx)
|
||||||
// try all connections from the pool
|
// try all connections from the pool
|
||||||
for i := 0; i < pxy.poolCount+1; i++ {
|
for i := 0; i < pxy.poolCount+1; i++ {
|
||||||
if workConn, err = pxy.getWorkConnFn(); err != nil {
|
var pxyWorkConn *WorkConn
|
||||||
|
if pxyWorkConn, err = pxy.getWorkConnFn(); err != nil {
|
||||||
xl.Warnf("failed to get work connection: %v", err)
|
xl.Warnf("failed to get work connection: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
xl.Debugf("get a new work connection: [%s]", workConn.RemoteAddr().String())
|
xl.Debugf("get a new work connection: [%s]", pxyWorkConn.conn.RemoteAddr().String())
|
||||||
xl.Spawn().AppendPrefix(pxy.GetName())
|
xl.Spawn().AppendPrefix(pxy.GetName())
|
||||||
workConn = netpkg.NewContextConn(pxy.ctx, workConn)
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
srcAddr string
|
srcAddr string
|
||||||
@@ -150,7 +169,7 @@ func (pxy *BaseProxy) GetWorkConnFromPool(src, dst net.Addr) (workConn net.Conn,
|
|||||||
dstAddr, dstPortStr, _ = net.SplitHostPort(dst.String())
|
dstAddr, dstPortStr, _ = net.SplitHostPort(dst.String())
|
||||||
dstPort, _ = strconv.ParseUint(dstPortStr, 10, 16)
|
dstPort, _ = strconv.ParseUint(dstPortStr, 10, 16)
|
||||||
}
|
}
|
||||||
err = msg.WriteMsg(workConn, &msg.StartWorkConn{
|
workConn, err = pxyWorkConn.Start(&msg.StartWorkConn{
|
||||||
ProxyName: pxy.GetName(),
|
ProxyName: pxy.GetName(),
|
||||||
SrcAddr: srcAddr,
|
SrcAddr: srcAddr,
|
||||||
SrcPort: uint16(srcPort),
|
SrcPort: uint16(srcPort),
|
||||||
@@ -160,9 +179,10 @@ func (pxy *BaseProxy) GetWorkConnFromPool(src, dst net.Addr) (workConn net.Conn,
|
|||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
xl.Warnf("failed to send message to work connection from pool: %v, times: %d", err, i)
|
xl.Warnf("failed to send message to work connection from pool: %v, times: %d", err, i)
|
||||||
workConn.Close()
|
pxyWorkConn.Close()
|
||||||
workConn = nil
|
workConn = nil
|
||||||
} else {
|
} else {
|
||||||
|
workConn = netpkg.NewContextConn(pxy.ctx, workConn)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
53
server/proxy/proxy_test.go
Normal file
53
server/proxy/proxy_test.go
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
// Copyright 2026 The frp Authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/fatedier/frp/pkg/msg"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWorkConnStartWritesStartWorkConn(t *testing.T) {
|
||||||
|
client, server := net.Pipe()
|
||||||
|
defer client.Close()
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
serverMsgConn := msg.NewConn(server, msg.NewV2ReadWriter(server))
|
||||||
|
clientMsgConn := msg.NewConn(client, msg.NewV2ReadWriter(client))
|
||||||
|
workConn := NewWorkConn(serverMsgConn)
|
||||||
|
|
||||||
|
in := &msg.StartWorkConn{ProxyName: "tcp", SrcAddr: "127.0.0.1", SrcPort: 1234}
|
||||||
|
type startResult struct {
|
||||||
|
conn net.Conn
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
resultCh := make(chan startResult, 1)
|
||||||
|
go func() {
|
||||||
|
conn, err := workConn.Start(in)
|
||||||
|
resultCh <- startResult{conn: conn, err: err}
|
||||||
|
}()
|
||||||
|
|
||||||
|
out, err := clientMsgConn.ReadMsg()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, in, out)
|
||||||
|
|
||||||
|
result := <-resultCh
|
||||||
|
require.NoError(t, result.err)
|
||||||
|
require.Same(t, serverMsgConn, result.conn)
|
||||||
|
}
|
||||||
@@ -29,6 +29,7 @@ type ClientInfo struct {
|
|||||||
Hostname string
|
Hostname string
|
||||||
IP string
|
IP string
|
||||||
Version string
|
Version string
|
||||||
|
WireProtocol string
|
||||||
FirstConnectedAt time.Time
|
FirstConnectedAt time.Time
|
||||||
LastConnectedAt time.Time
|
LastConnectedAt time.Time
|
||||||
DisconnectedAt time.Time
|
DisconnectedAt time.Time
|
||||||
@@ -51,7 +52,7 @@ func NewClientRegistry() *ClientRegistry {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Register stores/updates metadata for a client and returns the registry key plus whether it conflicts with an online client.
|
// Register stores/updates metadata for a client and returns the registry key plus whether it conflicts with an online client.
|
||||||
func (cr *ClientRegistry) Register(user, rawClientID, runID, hostname, version, remoteAddr string) (key string, conflict bool) {
|
func (cr *ClientRegistry) Register(user, rawClientID, runID, hostname, version, remoteAddr, wireProtocol string) (key string, conflict bool) {
|
||||||
if runID == "" {
|
if runID == "" {
|
||||||
return "", false
|
return "", false
|
||||||
}
|
}
|
||||||
@@ -88,6 +89,7 @@ func (cr *ClientRegistry) Register(user, rawClientID, runID, hostname, version,
|
|||||||
info.Hostname = hostname
|
info.Hostname = hostname
|
||||||
info.IP = remoteAddr
|
info.IP = remoteAddr
|
||||||
info.Version = version
|
info.Version = version
|
||||||
|
info.WireProtocol = wireProtocol
|
||||||
if info.FirstConnectedAt.IsZero() {
|
if info.FirstConnectedAt.IsZero() {
|
||||||
info.FirstConnectedAt = now
|
info.FirstConnectedAt = now
|
||||||
}
|
}
|
||||||
|
|||||||
37
server/registry/registry_test.go
Normal file
37
server/registry/registry_test.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
// Copyright 2026 The frp Authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package registry
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/fatedier/frp/pkg/proto/wire"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestClientRegistryRegisterStoresWireProtocol(t *testing.T) {
|
||||||
|
registry := NewClientRegistry()
|
||||||
|
key, conflict := registry.Register("user", "client-id", "run-id", "host", "1.0.0", "127.0.0.1", wire.ProtocolV2)
|
||||||
|
if conflict {
|
||||||
|
t.Fatal("unexpected client conflict")
|
||||||
|
}
|
||||||
|
|
||||||
|
info, ok := registry.GetByKey(key)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("client %q not found", key)
|
||||||
|
}
|
||||||
|
if info.WireProtocol != wire.ProtocolV2 {
|
||||||
|
t.Fatalf("wire protocol mismatch, want %q got %q", wire.ProtocolV2, info.WireProtocol)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
@@ -37,6 +38,7 @@ import (
|
|||||||
"github.com/fatedier/frp/pkg/msg"
|
"github.com/fatedier/frp/pkg/msg"
|
||||||
"github.com/fatedier/frp/pkg/nathole"
|
"github.com/fatedier/frp/pkg/nathole"
|
||||||
plugin "github.com/fatedier/frp/pkg/plugin/server"
|
plugin "github.com/fatedier/frp/pkg/plugin/server"
|
||||||
|
"github.com/fatedier/frp/pkg/proto/wire"
|
||||||
"github.com/fatedier/frp/pkg/ssh"
|
"github.com/fatedier/frp/pkg/ssh"
|
||||||
"github.com/fatedier/frp/pkg/transport"
|
"github.com/fatedier/frp/pkg/transport"
|
||||||
httppkg "github.com/fatedier/frp/pkg/util/http"
|
httppkg "github.com/fatedier/frp/pkg/util/http"
|
||||||
@@ -432,20 +434,15 @@ func (svr *Service) Close() error {
|
|||||||
func (svr *Service) handleConnection(ctx context.Context, conn net.Conn, internal bool) {
|
func (svr *Service) handleConnection(ctx context.Context, conn net.Conn, internal bool) {
|
||||||
xl := xlog.FromContextSafe(ctx)
|
xl := xlog.FromContextSafe(ctx)
|
||||||
|
|
||||||
var (
|
acceptedConn, err := svr.acceptConnection(ctx, conn)
|
||||||
rawMsg msg.Message
|
if err != nil {
|
||||||
err error
|
log.Tracef("failed to accept frp connection: %v", err)
|
||||||
)
|
|
||||||
|
|
||||||
_ = conn.SetReadDeadline(time.Now().Add(connReadTimeout))
|
|
||||||
if rawMsg, err = msg.ReadMsg(conn); err != nil {
|
|
||||||
log.Tracef("failed to read message: %v", err)
|
|
||||||
conn.Close()
|
conn.Close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
_ = conn.SetReadDeadline(time.Time{})
|
conn = acceptedConn.conn
|
||||||
|
|
||||||
switch m := rawMsg.(type) {
|
switch m := acceptedConn.firstMsg.(type) {
|
||||||
case *msg.Login:
|
case *msg.Login:
|
||||||
// server plugin hook
|
// server plugin hook
|
||||||
content := &plugin.LoginContent{
|
content := &plugin.LoginContent{
|
||||||
@@ -453,35 +450,66 @@ func (svr *Service) handleConnection(ctx context.Context, conn net.Conn, interna
|
|||||||
ClientAddress: conn.RemoteAddr().String(),
|
ClientAddress: conn.RemoteAddr().String(),
|
||||||
}
|
}
|
||||||
retContent, err := svr.pluginManager.Login(content)
|
retContent, err := svr.pluginManager.Login(content)
|
||||||
|
var ctl *Control
|
||||||
if err == nil {
|
if err == nil {
|
||||||
m = &retContent.Login
|
m = &retContent.Login
|
||||||
err = svr.RegisterControl(conn, m, internal)
|
controlConn := acceptedConn.conn
|
||||||
|
if !internal {
|
||||||
|
var controlRW io.ReadWriter
|
||||||
|
controlRW, err = netpkg.NewCryptoReadWriter(conn, svr.auth.EncryptionKey())
|
||||||
|
if err == nil {
|
||||||
|
controlConn = acceptedConn.messageConnFor(controlRW)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
ctl, err = svr.RegisterControl(controlConn, m, internal, acceptedConn.wireProtocol)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If login failed, send error message there.
|
|
||||||
// Otherwise send success message in control's work goroutine.
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
xl.Warnf("register control error: %v", err)
|
xl.Warnf("register control error: %v", err)
|
||||||
_ = msg.WriteMsg(conn, &msg.LoginResp{
|
_ = acceptedConn.conn.WriteMsg(&msg.LoginResp{
|
||||||
Version: version.Full(),
|
Version: version.Full(),
|
||||||
Error: util.GenerateResponseErrorString("register control error", err, lo.FromPtr(svr.cfg.DetailedErrorsToClient)),
|
Error: util.GenerateResponseErrorString("register control error", err, lo.FromPtr(svr.cfg.DetailedErrorsToClient)),
|
||||||
})
|
})
|
||||||
conn.Close()
|
conn.Close()
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
if err = acceptedConn.conn.WriteMsg(&msg.LoginResp{
|
||||||
|
Version: version.Full(),
|
||||||
|
RunID: ctl.runID,
|
||||||
|
Error: "",
|
||||||
|
}); err != nil {
|
||||||
|
xl.Warnf("write login response error: %v", err)
|
||||||
|
svr.ctlManager.Del(m.RunID, ctl)
|
||||||
|
svr.clientRegistry.MarkOfflineByRunID(m.RunID)
|
||||||
|
conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ctl.Start()
|
||||||
|
metrics.Server.NewClient()
|
||||||
|
go func() {
|
||||||
|
// block until control closed
|
||||||
|
ctl.WaitClosed()
|
||||||
|
svr.ctlManager.Del(m.RunID, ctl)
|
||||||
|
}()
|
||||||
case *msg.NewWorkConn:
|
case *msg.NewWorkConn:
|
||||||
if err := svr.RegisterWorkConn(conn, m); err != nil {
|
if err := svr.RegisterWorkConn(acceptedConn.conn, m); err != nil {
|
||||||
|
_ = acceptedConn.conn.WriteMsg(&msg.StartWorkConn{
|
||||||
|
Error: util.GenerateResponseErrorString("invalid NewWorkConn", err, lo.FromPtr(svr.cfg.DetailedErrorsToClient)),
|
||||||
|
})
|
||||||
conn.Close()
|
conn.Close()
|
||||||
}
|
}
|
||||||
case *msg.NewVisitorConn:
|
case *msg.NewVisitorConn:
|
||||||
if err = svr.RegisterVisitorConn(conn, m); err != nil {
|
if err = svr.RegisterVisitorConn(conn, m); err != nil {
|
||||||
xl.Warnf("register visitor conn error: %v", err)
|
xl.Warnf("register visitor conn error: %v", err)
|
||||||
_ = msg.WriteMsg(conn, &msg.NewVisitorConnResp{
|
_ = acceptedConn.conn.WriteMsg(&msg.NewVisitorConnResp{
|
||||||
ProxyName: m.ProxyName,
|
ProxyName: m.ProxyName,
|
||||||
Error: util.GenerateResponseErrorString("register visitor conn error", err, lo.FromPtr(svr.cfg.DetailedErrorsToClient)),
|
Error: util.GenerateResponseErrorString("register visitor conn error", err, lo.FromPtr(svr.cfg.DetailedErrorsToClient)),
|
||||||
})
|
})
|
||||||
conn.Close()
|
conn.Close()
|
||||||
} else {
|
} else {
|
||||||
_ = msg.WriteMsg(conn, &msg.NewVisitorConnResp{
|
_ = acceptedConn.conn.WriteMsg(&msg.NewVisitorConnResp{
|
||||||
ProxyName: m.ProxyName,
|
ProxyName: m.ProxyName,
|
||||||
Error: "",
|
Error: "",
|
||||||
})
|
})
|
||||||
@@ -492,6 +520,87 @@ func (svr *Service) handleConnection(ctx context.Context, conn net.Conn, interna
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type acceptedConnection struct {
|
||||||
|
conn *msg.Conn
|
||||||
|
wireProtocol string
|
||||||
|
firstMsg msg.Message
|
||||||
|
}
|
||||||
|
|
||||||
|
func (svr *Service) acceptConnection(ctx context.Context, conn net.Conn) (*acceptedConnection, error) {
|
||||||
|
_ = conn.SetReadDeadline(time.Now().Add(connReadTimeout))
|
||||||
|
checkedConn, isV2, err := wire.CheckMagic(conn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read wire protocol magic: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
wireProtocol := wire.ProtocolV1
|
||||||
|
if isV2 {
|
||||||
|
wireProtocol = wire.ProtocolV2
|
||||||
|
}
|
||||||
|
|
||||||
|
conn = netpkg.NewContextConn(ctx, checkedConn)
|
||||||
|
acceptedConn := &acceptedConnection{wireProtocol: wireProtocol}
|
||||||
|
if isV2 {
|
||||||
|
wireConn := wire.NewConn(conn)
|
||||||
|
rw := msg.NewV2ReadWriterWithConn(wireConn)
|
||||||
|
acceptedConn.conn = msg.NewConn(conn, rw)
|
||||||
|
acceptedConn.firstMsg, err = acceptedConn.readFirstV2Msg(wireConn)
|
||||||
|
} else {
|
||||||
|
rw := msg.NewV1ReadWriter(conn)
|
||||||
|
acceptedConn.conn = msg.NewConn(conn, rw)
|
||||||
|
acceptedConn.firstMsg, err = acceptedConn.conn.ReadMsg()
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_ = conn.SetReadDeadline(time.Time{})
|
||||||
|
return acceptedConn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ac *acceptedConnection) messageConnFor(rw io.ReadWriter) *msg.Conn {
|
||||||
|
return msg.NewConn(ac.conn, msg.NewReadWriter(rw, ac.wireProtocol))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ac *acceptedConnection) readFirstV2Msg(wireConn *wire.Conn) (msg.Message, error) {
|
||||||
|
frame, err := wireConn.ReadFrame()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read v2 frame: %w", err)
|
||||||
|
}
|
||||||
|
if frame.Type == wire.FrameTypeClientHello {
|
||||||
|
if err := ac.handleClientHello(wireConn, frame); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
frame, err = wireConn.ReadFrame()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read first v2 message frame: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := msg.DecodeV2MessageFrame(frame)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("decode v2 message: %w", err)
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ac *acceptedConnection) handleClientHello(wireConn *wire.Conn, frame *wire.Frame) error {
|
||||||
|
var hello wire.ClientHello
|
||||||
|
if err := wireConn.UnmarshalFrame(frame, &hello); err != nil {
|
||||||
|
return fmt.Errorf("decode ClientHello: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
serverHello := wire.DefaultServerHello()
|
||||||
|
if err := wire.ValidateClientHello(hello); err != nil {
|
||||||
|
serverHello.Error = err.Error()
|
||||||
|
_ = wireConn.WriteJSONFrame(wire.FrameTypeServerHello, serverHello)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := wireConn.WriteJSONFrame(wire.FrameTypeServerHello, serverHello); err != nil {
|
||||||
|
return fmt.Errorf("write ServerHello: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// HandleListener accepts connections from client and call handleConnection to handle them.
|
// HandleListener accepts connections from client and call handleConnection to handle them.
|
||||||
// If internal is true, it means that this listener is used for internal communication like ssh tunnel gateway.
|
// If internal is true, it means that this listener is used for internal communication like ssh tunnel gateway.
|
||||||
// TODO(fatedier): Pass some parameters of listener/connection through context to avoid passing too many parameters.
|
// TODO(fatedier): Pass some parameters of listener/connection through context to avoid passing too many parameters.
|
||||||
@@ -577,14 +686,19 @@ func (svr *Service) HandleQUICListener(l *quic.Listener) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, internal bool) error {
|
func (svr *Service) RegisterControl(
|
||||||
|
ctlConn *msg.Conn,
|
||||||
|
loginMsg *msg.Login,
|
||||||
|
internal bool,
|
||||||
|
wireProtocol string,
|
||||||
|
) (*Control, error) {
|
||||||
// If client's RunID is empty, it's a new client, we just create a new controller.
|
// If client's RunID is empty, it's a new client, we just create a new controller.
|
||||||
// Otherwise, we check if there is one controller has the same run id. If so, we release previous controller and start new one.
|
// Otherwise, we check if there is one controller has the same run id. If so, we release previous controller and start new one.
|
||||||
var err error
|
var err error
|
||||||
if loginMsg.RunID == "" {
|
if loginMsg.RunID == "" {
|
||||||
loginMsg.RunID, err = util.RandID()
|
loginMsg.RunID, err = util.RandID()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -601,7 +715,7 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, inter
|
|||||||
authVerifier = auth.AlwaysPassVerifier
|
authVerifier = auth.AlwaysPassVerifier
|
||||||
}
|
}
|
||||||
if err := authVerifier.VerifyLogin(loginMsg); err != nil {
|
if err := authVerifier.VerifyLogin(loginMsg); err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ctl, err := NewControl(ctx, &SessionContext{
|
ctl, err := NewControl(ctx, &SessionContext{
|
||||||
@@ -611,7 +725,6 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, inter
|
|||||||
AuthVerifier: authVerifier,
|
AuthVerifier: authVerifier,
|
||||||
EncryptionKey: svr.auth.EncryptionKey(),
|
EncryptionKey: svr.auth.EncryptionKey(),
|
||||||
Conn: ctlConn,
|
Conn: ctlConn,
|
||||||
ConnEncrypted: !internal,
|
|
||||||
LoginMsg: loginMsg,
|
LoginMsg: loginMsg,
|
||||||
ServerCfg: svr.cfg,
|
ServerCfg: svr.cfg,
|
||||||
ClientRegistry: svr.clientRegistry,
|
ClientRegistry: svr.clientRegistry,
|
||||||
@@ -619,7 +732,7 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, inter
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
xl.Warnf("create new controller error: %v", err)
|
xl.Warnf("create new controller error: %v", err)
|
||||||
// don't return detailed errors to client
|
// don't return detailed errors to client
|
||||||
return fmt.Errorf("unexpected error when creating new controller")
|
return nil, fmt.Errorf("unexpected error when creating new controller")
|
||||||
}
|
}
|
||||||
|
|
||||||
if oldCtl := svr.ctlManager.Add(loginMsg.RunID, ctl); oldCtl != nil {
|
if oldCtl := svr.ctlManager.Add(loginMsg.RunID, ctl); oldCtl != nil {
|
||||||
@@ -630,34 +743,24 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, inter
|
|||||||
if host, _, err := net.SplitHostPort(remoteAddr); err == nil {
|
if host, _, err := net.SplitHostPort(remoteAddr); err == nil {
|
||||||
remoteAddr = host
|
remoteAddr = host
|
||||||
}
|
}
|
||||||
_, conflict := svr.clientRegistry.Register(loginMsg.User, loginMsg.ClientID, loginMsg.RunID, loginMsg.Hostname, loginMsg.Version, remoteAddr)
|
_, conflict := svr.clientRegistry.Register(loginMsg.User, loginMsg.ClientID, loginMsg.RunID, loginMsg.Hostname, loginMsg.Version, remoteAddr, wireProtocol)
|
||||||
if conflict {
|
if conflict {
|
||||||
svr.ctlManager.Del(loginMsg.RunID, ctl)
|
svr.ctlManager.Del(loginMsg.RunID, ctl)
|
||||||
ctl.Close()
|
return nil, fmt.Errorf("client_id [%s] for user [%s] is already online", loginMsg.ClientID, loginMsg.User)
|
||||||
return fmt.Errorf("client_id [%s] for user [%s] is already online", loginMsg.ClientID, loginMsg.User)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ctl.Start()
|
return ctl, nil
|
||||||
|
|
||||||
// for statistics
|
|
||||||
metrics.Server.NewClient()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
// block until control closed
|
|
||||||
ctl.WaitClosed()
|
|
||||||
svr.ctlManager.Del(loginMsg.RunID, ctl)
|
|
||||||
}()
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterWorkConn register a new work connection to control and proxies need it.
|
// RegisterWorkConn register a new work connection to control and proxies need it.
|
||||||
func (svr *Service) RegisterWorkConn(workConn net.Conn, newMsg *msg.NewWorkConn) error {
|
func (svr *Service) RegisterWorkConn(workConn *msg.Conn, newMsg *msg.NewWorkConn) error {
|
||||||
xl := netpkg.NewLogFromConn(workConn)
|
xl := netpkg.NewLogFromConn(workConn)
|
||||||
ctl, exist := svr.ctlManager.GetByID(newMsg.RunID)
|
ctl, exist := svr.ctlManager.GetByID(newMsg.RunID)
|
||||||
if !exist {
|
if !exist {
|
||||||
xl.Warnf("no client control found for run id [%s]", newMsg.RunID)
|
xl.Warnf("no client control found for run id [%s]", newMsg.RunID)
|
||||||
return fmt.Errorf("no client control found for run id [%s]", newMsg.RunID)
|
return fmt.Errorf("no client control found for run id [%s]", newMsg.RunID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// server plugin hook
|
// server plugin hook
|
||||||
content := &plugin.NewWorkConnContent{
|
content := &plugin.NewWorkConnContent{
|
||||||
User: plugin.UserInfo{
|
User: plugin.UserInfo{
|
||||||
@@ -675,12 +778,9 @@ func (svr *Service) RegisterWorkConn(workConn net.Conn, newMsg *msg.NewWorkConn)
|
|||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
xl.Warnf("invalid NewWorkConn with run id [%s]", newMsg.RunID)
|
xl.Warnf("invalid NewWorkConn with run id [%s]", newMsg.RunID)
|
||||||
_ = msg.WriteMsg(workConn, &msg.StartWorkConn{
|
return err
|
||||||
Error: util.GenerateResponseErrorString("invalid NewWorkConn", err, lo.FromPtr(svr.cfg.DetailedErrorsToClient)),
|
|
||||||
})
|
|
||||||
return fmt.Errorf("invalid NewWorkConn with run id [%s]", newMsg.RunID)
|
|
||||||
}
|
}
|
||||||
return ctl.RegisterWorkConn(workConn)
|
return ctl.RegisterWorkConn(proxy.NewWorkConn(workConn))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (svr *Service) RegisterVisitorConn(visitorConn net.Conn, newMsg *msg.NewVisitorConn) error {
|
func (svr *Service) RegisterVisitorConn(visitorConn net.Conn, newMsg *msg.NewVisitorConn) error {
|
||||||
|
|||||||
@@ -53,6 +53,8 @@ type Server struct {
|
|||||||
tokenRequestCount atomic.Int64
|
tokenRequestCount atomic.Int64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const maxTokenRequestBodySize = 1 << 20
|
||||||
|
|
||||||
type Option func(*Server)
|
type Option func(*Server)
|
||||||
|
|
||||||
func WithBindPort(port int) Option {
|
func WithBindPort(port int) Option {
|
||||||
@@ -178,6 +180,7 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
|
|||||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
r.Body = http.MaxBytesReader(w, r.Body, maxTokenRequestBodySize)
|
||||||
if err := r.ParseForm(); err != nil {
|
if err := r.ParseForm(); err != nil {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
@@ -187,7 +190,7 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.FormValue("grant_type") != "client_credentials" {
|
if r.Form.Get("grant_type") != "client_credentials" {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
@@ -199,8 +202,8 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Accept credentials from Basic Auth or form body.
|
// Accept credentials from Basic Auth or form body.
|
||||||
clientID, clientSecret, ok := r.BasicAuth()
|
clientID, clientSecret, ok := r.BasicAuth()
|
||||||
if !ok {
|
if !ok {
|
||||||
clientID = r.FormValue("client_id")
|
clientID = r.Form.Get("client_id")
|
||||||
clientSecret = r.FormValue("client_secret")
|
clientSecret = r.Form.Get("client_secret")
|
||||||
}
|
}
|
||||||
if clientID != s.clientID || clientSecret != s.clientSecret {
|
if clientID != s.clientID || clientSecret != s.clientSecret {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|||||||
163
test/e2e/v1/basic/wire.go
Normal file
163
test/e2e/v1/basic/wire.go
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
package basic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/onsi/ginkgo/v2"
|
||||||
|
|
||||||
|
"github.com/fatedier/frp/test/e2e/framework"
|
||||||
|
"github.com/fatedier/frp/test/e2e/framework/consts"
|
||||||
|
"github.com/fatedier/frp/test/e2e/pkg/port"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = ginkgo.Describe("[Feature: WireProtocol]", func() {
|
||||||
|
f := framework.NewDefaultFramework()
|
||||||
|
|
||||||
|
ginkgo.It("v1 tcp and udp proxies", func() {
|
||||||
|
serverConf := consts.DefaultServerConfig
|
||||||
|
tcpPortName := port.GenName("WireV1TCP")
|
||||||
|
udpPortName := port.GenName("WireV1UDP")
|
||||||
|
clientConf := consts.DefaultClientConfig + fmt.Sprintf(`
|
||||||
|
transport.wireProtocol = "v1"
|
||||||
|
|
||||||
|
[[proxies]]
|
||||||
|
name = "tcp"
|
||||||
|
type = "tcp"
|
||||||
|
localPort = {{ .%s }}
|
||||||
|
remotePort = {{ .%s }}
|
||||||
|
|
||||||
|
[[proxies]]
|
||||||
|
name = "udp"
|
||||||
|
type = "udp"
|
||||||
|
localPort = {{ .%s }}
|
||||||
|
remotePort = {{ .%s }}
|
||||||
|
`, framework.TCPEchoServerPort, tcpPortName, framework.UDPEchoServerPort, udpPortName)
|
||||||
|
|
||||||
|
f.RunProcesses(serverConf, []string{clientConf})
|
||||||
|
|
||||||
|
framework.NewRequestExpect(f).PortName(tcpPortName).Ensure()
|
||||||
|
framework.NewRequestExpect(f).Protocol("udp").PortName(udpPortName).Ensure()
|
||||||
|
})
|
||||||
|
|
||||||
|
ginkgo.It("v2 tcp and udp proxies", func() {
|
||||||
|
serverConf := consts.DefaultServerConfig
|
||||||
|
tcpPortName := port.GenName("WireV2TCP")
|
||||||
|
udpPortName := port.GenName("WireV2UDP")
|
||||||
|
clientConf := consts.DefaultClientConfig + fmt.Sprintf(`
|
||||||
|
transport.wireProtocol = "v2"
|
||||||
|
|
||||||
|
[[proxies]]
|
||||||
|
name = "tcp"
|
||||||
|
type = "tcp"
|
||||||
|
localPort = {{ .%s }}
|
||||||
|
remotePort = {{ .%s }}
|
||||||
|
|
||||||
|
[[proxies]]
|
||||||
|
name = "udp"
|
||||||
|
type = "udp"
|
||||||
|
localPort = {{ .%s }}
|
||||||
|
remotePort = {{ .%s }}
|
||||||
|
`, framework.TCPEchoServerPort, tcpPortName, framework.UDPEchoServerPort, udpPortName)
|
||||||
|
|
||||||
|
f.RunProcesses(serverConf, []string{clientConf})
|
||||||
|
|
||||||
|
framework.NewRequestExpect(f).PortName(tcpPortName).Ensure()
|
||||||
|
framework.NewRequestExpect(f).Protocol("udp").PortName(udpPortName).Ensure()
|
||||||
|
})
|
||||||
|
|
||||||
|
ginkgo.It("v2 stcp visitor", func() {
|
||||||
|
serverConf := consts.DefaultServerConfig
|
||||||
|
bindPortName := port.GenName("WireV2STCP")
|
||||||
|
clientServerConf := consts.DefaultClientConfig + fmt.Sprintf(`
|
||||||
|
user = "user1"
|
||||||
|
transport.wireProtocol = "v2"
|
||||||
|
|
||||||
|
[[proxies]]
|
||||||
|
name = "stcp"
|
||||||
|
type = "stcp"
|
||||||
|
secretKey = "abc"
|
||||||
|
localPort = {{ .%s }}
|
||||||
|
`, framework.TCPEchoServerPort)
|
||||||
|
clientVisitorConf := consts.DefaultClientConfig + fmt.Sprintf(`
|
||||||
|
user = "user1"
|
||||||
|
transport.wireProtocol = "v2"
|
||||||
|
|
||||||
|
[[visitors]]
|
||||||
|
name = "stcp-visitor"
|
||||||
|
type = "stcp"
|
||||||
|
serverName = "stcp"
|
||||||
|
secretKey = "abc"
|
||||||
|
bindPort = {{ .%s }}
|
||||||
|
`, bindPortName)
|
||||||
|
|
||||||
|
f.RunProcesses(serverConf, []string{clientServerConf, clientVisitorConf})
|
||||||
|
|
||||||
|
framework.NewRequestExpect(f).PortName(bindPortName).Ensure()
|
||||||
|
})
|
||||||
|
|
||||||
|
ginkgo.It("reports client wire protocol", func() {
|
||||||
|
webPort := f.AllocPort()
|
||||||
|
serverConf := consts.DefaultServerConfig + fmt.Sprintf(`
|
||||||
|
webServer.port = %d
|
||||||
|
`, webPort)
|
||||||
|
|
||||||
|
v1PortName := port.GenName("WireReportV1")
|
||||||
|
v1ClientConf := consts.DefaultClientConfig + fmt.Sprintf(`
|
||||||
|
clientID = "wire-v1"
|
||||||
|
transport.wireProtocol = "v1"
|
||||||
|
|
||||||
|
[[proxies]]
|
||||||
|
name = "v1"
|
||||||
|
type = "tcp"
|
||||||
|
localPort = {{ .%s }}
|
||||||
|
remotePort = {{ .%s }}
|
||||||
|
`, framework.TCPEchoServerPort, v1PortName)
|
||||||
|
|
||||||
|
v2PortName := port.GenName("WireReportV2")
|
||||||
|
v2ClientConf := consts.DefaultClientConfig + fmt.Sprintf(`
|
||||||
|
clientID = "wire-v2"
|
||||||
|
transport.wireProtocol = "v2"
|
||||||
|
|
||||||
|
[[proxies]]
|
||||||
|
name = "v2"
|
||||||
|
type = "tcp"
|
||||||
|
localPort = {{ .%s }}
|
||||||
|
remotePort = {{ .%s }}
|
||||||
|
`, framework.TCPEchoServerPort, v2PortName)
|
||||||
|
|
||||||
|
f.RunProcesses(serverConf, []string{v1ClientConf, v2ClientConf})
|
||||||
|
|
||||||
|
framework.NewRequestExpect(f).PortName(v1PortName).Ensure()
|
||||||
|
framework.NewRequestExpect(f).PortName(v2PortName).Ensure()
|
||||||
|
expectClientWireProtocol(webPort, "wire-v1", "v1")
|
||||||
|
expectClientWireProtocol(webPort, "wire-v2", "v2")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
type wireClientInfo struct {
|
||||||
|
ClientID string `json:"clientID"`
|
||||||
|
WireProtocol string `json:"wireProtocol"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func expectClientWireProtocol(webPort int, clientID string, wireProtocol string) {
|
||||||
|
resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d/api/clients", webPort))
|
||||||
|
framework.ExpectNoError(err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
framework.ExpectEqual(resp.StatusCode, 200)
|
||||||
|
|
||||||
|
content, err := io.ReadAll(resp.Body)
|
||||||
|
framework.ExpectNoError(err)
|
||||||
|
|
||||||
|
var clients []wireClientInfo
|
||||||
|
framework.ExpectNoError(json.Unmarshal(content, &clients))
|
||||||
|
for _, client := range clients {
|
||||||
|
if client.ClientID == clientID {
|
||||||
|
framework.ExpectEqual(client.WireProtocol, wireProtocol)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
framework.Failf("client %q not found in /api/clients response: %s", clientID, string(content))
|
||||||
|
}
|
||||||
@@ -16,6 +16,9 @@
|
|||||||
<el-tag v-if="client.version" size="small" type="success"
|
<el-tag v-if="client.version" size="small" type="success"
|
||||||
>v{{ client.version }}</el-tag
|
>v{{ client.version }}</el-tag
|
||||||
>
|
>
|
||||||
|
<el-tag v-if="client.wireProtocolLabel" size="small" type="info">
|
||||||
|
{{ client.wireProtocolLabel }}
|
||||||
|
</el-tag>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div class="card-meta">
|
<div class="card-meta">
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ export interface ClientInfoData {
|
|||||||
clientID: string
|
clientID: string
|
||||||
runID: string
|
runID: string
|
||||||
version?: string
|
version?: string
|
||||||
|
wireProtocol?: string
|
||||||
hostname: string
|
hostname: string
|
||||||
clientIP?: string
|
clientIP?: string
|
||||||
metas?: Record<string, string>
|
metas?: Record<string, string>
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ export class Client {
|
|||||||
clientID: string
|
clientID: string
|
||||||
runID: string
|
runID: string
|
||||||
version: string
|
version: string
|
||||||
|
wireProtocol: string
|
||||||
hostname: string
|
hostname: string
|
||||||
ip: string
|
ip: string
|
||||||
metas: Map<string, string>
|
metas: Map<string, string>
|
||||||
@@ -21,6 +22,7 @@ export class Client {
|
|||||||
this.clientID = data.clientID
|
this.clientID = data.clientID
|
||||||
this.runID = data.runID
|
this.runID = data.runID
|
||||||
this.version = data.version || ''
|
this.version = data.version || ''
|
||||||
|
this.wireProtocol = data.wireProtocol || ''
|
||||||
this.hostname = data.hostname
|
this.hostname = data.hostname
|
||||||
this.ip = data.clientIP || ''
|
this.ip = data.clientIP || ''
|
||||||
this.metas = new Map<string, string>()
|
this.metas = new Map<string, string>()
|
||||||
@@ -48,6 +50,11 @@ export class Client {
|
|||||||
return this.runID.substring(0, 8)
|
return this.runID.substring(0, 8)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
get wireProtocolLabel(): string {
|
||||||
|
if (!this.wireProtocol) return ''
|
||||||
|
return `Protocol ${this.wireProtocol}`
|
||||||
|
}
|
||||||
|
|
||||||
get firstConnectedAgo(): string {
|
get firstConnectedAgo(): string {
|
||||||
return formatDistanceToNow(this.firstConnectedAt)
|
return formatDistanceToNow(this.firstConnectedAt)
|
||||||
}
|
}
|
||||||
@@ -80,6 +87,7 @@ export class Client {
|
|||||||
this.user.toLowerCase().includes(search) ||
|
this.user.toLowerCase().includes(search) ||
|
||||||
this.clientID.toLowerCase().includes(search) ||
|
this.clientID.toLowerCase().includes(search) ||
|
||||||
this.runID.toLowerCase().includes(search) ||
|
this.runID.toLowerCase().includes(search) ||
|
||||||
|
this.wireProtocol.toLowerCase().includes(search) ||
|
||||||
this.hostname.toLowerCase().includes(search)
|
this.hostname.toLowerCase().includes(search)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -27,6 +27,9 @@
|
|||||||
<el-tag v-if="client.version" size="small" type="success"
|
<el-tag v-if="client.version" size="small" type="success"
|
||||||
>v{{ client.version }}</el-tag
|
>v{{ client.version }}</el-tag
|
||||||
>
|
>
|
||||||
|
<el-tag v-if="client.wireProtocolLabel" size="small" type="info">
|
||||||
|
{{ client.wireProtocolLabel }}
|
||||||
|
</el-tag>
|
||||||
</div>
|
</div>
|
||||||
<div class="client-meta">
|
<div class="client-meta">
|
||||||
<span v-if="client.ip" class="meta-item">{{
|
<span v-if="client.ip" class="meta-item">{{
|
||||||
@@ -58,6 +61,10 @@
|
|||||||
<span class="info-label">Run ID</span>
|
<span class="info-label">Run ID</span>
|
||||||
<span class="info-value">{{ client.runID }}</span>
|
<span class="info-value">{{ client.runID }}</span>
|
||||||
</div>
|
</div>
|
||||||
|
<div v-if="client.wireProtocol" class="info-item">
|
||||||
|
<span class="info-label">Protocol</span>
|
||||||
|
<span class="info-value">{{ client.wireProtocol }}</span>
|
||||||
|
</div>
|
||||||
<div class="info-item">
|
<div class="info-item">
|
||||||
<span class="info-label">First Connected</span>
|
<span class="info-label">First Connected</span>
|
||||||
<span class="info-value">{{ client.firstConnectedAgo }}</span>
|
<span class="info-value">{{ client.firstConnectedAgo }}</span>
|
||||||
|
|||||||
7
web/package-lock.json
generated
7
web/package-lock.json
generated
@@ -5753,9 +5753,9 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"node_modules/postcss": {
|
"node_modules/postcss": {
|
||||||
"version": "8.5.8",
|
"version": "8.5.12",
|
||||||
"resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.8.tgz",
|
"resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.12.tgz",
|
||||||
"integrity": "sha512-OW/rX8O/jXnm82Ey1k44pObPtdblfiuWnrd8X7GJ7emImCOstunGbXUpp7HdBrFQX6rJzn3sPT397Wp5aCwCHg==",
|
"integrity": "sha512-W62t/Se6rA0Az3DfCL0AqJwXuKwBeYg6nOaIgzP+xZ7N5BFCI7DYi1qs6ygUYT6rvfi6t9k65UMLJC+PHZpDAA==",
|
||||||
"funding": [
|
"funding": [
|
||||||
{
|
{
|
||||||
"type": "opencollective",
|
"type": "opencollective",
|
||||||
@@ -5770,7 +5770,6 @@
|
|||||||
"url": "https://github.com/sponsors/ai"
|
"url": "https://github.com/sponsors/ai"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"license": "MIT",
|
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"nanoid": "^3.3.11",
|
"nanoid": "^3.3.11",
|
||||||
"picocolors": "^1.1.1",
|
"picocolors": "^1.1.1",
|
||||||
|
|||||||
Reference in New Issue
Block a user