mirror of
https://github.com/fatedier/frp.git
synced 2026-04-27 11:29:09 +08:00
180 lines
3.7 KiB
Go
180 lines
3.7 KiB
Go
// Copyright 2023 The frp Authors
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package msg
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"net"
|
|
"reflect"
|
|
|
|
"github.com/fatedier/frp/pkg/proto/wire"
|
|
)
|
|
|
|
type ReadWriter interface {
|
|
ReadMsg() (Message, error)
|
|
ReadMsgInto(Message) error
|
|
WriteMsg(Message) error
|
|
}
|
|
|
|
type Conn struct {
|
|
net.Conn
|
|
rw ReadWriter
|
|
}
|
|
|
|
func NewConn(conn net.Conn, rw ReadWriter) *Conn {
|
|
return &Conn{
|
|
Conn: conn,
|
|
rw: rw,
|
|
}
|
|
}
|
|
|
|
func (c *Conn) ReadMsg() (Message, error) {
|
|
return c.rw.ReadMsg()
|
|
}
|
|
|
|
func (c *Conn) ReadMsgInto(m Message) error {
|
|
return c.rw.ReadMsgInto(m)
|
|
}
|
|
|
|
func (c *Conn) WriteMsg(m Message) error {
|
|
return c.rw.WriteMsg(m)
|
|
}
|
|
|
|
func (c *Conn) Context() context.Context {
|
|
if getter, ok := c.Conn.(interface{ Context() context.Context }); ok {
|
|
return getter.Context()
|
|
}
|
|
return context.Background()
|
|
}
|
|
|
|
func (c *Conn) WithContext(ctx context.Context) {
|
|
if setter, ok := c.Conn.(interface{ WithContext(context.Context) }); ok {
|
|
setter.WithContext(ctx)
|
|
}
|
|
}
|
|
|
|
type V1ReadWriter struct {
|
|
rw io.ReadWriter
|
|
}
|
|
|
|
func NewV1ReadWriter(rw io.ReadWriter) ReadWriter {
|
|
return &V1ReadWriter{rw: rw}
|
|
}
|
|
|
|
// NewReadWriter wraps rw with the message codec for the selected wire protocol.
|
|
// An empty protocol keeps the historical v1 behavior for tests and older call sites.
|
|
func NewReadWriter(rw io.ReadWriter, wireProtocol string) ReadWriter {
|
|
switch wireProtocol {
|
|
case wire.ProtocolV2:
|
|
return NewV2ReadWriter(rw)
|
|
case "", wire.ProtocolV1:
|
|
return NewV1ReadWriter(rw)
|
|
default:
|
|
return NewV1ReadWriter(rw)
|
|
}
|
|
}
|
|
|
|
func (rw *V1ReadWriter) ReadMsg() (Message, error) {
|
|
return ReadMsg(rw.rw)
|
|
}
|
|
|
|
func (rw *V1ReadWriter) ReadMsgInto(m Message) error {
|
|
return ReadMsgInto(rw.rw, m)
|
|
}
|
|
|
|
func (rw *V1ReadWriter) WriteMsg(m Message) error {
|
|
return WriteMsg(rw.rw, m)
|
|
}
|
|
|
|
func AsyncHandler(f func(Message)) func(Message) {
|
|
return func(m Message) {
|
|
go f(m)
|
|
}
|
|
}
|
|
|
|
// Dispatcher is used to send messages to net.Conn or register handlers for messages read from net.Conn.
|
|
type Dispatcher struct {
|
|
rw ReadWriter
|
|
|
|
sendCh chan Message
|
|
doneCh chan struct{}
|
|
msgHandlers map[reflect.Type]func(Message)
|
|
defaultHandler func(Message)
|
|
}
|
|
|
|
func NewDispatcher(rw ReadWriter) *Dispatcher {
|
|
return &Dispatcher{
|
|
rw: rw,
|
|
sendCh: make(chan Message, 100),
|
|
doneCh: make(chan struct{}),
|
|
msgHandlers: make(map[reflect.Type]func(Message)),
|
|
}
|
|
}
|
|
|
|
// Run will block until io.EOF or some error occurs.
|
|
func (d *Dispatcher) Run() {
|
|
go d.sendLoop()
|
|
go d.readLoop()
|
|
}
|
|
|
|
func (d *Dispatcher) sendLoop() {
|
|
for {
|
|
select {
|
|
case <-d.doneCh:
|
|
return
|
|
case m := <-d.sendCh:
|
|
_ = d.rw.WriteMsg(m)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (d *Dispatcher) readLoop() {
|
|
for {
|
|
m, err := d.rw.ReadMsg()
|
|
if err != nil {
|
|
close(d.doneCh)
|
|
return
|
|
}
|
|
|
|
if handler, ok := d.msgHandlers[reflect.TypeOf(m)]; ok {
|
|
handler(m)
|
|
} else if d.defaultHandler != nil {
|
|
d.defaultHandler(m)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (d *Dispatcher) Send(m Message) error {
|
|
select {
|
|
case <-d.doneCh:
|
|
return io.EOF
|
|
case d.sendCh <- m:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (d *Dispatcher) RegisterHandler(msg Message, handler func(Message)) {
|
|
d.msgHandlers[reflect.TypeOf(msg)] = handler
|
|
}
|
|
|
|
func (d *Dispatcher) RegisterDefaultHandler(handler func(Message)) {
|
|
d.defaultHandler = handler
|
|
}
|
|
|
|
func (d *Dispatcher) Done() chan struct{} {
|
|
return d.doneCh
|
|
}
|