mirror of
https://github.com/fatedier/frp.git
synced 2026-04-27 19:39:10 +08:00
246 lines
6.5 KiB
Go
246 lines
6.5 KiB
Go
// Copyright 2026 The frp Authors
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package client
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/fatedier/frp/pkg/auth"
|
|
v1 "github.com/fatedier/frp/pkg/config/v1"
|
|
"github.com/fatedier/frp/pkg/msg"
|
|
"github.com/fatedier/frp/pkg/proto/wire"
|
|
)
|
|
|
|
type testConnector struct {
|
|
conn net.Conn
|
|
closed atomic.Bool
|
|
}
|
|
|
|
func (c *testConnector) Open() error {
|
|
return nil
|
|
}
|
|
|
|
func (c *testConnector) Connect() (net.Conn, error) {
|
|
return c.conn, nil
|
|
}
|
|
|
|
func (c *testConnector) Close() error {
|
|
c.closed.Store(true)
|
|
return nil
|
|
}
|
|
|
|
type trackingConn struct {
|
|
net.Conn
|
|
closed atomic.Bool
|
|
}
|
|
|
|
func (c *trackingConn) Close() error {
|
|
c.closed.Store(true)
|
|
return c.Conn.Close()
|
|
}
|
|
|
|
func newTestControlSessionDialer(t *testing.T, protocol string, connector Connector, clientSpec *msg.ClientSpec) *controlSessionDialer {
|
|
t.Helper()
|
|
|
|
authRuntime, err := auth.BuildClientAuth(&v1.AuthClientConfig{
|
|
Method: v1.AuthMethodToken,
|
|
Token: "token",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
return &controlSessionDialer{
|
|
ctx: context.Background(),
|
|
common: &v1.ClientCommonConfig{
|
|
User: "test-user",
|
|
Transport: v1.ClientTransportConfig{
|
|
Protocol: "tcp",
|
|
WireProtocol: protocol,
|
|
},
|
|
},
|
|
auth: authRuntime,
|
|
clientSpec: clientSpec,
|
|
connectorCreator: func(context.Context, *v1.ClientCommonConfig) Connector {
|
|
return connector
|
|
},
|
|
}
|
|
}
|
|
|
|
func TestControlSessionDialerDialV1(t *testing.T) {
|
|
clientRaw, serverRaw := net.Pipe()
|
|
defer serverRaw.Close()
|
|
|
|
connector := &testConnector{conn: &trackingConn{Conn: clientRaw}}
|
|
serverErrCh := make(chan error, 1)
|
|
go func() {
|
|
rw := msg.NewV1ReadWriter(serverRaw)
|
|
var loginMsg msg.Login
|
|
if err := rw.ReadMsgInto(&loginMsg); err != nil {
|
|
serverErrCh <- err
|
|
return
|
|
}
|
|
if loginMsg.RunID != "previous-run-id" {
|
|
serverErrCh <- fmt.Errorf("unexpected previous run id: %s", loginMsg.RunID)
|
|
return
|
|
}
|
|
if loginMsg.User != "test-user" {
|
|
serverErrCh <- fmt.Errorf("unexpected user: %s", loginMsg.User)
|
|
return
|
|
}
|
|
serverErrCh <- rw.WriteMsg(&msg.LoginResp{RunID: "run-v1"})
|
|
}()
|
|
|
|
dialer := newTestControlSessionDialer(t, wire.ProtocolV1, connector, nil)
|
|
sessionCtx, err := dialer.Dial("previous-run-id")
|
|
require.NoError(t, err)
|
|
defer sessionCtx.Conn.Close()
|
|
defer sessionCtx.Connector.Close()
|
|
|
|
require.Equal(t, "run-v1", sessionCtx.RunID)
|
|
require.NotNil(t, sessionCtx.Conn)
|
|
require.NotNil(t, sessionCtx.Connector)
|
|
require.False(t, connector.closed.Load())
|
|
require.NoError(t, <-serverErrCh)
|
|
}
|
|
|
|
func TestControlSessionDialerDialV2(t *testing.T) {
|
|
clientRaw, serverRaw := net.Pipe()
|
|
defer serverRaw.Close()
|
|
|
|
connector := &testConnector{conn: &trackingConn{Conn: clientRaw}}
|
|
serverErrCh := make(chan error, 1)
|
|
go func() {
|
|
magic := make([]byte, len(wire.MagicV2))
|
|
if _, err := io.ReadFull(serverRaw, magic); err != nil {
|
|
serverErrCh <- err
|
|
return
|
|
}
|
|
if string(magic) != wire.MagicV2 {
|
|
serverErrCh <- fmt.Errorf("unexpected magic: %q", string(magic))
|
|
return
|
|
}
|
|
|
|
wireConn := wire.NewConn(serverRaw)
|
|
var hello wire.ClientHello
|
|
if err := wireConn.ReadJSONFrame(wire.FrameTypeClientHello, &hello); err != nil {
|
|
serverErrCh <- err
|
|
return
|
|
}
|
|
if err := wire.ValidateClientHello(hello); err != nil {
|
|
serverErrCh <- err
|
|
return
|
|
}
|
|
|
|
rw := msg.NewV2ReadWriterWithConn(wireConn)
|
|
var loginMsg msg.Login
|
|
if err := rw.ReadMsgInto(&loginMsg); err != nil {
|
|
serverErrCh <- err
|
|
return
|
|
}
|
|
if loginMsg.User != "test-user" {
|
|
serverErrCh <- fmt.Errorf("unexpected user: %s", loginMsg.User)
|
|
return
|
|
}
|
|
if err := wireConn.WriteJSONFrame(wire.FrameTypeServerHello, wire.DefaultServerHello()); err != nil {
|
|
serverErrCh <- err
|
|
return
|
|
}
|
|
serverErrCh <- rw.WriteMsg(&msg.LoginResp{RunID: "run-v2"})
|
|
}()
|
|
|
|
dialer := newTestControlSessionDialer(t, wire.ProtocolV2, connector, nil)
|
|
sessionCtx, err := dialer.Dial("")
|
|
require.NoError(t, err)
|
|
defer sessionCtx.Conn.Close()
|
|
defer sessionCtx.Connector.Close()
|
|
|
|
require.Equal(t, "run-v2", sessionCtx.RunID)
|
|
require.NotNil(t, sessionCtx.Conn)
|
|
require.NotNil(t, sessionCtx.Connector)
|
|
require.False(t, connector.closed.Load())
|
|
require.NoError(t, <-serverErrCh)
|
|
}
|
|
|
|
func TestControlSessionDialerDialLoginErrorClosesResources(t *testing.T) {
|
|
clientRaw, serverRaw := net.Pipe()
|
|
defer serverRaw.Close()
|
|
|
|
clientConn := &trackingConn{Conn: clientRaw}
|
|
connector := &testConnector{conn: clientConn}
|
|
serverErrCh := make(chan error, 1)
|
|
go func() {
|
|
rw := msg.NewV1ReadWriter(serverRaw)
|
|
var loginMsg msg.Login
|
|
if err := rw.ReadMsgInto(&loginMsg); err != nil {
|
|
serverErrCh <- err
|
|
return
|
|
}
|
|
serverErrCh <- rw.WriteMsg(&msg.LoginResp{Error: "login denied"})
|
|
}()
|
|
|
|
dialer := newTestControlSessionDialer(t, wire.ProtocolV1, connector, nil)
|
|
sessionCtx, err := dialer.Dial("")
|
|
require.Nil(t, sessionCtx)
|
|
require.ErrorContains(t, err, "login denied")
|
|
require.True(t, clientConn.closed.Load())
|
|
require.True(t, connector.closed.Load())
|
|
require.NoError(t, <-serverErrCh)
|
|
}
|
|
|
|
func TestControlSessionDialerDialSSHTunnelSkipsControlEncryption(t *testing.T) {
|
|
clientRaw, serverRaw := net.Pipe()
|
|
defer serverRaw.Close()
|
|
|
|
connector := &testConnector{conn: &trackingConn{Conn: clientRaw}}
|
|
serverErrCh := make(chan error, 1)
|
|
go func() {
|
|
rw := msg.NewV1ReadWriter(serverRaw)
|
|
var loginMsg msg.Login
|
|
if err := rw.ReadMsgInto(&loginMsg); err != nil {
|
|
serverErrCh <- err
|
|
return
|
|
}
|
|
if err := rw.WriteMsg(&msg.LoginResp{RunID: "run-ssh-tunnel"}); err != nil {
|
|
serverErrCh <- err
|
|
return
|
|
}
|
|
|
|
_ = serverRaw.SetReadDeadline(time.Now().Add(time.Second))
|
|
var ping msg.Ping
|
|
if err := rw.ReadMsgInto(&ping); err != nil {
|
|
serverErrCh <- err
|
|
return
|
|
}
|
|
serverErrCh <- nil
|
|
}()
|
|
|
|
dialer := newTestControlSessionDialer(t, wire.ProtocolV1, connector, &msg.ClientSpec{Type: "ssh-tunnel"})
|
|
sessionCtx, err := dialer.Dial("")
|
|
require.NoError(t, err)
|
|
defer sessionCtx.Conn.Close()
|
|
defer sessionCtx.Connector.Close()
|
|
|
|
require.Equal(t, "run-ssh-tunnel", sessionCtx.RunID)
|
|
require.NoError(t, sessionCtx.Conn.WriteMsg(&msg.Ping{}))
|
|
require.NoError(t, <-serverErrCh)
|
|
}
|