mirror of
https://github.com/fatedier/frp.git
synced 2026-04-28 03:49:09 +08:00
protocol: add v2 wire protocol with binary framing and capability negotiation (#5294)
This commit is contained in:
56
pkg/msg/conn_test.go
Normal file
56
pkg/msg/conn_test.go
Normal file
@@ -0,0 +1,56 @@
|
||||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package msg
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/fatedier/frp/pkg/proto/wire"
|
||||
)
|
||||
|
||||
func TestConnReadWriteMsg(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
protocol string
|
||||
}{
|
||||
{name: "v1", protocol: wire.ProtocolV1},
|
||||
{name: "v2", protocol: wire.ProtocolV2},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
client, server := net.Pipe()
|
||||
defer client.Close()
|
||||
defer server.Close()
|
||||
|
||||
clientConn := NewConn(client, NewReadWriter(client, tt.protocol))
|
||||
serverConn := NewConn(server, NewReadWriter(server, tt.protocol))
|
||||
|
||||
in := &Ping{PrivilegeKey: "key", Timestamp: 123}
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- clientConn.WriteMsg(in)
|
||||
}()
|
||||
|
||||
out, err := serverConn.ReadMsg()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, in, out)
|
||||
require.NoError(t, <-errCh)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -15,10 +15,90 @@
|
||||
package msg
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"reflect"
|
||||
|
||||
"github.com/fatedier/frp/pkg/proto/wire"
|
||||
)
|
||||
|
||||
type ReadWriter interface {
|
||||
ReadMsg() (Message, error)
|
||||
ReadMsgInto(Message) error
|
||||
WriteMsg(Message) error
|
||||
}
|
||||
|
||||
type Conn struct {
|
||||
net.Conn
|
||||
rw ReadWriter
|
||||
}
|
||||
|
||||
func NewConn(conn net.Conn, rw ReadWriter) *Conn {
|
||||
return &Conn{
|
||||
Conn: conn,
|
||||
rw: rw,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) ReadMsg() (Message, error) {
|
||||
return c.rw.ReadMsg()
|
||||
}
|
||||
|
||||
func (c *Conn) ReadMsgInto(m Message) error {
|
||||
return c.rw.ReadMsgInto(m)
|
||||
}
|
||||
|
||||
func (c *Conn) WriteMsg(m Message) error {
|
||||
return c.rw.WriteMsg(m)
|
||||
}
|
||||
|
||||
func (c *Conn) Context() context.Context {
|
||||
if getter, ok := c.Conn.(interface{ Context() context.Context }); ok {
|
||||
return getter.Context()
|
||||
}
|
||||
return context.Background()
|
||||
}
|
||||
|
||||
func (c *Conn) WithContext(ctx context.Context) {
|
||||
if setter, ok := c.Conn.(interface{ WithContext(context.Context) }); ok {
|
||||
setter.WithContext(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
type V1ReadWriter struct {
|
||||
rw io.ReadWriter
|
||||
}
|
||||
|
||||
func NewV1ReadWriter(rw io.ReadWriter) ReadWriter {
|
||||
return &V1ReadWriter{rw: rw}
|
||||
}
|
||||
|
||||
// NewReadWriter wraps rw with the message codec for the selected wire protocol.
|
||||
// An empty protocol keeps the historical v1 behavior for tests and older call sites.
|
||||
func NewReadWriter(rw io.ReadWriter, wireProtocol string) ReadWriter {
|
||||
switch wireProtocol {
|
||||
case wire.ProtocolV2:
|
||||
return NewV2ReadWriter(rw)
|
||||
case "", wire.ProtocolV1:
|
||||
return NewV1ReadWriter(rw)
|
||||
default:
|
||||
return NewV1ReadWriter(rw)
|
||||
}
|
||||
}
|
||||
|
||||
func (rw *V1ReadWriter) ReadMsg() (Message, error) {
|
||||
return ReadMsg(rw.rw)
|
||||
}
|
||||
|
||||
func (rw *V1ReadWriter) ReadMsgInto(m Message) error {
|
||||
return ReadMsgInto(rw.rw, m)
|
||||
}
|
||||
|
||||
func (rw *V1ReadWriter) WriteMsg(m Message) error {
|
||||
return WriteMsg(rw.rw, m)
|
||||
}
|
||||
|
||||
func AsyncHandler(f func(Message)) func(Message) {
|
||||
return func(m Message) {
|
||||
go f(m)
|
||||
@@ -27,7 +107,7 @@ func AsyncHandler(f func(Message)) func(Message) {
|
||||
|
||||
// Dispatcher is used to send messages to net.Conn or register handlers for messages read from net.Conn.
|
||||
type Dispatcher struct {
|
||||
rw io.ReadWriter
|
||||
rw ReadWriter
|
||||
|
||||
sendCh chan Message
|
||||
doneCh chan struct{}
|
||||
@@ -35,7 +115,7 @@ type Dispatcher struct {
|
||||
defaultHandler func(Message)
|
||||
}
|
||||
|
||||
func NewDispatcher(rw io.ReadWriter) *Dispatcher {
|
||||
func NewDispatcher(rw ReadWriter) *Dispatcher {
|
||||
return &Dispatcher{
|
||||
rw: rw,
|
||||
sendCh: make(chan Message, 100),
|
||||
@@ -56,14 +136,14 @@ func (d *Dispatcher) sendLoop() {
|
||||
case <-d.doneCh:
|
||||
return
|
||||
case m := <-d.sendCh:
|
||||
_ = WriteMsg(d.rw, m)
|
||||
_ = d.rw.WriteMsg(m)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Dispatcher) readLoop() {
|
||||
for {
|
||||
m, err := ReadMsg(d.rw)
|
||||
m, err := d.rw.ReadMsg()
|
||||
if err != nil {
|
||||
close(d.doneCh)
|
||||
return
|
||||
|
||||
@@ -20,24 +20,24 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
TypeLogin = 'o'
|
||||
TypeLoginResp = '1'
|
||||
TypeNewProxy = 'p'
|
||||
TypeNewProxyResp = '2'
|
||||
TypeCloseProxy = 'c'
|
||||
TypeNewWorkConn = 'w'
|
||||
TypeReqWorkConn = 'r'
|
||||
TypeStartWorkConn = 's'
|
||||
TypeNewVisitorConn = 'v'
|
||||
TypeNewVisitorConnResp = '3'
|
||||
TypePing = 'h'
|
||||
TypePong = '4'
|
||||
TypeUDPPacket = 'u'
|
||||
TypeNatHoleVisitor = 'i'
|
||||
TypeNatHoleClient = 'n'
|
||||
TypeNatHoleResp = 'm'
|
||||
TypeNatHoleSid = '5'
|
||||
TypeNatHoleReport = '6'
|
||||
TypeLogin byte = 'o'
|
||||
TypeLoginResp byte = '1'
|
||||
TypeNewProxy byte = 'p'
|
||||
TypeNewProxyResp byte = '2'
|
||||
TypeCloseProxy byte = 'c'
|
||||
TypeNewWorkConn byte = 'w'
|
||||
TypeReqWorkConn byte = 'r'
|
||||
TypeStartWorkConn byte = 's'
|
||||
TypeNewVisitorConn byte = 'v'
|
||||
TypeNewVisitorConnResp byte = '3'
|
||||
TypePing byte = 'h'
|
||||
TypePong byte = '4'
|
||||
TypeUDPPacket byte = 'u'
|
||||
TypeNatHoleVisitor byte = 'i'
|
||||
TypeNatHoleClient byte = 'n'
|
||||
TypeNatHoleResp byte = 'm'
|
||||
TypeNatHoleSid byte = '5'
|
||||
TypeNatHoleReport byte = '6'
|
||||
)
|
||||
|
||||
var msgTypeMap = map[byte]any{
|
||||
|
||||
55
pkg/msg/msg_test.go
Normal file
55
pkg/msg/msg_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package msg
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestV1MessageTypeIDsAreStable(t *testing.T) {
|
||||
require.Equal(t, byte('o'), TypeLogin)
|
||||
require.Equal(t, byte('1'), TypeLoginResp)
|
||||
require.Equal(t, byte('p'), TypeNewProxy)
|
||||
require.Equal(t, byte('2'), TypeNewProxyResp)
|
||||
require.Equal(t, byte('c'), TypeCloseProxy)
|
||||
require.Equal(t, byte('w'), TypeNewWorkConn)
|
||||
require.Equal(t, byte('r'), TypeReqWorkConn)
|
||||
require.Equal(t, byte('s'), TypeStartWorkConn)
|
||||
require.Equal(t, byte('v'), TypeNewVisitorConn)
|
||||
require.Equal(t, byte('3'), TypeNewVisitorConnResp)
|
||||
require.Equal(t, byte('h'), TypePing)
|
||||
require.Equal(t, byte('4'), TypePong)
|
||||
require.Equal(t, byte('u'), TypeUDPPacket)
|
||||
require.Equal(t, byte('i'), TypeNatHoleVisitor)
|
||||
require.Equal(t, byte('n'), TypeNatHoleClient)
|
||||
require.Equal(t, byte('m'), TypeNatHoleResp)
|
||||
require.Equal(t, byte('5'), TypeNatHoleSid)
|
||||
require.Equal(t, byte('6'), TypeNatHoleReport)
|
||||
}
|
||||
|
||||
func TestMessageTypeMapIsCompleteAndUnique(t *testing.T) {
|
||||
require.Len(t, msgTypeMap, 18)
|
||||
|
||||
msgTypes := make(map[reflect.Type]struct{}, len(msgTypeMap))
|
||||
|
||||
for _, m := range msgTypeMap {
|
||||
msgType := reflect.TypeOf(m)
|
||||
require.NotContains(t, msgTypes, msgType)
|
||||
msgTypes[msgType] = struct{}{}
|
||||
}
|
||||
}
|
||||
192
pkg/msg/wire_v2.go
Normal file
192
pkg/msg/wire_v2.go
Normal file
@@ -0,0 +1,192 @@
|
||||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package msg
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
|
||||
"github.com/fatedier/frp/pkg/proto/wire"
|
||||
)
|
||||
|
||||
const (
|
||||
V2TypeLogin uint16 = 1
|
||||
V2TypeLoginResp uint16 = 2
|
||||
V2TypeNewProxy uint16 = 3
|
||||
V2TypeNewProxyResp uint16 = 4
|
||||
V2TypeCloseProxy uint16 = 5
|
||||
V2TypeNewWorkConn uint16 = 6
|
||||
V2TypeReqWorkConn uint16 = 7
|
||||
V2TypeStartWorkConn uint16 = 8
|
||||
V2TypeNewVisitorConn uint16 = 9
|
||||
V2TypeNewVisitorConnResp uint16 = 10
|
||||
V2TypePing uint16 = 11
|
||||
V2TypePong uint16 = 12
|
||||
V2TypeUDPPacket uint16 = 13
|
||||
V2TypeNatHoleVisitor uint16 = 14
|
||||
V2TypeNatHoleClient uint16 = 15
|
||||
V2TypeNatHoleResp uint16 = 16
|
||||
V2TypeNatHoleSid uint16 = 17
|
||||
V2TypeNatHoleReport uint16 = 18
|
||||
)
|
||||
|
||||
var v2MsgTypeMap = map[uint16]any{
|
||||
V2TypeLogin: Login{},
|
||||
V2TypeLoginResp: LoginResp{},
|
||||
V2TypeNewProxy: NewProxy{},
|
||||
V2TypeNewProxyResp: NewProxyResp{},
|
||||
V2TypeCloseProxy: CloseProxy{},
|
||||
V2TypeNewWorkConn: NewWorkConn{},
|
||||
V2TypeReqWorkConn: ReqWorkConn{},
|
||||
V2TypeStartWorkConn: StartWorkConn{},
|
||||
V2TypeNewVisitorConn: NewVisitorConn{},
|
||||
V2TypeNewVisitorConnResp: NewVisitorConnResp{},
|
||||
V2TypePing: Ping{},
|
||||
V2TypePong: Pong{},
|
||||
V2TypeUDPPacket: UDPPacket{},
|
||||
V2TypeNatHoleVisitor: NatHoleVisitor{},
|
||||
V2TypeNatHoleClient: NatHoleClient{},
|
||||
V2TypeNatHoleResp: NatHoleResp{},
|
||||
V2TypeNatHoleSid: NatHoleSid{},
|
||||
V2TypeNatHoleReport: NatHoleReport{},
|
||||
}
|
||||
|
||||
var v2MsgReflectTypeMap, v2MsgTypeIDMap = buildV2MsgTypeMaps()
|
||||
|
||||
func buildV2MsgTypeMaps() (map[uint16]reflect.Type, map[reflect.Type]uint16) {
|
||||
reflectTypeMap := make(map[uint16]reflect.Type, len(v2MsgTypeMap))
|
||||
typeIDMap := make(map[reflect.Type]uint16, len(v2MsgTypeMap))
|
||||
for typeID, m := range v2MsgTypeMap {
|
||||
t := reflect.TypeOf(m)
|
||||
reflectTypeMap[typeID] = t
|
||||
typeIDMap[t] = typeID
|
||||
}
|
||||
return reflectTypeMap, typeIDMap
|
||||
}
|
||||
|
||||
type V2ReadWriter struct {
|
||||
conn *wire.Conn
|
||||
}
|
||||
|
||||
func NewV2ReadWriter(rw io.ReadWriter) *V2ReadWriter {
|
||||
return NewV2ReadWriterWithConn(wire.NewConn(rw))
|
||||
}
|
||||
|
||||
func NewV2ReadWriterWithConn(conn *wire.Conn) *V2ReadWriter {
|
||||
return &V2ReadWriter{conn: conn}
|
||||
}
|
||||
|
||||
func (rw *V2ReadWriter) WireConn() *wire.Conn {
|
||||
return rw.conn
|
||||
}
|
||||
|
||||
func (rw *V2ReadWriter) ReadMsg() (Message, error) {
|
||||
f, err := rw.conn.ReadFrame()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return DecodeV2MessageFrame(f)
|
||||
}
|
||||
|
||||
func (rw *V2ReadWriter) ReadMsgInto(m Message) error {
|
||||
f, err := rw.conn.ReadFrame()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return DecodeV2MessageFrameInto(f, m)
|
||||
}
|
||||
|
||||
func (rw *V2ReadWriter) WriteMsg(m Message) error {
|
||||
f, err := EncodeV2MessageFrame(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return rw.conn.WriteFrame(f)
|
||||
}
|
||||
|
||||
func DecodeV2MessageFrame(f *wire.Frame) (Message, error) {
|
||||
if f.Type != wire.FrameTypeMessage {
|
||||
return nil, fmt.Errorf("unexpected frame type %d, want %d", f.Type, wire.FrameTypeMessage)
|
||||
}
|
||||
if len(f.Payload) < 2 {
|
||||
return nil, fmt.Errorf("message frame payload too short")
|
||||
}
|
||||
typeID := binary.BigEndian.Uint16(f.Payload[:2])
|
||||
t, ok := v2MsgReflectTypeMap[typeID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown v2 message type %d", typeID)
|
||||
}
|
||||
m := reflect.New(t).Interface()
|
||||
if err := json.Unmarshal(f.Payload[2:], m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func DecodeV2MessageFrameInto(f *wire.Frame, out Message) error {
|
||||
if f.Type != wire.FrameTypeMessage {
|
||||
return fmt.Errorf("unexpected frame type %d, want %d", f.Type, wire.FrameTypeMessage)
|
||||
}
|
||||
if len(f.Payload) < 2 {
|
||||
return fmt.Errorf("message frame payload too short")
|
||||
}
|
||||
|
||||
typeID := binary.BigEndian.Uint16(f.Payload[:2])
|
||||
outType := reflect.TypeOf(out)
|
||||
if outType == nil || outType.Kind() != reflect.Pointer {
|
||||
return fmt.Errorf("message target must be a pointer")
|
||||
}
|
||||
elemType := outType.Elem()
|
||||
expectedTypeID, ok := v2MsgTypeIDMap[elemType]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown v2 message type %s", elemType.String())
|
||||
}
|
||||
if typeID != expectedTypeID {
|
||||
actualType, ok := v2MsgReflectTypeMap[typeID]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown v2 message type %d", typeID)
|
||||
}
|
||||
return fmt.Errorf("unexpected message type %s, want %s", actualType.String(), elemType.String())
|
||||
}
|
||||
return json.Unmarshal(f.Payload[2:], out)
|
||||
}
|
||||
|
||||
func EncodeV2MessageFrame(m Message) (*wire.Frame, error) {
|
||||
t := reflect.TypeOf(m)
|
||||
if t == nil {
|
||||
return nil, fmt.Errorf("nil message")
|
||||
}
|
||||
if t.Kind() == reflect.Pointer {
|
||||
t = t.Elem()
|
||||
}
|
||||
typeID, ok := v2MsgTypeIDMap[t]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown v2 message type %s", t.String())
|
||||
}
|
||||
content, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload := make([]byte, 2+len(content))
|
||||
binary.BigEndian.PutUint16(payload[:2], typeID)
|
||||
copy(payload[2:], content)
|
||||
return &wire.Frame{
|
||||
Type: wire.FrameTypeMessage,
|
||||
Payload: payload,
|
||||
}, nil
|
||||
}
|
||||
121
pkg/msg/wire_v2_test.go
Normal file
121
pkg/msg/wire_v2_test.go
Normal file
@@ -0,0 +1,121 @@
|
||||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package msg
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/fatedier/frp/pkg/proto/wire"
|
||||
)
|
||||
|
||||
func TestV2ReadWriterRoundTrip(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
rw := NewV2ReadWriter(&buf)
|
||||
|
||||
in := &Login{
|
||||
Version: "test-version",
|
||||
RunID: "run-id",
|
||||
User: "user",
|
||||
}
|
||||
require.NoError(t, rw.WriteMsg(in))
|
||||
|
||||
out, err := rw.ReadMsg()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, in, out)
|
||||
}
|
||||
|
||||
func TestNewReadWriter(t *testing.T) {
|
||||
require.IsType(t, &V1ReadWriter{}, NewReadWriter(&bytes.Buffer{}, ""))
|
||||
require.IsType(t, &V1ReadWriter{}, NewReadWriter(&bytes.Buffer{}, wire.ProtocolV1))
|
||||
require.IsType(t, &V2ReadWriter{}, NewReadWriter(&bytes.Buffer{}, wire.ProtocolV2))
|
||||
}
|
||||
|
||||
func TestV2MessageTypeIDsAreStable(t *testing.T) {
|
||||
require.Equal(t, uint16(1), V2TypeLogin)
|
||||
require.Equal(t, uint16(2), V2TypeLoginResp)
|
||||
require.Equal(t, uint16(3), V2TypeNewProxy)
|
||||
require.Equal(t, uint16(4), V2TypeNewProxyResp)
|
||||
require.Equal(t, uint16(5), V2TypeCloseProxy)
|
||||
require.Equal(t, uint16(6), V2TypeNewWorkConn)
|
||||
require.Equal(t, uint16(7), V2TypeReqWorkConn)
|
||||
require.Equal(t, uint16(8), V2TypeStartWorkConn)
|
||||
require.Equal(t, uint16(9), V2TypeNewVisitorConn)
|
||||
require.Equal(t, uint16(10), V2TypeNewVisitorConnResp)
|
||||
require.Equal(t, uint16(11), V2TypePing)
|
||||
require.Equal(t, uint16(12), V2TypePong)
|
||||
require.Equal(t, uint16(13), V2TypeUDPPacket)
|
||||
require.Equal(t, uint16(14), V2TypeNatHoleVisitor)
|
||||
require.Equal(t, uint16(15), V2TypeNatHoleClient)
|
||||
require.Equal(t, uint16(16), V2TypeNatHoleResp)
|
||||
require.Equal(t, uint16(17), V2TypeNatHoleSid)
|
||||
require.Equal(t, uint16(18), V2TypeNatHoleReport)
|
||||
}
|
||||
|
||||
func TestV2MessageFrameEncoding(t *testing.T) {
|
||||
frame, err := EncodeV2MessageFrame(&ReqWorkConn{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wire.FrameTypeMessage, frame.Type)
|
||||
require.Len(t, frame.Payload, 4)
|
||||
require.Equal(t, V2TypeReqWorkConn, binary.BigEndian.Uint16(frame.Payload[:2]))
|
||||
|
||||
out, err := DecodeV2MessageFrame(frame)
|
||||
require.NoError(t, err)
|
||||
require.IsType(t, &ReqWorkConn{}, out)
|
||||
}
|
||||
|
||||
func TestDecodeV2MessageFrameInto(t *testing.T) {
|
||||
in := &StartWorkConn{ProxyName: "tcp", SrcAddr: "127.0.0.1", SrcPort: 1234}
|
||||
frame, err := EncodeV2MessageFrame(in)
|
||||
require.NoError(t, err)
|
||||
|
||||
var out StartWorkConn
|
||||
require.NoError(t, DecodeV2MessageFrameInto(frame, &out))
|
||||
require.Equal(t, *in, out)
|
||||
}
|
||||
|
||||
func TestDecodeV2MessageFrameRejectsInvalidFrame(t *testing.T) {
|
||||
_, err := DecodeV2MessageFrame(&wire.Frame{Type: wire.FrameTypeClientHello})
|
||||
require.ErrorContains(t, err, "unexpected frame type")
|
||||
|
||||
_, err = DecodeV2MessageFrame(&wire.Frame{Type: wire.FrameTypeMessage, Payload: []byte{0}})
|
||||
require.ErrorContains(t, err, "payload too short")
|
||||
|
||||
payload := make([]byte, 4)
|
||||
binary.BigEndian.PutUint16(payload[:2], 65535)
|
||||
copy(payload[2:], []byte("{}"))
|
||||
_, err = DecodeV2MessageFrame(&wire.Frame{Type: wire.FrameTypeMessage, Payload: payload})
|
||||
require.ErrorContains(t, err, "unknown v2 message type")
|
||||
}
|
||||
|
||||
func TestDecodeV2MessageFrameIntoRejectsWrongTarget(t *testing.T) {
|
||||
frame, err := EncodeV2MessageFrame(&ReqWorkConn{})
|
||||
require.NoError(t, err)
|
||||
|
||||
var out StartWorkConn
|
||||
err = DecodeV2MessageFrameInto(frame, &out)
|
||||
require.ErrorContains(t, err, "unexpected message type")
|
||||
|
||||
err = DecodeV2MessageFrameInto(frame, StartWorkConn{})
|
||||
require.ErrorContains(t, err, "must be a pointer")
|
||||
}
|
||||
|
||||
func TestEncodeV2MessageFrameRejectsUnknownMessage(t *testing.T) {
|
||||
_, err := EncodeV2MessageFrame(struct{}{})
|
||||
require.ErrorContains(t, err, "unknown v2 message type")
|
||||
}
|
||||
Reference in New Issue
Block a user