forked from Mxmilu666/frp
Fix conflicts in fatedier/connection_pool with dev
Conflicts: src/frp/cmd/frpc/control.go src/frp/cmd/frps/control.go src/frp/models/config/config.go src/frp/models/server/server.go
This commit is contained in:
@@ -117,6 +117,16 @@ func ConnectServer(host string, port int64) (c *Conn, err error) {
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// if the tcpConn is different with c.TcpConn
|
||||
// you should call c.Close() first
|
||||
func (c *Conn) SetTcpConn(tcpConn net.Conn) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
c.TcpConn = tcpConn
|
||||
c.closeFlag = false
|
||||
c.Reader = bufio.NewReader(c.TcpConn)
|
||||
}
|
||||
|
||||
func (c *Conn) GetRemoteAddr() (addr string) {
|
||||
return c.TcpConn.RemoteAddr().String()
|
||||
}
|
||||
@@ -125,6 +135,11 @@ func (c *Conn) GetLocalAddr() (addr string) {
|
||||
return c.TcpConn.LocalAddr().String()
|
||||
}
|
||||
|
||||
func (c *Conn) Read(p []byte) (n int, err error) {
|
||||
n, err = c.Reader.Read(p)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Conn) ReadLine() (buff string, err error) {
|
||||
buff, err = c.Reader.ReadString('\n')
|
||||
if err != nil {
|
||||
@@ -138,10 +153,14 @@ func (c *Conn) ReadLine() (buff string, err error) {
|
||||
return buff, err
|
||||
}
|
||||
|
||||
func (c *Conn) WriteBytes(content []byte) (n int, err error) {
|
||||
n, err = c.TcpConn.Write(content)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Conn) Write(content string) (err error) {
|
||||
_, err = c.TcpConn.Write([]byte(content))
|
||||
return err
|
||||
|
||||
}
|
||||
|
||||
func (c *Conn) SetDeadline(t time.Time) error {
|
||||
|
||||
@@ -16,8 +16,12 @@ package vhost
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -42,6 +46,123 @@ func GetHttpHostname(c *conn.Conn) (_ net.Conn, routerName string, err error) {
|
||||
}
|
||||
|
||||
func NewHttpMuxer(listener *conn.Listener, timeout time.Duration) (*HttpMuxer, error) {
|
||||
mux, err := NewVhostMuxer(listener, GetHttpHostname, timeout)
|
||||
mux, err := NewVhostMuxer(listener, GetHttpHostname, HttpHostNameRewrite, timeout)
|
||||
return &HttpMuxer{mux}, err
|
||||
}
|
||||
|
||||
func HttpHostNameRewrite(c *conn.Conn, rewriteHost string) (_ net.Conn, err error) {
|
||||
sc, rd := newShareConn(c.TcpConn)
|
||||
var buff []byte
|
||||
if buff, err = hostNameRewrite(rd, rewriteHost); err != nil {
|
||||
return sc, err
|
||||
}
|
||||
err = sc.WriteBuff(buff)
|
||||
return sc, err
|
||||
}
|
||||
|
||||
func hostNameRewrite(request io.Reader, rewriteHost string) (_ []byte, err error) {
|
||||
buffer := make([]byte, 1024)
|
||||
request.Read(buffer)
|
||||
retBuffer, err := parseRequest(buffer, rewriteHost)
|
||||
return retBuffer, err
|
||||
}
|
||||
|
||||
func parseRequest(org []byte, rewriteHost string) (ret []byte, err error) {
|
||||
tp := bytes.NewBuffer(org)
|
||||
// First line: GET /index.html HTTP/1.0
|
||||
var b []byte
|
||||
if b, err = tp.ReadBytes('\n'); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req := new(http.Request)
|
||||
// we invoked ReadRequest in GetHttpHostname before, so we ignore error
|
||||
req.Method, req.RequestURI, req.Proto, _ = parseRequestLine(string(b))
|
||||
rawurl := req.RequestURI
|
||||
// CONNECT www.google.com:443 HTTP/1.1
|
||||
justAuthority := req.Method == "CONNECT" && !strings.HasPrefix(rawurl, "/")
|
||||
if justAuthority {
|
||||
rawurl = "http://" + rawurl
|
||||
}
|
||||
req.URL, _ = url.ParseRequestURI(rawurl)
|
||||
if justAuthority {
|
||||
// Strip the bogus "http://" back off.
|
||||
req.URL.Scheme = ""
|
||||
}
|
||||
|
||||
// RFC2616: first case
|
||||
// GET /index.html HTTP/1.1
|
||||
// Host: www.google.com
|
||||
if req.URL.Host == "" {
|
||||
changedBuf, err := changeHostName(tp, rewriteHost)
|
||||
buf := new(bytes.Buffer)
|
||||
buf.Write(b)
|
||||
buf.Write(changedBuf)
|
||||
return buf.Bytes(), err
|
||||
}
|
||||
|
||||
// RFC2616: second case
|
||||
// GET http://www.google.com/index.html HTTP/1.1
|
||||
// Host: doesntmatter
|
||||
// In this case, any Host line is ignored.
|
||||
hostPort := strings.Split(req.URL.Host, ":")
|
||||
if len(hostPort) == 1 {
|
||||
req.URL.Host = rewriteHost
|
||||
} else if len(hostPort) == 2 {
|
||||
req.URL.Host = fmt.Sprintf("%s:%s", rewriteHost, hostPort[1])
|
||||
}
|
||||
firstLine := req.Method + " " + req.URL.String() + " " + req.Proto
|
||||
buf := new(bytes.Buffer)
|
||||
buf.WriteString(firstLine)
|
||||
tp.WriteTo(buf)
|
||||
return buf.Bytes(), err
|
||||
|
||||
}
|
||||
|
||||
// parseRequestLine parses "GET /foo HTTP/1.1" into its three parts.
|
||||
func parseRequestLine(line string) (method, requestURI, proto string, ok bool) {
|
||||
s1 := strings.Index(line, " ")
|
||||
s2 := strings.Index(line[s1+1:], " ")
|
||||
if s1 < 0 || s2 < 0 {
|
||||
return
|
||||
}
|
||||
s2 += s1 + 1
|
||||
return line[:s1], line[s1+1 : s2], line[s2+1:], true
|
||||
}
|
||||
|
||||
func changeHostName(buff *bytes.Buffer, rewriteHost string) (_ []byte, err error) {
|
||||
retBuf := new(bytes.Buffer)
|
||||
|
||||
peek := buff.Bytes()
|
||||
for len(peek) > 0 {
|
||||
i := bytes.IndexByte(peek, '\n')
|
||||
if i < 3 {
|
||||
// Not present (-1) or found within the next few bytes,
|
||||
// implying we're at the end ("\r\n\r\n" or "\n\n")
|
||||
return nil, err
|
||||
}
|
||||
kv := peek[:i]
|
||||
j := bytes.IndexByte(kv, ':')
|
||||
if j < 0 {
|
||||
return nil, fmt.Errorf("malformed MIME header line: " + string(kv))
|
||||
}
|
||||
if strings.Contains(strings.ToLower(string(kv[:j])), "host") {
|
||||
var hostHeader string
|
||||
portPos := bytes.IndexByte(kv[j+1:], ':')
|
||||
if portPos == -1 {
|
||||
hostHeader = fmt.Sprintf("Host: %s\n", rewriteHost)
|
||||
} else {
|
||||
hostHeader = fmt.Sprintf("Host: %s:%s\n", rewriteHost, kv[portPos+1:])
|
||||
}
|
||||
retBuf.WriteString(hostHeader)
|
||||
peek = peek[i+1:]
|
||||
break
|
||||
} else {
|
||||
retBuf.Write(peek[:i])
|
||||
retBuf.WriteByte('\n')
|
||||
}
|
||||
|
||||
peek = peek[i+1:]
|
||||
}
|
||||
retBuf.Write(peek)
|
||||
return retBuf.Bytes(), err
|
||||
}
|
||||
|
||||
@@ -47,7 +47,7 @@ type HttpsMuxer struct {
|
||||
}
|
||||
|
||||
func NewHttpsMuxer(listener *conn.Listener, timeout time.Duration) (*HttpsMuxer, error) {
|
||||
mux, err := NewVhostMuxer(listener, GetHttpsHostname, timeout)
|
||||
mux, err := NewVhostMuxer(listener, GetHttpsHostname, nil, timeout)
|
||||
return &HttpsMuxer{mux}, err
|
||||
}
|
||||
|
||||
|
||||
@@ -27,37 +27,42 @@ import (
|
||||
)
|
||||
|
||||
type muxFunc func(*conn.Conn) (net.Conn, string, error)
|
||||
type hostRewriteFunc func(*conn.Conn, string) (net.Conn, error)
|
||||
|
||||
type VhostMuxer struct {
|
||||
listener *conn.Listener
|
||||
timeout time.Duration
|
||||
vhostFunc muxFunc
|
||||
rewriteFunc hostRewriteFunc
|
||||
registryMap map[string]*Listener
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
func NewVhostMuxer(listener *conn.Listener, vhostFunc muxFunc, timeout time.Duration) (mux *VhostMuxer, err error) {
|
||||
func NewVhostMuxer(listener *conn.Listener, vhostFunc muxFunc, rewriteFunc hostRewriteFunc, timeout time.Duration) (mux *VhostMuxer, err error) {
|
||||
mux = &VhostMuxer{
|
||||
listener: listener,
|
||||
timeout: timeout,
|
||||
vhostFunc: vhostFunc,
|
||||
rewriteFunc: rewriteFunc,
|
||||
registryMap: make(map[string]*Listener),
|
||||
}
|
||||
go mux.run()
|
||||
return mux, nil
|
||||
}
|
||||
|
||||
func (v *VhostMuxer) Listen(name string) (l *Listener, err error) {
|
||||
// listen for a new domain name, if rewriteHost is not empty and rewriteFunc is not nil, then rewrite the host header to rewriteHost
|
||||
func (v *VhostMuxer) Listen(name string, rewriteHost string) (l *Listener, err error) {
|
||||
v.mutex.Lock()
|
||||
defer v.mutex.Unlock()
|
||||
if _, exist := v.registryMap[name]; exist {
|
||||
return nil, fmt.Errorf("name %s is already bound", name)
|
||||
return nil, fmt.Errorf("domain name %s is already bound", name)
|
||||
}
|
||||
|
||||
l = &Listener{
|
||||
name: name,
|
||||
mux: v,
|
||||
accept: make(chan *conn.Conn),
|
||||
name: name,
|
||||
rewriteHost: rewriteHost,
|
||||
mux: v,
|
||||
accept: make(chan *conn.Conn),
|
||||
}
|
||||
v.registryMap[name] = l
|
||||
return l, nil
|
||||
@@ -105,15 +110,16 @@ func (v *VhostMuxer) handle(c *conn.Conn) {
|
||||
if err = sConn.SetDeadline(time.Time{}); err != nil {
|
||||
return
|
||||
}
|
||||
c.TcpConn = sConn
|
||||
c.SetTcpConn(sConn)
|
||||
|
||||
l.accept <- c
|
||||
}
|
||||
|
||||
type Listener struct {
|
||||
name string
|
||||
mux *VhostMuxer // for closing VhostMuxer
|
||||
accept chan *conn.Conn
|
||||
name string
|
||||
rewriteHost string
|
||||
mux *VhostMuxer // for closing VhostMuxer
|
||||
accept chan *conn.Conn
|
||||
}
|
||||
|
||||
func (l *Listener) Accept() (*conn.Conn, error) {
|
||||
@@ -121,6 +127,17 @@ func (l *Listener) Accept() (*conn.Conn, error) {
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Listener closed")
|
||||
}
|
||||
|
||||
// if rewriteFunc is exist and rewriteHost is set
|
||||
// rewrite http requests with a modified host header
|
||||
if l.mux.rewriteFunc != nil && l.rewriteHost != "" {
|
||||
fmt.Printf("host rewrite: %s\n", l.rewriteHost)
|
||||
sConn, err := l.mux.rewriteFunc(conn, l.rewriteHost)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("http host header rewrite failed")
|
||||
}
|
||||
conn.SetTcpConn(sConn)
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
@@ -140,6 +157,7 @@ type sharedConn struct {
|
||||
buff *bytes.Buffer
|
||||
}
|
||||
|
||||
// the bytes you read in io.Reader, will be reserved in sharedConn
|
||||
func newShareConn(conn net.Conn) (*sharedConn, io.Reader) {
|
||||
sc := &sharedConn{
|
||||
Conn: conn,
|
||||
@@ -166,3 +184,9 @@ func (sc *sharedConn) Read(p []byte) (n int, err error) {
|
||||
sc.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
func (sc *sharedConn) WriteBuff(buffer []byte) (err error) {
|
||||
sc.buff.Reset()
|
||||
_, err = sc.buff.Write(buffer)
|
||||
return err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user