mirror of
https://github.com/fatedier/frp.git
synced 2026-04-27 19:39:10 +08:00
protocol: add v2 wire protocol with binary framing and capability negotiation (#5294)
This commit is contained in:
245
client/control_session_test.go
Normal file
245
client/control_session_test.go
Normal file
@@ -0,0 +1,245 @@
|
||||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/fatedier/frp/pkg/auth"
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
"github.com/fatedier/frp/pkg/proto/wire"
|
||||
)
|
||||
|
||||
type testConnector struct {
|
||||
conn net.Conn
|
||||
closed atomic.Bool
|
||||
}
|
||||
|
||||
func (c *testConnector) Open() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *testConnector) Connect() (net.Conn, error) {
|
||||
return c.conn, nil
|
||||
}
|
||||
|
||||
func (c *testConnector) Close() error {
|
||||
c.closed.Store(true)
|
||||
return nil
|
||||
}
|
||||
|
||||
type trackingConn struct {
|
||||
net.Conn
|
||||
closed atomic.Bool
|
||||
}
|
||||
|
||||
func (c *trackingConn) Close() error {
|
||||
c.closed.Store(true)
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
func newTestControlSessionDialer(t *testing.T, protocol string, connector Connector, clientSpec *msg.ClientSpec) *controlSessionDialer {
|
||||
t.Helper()
|
||||
|
||||
authRuntime, err := auth.BuildClientAuth(&v1.AuthClientConfig{
|
||||
Method: v1.AuthMethodToken,
|
||||
Token: "token",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return &controlSessionDialer{
|
||||
ctx: context.Background(),
|
||||
common: &v1.ClientCommonConfig{
|
||||
User: "test-user",
|
||||
Transport: v1.ClientTransportConfig{
|
||||
Protocol: "tcp",
|
||||
WireProtocol: protocol,
|
||||
},
|
||||
},
|
||||
auth: authRuntime,
|
||||
clientSpec: clientSpec,
|
||||
connectorCreator: func(context.Context, *v1.ClientCommonConfig) Connector {
|
||||
return connector
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestControlSessionDialerDialV1(t *testing.T) {
|
||||
clientRaw, serverRaw := net.Pipe()
|
||||
defer serverRaw.Close()
|
||||
|
||||
connector := &testConnector{conn: &trackingConn{Conn: clientRaw}}
|
||||
serverErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
rw := msg.NewV1ReadWriter(serverRaw)
|
||||
var loginMsg msg.Login
|
||||
if err := rw.ReadMsgInto(&loginMsg); err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
if loginMsg.RunID != "previous-run-id" {
|
||||
serverErrCh <- fmt.Errorf("unexpected previous run id: %s", loginMsg.RunID)
|
||||
return
|
||||
}
|
||||
if loginMsg.User != "test-user" {
|
||||
serverErrCh <- fmt.Errorf("unexpected user: %s", loginMsg.User)
|
||||
return
|
||||
}
|
||||
serverErrCh <- rw.WriteMsg(&msg.LoginResp{RunID: "run-v1"})
|
||||
}()
|
||||
|
||||
dialer := newTestControlSessionDialer(t, wire.ProtocolV1, connector, nil)
|
||||
sessionCtx, err := dialer.Dial("previous-run-id")
|
||||
require.NoError(t, err)
|
||||
defer sessionCtx.Conn.Close()
|
||||
defer sessionCtx.Connector.Close()
|
||||
|
||||
require.Equal(t, "run-v1", sessionCtx.RunID)
|
||||
require.NotNil(t, sessionCtx.Conn)
|
||||
require.NotNil(t, sessionCtx.Connector)
|
||||
require.False(t, connector.closed.Load())
|
||||
require.NoError(t, <-serverErrCh)
|
||||
}
|
||||
|
||||
func TestControlSessionDialerDialV2(t *testing.T) {
|
||||
clientRaw, serverRaw := net.Pipe()
|
||||
defer serverRaw.Close()
|
||||
|
||||
connector := &testConnector{conn: &trackingConn{Conn: clientRaw}}
|
||||
serverErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
magic := make([]byte, len(wire.MagicV2))
|
||||
if _, err := io.ReadFull(serverRaw, magic); err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
if string(magic) != wire.MagicV2 {
|
||||
serverErrCh <- fmt.Errorf("unexpected magic: %q", string(magic))
|
||||
return
|
||||
}
|
||||
|
||||
wireConn := wire.NewConn(serverRaw)
|
||||
var hello wire.ClientHello
|
||||
if err := wireConn.ReadJSONFrame(wire.FrameTypeClientHello, &hello); err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
if err := wire.ValidateClientHello(hello); err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
rw := msg.NewV2ReadWriterWithConn(wireConn)
|
||||
var loginMsg msg.Login
|
||||
if err := rw.ReadMsgInto(&loginMsg); err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
if loginMsg.User != "test-user" {
|
||||
serverErrCh <- fmt.Errorf("unexpected user: %s", loginMsg.User)
|
||||
return
|
||||
}
|
||||
if err := wireConn.WriteJSONFrame(wire.FrameTypeServerHello, wire.DefaultServerHello()); err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
serverErrCh <- rw.WriteMsg(&msg.LoginResp{RunID: "run-v2"})
|
||||
}()
|
||||
|
||||
dialer := newTestControlSessionDialer(t, wire.ProtocolV2, connector, nil)
|
||||
sessionCtx, err := dialer.Dial("")
|
||||
require.NoError(t, err)
|
||||
defer sessionCtx.Conn.Close()
|
||||
defer sessionCtx.Connector.Close()
|
||||
|
||||
require.Equal(t, "run-v2", sessionCtx.RunID)
|
||||
require.NotNil(t, sessionCtx.Conn)
|
||||
require.NotNil(t, sessionCtx.Connector)
|
||||
require.False(t, connector.closed.Load())
|
||||
require.NoError(t, <-serverErrCh)
|
||||
}
|
||||
|
||||
func TestControlSessionDialerDialLoginErrorClosesResources(t *testing.T) {
|
||||
clientRaw, serverRaw := net.Pipe()
|
||||
defer serverRaw.Close()
|
||||
|
||||
clientConn := &trackingConn{Conn: clientRaw}
|
||||
connector := &testConnector{conn: clientConn}
|
||||
serverErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
rw := msg.NewV1ReadWriter(serverRaw)
|
||||
var loginMsg msg.Login
|
||||
if err := rw.ReadMsgInto(&loginMsg); err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
serverErrCh <- rw.WriteMsg(&msg.LoginResp{Error: "login denied"})
|
||||
}()
|
||||
|
||||
dialer := newTestControlSessionDialer(t, wire.ProtocolV1, connector, nil)
|
||||
sessionCtx, err := dialer.Dial("")
|
||||
require.Nil(t, sessionCtx)
|
||||
require.ErrorContains(t, err, "login denied")
|
||||
require.True(t, clientConn.closed.Load())
|
||||
require.True(t, connector.closed.Load())
|
||||
require.NoError(t, <-serverErrCh)
|
||||
}
|
||||
|
||||
func TestControlSessionDialerDialSSHTunnelSkipsControlEncryption(t *testing.T) {
|
||||
clientRaw, serverRaw := net.Pipe()
|
||||
defer serverRaw.Close()
|
||||
|
||||
connector := &testConnector{conn: &trackingConn{Conn: clientRaw}}
|
||||
serverErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
rw := msg.NewV1ReadWriter(serverRaw)
|
||||
var loginMsg msg.Login
|
||||
if err := rw.ReadMsgInto(&loginMsg); err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
if err := rw.WriteMsg(&msg.LoginResp{RunID: "run-ssh-tunnel"}); err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
_ = serverRaw.SetReadDeadline(time.Now().Add(time.Second))
|
||||
var ping msg.Ping
|
||||
if err := rw.ReadMsgInto(&ping); err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
serverErrCh <- nil
|
||||
}()
|
||||
|
||||
dialer := newTestControlSessionDialer(t, wire.ProtocolV1, connector, &msg.ClientSpec{Type: "ssh-tunnel"})
|
||||
sessionCtx, err := dialer.Dial("")
|
||||
require.NoError(t, err)
|
||||
defer sessionCtx.Conn.Close()
|
||||
defer sessionCtx.Connector.Close()
|
||||
|
||||
require.Equal(t, "run-ssh-tunnel", sessionCtx.RunID)
|
||||
require.NoError(t, sessionCtx.Conn.WriteMsg(&msg.Ping{}))
|
||||
require.NoError(t, <-serverErrCh)
|
||||
}
|
||||
Reference in New Issue
Block a user