mirror of
https://github.com/fatedier/frp.git
synced 2026-04-27 19:39:10 +08:00
protocol: add v2 wire protocol with binary framing and capability negotiation (#5294)
This commit is contained in:
@@ -29,6 +29,8 @@ import (
|
||||
"github.com/samber/lo"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
"github.com/fatedier/frp/pkg/proto/wire"
|
||||
"github.com/fatedier/frp/pkg/transport"
|
||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||
"github.com/fatedier/frp/pkg/util/xlog"
|
||||
@@ -41,6 +43,39 @@ type Connector interface {
|
||||
Close() error
|
||||
}
|
||||
|
||||
type MessageConnector interface {
|
||||
Connect() (*msg.Conn, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
type messageConnector struct {
|
||||
connector Connector
|
||||
wireProtocol string
|
||||
}
|
||||
|
||||
func newMessageConnector(connector Connector, wireProtocol string) *messageConnector {
|
||||
return &messageConnector{
|
||||
connector: connector,
|
||||
wireProtocol: wireProtocol,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *messageConnector) Connect() (*msg.Conn, error) {
|
||||
conn, err := c.connector.Connect()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = wire.WriteMagicIfV2(conn, c.wireProtocol); err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return msg.NewConn(conn, msg.NewReadWriter(conn, c.wireProtocol)), nil
|
||||
}
|
||||
|
||||
func (c *messageConnector) Close() error {
|
||||
return c.connector.Close()
|
||||
}
|
||||
|
||||
// defaultConnectorImpl is the default implementation of Connector for normal frpc.
|
||||
type defaultConnectorImpl struct {
|
||||
ctx context.Context
|
||||
|
||||
@@ -27,7 +27,6 @@ import (
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
"github.com/fatedier/frp/pkg/naming"
|
||||
"github.com/fatedier/frp/pkg/transport"
|
||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||
"github.com/fatedier/frp/pkg/util/wait"
|
||||
"github.com/fatedier/frp/pkg/util/xlog"
|
||||
"github.com/fatedier/frp/pkg/vnet"
|
||||
@@ -41,13 +40,11 @@ type SessionContext struct {
|
||||
// It should be attached to the login message when reconnecting.
|
||||
RunID string
|
||||
// Underlying control connection. Once conn is closed, the msgDispatcher and the entire Control will exit.
|
||||
Conn net.Conn
|
||||
// Indicates whether the connection is encrypted.
|
||||
ConnEncrypted bool
|
||||
Conn *msg.Conn
|
||||
// Auth runtime used for login, heartbeats, and encryption.
|
||||
Auth *auth.ClientAuth
|
||||
// Connector is used to create new connections, which could be real TCP connections or virtual streams.
|
||||
Connector Connector
|
||||
// Connector is used to create message connections to frps.
|
||||
Connector MessageConnector
|
||||
// Virtual net controller
|
||||
VnetController *vnet.Controller
|
||||
}
|
||||
@@ -91,15 +88,7 @@ func NewControl(ctx context.Context, sessionCtx *SessionContext) (*Control, erro
|
||||
}
|
||||
ctl.lastPong.Store(time.Now())
|
||||
|
||||
if sessionCtx.ConnEncrypted {
|
||||
cryptoRW, err := netpkg.NewCryptoReadWriter(sessionCtx.Conn, sessionCtx.Auth.EncryptionKey())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctl.msgDispatcher = msg.NewDispatcher(cryptoRW)
|
||||
} else {
|
||||
ctl.msgDispatcher = msg.NewDispatcher(sessionCtx.Conn)
|
||||
}
|
||||
ctl.msgDispatcher = msg.NewDispatcher(sessionCtx.Conn)
|
||||
ctl.registerMsgHandlers()
|
||||
ctl.msgTransporter = transport.NewMessageTransporter(ctl.msgDispatcher)
|
||||
|
||||
@@ -139,14 +128,14 @@ func (ctl *Control) handleReqWorkConn(_ msg.Message) {
|
||||
workConn.Close()
|
||||
return
|
||||
}
|
||||
if err = msg.WriteMsg(workConn, m); err != nil {
|
||||
if err = workConn.WriteMsg(m); err != nil {
|
||||
xl.Warnf("work connection write to server error: %v", err)
|
||||
workConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
var startMsg msg.StartWorkConn
|
||||
if err = msg.ReadMsgInto(workConn, &startMsg); err != nil {
|
||||
if err = workConn.ReadMsgInto(&startMsg); err != nil {
|
||||
xl.Tracef("work connection closed before response StartWorkConn message: %v", err)
|
||||
workConn.Close()
|
||||
return
|
||||
@@ -227,7 +216,7 @@ func (ctl *Control) Done() <-chan struct{} {
|
||||
}
|
||||
|
||||
// connectServer return a new connection to frps
|
||||
func (ctl *Control) connectServer() (net.Conn, error) {
|
||||
func (ctl *Control) connectServer() (*msg.Conn, error) {
|
||||
return ctl.sessionCtx.Connector.Connect()
|
||||
}
|
||||
|
||||
|
||||
172
client/control_session.go
Normal file
172
client/control_session.go
Normal file
@@ -0,0 +1,172 @@
|
||||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/samber/lo"
|
||||
|
||||
"github.com/fatedier/frp/pkg/auth"
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
"github.com/fatedier/frp/pkg/proto/wire"
|
||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||
"github.com/fatedier/frp/pkg/util/version"
|
||||
"github.com/fatedier/frp/pkg/vnet"
|
||||
)
|
||||
|
||||
type controlSessionDialer struct {
|
||||
ctx context.Context
|
||||
|
||||
common *v1.ClientCommonConfig
|
||||
auth *auth.ClientAuth
|
||||
clientSpec *msg.ClientSpec
|
||||
vnetController *vnet.Controller
|
||||
|
||||
connectorCreator func(context.Context, *v1.ClientCommonConfig) Connector
|
||||
}
|
||||
|
||||
func (d *controlSessionDialer) Dial(previousRunID string) (*SessionContext, error) {
|
||||
connector := d.connectorCreator(d.ctx, d.common)
|
||||
if err := connector.Open(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
success := false
|
||||
defer func() {
|
||||
if !success {
|
||||
_ = connector.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := connector.Connect()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if !success {
|
||||
_ = conn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
loginMsg, err := d.buildLoginMsg(previousRunID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
loginRespMsg, err := d.exchangeLogin(conn, loginMsg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if loginRespMsg.Error != "" {
|
||||
return nil, errors.New(loginRespMsg.Error)
|
||||
}
|
||||
|
||||
var controlRW io.ReadWriter = conn
|
||||
if d.clientSpec == nil || d.clientSpec.Type != "ssh-tunnel" {
|
||||
controlRW, err = netpkg.NewCryptoReadWriter(conn, d.auth.EncryptionKey())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create control crypto read writer: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
success = true
|
||||
return &SessionContext{
|
||||
Common: d.common,
|
||||
RunID: loginRespMsg.RunID,
|
||||
Conn: msg.NewConn(conn, msg.NewReadWriter(controlRW, d.common.Transport.WireProtocol)),
|
||||
Auth: d.auth,
|
||||
Connector: newMessageConnector(connector, d.common.Transport.WireProtocol),
|
||||
VnetController: d.vnetController,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *controlSessionDialer) buildLoginMsg(previousRunID string) (*msg.Login, error) {
|
||||
hostname, _ := os.Hostname()
|
||||
loginMsg := &msg.Login{
|
||||
Arch: runtime.GOARCH,
|
||||
Os: runtime.GOOS,
|
||||
Hostname: hostname,
|
||||
PoolCount: d.common.Transport.PoolCount,
|
||||
User: d.common.User,
|
||||
ClientID: d.common.ClientID,
|
||||
Version: version.Full(),
|
||||
Timestamp: time.Now().Unix(),
|
||||
RunID: previousRunID,
|
||||
Metas: d.common.Metadatas,
|
||||
}
|
||||
if d.clientSpec != nil {
|
||||
loginMsg.ClientSpec = *d.clientSpec
|
||||
}
|
||||
|
||||
if err := d.auth.Setter.SetLogin(loginMsg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return loginMsg, nil
|
||||
}
|
||||
|
||||
func (d *controlSessionDialer) exchangeLogin(conn net.Conn, loginMsg *msg.Login) (*msg.LoginResp, error) {
|
||||
rw := msg.NewV1ReadWriter(conn)
|
||||
var wireConn *wire.Conn
|
||||
|
||||
if d.common.Transport.WireProtocol == wire.ProtocolV2 {
|
||||
if err := wire.WriteMagic(conn); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wireConn = wire.NewConn(conn)
|
||||
rw = msg.NewV2ReadWriterWithConn(wireConn)
|
||||
hello := wire.DefaultClientHello(wire.BootstrapInfo{
|
||||
Transport: d.common.Transport.Protocol,
|
||||
TLS: lo.FromPtr(d.common.Transport.TLS.Enable) || d.common.Transport.Protocol == "wss" || d.common.Transport.Protocol == "quic",
|
||||
TCPMux: lo.FromPtr(d.common.Transport.TCPMux),
|
||||
})
|
||||
if err := wireConn.WriteJSONFrame(wire.FrameTypeClientHello, hello); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err := rw.WriteMsg(loginMsg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_ = conn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
||||
defer func() {
|
||||
_ = conn.SetReadDeadline(time.Time{})
|
||||
}()
|
||||
|
||||
if wireConn != nil {
|
||||
var serverHello wire.ServerHello
|
||||
if err := wireConn.ReadJSONFrame(wire.FrameTypeServerHello, &serverHello); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if serverHello.Error != "" {
|
||||
return nil, errors.New(serverHello.Error)
|
||||
}
|
||||
}
|
||||
|
||||
var loginRespMsg msg.LoginResp
|
||||
if err := rw.ReadMsgInto(&loginRespMsg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &loginRespMsg, nil
|
||||
}
|
||||
245
client/control_session_test.go
Normal file
245
client/control_session_test.go
Normal file
@@ -0,0 +1,245 @@
|
||||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/fatedier/frp/pkg/auth"
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
"github.com/fatedier/frp/pkg/proto/wire"
|
||||
)
|
||||
|
||||
type testConnector struct {
|
||||
conn net.Conn
|
||||
closed atomic.Bool
|
||||
}
|
||||
|
||||
func (c *testConnector) Open() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *testConnector) Connect() (net.Conn, error) {
|
||||
return c.conn, nil
|
||||
}
|
||||
|
||||
func (c *testConnector) Close() error {
|
||||
c.closed.Store(true)
|
||||
return nil
|
||||
}
|
||||
|
||||
type trackingConn struct {
|
||||
net.Conn
|
||||
closed atomic.Bool
|
||||
}
|
||||
|
||||
func (c *trackingConn) Close() error {
|
||||
c.closed.Store(true)
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
func newTestControlSessionDialer(t *testing.T, protocol string, connector Connector, clientSpec *msg.ClientSpec) *controlSessionDialer {
|
||||
t.Helper()
|
||||
|
||||
authRuntime, err := auth.BuildClientAuth(&v1.AuthClientConfig{
|
||||
Method: v1.AuthMethodToken,
|
||||
Token: "token",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return &controlSessionDialer{
|
||||
ctx: context.Background(),
|
||||
common: &v1.ClientCommonConfig{
|
||||
User: "test-user",
|
||||
Transport: v1.ClientTransportConfig{
|
||||
Protocol: "tcp",
|
||||
WireProtocol: protocol,
|
||||
},
|
||||
},
|
||||
auth: authRuntime,
|
||||
clientSpec: clientSpec,
|
||||
connectorCreator: func(context.Context, *v1.ClientCommonConfig) Connector {
|
||||
return connector
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestControlSessionDialerDialV1(t *testing.T) {
|
||||
clientRaw, serverRaw := net.Pipe()
|
||||
defer serverRaw.Close()
|
||||
|
||||
connector := &testConnector{conn: &trackingConn{Conn: clientRaw}}
|
||||
serverErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
rw := msg.NewV1ReadWriter(serverRaw)
|
||||
var loginMsg msg.Login
|
||||
if err := rw.ReadMsgInto(&loginMsg); err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
if loginMsg.RunID != "previous-run-id" {
|
||||
serverErrCh <- fmt.Errorf("unexpected previous run id: %s", loginMsg.RunID)
|
||||
return
|
||||
}
|
||||
if loginMsg.User != "test-user" {
|
||||
serverErrCh <- fmt.Errorf("unexpected user: %s", loginMsg.User)
|
||||
return
|
||||
}
|
||||
serverErrCh <- rw.WriteMsg(&msg.LoginResp{RunID: "run-v1"})
|
||||
}()
|
||||
|
||||
dialer := newTestControlSessionDialer(t, wire.ProtocolV1, connector, nil)
|
||||
sessionCtx, err := dialer.Dial("previous-run-id")
|
||||
require.NoError(t, err)
|
||||
defer sessionCtx.Conn.Close()
|
||||
defer sessionCtx.Connector.Close()
|
||||
|
||||
require.Equal(t, "run-v1", sessionCtx.RunID)
|
||||
require.NotNil(t, sessionCtx.Conn)
|
||||
require.NotNil(t, sessionCtx.Connector)
|
||||
require.False(t, connector.closed.Load())
|
||||
require.NoError(t, <-serverErrCh)
|
||||
}
|
||||
|
||||
func TestControlSessionDialerDialV2(t *testing.T) {
|
||||
clientRaw, serverRaw := net.Pipe()
|
||||
defer serverRaw.Close()
|
||||
|
||||
connector := &testConnector{conn: &trackingConn{Conn: clientRaw}}
|
||||
serverErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
magic := make([]byte, len(wire.MagicV2))
|
||||
if _, err := io.ReadFull(serverRaw, magic); err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
if string(magic) != wire.MagicV2 {
|
||||
serverErrCh <- fmt.Errorf("unexpected magic: %q", string(magic))
|
||||
return
|
||||
}
|
||||
|
||||
wireConn := wire.NewConn(serverRaw)
|
||||
var hello wire.ClientHello
|
||||
if err := wireConn.ReadJSONFrame(wire.FrameTypeClientHello, &hello); err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
if err := wire.ValidateClientHello(hello); err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
rw := msg.NewV2ReadWriterWithConn(wireConn)
|
||||
var loginMsg msg.Login
|
||||
if err := rw.ReadMsgInto(&loginMsg); err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
if loginMsg.User != "test-user" {
|
||||
serverErrCh <- fmt.Errorf("unexpected user: %s", loginMsg.User)
|
||||
return
|
||||
}
|
||||
if err := wireConn.WriteJSONFrame(wire.FrameTypeServerHello, wire.DefaultServerHello()); err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
serverErrCh <- rw.WriteMsg(&msg.LoginResp{RunID: "run-v2"})
|
||||
}()
|
||||
|
||||
dialer := newTestControlSessionDialer(t, wire.ProtocolV2, connector, nil)
|
||||
sessionCtx, err := dialer.Dial("")
|
||||
require.NoError(t, err)
|
||||
defer sessionCtx.Conn.Close()
|
||||
defer sessionCtx.Connector.Close()
|
||||
|
||||
require.Equal(t, "run-v2", sessionCtx.RunID)
|
||||
require.NotNil(t, sessionCtx.Conn)
|
||||
require.NotNil(t, sessionCtx.Connector)
|
||||
require.False(t, connector.closed.Load())
|
||||
require.NoError(t, <-serverErrCh)
|
||||
}
|
||||
|
||||
func TestControlSessionDialerDialLoginErrorClosesResources(t *testing.T) {
|
||||
clientRaw, serverRaw := net.Pipe()
|
||||
defer serverRaw.Close()
|
||||
|
||||
clientConn := &trackingConn{Conn: clientRaw}
|
||||
connector := &testConnector{conn: clientConn}
|
||||
serverErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
rw := msg.NewV1ReadWriter(serverRaw)
|
||||
var loginMsg msg.Login
|
||||
if err := rw.ReadMsgInto(&loginMsg); err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
serverErrCh <- rw.WriteMsg(&msg.LoginResp{Error: "login denied"})
|
||||
}()
|
||||
|
||||
dialer := newTestControlSessionDialer(t, wire.ProtocolV1, connector, nil)
|
||||
sessionCtx, err := dialer.Dial("")
|
||||
require.Nil(t, sessionCtx)
|
||||
require.ErrorContains(t, err, "login denied")
|
||||
require.True(t, clientConn.closed.Load())
|
||||
require.True(t, connector.closed.Load())
|
||||
require.NoError(t, <-serverErrCh)
|
||||
}
|
||||
|
||||
func TestControlSessionDialerDialSSHTunnelSkipsControlEncryption(t *testing.T) {
|
||||
clientRaw, serverRaw := net.Pipe()
|
||||
defer serverRaw.Close()
|
||||
|
||||
connector := &testConnector{conn: &trackingConn{Conn: clientRaw}}
|
||||
serverErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
rw := msg.NewV1ReadWriter(serverRaw)
|
||||
var loginMsg msg.Login
|
||||
if err := rw.ReadMsgInto(&loginMsg); err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
if err := rw.WriteMsg(&msg.LoginResp{RunID: "run-ssh-tunnel"}); err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
_ = serverRaw.SetReadDeadline(time.Now().Add(time.Second))
|
||||
var ping msg.Ping
|
||||
if err := rw.ReadMsgInto(&ping); err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
serverErrCh <- nil
|
||||
}()
|
||||
|
||||
dialer := newTestControlSessionDialer(t, wire.ProtocolV1, connector, &msg.ClientSpec{Type: "ssh-tunnel"})
|
||||
sessionCtx, err := dialer.Dial("")
|
||||
require.NoError(t, err)
|
||||
defer sessionCtx.Conn.Close()
|
||||
defer sessionCtx.Connector.Close()
|
||||
|
||||
require.Equal(t, "run-ssh-tunnel", sessionCtx.RunID)
|
||||
require.NoError(t, sessionCtx.Conn.WriteMsg(&msg.Ping{}))
|
||||
require.NoError(t, <-serverErrCh)
|
||||
}
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -38,7 +37,6 @@ import (
|
||||
httppkg "github.com/fatedier/frp/pkg/util/http"
|
||||
"github.com/fatedier/frp/pkg/util/log"
|
||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||
"github.com/fatedier/frp/pkg/util/version"
|
||||
"github.com/fatedier/frp/pkg/util/wait"
|
||||
"github.com/fatedier/frp/pkg/util/xlog"
|
||||
"github.com/fatedier/frp/pkg/vnet"
|
||||
@@ -303,80 +301,20 @@ func (svr *Service) keepControllerWorking() {
|
||||
), true, svr.ctx.Done())
|
||||
}
|
||||
|
||||
// login creates a connection to frps and registers it self as a client
|
||||
// conn: control connection
|
||||
// session: if it's not nil, using tcp mux
|
||||
func (svr *Service) login() (conn net.Conn, connector Connector, err error) {
|
||||
xl := xlog.FromContextSafe(svr.ctx)
|
||||
connector = svr.connectorCreator(svr.ctx, svr.common)
|
||||
if err = connector.Open(); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
connector.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err = connector.Connect()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
hostname, _ := os.Hostname()
|
||||
|
||||
loginMsg := &msg.Login{
|
||||
Arch: runtime.GOARCH,
|
||||
Os: runtime.GOOS,
|
||||
Hostname: hostname,
|
||||
PoolCount: svr.common.Transport.PoolCount,
|
||||
User: svr.common.User,
|
||||
ClientID: svr.common.ClientID,
|
||||
Version: version.Full(),
|
||||
Timestamp: time.Now().Unix(),
|
||||
RunID: svr.runID,
|
||||
Metas: svr.common.Metadatas,
|
||||
}
|
||||
if svr.clientSpec != nil {
|
||||
loginMsg.ClientSpec = *svr.clientSpec
|
||||
}
|
||||
|
||||
// Add auth
|
||||
if err = svr.auth.Setter.SetLogin(loginMsg); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err = msg.WriteMsg(conn, loginMsg); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var loginRespMsg msg.LoginResp
|
||||
_ = conn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
||||
if err = msg.ReadMsgInto(conn, &loginRespMsg); err != nil {
|
||||
return
|
||||
}
|
||||
_ = conn.SetReadDeadline(time.Time{})
|
||||
|
||||
if loginRespMsg.Error != "" {
|
||||
err = fmt.Errorf("%s", loginRespMsg.Error)
|
||||
xl.Errorf("%s", loginRespMsg.Error)
|
||||
return
|
||||
}
|
||||
|
||||
svr.runID = loginRespMsg.RunID
|
||||
xl.AddPrefix(xlog.LogPrefix{Name: "runID", Value: svr.runID})
|
||||
|
||||
xl.Infof("login to server success, get run id [%s]", loginRespMsg.RunID)
|
||||
return
|
||||
}
|
||||
|
||||
func (svr *Service) loopLoginUntilSuccess(maxInterval time.Duration, firstLoginExit bool) {
|
||||
xl := xlog.FromContextSafe(svr.ctx)
|
||||
|
||||
loginFunc := func() (bool, error) {
|
||||
xl.Infof("try to connect to server...")
|
||||
conn, connector, err := svr.login()
|
||||
dialer := &controlSessionDialer{
|
||||
ctx: svr.ctx,
|
||||
common: svr.common,
|
||||
auth: svr.auth,
|
||||
clientSpec: svr.clientSpec,
|
||||
vnetController: svr.vnetController,
|
||||
connectorCreator: svr.connectorCreator,
|
||||
}
|
||||
sessionCtx, err := dialer.Dial(svr.runID)
|
||||
if err != nil {
|
||||
xl.Warnf("connect to server error: %v", err)
|
||||
if firstLoginExit {
|
||||
@@ -385,25 +323,19 @@ func (svr *Service) loopLoginUntilSuccess(maxInterval time.Duration, firstLoginE
|
||||
return false, err
|
||||
}
|
||||
|
||||
svr.runID = sessionCtx.RunID
|
||||
xl.AddPrefix(xlog.LogPrefix{Name: "runID", Value: svr.runID})
|
||||
xl.Infof("login to server success, get run id [%s]", svr.runID)
|
||||
|
||||
svr.cfgMu.RLock()
|
||||
proxyCfgs := svr.proxyCfgs
|
||||
visitorCfgs := svr.visitorCfgs
|
||||
svr.cfgMu.RUnlock()
|
||||
|
||||
connEncrypted := svr.clientSpec == nil || svr.clientSpec.Type != "ssh-tunnel"
|
||||
|
||||
sessionCtx := &SessionContext{
|
||||
Common: svr.common,
|
||||
RunID: svr.runID,
|
||||
Conn: conn,
|
||||
ConnEncrypted: connEncrypted,
|
||||
Auth: svr.auth,
|
||||
Connector: connector,
|
||||
VnetController: svr.vnetController,
|
||||
}
|
||||
ctl, err := NewControl(svr.ctx, sessionCtx)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
sessionCtx.Conn.Close()
|
||||
sessionCtx.Connector.Close()
|
||||
xl.Errorf("new control error: %v", err)
|
||||
return false, err
|
||||
}
|
||||
|
||||
@@ -38,7 +38,7 @@ import (
|
||||
// Helper wraps some functions for visitor to use.
|
||||
type Helper interface {
|
||||
// ConnectServer directly connects to the frp server.
|
||||
ConnectServer() (net.Conn, error)
|
||||
ConnectServer() (*msg.Conn, error)
|
||||
// TransferConn transfers the connection to another visitor.
|
||||
TransferConn(string, net.Conn) error
|
||||
// MsgTransporter returns the message transporter that is used to send and receive messages
|
||||
@@ -167,15 +167,15 @@ func (v *BaseVisitor) dialRawVisitorConn(cfg *v1.VisitorBaseConfig) (net.Conn, e
|
||||
UseEncryption: cfg.Transport.UseEncryption,
|
||||
UseCompression: cfg.Transport.UseCompression,
|
||||
}
|
||||
err = msg.WriteMsg(visitorConn, newVisitorConnMsg)
|
||||
err = visitorConn.WriteMsg(newVisitorConnMsg)
|
||||
if err != nil {
|
||||
visitorConn.Close()
|
||||
return nil, fmt.Errorf("send newVisitorConnMsg to server error: %v", err)
|
||||
}
|
||||
|
||||
var newVisitorConnRespMsg msg.NewVisitorConnResp
|
||||
_ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
||||
err = msg.ReadMsgInto(visitorConn, &newVisitorConnRespMsg)
|
||||
var newVisitorConnRespMsg msg.NewVisitorConnResp
|
||||
err = visitorConn.ReadMsgInto(&newVisitorConnRespMsg)
|
||||
if err != nil {
|
||||
visitorConn.Close()
|
||||
return nil, fmt.Errorf("read newVisitorConnRespMsg error: %v", err)
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"github.com/samber/lo"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
"github.com/fatedier/frp/pkg/transport"
|
||||
"github.com/fatedier/frp/pkg/util/xlog"
|
||||
"github.com/fatedier/frp/pkg/vnet"
|
||||
@@ -49,7 +50,7 @@ func NewManager(
|
||||
ctx context.Context,
|
||||
runID string,
|
||||
clientCfg *v1.ClientCommonConfig,
|
||||
connectServer func() (net.Conn, error),
|
||||
connectServer func() (*msg.Conn, error),
|
||||
msgTransporter transport.MessageTransporter,
|
||||
vnetController *vnet.Controller,
|
||||
) *Manager {
|
||||
@@ -199,14 +200,14 @@ func (vm *Manager) GetVisitorCfg(name string) (v1.VisitorConfigurer, bool) {
|
||||
}
|
||||
|
||||
type visitorHelperImpl struct {
|
||||
connectServerFn func() (net.Conn, error)
|
||||
connectServerFn func() (*msg.Conn, error)
|
||||
msgTransporter transport.MessageTransporter
|
||||
vnetController *vnet.Controller
|
||||
transferConnFn func(name string, conn net.Conn) error
|
||||
runID string
|
||||
}
|
||||
|
||||
func (v *visitorHelperImpl) ConnectServer() (net.Conn, error) {
|
||||
func (v *visitorHelperImpl) ConnectServer() (*msg.Conn, error) {
|
||||
return v.connectServerFn()
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user