mirror of
https://github.com/fatedier/frp.git
synced 2026-04-28 03:49:09 +08:00
protocol: add v2 wire protocol with binary framing and capability negotiation (#5294)
This commit is contained in:
@@ -19,6 +19,7 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -37,6 +38,7 @@ import (
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
"github.com/fatedier/frp/pkg/nathole"
|
||||
plugin "github.com/fatedier/frp/pkg/plugin/server"
|
||||
"github.com/fatedier/frp/pkg/proto/wire"
|
||||
"github.com/fatedier/frp/pkg/ssh"
|
||||
"github.com/fatedier/frp/pkg/transport"
|
||||
httppkg "github.com/fatedier/frp/pkg/util/http"
|
||||
@@ -432,20 +434,15 @@ func (svr *Service) Close() error {
|
||||
func (svr *Service) handleConnection(ctx context.Context, conn net.Conn, internal bool) {
|
||||
xl := xlog.FromContextSafe(ctx)
|
||||
|
||||
var (
|
||||
rawMsg msg.Message
|
||||
err error
|
||||
)
|
||||
|
||||
_ = conn.SetReadDeadline(time.Now().Add(connReadTimeout))
|
||||
if rawMsg, err = msg.ReadMsg(conn); err != nil {
|
||||
log.Tracef("failed to read message: %v", err)
|
||||
acceptedConn, err := svr.acceptConnection(ctx, conn)
|
||||
if err != nil {
|
||||
log.Tracef("failed to accept frp connection: %v", err)
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
_ = conn.SetReadDeadline(time.Time{})
|
||||
conn = acceptedConn.conn
|
||||
|
||||
switch m := rawMsg.(type) {
|
||||
switch m := acceptedConn.firstMsg.(type) {
|
||||
case *msg.Login:
|
||||
// server plugin hook
|
||||
content := &plugin.LoginContent{
|
||||
@@ -453,35 +450,66 @@ func (svr *Service) handleConnection(ctx context.Context, conn net.Conn, interna
|
||||
ClientAddress: conn.RemoteAddr().String(),
|
||||
}
|
||||
retContent, err := svr.pluginManager.Login(content)
|
||||
var ctl *Control
|
||||
if err == nil {
|
||||
m = &retContent.Login
|
||||
err = svr.RegisterControl(conn, m, internal)
|
||||
controlConn := acceptedConn.conn
|
||||
if !internal {
|
||||
var controlRW io.ReadWriter
|
||||
controlRW, err = netpkg.NewCryptoReadWriter(conn, svr.auth.EncryptionKey())
|
||||
if err == nil {
|
||||
controlConn = acceptedConn.messageConnFor(controlRW)
|
||||
}
|
||||
}
|
||||
if err == nil {
|
||||
ctl, err = svr.RegisterControl(controlConn, m, internal, acceptedConn.wireProtocol)
|
||||
}
|
||||
}
|
||||
|
||||
// If login failed, send error message there.
|
||||
// Otherwise send success message in control's work goroutine.
|
||||
if err != nil {
|
||||
xl.Warnf("register control error: %v", err)
|
||||
_ = msg.WriteMsg(conn, &msg.LoginResp{
|
||||
_ = acceptedConn.conn.WriteMsg(&msg.LoginResp{
|
||||
Version: version.Full(),
|
||||
Error: util.GenerateResponseErrorString("register control error", err, lo.FromPtr(svr.cfg.DetailedErrorsToClient)),
|
||||
})
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
if err = acceptedConn.conn.WriteMsg(&msg.LoginResp{
|
||||
Version: version.Full(),
|
||||
RunID: ctl.runID,
|
||||
Error: "",
|
||||
}); err != nil {
|
||||
xl.Warnf("write login response error: %v", err)
|
||||
svr.ctlManager.Del(m.RunID, ctl)
|
||||
svr.clientRegistry.MarkOfflineByRunID(m.RunID)
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
ctl.Start()
|
||||
metrics.Server.NewClient()
|
||||
go func() {
|
||||
// block until control closed
|
||||
ctl.WaitClosed()
|
||||
svr.ctlManager.Del(m.RunID, ctl)
|
||||
}()
|
||||
case *msg.NewWorkConn:
|
||||
if err := svr.RegisterWorkConn(conn, m); err != nil {
|
||||
if err := svr.RegisterWorkConn(acceptedConn.conn, m); err != nil {
|
||||
_ = acceptedConn.conn.WriteMsg(&msg.StartWorkConn{
|
||||
Error: util.GenerateResponseErrorString("invalid NewWorkConn", err, lo.FromPtr(svr.cfg.DetailedErrorsToClient)),
|
||||
})
|
||||
conn.Close()
|
||||
}
|
||||
case *msg.NewVisitorConn:
|
||||
if err = svr.RegisterVisitorConn(conn, m); err != nil {
|
||||
xl.Warnf("register visitor conn error: %v", err)
|
||||
_ = msg.WriteMsg(conn, &msg.NewVisitorConnResp{
|
||||
_ = acceptedConn.conn.WriteMsg(&msg.NewVisitorConnResp{
|
||||
ProxyName: m.ProxyName,
|
||||
Error: util.GenerateResponseErrorString("register visitor conn error", err, lo.FromPtr(svr.cfg.DetailedErrorsToClient)),
|
||||
})
|
||||
conn.Close()
|
||||
} else {
|
||||
_ = msg.WriteMsg(conn, &msg.NewVisitorConnResp{
|
||||
_ = acceptedConn.conn.WriteMsg(&msg.NewVisitorConnResp{
|
||||
ProxyName: m.ProxyName,
|
||||
Error: "",
|
||||
})
|
||||
@@ -492,6 +520,87 @@ func (svr *Service) handleConnection(ctx context.Context, conn net.Conn, interna
|
||||
}
|
||||
}
|
||||
|
||||
type acceptedConnection struct {
|
||||
conn *msg.Conn
|
||||
wireProtocol string
|
||||
firstMsg msg.Message
|
||||
}
|
||||
|
||||
func (svr *Service) acceptConnection(ctx context.Context, conn net.Conn) (*acceptedConnection, error) {
|
||||
_ = conn.SetReadDeadline(time.Now().Add(connReadTimeout))
|
||||
checkedConn, isV2, err := wire.CheckMagic(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read wire protocol magic: %w", err)
|
||||
}
|
||||
|
||||
wireProtocol := wire.ProtocolV1
|
||||
if isV2 {
|
||||
wireProtocol = wire.ProtocolV2
|
||||
}
|
||||
|
||||
conn = netpkg.NewContextConn(ctx, checkedConn)
|
||||
acceptedConn := &acceptedConnection{wireProtocol: wireProtocol}
|
||||
if isV2 {
|
||||
wireConn := wire.NewConn(conn)
|
||||
rw := msg.NewV2ReadWriterWithConn(wireConn)
|
||||
acceptedConn.conn = msg.NewConn(conn, rw)
|
||||
acceptedConn.firstMsg, err = acceptedConn.readFirstV2Msg(wireConn)
|
||||
} else {
|
||||
rw := msg.NewV1ReadWriter(conn)
|
||||
acceptedConn.conn = msg.NewConn(conn, rw)
|
||||
acceptedConn.firstMsg, err = acceptedConn.conn.ReadMsg()
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_ = conn.SetReadDeadline(time.Time{})
|
||||
return acceptedConn, nil
|
||||
}
|
||||
|
||||
func (ac *acceptedConnection) messageConnFor(rw io.ReadWriter) *msg.Conn {
|
||||
return msg.NewConn(ac.conn, msg.NewReadWriter(rw, ac.wireProtocol))
|
||||
}
|
||||
|
||||
func (ac *acceptedConnection) readFirstV2Msg(wireConn *wire.Conn) (msg.Message, error) {
|
||||
frame, err := wireConn.ReadFrame()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read v2 frame: %w", err)
|
||||
}
|
||||
if frame.Type == wire.FrameTypeClientHello {
|
||||
if err := ac.handleClientHello(wireConn, frame); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
frame, err = wireConn.ReadFrame()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read first v2 message frame: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
m, err := msg.DecodeV2MessageFrame(frame)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode v2 message: %w", err)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (ac *acceptedConnection) handleClientHello(wireConn *wire.Conn, frame *wire.Frame) error {
|
||||
var hello wire.ClientHello
|
||||
if err := wireConn.UnmarshalFrame(frame, &hello); err != nil {
|
||||
return fmt.Errorf("decode ClientHello: %w", err)
|
||||
}
|
||||
|
||||
serverHello := wire.DefaultServerHello()
|
||||
if err := wire.ValidateClientHello(hello); err != nil {
|
||||
serverHello.Error = err.Error()
|
||||
_ = wireConn.WriteJSONFrame(wire.FrameTypeServerHello, serverHello)
|
||||
return err
|
||||
}
|
||||
if err := wireConn.WriteJSONFrame(wire.FrameTypeServerHello, serverHello); err != nil {
|
||||
return fmt.Errorf("write ServerHello: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleListener accepts connections from client and call handleConnection to handle them.
|
||||
// If internal is true, it means that this listener is used for internal communication like ssh tunnel gateway.
|
||||
// TODO(fatedier): Pass some parameters of listener/connection through context to avoid passing too many parameters.
|
||||
@@ -577,14 +686,19 @@ func (svr *Service) HandleQUICListener(l *quic.Listener) {
|
||||
}
|
||||
}
|
||||
|
||||
func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, internal bool) error {
|
||||
func (svr *Service) RegisterControl(
|
||||
ctlConn *msg.Conn,
|
||||
loginMsg *msg.Login,
|
||||
internal bool,
|
||||
wireProtocol string,
|
||||
) (*Control, error) {
|
||||
// If client's RunID is empty, it's a new client, we just create a new controller.
|
||||
// Otherwise, we check if there is one controller has the same run id. If so, we release previous controller and start new one.
|
||||
var err error
|
||||
if loginMsg.RunID == "" {
|
||||
loginMsg.RunID, err = util.RandID()
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -601,7 +715,7 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, inter
|
||||
authVerifier = auth.AlwaysPassVerifier
|
||||
}
|
||||
if err := authVerifier.VerifyLogin(loginMsg); err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctl, err := NewControl(ctx, &SessionContext{
|
||||
@@ -611,7 +725,6 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, inter
|
||||
AuthVerifier: authVerifier,
|
||||
EncryptionKey: svr.auth.EncryptionKey(),
|
||||
Conn: ctlConn,
|
||||
ConnEncrypted: !internal,
|
||||
LoginMsg: loginMsg,
|
||||
ServerCfg: svr.cfg,
|
||||
ClientRegistry: svr.clientRegistry,
|
||||
@@ -619,7 +732,7 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, inter
|
||||
if err != nil {
|
||||
xl.Warnf("create new controller error: %v", err)
|
||||
// don't return detailed errors to client
|
||||
return fmt.Errorf("unexpected error when creating new controller")
|
||||
return nil, fmt.Errorf("unexpected error when creating new controller")
|
||||
}
|
||||
|
||||
if oldCtl := svr.ctlManager.Add(loginMsg.RunID, ctl); oldCtl != nil {
|
||||
@@ -630,34 +743,24 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, inter
|
||||
if host, _, err := net.SplitHostPort(remoteAddr); err == nil {
|
||||
remoteAddr = host
|
||||
}
|
||||
_, conflict := svr.clientRegistry.Register(loginMsg.User, loginMsg.ClientID, loginMsg.RunID, loginMsg.Hostname, loginMsg.Version, remoteAddr)
|
||||
_, conflict := svr.clientRegistry.Register(loginMsg.User, loginMsg.ClientID, loginMsg.RunID, loginMsg.Hostname, loginMsg.Version, remoteAddr, wireProtocol)
|
||||
if conflict {
|
||||
svr.ctlManager.Del(loginMsg.RunID, ctl)
|
||||
ctl.Close()
|
||||
return fmt.Errorf("client_id [%s] for user [%s] is already online", loginMsg.ClientID, loginMsg.User)
|
||||
return nil, fmt.Errorf("client_id [%s] for user [%s] is already online", loginMsg.ClientID, loginMsg.User)
|
||||
}
|
||||
|
||||
ctl.Start()
|
||||
|
||||
// for statistics
|
||||
metrics.Server.NewClient()
|
||||
|
||||
go func() {
|
||||
// block until control closed
|
||||
ctl.WaitClosed()
|
||||
svr.ctlManager.Del(loginMsg.RunID, ctl)
|
||||
}()
|
||||
return nil
|
||||
return ctl, nil
|
||||
}
|
||||
|
||||
// RegisterWorkConn register a new work connection to control and proxies need it.
|
||||
func (svr *Service) RegisterWorkConn(workConn net.Conn, newMsg *msg.NewWorkConn) error {
|
||||
func (svr *Service) RegisterWorkConn(workConn *msg.Conn, newMsg *msg.NewWorkConn) error {
|
||||
xl := netpkg.NewLogFromConn(workConn)
|
||||
ctl, exist := svr.ctlManager.GetByID(newMsg.RunID)
|
||||
if !exist {
|
||||
xl.Warnf("no client control found for run id [%s]", newMsg.RunID)
|
||||
return fmt.Errorf("no client control found for run id [%s]", newMsg.RunID)
|
||||
}
|
||||
|
||||
// server plugin hook
|
||||
content := &plugin.NewWorkConnContent{
|
||||
User: plugin.UserInfo{
|
||||
@@ -675,12 +778,9 @@ func (svr *Service) RegisterWorkConn(workConn net.Conn, newMsg *msg.NewWorkConn)
|
||||
}
|
||||
if err != nil {
|
||||
xl.Warnf("invalid NewWorkConn with run id [%s]", newMsg.RunID)
|
||||
_ = msg.WriteMsg(workConn, &msg.StartWorkConn{
|
||||
Error: util.GenerateResponseErrorString("invalid NewWorkConn", err, lo.FromPtr(svr.cfg.DetailedErrorsToClient)),
|
||||
})
|
||||
return fmt.Errorf("invalid NewWorkConn with run id [%s]", newMsg.RunID)
|
||||
return err
|
||||
}
|
||||
return ctl.RegisterWorkConn(workConn)
|
||||
return ctl.RegisterWorkConn(proxy.NewWorkConn(workConn))
|
||||
}
|
||||
|
||||
func (svr *Service) RegisterVisitorConn(visitorConn net.Conn, newMsg *msg.NewVisitorConn) error {
|
||||
|
||||
Reference in New Issue
Block a user