mirror of
https://github.com/fatedier/frp.git
synced 2026-04-28 03:49:09 +08:00
protocol: add v2 wire protocol with binary framing and capability negotiation (#5294)
This commit is contained in:
222
pkg/proto/wire/wire.go
Normal file
222
pkg/proto/wire/wire.go
Normal file
@@ -0,0 +1,222 @@
|
||||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package wire
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"slices"
|
||||
|
||||
libnet "github.com/fatedier/golib/net"
|
||||
)
|
||||
|
||||
const (
|
||||
ProtocolV1 = "v1"
|
||||
ProtocolV2 = "v2"
|
||||
|
||||
WireVersionV2 = 2
|
||||
|
||||
FrameTypeClientHello uint16 = 1
|
||||
FrameTypeServerHello uint16 = 2
|
||||
FrameTypeMessage uint16 = 16
|
||||
|
||||
MessageCodecJSON = "json"
|
||||
DefaultMaxFramePayloadSize = 64 * 1024
|
||||
|
||||
MagicV2 = "FRP\x00\x02\r\n"
|
||||
)
|
||||
|
||||
type Frame struct {
|
||||
Type uint16
|
||||
Flags uint16
|
||||
Payload []byte
|
||||
}
|
||||
|
||||
type Conn struct {
|
||||
rw io.ReadWriter
|
||||
maxFramePayloadSize uint32
|
||||
}
|
||||
|
||||
func NewConn(rw io.ReadWriter) *Conn {
|
||||
return &Conn{
|
||||
rw: rw,
|
||||
maxFramePayloadSize: DefaultMaxFramePayloadSize,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) ReadFrame() (*Frame, error) {
|
||||
header := make([]byte, 8)
|
||||
if _, err := io.ReadFull(c.rw, header); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
frameType := binary.BigEndian.Uint16(header[0:2])
|
||||
flags := binary.BigEndian.Uint16(header[2:4])
|
||||
length := binary.BigEndian.Uint32(header[4:8])
|
||||
if flags != 0 {
|
||||
return nil, fmt.Errorf("unsupported frame flags: %d", flags)
|
||||
}
|
||||
if length > c.maxFramePayloadSize {
|
||||
return nil, fmt.Errorf("frame payload length %d exceeds limit %d", length, c.maxFramePayloadSize)
|
||||
}
|
||||
|
||||
payload := make([]byte, length)
|
||||
if _, err := io.ReadFull(c.rw, payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Frame{
|
||||
Type: frameType,
|
||||
Flags: flags,
|
||||
Payload: payload,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Conn) WriteFrame(f *Frame) error {
|
||||
if f.Flags != 0 {
|
||||
return fmt.Errorf("unsupported frame flags: %d", f.Flags)
|
||||
}
|
||||
if len(f.Payload) > int(c.maxFramePayloadSize) {
|
||||
return fmt.Errorf("frame payload length %d exceeds limit %d", len(f.Payload), c.maxFramePayloadSize)
|
||||
}
|
||||
|
||||
header := make([]byte, 8)
|
||||
binary.BigEndian.PutUint16(header[0:2], f.Type)
|
||||
binary.BigEndian.PutUint16(header[2:4], f.Flags)
|
||||
binary.BigEndian.PutUint32(header[4:8], uint32(len(f.Payload)))
|
||||
if _, err := c.rw.Write(header); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := c.rw.Write(f.Payload)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Conn) ReadJSONFrame(frameType uint16, out any) error {
|
||||
f, err := c.ReadFrame()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if f.Type != frameType {
|
||||
return fmt.Errorf("unexpected frame type %d, want %d", f.Type, frameType)
|
||||
}
|
||||
return c.UnmarshalFrame(f, out)
|
||||
}
|
||||
|
||||
func (c *Conn) UnmarshalFrame(f *Frame, out any) error {
|
||||
return json.Unmarshal(f.Payload, out)
|
||||
}
|
||||
|
||||
func (c *Conn) WriteJSONFrame(frameType uint16, in any) error {
|
||||
payload, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.WriteFrame(&Frame{
|
||||
Type: frameType,
|
||||
Payload: payload,
|
||||
})
|
||||
}
|
||||
|
||||
func WriteMagic(w io.Writer) error {
|
||||
_, err := io.WriteString(w, MagicV2)
|
||||
return err
|
||||
}
|
||||
|
||||
func WriteMagicIfV2(w io.Writer, wireProtocol string) error {
|
||||
if wireProtocol != ProtocolV2 {
|
||||
return nil
|
||||
}
|
||||
return WriteMagic(w)
|
||||
}
|
||||
|
||||
func CheckMagic(conn net.Conn) (out net.Conn, isV2 bool, err error) {
|
||||
sharedConn, r := libnet.NewSharedConnSize(conn, len(MagicV2))
|
||||
buf := make([]byte, len(MagicV2))
|
||||
if _, err = io.ReadFull(r, buf); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
for i := range MagicV2 {
|
||||
if buf[i] != MagicV2[i] {
|
||||
return sharedConn, false, nil
|
||||
}
|
||||
}
|
||||
return conn, true, nil
|
||||
}
|
||||
|
||||
type BootstrapInfo struct {
|
||||
Transport string `json:"transport,omitempty"`
|
||||
TLS bool `json:"tls,omitempty"`
|
||||
TCPMux bool `json:"tcpMux,omitempty"`
|
||||
}
|
||||
|
||||
type ClientHello struct {
|
||||
Bootstrap BootstrapInfo `json:"bootstrap,omitempty"`
|
||||
Capabilities ClientCapabilities `json:"capabilities,omitempty"`
|
||||
}
|
||||
|
||||
type ClientCapabilities struct {
|
||||
Message MessageCapabilities `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
type MessageCapabilities struct {
|
||||
Codecs []string `json:"codecs,omitempty"`
|
||||
}
|
||||
|
||||
type ServerHello struct {
|
||||
Selected ServerSelection `json:"selected,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type ServerSelection struct {
|
||||
Message MessageSelection `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
type MessageSelection struct {
|
||||
Codec string `json:"codec,omitempty"`
|
||||
}
|
||||
|
||||
func DefaultClientHello(bootstrap BootstrapInfo) ClientHello {
|
||||
return ClientHello{
|
||||
Bootstrap: bootstrap,
|
||||
Capabilities: ClientCapabilities{
|
||||
Message: MessageCapabilities{
|
||||
Codecs: []string{MessageCodecJSON},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func DefaultServerHello() ServerHello {
|
||||
return ServerHello{
|
||||
Selected: ServerSelection{
|
||||
Message: MessageSelection{
|
||||
Codec: MessageCodecJSON,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func Supports(list []string, value string) bool {
|
||||
return slices.Contains(list, value)
|
||||
}
|
||||
|
||||
func ValidateClientHello(h ClientHello) error {
|
||||
if !Supports(h.Capabilities.Message.Codecs, MessageCodecJSON) {
|
||||
return fmt.Errorf("unsupported message codec")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
120
pkg/proto/wire/wire_test.go
Normal file
120
pkg/proto/wire/wire_test.go
Normal file
@@ -0,0 +1,120 @@
|
||||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package wire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFrameRoundTrip(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
conn := NewConn(&buf)
|
||||
|
||||
in := DefaultClientHello(BootstrapInfo{
|
||||
Transport: "tcp",
|
||||
TLS: true,
|
||||
TCPMux: true,
|
||||
})
|
||||
require.NoError(t, conn.WriteJSONFrame(FrameTypeClientHello, in))
|
||||
|
||||
var out ClientHello
|
||||
require.NoError(t, conn.ReadJSONFrame(FrameTypeClientHello, &out))
|
||||
require.Equal(t, in, out)
|
||||
}
|
||||
|
||||
func TestReadFrameRejectsUnsupportedFlags(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
header := make([]byte, 8)
|
||||
binary.BigEndian.PutUint16(header[0:2], FrameTypeMessage)
|
||||
binary.BigEndian.PutUint16(header[2:4], 1)
|
||||
binary.BigEndian.PutUint32(header[4:8], 0)
|
||||
buf.Write(header)
|
||||
|
||||
_, err := NewConn(&buf).ReadFrame()
|
||||
require.ErrorContains(t, err, "unsupported frame flags")
|
||||
}
|
||||
|
||||
func TestReadFrameRejectsOversizedPayload(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
header := make([]byte, 8)
|
||||
binary.BigEndian.PutUint16(header[0:2], FrameTypeMessage)
|
||||
binary.BigEndian.PutUint32(header[4:8], DefaultMaxFramePayloadSize+1)
|
||||
buf.Write(header)
|
||||
|
||||
_, err := NewConn(&buf).ReadFrame()
|
||||
require.ErrorContains(t, err, "exceeds limit")
|
||||
}
|
||||
|
||||
func TestCheckMagicV2ConsumesMagic(t *testing.T) {
|
||||
client, server := net.Pipe()
|
||||
defer server.Close()
|
||||
|
||||
want := []byte("payload")
|
||||
go func() {
|
||||
defer client.Close()
|
||||
_, _ = client.Write(append([]byte(MagicV2), want...))
|
||||
}()
|
||||
|
||||
out, isV2, err := CheckMagic(server)
|
||||
require.NoError(t, err)
|
||||
require.True(t, isV2)
|
||||
|
||||
got := make([]byte, len(want))
|
||||
_, err = io.ReadFull(out, got)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestWriteMagicIfV2(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
require.NoError(t, WriteMagicIfV2(&buf, ProtocolV1))
|
||||
require.Empty(t, buf.Bytes())
|
||||
|
||||
require.NoError(t, WriteMagicIfV2(&buf, ProtocolV2))
|
||||
require.Equal(t, []byte(MagicV2), buf.Bytes())
|
||||
}
|
||||
|
||||
func TestCheckMagicV1PreservesReadBytes(t *testing.T) {
|
||||
client, server := net.Pipe()
|
||||
defer server.Close()
|
||||
|
||||
want := []byte("legacy payload")
|
||||
go func() {
|
||||
defer client.Close()
|
||||
_, _ = client.Write(want)
|
||||
}()
|
||||
|
||||
out, isV2, err := CheckMagic(server)
|
||||
require.NoError(t, err)
|
||||
require.False(t, isV2)
|
||||
|
||||
got, err := io.ReadAll(out)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestValidateClientHello(t *testing.T) {
|
||||
require.NoError(t, ValidateClientHello(DefaultClientHello(BootstrapInfo{})))
|
||||
|
||||
hello := DefaultClientHello(BootstrapInfo{})
|
||||
hello.Capabilities.Message.Codecs = []string{"unknown"}
|
||||
require.ErrorContains(t, ValidateClientHello(hello), "unsupported message codec")
|
||||
}
|
||||
Reference in New Issue
Block a user