mirror of
https://github.com/fatedier/frp.git
synced 2026-03-09 19:39:11 +08:00
Compare commits
32 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
01413c3853 | ||
|
|
adcd2e64b6 | ||
|
|
48e8901466 | ||
|
|
bcd2424c24 | ||
|
|
c7ac12ea0f | ||
|
|
eeb0dacfc1 | ||
|
|
535eb3db35 | ||
|
|
605f3bdece | ||
|
|
764a626b6e | ||
|
|
c2454e7114 | ||
|
|
017d71717f | ||
|
|
bd200b1a3b | ||
|
|
c70ceff370 | ||
|
|
bb3d0e7140 | ||
|
|
cf396563f8 | ||
|
|
0b4f83cd04 | ||
|
|
e9f7a1a9f2 | ||
|
|
d644593342 | ||
|
|
427c4ca3ae | ||
|
|
f2d1f3739a | ||
|
|
c23894f156 | ||
|
|
cb459b02b6 | ||
|
|
8f633fe363 | ||
|
|
c62a1da161 | ||
|
|
f22f7d539c | ||
|
|
462c987f6d | ||
|
|
541878af4d | ||
|
|
b7435967b0 | ||
|
|
774478d071 | ||
|
|
fbeb6ca43a | ||
|
|
381245a439 | ||
|
|
01997deb98 |
@@ -2,7 +2,7 @@ version: 2
|
||||
jobs:
|
||||
go-version-latest:
|
||||
docker:
|
||||
- image: cimg/go:1.24-node
|
||||
- image: cimg/go:1.25-node
|
||||
resource_class: large
|
||||
steps:
|
||||
- checkout
|
||||
|
||||
6
.github/workflows/golangci-lint.yml
vendored
6
.github/workflows/golangci-lint.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.24'
|
||||
go-version: '1.25'
|
||||
cache: false
|
||||
- uses: actions/setup-node@v4
|
||||
with:
|
||||
@@ -29,7 +29,7 @@ jobs:
|
||||
run: make build
|
||||
working-directory: web/frpc
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v8
|
||||
uses: golangci/golangci-lint-action@v9
|
||||
with:
|
||||
# Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
|
||||
version: v2.3
|
||||
version: v2.10
|
||||
|
||||
2
.github/workflows/goreleaser.yml
vendored
2
.github/workflows/goreleaser.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.24'
|
||||
go-version: '1.25'
|
||||
- uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '22'
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -30,4 +30,5 @@ client.key
|
||||
|
||||
# AI
|
||||
CLAUDE.md
|
||||
AGENTS.md
|
||||
.sisyphus/
|
||||
|
||||
@@ -18,6 +18,7 @@ linters:
|
||||
- lll
|
||||
- makezero
|
||||
- misspell
|
||||
- modernize
|
||||
- prealloc
|
||||
- predeclared
|
||||
- revive
|
||||
@@ -33,13 +34,7 @@ linters:
|
||||
disabled-checks:
|
||||
- exitAfterDefer
|
||||
gosec:
|
||||
excludes:
|
||||
- G401
|
||||
- G402
|
||||
- G404
|
||||
- G501
|
||||
- G115
|
||||
- G204
|
||||
excludes: ["G115", "G117", "G204", "G401", "G402", "G404", "G501", "G703", "G704", "G705"]
|
||||
severity: low
|
||||
confidence: low
|
||||
govet:
|
||||
@@ -53,6 +48,9 @@ linters:
|
||||
ignore-rules:
|
||||
- cancelled
|
||||
- marshalled
|
||||
modernize:
|
||||
disable:
|
||||
- omitzero
|
||||
unparam:
|
||||
check-exported: false
|
||||
exclusions:
|
||||
@@ -77,6 +75,9 @@ linters:
|
||||
- linters:
|
||||
- revive
|
||||
text: "avoid meaningless package names"
|
||||
- linters:
|
||||
- revive
|
||||
text: "Go standard library package names"
|
||||
- linters:
|
||||
- unparam
|
||||
text: is always false
|
||||
|
||||
21
Makefile
21
Makefile
@@ -1,6 +1,7 @@
|
||||
export PATH := $(PATH):`go env GOPATH`/bin
|
||||
export GO111MODULE=on
|
||||
LDFLAGS := -s -w
|
||||
NOWEB_TAG = $(shell [ ! -d web/frps/dist ] || [ ! -d web/frpc/dist ] && echo ',noweb')
|
||||
|
||||
.PHONY: web frps-web frpc-web frps frpc
|
||||
|
||||
@@ -28,23 +29,23 @@ fmt-more:
|
||||
gci:
|
||||
gci write -s standard -s default -s "prefix(github.com/fatedier/frp/)" ./
|
||||
|
||||
vet: web
|
||||
go vet ./...
|
||||
vet:
|
||||
go vet -tags "$(NOWEB_TAG)" ./...
|
||||
|
||||
frps:
|
||||
env CGO_ENABLED=0 go build -trimpath -ldflags "$(LDFLAGS)" -tags frps -o bin/frps ./cmd/frps
|
||||
env CGO_ENABLED=0 go build -trimpath -ldflags "$(LDFLAGS)" -tags "frps$(NOWEB_TAG)" -o bin/frps ./cmd/frps
|
||||
|
||||
frpc:
|
||||
env CGO_ENABLED=0 go build -trimpath -ldflags "$(LDFLAGS)" -tags frpc -o bin/frpc ./cmd/frpc
|
||||
env CGO_ENABLED=0 go build -trimpath -ldflags "$(LDFLAGS)" -tags "frpc$(NOWEB_TAG)" -o bin/frpc ./cmd/frpc
|
||||
|
||||
test: gotest
|
||||
|
||||
gotest: web
|
||||
go test -v --cover ./assets/...
|
||||
go test -v --cover ./cmd/...
|
||||
go test -v --cover ./client/...
|
||||
go test -v --cover ./server/...
|
||||
go test -v --cover ./pkg/...
|
||||
gotest:
|
||||
go test -tags "$(NOWEB_TAG)" -v --cover ./assets/...
|
||||
go test -tags "$(NOWEB_TAG)" -v --cover ./cmd/...
|
||||
go test -tags "$(NOWEB_TAG)" -v --cover ./client/...
|
||||
go test -tags "$(NOWEB_TAG)" -v --cover ./server/...
|
||||
go test -tags "$(NOWEB_TAG)" -v --cover ./pkg/...
|
||||
|
||||
e2e:
|
||||
./hack/run-e2e.sh
|
||||
|
||||
27
README.md
27
README.md
@@ -13,6 +13,16 @@ frp is an open source project with its ongoing development made possible entirel
|
||||
|
||||
<h3 align="center">Gold Sponsors</h3>
|
||||
<!--gold sponsors start-->
|
||||
<div align="center">
|
||||
|
||||
## Recall.ai - API for meeting recordings
|
||||
|
||||
If you're looking for a meeting recording API, consider checking out [Recall.ai](https://www.recall.ai/?utm_source=github&utm_medium=sponsorship&utm_campaign=fatedier-frp),
|
||||
|
||||
an API that records Zoom, Google Meet, Microsoft Teams, in-person meetings, and more.
|
||||
|
||||
</div>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://requestly.com/?utm_source=github&utm_medium=partnered&utm_campaign=frp" target="_blank">
|
||||
<img width="480px" src="https://github.com/user-attachments/assets/24670320-997d-4d62-9bca-955c59fe883d">
|
||||
@@ -40,15 +50,6 @@ frp is an open source project with its ongoing development made possible entirel
|
||||
<sub>An open source, self-hosted alternative to public clouds, built for data ownership and privacy</sub>
|
||||
</a>
|
||||
</p>
|
||||
<div align="center">
|
||||
|
||||
## Recall.ai - API for meeting recordings
|
||||
|
||||
If you're looking for a meeting recording API, consider checking out [Recall.ai](https://www.recall.ai/?utm_source=github&utm_medium=sponsorship&utm_campaign=fatedier-frp),
|
||||
|
||||
an API that records Zoom, Google Meet, Microsoft Teams, in-person meetings, and more.
|
||||
|
||||
</div>
|
||||
<!--gold sponsors end-->
|
||||
|
||||
## What is frp?
|
||||
@@ -800,6 +801,14 @@ Then run command `frpc reload -c ./frpc.toml` and wait for about 10 seconds to l
|
||||
|
||||
**Note that global client parameters won't be modified except 'start'.**
|
||||
|
||||
`start` is a global allowlist evaluated after all sources are merged (config file/include/store).
|
||||
If `start` is non-empty, any proxy or visitor not listed there will not be started, including
|
||||
entries created via Store API.
|
||||
|
||||
`start` is kept mainly for compatibility and is generally not recommended for new configurations.
|
||||
Prefer per-proxy/per-visitor `enabled`, and keep `start` empty unless you explicitly want this
|
||||
global allowlist behavior.
|
||||
|
||||
You can run command `frpc verify -c ./frpc.toml` before reloading to check if there are config errors.
|
||||
|
||||
### Get proxy status from client
|
||||
|
||||
19
README_zh.md
19
README_zh.md
@@ -15,6 +15,16 @@ frp 是一个完全开源的项目,我们的开发工作完全依靠赞助者
|
||||
|
||||
<h3 align="center">Gold Sponsors</h3>
|
||||
<!--gold sponsors start-->
|
||||
<div align="center">
|
||||
|
||||
## Recall.ai - API for meeting recordings
|
||||
|
||||
If you're looking for a meeting recording API, consider checking out [Recall.ai](https://www.recall.ai/?utm_source=github&utm_medium=sponsorship&utm_campaign=fatedier-frp),
|
||||
|
||||
an API that records Zoom, Google Meet, Microsoft Teams, in-person meetings, and more.
|
||||
|
||||
</div>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://requestly.com/?utm_source=github&utm_medium=partnered&utm_campaign=frp" target="_blank">
|
||||
<img width="480px" src="https://github.com/user-attachments/assets/24670320-997d-4d62-9bca-955c59fe883d">
|
||||
@@ -42,15 +52,6 @@ frp 是一个完全开源的项目,我们的开发工作完全依靠赞助者
|
||||
<sub>An open source, self-hosted alternative to public clouds, built for data ownership and privacy</sub>
|
||||
</a>
|
||||
</p>
|
||||
<div align="center">
|
||||
|
||||
## Recall.ai - API for meeting recordings
|
||||
|
||||
If you're looking for a meeting recording API, consider checking out [Recall.ai](https://www.recall.ai/?utm_source=github&utm_medium=sponsorship&utm_campaign=fatedier-frp),
|
||||
|
||||
an API that records Zoom, Google Meet, Microsoft Teams, in-person meetings, and more.
|
||||
|
||||
</div>
|
||||
<!--gold sponsors end-->
|
||||
|
||||
## 为什么使用 frp ?
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
## Features
|
||||
|
||||
* frpc now supports a `clientID` option to uniquely identify client instances. The server dashboard displays all connected clients with their online/offline status, connection history, and metadata, making it easier to monitor and manage multiple frpc deployments.
|
||||
* Redesigned the frp web dashboard with a modern UI, dark mode support, and improved navigation.
|
||||
* Added a built-in `store` capability for frpc, including persisted store source (`[store] path = "..."`), Store CRUD admin APIs (`/api/store/proxies*`, `/api/store/visitors*`) with runtime reload, and Store management pages in the frpc web dashboard.
|
||||
|
||||
## Fixes
|
||||
## Improvements
|
||||
|
||||
* Fixed UDP proxy protocol sending header on every packet instead of only the first packet of each session.
|
||||
* Kept proxy/visitor names as raw config names during completion; moved user-prefix handling to explicit wire-level naming logic.
|
||||
* Added `noweb` build tag to allow compiling without frontend assets. `make build` now auto-detects missing `web/*/dist` directories and skips embedding, so a fresh clone can build without running `make web` first. The dashboard gracefully returns 404 when assets are not embedded.
|
||||
* Improved config parsing errors: for `.toml` files, syntax errors now return immediately with parser position details (line/column when available) instead of falling through to YAML/JSON parsing, and TOML type mismatches report field-level errors without misleading line numbers.
|
||||
|
||||
@@ -29,14 +29,23 @@ var (
|
||||
prefixPath string
|
||||
)
|
||||
|
||||
type emptyFS struct{}
|
||||
|
||||
func (emptyFS) Open(name string) (http.File, error) {
|
||||
return nil, &fs.PathError{Op: "open", Path: name, Err: fs.ErrNotExist}
|
||||
}
|
||||
|
||||
// if path is empty, load assets in memory
|
||||
// or set FileSystem using disk files
|
||||
func Load(path string) {
|
||||
prefixPath = path
|
||||
if prefixPath != "" {
|
||||
switch {
|
||||
case prefixPath != "":
|
||||
FileSystem = http.Dir(prefixPath)
|
||||
} else {
|
||||
case content != nil:
|
||||
FileSystem = http.FS(content)
|
||||
default:
|
||||
FileSystem = emptyFS{}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,498 +0,0 @@
|
||||
// Copyright 2025 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"slices"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/frp/client/proxy"
|
||||
"github.com/fatedier/frp/pkg/config"
|
||||
"github.com/fatedier/frp/pkg/config/source"
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/config/v1/validation"
|
||||
"github.com/fatedier/frp/pkg/policy/security"
|
||||
httppkg "github.com/fatedier/frp/pkg/util/http"
|
||||
"github.com/fatedier/frp/pkg/util/log"
|
||||
)
|
||||
|
||||
// Controller handles HTTP API requests for frpc.
|
||||
type Controller struct {
|
||||
getProxyStatus func() []*proxy.WorkingStatus
|
||||
serverAddr string
|
||||
configFilePath string
|
||||
unsafeFeatures *security.UnsafeFeatures
|
||||
updateConfig func(common *v1.ClientCommonConfig, proxyCfgs []v1.ProxyConfigurer, visitorCfgs []v1.VisitorConfigurer) error
|
||||
reloadFromSources func() error
|
||||
gracefulClose func(d time.Duration)
|
||||
storeSource *source.StoreSource
|
||||
}
|
||||
|
||||
// ControllerParams contains parameters for creating an APIController.
|
||||
type ControllerParams struct {
|
||||
GetProxyStatus func() []*proxy.WorkingStatus
|
||||
ServerAddr string
|
||||
ConfigFilePath string
|
||||
UnsafeFeatures *security.UnsafeFeatures
|
||||
UpdateConfig func(common *v1.ClientCommonConfig, proxyCfgs []v1.ProxyConfigurer, visitorCfgs []v1.VisitorConfigurer) error
|
||||
ReloadFromSources func() error
|
||||
GracefulClose func(d time.Duration)
|
||||
StoreSource *source.StoreSource
|
||||
}
|
||||
|
||||
func NewController(params ControllerParams) *Controller {
|
||||
return &Controller{
|
||||
getProxyStatus: params.GetProxyStatus,
|
||||
serverAddr: params.ServerAddr,
|
||||
configFilePath: params.ConfigFilePath,
|
||||
unsafeFeatures: params.UnsafeFeatures,
|
||||
updateConfig: params.UpdateConfig,
|
||||
reloadFromSources: params.ReloadFromSources,
|
||||
gracefulClose: params.GracefulClose,
|
||||
storeSource: params.StoreSource,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Controller) reloadFromSourcesOrError() error {
|
||||
if err := c.reloadFromSources(); err != nil {
|
||||
return httppkg.NewError(http.StatusInternalServerError, fmt.Sprintf("failed to apply config: %v", err))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reload handles GET /api/reload
|
||||
func (c *Controller) Reload(ctx *httppkg.Context) (any, error) {
|
||||
strictConfigMode := false
|
||||
strictStr := ctx.Query("strictConfig")
|
||||
if strictStr != "" {
|
||||
strictConfigMode, _ = strconv.ParseBool(strictStr)
|
||||
}
|
||||
|
||||
result, err := config.LoadClientConfigResult(c.configFilePath, strictConfigMode)
|
||||
if err != nil {
|
||||
log.Warnf("reload frpc proxy config error: %s", err.Error())
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
|
||||
proxyCfgs := result.Proxies
|
||||
visitorCfgs := result.Visitors
|
||||
|
||||
proxyCfgsForValidation, visitorCfgsForValidation := config.FilterClientConfigurers(
|
||||
result.Common,
|
||||
proxyCfgs,
|
||||
visitorCfgs,
|
||||
)
|
||||
proxyCfgsForValidation = config.CompleteProxyConfigurers(proxyCfgsForValidation)
|
||||
visitorCfgsForValidation = config.CompleteVisitorConfigurers(visitorCfgsForValidation)
|
||||
|
||||
if _, err := validation.ValidateAllClientConfig(result.Common, proxyCfgsForValidation, visitorCfgsForValidation, c.unsafeFeatures); err != nil {
|
||||
log.Warnf("reload frpc proxy config error: %s", err.Error())
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
|
||||
if err := c.updateConfig(result.Common, proxyCfgs, visitorCfgs); err != nil {
|
||||
log.Warnf("reload frpc proxy config error: %s", err.Error())
|
||||
return nil, httppkg.NewError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
log.Infof("success reload conf")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Stop handles POST /api/stop
|
||||
func (c *Controller) Stop(ctx *httppkg.Context) (any, error) {
|
||||
go c.gracefulClose(100 * time.Millisecond)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Status handles GET /api/status
|
||||
func (c *Controller) Status(ctx *httppkg.Context) (any, error) {
|
||||
res := make(StatusResp)
|
||||
ps := c.getProxyStatus()
|
||||
if ps == nil {
|
||||
return res, nil
|
||||
}
|
||||
|
||||
for _, status := range ps {
|
||||
res[status.Type] = append(res[status.Type], c.buildProxyStatusResp(status))
|
||||
}
|
||||
|
||||
for _, arrs := range res {
|
||||
if len(arrs) <= 1 {
|
||||
continue
|
||||
}
|
||||
slices.SortFunc(arrs, func(a, b ProxyStatusResp) int {
|
||||
return cmp.Compare(a.Name, b.Name)
|
||||
})
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// GetConfig handles GET /api/config
|
||||
func (c *Controller) GetConfig(ctx *httppkg.Context) (any, error) {
|
||||
if c.configFilePath == "" {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, "frpc has no config file path")
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(c.configFilePath)
|
||||
if err != nil {
|
||||
log.Warnf("load frpc config file error: %s", err.Error())
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
return string(content), nil
|
||||
}
|
||||
|
||||
// PutConfig handles PUT /api/config
|
||||
func (c *Controller) PutConfig(ctx *httppkg.Context) (any, error) {
|
||||
body, err := ctx.Body()
|
||||
if err != nil {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("read request body error: %v", err))
|
||||
}
|
||||
|
||||
if len(body) == 0 {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, "body can't be empty")
|
||||
}
|
||||
|
||||
if err := os.WriteFile(c.configFilePath, body, 0o600); err != nil {
|
||||
return nil, httppkg.NewError(http.StatusInternalServerError, fmt.Sprintf("write content to frpc config file error: %v", err))
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *Controller) buildProxyStatusResp(status *proxy.WorkingStatus) ProxyStatusResp {
|
||||
psr := ProxyStatusResp{
|
||||
Name: status.Name,
|
||||
Type: status.Type,
|
||||
Status: status.Phase,
|
||||
Err: status.Err,
|
||||
}
|
||||
baseCfg := status.Cfg.GetBaseConfig()
|
||||
if baseCfg.LocalPort != 0 {
|
||||
psr.LocalAddr = net.JoinHostPort(baseCfg.LocalIP, strconv.Itoa(baseCfg.LocalPort))
|
||||
}
|
||||
psr.Plugin = baseCfg.Plugin.Type
|
||||
|
||||
if status.Err == "" {
|
||||
psr.RemoteAddr = status.RemoteAddr
|
||||
if slices.Contains([]string{"tcp", "udp"}, status.Type) {
|
||||
psr.RemoteAddr = c.serverAddr + psr.RemoteAddr
|
||||
}
|
||||
}
|
||||
|
||||
// Check if proxy is from store
|
||||
if c.storeSource != nil {
|
||||
if c.storeSource.GetProxy(status.Name) != nil {
|
||||
psr.Source = "store"
|
||||
}
|
||||
}
|
||||
return psr
|
||||
}
|
||||
|
||||
func (c *Controller) ListStoreProxies(ctx *httppkg.Context) (any, error) {
|
||||
proxies, err := c.storeSource.GetAllProxies()
|
||||
if err != nil {
|
||||
return nil, httppkg.NewError(http.StatusInternalServerError, fmt.Sprintf("failed to list proxies: %v", err))
|
||||
}
|
||||
resp := ProxyListResp{Proxies: make([]ProxyConfig, 0, len(proxies))}
|
||||
|
||||
for _, p := range proxies {
|
||||
cfg, err := proxyConfigurerToMap(p)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
resp.Proxies = append(resp.Proxies, ProxyConfig{
|
||||
Name: p.GetBaseConfig().Name,
|
||||
Type: p.GetBaseConfig().Type,
|
||||
Config: cfg,
|
||||
})
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (c *Controller) GetStoreProxy(ctx *httppkg.Context) (any, error) {
|
||||
name := ctx.Param("name")
|
||||
if name == "" {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, "proxy name is required")
|
||||
}
|
||||
|
||||
p := c.storeSource.GetProxy(name)
|
||||
if p == nil {
|
||||
return nil, httppkg.NewError(http.StatusNotFound, fmt.Sprintf("proxy %q not found", name))
|
||||
}
|
||||
|
||||
cfg, err := proxyConfigurerToMap(p)
|
||||
if err != nil {
|
||||
return nil, httppkg.NewError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
return ProxyConfig{
|
||||
Name: p.GetBaseConfig().Name,
|
||||
Type: p.GetBaseConfig().Type,
|
||||
Config: cfg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Controller) CreateStoreProxy(ctx *httppkg.Context) (any, error) {
|
||||
body, err := ctx.Body()
|
||||
if err != nil {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("read body error: %v", err))
|
||||
}
|
||||
|
||||
var typed v1.TypedProxyConfig
|
||||
if err := json.Unmarshal(body, &typed); err != nil {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("parse JSON error: %v", err))
|
||||
}
|
||||
|
||||
if typed.ProxyConfigurer == nil {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, "invalid proxy config: type is required")
|
||||
}
|
||||
|
||||
typed.Complete()
|
||||
if err := validation.ValidateProxyConfigurerForClient(typed.ProxyConfigurer); err != nil {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("validation error: %v", err))
|
||||
}
|
||||
|
||||
if err := c.storeSource.AddProxy(typed.ProxyConfigurer); err != nil {
|
||||
return nil, httppkg.NewError(http.StatusConflict, err.Error())
|
||||
}
|
||||
if err := c.reloadFromSourcesOrError(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Infof("store: created proxy %q", typed.ProxyConfigurer.GetBaseConfig().Name)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *Controller) UpdateStoreProxy(ctx *httppkg.Context) (any, error) {
|
||||
name := ctx.Param("name")
|
||||
if name == "" {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, "proxy name is required")
|
||||
}
|
||||
|
||||
body, err := ctx.Body()
|
||||
if err != nil {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("read body error: %v", err))
|
||||
}
|
||||
|
||||
var typed v1.TypedProxyConfig
|
||||
if err := json.Unmarshal(body, &typed); err != nil {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("parse JSON error: %v", err))
|
||||
}
|
||||
|
||||
if typed.ProxyConfigurer == nil {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, "invalid proxy config: type is required")
|
||||
}
|
||||
|
||||
bodyName := typed.ProxyConfigurer.GetBaseConfig().Name
|
||||
if bodyName != name {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, "proxy name in URL must match name in body")
|
||||
}
|
||||
|
||||
typed.Complete()
|
||||
if err := validation.ValidateProxyConfigurerForClient(typed.ProxyConfigurer); err != nil {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("validation error: %v", err))
|
||||
}
|
||||
|
||||
if err := c.storeSource.UpdateProxy(typed.ProxyConfigurer); err != nil {
|
||||
return nil, httppkg.NewError(http.StatusNotFound, err.Error())
|
||||
}
|
||||
if err := c.reloadFromSourcesOrError(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Infof("store: updated proxy %q", name)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *Controller) DeleteStoreProxy(ctx *httppkg.Context) (any, error) {
|
||||
name := ctx.Param("name")
|
||||
if name == "" {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, "proxy name is required")
|
||||
}
|
||||
|
||||
if err := c.storeSource.RemoveProxy(name); err != nil {
|
||||
return nil, httppkg.NewError(http.StatusNotFound, err.Error())
|
||||
}
|
||||
if err := c.reloadFromSourcesOrError(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Infof("store: deleted proxy %q", name)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *Controller) ListStoreVisitors(ctx *httppkg.Context) (any, error) {
|
||||
visitors, err := c.storeSource.GetAllVisitors()
|
||||
if err != nil {
|
||||
return nil, httppkg.NewError(http.StatusInternalServerError, fmt.Sprintf("failed to list visitors: %v", err))
|
||||
}
|
||||
resp := VisitorListResp{Visitors: make([]VisitorConfig, 0, len(visitors))}
|
||||
|
||||
for _, v := range visitors {
|
||||
cfg, err := visitorConfigurerToMap(v)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
resp.Visitors = append(resp.Visitors, VisitorConfig{
|
||||
Name: v.GetBaseConfig().Name,
|
||||
Type: v.GetBaseConfig().Type,
|
||||
Config: cfg,
|
||||
})
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (c *Controller) GetStoreVisitor(ctx *httppkg.Context) (any, error) {
|
||||
name := ctx.Param("name")
|
||||
if name == "" {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, "visitor name is required")
|
||||
}
|
||||
|
||||
v := c.storeSource.GetVisitor(name)
|
||||
if v == nil {
|
||||
return nil, httppkg.NewError(http.StatusNotFound, fmt.Sprintf("visitor %q not found", name))
|
||||
}
|
||||
|
||||
cfg, err := visitorConfigurerToMap(v)
|
||||
if err != nil {
|
||||
return nil, httppkg.NewError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
return VisitorConfig{
|
||||
Name: v.GetBaseConfig().Name,
|
||||
Type: v.GetBaseConfig().Type,
|
||||
Config: cfg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Controller) CreateStoreVisitor(ctx *httppkg.Context) (any, error) {
|
||||
body, err := ctx.Body()
|
||||
if err != nil {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("read body error: %v", err))
|
||||
}
|
||||
|
||||
var typed v1.TypedVisitorConfig
|
||||
if err := json.Unmarshal(body, &typed); err != nil {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("parse JSON error: %v", err))
|
||||
}
|
||||
|
||||
if typed.VisitorConfigurer == nil {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, "invalid visitor config: type is required")
|
||||
}
|
||||
|
||||
typed.Complete()
|
||||
if err := validation.ValidateVisitorConfigurer(typed.VisitorConfigurer); err != nil {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("validation error: %v", err))
|
||||
}
|
||||
|
||||
if err := c.storeSource.AddVisitor(typed.VisitorConfigurer); err != nil {
|
||||
return nil, httppkg.NewError(http.StatusConflict, err.Error())
|
||||
}
|
||||
if err := c.reloadFromSourcesOrError(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Infof("store: created visitor %q", typed.VisitorConfigurer.GetBaseConfig().Name)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *Controller) UpdateStoreVisitor(ctx *httppkg.Context) (any, error) {
|
||||
name := ctx.Param("name")
|
||||
if name == "" {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, "visitor name is required")
|
||||
}
|
||||
|
||||
body, err := ctx.Body()
|
||||
if err != nil {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("read body error: %v", err))
|
||||
}
|
||||
|
||||
var typed v1.TypedVisitorConfig
|
||||
if err := json.Unmarshal(body, &typed); err != nil {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("parse JSON error: %v", err))
|
||||
}
|
||||
|
||||
if typed.VisitorConfigurer == nil {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, "invalid visitor config: type is required")
|
||||
}
|
||||
|
||||
bodyName := typed.VisitorConfigurer.GetBaseConfig().Name
|
||||
if bodyName != name {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, "visitor name in URL must match name in body")
|
||||
}
|
||||
|
||||
typed.Complete()
|
||||
if err := validation.ValidateVisitorConfigurer(typed.VisitorConfigurer); err != nil {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("validation error: %v", err))
|
||||
}
|
||||
|
||||
if err := c.storeSource.UpdateVisitor(typed.VisitorConfigurer); err != nil {
|
||||
return nil, httppkg.NewError(http.StatusNotFound, err.Error())
|
||||
}
|
||||
if err := c.reloadFromSourcesOrError(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Infof("store: updated visitor %q", name)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *Controller) DeleteStoreVisitor(ctx *httppkg.Context) (any, error) {
|
||||
name := ctx.Param("name")
|
||||
if name == "" {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, "visitor name is required")
|
||||
}
|
||||
|
||||
if err := c.storeSource.RemoveVisitor(name); err != nil {
|
||||
return nil, httppkg.NewError(http.StatusNotFound, err.Error())
|
||||
}
|
||||
if err := c.reloadFromSourcesOrError(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Infof("store: deleted visitor %q", name)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func proxyConfigurerToMap(p v1.ProxyConfigurer) (map[string]any, error) {
|
||||
data, err := json.Marshal(p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal(data, &m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func visitorConfigurerToMap(v v1.VisitorConfigurer) (map[string]any, error) {
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal(data, &m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
@@ -17,7 +17,7 @@ package client
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/fatedier/frp/client/api"
|
||||
adminapi "github.com/fatedier/frp/client/http"
|
||||
"github.com/fatedier/frp/client/proxy"
|
||||
httppkg "github.com/fatedier/frp/pkg/util/http"
|
||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||
@@ -65,16 +65,11 @@ func healthz(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func newAPIController(svr *Service) *api.Controller {
|
||||
return api.NewController(api.ControllerParams{
|
||||
GetProxyStatus: svr.getAllProxyStatus,
|
||||
ServerAddr: svr.common.ServerAddr,
|
||||
ConfigFilePath: svr.configFilePath,
|
||||
UnsafeFeatures: svr.unsafeFeatures,
|
||||
UpdateConfig: svr.UpdateConfigSource,
|
||||
ReloadFromSources: svr.reloadConfigFromSources,
|
||||
GracefulClose: svr.GracefulClose,
|
||||
StoreSource: svr.storeSource,
|
||||
func newAPIController(svr *Service) *adminapi.Controller {
|
||||
manager := newServiceConfigManager(svr)
|
||||
return adminapi.NewController(adminapi.ControllerParams{
|
||||
ServerAddr: svr.common.ServerAddr,
|
||||
Manager: manager,
|
||||
})
|
||||
}
|
||||
|
||||
422
client/config_manager.go
Normal file
422
client/config_manager.go
Normal file
@@ -0,0 +1,422 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/frp/client/configmgmt"
|
||||
"github.com/fatedier/frp/client/proxy"
|
||||
"github.com/fatedier/frp/pkg/config"
|
||||
"github.com/fatedier/frp/pkg/config/source"
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/config/v1/validation"
|
||||
"github.com/fatedier/frp/pkg/util/log"
|
||||
)
|
||||
|
||||
type serviceConfigManager struct {
|
||||
svr *Service
|
||||
}
|
||||
|
||||
func newServiceConfigManager(svr *Service) configmgmt.ConfigManager {
|
||||
return &serviceConfigManager{svr: svr}
|
||||
}
|
||||
|
||||
func (m *serviceConfigManager) ReloadFromFile(strict bool) error {
|
||||
if m.svr.configFilePath == "" {
|
||||
return fmt.Errorf("%w: frpc has no config file path", configmgmt.ErrInvalidArgument)
|
||||
}
|
||||
|
||||
result, err := config.LoadClientConfigResult(m.svr.configFilePath, strict)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", configmgmt.ErrInvalidArgument, err)
|
||||
}
|
||||
|
||||
proxyCfgsForValidation, visitorCfgsForValidation := config.FilterClientConfigurers(
|
||||
result.Common,
|
||||
result.Proxies,
|
||||
result.Visitors,
|
||||
)
|
||||
proxyCfgsForValidation = config.CompleteProxyConfigurers(proxyCfgsForValidation)
|
||||
visitorCfgsForValidation = config.CompleteVisitorConfigurers(visitorCfgsForValidation)
|
||||
|
||||
if _, err := validation.ValidateAllClientConfig(result.Common, proxyCfgsForValidation, visitorCfgsForValidation, m.svr.unsafeFeatures); err != nil {
|
||||
return fmt.Errorf("%w: %v", configmgmt.ErrInvalidArgument, err)
|
||||
}
|
||||
|
||||
if err := m.svr.UpdateConfigSource(result.Common, result.Proxies, result.Visitors); err != nil {
|
||||
return fmt.Errorf("%w: %v", configmgmt.ErrApplyConfig, err)
|
||||
}
|
||||
|
||||
log.Infof("success reload conf")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *serviceConfigManager) ReadConfigFile() (string, error) {
|
||||
if m.svr.configFilePath == "" {
|
||||
return "", fmt.Errorf("%w: frpc has no config file path", configmgmt.ErrInvalidArgument)
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(m.svr.configFilePath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("%w: %v", configmgmt.ErrInvalidArgument, err)
|
||||
}
|
||||
return string(content), nil
|
||||
}
|
||||
|
||||
func (m *serviceConfigManager) WriteConfigFile(content []byte) error {
|
||||
if len(content) == 0 {
|
||||
return fmt.Errorf("%w: body can't be empty", configmgmt.ErrInvalidArgument)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(m.svr.configFilePath, content, 0o600); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *serviceConfigManager) GetProxyStatus() []*proxy.WorkingStatus {
|
||||
return m.svr.getAllProxyStatus()
|
||||
}
|
||||
|
||||
func (m *serviceConfigManager) IsStoreProxyEnabled(name string) bool {
|
||||
if name == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
m.svr.reloadMu.Lock()
|
||||
storeSource := m.svr.storeSource
|
||||
m.svr.reloadMu.Unlock()
|
||||
|
||||
if storeSource == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
cfg := storeSource.GetProxy(name)
|
||||
if cfg == nil {
|
||||
return false
|
||||
}
|
||||
enabled := cfg.GetBaseConfig().Enabled
|
||||
return enabled == nil || *enabled
|
||||
}
|
||||
|
||||
func (m *serviceConfigManager) StoreEnabled() bool {
|
||||
m.svr.reloadMu.Lock()
|
||||
storeSource := m.svr.storeSource
|
||||
m.svr.reloadMu.Unlock()
|
||||
return storeSource != nil
|
||||
}
|
||||
|
||||
func (m *serviceConfigManager) ListStoreProxies() ([]v1.ProxyConfigurer, error) {
|
||||
storeSource, err := m.storeSourceOrError()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return storeSource.GetAllProxies()
|
||||
}
|
||||
|
||||
func (m *serviceConfigManager) GetStoreProxy(name string) (v1.ProxyConfigurer, error) {
|
||||
if name == "" {
|
||||
return nil, fmt.Errorf("%w: proxy name is required", configmgmt.ErrInvalidArgument)
|
||||
}
|
||||
|
||||
storeSource, err := m.storeSourceOrError()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cfg := storeSource.GetProxy(name)
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("%w: proxy %q", configmgmt.ErrNotFound, name)
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (m *serviceConfigManager) CreateStoreProxy(cfg v1.ProxyConfigurer) (v1.ProxyConfigurer, error) {
|
||||
if err := m.validateStoreProxyConfigurer(cfg); err != nil {
|
||||
return nil, fmt.Errorf("%w: validation error: %v", configmgmt.ErrInvalidArgument, err)
|
||||
}
|
||||
|
||||
name := cfg.GetBaseConfig().Name
|
||||
persisted, err := m.withStoreProxyMutationAndReload(name, func(storeSource *source.StoreSource) error {
|
||||
if err := storeSource.AddProxy(cfg); err != nil {
|
||||
if errors.Is(err, source.ErrAlreadyExists) {
|
||||
return fmt.Errorf("%w: %v", configmgmt.ErrConflict, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Infof("store: created proxy %q", name)
|
||||
return persisted, nil
|
||||
}
|
||||
|
||||
func (m *serviceConfigManager) UpdateStoreProxy(name string, cfg v1.ProxyConfigurer) (v1.ProxyConfigurer, error) {
|
||||
if name == "" {
|
||||
return nil, fmt.Errorf("%w: proxy name is required", configmgmt.ErrInvalidArgument)
|
||||
}
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("%w: invalid proxy config: type is required", configmgmt.ErrInvalidArgument)
|
||||
}
|
||||
bodyName := cfg.GetBaseConfig().Name
|
||||
if bodyName != name {
|
||||
return nil, fmt.Errorf("%w: proxy name in URL must match name in body", configmgmt.ErrInvalidArgument)
|
||||
}
|
||||
if err := m.validateStoreProxyConfigurer(cfg); err != nil {
|
||||
return nil, fmt.Errorf("%w: validation error: %v", configmgmt.ErrInvalidArgument, err)
|
||||
}
|
||||
|
||||
persisted, err := m.withStoreProxyMutationAndReload(name, func(storeSource *source.StoreSource) error {
|
||||
if err := storeSource.UpdateProxy(cfg); err != nil {
|
||||
if errors.Is(err, source.ErrNotFound) {
|
||||
return fmt.Errorf("%w: %v", configmgmt.ErrNotFound, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Infof("store: updated proxy %q", name)
|
||||
return persisted, nil
|
||||
}
|
||||
|
||||
func (m *serviceConfigManager) DeleteStoreProxy(name string) error {
|
||||
if name == "" {
|
||||
return fmt.Errorf("%w: proxy name is required", configmgmt.ErrInvalidArgument)
|
||||
}
|
||||
|
||||
if err := m.withStoreMutationAndReload(func(storeSource *source.StoreSource) error {
|
||||
if err := storeSource.RemoveProxy(name); err != nil {
|
||||
if errors.Is(err, source.ErrNotFound) {
|
||||
return fmt.Errorf("%w: %v", configmgmt.ErrNotFound, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Infof("store: deleted proxy %q", name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *serviceConfigManager) ListStoreVisitors() ([]v1.VisitorConfigurer, error) {
|
||||
storeSource, err := m.storeSourceOrError()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return storeSource.GetAllVisitors()
|
||||
}
|
||||
|
||||
func (m *serviceConfigManager) GetStoreVisitor(name string) (v1.VisitorConfigurer, error) {
|
||||
if name == "" {
|
||||
return nil, fmt.Errorf("%w: visitor name is required", configmgmt.ErrInvalidArgument)
|
||||
}
|
||||
|
||||
storeSource, err := m.storeSourceOrError()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cfg := storeSource.GetVisitor(name)
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("%w: visitor %q", configmgmt.ErrNotFound, name)
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (m *serviceConfigManager) CreateStoreVisitor(cfg v1.VisitorConfigurer) (v1.VisitorConfigurer, error) {
|
||||
if err := m.validateStoreVisitorConfigurer(cfg); err != nil {
|
||||
return nil, fmt.Errorf("%w: validation error: %v", configmgmt.ErrInvalidArgument, err)
|
||||
}
|
||||
|
||||
name := cfg.GetBaseConfig().Name
|
||||
persisted, err := m.withStoreVisitorMutationAndReload(name, func(storeSource *source.StoreSource) error {
|
||||
if err := storeSource.AddVisitor(cfg); err != nil {
|
||||
if errors.Is(err, source.ErrAlreadyExists) {
|
||||
return fmt.Errorf("%w: %v", configmgmt.ErrConflict, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Infof("store: created visitor %q", name)
|
||||
return persisted, nil
|
||||
}
|
||||
|
||||
func (m *serviceConfigManager) UpdateStoreVisitor(name string, cfg v1.VisitorConfigurer) (v1.VisitorConfigurer, error) {
|
||||
if name == "" {
|
||||
return nil, fmt.Errorf("%w: visitor name is required", configmgmt.ErrInvalidArgument)
|
||||
}
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("%w: invalid visitor config: type is required", configmgmt.ErrInvalidArgument)
|
||||
}
|
||||
bodyName := cfg.GetBaseConfig().Name
|
||||
if bodyName != name {
|
||||
return nil, fmt.Errorf("%w: visitor name in URL must match name in body", configmgmt.ErrInvalidArgument)
|
||||
}
|
||||
if err := m.validateStoreVisitorConfigurer(cfg); err != nil {
|
||||
return nil, fmt.Errorf("%w: validation error: %v", configmgmt.ErrInvalidArgument, err)
|
||||
}
|
||||
|
||||
persisted, err := m.withStoreVisitorMutationAndReload(name, func(storeSource *source.StoreSource) error {
|
||||
if err := storeSource.UpdateVisitor(cfg); err != nil {
|
||||
if errors.Is(err, source.ErrNotFound) {
|
||||
return fmt.Errorf("%w: %v", configmgmt.ErrNotFound, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Infof("store: updated visitor %q", name)
|
||||
return persisted, nil
|
||||
}
|
||||
|
||||
func (m *serviceConfigManager) DeleteStoreVisitor(name string) error {
|
||||
if name == "" {
|
||||
return fmt.Errorf("%w: visitor name is required", configmgmt.ErrInvalidArgument)
|
||||
}
|
||||
|
||||
if err := m.withStoreMutationAndReload(func(storeSource *source.StoreSource) error {
|
||||
if err := storeSource.RemoveVisitor(name); err != nil {
|
||||
if errors.Is(err, source.ErrNotFound) {
|
||||
return fmt.Errorf("%w: %v", configmgmt.ErrNotFound, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Infof("store: deleted visitor %q", name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *serviceConfigManager) GracefulClose(d time.Duration) {
|
||||
m.svr.GracefulClose(d)
|
||||
}
|
||||
|
||||
func (m *serviceConfigManager) storeSourceOrError() (*source.StoreSource, error) {
|
||||
m.svr.reloadMu.Lock()
|
||||
storeSource := m.svr.storeSource
|
||||
m.svr.reloadMu.Unlock()
|
||||
|
||||
if storeSource == nil {
|
||||
return nil, fmt.Errorf("%w: store API is disabled", configmgmt.ErrStoreDisabled)
|
||||
}
|
||||
return storeSource, nil
|
||||
}
|
||||
|
||||
func (m *serviceConfigManager) withStoreMutationAndReload(
|
||||
fn func(storeSource *source.StoreSource) error,
|
||||
) error {
|
||||
m.svr.reloadMu.Lock()
|
||||
defer m.svr.reloadMu.Unlock()
|
||||
|
||||
storeSource := m.svr.storeSource
|
||||
if storeSource == nil {
|
||||
return fmt.Errorf("%w: store API is disabled", configmgmt.ErrStoreDisabled)
|
||||
}
|
||||
|
||||
if err := fn(storeSource); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := m.svr.reloadConfigFromSourcesLocked(); err != nil {
|
||||
return fmt.Errorf("%w: failed to apply config: %v", configmgmt.ErrApplyConfig, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *serviceConfigManager) withStoreProxyMutationAndReload(
|
||||
name string,
|
||||
fn func(storeSource *source.StoreSource) error,
|
||||
) (v1.ProxyConfigurer, error) {
|
||||
m.svr.reloadMu.Lock()
|
||||
defer m.svr.reloadMu.Unlock()
|
||||
|
||||
storeSource := m.svr.storeSource
|
||||
if storeSource == nil {
|
||||
return nil, fmt.Errorf("%w: store API is disabled", configmgmt.ErrStoreDisabled)
|
||||
}
|
||||
|
||||
if err := fn(storeSource); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := m.svr.reloadConfigFromSourcesLocked(); err != nil {
|
||||
return nil, fmt.Errorf("%w: failed to apply config: %v", configmgmt.ErrApplyConfig, err)
|
||||
}
|
||||
|
||||
persisted := storeSource.GetProxy(name)
|
||||
if persisted == nil {
|
||||
return nil, fmt.Errorf("%w: proxy %q not found in store after mutation", configmgmt.ErrApplyConfig, name)
|
||||
}
|
||||
return persisted.Clone(), nil
|
||||
}
|
||||
|
||||
func (m *serviceConfigManager) withStoreVisitorMutationAndReload(
|
||||
name string,
|
||||
fn func(storeSource *source.StoreSource) error,
|
||||
) (v1.VisitorConfigurer, error) {
|
||||
m.svr.reloadMu.Lock()
|
||||
defer m.svr.reloadMu.Unlock()
|
||||
|
||||
storeSource := m.svr.storeSource
|
||||
if storeSource == nil {
|
||||
return nil, fmt.Errorf("%w: store API is disabled", configmgmt.ErrStoreDisabled)
|
||||
}
|
||||
|
||||
if err := fn(storeSource); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := m.svr.reloadConfigFromSourcesLocked(); err != nil {
|
||||
return nil, fmt.Errorf("%w: failed to apply config: %v", configmgmt.ErrApplyConfig, err)
|
||||
}
|
||||
|
||||
persisted := storeSource.GetVisitor(name)
|
||||
if persisted == nil {
|
||||
return nil, fmt.Errorf("%w: visitor %q not found in store after mutation", configmgmt.ErrApplyConfig, name)
|
||||
}
|
||||
return persisted.Clone(), nil
|
||||
}
|
||||
|
||||
func (m *serviceConfigManager) validateStoreProxyConfigurer(cfg v1.ProxyConfigurer) error {
|
||||
if cfg == nil {
|
||||
return fmt.Errorf("invalid proxy config")
|
||||
}
|
||||
runtimeCfg := cfg.Clone()
|
||||
if runtimeCfg == nil {
|
||||
return fmt.Errorf("invalid proxy config")
|
||||
}
|
||||
runtimeCfg.Complete()
|
||||
return validation.ValidateProxyConfigurerForClient(runtimeCfg)
|
||||
}
|
||||
|
||||
func (m *serviceConfigManager) validateStoreVisitorConfigurer(cfg v1.VisitorConfigurer) error {
|
||||
if cfg == nil {
|
||||
return fmt.Errorf("invalid visitor config")
|
||||
}
|
||||
runtimeCfg := cfg.Clone()
|
||||
if runtimeCfg == nil {
|
||||
return fmt.Errorf("invalid visitor config")
|
||||
}
|
||||
runtimeCfg.Complete()
|
||||
return validation.ValidateVisitorConfigurer(runtimeCfg)
|
||||
}
|
||||
137
client/config_manager_test.go
Normal file
137
client/config_manager_test.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/fatedier/frp/client/configmgmt"
|
||||
"github.com/fatedier/frp/pkg/config/source"
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
)
|
||||
|
||||
func newTestRawTCPProxyConfig(name string) *v1.TCPProxyConfig {
|
||||
return &v1.TCPProxyConfig{
|
||||
ProxyBaseConfig: v1.ProxyBaseConfig{
|
||||
Name: name,
|
||||
Type: "tcp",
|
||||
ProxyBackend: v1.ProxyBackend{
|
||||
LocalPort: 10080,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestServiceConfigManagerCreateStoreProxyConflict(t *testing.T) {
|
||||
storeSource, err := source.NewStoreSource(source.StoreSourceConfig{
|
||||
Path: filepath.Join(t.TempDir(), "store.json"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("new store source: %v", err)
|
||||
}
|
||||
if err := storeSource.AddProxy(newTestRawTCPProxyConfig("p1")); err != nil {
|
||||
t.Fatalf("seed proxy: %v", err)
|
||||
}
|
||||
|
||||
agg := source.NewAggregator(source.NewConfigSource())
|
||||
agg.SetStoreSource(storeSource)
|
||||
|
||||
mgr := &serviceConfigManager{
|
||||
svr: &Service{
|
||||
aggregator: agg,
|
||||
configSource: agg.ConfigSource(),
|
||||
storeSource: storeSource,
|
||||
reloadCommon: &v1.ClientCommonConfig{},
|
||||
},
|
||||
}
|
||||
|
||||
_, err = mgr.CreateStoreProxy(newTestRawTCPProxyConfig("p1"))
|
||||
if err == nil {
|
||||
t.Fatal("expected conflict error")
|
||||
}
|
||||
if !errors.Is(err, configmgmt.ErrConflict) {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServiceConfigManagerCreateStoreProxyKeepsStoreOnReloadFailure(t *testing.T) {
|
||||
storeSource, err := source.NewStoreSource(source.StoreSourceConfig{
|
||||
Path: filepath.Join(t.TempDir(), "store.json"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("new store source: %v", err)
|
||||
}
|
||||
|
||||
mgr := &serviceConfigManager{
|
||||
svr: &Service{
|
||||
storeSource: storeSource,
|
||||
reloadCommon: &v1.ClientCommonConfig{},
|
||||
},
|
||||
}
|
||||
|
||||
_, err = mgr.CreateStoreProxy(newTestRawTCPProxyConfig("p1"))
|
||||
if err == nil {
|
||||
t.Fatal("expected apply config error")
|
||||
}
|
||||
if !errors.Is(err, configmgmt.ErrApplyConfig) {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if storeSource.GetProxy("p1") == nil {
|
||||
t.Fatal("proxy should remain in store after reload failure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServiceConfigManagerCreateStoreProxyStoreDisabled(t *testing.T) {
|
||||
mgr := &serviceConfigManager{
|
||||
svr: &Service{
|
||||
reloadCommon: &v1.ClientCommonConfig{},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := mgr.CreateStoreProxy(newTestRawTCPProxyConfig("p1"))
|
||||
if err == nil {
|
||||
t.Fatal("expected store disabled error")
|
||||
}
|
||||
if !errors.Is(err, configmgmt.ErrStoreDisabled) {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServiceConfigManagerCreateStoreProxyDoesNotPersistRuntimeDefaults(t *testing.T) {
|
||||
storeSource, err := source.NewStoreSource(source.StoreSourceConfig{
|
||||
Path: filepath.Join(t.TempDir(), "store.json"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("new store source: %v", err)
|
||||
}
|
||||
agg := source.NewAggregator(source.NewConfigSource())
|
||||
agg.SetStoreSource(storeSource)
|
||||
|
||||
mgr := &serviceConfigManager{
|
||||
svr: &Service{
|
||||
aggregator: agg,
|
||||
configSource: agg.ConfigSource(),
|
||||
storeSource: storeSource,
|
||||
reloadCommon: &v1.ClientCommonConfig{},
|
||||
},
|
||||
}
|
||||
|
||||
persisted, err := mgr.CreateStoreProxy(newTestRawTCPProxyConfig("raw-proxy"))
|
||||
if err != nil {
|
||||
t.Fatalf("create store proxy: %v", err)
|
||||
}
|
||||
if persisted == nil {
|
||||
t.Fatal("expected persisted proxy to be returned")
|
||||
}
|
||||
|
||||
got := storeSource.GetProxy("raw-proxy")
|
||||
if got == nil {
|
||||
t.Fatal("proxy not found in store")
|
||||
}
|
||||
if got.GetBaseConfig().LocalIP != "" {
|
||||
t.Fatalf("localIP was persisted with runtime default: %q", got.GetBaseConfig().LocalIP)
|
||||
}
|
||||
if got.GetBaseConfig().Transport.BandwidthLimitMode != "" {
|
||||
t.Fatalf("bandwidthLimitMode was persisted with runtime default: %q", got.GetBaseConfig().Transport.BandwidthLimitMode)
|
||||
}
|
||||
}
|
||||
42
client/configmgmt/types.go
Normal file
42
client/configmgmt/types.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package configmgmt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/frp/client/proxy"
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidArgument = errors.New("invalid argument")
|
||||
ErrNotFound = errors.New("not found")
|
||||
ErrConflict = errors.New("conflict")
|
||||
ErrStoreDisabled = errors.New("store disabled")
|
||||
ErrApplyConfig = errors.New("apply config failed")
|
||||
)
|
||||
|
||||
type ConfigManager interface {
|
||||
ReloadFromFile(strict bool) error
|
||||
|
||||
ReadConfigFile() (string, error)
|
||||
WriteConfigFile(content []byte) error
|
||||
|
||||
GetProxyStatus() []*proxy.WorkingStatus
|
||||
IsStoreProxyEnabled(name string) bool
|
||||
StoreEnabled() bool
|
||||
|
||||
ListStoreProxies() ([]v1.ProxyConfigurer, error)
|
||||
GetStoreProxy(name string) (v1.ProxyConfigurer, error)
|
||||
CreateStoreProxy(cfg v1.ProxyConfigurer) (v1.ProxyConfigurer, error)
|
||||
UpdateStoreProxy(name string, cfg v1.ProxyConfigurer) (v1.ProxyConfigurer, error)
|
||||
DeleteStoreProxy(name string) error
|
||||
|
||||
ListStoreVisitors() ([]v1.VisitorConfigurer, error)
|
||||
GetStoreVisitor(name string) (v1.VisitorConfigurer, error)
|
||||
CreateStoreVisitor(cfg v1.VisitorConfigurer) (v1.VisitorConfigurer, error)
|
||||
UpdateStoreVisitor(name string, cfg v1.VisitorConfigurer) (v1.VisitorConfigurer, error)
|
||||
DeleteStoreVisitor(name string) error
|
||||
|
||||
GracefulClose(d time.Duration)
|
||||
}
|
||||
@@ -25,9 +25,9 @@ import (
|
||||
"github.com/fatedier/frp/pkg/auth"
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
"github.com/fatedier/frp/pkg/naming"
|
||||
"github.com/fatedier/frp/pkg/transport"
|
||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||
"github.com/fatedier/frp/pkg/util/util"
|
||||
"github.com/fatedier/frp/pkg/util/wait"
|
||||
"github.com/fatedier/frp/pkg/util/xlog"
|
||||
"github.com/fatedier/frp/pkg/vnet"
|
||||
@@ -157,7 +157,7 @@ func (ctl *Control) handleReqWorkConn(_ msg.Message) {
|
||||
return
|
||||
}
|
||||
|
||||
startMsg.ProxyName = util.StripUserPrefix(ctl.sessionCtx.Common.User, startMsg.ProxyName)
|
||||
startMsg.ProxyName = naming.StripUserPrefix(ctl.sessionCtx.Common.User, startMsg.ProxyName)
|
||||
|
||||
// dispatch this work connection to related proxy
|
||||
ctl.pm.HandleWorkConn(startMsg.ProxyName, workConn, &startMsg)
|
||||
@@ -168,7 +168,7 @@ func (ctl *Control) handleNewProxyResp(m msg.Message) {
|
||||
inMsg := m.(*msg.NewProxyResp)
|
||||
// Server will return NewProxyResp message to each NewProxy message.
|
||||
// Start a new proxy handler if no error got
|
||||
proxyName := util.StripUserPrefix(ctl.sessionCtx.Common.User, inMsg.ProxyName)
|
||||
proxyName := naming.StripUserPrefix(ctl.sessionCtx.Common.User, inMsg.ProxyName)
|
||||
err := ctl.pm.StartProxy(proxyName, inMsg.RemoteAddr, inMsg.Error)
|
||||
if err != nil {
|
||||
xl.Warnf("[%s] start error: %v", proxyName, err)
|
||||
|
||||
395
client/http/controller.go
Normal file
395
client/http/controller.go
Normal file
@@ -0,0 +1,395 @@
|
||||
// Copyright 2025 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package http
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/frp/client/configmgmt"
|
||||
"github.com/fatedier/frp/client/http/model"
|
||||
"github.com/fatedier/frp/client/proxy"
|
||||
httppkg "github.com/fatedier/frp/pkg/util/http"
|
||||
"github.com/fatedier/frp/pkg/util/jsonx"
|
||||
)
|
||||
|
||||
// Controller handles HTTP API requests for frpc.
|
||||
type Controller struct {
|
||||
serverAddr string
|
||||
manager configmgmt.ConfigManager
|
||||
}
|
||||
|
||||
// ControllerParams contains parameters for creating an APIController.
|
||||
type ControllerParams struct {
|
||||
ServerAddr string
|
||||
Manager configmgmt.ConfigManager
|
||||
}
|
||||
|
||||
func NewController(params ControllerParams) *Controller {
|
||||
return &Controller{
|
||||
serverAddr: params.ServerAddr,
|
||||
manager: params.Manager,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Controller) toHTTPError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
code := http.StatusInternalServerError
|
||||
switch {
|
||||
case errors.Is(err, configmgmt.ErrInvalidArgument):
|
||||
code = http.StatusBadRequest
|
||||
case errors.Is(err, configmgmt.ErrNotFound), errors.Is(err, configmgmt.ErrStoreDisabled):
|
||||
code = http.StatusNotFound
|
||||
case errors.Is(err, configmgmt.ErrConflict):
|
||||
code = http.StatusConflict
|
||||
}
|
||||
return httppkg.NewError(code, err.Error())
|
||||
}
|
||||
|
||||
// Reload handles GET /api/reload
|
||||
func (c *Controller) Reload(ctx *httppkg.Context) (any, error) {
|
||||
strictConfigMode := false
|
||||
strictStr := ctx.Query("strictConfig")
|
||||
if strictStr != "" {
|
||||
strictConfigMode, _ = strconv.ParseBool(strictStr)
|
||||
}
|
||||
|
||||
if err := c.manager.ReloadFromFile(strictConfigMode); err != nil {
|
||||
return nil, c.toHTTPError(err)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Stop handles POST /api/stop
|
||||
func (c *Controller) Stop(ctx *httppkg.Context) (any, error) {
|
||||
go c.manager.GracefulClose(100 * time.Millisecond)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Status handles GET /api/status
|
||||
func (c *Controller) Status(ctx *httppkg.Context) (any, error) {
|
||||
res := make(model.StatusResp)
|
||||
ps := c.manager.GetProxyStatus()
|
||||
if ps == nil {
|
||||
return res, nil
|
||||
}
|
||||
|
||||
for _, status := range ps {
|
||||
res[status.Type] = append(res[status.Type], c.buildProxyStatusResp(status))
|
||||
}
|
||||
|
||||
for _, arrs := range res {
|
||||
if len(arrs) <= 1 {
|
||||
continue
|
||||
}
|
||||
slices.SortFunc(arrs, func(a, b model.ProxyStatusResp) int {
|
||||
return cmp.Compare(a.Name, b.Name)
|
||||
})
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// GetConfig handles GET /api/config
|
||||
func (c *Controller) GetConfig(ctx *httppkg.Context) (any, error) {
|
||||
content, err := c.manager.ReadConfigFile()
|
||||
if err != nil {
|
||||
return nil, c.toHTTPError(err)
|
||||
}
|
||||
return content, nil
|
||||
}
|
||||
|
||||
// PutConfig handles PUT /api/config
|
||||
func (c *Controller) PutConfig(ctx *httppkg.Context) (any, error) {
|
||||
body, err := ctx.Body()
|
||||
if err != nil {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("read request body error: %v", err))
|
||||
}
|
||||
|
||||
if len(body) == 0 {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, "body can't be empty")
|
||||
}
|
||||
|
||||
if err := c.manager.WriteConfigFile(body); err != nil {
|
||||
return nil, c.toHTTPError(err)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *Controller) buildProxyStatusResp(status *proxy.WorkingStatus) model.ProxyStatusResp {
|
||||
psr := model.ProxyStatusResp{
|
||||
Name: status.Name,
|
||||
Type: status.Type,
|
||||
Status: status.Phase,
|
||||
Err: status.Err,
|
||||
}
|
||||
baseCfg := status.Cfg.GetBaseConfig()
|
||||
if baseCfg.LocalPort != 0 {
|
||||
psr.LocalAddr = net.JoinHostPort(baseCfg.LocalIP, strconv.Itoa(baseCfg.LocalPort))
|
||||
}
|
||||
psr.Plugin = baseCfg.Plugin.Type
|
||||
|
||||
if status.Err == "" {
|
||||
psr.RemoteAddr = status.RemoteAddr
|
||||
if slices.Contains([]string{"tcp", "udp"}, status.Type) {
|
||||
psr.RemoteAddr = c.serverAddr + psr.RemoteAddr
|
||||
}
|
||||
}
|
||||
|
||||
if c.manager.IsStoreProxyEnabled(status.Name) {
|
||||
psr.Source = model.SourceStore
|
||||
}
|
||||
return psr
|
||||
}
|
||||
|
||||
func (c *Controller) ListStoreProxies(ctx *httppkg.Context) (any, error) {
|
||||
proxies, err := c.manager.ListStoreProxies()
|
||||
if err != nil {
|
||||
return nil, c.toHTTPError(err)
|
||||
}
|
||||
|
||||
resp := model.ProxyListResp{Proxies: make([]model.ProxyDefinition, 0, len(proxies))}
|
||||
for _, p := range proxies {
|
||||
payload, err := model.ProxyDefinitionFromConfigurer(p)
|
||||
if err != nil {
|
||||
return nil, httppkg.NewError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
resp.Proxies = append(resp.Proxies, payload)
|
||||
}
|
||||
slices.SortFunc(resp.Proxies, func(a, b model.ProxyDefinition) int {
|
||||
return cmp.Compare(a.Name, b.Name)
|
||||
})
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (c *Controller) GetStoreProxy(ctx *httppkg.Context) (any, error) {
|
||||
name := ctx.Param("name")
|
||||
if name == "" {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, "proxy name is required")
|
||||
}
|
||||
|
||||
p, err := c.manager.GetStoreProxy(name)
|
||||
if err != nil {
|
||||
return nil, c.toHTTPError(err)
|
||||
}
|
||||
|
||||
payload, err := model.ProxyDefinitionFromConfigurer(p)
|
||||
if err != nil {
|
||||
return nil, httppkg.NewError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func (c *Controller) CreateStoreProxy(ctx *httppkg.Context) (any, error) {
|
||||
body, err := ctx.Body()
|
||||
if err != nil {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("read body error: %v", err))
|
||||
}
|
||||
|
||||
var payload model.ProxyDefinition
|
||||
if err := jsonx.Unmarshal(body, &payload); err != nil {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("parse JSON error: %v", err))
|
||||
}
|
||||
|
||||
if err := payload.Validate("", false); err != nil {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
cfg, err := payload.ToConfigurer()
|
||||
if err != nil {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
created, err := c.manager.CreateStoreProxy(cfg)
|
||||
if err != nil {
|
||||
return nil, c.toHTTPError(err)
|
||||
}
|
||||
|
||||
resp, err := model.ProxyDefinitionFromConfigurer(created)
|
||||
if err != nil {
|
||||
return nil, httppkg.NewError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (c *Controller) UpdateStoreProxy(ctx *httppkg.Context) (any, error) {
|
||||
name := ctx.Param("name")
|
||||
if name == "" {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, "proxy name is required")
|
||||
}
|
||||
|
||||
body, err := ctx.Body()
|
||||
if err != nil {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("read body error: %v", err))
|
||||
}
|
||||
|
||||
var payload model.ProxyDefinition
|
||||
if err := jsonx.Unmarshal(body, &payload); err != nil {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("parse JSON error: %v", err))
|
||||
}
|
||||
|
||||
if err := payload.Validate(name, true); err != nil {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
cfg, err := payload.ToConfigurer()
|
||||
if err != nil {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
updated, err := c.manager.UpdateStoreProxy(name, cfg)
|
||||
if err != nil {
|
||||
return nil, c.toHTTPError(err)
|
||||
}
|
||||
|
||||
resp, err := model.ProxyDefinitionFromConfigurer(updated)
|
||||
if err != nil {
|
||||
return nil, httppkg.NewError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (c *Controller) DeleteStoreProxy(ctx *httppkg.Context) (any, error) {
|
||||
name := ctx.Param("name")
|
||||
if name == "" {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, "proxy name is required")
|
||||
}
|
||||
|
||||
if err := c.manager.DeleteStoreProxy(name); err != nil {
|
||||
return nil, c.toHTTPError(err)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *Controller) ListStoreVisitors(ctx *httppkg.Context) (any, error) {
|
||||
visitors, err := c.manager.ListStoreVisitors()
|
||||
if err != nil {
|
||||
return nil, c.toHTTPError(err)
|
||||
}
|
||||
|
||||
resp := model.VisitorListResp{Visitors: make([]model.VisitorDefinition, 0, len(visitors))}
|
||||
for _, v := range visitors {
|
||||
payload, err := model.VisitorDefinitionFromConfigurer(v)
|
||||
if err != nil {
|
||||
return nil, httppkg.NewError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
resp.Visitors = append(resp.Visitors, payload)
|
||||
}
|
||||
slices.SortFunc(resp.Visitors, func(a, b model.VisitorDefinition) int {
|
||||
return cmp.Compare(a.Name, b.Name)
|
||||
})
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (c *Controller) GetStoreVisitor(ctx *httppkg.Context) (any, error) {
|
||||
name := ctx.Param("name")
|
||||
if name == "" {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, "visitor name is required")
|
||||
}
|
||||
|
||||
v, err := c.manager.GetStoreVisitor(name)
|
||||
if err != nil {
|
||||
return nil, c.toHTTPError(err)
|
||||
}
|
||||
|
||||
payload, err := model.VisitorDefinitionFromConfigurer(v)
|
||||
if err != nil {
|
||||
return nil, httppkg.NewError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func (c *Controller) CreateStoreVisitor(ctx *httppkg.Context) (any, error) {
|
||||
body, err := ctx.Body()
|
||||
if err != nil {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("read body error: %v", err))
|
||||
}
|
||||
|
||||
var payload model.VisitorDefinition
|
||||
if err := jsonx.Unmarshal(body, &payload); err != nil {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("parse JSON error: %v", err))
|
||||
}
|
||||
|
||||
if err := payload.Validate("", false); err != nil {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
cfg, err := payload.ToConfigurer()
|
||||
if err != nil {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
created, err := c.manager.CreateStoreVisitor(cfg)
|
||||
if err != nil {
|
||||
return nil, c.toHTTPError(err)
|
||||
}
|
||||
|
||||
resp, err := model.VisitorDefinitionFromConfigurer(created)
|
||||
if err != nil {
|
||||
return nil, httppkg.NewError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (c *Controller) UpdateStoreVisitor(ctx *httppkg.Context) (any, error) {
|
||||
name := ctx.Param("name")
|
||||
if name == "" {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, "visitor name is required")
|
||||
}
|
||||
|
||||
body, err := ctx.Body()
|
||||
if err != nil {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("read body error: %v", err))
|
||||
}
|
||||
|
||||
var payload model.VisitorDefinition
|
||||
if err := jsonx.Unmarshal(body, &payload); err != nil {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("parse JSON error: %v", err))
|
||||
}
|
||||
|
||||
if err := payload.Validate(name, true); err != nil {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
cfg, err := payload.ToConfigurer()
|
||||
if err != nil {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
updated, err := c.manager.UpdateStoreVisitor(name, cfg)
|
||||
if err != nil {
|
||||
return nil, c.toHTTPError(err)
|
||||
}
|
||||
|
||||
resp, err := model.VisitorDefinitionFromConfigurer(updated)
|
||||
if err != nil {
|
||||
return nil, httppkg.NewError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (c *Controller) DeleteStoreVisitor(ctx *httppkg.Context) (any, error) {
|
||||
name := ctx.Param("name")
|
||||
if name == "" {
|
||||
return nil, httppkg.NewError(http.StatusBadRequest, "visitor name is required")
|
||||
}
|
||||
|
||||
if err := c.manager.DeleteStoreVisitor(name); err != nil {
|
||||
return nil, c.toHTTPError(err)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
531
client/http/controller_test.go
Normal file
531
client/http/controller_test.go
Normal file
@@ -0,0 +1,531 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/fatedier/frp/client/configmgmt"
|
||||
"github.com/fatedier/frp/client/http/model"
|
||||
"github.com/fatedier/frp/client/proxy"
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
httppkg "github.com/fatedier/frp/pkg/util/http"
|
||||
)
|
||||
|
||||
type fakeConfigManager struct {
|
||||
reloadFromFileFn func(strict bool) error
|
||||
readConfigFileFn func() (string, error)
|
||||
writeConfigFileFn func(content []byte) error
|
||||
getProxyStatusFn func() []*proxy.WorkingStatus
|
||||
isStoreProxyEnabledFn func(name string) bool
|
||||
storeEnabledFn func() bool
|
||||
|
||||
listStoreProxiesFn func() ([]v1.ProxyConfigurer, error)
|
||||
getStoreProxyFn func(name string) (v1.ProxyConfigurer, error)
|
||||
createStoreProxyFn func(cfg v1.ProxyConfigurer) (v1.ProxyConfigurer, error)
|
||||
updateStoreProxyFn func(name string, cfg v1.ProxyConfigurer) (v1.ProxyConfigurer, error)
|
||||
deleteStoreProxyFn func(name string) error
|
||||
listStoreVisitorsFn func() ([]v1.VisitorConfigurer, error)
|
||||
getStoreVisitorFn func(name string) (v1.VisitorConfigurer, error)
|
||||
createStoreVisitFn func(cfg v1.VisitorConfigurer) (v1.VisitorConfigurer, error)
|
||||
updateStoreVisitFn func(name string, cfg v1.VisitorConfigurer) (v1.VisitorConfigurer, error)
|
||||
deleteStoreVisitFn func(name string) error
|
||||
gracefulCloseFn func(d time.Duration)
|
||||
}
|
||||
|
||||
func (m *fakeConfigManager) ReloadFromFile(strict bool) error {
|
||||
if m.reloadFromFileFn != nil {
|
||||
return m.reloadFromFileFn(strict)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *fakeConfigManager) ReadConfigFile() (string, error) {
|
||||
if m.readConfigFileFn != nil {
|
||||
return m.readConfigFileFn()
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (m *fakeConfigManager) WriteConfigFile(content []byte) error {
|
||||
if m.writeConfigFileFn != nil {
|
||||
return m.writeConfigFileFn(content)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *fakeConfigManager) GetProxyStatus() []*proxy.WorkingStatus {
|
||||
if m.getProxyStatusFn != nil {
|
||||
return m.getProxyStatusFn()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *fakeConfigManager) IsStoreProxyEnabled(name string) bool {
|
||||
if m.isStoreProxyEnabledFn != nil {
|
||||
return m.isStoreProxyEnabledFn(name)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *fakeConfigManager) StoreEnabled() bool {
|
||||
if m.storeEnabledFn != nil {
|
||||
return m.storeEnabledFn()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *fakeConfigManager) ListStoreProxies() ([]v1.ProxyConfigurer, error) {
|
||||
if m.listStoreProxiesFn != nil {
|
||||
return m.listStoreProxiesFn()
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *fakeConfigManager) GetStoreProxy(name string) (v1.ProxyConfigurer, error) {
|
||||
if m.getStoreProxyFn != nil {
|
||||
return m.getStoreProxyFn(name)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *fakeConfigManager) CreateStoreProxy(cfg v1.ProxyConfigurer) (v1.ProxyConfigurer, error) {
|
||||
if m.createStoreProxyFn != nil {
|
||||
return m.createStoreProxyFn(cfg)
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (m *fakeConfigManager) UpdateStoreProxy(name string, cfg v1.ProxyConfigurer) (v1.ProxyConfigurer, error) {
|
||||
if m.updateStoreProxyFn != nil {
|
||||
return m.updateStoreProxyFn(name, cfg)
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (m *fakeConfigManager) DeleteStoreProxy(name string) error {
|
||||
if m.deleteStoreProxyFn != nil {
|
||||
return m.deleteStoreProxyFn(name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *fakeConfigManager) ListStoreVisitors() ([]v1.VisitorConfigurer, error) {
|
||||
if m.listStoreVisitorsFn != nil {
|
||||
return m.listStoreVisitorsFn()
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *fakeConfigManager) GetStoreVisitor(name string) (v1.VisitorConfigurer, error) {
|
||||
if m.getStoreVisitorFn != nil {
|
||||
return m.getStoreVisitorFn(name)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *fakeConfigManager) CreateStoreVisitor(cfg v1.VisitorConfigurer) (v1.VisitorConfigurer, error) {
|
||||
if m.createStoreVisitFn != nil {
|
||||
return m.createStoreVisitFn(cfg)
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (m *fakeConfigManager) UpdateStoreVisitor(name string, cfg v1.VisitorConfigurer) (v1.VisitorConfigurer, error) {
|
||||
if m.updateStoreVisitFn != nil {
|
||||
return m.updateStoreVisitFn(name, cfg)
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (m *fakeConfigManager) DeleteStoreVisitor(name string) error {
|
||||
if m.deleteStoreVisitFn != nil {
|
||||
return m.deleteStoreVisitFn(name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *fakeConfigManager) GracefulClose(d time.Duration) {
|
||||
if m.gracefulCloseFn != nil {
|
||||
m.gracefulCloseFn(d)
|
||||
}
|
||||
}
|
||||
|
||||
func newRawTCPProxyConfig(name string) *v1.TCPProxyConfig {
|
||||
return &v1.TCPProxyConfig{
|
||||
ProxyBaseConfig: v1.ProxyBaseConfig{
|
||||
Name: name,
|
||||
Type: "tcp",
|
||||
ProxyBackend: v1.ProxyBackend{
|
||||
LocalPort: 10080,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildProxyStatusRespStoreSourceEnabled(t *testing.T) {
|
||||
status := &proxy.WorkingStatus{
|
||||
Name: "shared-proxy",
|
||||
Type: "tcp",
|
||||
Phase: proxy.ProxyPhaseRunning,
|
||||
RemoteAddr: ":8080",
|
||||
Cfg: newRawTCPProxyConfig("shared-proxy"),
|
||||
}
|
||||
|
||||
controller := &Controller{
|
||||
serverAddr: "127.0.0.1",
|
||||
manager: &fakeConfigManager{
|
||||
isStoreProxyEnabledFn: func(name string) bool {
|
||||
return name == "shared-proxy"
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp := controller.buildProxyStatusResp(status)
|
||||
if resp.Source != "store" {
|
||||
t.Fatalf("unexpected source: %q", resp.Source)
|
||||
}
|
||||
if resp.RemoteAddr != "127.0.0.1:8080" {
|
||||
t.Fatalf("unexpected remote addr: %q", resp.RemoteAddr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReloadErrorMapping(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expectedCode int
|
||||
}{
|
||||
{name: "invalid arg", err: fmtError(configmgmt.ErrInvalidArgument, "bad cfg"), expectedCode: http.StatusBadRequest},
|
||||
{name: "apply fail", err: fmtError(configmgmt.ErrApplyConfig, "reload failed"), expectedCode: http.StatusInternalServerError},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
controller := &Controller{
|
||||
manager: &fakeConfigManager{reloadFromFileFn: func(bool) error { return tc.err }},
|
||||
}
|
||||
ctx := httppkg.NewContext(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/api/reload", nil))
|
||||
_, err := controller.Reload(ctx)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
assertHTTPCode(t, err, tc.expectedCode)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreProxyErrorMapping(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expectedCode int
|
||||
}{
|
||||
{name: "not found", err: fmtError(configmgmt.ErrNotFound, "not found"), expectedCode: http.StatusNotFound},
|
||||
{name: "conflict", err: fmtError(configmgmt.ErrConflict, "exists"), expectedCode: http.StatusConflict},
|
||||
{name: "internal", err: errors.New("persist failed"), expectedCode: http.StatusInternalServerError},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
body := []byte(`{"name":"shared-proxy","type":"tcp","tcp":{"localPort":10080}}`)
|
||||
req := httptest.NewRequest(http.MethodPut, "/api/store/proxies/shared-proxy", bytes.NewReader(body))
|
||||
req = mux.SetURLVars(req, map[string]string{"name": "shared-proxy"})
|
||||
ctx := httppkg.NewContext(httptest.NewRecorder(), req)
|
||||
|
||||
controller := &Controller{
|
||||
manager: &fakeConfigManager{
|
||||
updateStoreProxyFn: func(_ string, _ v1.ProxyConfigurer) (v1.ProxyConfigurer, error) {
|
||||
return nil, tc.err
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := controller.UpdateStoreProxy(ctx)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
assertHTTPCode(t, err, tc.expectedCode)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreVisitorErrorMapping(t *testing.T) {
|
||||
body := []byte(`{"name":"shared-visitor","type":"xtcp","xtcp":{"serverName":"server","bindPort":10081,"secretKey":"secret"}}`)
|
||||
req := httptest.NewRequest(http.MethodDelete, "/api/store/visitors/shared-visitor", bytes.NewReader(body))
|
||||
req = mux.SetURLVars(req, map[string]string{"name": "shared-visitor"})
|
||||
ctx := httppkg.NewContext(httptest.NewRecorder(), req)
|
||||
|
||||
controller := &Controller{
|
||||
manager: &fakeConfigManager{
|
||||
deleteStoreVisitFn: func(string) error {
|
||||
return fmtError(configmgmt.ErrStoreDisabled, "disabled")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := controller.DeleteStoreVisitor(ctx)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
assertHTTPCode(t, err, http.StatusNotFound)
|
||||
}
|
||||
|
||||
func TestCreateStoreProxyIgnoresUnknownFields(t *testing.T) {
|
||||
var gotName string
|
||||
controller := &Controller{
|
||||
manager: &fakeConfigManager{
|
||||
createStoreProxyFn: func(cfg v1.ProxyConfigurer) (v1.ProxyConfigurer, error) {
|
||||
gotName = cfg.GetBaseConfig().Name
|
||||
return cfg, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
body := []byte(`{"name":"raw-proxy","type":"tcp","unexpected":"value","tcp":{"localPort":10080,"unknownInBlock":"value"}}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/store/proxies", bytes.NewReader(body))
|
||||
ctx := httppkg.NewContext(httptest.NewRecorder(), req)
|
||||
|
||||
resp, err := controller.CreateStoreProxy(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("create store proxy: %v", err)
|
||||
}
|
||||
if gotName != "raw-proxy" {
|
||||
t.Fatalf("unexpected proxy name: %q", gotName)
|
||||
}
|
||||
|
||||
payload, ok := resp.(model.ProxyDefinition)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected response type: %T", resp)
|
||||
}
|
||||
if payload.Type != "tcp" || payload.TCP == nil {
|
||||
t.Fatalf("unexpected payload: %#v", payload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateStoreVisitorIgnoresUnknownFields(t *testing.T) {
|
||||
var gotName string
|
||||
controller := &Controller{
|
||||
manager: &fakeConfigManager{
|
||||
createStoreVisitFn: func(cfg v1.VisitorConfigurer) (v1.VisitorConfigurer, error) {
|
||||
gotName = cfg.GetBaseConfig().Name
|
||||
return cfg, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
body := []byte(`{
|
||||
"name":"raw-visitor","type":"xtcp","unexpected":"value",
|
||||
"xtcp":{"serverName":"server","bindPort":10081,"secretKey":"secret","unknownInBlock":"value"}
|
||||
}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/store/visitors", bytes.NewReader(body))
|
||||
ctx := httppkg.NewContext(httptest.NewRecorder(), req)
|
||||
|
||||
resp, err := controller.CreateStoreVisitor(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("create store visitor: %v", err)
|
||||
}
|
||||
if gotName != "raw-visitor" {
|
||||
t.Fatalf("unexpected visitor name: %q", gotName)
|
||||
}
|
||||
|
||||
payload, ok := resp.(model.VisitorDefinition)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected response type: %T", resp)
|
||||
}
|
||||
if payload.Type != "xtcp" || payload.XTCP == nil {
|
||||
t.Fatalf("unexpected payload: %#v", payload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateStoreProxyPluginUnknownFieldsAreIgnored(t *testing.T) {
|
||||
var gotPluginType string
|
||||
controller := &Controller{
|
||||
manager: &fakeConfigManager{
|
||||
createStoreProxyFn: func(cfg v1.ProxyConfigurer) (v1.ProxyConfigurer, error) {
|
||||
gotPluginType = cfg.GetBaseConfig().Plugin.Type
|
||||
return cfg, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
body := []byte(`{"name":"plugin-proxy","type":"tcp","tcp":{"plugin":{"type":"http2https","localAddr":"127.0.0.1:8080","unknownInPlugin":"value"}}}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/store/proxies", bytes.NewReader(body))
|
||||
ctx := httppkg.NewContext(httptest.NewRecorder(), req)
|
||||
|
||||
resp, err := controller.CreateStoreProxy(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("create store proxy: %v", err)
|
||||
}
|
||||
if gotPluginType != "http2https" {
|
||||
t.Fatalf("unexpected plugin type: %q", gotPluginType)
|
||||
}
|
||||
payload, ok := resp.(model.ProxyDefinition)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected response type: %T", resp)
|
||||
}
|
||||
if payload.TCP == nil {
|
||||
t.Fatalf("unexpected response payload: %#v", payload)
|
||||
}
|
||||
pluginType := payload.TCP.Plugin.Type
|
||||
|
||||
if pluginType != "http2https" {
|
||||
t.Fatalf("unexpected plugin type in response payload: %q", pluginType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateStoreVisitorPluginUnknownFieldsAreIgnored(t *testing.T) {
|
||||
var gotPluginType string
|
||||
controller := &Controller{
|
||||
manager: &fakeConfigManager{
|
||||
createStoreVisitFn: func(cfg v1.VisitorConfigurer) (v1.VisitorConfigurer, error) {
|
||||
gotPluginType = cfg.GetBaseConfig().Plugin.Type
|
||||
return cfg, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
body := []byte(`{
|
||||
"name":"plugin-visitor","type":"stcp",
|
||||
"stcp":{"serverName":"server","bindPort":10081,"plugin":{"type":"virtual_net","destinationIP":"10.0.0.1","unknownInPlugin":"value"}}
|
||||
}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/store/visitors", bytes.NewReader(body))
|
||||
ctx := httppkg.NewContext(httptest.NewRecorder(), req)
|
||||
|
||||
resp, err := controller.CreateStoreVisitor(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("create store visitor: %v", err)
|
||||
}
|
||||
if gotPluginType != "virtual_net" {
|
||||
t.Fatalf("unexpected plugin type: %q", gotPluginType)
|
||||
}
|
||||
payload, ok := resp.(model.VisitorDefinition)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected response type: %T", resp)
|
||||
}
|
||||
if payload.STCP == nil {
|
||||
t.Fatalf("unexpected response payload: %#v", payload)
|
||||
}
|
||||
pluginType := payload.STCP.Plugin.Type
|
||||
|
||||
if pluginType != "virtual_net" {
|
||||
t.Fatalf("unexpected plugin type in response payload: %q", pluginType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateStoreProxyRejectsMismatchedTypeBlock(t *testing.T) {
|
||||
controller := &Controller{manager: &fakeConfigManager{}}
|
||||
body := []byte(`{"name":"p1","type":"tcp","udp":{"localPort":10080}}`)
|
||||
req := httptest.NewRequest(http.MethodPut, "/api/store/proxies/p1", bytes.NewReader(body))
|
||||
req = mux.SetURLVars(req, map[string]string{"name": "p1"})
|
||||
ctx := httppkg.NewContext(httptest.NewRecorder(), req)
|
||||
|
||||
_, err := controller.UpdateStoreProxy(ctx)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
assertHTTPCode(t, err, http.StatusBadRequest)
|
||||
}
|
||||
|
||||
func TestUpdateStoreProxyRejectsNameMismatch(t *testing.T) {
|
||||
controller := &Controller{manager: &fakeConfigManager{}}
|
||||
body := []byte(`{"name":"p2","type":"tcp","tcp":{"localPort":10080}}`)
|
||||
req := httptest.NewRequest(http.MethodPut, "/api/store/proxies/p1", bytes.NewReader(body))
|
||||
req = mux.SetURLVars(req, map[string]string{"name": "p1"})
|
||||
ctx := httppkg.NewContext(httptest.NewRecorder(), req)
|
||||
|
||||
_, err := controller.UpdateStoreProxy(ctx)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
assertHTTPCode(t, err, http.StatusBadRequest)
|
||||
}
|
||||
|
||||
func TestListStoreProxiesReturnsSortedPayload(t *testing.T) {
|
||||
controller := &Controller{
|
||||
manager: &fakeConfigManager{
|
||||
listStoreProxiesFn: func() ([]v1.ProxyConfigurer, error) {
|
||||
b := newRawTCPProxyConfig("b")
|
||||
a := newRawTCPProxyConfig("a")
|
||||
return []v1.ProxyConfigurer{b, a}, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := httppkg.NewContext(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/api/store/proxies", nil))
|
||||
|
||||
resp, err := controller.ListStoreProxies(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("list store proxies: %v", err)
|
||||
}
|
||||
out, ok := resp.(model.ProxyListResp)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected response type: %T", resp)
|
||||
}
|
||||
if len(out.Proxies) != 2 {
|
||||
t.Fatalf("unexpected proxy count: %d", len(out.Proxies))
|
||||
}
|
||||
if out.Proxies[0].Name != "a" || out.Proxies[1].Name != "b" {
|
||||
t.Fatalf("proxies are not sorted by name: %#v", out.Proxies)
|
||||
}
|
||||
}
|
||||
|
||||
func fmtError(sentinel error, msg string) error {
|
||||
return fmt.Errorf("%w: %s", sentinel, msg)
|
||||
}
|
||||
|
||||
func assertHTTPCode(t *testing.T, err error, expected int) {
|
||||
t.Helper()
|
||||
var httpErr *httppkg.Error
|
||||
if !errors.As(err, &httpErr) {
|
||||
t.Fatalf("unexpected error type: %T", err)
|
||||
}
|
||||
if httpErr.Code != expected {
|
||||
t.Fatalf("unexpected status code: got %d, want %d", httpErr.Code, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateStoreProxyReturnsTypedPayload(t *testing.T) {
|
||||
controller := &Controller{
|
||||
manager: &fakeConfigManager{
|
||||
updateStoreProxyFn: func(_ string, cfg v1.ProxyConfigurer) (v1.ProxyConfigurer, error) {
|
||||
return cfg, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
body := map[string]any{
|
||||
"name": "shared-proxy",
|
||||
"type": "tcp",
|
||||
"tcp": map[string]any{
|
||||
"localPort": 10080,
|
||||
"remotePort": 7000,
|
||||
},
|
||||
}
|
||||
data, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal request: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPut, "/api/store/proxies/shared-proxy", bytes.NewReader(data))
|
||||
req = mux.SetURLVars(req, map[string]string{"name": "shared-proxy"})
|
||||
ctx := httppkg.NewContext(httptest.NewRecorder(), req)
|
||||
|
||||
resp, err := controller.UpdateStoreProxy(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("update store proxy: %v", err)
|
||||
}
|
||||
payload, ok := resp.(model.ProxyDefinition)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected response type: %T", resp)
|
||||
}
|
||||
if payload.TCP == nil || payload.TCP.RemotePort != 7000 {
|
||||
t.Fatalf("unexpected response payload: %#v", payload)
|
||||
}
|
||||
}
|
||||
148
client/http/model/proxy_definition.go
Normal file
148
client/http/model/proxy_definition.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
)
|
||||
|
||||
type ProxyDefinition struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
|
||||
TCP *v1.TCPProxyConfig `json:"tcp,omitempty"`
|
||||
UDP *v1.UDPProxyConfig `json:"udp,omitempty"`
|
||||
HTTP *v1.HTTPProxyConfig `json:"http,omitempty"`
|
||||
HTTPS *v1.HTTPSProxyConfig `json:"https,omitempty"`
|
||||
TCPMux *v1.TCPMuxProxyConfig `json:"tcpmux,omitempty"`
|
||||
STCP *v1.STCPProxyConfig `json:"stcp,omitempty"`
|
||||
SUDP *v1.SUDPProxyConfig `json:"sudp,omitempty"`
|
||||
XTCP *v1.XTCPProxyConfig `json:"xtcp,omitempty"`
|
||||
}
|
||||
|
||||
func (p *ProxyDefinition) Validate(pathName string, isUpdate bool) error {
|
||||
if strings.TrimSpace(p.Name) == "" {
|
||||
return fmt.Errorf("proxy name is required")
|
||||
}
|
||||
if !IsProxyType(p.Type) {
|
||||
return fmt.Errorf("invalid proxy type: %s", p.Type)
|
||||
}
|
||||
if isUpdate && pathName != "" && pathName != p.Name {
|
||||
return fmt.Errorf("proxy name in URL must match name in body")
|
||||
}
|
||||
|
||||
_, blockType, blockCount := p.activeBlock()
|
||||
if blockCount != 1 {
|
||||
return fmt.Errorf("exactly one proxy type block is required")
|
||||
}
|
||||
if blockType != p.Type {
|
||||
return fmt.Errorf("proxy type block %q does not match type %q", blockType, p.Type)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ProxyDefinition) ToConfigurer() (v1.ProxyConfigurer, error) {
|
||||
block, _, _ := p.activeBlock()
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("exactly one proxy type block is required")
|
||||
}
|
||||
|
||||
cfg := block
|
||||
cfg.GetBaseConfig().Name = p.Name
|
||||
cfg.GetBaseConfig().Type = p.Type
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func ProxyDefinitionFromConfigurer(cfg v1.ProxyConfigurer) (ProxyDefinition, error) {
|
||||
if cfg == nil {
|
||||
return ProxyDefinition{}, fmt.Errorf("proxy config is nil")
|
||||
}
|
||||
|
||||
base := cfg.GetBaseConfig()
|
||||
payload := ProxyDefinition{
|
||||
Name: base.Name,
|
||||
Type: base.Type,
|
||||
}
|
||||
|
||||
switch c := cfg.(type) {
|
||||
case *v1.TCPProxyConfig:
|
||||
payload.TCP = c
|
||||
case *v1.UDPProxyConfig:
|
||||
payload.UDP = c
|
||||
case *v1.HTTPProxyConfig:
|
||||
payload.HTTP = c
|
||||
case *v1.HTTPSProxyConfig:
|
||||
payload.HTTPS = c
|
||||
case *v1.TCPMuxProxyConfig:
|
||||
payload.TCPMux = c
|
||||
case *v1.STCPProxyConfig:
|
||||
payload.STCP = c
|
||||
case *v1.SUDPProxyConfig:
|
||||
payload.SUDP = c
|
||||
case *v1.XTCPProxyConfig:
|
||||
payload.XTCP = c
|
||||
default:
|
||||
return ProxyDefinition{}, fmt.Errorf("unsupported proxy configurer type %T", cfg)
|
||||
}
|
||||
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func (p *ProxyDefinition) activeBlock() (v1.ProxyConfigurer, string, int) {
|
||||
count := 0
|
||||
var block v1.ProxyConfigurer
|
||||
var blockType string
|
||||
|
||||
if p.TCP != nil {
|
||||
count++
|
||||
block = p.TCP
|
||||
blockType = "tcp"
|
||||
}
|
||||
if p.UDP != nil {
|
||||
count++
|
||||
block = p.UDP
|
||||
blockType = "udp"
|
||||
}
|
||||
if p.HTTP != nil {
|
||||
count++
|
||||
block = p.HTTP
|
||||
blockType = "http"
|
||||
}
|
||||
if p.HTTPS != nil {
|
||||
count++
|
||||
block = p.HTTPS
|
||||
blockType = "https"
|
||||
}
|
||||
if p.TCPMux != nil {
|
||||
count++
|
||||
block = p.TCPMux
|
||||
blockType = "tcpmux"
|
||||
}
|
||||
if p.STCP != nil {
|
||||
count++
|
||||
block = p.STCP
|
||||
blockType = "stcp"
|
||||
}
|
||||
if p.SUDP != nil {
|
||||
count++
|
||||
block = p.SUDP
|
||||
blockType = "sudp"
|
||||
}
|
||||
if p.XTCP != nil {
|
||||
count++
|
||||
block = p.XTCP
|
||||
blockType = "xtcp"
|
||||
}
|
||||
|
||||
return block, blockType, count
|
||||
}
|
||||
|
||||
func IsProxyType(typ string) bool {
|
||||
switch typ {
|
||||
case "tcp", "udp", "http", "https", "tcpmux", "stcp", "sudp", "xtcp":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -12,7 +12,9 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package api
|
||||
package model
|
||||
|
||||
const SourceStore = "store"
|
||||
|
||||
// StatusResp is the response for GET /api/status
|
||||
type StatusResp map[string][]ProxyStatusResp
|
||||
@@ -29,31 +31,12 @@ type ProxyStatusResp struct {
|
||||
Source string `json:"source,omitempty"` // "store" or "config"
|
||||
}
|
||||
|
||||
// ProxyConfig wraps proxy configuration for API requests/responses.
|
||||
type ProxyConfig struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Config map[string]any `json:"config"`
|
||||
}
|
||||
|
||||
// VisitorConfig wraps visitor configuration for API requests/responses.
|
||||
type VisitorConfig struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Config map[string]any `json:"config"`
|
||||
}
|
||||
|
||||
// ProxyListResp is the response for GET /api/store/proxies
|
||||
type ProxyListResp struct {
|
||||
Proxies []ProxyConfig `json:"proxies"`
|
||||
Proxies []ProxyDefinition `json:"proxies"`
|
||||
}
|
||||
|
||||
// VisitorListResp is the response for GET /api/store/visitors
|
||||
type VisitorListResp struct {
|
||||
Visitors []VisitorConfig `json:"visitors"`
|
||||
}
|
||||
|
||||
// ErrorResp represents an error response
|
||||
type ErrorResp struct {
|
||||
Error string `json:"error"`
|
||||
Visitors []VisitorDefinition `json:"visitors"`
|
||||
}
|
||||
107
client/http/model/visitor_definition.go
Normal file
107
client/http/model/visitor_definition.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
)
|
||||
|
||||
type VisitorDefinition struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
|
||||
STCP *v1.STCPVisitorConfig `json:"stcp,omitempty"`
|
||||
SUDP *v1.SUDPVisitorConfig `json:"sudp,omitempty"`
|
||||
XTCP *v1.XTCPVisitorConfig `json:"xtcp,omitempty"`
|
||||
}
|
||||
|
||||
func (p *VisitorDefinition) Validate(pathName string, isUpdate bool) error {
|
||||
if strings.TrimSpace(p.Name) == "" {
|
||||
return fmt.Errorf("visitor name is required")
|
||||
}
|
||||
if !IsVisitorType(p.Type) {
|
||||
return fmt.Errorf("invalid visitor type: %s", p.Type)
|
||||
}
|
||||
if isUpdate && pathName != "" && pathName != p.Name {
|
||||
return fmt.Errorf("visitor name in URL must match name in body")
|
||||
}
|
||||
|
||||
_, blockType, blockCount := p.activeBlock()
|
||||
if blockCount != 1 {
|
||||
return fmt.Errorf("exactly one visitor type block is required")
|
||||
}
|
||||
if blockType != p.Type {
|
||||
return fmt.Errorf("visitor type block %q does not match type %q", blockType, p.Type)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *VisitorDefinition) ToConfigurer() (v1.VisitorConfigurer, error) {
|
||||
block, _, _ := p.activeBlock()
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("exactly one visitor type block is required")
|
||||
}
|
||||
|
||||
cfg := block
|
||||
cfg.GetBaseConfig().Name = p.Name
|
||||
cfg.GetBaseConfig().Type = p.Type
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func VisitorDefinitionFromConfigurer(cfg v1.VisitorConfigurer) (VisitorDefinition, error) {
|
||||
if cfg == nil {
|
||||
return VisitorDefinition{}, fmt.Errorf("visitor config is nil")
|
||||
}
|
||||
|
||||
base := cfg.GetBaseConfig()
|
||||
payload := VisitorDefinition{
|
||||
Name: base.Name,
|
||||
Type: base.Type,
|
||||
}
|
||||
|
||||
switch c := cfg.(type) {
|
||||
case *v1.STCPVisitorConfig:
|
||||
payload.STCP = c
|
||||
case *v1.SUDPVisitorConfig:
|
||||
payload.SUDP = c
|
||||
case *v1.XTCPVisitorConfig:
|
||||
payload.XTCP = c
|
||||
default:
|
||||
return VisitorDefinition{}, fmt.Errorf("unsupported visitor configurer type %T", cfg)
|
||||
}
|
||||
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func (p *VisitorDefinition) activeBlock() (v1.VisitorConfigurer, string, int) {
|
||||
count := 0
|
||||
var block v1.VisitorConfigurer
|
||||
var blockType string
|
||||
|
||||
if p.STCP != nil {
|
||||
count++
|
||||
block = p.STCP
|
||||
blockType = "stcp"
|
||||
}
|
||||
if p.SUDP != nil {
|
||||
count++
|
||||
block = p.SUDP
|
||||
blockType = "sudp"
|
||||
}
|
||||
if p.XTCP != nil {
|
||||
count++
|
||||
block = p.XTCP
|
||||
blockType = "xtcp"
|
||||
}
|
||||
return block, blockType, count
|
||||
}
|
||||
|
||||
func IsVisitorType(typ string) bool {
|
||||
switch typ {
|
||||
case "stcp", "sudp", "xtcp":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -16,6 +16,7 @@ package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"reflect"
|
||||
@@ -122,6 +123,33 @@ func (pxy *BaseProxy) Close() {
|
||||
}
|
||||
}
|
||||
|
||||
// wrapWorkConn applies rate limiting, encryption, and compression
|
||||
// to a work connection based on the proxy's transport configuration.
|
||||
// The returned recycle function should be called when the stream is no longer in use
|
||||
// to return compression resources to the pool. It is safe to not call recycle,
|
||||
// in which case resources will be garbage collected normally.
|
||||
func (pxy *BaseProxy) wrapWorkConn(conn net.Conn, encKey []byte) (io.ReadWriteCloser, func(), error) {
|
||||
var rwc io.ReadWriteCloser = conn
|
||||
if pxy.limiter != nil {
|
||||
rwc = libio.WrapReadWriteCloser(limit.NewReader(conn, pxy.limiter), limit.NewWriter(conn, pxy.limiter), func() error {
|
||||
return conn.Close()
|
||||
})
|
||||
}
|
||||
if pxy.baseCfg.Transport.UseEncryption {
|
||||
var err error
|
||||
rwc, err = libio.WithEncryption(rwc, encKey)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, nil, fmt.Errorf("create encryption stream error: %w", err)
|
||||
}
|
||||
}
|
||||
var recycleFn func()
|
||||
if pxy.baseCfg.Transport.UseCompression {
|
||||
rwc, recycleFn = libio.WithCompressionFromPool(rwc)
|
||||
}
|
||||
return rwc, recycleFn, nil
|
||||
}
|
||||
|
||||
func (pxy *BaseProxy) SetInWorkConnCallback(cb func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool) {
|
||||
pxy.inWorkConnCallback = cb
|
||||
}
|
||||
@@ -139,30 +167,14 @@ func (pxy *BaseProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) {
|
||||
func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWorkConn, encKey []byte) {
|
||||
xl := pxy.xl
|
||||
baseCfg := pxy.baseCfg
|
||||
var (
|
||||
remote io.ReadWriteCloser
|
||||
err error
|
||||
)
|
||||
remote = workConn
|
||||
if pxy.limiter != nil {
|
||||
remote = libio.WrapReadWriteCloser(limit.NewReader(workConn, pxy.limiter), limit.NewWriter(workConn, pxy.limiter), func() error {
|
||||
return workConn.Close()
|
||||
})
|
||||
}
|
||||
|
||||
xl.Tracef("handle tcp work connection, useEncryption: %t, useCompression: %t",
|
||||
baseCfg.Transport.UseEncryption, baseCfg.Transport.UseCompression)
|
||||
if baseCfg.Transport.UseEncryption {
|
||||
remote, err = libio.WithEncryption(remote, encKey)
|
||||
if err != nil {
|
||||
workConn.Close()
|
||||
xl.Errorf("create encryption stream error: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
var compressionResourceRecycleFn func()
|
||||
if baseCfg.Transport.UseCompression {
|
||||
remote, compressionResourceRecycleFn = libio.WithCompressionFromPool(remote)
|
||||
|
||||
remote, recycleFn, err := pxy.wrapWorkConn(workConn, encKey)
|
||||
if err != nil {
|
||||
xl.Errorf("wrap work connection: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// check if we need to send proxy protocol info
|
||||
@@ -178,7 +190,6 @@ func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWor
|
||||
}
|
||||
|
||||
if baseCfg.Transport.ProxyProtocolVersion != "" && m.SrcAddr != "" && m.SrcPort != 0 {
|
||||
// Use the common proxy protocol builder function
|
||||
header := netpkg.BuildProxyProtocolHeaderStruct(connInfo.SrcAddr, connInfo.DstAddr, baseCfg.Transport.ProxyProtocolVersion)
|
||||
connInfo.ProxyProtocolHeader = header
|
||||
}
|
||||
@@ -187,12 +198,18 @@ func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWor
|
||||
|
||||
if pxy.proxyPlugin != nil {
|
||||
// if plugin is set, let plugin handle connection first
|
||||
// Don't recycle compression resources here because plugins may
|
||||
// retain the connection after Handle returns.
|
||||
xl.Debugf("handle by plugin: %s", pxy.proxyPlugin.Name())
|
||||
pxy.proxyPlugin.Handle(pxy.ctx, &connInfo)
|
||||
xl.Debugf("handle by plugin finished")
|
||||
return
|
||||
}
|
||||
|
||||
if recycleFn != nil {
|
||||
defer recycleFn()
|
||||
}
|
||||
|
||||
localConn, err := libnet.Dial(
|
||||
net.JoinHostPort(baseCfg.LocalIP, strconv.Itoa(baseCfg.LocalPort)),
|
||||
libnet.WithTimeout(10*time.Second),
|
||||
@@ -209,6 +226,7 @@ func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWor
|
||||
if connInfo.ProxyProtocolHeader != nil {
|
||||
if _, err := connInfo.ProxyProtocolHeader.WriteTo(localConn); err != nil {
|
||||
workConn.Close()
|
||||
localConn.Close()
|
||||
xl.Errorf("write proxy protocol header to local conn error: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -219,7 +237,4 @@ func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWor
|
||||
if len(errs) > 0 {
|
||||
xl.Tracef("join connections errors: %v", errs)
|
||||
}
|
||||
if compressionResourceRecycleFn != nil {
|
||||
compressionResourceRecycleFn()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -118,9 +118,9 @@ func (pm *Manager) HandleEvent(payload any) error {
|
||||
}
|
||||
|
||||
func (pm *Manager) GetAllProxyStatus() []*WorkingStatus {
|
||||
ps := make([]*WorkingStatus, 0)
|
||||
pm.mu.RLock()
|
||||
defer pm.mu.RUnlock()
|
||||
ps := make([]*WorkingStatus, 0, len(pm.proxies))
|
||||
for _, pxy := range pm.proxies {
|
||||
ps = append(ps, pxy.GetStatus())
|
||||
}
|
||||
|
||||
@@ -29,8 +29,8 @@ import (
|
||||
"github.com/fatedier/frp/client/health"
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
"github.com/fatedier/frp/pkg/naming"
|
||||
"github.com/fatedier/frp/pkg/transport"
|
||||
"github.com/fatedier/frp/pkg/util/util"
|
||||
"github.com/fatedier/frp/pkg/util/xlog"
|
||||
"github.com/fatedier/frp/pkg/vnet"
|
||||
)
|
||||
@@ -116,7 +116,7 @@ func NewWrapper(
|
||||
vnetController: vnetController,
|
||||
xl: xl,
|
||||
ctx: xlog.NewContext(ctx, xl),
|
||||
wireName: util.AddUserPrefix(clientCfg.User, baseInfo.Name),
|
||||
wireName: naming.AddUserPrefix(clientCfg.User, baseInfo.Name),
|
||||
}
|
||||
|
||||
if baseInfo.HealthCheck.Type != "" && baseInfo.LocalPort > 0 {
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"reflect"
|
||||
"strconv"
|
||||
@@ -25,17 +24,15 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/golib/errors"
|
||||
libio "github.com/fatedier/golib/io"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
"github.com/fatedier/frp/pkg/proto/udp"
|
||||
"github.com/fatedier/frp/pkg/util/limit"
|
||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||
)
|
||||
|
||||
func init() {
|
||||
RegisterProxyFactory(reflect.TypeOf(&v1.SUDPProxyConfig{}), NewSUDPProxy)
|
||||
RegisterProxyFactory(reflect.TypeFor[*v1.SUDPProxyConfig](), NewSUDPProxy)
|
||||
}
|
||||
|
||||
type SUDPProxy struct {
|
||||
@@ -83,27 +80,13 @@ func (pxy *SUDPProxy) InWorkConn(conn net.Conn, _ *msg.StartWorkConn) {
|
||||
xl := pxy.xl
|
||||
xl.Infof("incoming a new work connection for sudp proxy, %s", conn.RemoteAddr().String())
|
||||
|
||||
var rwc io.ReadWriteCloser = conn
|
||||
var err error
|
||||
if pxy.limiter != nil {
|
||||
rwc = libio.WrapReadWriteCloser(limit.NewReader(conn, pxy.limiter), limit.NewWriter(conn, pxy.limiter), func() error {
|
||||
return conn.Close()
|
||||
})
|
||||
remote, _, err := pxy.wrapWorkConn(conn, pxy.encryptionKey)
|
||||
if err != nil {
|
||||
xl.Errorf("wrap work connection: %v", err)
|
||||
return
|
||||
}
|
||||
if pxy.cfg.Transport.UseEncryption {
|
||||
rwc, err = libio.WithEncryption(rwc, pxy.encryptionKey)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
xl.Errorf("create encryption stream error: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
if pxy.cfg.Transport.UseCompression {
|
||||
rwc = libio.WithCompression(rwc)
|
||||
}
|
||||
conn = netpkg.WrapReadWriteCloserToConn(rwc, conn)
|
||||
|
||||
workConn := conn
|
||||
workConn := netpkg.WrapReadWriteCloserToConn(remote, conn)
|
||||
readCh := make(chan *msg.UDPPacket, 1024)
|
||||
sendCh := make(chan msg.Message, 1024)
|
||||
isClose := false
|
||||
|
||||
@@ -17,24 +17,21 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/golib/errors"
|
||||
libio "github.com/fatedier/golib/io"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
"github.com/fatedier/frp/pkg/proto/udp"
|
||||
"github.com/fatedier/frp/pkg/util/limit"
|
||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||
)
|
||||
|
||||
func init() {
|
||||
RegisterProxyFactory(reflect.TypeOf(&v1.UDPProxyConfig{}), NewUDPProxy)
|
||||
RegisterProxyFactory(reflect.TypeFor[*v1.UDPProxyConfig](), NewUDPProxy)
|
||||
}
|
||||
|
||||
type UDPProxy struct {
|
||||
@@ -94,28 +91,14 @@ func (pxy *UDPProxy) InWorkConn(conn net.Conn, _ *msg.StartWorkConn) {
|
||||
// close resources related with old workConn
|
||||
pxy.Close()
|
||||
|
||||
var rwc io.ReadWriteCloser = conn
|
||||
var err error
|
||||
if pxy.limiter != nil {
|
||||
rwc = libio.WrapReadWriteCloser(limit.NewReader(conn, pxy.limiter), limit.NewWriter(conn, pxy.limiter), func() error {
|
||||
return conn.Close()
|
||||
})
|
||||
remote, _, err := pxy.wrapWorkConn(conn, pxy.encryptionKey)
|
||||
if err != nil {
|
||||
xl.Errorf("wrap work connection: %v", err)
|
||||
return
|
||||
}
|
||||
if pxy.cfg.Transport.UseEncryption {
|
||||
rwc, err = libio.WithEncryption(rwc, pxy.encryptionKey)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
xl.Errorf("create encryption stream error: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
if pxy.cfg.Transport.UseCompression {
|
||||
rwc = libio.WithCompression(rwc)
|
||||
}
|
||||
conn = netpkg.WrapReadWriteCloserToConn(rwc, conn)
|
||||
|
||||
pxy.mu.Lock()
|
||||
pxy.workConn = conn
|
||||
pxy.workConn = netpkg.WrapReadWriteCloserToConn(remote, conn)
|
||||
pxy.readCh = make(chan *msg.UDPPacket, 1024)
|
||||
pxy.sendCh = make(chan msg.Message, 1024)
|
||||
pxy.closed = false
|
||||
@@ -129,7 +112,7 @@ func (pxy *UDPProxy) InWorkConn(conn net.Conn, _ *msg.StartWorkConn) {
|
||||
return
|
||||
}
|
||||
if errRet := errors.PanicToError(func() {
|
||||
xl.Tracef("get udp package from workConn: %s", udpMsg.Content)
|
||||
xl.Tracef("get udp package from workConn, len: %d", len(udpMsg.Content))
|
||||
readCh <- &udpMsg
|
||||
}); errRet != nil {
|
||||
xl.Infof("reader goroutine for udp work connection closed: %v", errRet)
|
||||
@@ -145,7 +128,7 @@ func (pxy *UDPProxy) InWorkConn(conn net.Conn, _ *msg.StartWorkConn) {
|
||||
for rawMsg := range sendCh {
|
||||
switch m := rawMsg.(type) {
|
||||
case *msg.UDPPacket:
|
||||
xl.Tracef("send udp package to workConn: %s", m.Content)
|
||||
xl.Tracef("send udp package to workConn, len: %d", len(m.Content))
|
||||
case *msg.Ping:
|
||||
xl.Tracef("send ping message to udp workConn")
|
||||
}
|
||||
|
||||
@@ -27,14 +27,14 @@ import (
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
"github.com/fatedier/frp/pkg/naming"
|
||||
"github.com/fatedier/frp/pkg/nathole"
|
||||
"github.com/fatedier/frp/pkg/transport"
|
||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||
"github.com/fatedier/frp/pkg/util/util"
|
||||
)
|
||||
|
||||
func init() {
|
||||
RegisterProxyFactory(reflect.TypeOf(&v1.XTCPProxyConfig{}), NewXTCPProxy)
|
||||
RegisterProxyFactory(reflect.TypeFor[*v1.XTCPProxyConfig](), NewXTCPProxy)
|
||||
}
|
||||
|
||||
type XTCPProxy struct {
|
||||
@@ -86,7 +86,7 @@ func (pxy *XTCPProxy) InWorkConn(conn net.Conn, startWorkConnMsg *msg.StartWorkC
|
||||
transactionID := nathole.NewTransactionID()
|
||||
natHoleClientMsg := &msg.NatHoleClient{
|
||||
TransactionID: transactionID,
|
||||
ProxyName: util.AddUserPrefix(pxy.clientCfg.User, pxy.cfg.Name),
|
||||
ProxyName: naming.AddUserPrefix(pxy.clientCfg.User, pxy.cfg.Name),
|
||||
Sid: natHoleSidMsg.Sid,
|
||||
MappedAddrs: prepareResult.Addrs,
|
||||
AssistedAddrs: prepareResult.AssistedAddrs,
|
||||
|
||||
@@ -123,8 +123,11 @@ type Service struct {
|
||||
|
||||
vnetController *vnet.Controller
|
||||
|
||||
cfgMu sync.RWMutex
|
||||
common *v1.ClientCommonConfig
|
||||
cfgMu sync.RWMutex
|
||||
// reloadMu serializes reload transactions to keep reloadCommon and applied
|
||||
// config in sync across concurrent API operations.
|
||||
reloadMu sync.Mutex
|
||||
common *v1.ClientCommonConfig
|
||||
// reloadCommon is used for filtering/defaulting during config-source reloads.
|
||||
// It can be updated by /api/reload without mutating startup-only common behavior.
|
||||
reloadCommon *v1.ClientCommonConfig
|
||||
@@ -441,26 +444,28 @@ func (svr *Service) UpdateConfigSource(
|
||||
proxyCfgs []v1.ProxyConfigurer,
|
||||
visitorCfgs []v1.VisitorConfigurer,
|
||||
) error {
|
||||
svr.reloadMu.Lock()
|
||||
defer svr.reloadMu.Unlock()
|
||||
|
||||
cfgSource := svr.configSource
|
||||
if cfgSource == nil {
|
||||
return fmt.Errorf("config source is not available")
|
||||
}
|
||||
|
||||
// Update reloadCommon before ReplaceAll so the subsequent reload uses the
|
||||
// same common config as /api/reload validation.
|
||||
svr.cfgMu.Lock()
|
||||
prevReloadCommon := svr.reloadCommon
|
||||
svr.reloadCommon = common
|
||||
svr.cfgMu.Unlock()
|
||||
|
||||
if err := cfgSource.ReplaceAll(proxyCfgs, visitorCfgs); err != nil {
|
||||
svr.cfgMu.Lock()
|
||||
svr.reloadCommon = prevReloadCommon
|
||||
svr.cfgMu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
return svr.reloadConfigFromSources()
|
||||
// Non-atomic update semantics: source has been updated at this point.
|
||||
// Even if reload fails below, keep this common config for subsequent reloads.
|
||||
svr.cfgMu.Lock()
|
||||
svr.reloadCommon = common
|
||||
svr.cfgMu.Unlock()
|
||||
|
||||
if err := svr.reloadConfigFromSourcesLocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svr *Service) Close() {
|
||||
@@ -473,6 +478,15 @@ func (svr *Service) GracefulClose(d time.Duration) {
|
||||
}
|
||||
|
||||
func (svr *Service) stop() {
|
||||
// Coordinate shutdown with reload/update paths that read source pointers.
|
||||
svr.reloadMu.Lock()
|
||||
if svr.aggregator != nil {
|
||||
svr.aggregator = nil
|
||||
}
|
||||
svr.configSource = nil
|
||||
svr.storeSource = nil
|
||||
svr.reloadMu.Unlock()
|
||||
|
||||
svr.ctlMu.Lock()
|
||||
defer svr.ctlMu.Unlock()
|
||||
if svr.ctl != nil {
|
||||
@@ -483,11 +497,6 @@ func (svr *Service) stop() {
|
||||
svr.webServer.Close()
|
||||
svr.webServer = nil
|
||||
}
|
||||
if svr.aggregator != nil {
|
||||
svr.aggregator = nil
|
||||
}
|
||||
svr.configSource = nil
|
||||
svr.storeSource = nil
|
||||
}
|
||||
|
||||
func (svr *Service) getProxyStatus(name string) (*proxy.WorkingStatus, bool) {
|
||||
@@ -520,7 +529,14 @@ func (s *statusExporterImpl) GetProxyStatus(name string) (*proxy.WorkingStatus,
|
||||
}
|
||||
|
||||
func (svr *Service) reloadConfigFromSources() error {
|
||||
if svr.aggregator == nil {
|
||||
svr.reloadMu.Lock()
|
||||
defer svr.reloadMu.Unlock()
|
||||
return svr.reloadConfigFromSourcesLocked()
|
||||
}
|
||||
|
||||
func (svr *Service) reloadConfigFromSourcesLocked() error {
|
||||
aggregator := svr.aggregator
|
||||
if aggregator == nil {
|
||||
return errors.New("config aggregator is not initialized")
|
||||
}
|
||||
|
||||
@@ -528,7 +544,7 @@ func (svr *Service) reloadConfigFromSources() error {
|
||||
reloadCommon := svr.reloadCommon
|
||||
svr.cfgMu.RUnlock()
|
||||
|
||||
proxies, visitors, err := svr.aggregator.Load()
|
||||
proxies, visitors, err := aggregator.Load()
|
||||
if err != nil {
|
||||
return fmt.Errorf("reload config from sources failed: %w", err)
|
||||
}
|
||||
|
||||
140
client/service_test.go
Normal file
140
client/service_test.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/fatedier/frp/pkg/config/source"
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
)
|
||||
|
||||
func TestUpdateConfigSourceRollsBackReloadCommonOnReplaceAllFailure(t *testing.T) {
|
||||
prevCommon := &v1.ClientCommonConfig{User: "old-user"}
|
||||
newCommon := &v1.ClientCommonConfig{User: "new-user"}
|
||||
|
||||
svr := &Service{
|
||||
configSource: source.NewConfigSource(),
|
||||
reloadCommon: prevCommon,
|
||||
}
|
||||
|
||||
invalidProxy := &v1.TCPProxyConfig{}
|
||||
err := svr.UpdateConfigSource(newCommon, []v1.ProxyConfigurer{invalidProxy}, nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "proxy name cannot be empty") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if svr.reloadCommon != prevCommon {
|
||||
t.Fatalf("reloadCommon should roll back on ReplaceAll failure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateConfigSourceKeepsReloadCommonOnReloadFailure(t *testing.T) {
|
||||
prevCommon := &v1.ClientCommonConfig{User: "old-user"}
|
||||
newCommon := &v1.ClientCommonConfig{User: "new-user"}
|
||||
|
||||
svr := &Service{
|
||||
// Keep configSource valid so ReplaceAll succeeds first.
|
||||
configSource: source.NewConfigSource(),
|
||||
reloadCommon: prevCommon,
|
||||
// Keep aggregator nil to force reload failure.
|
||||
aggregator: nil,
|
||||
}
|
||||
|
||||
validProxy := &v1.TCPProxyConfig{
|
||||
ProxyBaseConfig: v1.ProxyBaseConfig{
|
||||
Name: "p1",
|
||||
Type: "tcp",
|
||||
},
|
||||
}
|
||||
err := svr.UpdateConfigSource(newCommon, []v1.ProxyConfigurer{validProxy}, nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "config aggregator is not initialized") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if svr.reloadCommon != newCommon {
|
||||
t.Fatalf("reloadCommon should keep new value on reload failure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReloadConfigFromSourcesDoesNotMutateStoreConfigs(t *testing.T) {
|
||||
storeSource, err := source.NewStoreSource(source.StoreSourceConfig{
|
||||
Path: filepath.Join(t.TempDir(), "store.json"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("new store source: %v", err)
|
||||
}
|
||||
|
||||
proxyCfg := &v1.TCPProxyConfig{
|
||||
ProxyBaseConfig: v1.ProxyBaseConfig{
|
||||
Name: "store-proxy",
|
||||
Type: "tcp",
|
||||
},
|
||||
}
|
||||
visitorCfg := &v1.STCPVisitorConfig{
|
||||
VisitorBaseConfig: v1.VisitorBaseConfig{
|
||||
Name: "store-visitor",
|
||||
Type: "stcp",
|
||||
},
|
||||
}
|
||||
if err := storeSource.AddProxy(proxyCfg); err != nil {
|
||||
t.Fatalf("add proxy to store: %v", err)
|
||||
}
|
||||
if err := storeSource.AddVisitor(visitorCfg); err != nil {
|
||||
t.Fatalf("add visitor to store: %v", err)
|
||||
}
|
||||
|
||||
agg := source.NewAggregator(source.NewConfigSource())
|
||||
agg.SetStoreSource(storeSource)
|
||||
svr := &Service{
|
||||
aggregator: agg,
|
||||
configSource: agg.ConfigSource(),
|
||||
storeSource: storeSource,
|
||||
reloadCommon: &v1.ClientCommonConfig{},
|
||||
}
|
||||
|
||||
if err := svr.reloadConfigFromSources(); err != nil {
|
||||
t.Fatalf("reload config from sources: %v", err)
|
||||
}
|
||||
|
||||
gotProxy := storeSource.GetProxy("store-proxy")
|
||||
if gotProxy == nil {
|
||||
t.Fatalf("proxy not found in store")
|
||||
}
|
||||
if gotProxy.GetBaseConfig().LocalIP != "" {
|
||||
t.Fatalf("store proxy localIP should stay empty, got %q", gotProxy.GetBaseConfig().LocalIP)
|
||||
}
|
||||
|
||||
gotVisitor := storeSource.GetVisitor("store-visitor")
|
||||
if gotVisitor == nil {
|
||||
t.Fatalf("visitor not found in store")
|
||||
}
|
||||
if gotVisitor.GetBaseConfig().BindAddr != "" {
|
||||
t.Fatalf("store visitor bindAddr should stay empty, got %q", gotVisitor.GetBaseConfig().BindAddr)
|
||||
}
|
||||
|
||||
svr.cfgMu.RLock()
|
||||
defer svr.cfgMu.RUnlock()
|
||||
|
||||
if len(svr.proxyCfgs) != 1 {
|
||||
t.Fatalf("expected 1 runtime proxy, got %d", len(svr.proxyCfgs))
|
||||
}
|
||||
if svr.proxyCfgs[0].GetBaseConfig().LocalIP != "127.0.0.1" {
|
||||
t.Fatalf("runtime proxy localIP should be defaulted, got %q", svr.proxyCfgs[0].GetBaseConfig().LocalIP)
|
||||
}
|
||||
|
||||
if len(svr.visitorCfgs) != 1 {
|
||||
t.Fatalf("expected 1 runtime visitor, got %d", len(svr.visitorCfgs))
|
||||
}
|
||||
if svr.visitorCfgs[0].GetBaseConfig().BindAddr != "127.0.0.1" {
|
||||
t.Fatalf("runtime visitor bindAddr should be defaulted, got %q", svr.visitorCfgs[0].GetBaseConfig().BindAddr)
|
||||
}
|
||||
}
|
||||
@@ -15,17 +15,12 @@
|
||||
package visitor
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
libio "github.com/fatedier/golib/io"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
"github.com/fatedier/frp/pkg/util/util"
|
||||
"github.com/fatedier/frp/pkg/util/xlog"
|
||||
)
|
||||
|
||||
@@ -41,10 +36,10 @@ func (sv *STCPVisitor) Run() (err error) {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go sv.worker()
|
||||
go sv.acceptLoop(sv.l, "stcp local", sv.handleConn)
|
||||
}
|
||||
|
||||
go sv.internalConnWorker()
|
||||
go sv.acceptLoop(sv.internalLn, "stcp internal", sv.handleConn)
|
||||
|
||||
if sv.plugin != nil {
|
||||
sv.plugin.Start()
|
||||
@@ -56,35 +51,10 @@ func (sv *STCPVisitor) Close() {
|
||||
sv.BaseVisitor.Close()
|
||||
}
|
||||
|
||||
func (sv *STCPVisitor) worker() {
|
||||
xl := xlog.FromContextSafe(sv.ctx)
|
||||
for {
|
||||
conn, err := sv.l.Accept()
|
||||
if err != nil {
|
||||
xl.Warnf("stcp local listener closed")
|
||||
return
|
||||
}
|
||||
go sv.handleConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (sv *STCPVisitor) internalConnWorker() {
|
||||
xl := xlog.FromContextSafe(sv.ctx)
|
||||
for {
|
||||
conn, err := sv.internalLn.Accept()
|
||||
if err != nil {
|
||||
xl.Warnf("stcp internal listener closed")
|
||||
return
|
||||
}
|
||||
go sv.handleConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (sv *STCPVisitor) handleConn(userConn net.Conn) {
|
||||
xl := xlog.FromContextSafe(sv.ctx)
|
||||
var tunnelErr error
|
||||
defer func() {
|
||||
// If there was an error and connection supports CloseWithError, use it
|
||||
if tunnelErr != nil {
|
||||
if eConn, ok := userConn.(interface{ CloseWithError(error) error }); ok {
|
||||
_ = eConn.CloseWithError(tunnelErr)
|
||||
@@ -95,62 +65,21 @@ func (sv *STCPVisitor) handleConn(userConn net.Conn) {
|
||||
}()
|
||||
|
||||
xl.Debugf("get a new stcp user connection")
|
||||
visitorConn, err := sv.helper.ConnectServer()
|
||||
visitorConn, err := sv.dialRawVisitorConn(sv.cfg.GetBaseConfig())
|
||||
if err != nil {
|
||||
xl.Warnf("dialRawVisitorConn error: %v", err)
|
||||
tunnelErr = err
|
||||
return
|
||||
}
|
||||
defer visitorConn.Close()
|
||||
|
||||
now := time.Now().Unix()
|
||||
targetProxyName := util.BuildTargetServerProxyName(sv.clientCfg.User, sv.cfg.ServerUser, sv.cfg.ServerName)
|
||||
newVisitorConnMsg := &msg.NewVisitorConn{
|
||||
RunID: sv.helper.RunID(),
|
||||
ProxyName: targetProxyName,
|
||||
SignKey: util.GetAuthKey(sv.cfg.SecretKey, now),
|
||||
Timestamp: now,
|
||||
UseEncryption: sv.cfg.Transport.UseEncryption,
|
||||
UseCompression: sv.cfg.Transport.UseCompression,
|
||||
}
|
||||
err = msg.WriteMsg(visitorConn, newVisitorConnMsg)
|
||||
remote, recycleFn, err := wrapVisitorConn(visitorConn, sv.cfg.GetBaseConfig())
|
||||
if err != nil {
|
||||
xl.Warnf("send newVisitorConnMsg to server error: %v", err)
|
||||
xl.Warnf("wrapVisitorConn error: %v", err)
|
||||
tunnelErr = err
|
||||
return
|
||||
}
|
||||
|
||||
var newVisitorConnRespMsg msg.NewVisitorConnResp
|
||||
_ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
||||
err = msg.ReadMsgInto(visitorConn, &newVisitorConnRespMsg)
|
||||
if err != nil {
|
||||
xl.Warnf("get newVisitorConnRespMsg error: %v", err)
|
||||
tunnelErr = err
|
||||
return
|
||||
}
|
||||
_ = visitorConn.SetReadDeadline(time.Time{})
|
||||
|
||||
if newVisitorConnRespMsg.Error != "" {
|
||||
xl.Warnf("start new visitor connection error: %s", newVisitorConnRespMsg.Error)
|
||||
tunnelErr = fmt.Errorf("%s", newVisitorConnRespMsg.Error)
|
||||
return
|
||||
}
|
||||
|
||||
var remote io.ReadWriteCloser
|
||||
remote = visitorConn
|
||||
if sv.cfg.Transport.UseEncryption {
|
||||
remote, err = libio.WithEncryption(remote, []byte(sv.cfg.SecretKey))
|
||||
if err != nil {
|
||||
xl.Errorf("create encryption stream error: %v", err)
|
||||
tunnelErr = err
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if sv.cfg.Transport.UseCompression {
|
||||
var recycleFn func()
|
||||
remote, recycleFn = libio.WithCompressionFromPool(remote)
|
||||
defer recycleFn()
|
||||
}
|
||||
defer recycleFn()
|
||||
|
||||
libio.Join(userConn, remote)
|
||||
}
|
||||
|
||||
@@ -16,20 +16,17 @@ package visitor
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/golib/errors"
|
||||
libio "github.com/fatedier/golib/io"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
"github.com/fatedier/frp/pkg/proto/udp"
|
||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||
"github.com/fatedier/frp/pkg/util/util"
|
||||
"github.com/fatedier/frp/pkg/util/xlog"
|
||||
)
|
||||
|
||||
@@ -75,6 +72,7 @@ func (sv *SUDPVisitor) dispatcher() {
|
||||
|
||||
var (
|
||||
visitorConn net.Conn
|
||||
recycleFn func()
|
||||
err error
|
||||
|
||||
firstPacket *msg.UDPPacket
|
||||
@@ -92,14 +90,17 @@ func (sv *SUDPVisitor) dispatcher() {
|
||||
return
|
||||
}
|
||||
|
||||
visitorConn, err = sv.getNewVisitorConn()
|
||||
visitorConn, recycleFn, err = sv.getNewVisitorConn()
|
||||
if err != nil {
|
||||
xl.Warnf("newVisitorConn to frps error: %v, try to reconnect", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// visitorConn always be closed when worker done.
|
||||
sv.worker(visitorConn, firstPacket)
|
||||
func() {
|
||||
defer recycleFn()
|
||||
sv.worker(visitorConn, firstPacket)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-sv.checkCloseCh:
|
||||
@@ -146,7 +147,7 @@ func (sv *SUDPVisitor) worker(workConn net.Conn, firstPacket *msg.UDPPacket) {
|
||||
case *msg.UDPPacket:
|
||||
if errRet := errors.PanicToError(func() {
|
||||
sv.readCh <- m
|
||||
xl.Tracef("frpc visitor get udp packet from workConn: %s", m.Content)
|
||||
xl.Tracef("frpc visitor get udp packet from workConn, len: %d", len(m.Content))
|
||||
}); errRet != nil {
|
||||
xl.Infof("reader goroutine for udp work connection closed")
|
||||
return
|
||||
@@ -168,7 +169,7 @@ func (sv *SUDPVisitor) worker(workConn net.Conn, firstPacket *msg.UDPPacket) {
|
||||
xl.Warnf("sender goroutine for udp work connection closed: %v", errRet)
|
||||
return
|
||||
}
|
||||
xl.Tracef("send udp package to workConn: %s", firstPacket.Content)
|
||||
xl.Tracef("send udp package to workConn, len: %d", len(firstPacket.Content))
|
||||
}
|
||||
|
||||
for {
|
||||
@@ -183,7 +184,7 @@ func (sv *SUDPVisitor) worker(workConn net.Conn, firstPacket *msg.UDPPacket) {
|
||||
xl.Warnf("sender goroutine for udp work connection closed: %v", errRet)
|
||||
return
|
||||
}
|
||||
xl.Tracef("send udp package to workConn: %s", udpMsg.Content)
|
||||
xl.Tracef("send udp package to workConn, len: %d", len(udpMsg.Content))
|
||||
case <-closeCh:
|
||||
return
|
||||
}
|
||||
@@ -197,53 +198,17 @@ func (sv *SUDPVisitor) worker(workConn net.Conn, firstPacket *msg.UDPPacket) {
|
||||
xl.Infof("sudp worker is closed")
|
||||
}
|
||||
|
||||
func (sv *SUDPVisitor) getNewVisitorConn() (net.Conn, error) {
|
||||
xl := xlog.FromContextSafe(sv.ctx)
|
||||
visitorConn, err := sv.helper.ConnectServer()
|
||||
func (sv *SUDPVisitor) getNewVisitorConn() (net.Conn, func(), error) {
|
||||
rawConn, err := sv.dialRawVisitorConn(sv.cfg.GetBaseConfig())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("frpc connect frps error: %v", err)
|
||||
return nil, func() {}, err
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
targetProxyName := util.BuildTargetServerProxyName(sv.clientCfg.User, sv.cfg.ServerUser, sv.cfg.ServerName)
|
||||
newVisitorConnMsg := &msg.NewVisitorConn{
|
||||
RunID: sv.helper.RunID(),
|
||||
ProxyName: targetProxyName,
|
||||
SignKey: util.GetAuthKey(sv.cfg.SecretKey, now),
|
||||
Timestamp: now,
|
||||
UseEncryption: sv.cfg.Transport.UseEncryption,
|
||||
UseCompression: sv.cfg.Transport.UseCompression,
|
||||
}
|
||||
err = msg.WriteMsg(visitorConn, newVisitorConnMsg)
|
||||
rwc, recycleFn, err := wrapVisitorConn(rawConn, sv.cfg.GetBaseConfig())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("frpc send newVisitorConnMsg to frps error: %v", err)
|
||||
rawConn.Close()
|
||||
return nil, func() {}, err
|
||||
}
|
||||
|
||||
var newVisitorConnRespMsg msg.NewVisitorConnResp
|
||||
_ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
||||
err = msg.ReadMsgInto(visitorConn, &newVisitorConnRespMsg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("frpc read newVisitorConnRespMsg error: %v", err)
|
||||
}
|
||||
_ = visitorConn.SetReadDeadline(time.Time{})
|
||||
|
||||
if newVisitorConnRespMsg.Error != "" {
|
||||
return nil, fmt.Errorf("start new visitor connection error: %s", newVisitorConnRespMsg.Error)
|
||||
}
|
||||
|
||||
var remote io.ReadWriteCloser
|
||||
remote = visitorConn
|
||||
if sv.cfg.Transport.UseEncryption {
|
||||
remote, err = libio.WithEncryption(remote, []byte(sv.cfg.SecretKey))
|
||||
if err != nil {
|
||||
xl.Errorf("create encryption stream error: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if sv.cfg.Transport.UseCompression {
|
||||
remote = libio.WithCompression(remote)
|
||||
}
|
||||
return netpkg.WrapReadWriteCloserToConn(remote, visitorConn), nil
|
||||
return netpkg.WrapReadWriteCloserToConn(rwc, rawConn), recycleFn, nil
|
||||
}
|
||||
|
||||
func (sv *SUDPVisitor) Close() {
|
||||
|
||||
@@ -16,13 +16,21 @@ package visitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
libio "github.com/fatedier/golib/io"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
"github.com/fatedier/frp/pkg/naming"
|
||||
plugin "github.com/fatedier/frp/pkg/plugin/visitor"
|
||||
"github.com/fatedier/frp/pkg/transport"
|
||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||
"github.com/fatedier/frp/pkg/util/util"
|
||||
"github.com/fatedier/frp/pkg/util/xlog"
|
||||
"github.com/fatedier/frp/pkg/vnet"
|
||||
)
|
||||
@@ -119,6 +127,18 @@ func (v *BaseVisitor) AcceptConn(conn net.Conn) error {
|
||||
return v.internalLn.PutConn(conn)
|
||||
}
|
||||
|
||||
func (v *BaseVisitor) acceptLoop(l net.Listener, name string, handleConn func(net.Conn)) {
|
||||
xl := xlog.FromContextSafe(v.ctx)
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
xl.Warnf("%s listener closed", name)
|
||||
return
|
||||
}
|
||||
go handleConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (v *BaseVisitor) Close() {
|
||||
if v.l != nil {
|
||||
v.l.Close()
|
||||
@@ -130,3 +150,57 @@ func (v *BaseVisitor) Close() {
|
||||
v.plugin.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (v *BaseVisitor) dialRawVisitorConn(cfg *v1.VisitorBaseConfig) (net.Conn, error) {
|
||||
visitorConn, err := v.helper.ConnectServer()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connect to server error: %v", err)
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
targetProxyName := naming.BuildTargetServerProxyName(v.clientCfg.User, cfg.ServerUser, cfg.ServerName)
|
||||
newVisitorConnMsg := &msg.NewVisitorConn{
|
||||
RunID: v.helper.RunID(),
|
||||
ProxyName: targetProxyName,
|
||||
SignKey: util.GetAuthKey(cfg.SecretKey, now),
|
||||
Timestamp: now,
|
||||
UseEncryption: cfg.Transport.UseEncryption,
|
||||
UseCompression: cfg.Transport.UseCompression,
|
||||
}
|
||||
err = msg.WriteMsg(visitorConn, newVisitorConnMsg)
|
||||
if err != nil {
|
||||
visitorConn.Close()
|
||||
return nil, fmt.Errorf("send newVisitorConnMsg to server error: %v", err)
|
||||
}
|
||||
|
||||
var newVisitorConnRespMsg msg.NewVisitorConnResp
|
||||
_ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
||||
err = msg.ReadMsgInto(visitorConn, &newVisitorConnRespMsg)
|
||||
if err != nil {
|
||||
visitorConn.Close()
|
||||
return nil, fmt.Errorf("read newVisitorConnRespMsg error: %v", err)
|
||||
}
|
||||
_ = visitorConn.SetReadDeadline(time.Time{})
|
||||
|
||||
if newVisitorConnRespMsg.Error != "" {
|
||||
visitorConn.Close()
|
||||
return nil, fmt.Errorf("start new visitor connection error: %s", newVisitorConnRespMsg.Error)
|
||||
}
|
||||
return visitorConn, nil
|
||||
}
|
||||
|
||||
func wrapVisitorConn(conn io.ReadWriteCloser, cfg *v1.VisitorBaseConfig) (io.ReadWriteCloser, func(), error) {
|
||||
rwc := conn
|
||||
if cfg.Transport.UseEncryption {
|
||||
var err error
|
||||
rwc, err = libio.WithEncryption(rwc, []byte(cfg.SecretKey))
|
||||
if err != nil {
|
||||
return nil, func() {}, fmt.Errorf("create encryption stream error: %v", err)
|
||||
}
|
||||
}
|
||||
recycleFn := func() {}
|
||||
if cfg.Transport.UseCompression {
|
||||
rwc, recycleFn = libio.WithCompressionFromPool(rwc)
|
||||
}
|
||||
return rwc, recycleFn, nil
|
||||
}
|
||||
|
||||
@@ -31,6 +31,7 @@ import (
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
"github.com/fatedier/frp/pkg/naming"
|
||||
"github.com/fatedier/frp/pkg/nathole"
|
||||
"github.com/fatedier/frp/pkg/transport"
|
||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||
@@ -64,10 +65,10 @@ func (sv *XTCPVisitor) Run() (err error) {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go sv.worker()
|
||||
go sv.acceptLoop(sv.l, "xtcp local", sv.handleConn)
|
||||
}
|
||||
|
||||
go sv.internalConnWorker()
|
||||
go sv.acceptLoop(sv.internalLn, "xtcp internal", sv.handleConn)
|
||||
go sv.processTunnelStartEvents()
|
||||
if sv.cfg.KeepTunnelOpen {
|
||||
sv.retryLimiter = rate.NewLimiter(rate.Every(time.Hour/time.Duration(sv.cfg.MaxRetriesAnHour)), sv.cfg.MaxRetriesAnHour)
|
||||
@@ -92,30 +93,6 @@ func (sv *XTCPVisitor) Close() {
|
||||
}
|
||||
}
|
||||
|
||||
func (sv *XTCPVisitor) worker() {
|
||||
xl := xlog.FromContextSafe(sv.ctx)
|
||||
for {
|
||||
conn, err := sv.l.Accept()
|
||||
if err != nil {
|
||||
xl.Warnf("xtcp local listener closed")
|
||||
return
|
||||
}
|
||||
go sv.handleConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (sv *XTCPVisitor) internalConnWorker() {
|
||||
xl := xlog.FromContextSafe(sv.ctx)
|
||||
for {
|
||||
conn, err := sv.internalLn.Accept()
|
||||
if err != nil {
|
||||
xl.Warnf("xtcp internal listener closed")
|
||||
return
|
||||
}
|
||||
go sv.handleConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (sv *XTCPVisitor) processTunnelStartEvents() {
|
||||
for {
|
||||
select {
|
||||
@@ -205,20 +182,14 @@ func (sv *XTCPVisitor) handleConn(userConn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
var muxConnRWCloser io.ReadWriteCloser = tunnelConn
|
||||
if sv.cfg.Transport.UseEncryption {
|
||||
muxConnRWCloser, err = libio.WithEncryption(muxConnRWCloser, []byte(sv.cfg.SecretKey))
|
||||
if err != nil {
|
||||
xl.Errorf("create encryption stream error: %v", err)
|
||||
tunnelErr = err
|
||||
return
|
||||
}
|
||||
}
|
||||
if sv.cfg.Transport.UseCompression {
|
||||
var recycleFn func()
|
||||
muxConnRWCloser, recycleFn = libio.WithCompressionFromPool(muxConnRWCloser)
|
||||
defer recycleFn()
|
||||
muxConnRWCloser, recycleFn, err := wrapVisitorConn(tunnelConn, sv.cfg.GetBaseConfig())
|
||||
if err != nil {
|
||||
xl.Errorf("%v", err)
|
||||
tunnelConn.Close()
|
||||
tunnelErr = err
|
||||
return
|
||||
}
|
||||
defer recycleFn()
|
||||
|
||||
_, _, errs := libio.Join(userConn, muxConnRWCloser)
|
||||
xl.Debugf("join connections closed")
|
||||
@@ -280,7 +251,7 @@ func (sv *XTCPVisitor) getTunnelConn(ctx context.Context) (net.Conn, error) {
|
||||
// 4. Create a tunnel session using an underlying UDP connection.
|
||||
func (sv *XTCPVisitor) makeNatHole() {
|
||||
xl := xlog.FromContextSafe(sv.ctx)
|
||||
targetProxyName := util.BuildTargetServerProxyName(sv.clientCfg.User, sv.cfg.ServerUser, sv.cfg.ServerName)
|
||||
targetProxyName := naming.BuildTargetServerProxyName(sv.clientCfg.User, sv.cfg.ServerUser, sv.cfg.ServerName)
|
||||
xl.Tracef("makeNatHole start")
|
||||
if err := nathole.PreCheck(sv.ctx, sv.helper.MsgTransporter(), targetProxyName, 5*time.Second); err != nil {
|
||||
xl.Warnf("nathole precheck error: %v", err)
|
||||
@@ -372,6 +343,7 @@ func (ks *KCPTunnelSession) Init(listenConn *net.UDPConn, raddr *net.UDPAddr) er
|
||||
}
|
||||
remote, err := netpkg.NewKCPConnFromUDP(lConn, true, raddr.String())
|
||||
if err != nil {
|
||||
lConn.Close()
|
||||
return fmt.Errorf("create kcp connection from udp connection error: %v", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ var natholeDiscoveryCmd = &cobra.Command{
|
||||
Use: "discover",
|
||||
Short: "Discover nathole information from stun server",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
// ignore error here, because we can use command line pameters
|
||||
// ignore error here, because we can use command line parameters
|
||||
cfg, _, _, _, err := config.LoadClientConfig(cfgFile, strictConfigMode)
|
||||
if err != nil {
|
||||
cfg = &v1.ClientCommonConfig{}
|
||||
|
||||
@@ -143,6 +143,9 @@ transport.tls.enable = true
|
||||
|
||||
# Proxy names you want to start.
|
||||
# Default is empty, means all proxies.
|
||||
# This list is a global allowlist after config + store are merged, so entries
|
||||
# created via Store API are also filtered by this list.
|
||||
# If start is non-empty, any proxy/visitor not listed here will not be started.
|
||||
# start = ["ssh", "dns"]
|
||||
|
||||
# Alternative to 'start': You can control each proxy individually using the 'enabled' field.
|
||||
|
||||
@@ -5,7 +5,7 @@ COPY web/frpc/ ./
|
||||
RUN npm install
|
||||
RUN npm run build
|
||||
|
||||
FROM golang:1.24 AS building
|
||||
FROM golang:1.25 AS building
|
||||
|
||||
COPY . /building
|
||||
COPY --from=web-builder /web/frpc/dist /building/web/frpc/dist
|
||||
|
||||
@@ -5,7 +5,7 @@ COPY web/frps/ ./
|
||||
RUN npm install
|
||||
RUN npm run build
|
||||
|
||||
FROM golang:1.24 AS building
|
||||
FROM golang:1.25 AS building
|
||||
|
||||
COPY . /building
|
||||
COPY --from=web-builder /web/frps/dist /building/web/frps/dist
|
||||
|
||||
2
go.mod
2
go.mod
@@ -1,6 +1,6 @@
|
||||
module github.com/fatedier/frp
|
||||
|
||||
go 1.24.0
|
||||
go 1.25.0
|
||||
|
||||
require (
|
||||
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
@@ -205,7 +206,8 @@ type OidcAuthConsumer struct {
|
||||
additionalAuthScopes []v1.AuthScope
|
||||
|
||||
verifier TokenVerifier
|
||||
subjectsFromLogin []string
|
||||
mu sync.RWMutex
|
||||
subjectsFromLogin map[string]struct{}
|
||||
}
|
||||
|
||||
func NewTokenVerifier(cfg v1.AuthOIDCServerConfig) TokenVerifier {
|
||||
@@ -226,7 +228,7 @@ func NewOidcAuthVerifier(additionalAuthScopes []v1.AuthScope, verifier TokenVeri
|
||||
return &OidcAuthConsumer{
|
||||
additionalAuthScopes: additionalAuthScopes,
|
||||
verifier: verifier,
|
||||
subjectsFromLogin: []string{},
|
||||
subjectsFromLogin: make(map[string]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -235,9 +237,9 @@ func (auth *OidcAuthConsumer) VerifyLogin(loginMsg *msg.Login) (err error) {
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid OIDC token in login: %v", err)
|
||||
}
|
||||
if !slices.Contains(auth.subjectsFromLogin, token.Subject) {
|
||||
auth.subjectsFromLogin = append(auth.subjectsFromLogin, token.Subject)
|
||||
}
|
||||
auth.mu.Lock()
|
||||
auth.subjectsFromLogin[token.Subject] = struct{}{}
|
||||
auth.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -246,11 +248,13 @@ func (auth *OidcAuthConsumer) verifyPostLoginToken(privilegeKey string) (err err
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid OIDC token in ping: %v", err)
|
||||
}
|
||||
if !slices.Contains(auth.subjectsFromLogin, token.Subject) {
|
||||
auth.mu.RLock()
|
||||
_, ok := auth.subjectsFromLogin[token.Subject]
|
||||
auth.mu.RUnlock()
|
||||
if !ok {
|
||||
return fmt.Errorf("received different OIDC subject in login and ping. "+
|
||||
"original subjects: %s, "+
|
||||
"new subject: %s",
|
||||
auth.subjectsFromLogin, token.Subject)
|
||||
token.Subject)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -171,15 +171,14 @@ func Convert_ServerCommonConf_To_v1(conf *ServerCommonConf) *v1.ServerConfig {
|
||||
func transformHeadersFromPluginParams(params map[string]string) v1.HeaderOperations {
|
||||
out := v1.HeaderOperations{}
|
||||
for k, v := range params {
|
||||
if !strings.HasPrefix(k, "plugin_header_") {
|
||||
k, ok := strings.CutPrefix(k, "plugin_header_")
|
||||
if !ok || k == "" {
|
||||
continue
|
||||
}
|
||||
if k = strings.TrimPrefix(k, "plugin_header_"); k != "" {
|
||||
if out.Set == nil {
|
||||
out.Set = make(map[string]string)
|
||||
}
|
||||
out.Set[k] = v
|
||||
if out.Set == nil {
|
||||
out.Set = make(map[string]string)
|
||||
}
|
||||
out.Set[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -39,14 +39,14 @@ const (
|
||||
// Proxy
|
||||
var (
|
||||
proxyConfTypeMap = map[ProxyType]reflect.Type{
|
||||
ProxyTypeTCP: reflect.TypeOf(TCPProxyConf{}),
|
||||
ProxyTypeUDP: reflect.TypeOf(UDPProxyConf{}),
|
||||
ProxyTypeTCPMUX: reflect.TypeOf(TCPMuxProxyConf{}),
|
||||
ProxyTypeHTTP: reflect.TypeOf(HTTPProxyConf{}),
|
||||
ProxyTypeHTTPS: reflect.TypeOf(HTTPSProxyConf{}),
|
||||
ProxyTypeSTCP: reflect.TypeOf(STCPProxyConf{}),
|
||||
ProxyTypeXTCP: reflect.TypeOf(XTCPProxyConf{}),
|
||||
ProxyTypeSUDP: reflect.TypeOf(SUDPProxyConf{}),
|
||||
ProxyTypeTCP: reflect.TypeFor[TCPProxyConf](),
|
||||
ProxyTypeUDP: reflect.TypeFor[UDPProxyConf](),
|
||||
ProxyTypeTCPMUX: reflect.TypeFor[TCPMuxProxyConf](),
|
||||
ProxyTypeHTTP: reflect.TypeFor[HTTPProxyConf](),
|
||||
ProxyTypeHTTPS: reflect.TypeFor[HTTPSProxyConf](),
|
||||
ProxyTypeSTCP: reflect.TypeFor[STCPProxyConf](),
|
||||
ProxyTypeXTCP: reflect.TypeFor[XTCPProxyConf](),
|
||||
ProxyTypeSUDP: reflect.TypeFor[SUDPProxyConf](),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -22,8 +22,8 @@ func GetMapWithoutPrefix(set map[string]string, prefix string) map[string]string
|
||||
m := make(map[string]string)
|
||||
|
||||
for key, value := range set {
|
||||
if strings.HasPrefix(key, prefix) {
|
||||
m[strings.TrimPrefix(key, prefix)] = value
|
||||
if trimmed, ok := strings.CutPrefix(key, prefix); ok {
|
||||
m[trimmed] = value
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -32,9 +32,9 @@ const (
|
||||
// Visitor
|
||||
var (
|
||||
visitorConfTypeMap = map[VisitorType]reflect.Type{
|
||||
VisitorTypeSTCP: reflect.TypeOf(STCPVisitorConf{}),
|
||||
VisitorTypeXTCP: reflect.TypeOf(XTCPVisitorConf{}),
|
||||
VisitorTypeSUDP: reflect.TypeOf(SUDPVisitorConf{}),
|
||||
VisitorTypeSTCP: reflect.TypeFor[STCPVisitorConf](),
|
||||
VisitorTypeXTCP: reflect.TypeFor[XTCPVisitorConf](),
|
||||
VisitorTypeSUDP: reflect.TypeFor[SUDPVisitorConf](),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ package config
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -33,6 +34,7 @@ import (
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/config/v1/validation"
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
"github.com/fatedier/frp/pkg/util/jsonx"
|
||||
"github.com/fatedier/frp/pkg/util/util"
|
||||
)
|
||||
|
||||
@@ -108,7 +110,21 @@ func LoadConfigureFromFile(path string, c any, strict bool) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return LoadConfigure(content, c, strict)
|
||||
return LoadConfigure(content, c, strict, detectFormatFromPath(path))
|
||||
}
|
||||
|
||||
// detectFormatFromPath returns a format hint based on the file extension.
|
||||
func detectFormatFromPath(path string) string {
|
||||
switch strings.ToLower(filepath.Ext(path)) {
|
||||
case ".toml":
|
||||
return "toml"
|
||||
case ".yaml", ".yml":
|
||||
return "yaml"
|
||||
case ".json":
|
||||
return "json"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// parseYAMLWithDotFieldsHandling parses YAML with dot-prefixed fields handling
|
||||
@@ -129,48 +145,136 @@ func parseYAMLWithDotFieldsHandling(content []byte, target any) error {
|
||||
}
|
||||
|
||||
// Convert to JSON and decode with strict validation
|
||||
jsonBytes, err := json.Marshal(temp)
|
||||
jsonBytes, err := jsonx.Marshal(temp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
decoder := json.NewDecoder(bytes.NewReader(jsonBytes))
|
||||
decoder.DisallowUnknownFields()
|
||||
return decoder.Decode(target)
|
||||
return decodeJSONContent(jsonBytes, target, true)
|
||||
}
|
||||
|
||||
func decodeJSONContent(content []byte, target any, strict bool) error {
|
||||
if clientCfg, ok := target.(*v1.ClientConfig); ok {
|
||||
decoded, err := v1.DecodeClientConfigJSON(content, v1.DecodeOptions{
|
||||
DisallowUnknownFields: strict,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*clientCfg = decoded
|
||||
return nil
|
||||
}
|
||||
|
||||
return jsonx.UnmarshalWithOptions(content, target, jsonx.DecodeOptions{
|
||||
RejectUnknownMembers: strict,
|
||||
})
|
||||
}
|
||||
|
||||
// LoadConfigure loads configuration from bytes and unmarshal into c.
|
||||
// Now it supports json, yaml and toml format.
|
||||
func LoadConfigure(b []byte, c any, strict bool) error {
|
||||
v1.DisallowUnknownFieldsMu.Lock()
|
||||
defer v1.DisallowUnknownFieldsMu.Unlock()
|
||||
v1.DisallowUnknownFields = strict
|
||||
// An optional format hint (e.g. "toml", "yaml", "json") can be provided
|
||||
// to enable better error messages with line number information.
|
||||
func LoadConfigure(b []byte, c any, strict bool, formats ...string) error {
|
||||
format := ""
|
||||
if len(formats) > 0 {
|
||||
format = formats[0]
|
||||
}
|
||||
|
||||
originalBytes := b
|
||||
parsedFromTOML := false
|
||||
|
||||
var tomlObj any
|
||||
// Try to unmarshal as TOML first; swallow errors from that (assume it's not valid TOML).
|
||||
if err := toml.Unmarshal(b, &tomlObj); err == nil {
|
||||
b, err = json.Marshal(&tomlObj)
|
||||
tomlErr := toml.Unmarshal(b, &tomlObj)
|
||||
if tomlErr == nil {
|
||||
parsedFromTOML = true
|
||||
var err error
|
||||
b, err = jsonx.Marshal(&tomlObj)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else if format == "toml" {
|
||||
// File is known to be TOML but has syntax errors.
|
||||
return formatTOMLError(tomlErr)
|
||||
}
|
||||
|
||||
// If the buffer smells like JSON (first non-whitespace character is '{'), unmarshal as JSON directly.
|
||||
if yaml.IsJSONBuffer(b) {
|
||||
decoder := json.NewDecoder(bytes.NewBuffer(b))
|
||||
if strict {
|
||||
decoder.DisallowUnknownFields()
|
||||
if err := decodeJSONContent(b, c, strict); err != nil {
|
||||
return enhanceDecodeError(err, originalBytes, !parsedFromTOML)
|
||||
}
|
||||
return decoder.Decode(c)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle YAML content
|
||||
if strict {
|
||||
// In strict mode, always use our custom handler to support YAML merge
|
||||
return parseYAMLWithDotFieldsHandling(b, c)
|
||||
if err := parseYAMLWithDotFieldsHandling(b, c); err != nil {
|
||||
return enhanceDecodeError(err, originalBytes, !parsedFromTOML)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// Non-strict mode, parse normally
|
||||
return yaml.Unmarshal(b, c)
|
||||
}
|
||||
|
||||
// formatTOMLError extracts line/column information from TOML decode errors.
|
||||
func formatTOMLError(err error) error {
|
||||
var decErr *toml.DecodeError
|
||||
if errors.As(err, &decErr) {
|
||||
row, col := decErr.Position()
|
||||
return fmt.Errorf("toml: line %d, column %d: %s", row, col, decErr.Error())
|
||||
}
|
||||
var strictErr *toml.StrictMissingError
|
||||
if errors.As(err, &strictErr) {
|
||||
return strictErr
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// enhanceDecodeError tries to add field path and line number information to JSON/YAML decode errors.
|
||||
func enhanceDecodeError(err error, originalContent []byte, includeLine bool) error {
|
||||
var typeErr *json.UnmarshalTypeError
|
||||
if errors.As(err, &typeErr) && typeErr.Field != "" {
|
||||
if includeLine {
|
||||
line := findFieldLineInContent(originalContent, typeErr.Field)
|
||||
if line > 0 {
|
||||
return fmt.Errorf("line %d: field \"%s\": cannot unmarshal %s into %s", line, typeErr.Field, typeErr.Value, typeErr.Type)
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("field \"%s\": cannot unmarshal %s into %s", typeErr.Field, typeErr.Value, typeErr.Type)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// findFieldLineInContent searches the original config content for a field name
|
||||
// and returns the 1-indexed line number where it appears, or 0 if not found.
|
||||
func findFieldLineInContent(content []byte, fieldPath string) int {
|
||||
if fieldPath == "" {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Use the last component of the field path (e.g. "proxies" from "proxies" or
|
||||
// "protocol" from "transport.protocol").
|
||||
parts := strings.Split(fieldPath, ".")
|
||||
searchKey := parts[len(parts)-1]
|
||||
|
||||
lines := bytes.Split(content, []byte("\n"))
|
||||
for i, line := range lines {
|
||||
trimmed := bytes.TrimSpace(line)
|
||||
// Match TOML key assignments like: key = ...
|
||||
if bytes.HasPrefix(trimmed, []byte(searchKey)) {
|
||||
rest := bytes.TrimSpace(trimmed[len(searchKey):])
|
||||
if len(rest) > 0 && rest[0] == '=' {
|
||||
return i + 1
|
||||
}
|
||||
}
|
||||
// Match TOML table array headers like: [[proxies]]
|
||||
if bytes.Contains(trimmed, []byte("[["+searchKey+"]]")) {
|
||||
return i + 1
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func NewProxyConfigurerFromMsg(m *msg.NewProxy, serverCfg *v1.ServerConfig) (v1.ProxyConfigurer, error) {
|
||||
m.ProxyType = util.EmptyOr(m.ProxyType, string(v1.ProxyTypeTCP))
|
||||
|
||||
@@ -341,7 +445,8 @@ func FilterClientConfigurers(
|
||||
proxyCfgs := proxies
|
||||
visitorCfgs := visitors
|
||||
|
||||
// Filter by start
|
||||
// Filter by start across merged configurers from all sources.
|
||||
// For example, store entries are also filtered by this set.
|
||||
if len(common.Start) > 0 {
|
||||
startSet := sets.New(common.Start...)
|
||||
proxyCfgs = lo.Filter(proxyCfgs, func(c v1.ProxyConfigurer, _ int) bool {
|
||||
|
||||
@@ -189,6 +189,31 @@ unixPath = "/tmp/uds.sock"
|
||||
require.Error(err)
|
||||
}
|
||||
|
||||
func TestLoadClientConfigStrictMode_UnknownPluginField(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
content := `
|
||||
serverPort = 7000
|
||||
|
||||
[[proxies]]
|
||||
name = "test"
|
||||
type = "tcp"
|
||||
localPort = 6000
|
||||
[proxies.plugin]
|
||||
type = "http2https"
|
||||
localAddr = "127.0.0.1:8080"
|
||||
unknownInPlugin = "value"
|
||||
`
|
||||
|
||||
clientCfg := v1.ClientConfig{}
|
||||
|
||||
err := LoadConfigure([]byte(content), &clientCfg, false)
|
||||
require.NoError(err)
|
||||
|
||||
err = LoadConfigure([]byte(content), &clientCfg, true)
|
||||
require.ErrorContains(err, "unknownInPlugin")
|
||||
}
|
||||
|
||||
// TestYAMLMergeInStrictMode tests that YAML merge functionality works
|
||||
// even in strict mode by properly handling dot-prefixed fields
|
||||
func TestYAMLMergeInStrictMode(t *testing.T) {
|
||||
@@ -470,3 +495,111 @@ serverPort: 7000
|
||||
require.Equal("127.0.0.1", clientCfg.ServerAddr)
|
||||
require.Equal(7000, clientCfg.ServerPort)
|
||||
}
|
||||
|
||||
func TestTOMLSyntaxErrorWithPosition(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
// TOML with syntax error (unclosed table array header)
|
||||
content := `serverAddr = "127.0.0.1"
|
||||
serverPort = 7000
|
||||
|
||||
[[proxies]
|
||||
name = "test"
|
||||
`
|
||||
|
||||
clientCfg := v1.ClientConfig{}
|
||||
err := LoadConfigure([]byte(content), &clientCfg, false, "toml")
|
||||
require.Error(err)
|
||||
require.Contains(err.Error(), "toml")
|
||||
require.Contains(err.Error(), "line")
|
||||
require.Contains(err.Error(), "column")
|
||||
}
|
||||
|
||||
func TestTOMLTypeMismatchErrorWithFieldInfo(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
// TOML with wrong type: proxies should be a table array, not a string
|
||||
content := `serverAddr = "127.0.0.1"
|
||||
serverPort = 7000
|
||||
proxies = "this should be a table array"
|
||||
`
|
||||
|
||||
clientCfg := v1.ClientConfig{}
|
||||
err := LoadConfigure([]byte(content), &clientCfg, false, "toml")
|
||||
require.Error(err)
|
||||
// The error should contain field info
|
||||
errMsg := err.Error()
|
||||
require.Contains(errMsg, "proxies")
|
||||
require.NotContains(errMsg, "line")
|
||||
}
|
||||
|
||||
func TestFindFieldLineInContent(t *testing.T) {
|
||||
content := []byte(`serverAddr = "127.0.0.1"
|
||||
serverPort = 7000
|
||||
|
||||
[[proxies]]
|
||||
name = "test"
|
||||
type = "tcp"
|
||||
remotePort = 6000
|
||||
`)
|
||||
|
||||
tests := []struct {
|
||||
fieldPath string
|
||||
wantLine int
|
||||
}{
|
||||
{"serverAddr", 1},
|
||||
{"serverPort", 2},
|
||||
{"name", 5},
|
||||
{"type", 6},
|
||||
{"remotePort", 7},
|
||||
{"nonexistent", 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.fieldPath, func(t *testing.T) {
|
||||
got := findFieldLineInContent(content, tt.fieldPath)
|
||||
require.Equal(t, tt.wantLine, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatDetection(t *testing.T) {
|
||||
tests := []struct {
|
||||
path string
|
||||
format string
|
||||
}{
|
||||
{"config.toml", "toml"},
|
||||
{"config.TOML", "toml"},
|
||||
{"config.yaml", "yaml"},
|
||||
{"config.yml", "yaml"},
|
||||
{"config.json", "json"},
|
||||
{"config.ini", ""},
|
||||
{"config", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.path, func(t *testing.T) {
|
||||
require.Equal(t, tt.format, detectFormatFromPath(tt.path))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidTOMLStillWorks(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
// Valid TOML with format hint should work fine
|
||||
content := `serverAddr = "127.0.0.1"
|
||||
serverPort = 7000
|
||||
|
||||
[[proxies]]
|
||||
name = "test"
|
||||
type = "tcp"
|
||||
remotePort = 6000
|
||||
`
|
||||
clientCfg := v1.ClientConfig{}
|
||||
err := LoadConfigure([]byte(content), &clientCfg, false, "toml")
|
||||
require.NoError(err)
|
||||
require.Equal("127.0.0.1", clientCfg.ServerAddr)
|
||||
require.Equal(7000, clientCfg.ServerPort)
|
||||
require.Len(clientCfg.Proxies, 1)
|
||||
}
|
||||
|
||||
@@ -15,18 +15,16 @@
|
||||
package source
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"maps"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
)
|
||||
|
||||
type sourceEntry struct {
|
||||
source Source
|
||||
}
|
||||
|
||||
type Aggregator struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
@@ -58,17 +56,13 @@ func (a *Aggregator) StoreSource() *StoreSource {
|
||||
return a.storeSource
|
||||
}
|
||||
|
||||
func (a *Aggregator) getSourcesLocked() []sourceEntry {
|
||||
sources := make([]sourceEntry, 0, 2)
|
||||
func (a *Aggregator) getSourcesLocked() []Source {
|
||||
sources := make([]Source, 0, 2)
|
||||
if a.configSource != nil {
|
||||
sources = append(sources, sourceEntry{
|
||||
source: a.configSource,
|
||||
})
|
||||
sources = append(sources, a.configSource)
|
||||
}
|
||||
if a.storeSource != nil {
|
||||
sources = append(sources, sourceEntry{
|
||||
source: a.storeSource,
|
||||
})
|
||||
sources = append(sources, a.storeSource)
|
||||
}
|
||||
return sources
|
||||
}
|
||||
@@ -85,8 +79,8 @@ func (a *Aggregator) Load() ([]v1.ProxyConfigurer, []v1.VisitorConfigurer, error
|
||||
proxyMap := make(map[string]v1.ProxyConfigurer)
|
||||
visitorMap := make(map[string]v1.VisitorConfigurer)
|
||||
|
||||
for _, entry := range entries {
|
||||
proxies, visitors, err := entry.source.Load()
|
||||
for _, src := range entries {
|
||||
proxies, visitors, err := src.Load()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("load source: %w", err)
|
||||
}
|
||||
@@ -105,21 +99,11 @@ func (a *Aggregator) mapsToSortedSlices(
|
||||
proxyMap map[string]v1.ProxyConfigurer,
|
||||
visitorMap map[string]v1.VisitorConfigurer,
|
||||
) ([]v1.ProxyConfigurer, []v1.VisitorConfigurer) {
|
||||
proxies := make([]v1.ProxyConfigurer, 0, len(proxyMap))
|
||||
for _, p := range proxyMap {
|
||||
proxies = append(proxies, p)
|
||||
}
|
||||
sort.Slice(proxies, func(i, j int) bool {
|
||||
return proxies[i].GetBaseConfig().Name < proxies[j].GetBaseConfig().Name
|
||||
proxies := slices.SortedFunc(maps.Values(proxyMap), func(x, y v1.ProxyConfigurer) int {
|
||||
return cmp.Compare(x.GetBaseConfig().Name, y.GetBaseConfig().Name)
|
||||
})
|
||||
|
||||
visitors := make([]v1.VisitorConfigurer, 0, len(visitorMap))
|
||||
for _, v := range visitorMap {
|
||||
visitors = append(visitors, v)
|
||||
}
|
||||
sort.Slice(visitors, func(i, j int) bool {
|
||||
return visitors[i].GetBaseConfig().Name < visitors[j].GetBaseConfig().Name
|
||||
visitors := slices.SortedFunc(maps.Values(visitorMap), func(x, y v1.VisitorConfigurer) int {
|
||||
return cmp.Compare(x.GetBaseConfig().Name, y.GetBaseConfig().Name)
|
||||
})
|
||||
|
||||
return proxies, visitors
|
||||
}
|
||||
|
||||
@@ -196,7 +196,28 @@ func TestAggregator_VisitorMerge(t *testing.T) {
|
||||
require.Len(visitors, 2)
|
||||
}
|
||||
|
||||
func TestAggregator_Load_ReturnsSharedReferences(t *testing.T) {
|
||||
func TestAggregator_Load_ReturnsSortedByName(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
agg := newTestAggregator(t, nil)
|
||||
err := agg.ConfigSource().ReplaceAll(
|
||||
[]v1.ProxyConfigurer{mockProxy("charlie"), mockProxy("alice"), mockProxy("bob")},
|
||||
[]v1.VisitorConfigurer{mockVisitor("zulu"), mockVisitor("alpha")},
|
||||
)
|
||||
require.NoError(err)
|
||||
|
||||
proxies, visitors, err := agg.Load()
|
||||
require.NoError(err)
|
||||
require.Len(proxies, 3)
|
||||
require.Equal("alice", proxies[0].GetBaseConfig().Name)
|
||||
require.Equal("bob", proxies[1].GetBaseConfig().Name)
|
||||
require.Equal("charlie", proxies[2].GetBaseConfig().Name)
|
||||
require.Len(visitors, 2)
|
||||
require.Equal("alpha", visitors[0].GetBaseConfig().Name)
|
||||
require.Equal("zulu", visitors[1].GetBaseConfig().Name)
|
||||
}
|
||||
|
||||
func TestAggregator_Load_ReturnsDefensiveCopies(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
agg := newTestAggregator(t, nil)
|
||||
@@ -213,5 +234,5 @@ func TestAggregator_Load_ReturnsSharedReferences(t *testing.T) {
|
||||
proxies2, _, err := agg.Load()
|
||||
require.NoError(err)
|
||||
require.Len(proxies2, 1)
|
||||
require.Equal("alice.ssh", proxies2[0].GetBaseConfig().Name)
|
||||
require.Equal("ssh", proxies2[0].GetBaseConfig().Name)
|
||||
}
|
||||
|
||||
@@ -61,5 +61,5 @@ func (s *baseSource) Load() ([]v1.ProxyConfigurer, []v1.VisitorConfigurer, error
|
||||
visitors = append(visitors, v)
|
||||
}
|
||||
|
||||
return proxies, visitors, nil
|
||||
return cloneConfigurers(proxies, visitors)
|
||||
}
|
||||
|
||||
48
pkg/config/source/base_source_test.go
Normal file
48
pkg/config/source/base_source_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package source
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
)
|
||||
|
||||
func TestBaseSourceLoadReturnsClonedConfigurers(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
src := NewConfigSource()
|
||||
|
||||
proxyCfg := &v1.TCPProxyConfig{
|
||||
ProxyBaseConfig: v1.ProxyBaseConfig{
|
||||
Name: "proxy1",
|
||||
Type: "tcp",
|
||||
},
|
||||
}
|
||||
visitorCfg := &v1.STCPVisitorConfig{
|
||||
VisitorBaseConfig: v1.VisitorBaseConfig{
|
||||
Name: "visitor1",
|
||||
Type: "stcp",
|
||||
},
|
||||
}
|
||||
|
||||
err := src.ReplaceAll([]v1.ProxyConfigurer{proxyCfg}, []v1.VisitorConfigurer{visitorCfg})
|
||||
require.NoError(err)
|
||||
|
||||
firstProxies, firstVisitors, err := src.Load()
|
||||
require.NoError(err)
|
||||
require.Len(firstProxies, 1)
|
||||
require.Len(firstVisitors, 1)
|
||||
|
||||
// Mutate loaded objects as runtime completion would do.
|
||||
firstProxies[0].Complete()
|
||||
firstVisitors[0].Complete()
|
||||
|
||||
secondProxies, secondVisitors, err := src.Load()
|
||||
require.NoError(err)
|
||||
require.Len(secondProxies, 1)
|
||||
require.Len(secondVisitors, 1)
|
||||
|
||||
require.Empty(secondProxies[0].GetBaseConfig().LocalIP)
|
||||
require.Empty(secondVisitors[0].GetBaseConfig().BindAddr)
|
||||
}
|
||||
43
pkg/config/source/clone.go
Normal file
43
pkg/config/source/clone.go
Normal file
@@ -0,0 +1,43 @@
|
||||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package source
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
)
|
||||
|
||||
func cloneConfigurers(
|
||||
proxies []v1.ProxyConfigurer,
|
||||
visitors []v1.VisitorConfigurer,
|
||||
) ([]v1.ProxyConfigurer, []v1.VisitorConfigurer, error) {
|
||||
clonedProxies := make([]v1.ProxyConfigurer, 0, len(proxies))
|
||||
clonedVisitors := make([]v1.VisitorConfigurer, 0, len(visitors))
|
||||
|
||||
for _, cfg := range proxies {
|
||||
if cfg == nil {
|
||||
return nil, nil, fmt.Errorf("proxy cannot be nil")
|
||||
}
|
||||
clonedProxies = append(clonedProxies, cfg.Clone())
|
||||
}
|
||||
for _, cfg := range visitors {
|
||||
if cfg == nil {
|
||||
return nil, nil, fmt.Errorf("visitor cannot be nil")
|
||||
}
|
||||
clonedVisitors = append(clonedVisitors, cfg.Clone())
|
||||
}
|
||||
return clonedProxies, clonedVisitors, nil
|
||||
}
|
||||
@@ -15,12 +15,13 @@
|
||||
package source
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/util/jsonx"
|
||||
)
|
||||
|
||||
type StoreSourceConfig struct {
|
||||
@@ -37,6 +38,11 @@ type StoreSource struct {
|
||||
config StoreSourceConfig
|
||||
}
|
||||
|
||||
var (
|
||||
ErrAlreadyExists = errors.New("already exists")
|
||||
ErrNotFound = errors.New("not found")
|
||||
)
|
||||
|
||||
func NewStoreSource(cfg StoreSourceConfig) (*StoreSource, error) {
|
||||
if cfg.Path == "" {
|
||||
return nil, fmt.Errorf("path is required")
|
||||
@@ -68,34 +74,44 @@ func (s *StoreSource) loadFromFileUnlocked() error {
|
||||
return err
|
||||
}
|
||||
|
||||
var stored storeData
|
||||
if err := json.Unmarshal(data, &stored); err != nil {
|
||||
type rawStoreData struct {
|
||||
Proxies []jsonx.RawMessage `json:"proxies,omitempty"`
|
||||
Visitors []jsonx.RawMessage `json:"visitors,omitempty"`
|
||||
}
|
||||
stored := rawStoreData{}
|
||||
if err := jsonx.Unmarshal(data, &stored); err != nil {
|
||||
return fmt.Errorf("failed to parse JSON: %w", err)
|
||||
}
|
||||
|
||||
s.proxies = make(map[string]v1.ProxyConfigurer)
|
||||
s.visitors = make(map[string]v1.VisitorConfigurer)
|
||||
|
||||
for _, tp := range stored.Proxies {
|
||||
if tp.ProxyConfigurer != nil {
|
||||
proxyCfg := tp.ProxyConfigurer
|
||||
name := proxyCfg.GetBaseConfig().Name
|
||||
if name == "" {
|
||||
return fmt.Errorf("proxy name cannot be empty")
|
||||
}
|
||||
s.proxies[name] = proxyCfg
|
||||
for i, proxyData := range stored.Proxies {
|
||||
proxyCfg, err := v1.DecodeProxyConfigurerJSON(proxyData, v1.DecodeOptions{
|
||||
DisallowUnknownFields: false,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode proxy at index %d: %w", i, err)
|
||||
}
|
||||
name := proxyCfg.GetBaseConfig().Name
|
||||
if name == "" {
|
||||
return fmt.Errorf("proxy name cannot be empty")
|
||||
}
|
||||
s.proxies[name] = proxyCfg
|
||||
}
|
||||
|
||||
for _, tv := range stored.Visitors {
|
||||
if tv.VisitorConfigurer != nil {
|
||||
visitorCfg := tv.VisitorConfigurer
|
||||
name := visitorCfg.GetBaseConfig().Name
|
||||
if name == "" {
|
||||
return fmt.Errorf("visitor name cannot be empty")
|
||||
}
|
||||
s.visitors[name] = visitorCfg
|
||||
for i, visitorData := range stored.Visitors {
|
||||
visitorCfg, err := v1.DecodeVisitorConfigurerJSON(visitorData, v1.DecodeOptions{
|
||||
DisallowUnknownFields: false,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode visitor at index %d: %w", i, err)
|
||||
}
|
||||
name := visitorCfg.GetBaseConfig().Name
|
||||
if name == "" {
|
||||
return fmt.Errorf("visitor name cannot be empty")
|
||||
}
|
||||
s.visitors[name] = visitorCfg
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -114,7 +130,7 @@ func (s *StoreSource) saveToFileUnlocked() error {
|
||||
stored.Visitors = append(stored.Visitors, v1.TypedVisitorConfig{VisitorConfigurer: v})
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(stored, "", " ")
|
||||
data, err := jsonx.MarshalIndent(stored, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal JSON: %w", err)
|
||||
}
|
||||
@@ -170,7 +186,7 @@ func (s *StoreSource) AddProxy(proxy v1.ProxyConfigurer) error {
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if _, exists := s.proxies[name]; exists {
|
||||
return fmt.Errorf("proxy %q already exists", name)
|
||||
return fmt.Errorf("%w: proxy %q", ErrAlreadyExists, name)
|
||||
}
|
||||
|
||||
s.proxies[name] = proxy
|
||||
@@ -197,7 +213,7 @@ func (s *StoreSource) UpdateProxy(proxy v1.ProxyConfigurer) error {
|
||||
|
||||
oldProxy, exists := s.proxies[name]
|
||||
if !exists {
|
||||
return fmt.Errorf("proxy %q does not exist", name)
|
||||
return fmt.Errorf("%w: proxy %q", ErrNotFound, name)
|
||||
}
|
||||
|
||||
s.proxies[name] = proxy
|
||||
@@ -219,7 +235,7 @@ func (s *StoreSource) RemoveProxy(name string) error {
|
||||
|
||||
oldProxy, exists := s.proxies[name]
|
||||
if !exists {
|
||||
return fmt.Errorf("proxy %q does not exist", name)
|
||||
return fmt.Errorf("%w: proxy %q", ErrNotFound, name)
|
||||
}
|
||||
|
||||
delete(s.proxies, name)
|
||||
@@ -256,7 +272,7 @@ func (s *StoreSource) AddVisitor(visitor v1.VisitorConfigurer) error {
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if _, exists := s.visitors[name]; exists {
|
||||
return fmt.Errorf("visitor %q already exists", name)
|
||||
return fmt.Errorf("%w: visitor %q", ErrAlreadyExists, name)
|
||||
}
|
||||
|
||||
s.visitors[name] = visitor
|
||||
@@ -283,7 +299,7 @@ func (s *StoreSource) UpdateVisitor(visitor v1.VisitorConfigurer) error {
|
||||
|
||||
oldVisitor, exists := s.visitors[name]
|
||||
if !exists {
|
||||
return fmt.Errorf("visitor %q does not exist", name)
|
||||
return fmt.Errorf("%w: visitor %q", ErrNotFound, name)
|
||||
}
|
||||
|
||||
s.visitors[name] = visitor
|
||||
@@ -305,7 +321,7 @@ func (s *StoreSource) RemoveVisitor(name string) error {
|
||||
|
||||
oldVisitor, exists := s.visitors[name]
|
||||
if !exists {
|
||||
return fmt.Errorf("visitor %q does not exist", name)
|
||||
return fmt.Errorf("%w: visitor %q", ErrNotFound, name)
|
||||
}
|
||||
|
||||
delete(s.visitors, name)
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
package source
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@@ -23,6 +22,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/util/jsonx"
|
||||
)
|
||||
|
||||
func TestStoreSource_AddProxyAndVisitor_DoesNotApplyRuntimeDefaults(t *testing.T) {
|
||||
@@ -80,7 +80,7 @@ func TestStoreSource_LoadFromFile_DoesNotApplyRuntimeDefaults(t *testing.T) {
|
||||
Proxies: []v1.TypedProxyConfig{{ProxyConfigurer: proxyCfg}},
|
||||
Visitors: []v1.TypedVisitorConfig{{VisitorConfigurer: visitorCfg}},
|
||||
}
|
||||
data, err := json.Marshal(stored)
|
||||
data, err := jsonx.Marshal(stored)
|
||||
require.NoError(err)
|
||||
err = os.WriteFile(path, data, 0o600)
|
||||
require.NoError(err)
|
||||
@@ -97,3 +97,25 @@ func TestStoreSource_LoadFromFile_DoesNotApplyRuntimeDefaults(t *testing.T) {
|
||||
require.Empty(gotVisitor.GetBaseConfig().BindAddr)
|
||||
require.Empty(gotVisitor.(*v1.XTCPVisitorConfig).Protocol)
|
||||
}
|
||||
|
||||
func TestStoreSource_LoadFromFile_UnknownFieldsAreIgnored(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
path := filepath.Join(t.TempDir(), "store.json")
|
||||
raw := []byte(`{
|
||||
"proxies": [
|
||||
{"name":"proxy1","type":"tcp","localPort":10080,"unexpected":"value"}
|
||||
],
|
||||
"visitors": [
|
||||
{"name":"visitor1","type":"xtcp","serverName":"server1","secretKey":"secret","bindPort":10081,"unexpected":"value"}
|
||||
]
|
||||
}`)
|
||||
err := os.WriteFile(path, raw, 0o600)
|
||||
require.NoError(err)
|
||||
|
||||
storeSource, err := NewStoreSource(StoreSourceConfig{Path: path})
|
||||
require.NoError(err)
|
||||
|
||||
require.NotNil(storeSource.GetProxy("proxy1"))
|
||||
require.NotNil(storeSource.GetVisitor("visitor1"))
|
||||
}
|
||||
|
||||
@@ -38,7 +38,7 @@ func parseNumberRangePair(firstRangeStr, secondRangeStr string) ([]NumberPair, e
|
||||
return nil, fmt.Errorf("first and second range numbers are not in pairs")
|
||||
}
|
||||
pairs := make([]NumberPair, 0, len(firstRangeNumbers))
|
||||
for i := 0; i < len(firstRangeNumbers); i++ {
|
||||
for i := range firstRangeNumbers {
|
||||
pairs = append(pairs, NumberPair{
|
||||
First: firstRangeNumbers[i],
|
||||
Second: secondRangeNumbers[i],
|
||||
|
||||
@@ -70,24 +70,18 @@ func (q *BandwidthQuantity) UnmarshalString(s string) error {
|
||||
f float64
|
||||
err error
|
||||
)
|
||||
switch {
|
||||
case strings.HasSuffix(s, "MB"):
|
||||
if fstr, ok := strings.CutSuffix(s, "MB"); ok {
|
||||
base = MB
|
||||
fstr := strings.TrimSuffix(s, "MB")
|
||||
f, err = strconv.ParseFloat(fstr, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case strings.HasSuffix(s, "KB"):
|
||||
} else if fstr, ok := strings.CutSuffix(s, "KB"); ok {
|
||||
base = KB
|
||||
fstr := strings.TrimSuffix(s, "KB")
|
||||
f, err = strconv.ParseFloat(fstr, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
} else {
|
||||
return errors.New("unit not support")
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
q.s = s
|
||||
q.i = int64(f * float64(base))
|
||||
@@ -143,8 +137,8 @@ func (p PortsRangeSlice) String() string {
|
||||
func NewPortsRangeSliceFromString(str string) ([]PortsRange, error) {
|
||||
str = strings.TrimSpace(str)
|
||||
out := []PortsRange{}
|
||||
numRanges := strings.Split(str, ",")
|
||||
for _, numRangeStr := range numRanges {
|
||||
numRanges := strings.SplitSeq(str, ",")
|
||||
for numRangeStr := range numRanges {
|
||||
// 1000-2000 or 2001
|
||||
numArray := strings.Split(numRangeStr, "-")
|
||||
// length: only 1 or 2 is correct
|
||||
|
||||
@@ -39,6 +39,31 @@ func TestBandwidthQuantity(t *testing.T) {
|
||||
require.Equal(`{"b":"1KB","int":5}`, string(buf))
|
||||
}
|
||||
|
||||
func TestBandwidthQuantity_MB(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
var w Wrap
|
||||
err := json.Unmarshal([]byte(`{"b":"2MB","int":1}`), &w)
|
||||
require.NoError(err)
|
||||
require.EqualValues(2*MB, w.B.Bytes())
|
||||
|
||||
buf, err := json.Marshal(&w)
|
||||
require.NoError(err)
|
||||
require.Equal(`{"b":"2MB","int":1}`, string(buf))
|
||||
}
|
||||
|
||||
func TestBandwidthQuantity_InvalidUnit(t *testing.T) {
|
||||
var w Wrap
|
||||
err := json.Unmarshal([]byte(`{"b":"1GB","int":1}`), &w)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestBandwidthQuantity_InvalidNumber(t *testing.T) {
|
||||
var w Wrap
|
||||
err := json.Unmarshal([]byte(`{"b":"abcKB","int":1}`), &w)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestPortsRangeSlice2String(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
|
||||
109
pkg/config/v1/clone_test.go
Normal file
109
pkg/config/v1/clone_test.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestProxyCloneDeepCopy(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
enabled := true
|
||||
pluginHTTP2 := true
|
||||
cfg := &HTTPProxyConfig{
|
||||
ProxyBaseConfig: ProxyBaseConfig{
|
||||
Name: "p1",
|
||||
Type: "http",
|
||||
Enabled: &enabled,
|
||||
Annotations: map[string]string{"a": "1"},
|
||||
Metadatas: map[string]string{"m": "1"},
|
||||
HealthCheck: HealthCheckConfig{
|
||||
Type: "http",
|
||||
HTTPHeaders: []HTTPHeader{
|
||||
{Name: "X-Test", Value: "v1"},
|
||||
},
|
||||
},
|
||||
ProxyBackend: ProxyBackend{
|
||||
Plugin: TypedClientPluginOptions{
|
||||
Type: PluginHTTPS2HTTP,
|
||||
ClientPluginOptions: &HTTPS2HTTPPluginOptions{
|
||||
Type: PluginHTTPS2HTTP,
|
||||
EnableHTTP2: &pluginHTTP2,
|
||||
RequestHeaders: HeaderOperations{Set: map[string]string{"k": "v"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
DomainConfig: DomainConfig{
|
||||
CustomDomains: []string{"a.example.com"},
|
||||
SubDomain: "a",
|
||||
},
|
||||
Locations: []string{"/api"},
|
||||
RequestHeaders: HeaderOperations{Set: map[string]string{"h1": "v1"}},
|
||||
ResponseHeaders: HeaderOperations{Set: map[string]string{"h2": "v2"}},
|
||||
}
|
||||
|
||||
cloned := cfg.Clone().(*HTTPProxyConfig)
|
||||
|
||||
*cloned.Enabled = false
|
||||
cloned.Annotations["a"] = "changed"
|
||||
cloned.Metadatas["m"] = "changed"
|
||||
cloned.HealthCheck.HTTPHeaders[0].Value = "changed"
|
||||
cloned.CustomDomains[0] = "b.example.com"
|
||||
cloned.Locations[0] = "/new"
|
||||
cloned.RequestHeaders.Set["h1"] = "changed"
|
||||
cloned.ResponseHeaders.Set["h2"] = "changed"
|
||||
clientPlugin := cloned.Plugin.ClientPluginOptions.(*HTTPS2HTTPPluginOptions)
|
||||
*clientPlugin.EnableHTTP2 = false
|
||||
clientPlugin.RequestHeaders.Set["k"] = "changed"
|
||||
|
||||
require.True(*cfg.Enabled)
|
||||
require.Equal("1", cfg.Annotations["a"])
|
||||
require.Equal("1", cfg.Metadatas["m"])
|
||||
require.Equal("v1", cfg.HealthCheck.HTTPHeaders[0].Value)
|
||||
require.Equal("a.example.com", cfg.CustomDomains[0])
|
||||
require.Equal("/api", cfg.Locations[0])
|
||||
require.Equal("v1", cfg.RequestHeaders.Set["h1"])
|
||||
require.Equal("v2", cfg.ResponseHeaders.Set["h2"])
|
||||
|
||||
origPlugin := cfg.Plugin.ClientPluginOptions.(*HTTPS2HTTPPluginOptions)
|
||||
require.True(*origPlugin.EnableHTTP2)
|
||||
require.Equal("v", origPlugin.RequestHeaders.Set["k"])
|
||||
}
|
||||
|
||||
func TestVisitorCloneDeepCopy(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
enabled := true
|
||||
cfg := &XTCPVisitorConfig{
|
||||
VisitorBaseConfig: VisitorBaseConfig{
|
||||
Name: "v1",
|
||||
Type: "xtcp",
|
||||
Enabled: &enabled,
|
||||
ServerName: "server",
|
||||
BindPort: 7000,
|
||||
Plugin: TypedVisitorPluginOptions{
|
||||
Type: VisitorPluginVirtualNet,
|
||||
VisitorPluginOptions: &VirtualNetVisitorPluginOptions{
|
||||
Type: VisitorPluginVirtualNet,
|
||||
DestinationIP: "10.0.0.1",
|
||||
},
|
||||
},
|
||||
},
|
||||
NatTraversal: &NatTraversalConfig{
|
||||
DisableAssistedAddrs: true,
|
||||
},
|
||||
}
|
||||
|
||||
cloned := cfg.Clone().(*XTCPVisitorConfig)
|
||||
*cloned.Enabled = false
|
||||
cloned.NatTraversal.DisableAssistedAddrs = false
|
||||
visitorPlugin := cloned.Plugin.VisitorPluginOptions.(*VirtualNetVisitorPluginOptions)
|
||||
visitorPlugin.DestinationIP = "10.0.0.2"
|
||||
|
||||
require.True(*cfg.Enabled)
|
||||
require.True(cfg.NatTraversal.DisableAssistedAddrs)
|
||||
origPlugin := cfg.Plugin.VisitorPluginOptions.(*VirtualNetVisitorPluginOptions)
|
||||
require.Equal("10.0.0.1", origPlugin.DestinationIP)
|
||||
}
|
||||
@@ -15,23 +15,11 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"maps"
|
||||
|
||||
"github.com/fatedier/frp/pkg/util/util"
|
||||
)
|
||||
|
||||
// TODO(fatedier): Due to the current implementation issue of the go json library, the UnmarshalJSON method
|
||||
// of a custom struct cannot access the DisallowUnknownFields parameter of the parent decoder.
|
||||
// Here, a global variable is temporarily used to control whether unknown fields are allowed.
|
||||
// Once the v2 version is implemented by the community, we can switch to a standardized approach.
|
||||
//
|
||||
// https://github.com/golang/go/issues/41144
|
||||
// https://github.com/golang/go/discussions/63397
|
||||
var (
|
||||
DisallowUnknownFields = false
|
||||
DisallowUnknownFieldsMu sync.Mutex
|
||||
)
|
||||
|
||||
type AuthScope string
|
||||
|
||||
const (
|
||||
@@ -104,6 +92,14 @@ type NatTraversalConfig struct {
|
||||
DisableAssistedAddrs bool `json:"disableAssistedAddrs,omitempty"`
|
||||
}
|
||||
|
||||
func (c *NatTraversalConfig) Clone() *NatTraversalConfig {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
out := *c
|
||||
return &out
|
||||
}
|
||||
|
||||
type LogConfig struct {
|
||||
// This is destination where frp should write the logs.
|
||||
// If "console" is used, logs will be printed to stdout, otherwise,
|
||||
@@ -138,6 +134,12 @@ type HeaderOperations struct {
|
||||
Set map[string]string `json:"set,omitempty"`
|
||||
}
|
||||
|
||||
func (o HeaderOperations) Clone() HeaderOperations {
|
||||
return HeaderOperations{
|
||||
Set: maps.Clone(o.Set),
|
||||
}
|
||||
}
|
||||
|
||||
type HTTPHeader struct {
|
||||
Name string `json:"name"`
|
||||
Value string `json:"value"`
|
||||
|
||||
195
pkg/config/v1/decode.go
Normal file
195
pkg/config/v1/decode.go
Normal file
@@ -0,0 +1,195 @@
|
||||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package v1
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/fatedier/frp/pkg/util/jsonx"
|
||||
)
|
||||
|
||||
type DecodeOptions struct {
|
||||
DisallowUnknownFields bool
|
||||
}
|
||||
|
||||
func decodeJSONWithOptions(b []byte, out any, options DecodeOptions) error {
|
||||
return jsonx.UnmarshalWithOptions(b, out, jsonx.DecodeOptions{
|
||||
RejectUnknownMembers: options.DisallowUnknownFields,
|
||||
})
|
||||
}
|
||||
|
||||
func isJSONNull(b []byte) bool {
|
||||
return len(b) == 0 || string(b) == "null"
|
||||
}
|
||||
|
||||
type typedEnvelope struct {
|
||||
Type string `json:"type"`
|
||||
Plugin jsonx.RawMessage `json:"plugin,omitempty"`
|
||||
}
|
||||
|
||||
func DecodeProxyConfigurerJSON(b []byte, options DecodeOptions) (ProxyConfigurer, error) {
|
||||
if isJSONNull(b) {
|
||||
return nil, errors.New("type is required")
|
||||
}
|
||||
|
||||
var env typedEnvelope
|
||||
if err := jsonx.Unmarshal(b, &env); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
configurer := NewProxyConfigurerByType(ProxyType(env.Type))
|
||||
if configurer == nil {
|
||||
return nil, fmt.Errorf("unknown proxy type: %s", env.Type)
|
||||
}
|
||||
if err := decodeJSONWithOptions(b, configurer, options); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal ProxyConfig error: %v", err)
|
||||
}
|
||||
|
||||
if len(env.Plugin) > 0 && !isJSONNull(env.Plugin) {
|
||||
plugin, err := DecodeClientPluginOptionsJSON(env.Plugin, options)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unmarshal proxy plugin error: %v", err)
|
||||
}
|
||||
configurer.GetBaseConfig().Plugin = plugin
|
||||
}
|
||||
return configurer, nil
|
||||
}
|
||||
|
||||
func DecodeVisitorConfigurerJSON(b []byte, options DecodeOptions) (VisitorConfigurer, error) {
|
||||
if isJSONNull(b) {
|
||||
return nil, errors.New("type is required")
|
||||
}
|
||||
|
||||
var env typedEnvelope
|
||||
if err := jsonx.Unmarshal(b, &env); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
configurer := NewVisitorConfigurerByType(VisitorType(env.Type))
|
||||
if configurer == nil {
|
||||
return nil, fmt.Errorf("unknown visitor type: %s", env.Type)
|
||||
}
|
||||
if err := decodeJSONWithOptions(b, configurer, options); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal VisitorConfig error: %v", err)
|
||||
}
|
||||
|
||||
if len(env.Plugin) > 0 && !isJSONNull(env.Plugin) {
|
||||
plugin, err := DecodeVisitorPluginOptionsJSON(env.Plugin, options)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unmarshal visitor plugin error: %v", err)
|
||||
}
|
||||
configurer.GetBaseConfig().Plugin = plugin
|
||||
}
|
||||
return configurer, nil
|
||||
}
|
||||
|
||||
func DecodeClientPluginOptionsJSON(b []byte, options DecodeOptions) (TypedClientPluginOptions, error) {
|
||||
if isJSONNull(b) {
|
||||
return TypedClientPluginOptions{}, nil
|
||||
}
|
||||
|
||||
var env typedEnvelope
|
||||
if err := jsonx.Unmarshal(b, &env); err != nil {
|
||||
return TypedClientPluginOptions{}, err
|
||||
}
|
||||
if env.Type == "" {
|
||||
return TypedClientPluginOptions{}, errors.New("plugin type is empty")
|
||||
}
|
||||
|
||||
v, ok := clientPluginOptionsTypeMap[env.Type]
|
||||
if !ok {
|
||||
return TypedClientPluginOptions{}, fmt.Errorf("unknown plugin type: %s", env.Type)
|
||||
}
|
||||
optionsStruct := reflect.New(v).Interface().(ClientPluginOptions)
|
||||
if err := decodeJSONWithOptions(b, optionsStruct, options); err != nil {
|
||||
return TypedClientPluginOptions{}, fmt.Errorf("unmarshal ClientPluginOptions error: %v", err)
|
||||
}
|
||||
return TypedClientPluginOptions{
|
||||
Type: env.Type,
|
||||
ClientPluginOptions: optionsStruct,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func DecodeVisitorPluginOptionsJSON(b []byte, options DecodeOptions) (TypedVisitorPluginOptions, error) {
|
||||
if isJSONNull(b) {
|
||||
return TypedVisitorPluginOptions{}, nil
|
||||
}
|
||||
|
||||
var env typedEnvelope
|
||||
if err := jsonx.Unmarshal(b, &env); err != nil {
|
||||
return TypedVisitorPluginOptions{}, err
|
||||
}
|
||||
if env.Type == "" {
|
||||
return TypedVisitorPluginOptions{}, errors.New("visitor plugin type is empty")
|
||||
}
|
||||
|
||||
v, ok := visitorPluginOptionsTypeMap[env.Type]
|
||||
if !ok {
|
||||
return TypedVisitorPluginOptions{}, fmt.Errorf("unknown visitor plugin type: %s", env.Type)
|
||||
}
|
||||
optionsStruct := reflect.New(v).Interface().(VisitorPluginOptions)
|
||||
if err := decodeJSONWithOptions(b, optionsStruct, options); err != nil {
|
||||
return TypedVisitorPluginOptions{}, fmt.Errorf("unmarshal VisitorPluginOptions error: %v", err)
|
||||
}
|
||||
return TypedVisitorPluginOptions{
|
||||
Type: env.Type,
|
||||
VisitorPluginOptions: optionsStruct,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func DecodeClientConfigJSON(b []byte, options DecodeOptions) (ClientConfig, error) {
|
||||
type rawClientConfig struct {
|
||||
ClientCommonConfig
|
||||
Proxies []jsonx.RawMessage `json:"proxies,omitempty"`
|
||||
Visitors []jsonx.RawMessage `json:"visitors,omitempty"`
|
||||
}
|
||||
|
||||
raw := rawClientConfig{}
|
||||
if err := decodeJSONWithOptions(b, &raw, options); err != nil {
|
||||
return ClientConfig{}, err
|
||||
}
|
||||
|
||||
cfg := ClientConfig{
|
||||
ClientCommonConfig: raw.ClientCommonConfig,
|
||||
Proxies: make([]TypedProxyConfig, 0, len(raw.Proxies)),
|
||||
Visitors: make([]TypedVisitorConfig, 0, len(raw.Visitors)),
|
||||
}
|
||||
|
||||
for i, proxyData := range raw.Proxies {
|
||||
proxyCfg, err := DecodeProxyConfigurerJSON(proxyData, options)
|
||||
if err != nil {
|
||||
return ClientConfig{}, fmt.Errorf("decode proxy at index %d: %w", i, err)
|
||||
}
|
||||
cfg.Proxies = append(cfg.Proxies, TypedProxyConfig{
|
||||
Type: proxyCfg.GetBaseConfig().Type,
|
||||
ProxyConfigurer: proxyCfg,
|
||||
})
|
||||
}
|
||||
|
||||
for i, visitorData := range raw.Visitors {
|
||||
visitorCfg, err := DecodeVisitorConfigurerJSON(visitorData, options)
|
||||
if err != nil {
|
||||
return ClientConfig{}, fmt.Errorf("decode visitor at index %d: %w", i, err)
|
||||
}
|
||||
cfg.Visitors = append(cfg.Visitors, TypedVisitorConfig{
|
||||
Type: visitorCfg.GetBaseConfig().Type,
|
||||
VisitorConfigurer: visitorCfg,
|
||||
})
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
86
pkg/config/v1/decode_test.go
Normal file
86
pkg/config/v1/decode_test.go
Normal file
@@ -0,0 +1,86 @@
|
||||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package v1
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDecodeProxyConfigurerJSON_StrictPluginUnknownFields(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
data := []byte(`{
|
||||
"name":"p1",
|
||||
"type":"tcp",
|
||||
"localPort":10080,
|
||||
"plugin":{
|
||||
"type":"http2https",
|
||||
"localAddr":"127.0.0.1:8080",
|
||||
"unknownInPlugin":"value"
|
||||
}
|
||||
}`)
|
||||
|
||||
_, err := DecodeProxyConfigurerJSON(data, DecodeOptions{DisallowUnknownFields: false})
|
||||
require.NoError(err)
|
||||
|
||||
_, err = DecodeProxyConfigurerJSON(data, DecodeOptions{DisallowUnknownFields: true})
|
||||
require.ErrorContains(err, "unknownInPlugin")
|
||||
}
|
||||
|
||||
func TestDecodeVisitorConfigurerJSON_StrictPluginUnknownFields(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
data := []byte(`{
|
||||
"name":"v1",
|
||||
"type":"stcp",
|
||||
"serverName":"server",
|
||||
"bindPort":10081,
|
||||
"plugin":{
|
||||
"type":"virtual_net",
|
||||
"destinationIP":"10.0.0.1",
|
||||
"unknownInPlugin":"value"
|
||||
}
|
||||
}`)
|
||||
|
||||
_, err := DecodeVisitorConfigurerJSON(data, DecodeOptions{DisallowUnknownFields: false})
|
||||
require.NoError(err)
|
||||
|
||||
_, err = DecodeVisitorConfigurerJSON(data, DecodeOptions{DisallowUnknownFields: true})
|
||||
require.ErrorContains(err, "unknownInPlugin")
|
||||
}
|
||||
|
||||
func TestDecodeClientConfigJSON_StrictUnknownProxyField(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
data := []byte(`{
|
||||
"serverPort":7000,
|
||||
"proxies":[
|
||||
{
|
||||
"name":"p1",
|
||||
"type":"tcp",
|
||||
"localPort":10080,
|
||||
"unknownField":"value"
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
_, err := DecodeClientConfigJSON(data, DecodeOptions{DisallowUnknownFields: false})
|
||||
require.NoError(err)
|
||||
|
||||
_, err = DecodeClientConfigJSON(data, DecodeOptions{DisallowUnknownFields: true})
|
||||
require.ErrorContains(err, "unknownField")
|
||||
}
|
||||
@@ -15,14 +15,13 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"reflect"
|
||||
"slices"
|
||||
|
||||
"github.com/fatedier/frp/pkg/config/types"
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
"github.com/fatedier/frp/pkg/util/jsonx"
|
||||
"github.com/fatedier/frp/pkg/util/util"
|
||||
)
|
||||
|
||||
@@ -100,11 +99,23 @@ type HealthCheckConfig struct {
|
||||
HTTPHeaders []HTTPHeader `json:"httpHeaders,omitempty"`
|
||||
}
|
||||
|
||||
func (c HealthCheckConfig) Clone() HealthCheckConfig {
|
||||
out := c
|
||||
out.HTTPHeaders = slices.Clone(c.HTTPHeaders)
|
||||
return out
|
||||
}
|
||||
|
||||
type DomainConfig struct {
|
||||
CustomDomains []string `json:"customDomains,omitempty"`
|
||||
SubDomain string `json:"subdomain,omitempty"`
|
||||
}
|
||||
|
||||
func (c DomainConfig) Clone() DomainConfig {
|
||||
out := c
|
||||
out.CustomDomains = slices.Clone(c.CustomDomains)
|
||||
return out
|
||||
}
|
||||
|
||||
type ProxyBaseConfig struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
@@ -120,6 +131,22 @@ type ProxyBaseConfig struct {
|
||||
ProxyBackend
|
||||
}
|
||||
|
||||
func (c ProxyBaseConfig) Clone() ProxyBaseConfig {
|
||||
out := c
|
||||
out.Enabled = util.ClonePtr(c.Enabled)
|
||||
out.Annotations = maps.Clone(c.Annotations)
|
||||
out.Metadatas = maps.Clone(c.Metadatas)
|
||||
out.HealthCheck = c.HealthCheck.Clone()
|
||||
out.ProxyBackend = c.ProxyBackend.Clone()
|
||||
return out
|
||||
}
|
||||
|
||||
func (c ProxyBackend) Clone() ProxyBackend {
|
||||
out := c
|
||||
out.Plugin = c.Plugin.Clone()
|
||||
return out
|
||||
}
|
||||
|
||||
func (c *ProxyBaseConfig) GetBaseConfig() *ProxyBaseConfig {
|
||||
return c
|
||||
}
|
||||
@@ -172,40 +199,24 @@ type TypedProxyConfig struct {
|
||||
}
|
||||
|
||||
func (c *TypedProxyConfig) UnmarshalJSON(b []byte) error {
|
||||
if len(b) == 4 && string(b) == "null" {
|
||||
return errors.New("type is required")
|
||||
}
|
||||
|
||||
typeStruct := struct {
|
||||
Type string `json:"type"`
|
||||
}{}
|
||||
if err := json.Unmarshal(b, &typeStruct); err != nil {
|
||||
configurer, err := DecodeProxyConfigurerJSON(b, DecodeOptions{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.Type = typeStruct.Type
|
||||
configurer := NewProxyConfigurerByType(ProxyType(typeStruct.Type))
|
||||
if configurer == nil {
|
||||
return fmt.Errorf("unknown proxy type: %s", typeStruct.Type)
|
||||
}
|
||||
decoder := json.NewDecoder(bytes.NewBuffer(b))
|
||||
if DisallowUnknownFields {
|
||||
decoder.DisallowUnknownFields()
|
||||
}
|
||||
if err := decoder.Decode(configurer); err != nil {
|
||||
return fmt.Errorf("unmarshal ProxyConfig error: %v", err)
|
||||
}
|
||||
c.Type = configurer.GetBaseConfig().Type
|
||||
c.ProxyConfigurer = configurer
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *TypedProxyConfig) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(c.ProxyConfigurer)
|
||||
return jsonx.Marshal(c.ProxyConfigurer)
|
||||
}
|
||||
|
||||
type ProxyConfigurer interface {
|
||||
Complete()
|
||||
GetBaseConfig() *ProxyBaseConfig
|
||||
Clone() ProxyConfigurer
|
||||
// MarshalToMsg marshals this config into a msg.NewProxy message. This
|
||||
// function will be called on the frpc side.
|
||||
MarshalToMsg(*msg.NewProxy)
|
||||
@@ -228,14 +239,14 @@ const (
|
||||
)
|
||||
|
||||
var proxyConfigTypeMap = map[ProxyType]reflect.Type{
|
||||
ProxyTypeTCP: reflect.TypeOf(TCPProxyConfig{}),
|
||||
ProxyTypeUDP: reflect.TypeOf(UDPProxyConfig{}),
|
||||
ProxyTypeHTTP: reflect.TypeOf(HTTPProxyConfig{}),
|
||||
ProxyTypeHTTPS: reflect.TypeOf(HTTPSProxyConfig{}),
|
||||
ProxyTypeTCPMUX: reflect.TypeOf(TCPMuxProxyConfig{}),
|
||||
ProxyTypeSTCP: reflect.TypeOf(STCPProxyConfig{}),
|
||||
ProxyTypeXTCP: reflect.TypeOf(XTCPProxyConfig{}),
|
||||
ProxyTypeSUDP: reflect.TypeOf(SUDPProxyConfig{}),
|
||||
ProxyTypeTCP: reflect.TypeFor[TCPProxyConfig](),
|
||||
ProxyTypeUDP: reflect.TypeFor[UDPProxyConfig](),
|
||||
ProxyTypeHTTP: reflect.TypeFor[HTTPProxyConfig](),
|
||||
ProxyTypeHTTPS: reflect.TypeFor[HTTPSProxyConfig](),
|
||||
ProxyTypeTCPMUX: reflect.TypeFor[TCPMuxProxyConfig](),
|
||||
ProxyTypeSTCP: reflect.TypeFor[STCPProxyConfig](),
|
||||
ProxyTypeXTCP: reflect.TypeFor[XTCPProxyConfig](),
|
||||
ProxyTypeSUDP: reflect.TypeFor[SUDPProxyConfig](),
|
||||
}
|
||||
|
||||
func NewProxyConfigurerByType(proxyType ProxyType) ProxyConfigurer {
|
||||
@@ -268,6 +279,12 @@ func (c *TCPProxyConfig) UnmarshalFromMsg(m *msg.NewProxy) {
|
||||
c.RemotePort = m.RemotePort
|
||||
}
|
||||
|
||||
func (c *TCPProxyConfig) Clone() ProxyConfigurer {
|
||||
out := *c
|
||||
out.ProxyBaseConfig = c.ProxyBaseConfig.Clone()
|
||||
return &out
|
||||
}
|
||||
|
||||
var _ ProxyConfigurer = &UDPProxyConfig{}
|
||||
|
||||
type UDPProxyConfig struct {
|
||||
@@ -288,6 +305,12 @@ func (c *UDPProxyConfig) UnmarshalFromMsg(m *msg.NewProxy) {
|
||||
c.RemotePort = m.RemotePort
|
||||
}
|
||||
|
||||
func (c *UDPProxyConfig) Clone() ProxyConfigurer {
|
||||
out := *c
|
||||
out.ProxyBaseConfig = c.ProxyBaseConfig.Clone()
|
||||
return &out
|
||||
}
|
||||
|
||||
var _ ProxyConfigurer = &HTTPProxyConfig{}
|
||||
|
||||
type HTTPProxyConfig struct {
|
||||
@@ -331,6 +354,16 @@ func (c *HTTPProxyConfig) UnmarshalFromMsg(m *msg.NewProxy) {
|
||||
c.RouteByHTTPUser = m.RouteByHTTPUser
|
||||
}
|
||||
|
||||
func (c *HTTPProxyConfig) Clone() ProxyConfigurer {
|
||||
out := *c
|
||||
out.ProxyBaseConfig = c.ProxyBaseConfig.Clone()
|
||||
out.DomainConfig = c.DomainConfig.Clone()
|
||||
out.Locations = slices.Clone(c.Locations)
|
||||
out.RequestHeaders = c.RequestHeaders.Clone()
|
||||
out.ResponseHeaders = c.ResponseHeaders.Clone()
|
||||
return &out
|
||||
}
|
||||
|
||||
var _ ProxyConfigurer = &HTTPSProxyConfig{}
|
||||
|
||||
type HTTPSProxyConfig struct {
|
||||
@@ -352,6 +385,13 @@ func (c *HTTPSProxyConfig) UnmarshalFromMsg(m *msg.NewProxy) {
|
||||
c.SubDomain = m.SubDomain
|
||||
}
|
||||
|
||||
func (c *HTTPSProxyConfig) Clone() ProxyConfigurer {
|
||||
out := *c
|
||||
out.ProxyBaseConfig = c.ProxyBaseConfig.Clone()
|
||||
out.DomainConfig = c.DomainConfig.Clone()
|
||||
return &out
|
||||
}
|
||||
|
||||
type TCPMultiplexerType string
|
||||
|
||||
const (
|
||||
@@ -392,6 +432,13 @@ func (c *TCPMuxProxyConfig) UnmarshalFromMsg(m *msg.NewProxy) {
|
||||
c.RouteByHTTPUser = m.RouteByHTTPUser
|
||||
}
|
||||
|
||||
func (c *TCPMuxProxyConfig) Clone() ProxyConfigurer {
|
||||
out := *c
|
||||
out.ProxyBaseConfig = c.ProxyBaseConfig.Clone()
|
||||
out.DomainConfig = c.DomainConfig.Clone()
|
||||
return &out
|
||||
}
|
||||
|
||||
var _ ProxyConfigurer = &STCPProxyConfig{}
|
||||
|
||||
type STCPProxyConfig struct {
|
||||
@@ -415,6 +462,13 @@ func (c *STCPProxyConfig) UnmarshalFromMsg(m *msg.NewProxy) {
|
||||
c.AllowUsers = m.AllowUsers
|
||||
}
|
||||
|
||||
func (c *STCPProxyConfig) Clone() ProxyConfigurer {
|
||||
out := *c
|
||||
out.ProxyBaseConfig = c.ProxyBaseConfig.Clone()
|
||||
out.AllowUsers = slices.Clone(c.AllowUsers)
|
||||
return &out
|
||||
}
|
||||
|
||||
var _ ProxyConfigurer = &XTCPProxyConfig{}
|
||||
|
||||
type XTCPProxyConfig struct {
|
||||
@@ -441,6 +495,14 @@ func (c *XTCPProxyConfig) UnmarshalFromMsg(m *msg.NewProxy) {
|
||||
c.AllowUsers = m.AllowUsers
|
||||
}
|
||||
|
||||
func (c *XTCPProxyConfig) Clone() ProxyConfigurer {
|
||||
out := *c
|
||||
out.ProxyBaseConfig = c.ProxyBaseConfig.Clone()
|
||||
out.AllowUsers = slices.Clone(c.AllowUsers)
|
||||
out.NatTraversal = c.NatTraversal.Clone()
|
||||
return &out
|
||||
}
|
||||
|
||||
var _ ProxyConfigurer = &SUDPProxyConfig{}
|
||||
|
||||
type SUDPProxyConfig struct {
|
||||
@@ -463,3 +525,10 @@ func (c *SUDPProxyConfig) UnmarshalFromMsg(m *msg.NewProxy) {
|
||||
c.Secretkey = m.Sk
|
||||
c.AllowUsers = m.AllowUsers
|
||||
}
|
||||
|
||||
func (c *SUDPProxyConfig) Clone() ProxyConfigurer {
|
||||
out := *c
|
||||
out.ProxyBaseConfig = c.ProxyBaseConfig.Clone()
|
||||
out.AllowUsers = slices.Clone(c.AllowUsers)
|
||||
return &out
|
||||
}
|
||||
|
||||
@@ -15,14 +15,11 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/samber/lo"
|
||||
|
||||
"github.com/fatedier/frp/pkg/util/jsonx"
|
||||
"github.com/fatedier/frp/pkg/util/util"
|
||||
)
|
||||
|
||||
@@ -40,20 +37,21 @@ const (
|
||||
)
|
||||
|
||||
var clientPluginOptionsTypeMap = map[string]reflect.Type{
|
||||
PluginHTTP2HTTPS: reflect.TypeOf(HTTP2HTTPSPluginOptions{}),
|
||||
PluginHTTPProxy: reflect.TypeOf(HTTPProxyPluginOptions{}),
|
||||
PluginHTTPS2HTTP: reflect.TypeOf(HTTPS2HTTPPluginOptions{}),
|
||||
PluginHTTPS2HTTPS: reflect.TypeOf(HTTPS2HTTPSPluginOptions{}),
|
||||
PluginHTTP2HTTP: reflect.TypeOf(HTTP2HTTPPluginOptions{}),
|
||||
PluginSocks5: reflect.TypeOf(Socks5PluginOptions{}),
|
||||
PluginStaticFile: reflect.TypeOf(StaticFilePluginOptions{}),
|
||||
PluginUnixDomainSocket: reflect.TypeOf(UnixDomainSocketPluginOptions{}),
|
||||
PluginTLS2Raw: reflect.TypeOf(TLS2RawPluginOptions{}),
|
||||
PluginVirtualNet: reflect.TypeOf(VirtualNetPluginOptions{}),
|
||||
PluginHTTP2HTTPS: reflect.TypeFor[HTTP2HTTPSPluginOptions](),
|
||||
PluginHTTPProxy: reflect.TypeFor[HTTPProxyPluginOptions](),
|
||||
PluginHTTPS2HTTP: reflect.TypeFor[HTTPS2HTTPPluginOptions](),
|
||||
PluginHTTPS2HTTPS: reflect.TypeFor[HTTPS2HTTPSPluginOptions](),
|
||||
PluginHTTP2HTTP: reflect.TypeFor[HTTP2HTTPPluginOptions](),
|
||||
PluginSocks5: reflect.TypeFor[Socks5PluginOptions](),
|
||||
PluginStaticFile: reflect.TypeFor[StaticFilePluginOptions](),
|
||||
PluginUnixDomainSocket: reflect.TypeFor[UnixDomainSocketPluginOptions](),
|
||||
PluginTLS2Raw: reflect.TypeFor[TLS2RawPluginOptions](),
|
||||
PluginVirtualNet: reflect.TypeFor[VirtualNetPluginOptions](),
|
||||
}
|
||||
|
||||
type ClientPluginOptions interface {
|
||||
Complete()
|
||||
Clone() ClientPluginOptions
|
||||
}
|
||||
|
||||
type TypedClientPluginOptions struct {
|
||||
@@ -61,43 +59,25 @@ type TypedClientPluginOptions struct {
|
||||
ClientPluginOptions
|
||||
}
|
||||
|
||||
func (c *TypedClientPluginOptions) UnmarshalJSON(b []byte) error {
|
||||
if len(b) == 4 && string(b) == "null" {
|
||||
return nil
|
||||
func (c TypedClientPluginOptions) Clone() TypedClientPluginOptions {
|
||||
out := c
|
||||
if c.ClientPluginOptions != nil {
|
||||
out.ClientPluginOptions = c.ClientPluginOptions.Clone()
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
typeStruct := struct {
|
||||
Type string `json:"type"`
|
||||
}{}
|
||||
if err := json.Unmarshal(b, &typeStruct); err != nil {
|
||||
func (c *TypedClientPluginOptions) UnmarshalJSON(b []byte) error {
|
||||
decoded, err := DecodeClientPluginOptionsJSON(b, DecodeOptions{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.Type = typeStruct.Type
|
||||
if c.Type == "" {
|
||||
return errors.New("plugin type is empty")
|
||||
}
|
||||
|
||||
v, ok := clientPluginOptionsTypeMap[typeStruct.Type]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown plugin type: %s", typeStruct.Type)
|
||||
}
|
||||
options := reflect.New(v).Interface().(ClientPluginOptions)
|
||||
|
||||
decoder := json.NewDecoder(bytes.NewBuffer(b))
|
||||
if DisallowUnknownFields {
|
||||
decoder.DisallowUnknownFields()
|
||||
}
|
||||
|
||||
if err := decoder.Decode(options); err != nil {
|
||||
return fmt.Errorf("unmarshal ClientPluginOptions error: %v", err)
|
||||
}
|
||||
c.ClientPluginOptions = options
|
||||
*c = decoded
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *TypedClientPluginOptions) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(c.ClientPluginOptions)
|
||||
return jsonx.Marshal(c.ClientPluginOptions)
|
||||
}
|
||||
|
||||
type HTTP2HTTPSPluginOptions struct {
|
||||
@@ -109,6 +89,15 @@ type HTTP2HTTPSPluginOptions struct {
|
||||
|
||||
func (o *HTTP2HTTPSPluginOptions) Complete() {}
|
||||
|
||||
func (o *HTTP2HTTPSPluginOptions) Clone() ClientPluginOptions {
|
||||
if o == nil {
|
||||
return nil
|
||||
}
|
||||
out := *o
|
||||
out.RequestHeaders = o.RequestHeaders.Clone()
|
||||
return &out
|
||||
}
|
||||
|
||||
type HTTPProxyPluginOptions struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
HTTPUser string `json:"httpUser,omitempty"`
|
||||
@@ -117,6 +106,14 @@ type HTTPProxyPluginOptions struct {
|
||||
|
||||
func (o *HTTPProxyPluginOptions) Complete() {}
|
||||
|
||||
func (o *HTTPProxyPluginOptions) Clone() ClientPluginOptions {
|
||||
if o == nil {
|
||||
return nil
|
||||
}
|
||||
out := *o
|
||||
return &out
|
||||
}
|
||||
|
||||
type HTTPS2HTTPPluginOptions struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
LocalAddr string `json:"localAddr,omitempty"`
|
||||
@@ -131,6 +128,16 @@ func (o *HTTPS2HTTPPluginOptions) Complete() {
|
||||
o.EnableHTTP2 = util.EmptyOr(o.EnableHTTP2, lo.ToPtr(true))
|
||||
}
|
||||
|
||||
func (o *HTTPS2HTTPPluginOptions) Clone() ClientPluginOptions {
|
||||
if o == nil {
|
||||
return nil
|
||||
}
|
||||
out := *o
|
||||
out.RequestHeaders = o.RequestHeaders.Clone()
|
||||
out.EnableHTTP2 = util.ClonePtr(o.EnableHTTP2)
|
||||
return &out
|
||||
}
|
||||
|
||||
type HTTPS2HTTPSPluginOptions struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
LocalAddr string `json:"localAddr,omitempty"`
|
||||
@@ -145,6 +152,16 @@ func (o *HTTPS2HTTPSPluginOptions) Complete() {
|
||||
o.EnableHTTP2 = util.EmptyOr(o.EnableHTTP2, lo.ToPtr(true))
|
||||
}
|
||||
|
||||
func (o *HTTPS2HTTPSPluginOptions) Clone() ClientPluginOptions {
|
||||
if o == nil {
|
||||
return nil
|
||||
}
|
||||
out := *o
|
||||
out.RequestHeaders = o.RequestHeaders.Clone()
|
||||
out.EnableHTTP2 = util.ClonePtr(o.EnableHTTP2)
|
||||
return &out
|
||||
}
|
||||
|
||||
type HTTP2HTTPPluginOptions struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
LocalAddr string `json:"localAddr,omitempty"`
|
||||
@@ -154,6 +171,15 @@ type HTTP2HTTPPluginOptions struct {
|
||||
|
||||
func (o *HTTP2HTTPPluginOptions) Complete() {}
|
||||
|
||||
func (o *HTTP2HTTPPluginOptions) Clone() ClientPluginOptions {
|
||||
if o == nil {
|
||||
return nil
|
||||
}
|
||||
out := *o
|
||||
out.RequestHeaders = o.RequestHeaders.Clone()
|
||||
return &out
|
||||
}
|
||||
|
||||
type Socks5PluginOptions struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Username string `json:"username,omitempty"`
|
||||
@@ -162,6 +188,14 @@ type Socks5PluginOptions struct {
|
||||
|
||||
func (o *Socks5PluginOptions) Complete() {}
|
||||
|
||||
func (o *Socks5PluginOptions) Clone() ClientPluginOptions {
|
||||
if o == nil {
|
||||
return nil
|
||||
}
|
||||
out := *o
|
||||
return &out
|
||||
}
|
||||
|
||||
type StaticFilePluginOptions struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
LocalPath string `json:"localPath,omitempty"`
|
||||
@@ -172,6 +206,14 @@ type StaticFilePluginOptions struct {
|
||||
|
||||
func (o *StaticFilePluginOptions) Complete() {}
|
||||
|
||||
func (o *StaticFilePluginOptions) Clone() ClientPluginOptions {
|
||||
if o == nil {
|
||||
return nil
|
||||
}
|
||||
out := *o
|
||||
return &out
|
||||
}
|
||||
|
||||
type UnixDomainSocketPluginOptions struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
UnixPath string `json:"unixPath,omitempty"`
|
||||
@@ -179,6 +221,14 @@ type UnixDomainSocketPluginOptions struct {
|
||||
|
||||
func (o *UnixDomainSocketPluginOptions) Complete() {}
|
||||
|
||||
func (o *UnixDomainSocketPluginOptions) Clone() ClientPluginOptions {
|
||||
if o == nil {
|
||||
return nil
|
||||
}
|
||||
out := *o
|
||||
return &out
|
||||
}
|
||||
|
||||
type TLS2RawPluginOptions struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
LocalAddr string `json:"localAddr,omitempty"`
|
||||
@@ -188,8 +238,24 @@ type TLS2RawPluginOptions struct {
|
||||
|
||||
func (o *TLS2RawPluginOptions) Complete() {}
|
||||
|
||||
func (o *TLS2RawPluginOptions) Clone() ClientPluginOptions {
|
||||
if o == nil {
|
||||
return nil
|
||||
}
|
||||
out := *o
|
||||
return &out
|
||||
}
|
||||
|
||||
type VirtualNetPluginOptions struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
}
|
||||
|
||||
func (o *VirtualNetPluginOptions) Complete() {}
|
||||
|
||||
func (o *VirtualNetPluginOptions) Clone() ClientPluginOptions {
|
||||
if o == nil {
|
||||
return nil
|
||||
}
|
||||
out := *o
|
||||
return &out
|
||||
}
|
||||
|
||||
@@ -15,12 +15,9 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/fatedier/frp/pkg/util/jsonx"
|
||||
"github.com/fatedier/frp/pkg/util/util"
|
||||
)
|
||||
|
||||
@@ -50,6 +47,13 @@ type VisitorBaseConfig struct {
|
||||
Plugin TypedVisitorPluginOptions `json:"plugin,omitempty"`
|
||||
}
|
||||
|
||||
func (c VisitorBaseConfig) Clone() VisitorBaseConfig {
|
||||
out := c
|
||||
out.Enabled = util.ClonePtr(c.Enabled)
|
||||
out.Plugin = c.Plugin.Clone()
|
||||
return out
|
||||
}
|
||||
|
||||
func (c *VisitorBaseConfig) GetBaseConfig() *VisitorBaseConfig {
|
||||
return c
|
||||
}
|
||||
@@ -63,6 +67,7 @@ func (c *VisitorBaseConfig) Complete() {
|
||||
type VisitorConfigurer interface {
|
||||
Complete()
|
||||
GetBaseConfig() *VisitorBaseConfig
|
||||
Clone() VisitorConfigurer
|
||||
}
|
||||
|
||||
type VisitorType string
|
||||
@@ -74,9 +79,9 @@ const (
|
||||
)
|
||||
|
||||
var visitorConfigTypeMap = map[VisitorType]reflect.Type{
|
||||
VisitorTypeSTCP: reflect.TypeOf(STCPVisitorConfig{}),
|
||||
VisitorTypeXTCP: reflect.TypeOf(XTCPVisitorConfig{}),
|
||||
VisitorTypeSUDP: reflect.TypeOf(SUDPVisitorConfig{}),
|
||||
VisitorTypeSTCP: reflect.TypeFor[STCPVisitorConfig](),
|
||||
VisitorTypeXTCP: reflect.TypeFor[XTCPVisitorConfig](),
|
||||
VisitorTypeSUDP: reflect.TypeFor[SUDPVisitorConfig](),
|
||||
}
|
||||
|
||||
type TypedVisitorConfig struct {
|
||||
@@ -85,35 +90,18 @@ type TypedVisitorConfig struct {
|
||||
}
|
||||
|
||||
func (c *TypedVisitorConfig) UnmarshalJSON(b []byte) error {
|
||||
if len(b) == 4 && string(b) == "null" {
|
||||
return errors.New("type is required")
|
||||
}
|
||||
|
||||
typeStruct := struct {
|
||||
Type string `json:"type"`
|
||||
}{}
|
||||
if err := json.Unmarshal(b, &typeStruct); err != nil {
|
||||
configurer, err := DecodeVisitorConfigurerJSON(b, DecodeOptions{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.Type = typeStruct.Type
|
||||
configurer := NewVisitorConfigurerByType(VisitorType(typeStruct.Type))
|
||||
if configurer == nil {
|
||||
return fmt.Errorf("unknown visitor type: %s", typeStruct.Type)
|
||||
}
|
||||
decoder := json.NewDecoder(bytes.NewBuffer(b))
|
||||
if DisallowUnknownFields {
|
||||
decoder.DisallowUnknownFields()
|
||||
}
|
||||
if err := decoder.Decode(configurer); err != nil {
|
||||
return fmt.Errorf("unmarshal VisitorConfig error: %v", err)
|
||||
}
|
||||
c.Type = configurer.GetBaseConfig().Type
|
||||
c.VisitorConfigurer = configurer
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *TypedVisitorConfig) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(c.VisitorConfigurer)
|
||||
return jsonx.Marshal(c.VisitorConfigurer)
|
||||
}
|
||||
|
||||
func NewVisitorConfigurerByType(t VisitorType) VisitorConfigurer {
|
||||
@@ -132,12 +120,24 @@ type STCPVisitorConfig struct {
|
||||
VisitorBaseConfig
|
||||
}
|
||||
|
||||
func (c *STCPVisitorConfig) Clone() VisitorConfigurer {
|
||||
out := *c
|
||||
out.VisitorBaseConfig = c.VisitorBaseConfig.Clone()
|
||||
return &out
|
||||
}
|
||||
|
||||
var _ VisitorConfigurer = &SUDPVisitorConfig{}
|
||||
|
||||
type SUDPVisitorConfig struct {
|
||||
VisitorBaseConfig
|
||||
}
|
||||
|
||||
func (c *SUDPVisitorConfig) Clone() VisitorConfigurer {
|
||||
out := *c
|
||||
out.VisitorBaseConfig = c.VisitorBaseConfig.Clone()
|
||||
return &out
|
||||
}
|
||||
|
||||
var _ VisitorConfigurer = &XTCPVisitorConfig{}
|
||||
|
||||
type XTCPVisitorConfig struct {
|
||||
@@ -162,3 +162,10 @@ func (c *XTCPVisitorConfig) Complete() {
|
||||
c.MinRetryInterval = util.EmptyOr(c.MinRetryInterval, 90)
|
||||
c.FallbackTimeoutMs = util.EmptyOr(c.FallbackTimeoutMs, 1000)
|
||||
}
|
||||
|
||||
func (c *XTCPVisitorConfig) Clone() VisitorConfigurer {
|
||||
out := *c
|
||||
out.VisitorBaseConfig = c.VisitorBaseConfig.Clone()
|
||||
out.NatTraversal = c.NatTraversal.Clone()
|
||||
return &out
|
||||
}
|
||||
|
||||
@@ -15,11 +15,9 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/fatedier/frp/pkg/util/jsonx"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -27,11 +25,12 @@ const (
|
||||
)
|
||||
|
||||
var visitorPluginOptionsTypeMap = map[string]reflect.Type{
|
||||
VisitorPluginVirtualNet: reflect.TypeOf(VirtualNetVisitorPluginOptions{}),
|
||||
VisitorPluginVirtualNet: reflect.TypeFor[VirtualNetVisitorPluginOptions](),
|
||||
}
|
||||
|
||||
type VisitorPluginOptions interface {
|
||||
Complete()
|
||||
Clone() VisitorPluginOptions
|
||||
}
|
||||
|
||||
type TypedVisitorPluginOptions struct {
|
||||
@@ -39,43 +38,25 @@ type TypedVisitorPluginOptions struct {
|
||||
VisitorPluginOptions
|
||||
}
|
||||
|
||||
func (c *TypedVisitorPluginOptions) UnmarshalJSON(b []byte) error {
|
||||
if len(b) == 4 && string(b) == "null" {
|
||||
return nil
|
||||
func (c TypedVisitorPluginOptions) Clone() TypedVisitorPluginOptions {
|
||||
out := c
|
||||
if c.VisitorPluginOptions != nil {
|
||||
out.VisitorPluginOptions = c.VisitorPluginOptions.Clone()
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
typeStruct := struct {
|
||||
Type string `json:"type"`
|
||||
}{}
|
||||
if err := json.Unmarshal(b, &typeStruct); err != nil {
|
||||
func (c *TypedVisitorPluginOptions) UnmarshalJSON(b []byte) error {
|
||||
decoded, err := DecodeVisitorPluginOptionsJSON(b, DecodeOptions{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.Type = typeStruct.Type
|
||||
if c.Type == "" {
|
||||
return errors.New("visitor plugin type is empty")
|
||||
}
|
||||
|
||||
v, ok := visitorPluginOptionsTypeMap[typeStruct.Type]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown visitor plugin type: %s", typeStruct.Type)
|
||||
}
|
||||
options := reflect.New(v).Interface().(VisitorPluginOptions)
|
||||
|
||||
decoder := json.NewDecoder(bytes.NewBuffer(b))
|
||||
if DisallowUnknownFields {
|
||||
decoder.DisallowUnknownFields()
|
||||
}
|
||||
|
||||
if err := decoder.Decode(options); err != nil {
|
||||
return fmt.Errorf("unmarshal VisitorPluginOptions error: %v", err)
|
||||
}
|
||||
c.VisitorPluginOptions = options
|
||||
*c = decoded
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *TypedVisitorPluginOptions) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(c.VisitorPluginOptions)
|
||||
return jsonx.Marshal(c.VisitorPluginOptions)
|
||||
}
|
||||
|
||||
type VirtualNetVisitorPluginOptions struct {
|
||||
@@ -84,3 +65,11 @@ type VirtualNetVisitorPluginOptions struct {
|
||||
}
|
||||
|
||||
func (o *VirtualNetVisitorPluginOptions) Complete() {}
|
||||
|
||||
func (o *VirtualNetVisitorPluginOptions) Clone() VisitorPluginOptions {
|
||||
if o == nil {
|
||||
return nil
|
||||
}
|
||||
out := *o
|
||||
return &out
|
||||
}
|
||||
|
||||
@@ -143,7 +143,6 @@ func (m *serverMetrics) OpenConnection(name string, _ string) {
|
||||
proxyStats, ok := m.info.ProxyStatistics[name]
|
||||
if ok {
|
||||
proxyStats.CurConns.Inc(1)
|
||||
m.info.ProxyStatistics[name] = proxyStats
|
||||
}
|
||||
}
|
||||
|
||||
@@ -155,7 +154,6 @@ func (m *serverMetrics) CloseConnection(name string, _ string) {
|
||||
proxyStats, ok := m.info.ProxyStatistics[name]
|
||||
if ok {
|
||||
proxyStats.CurConns.Dec(1)
|
||||
m.info.ProxyStatistics[name] = proxyStats
|
||||
}
|
||||
}
|
||||
|
||||
@@ -168,7 +166,6 @@ func (m *serverMetrics) AddTrafficIn(name string, _ string, trafficBytes int64)
|
||||
proxyStats, ok := m.info.ProxyStatistics[name]
|
||||
if ok {
|
||||
proxyStats.TrafficIn.Inc(trafficBytes)
|
||||
m.info.ProxyStatistics[name] = proxyStats
|
||||
}
|
||||
}
|
||||
|
||||
@@ -181,7 +178,6 @@ func (m *serverMetrics) AddTrafficOut(name string, _ string, trafficBytes int64)
|
||||
proxyStats, ok := m.info.ProxyStatistics[name]
|
||||
if ok {
|
||||
proxyStats.TrafficOut.Inc(trafficBytes)
|
||||
m.info.ProxyStatistics[name] = proxyStats
|
||||
}
|
||||
}
|
||||
|
||||
@@ -203,6 +199,25 @@ func (m *serverMetrics) GetServer() *ServerStats {
|
||||
return s
|
||||
}
|
||||
|
||||
func toProxyStats(name string, proxyStats *ProxyStatistics) *ProxyStats {
|
||||
ps := &ProxyStats{
|
||||
Name: name,
|
||||
Type: proxyStats.ProxyType,
|
||||
User: proxyStats.User,
|
||||
ClientID: proxyStats.ClientID,
|
||||
TodayTrafficIn: proxyStats.TrafficIn.TodayCount(),
|
||||
TodayTrafficOut: proxyStats.TrafficOut.TodayCount(),
|
||||
CurConns: int64(proxyStats.CurConns.Count()),
|
||||
}
|
||||
if !proxyStats.LastStartTime.IsZero() {
|
||||
ps.LastStartTime = proxyStats.LastStartTime.Format("01-02 15:04:05")
|
||||
}
|
||||
if !proxyStats.LastCloseTime.IsZero() {
|
||||
ps.LastCloseTime = proxyStats.LastCloseTime.Format("01-02 15:04:05")
|
||||
}
|
||||
return ps
|
||||
}
|
||||
|
||||
func (m *serverMetrics) GetProxiesByType(proxyType string) []*ProxyStats {
|
||||
res := make([]*ProxyStats, 0)
|
||||
m.mu.Lock()
|
||||
@@ -212,23 +227,7 @@ func (m *serverMetrics) GetProxiesByType(proxyType string) []*ProxyStats {
|
||||
if proxyStats.ProxyType != proxyType {
|
||||
continue
|
||||
}
|
||||
|
||||
ps := &ProxyStats{
|
||||
Name: name,
|
||||
Type: proxyStats.ProxyType,
|
||||
User: proxyStats.User,
|
||||
ClientID: proxyStats.ClientID,
|
||||
TodayTrafficIn: proxyStats.TrafficIn.TodayCount(),
|
||||
TodayTrafficOut: proxyStats.TrafficOut.TodayCount(),
|
||||
CurConns: int64(proxyStats.CurConns.Count()),
|
||||
}
|
||||
if !proxyStats.LastStartTime.IsZero() {
|
||||
ps.LastStartTime = proxyStats.LastStartTime.Format("01-02 15:04:05")
|
||||
}
|
||||
if !proxyStats.LastCloseTime.IsZero() {
|
||||
ps.LastCloseTime = proxyStats.LastCloseTime.Format("01-02 15:04:05")
|
||||
}
|
||||
res = append(res, ps)
|
||||
res = append(res, toProxyStats(name, proxyStats))
|
||||
}
|
||||
return res
|
||||
}
|
||||
@@ -237,31 +236,9 @@ func (m *serverMetrics) GetProxiesByTypeAndName(proxyType string, proxyName stri
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for name, proxyStats := range m.info.ProxyStatistics {
|
||||
if proxyStats.ProxyType != proxyType {
|
||||
continue
|
||||
}
|
||||
|
||||
if name != proxyName {
|
||||
continue
|
||||
}
|
||||
|
||||
res = &ProxyStats{
|
||||
Name: name,
|
||||
Type: proxyStats.ProxyType,
|
||||
User: proxyStats.User,
|
||||
ClientID: proxyStats.ClientID,
|
||||
TodayTrafficIn: proxyStats.TrafficIn.TodayCount(),
|
||||
TodayTrafficOut: proxyStats.TrafficOut.TodayCount(),
|
||||
CurConns: int64(proxyStats.CurConns.Count()),
|
||||
}
|
||||
if !proxyStats.LastStartTime.IsZero() {
|
||||
res.LastStartTime = proxyStats.LastStartTime.Format("01-02 15:04:05")
|
||||
}
|
||||
if !proxyStats.LastCloseTime.IsZero() {
|
||||
res.LastCloseTime = proxyStats.LastCloseTime.Format("01-02 15:04:05")
|
||||
}
|
||||
break
|
||||
proxyStats, ok := m.info.ProxyStatistics[proxyName]
|
||||
if ok && proxyStats.ProxyType == proxyType {
|
||||
res = toProxyStats(proxyName, proxyStats)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -272,21 +249,7 @@ func (m *serverMetrics) GetProxyByName(proxyName string) (res *ProxyStats) {
|
||||
|
||||
proxyStats, ok := m.info.ProxyStatistics[proxyName]
|
||||
if ok {
|
||||
res = &ProxyStats{
|
||||
Name: proxyName,
|
||||
Type: proxyStats.ProxyType,
|
||||
User: proxyStats.User,
|
||||
ClientID: proxyStats.ClientID,
|
||||
TodayTrafficIn: proxyStats.TrafficIn.TodayCount(),
|
||||
TodayTrafficOut: proxyStats.TrafficOut.TodayCount(),
|
||||
CurConns: int64(proxyStats.CurConns.Count()),
|
||||
}
|
||||
if !proxyStats.LastStartTime.IsZero() {
|
||||
res.LastStartTime = proxyStats.LastStartTime.Format("01-02 15:04:05")
|
||||
}
|
||||
if !proxyStats.LastCloseTime.IsZero() {
|
||||
res.LastCloseTime = proxyStats.LastCloseTime.Format("01-02 15:04:05")
|
||||
}
|
||||
res = toProxyStats(proxyName, proxyStats)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -61,7 +61,7 @@ var msgTypeMap = map[byte]any{
|
||||
TypeNatHoleReport: NatHoleReport{},
|
||||
}
|
||||
|
||||
var TypeNameNatHoleResp = reflect.TypeOf(&NatHoleResp{}).Elem().Name()
|
||||
var TypeNameNatHoleResp = reflect.TypeFor[NatHoleResp]().Name()
|
||||
|
||||
type ClientSpec struct {
|
||||
// Due to the support of VirtualClient, frps needs to know the client type in order to
|
||||
@@ -184,7 +184,7 @@ type Pong struct {
|
||||
}
|
||||
|
||||
type UDPPacket struct {
|
||||
Content string `json:"c,omitempty"`
|
||||
Content []byte `json:"c,omitempty"`
|
||||
LocalAddr *net.UDPAddr `json:"l,omitempty"`
|
||||
RemoteAddr *net.UDPAddr `json:"r,omitempty"`
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package util
|
||||
package naming
|
||||
|
||||
import "strings"
|
||||
|
||||
@@ -16,9 +16,8 @@ func StripUserPrefix(user, name string) string {
|
||||
if user == "" {
|
||||
return name
|
||||
}
|
||||
prefix := user + "."
|
||||
if strings.HasPrefix(name, prefix) {
|
||||
return strings.TrimPrefix(name, prefix)
|
||||
if trimmed, ok := strings.CutPrefix(name, user+"."); ok {
|
||||
return trimmed
|
||||
}
|
||||
return name
|
||||
}
|
||||
27
pkg/naming/names_test.go
Normal file
27
pkg/naming/names_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package naming
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAddUserPrefix(t *testing.T) {
|
||||
require := require.New(t)
|
||||
require.Equal("test", AddUserPrefix("", "test"))
|
||||
require.Equal("alice.test", AddUserPrefix("alice", "test"))
|
||||
}
|
||||
|
||||
func TestStripUserPrefix(t *testing.T) {
|
||||
require := require.New(t)
|
||||
require.Equal("test", StripUserPrefix("", "test"))
|
||||
require.Equal("test", StripUserPrefix("alice", "alice.test"))
|
||||
require.Equal("alice.test", StripUserPrefix("alice", "alice.alice.test"))
|
||||
require.Equal("bob.test", StripUserPrefix("alice", "bob.test"))
|
||||
}
|
||||
|
||||
func TestBuildTargetServerProxyName(t *testing.T) {
|
||||
require := require.New(t)
|
||||
require.Equal("alice.test", BuildTargetServerProxyName("alice", "", "test"))
|
||||
require.Equal("bob.test", BuildTargetServerProxyName("alice", "bob", "test"))
|
||||
}
|
||||
@@ -151,7 +151,7 @@ func getBehaviorScoresByMode(mode int, defaultScore int) []*BehaviorScore {
|
||||
func getBehaviorScoresByMode2(mode int, senderScore, receiverScore int) []*BehaviorScore {
|
||||
behaviors := getBehaviorByMode(mode)
|
||||
scores := make([]*BehaviorScore, 0, len(behaviors))
|
||||
for i := 0; i < len(behaviors); i++ {
|
||||
for i := range behaviors {
|
||||
score := receiverScore
|
||||
if behaviors[i].A.Role == DetectRoleSender {
|
||||
score = senderScore
|
||||
|
||||
@@ -70,12 +70,8 @@ func ClassifyNATFeature(addresses []string, localIPs []string) (*NatFeature, err
|
||||
continue
|
||||
}
|
||||
|
||||
if portNum > portMax {
|
||||
portMax = portNum
|
||||
}
|
||||
if portNum < portMin {
|
||||
portMin = portNum
|
||||
}
|
||||
portMax = max(portMax, portNum)
|
||||
portMin = min(portMin, portNum)
|
||||
if baseIP != ip {
|
||||
ipChanged = true
|
||||
}
|
||||
|
||||
@@ -152,7 +152,9 @@ func (c *Controller) GenSid() string {
|
||||
|
||||
func (c *Controller) HandleVisitor(m *msg.NatHoleVisitor, transporter transport.MessageTransporter, visitorUser string) {
|
||||
if m.PreCheck {
|
||||
c.mu.RLock()
|
||||
cfg, ok := c.clientCfgs[m.ProxyName]
|
||||
c.mu.RUnlock()
|
||||
if !ok {
|
||||
_ = transporter.Send(c.GenNatHoleResponse(m.TransactionID, nil, fmt.Sprintf("xtcp server for [%s] doesn't exist", m.ProxyName)))
|
||||
return
|
||||
@@ -375,7 +377,7 @@ func getRangePorts(addrs []string, difference, maxNumber int) []msg.PortsRange {
|
||||
if !isLast {
|
||||
return nil
|
||||
}
|
||||
var ports []msg.PortsRange
|
||||
ports := make([]msg.PortsRange, 0, 1)
|
||||
_, portStr, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil
|
||||
|
||||
@@ -298,11 +298,13 @@ func waitDetectMessage(
|
||||
n, raddr, err := conn.ReadFromUDP(buf)
|
||||
_ = conn.SetReadDeadline(time.Time{})
|
||||
if err != nil {
|
||||
pool.PutBuf(buf)
|
||||
return nil, err
|
||||
}
|
||||
xl.Debugf("get udp message local %s, from %s", conn.LocalAddr(), raddr)
|
||||
var m msg.NatHoleSid
|
||||
if err := DecodeMessageInto(buf[:n], key, &m); err != nil {
|
||||
pool.PutBuf(buf)
|
||||
xl.Warnf("decode sid message error: %v", err)
|
||||
continue
|
||||
}
|
||||
@@ -408,7 +410,7 @@ func sendSidMessageToRandomPorts(
|
||||
xl := xlog.FromContextSafe(ctx)
|
||||
used := sets.New[int]()
|
||||
getUnusedPort := func() int {
|
||||
for i := 0; i < 10; i++ {
|
||||
for range 10 {
|
||||
port := rand.IntN(65535-1024) + 1024
|
||||
if !used.Has(port) {
|
||||
used.Insert(port)
|
||||
@@ -418,7 +420,7 @@ func sendSidMessageToRandomPorts(
|
||||
return 0
|
||||
}
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
for range count {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
stdlog "log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/golib/pool"
|
||||
|
||||
@@ -68,7 +69,7 @@ func NewHTTP2HTTPPlugin(_ PluginContext, options v1.ClientPluginOptions) (Plugin
|
||||
|
||||
p.s = &http.Server{
|
||||
Handler: rp,
|
||||
ReadHeaderTimeout: 0,
|
||||
ReadHeaderTimeout: 60 * time.Second,
|
||||
}
|
||||
|
||||
go func() {
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
stdlog "log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/golib/pool"
|
||||
|
||||
@@ -77,7 +78,7 @@ func NewHTTP2HTTPSPlugin(_ PluginContext, options v1.ClientPluginOptions) (Plugi
|
||||
|
||||
p.s = &http.Server{
|
||||
Handler: rp,
|
||||
ReadHeaderTimeout: 0,
|
||||
ReadHeaderTimeout: 60 * time.Second,
|
||||
}
|
||||
|
||||
go func() {
|
||||
|
||||
@@ -62,11 +62,13 @@ func (p *TLS2RawPlugin) Handle(ctx context.Context, connInfo *ConnectionInfo) {
|
||||
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
xl.Warnf("tls handshake error: %v", err)
|
||||
tlsConn.Close()
|
||||
return
|
||||
}
|
||||
rawConn, err := net.Dial("tcp", p.opts.LocalAddr)
|
||||
if err != nil {
|
||||
xl.Warnf("dial to local addr error: %v", err)
|
||||
tlsConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -54,10 +54,13 @@ func (uds *UnixDomainSocketPlugin) Handle(ctx context.Context, connInfo *Connect
|
||||
localConn, err := net.DialUnix("unix", nil, uds.UnixAddr)
|
||||
if err != nil {
|
||||
xl.Warnf("dial to uds %s error: %v", uds.UnixAddr, err)
|
||||
connInfo.Conn.Close()
|
||||
return
|
||||
}
|
||||
if connInfo.ProxyProtocolHeader != nil {
|
||||
if _, err := connInfo.ProxyProtocolHeader.WriteTo(localConn); err != nil {
|
||||
localConn.Close()
|
||||
connInfo.Conn.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
@@ -64,12 +65,7 @@ func (p *httpPlugin) Name() string {
|
||||
}
|
||||
|
||||
func (p *httpPlugin) IsSupport(op string) bool {
|
||||
for _, v := range p.options.Ops {
|
||||
if v == op {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
return slices.Contains(p.options.Ops, op)
|
||||
}
|
||||
|
||||
func (p *httpPlugin) Handle(ctx context.Context, op string, content any) (*Response, any, error) {
|
||||
|
||||
@@ -153,10 +153,7 @@ func (p *VirtualNetPlugin) run() {
|
||||
|
||||
// Exponential backoff: 60s, 120s, 240s, 300s (capped)
|
||||
baseDelay := 60 * time.Second
|
||||
reconnectDelay = baseDelay * time.Duration(1<<uint(p.consecutiveErrors-1))
|
||||
if reconnectDelay > 300*time.Second {
|
||||
reconnectDelay = 300 * time.Second
|
||||
}
|
||||
reconnectDelay = min(baseDelay*time.Duration(1<<uint(p.consecutiveErrors-1)), 300*time.Second)
|
||||
} else {
|
||||
// Reset consecutive errors on successful connection
|
||||
if p.consecutiveErrors > 0 {
|
||||
|
||||
@@ -16,6 +16,7 @@ package featuregate
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"maps"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -92,10 +93,7 @@ type featureGate struct {
|
||||
|
||||
// NewFeatureGate creates a new feature gate with the default features
|
||||
func NewFeatureGate() MutableFeatureGate {
|
||||
known := map[Feature]FeatureSpec{}
|
||||
for k, v := range defaultFeatures {
|
||||
known[k] = v
|
||||
}
|
||||
known := maps.Clone(defaultFeatures)
|
||||
|
||||
f := &featureGate{}
|
||||
f.known.Store(known)
|
||||
@@ -109,14 +107,8 @@ func (f *featureGate) SetFromMap(m map[string]bool) error {
|
||||
defer f.lock.Unlock()
|
||||
|
||||
// Copy existing state
|
||||
known := map[Feature]FeatureSpec{}
|
||||
for k, v := range f.known.Load().(map[Feature]FeatureSpec) {
|
||||
known[k] = v
|
||||
}
|
||||
enabled := map[Feature]bool{}
|
||||
for k, v := range f.enabled.Load().(map[Feature]bool) {
|
||||
enabled[k] = v
|
||||
}
|
||||
known := maps.Clone(f.known.Load().(map[Feature]FeatureSpec))
|
||||
enabled := maps.Clone(f.enabled.Load().(map[Feature]bool))
|
||||
|
||||
// Apply the new settings
|
||||
for k, v := range m {
|
||||
@@ -147,10 +139,7 @@ func (f *featureGate) Add(features map[Feature]FeatureSpec) error {
|
||||
}
|
||||
|
||||
// Copy existing state
|
||||
known := map[Feature]FeatureSpec{}
|
||||
for k, v := range f.known.Load().(map[Feature]FeatureSpec) {
|
||||
known[k] = v
|
||||
}
|
||||
known := maps.Clone(f.known.Load().(map[Feature]FeatureSpec))
|
||||
|
||||
// Add new features
|
||||
for name, spec := range features {
|
||||
@@ -171,8 +160,9 @@ func (f *featureGate) Add(features map[Feature]FeatureSpec) error {
|
||||
|
||||
// String returns a string containing all enabled feature gates, formatted as "key1=value1,key2=value2,..."
|
||||
func (f *featureGate) String() string {
|
||||
pairs := []string{}
|
||||
for k, v := range f.enabled.Load().(map[Feature]bool) {
|
||||
enabled := f.enabled.Load().(map[Feature]bool)
|
||||
pairs := make([]string, 0, len(enabled))
|
||||
for k, v := range enabled {
|
||||
pairs = append(pairs, fmt.Sprintf("%s=%t", k, v))
|
||||
}
|
||||
sort.Strings(pairs)
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -28,16 +27,17 @@ import (
|
||||
)
|
||||
|
||||
func NewUDPPacket(buf []byte, laddr, raddr *net.UDPAddr) *msg.UDPPacket {
|
||||
content := make([]byte, len(buf))
|
||||
copy(content, buf)
|
||||
return &msg.UDPPacket{
|
||||
Content: base64.StdEncoding.EncodeToString(buf),
|
||||
Content: content,
|
||||
LocalAddr: laddr,
|
||||
RemoteAddr: raddr,
|
||||
}
|
||||
}
|
||||
|
||||
func GetContent(m *msg.UDPPacket) (buf []byte, err error) {
|
||||
buf, err = base64.StdEncoding.DecodeString(m.Content)
|
||||
return
|
||||
return m.Content, nil
|
||||
}
|
||||
|
||||
func ForwardUserConn(udpConn *net.UDPConn, readCh <-chan *msg.UDPPacket, sendCh chan<- *msg.UDPPacket, bufSize int) {
|
||||
@@ -60,7 +60,7 @@ func ForwardUserConn(udpConn *net.UDPConn, readCh <-chan *msg.UDPPacket, sendCh
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// buf[:n] will be encoded to string, so the bytes can be reused
|
||||
// NewUDPPacket copies buf[:n], so the read buffer can be reused
|
||||
udpMsg := NewUDPPacket(buf[:n], nil, remoteAddr)
|
||||
|
||||
select {
|
||||
@@ -85,6 +85,7 @@ func Forwarder(dstAddr *net.UDPAddr, readCh <-chan *msg.UDPPacket, sendCh chan<-
|
||||
}()
|
||||
|
||||
buf := pool.GetBuf(bufSize)
|
||||
defer pool.PutBuf(buf)
|
||||
for {
|
||||
_ = udpConn.SetReadDeadline(time.Now().Add(30 * time.Second))
|
||||
n, _, err := udpConn.ReadFromUDP(buf)
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/fatedier/frp/client/api"
|
||||
"github.com/fatedier/frp/client/http/model"
|
||||
httppkg "github.com/fatedier/frp/pkg/util/http"
|
||||
)
|
||||
|
||||
@@ -32,7 +32,7 @@ func (c *Client) SetAuth(user, pwd string) {
|
||||
c.authPwd = pwd
|
||||
}
|
||||
|
||||
func (c *Client) GetProxyStatus(ctx context.Context, name string) (*api.ProxyStatusResp, error) {
|
||||
func (c *Client) GetProxyStatus(ctx context.Context, name string) (*model.ProxyStatusResp, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "http://"+c.address+"/api/status", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -41,7 +41,7 @@ func (c *Client) GetProxyStatus(ctx context.Context, name string) (*api.ProxySta
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
allStatus := make(api.StatusResp)
|
||||
allStatus := make(model.StatusResp)
|
||||
if err = json.Unmarshal([]byte(content), &allStatus); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal http response error: %s", strings.TrimSpace(content))
|
||||
}
|
||||
@@ -55,7 +55,7 @@ func (c *Client) GetProxyStatus(ctx context.Context, name string) (*api.ProxySta
|
||||
return nil, fmt.Errorf("no proxy status found")
|
||||
}
|
||||
|
||||
func (c *Client) GetAllProxyStatus(ctx context.Context) (api.StatusResp, error) {
|
||||
func (c *Client) GetAllProxyStatus(ctx context.Context) (model.StatusResp, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "http://"+c.address+"/api/status", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -64,7 +64,7 @@ func (c *Client) GetAllProxyStatus(ctx context.Context) (api.StatusResp, error)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
allStatus := make(api.StatusResp)
|
||||
allStatus := make(model.StatusResp)
|
||||
if err = json.Unmarshal([]byte(content), &allStatus); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal http response error: %s", strings.TrimSpace(content))
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"os"
|
||||
"time"
|
||||
@@ -85,7 +86,9 @@ func newCertPool(caPath string) (*x509.CertPool, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pool.AppendCertsFromPEM(caCrt)
|
||||
if !pool.AppendCertsFromPEM(caCrt) {
|
||||
return nil, fmt.Errorf("failed to parse CA certificate from file %q: no valid PEM certificates found", caPath)
|
||||
}
|
||||
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
@@ -89,11 +89,11 @@ func ParseBasicAuth(auth string) (username, password string, ok bool) {
|
||||
return
|
||||
}
|
||||
cs := string(c)
|
||||
s := strings.IndexByte(cs, ':')
|
||||
if s < 0 {
|
||||
before, after, found := strings.Cut(cs, ":")
|
||||
if !found {
|
||||
return
|
||||
}
|
||||
return cs[:s], cs[s+1:], true
|
||||
return before, after, true
|
||||
}
|
||||
|
||||
func BasicAuth(username, passwd string) string {
|
||||
|
||||
45
pkg/util/jsonx/json_v1.go
Normal file
45
pkg/util/jsonx/json_v1.go
Normal file
@@ -0,0 +1,45 @@
|
||||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package jsonx
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type DecodeOptions struct {
|
||||
RejectUnknownMembers bool
|
||||
}
|
||||
|
||||
func Marshal(v any) ([]byte, error) {
|
||||
return json.Marshal(v)
|
||||
}
|
||||
|
||||
func MarshalIndent(v any, prefix, indent string) ([]byte, error) {
|
||||
return json.MarshalIndent(v, prefix, indent)
|
||||
}
|
||||
|
||||
func Unmarshal(data []byte, out any) error {
|
||||
return json.Unmarshal(data, out)
|
||||
}
|
||||
|
||||
func UnmarshalWithOptions(data []byte, out any, options DecodeOptions) error {
|
||||
if !options.RejectUnknownMembers {
|
||||
return json.Unmarshal(data, out)
|
||||
}
|
||||
decoder := json.NewDecoder(bytes.NewReader(data))
|
||||
decoder.DisallowUnknownFields()
|
||||
return decoder.Decode(out)
|
||||
}
|
||||
36
pkg/util/jsonx/raw_message.go
Normal file
36
pkg/util/jsonx/raw_message.go
Normal file
@@ -0,0 +1,36 @@
|
||||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package jsonx
|
||||
|
||||
import "fmt"
|
||||
|
||||
// RawMessage stores a raw encoded JSON value.
|
||||
// It is equivalent to encoding/json.RawMessage behavior.
|
||||
type RawMessage []byte
|
||||
|
||||
func (m RawMessage) MarshalJSON() ([]byte, error) {
|
||||
if m == nil {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *RawMessage) UnmarshalJSON(data []byte) error {
|
||||
if m == nil {
|
||||
return fmt.Errorf("jsonx.RawMessage: UnmarshalJSON on nil pointer")
|
||||
}
|
||||
*m = append((*m)[:0], data...)
|
||||
return nil
|
||||
}
|
||||
@@ -86,11 +86,7 @@ func (c *FakeUDPConn) Read(b []byte) (n int, err error) {
|
||||
c.lastActive = time.Now()
|
||||
c.mu.Unlock()
|
||||
|
||||
if len(b) < len(content) {
|
||||
n = len(b)
|
||||
} else {
|
||||
n = len(content)
|
||||
}
|
||||
n = min(len(b), len(content))
|
||||
copy(b, content)
|
||||
return n, nil
|
||||
}
|
||||
@@ -168,11 +164,15 @@ func ListenUDP(bindAddr string, bindPort int) (l *UDPListener, err error) {
|
||||
return l, err
|
||||
}
|
||||
readConn, err := net.ListenUDP("udp", udpAddr)
|
||||
if err != nil {
|
||||
return l, err
|
||||
}
|
||||
|
||||
l = &UDPListener{
|
||||
addr: udpAddr,
|
||||
acceptCh: make(chan net.Conn),
|
||||
writeCh: make(chan *UDPPacket, 1000),
|
||||
readConn: readConn,
|
||||
fakeConns: make(map[string]*FakeUDPConn),
|
||||
}
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ type WebsocketListener struct {
|
||||
// ln: tcp listener for websocket connections
|
||||
func NewWebsocketListener(ln net.Listener) (wl *WebsocketListener) {
|
||||
wl = &WebsocketListener{
|
||||
ln: ln,
|
||||
acceptCh: make(chan net.Conn),
|
||||
}
|
||||
|
||||
|
||||
@@ -68,8 +68,8 @@ func ParseRangeNumbers(rangeStr string) (numbers []int64, err error) {
|
||||
rangeStr = strings.TrimSpace(rangeStr)
|
||||
numbers = make([]int64, 0)
|
||||
// e.g. 1000-2000,2001,2002,3000-4000
|
||||
numRanges := strings.Split(rangeStr, ",")
|
||||
for _, numRangeStr := range numRanges {
|
||||
numRanges := strings.SplitSeq(rangeStr, ",")
|
||||
for numRangeStr := range numRanges {
|
||||
// 1000-2000 or 2001
|
||||
numArray := strings.Split(numRangeStr, "-")
|
||||
// length: only 1 or 2 is correct
|
||||
@@ -134,3 +134,12 @@ func RandomSleep(duration time.Duration, minRatio, maxRatio float64) time.Durati
|
||||
func ConstantTimeEqString(a, b string) bool {
|
||||
return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1
|
||||
}
|
||||
|
||||
// ClonePtr returns a pointer to a copied value. If v is nil, it returns nil.
|
||||
func ClonePtr[T any](v *T) *T {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
out := *v
|
||||
return &out
|
||||
}
|
||||
|
||||
@@ -42,22 +42,15 @@ func TestParseRangeNumbers(t *testing.T) {
|
||||
require.Error(err)
|
||||
}
|
||||
|
||||
func TestAddUserPrefix(t *testing.T) {
|
||||
func TestClonePtr(t *testing.T) {
|
||||
require := require.New(t)
|
||||
require.Equal("test", AddUserPrefix("", "test"))
|
||||
require.Equal("alice.test", AddUserPrefix("alice", "test"))
|
||||
}
|
||||
|
||||
func TestStripUserPrefix(t *testing.T) {
|
||||
require := require.New(t)
|
||||
require.Equal("test", StripUserPrefix("", "test"))
|
||||
require.Equal("test", StripUserPrefix("alice", "alice.test"))
|
||||
require.Equal("alice.test", StripUserPrefix("alice", "alice.alice.test"))
|
||||
require.Equal("bob.test", StripUserPrefix("alice", "bob.test"))
|
||||
}
|
||||
var nilInt *int
|
||||
require.Nil(ClonePtr(nilInt))
|
||||
|
||||
func TestBuildTargetServerProxyName(t *testing.T) {
|
||||
require := require.New(t)
|
||||
require.Equal("alice.test", BuildTargetServerProxyName("alice", "", "test"))
|
||||
require.Equal("bob.test", BuildTargetServerProxyName("alice", "bob", "test"))
|
||||
v := 42
|
||||
cloned := ClonePtr(&v)
|
||||
require.NotNil(cloned)
|
||||
require.Equal(v, *cloned)
|
||||
require.NotSame(&v, cloned)
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
package version
|
||||
|
||||
var version = "0.67.0"
|
||||
var version = "0.68.0"
|
||||
|
||||
func Full() string {
|
||||
return version
|
||||
|
||||
@@ -266,31 +266,13 @@ func (rp *HTTPReverseProxy) connectHandler(rw http.ResponseWriter, req *http.Req
|
||||
go libio.Join(remote, client)
|
||||
}
|
||||
|
||||
func parseBasicAuth(auth string) (username, password string, ok bool) {
|
||||
const prefix = "Basic "
|
||||
// Case insensitive prefix match. See Issue 22736.
|
||||
if len(auth) < len(prefix) || !strings.EqualFold(auth[:len(prefix)], prefix) {
|
||||
return
|
||||
}
|
||||
c, err := base64.StdEncoding.DecodeString(auth[len(prefix):])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
cs := string(c)
|
||||
s := strings.IndexByte(cs, ':')
|
||||
if s < 0 {
|
||||
return
|
||||
}
|
||||
return cs[:s], cs[s+1:], true
|
||||
}
|
||||
|
||||
func (rp *HTTPReverseProxy) injectRequestInfoToCtx(req *http.Request) *http.Request {
|
||||
user := ""
|
||||
// If url host isn't empty, it's a proxy request. Get http user from Proxy-Authorization header.
|
||||
if req.URL.Host != "" {
|
||||
proxyAuth := req.Header.Get("Proxy-Authorization")
|
||||
if proxyAuth != "" {
|
||||
user, _, _ = parseBasicAuth(proxyAuth)
|
||||
user, _, _ = httppkg.ParseBasicAuth(proxyAuth)
|
||||
}
|
||||
}
|
||||
if user == "" {
|
||||
|
||||
@@ -63,11 +63,12 @@ func (l *Logger) AddPrefix(prefix LogPrefix) *Logger {
|
||||
if prefix.Priority <= 0 {
|
||||
prefix.Priority = 10
|
||||
}
|
||||
for _, p := range l.prefixes {
|
||||
for i, p := range l.prefixes {
|
||||
if p.Name == prefix.Name {
|
||||
found = true
|
||||
p.Value = prefix.Value
|
||||
p.Priority = prefix.Priority
|
||||
l.prefixes[i].Value = prefix.Value
|
||||
l.prefixes[i].Priority = prefix.Priority
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
|
||||
64
server/api_router.go
Normal file
64
server/api_router.go
Normal file
@@ -0,0 +1,64 @@
|
||||
// Copyright 2017 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
|
||||
httppkg "github.com/fatedier/frp/pkg/util/http"
|
||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||
adminapi "github.com/fatedier/frp/server/http"
|
||||
)
|
||||
|
||||
func (svr *Service) registerRouteHandlers(helper *httppkg.RouterRegisterHelper) {
|
||||
helper.Router.HandleFunc("/healthz", healthz)
|
||||
subRouter := helper.Router.NewRoute().Subrouter()
|
||||
|
||||
subRouter.Use(helper.AuthMiddleware)
|
||||
subRouter.Use(httppkg.NewRequestLogger)
|
||||
|
||||
// metrics
|
||||
if svr.cfg.EnablePrometheus {
|
||||
subRouter.Handle("/metrics", promhttp.Handler())
|
||||
}
|
||||
|
||||
apiController := adminapi.NewController(svr.cfg, svr.clientRegistry, svr.pxyManager)
|
||||
|
||||
// apis
|
||||
subRouter.HandleFunc("/api/serverinfo", httppkg.MakeHTTPHandlerFunc(apiController.APIServerInfo)).Methods("GET")
|
||||
subRouter.HandleFunc("/api/proxy/{type}", httppkg.MakeHTTPHandlerFunc(apiController.APIProxyByType)).Methods("GET")
|
||||
subRouter.HandleFunc("/api/proxy/{type}/{name}", httppkg.MakeHTTPHandlerFunc(apiController.APIProxyByTypeAndName)).Methods("GET")
|
||||
subRouter.HandleFunc("/api/proxies/{name}", httppkg.MakeHTTPHandlerFunc(apiController.APIProxyByName)).Methods("GET")
|
||||
subRouter.HandleFunc("/api/traffic/{name}", httppkg.MakeHTTPHandlerFunc(apiController.APIProxyTraffic)).Methods("GET")
|
||||
subRouter.HandleFunc("/api/clients", httppkg.MakeHTTPHandlerFunc(apiController.APIClientList)).Methods("GET")
|
||||
subRouter.HandleFunc("/api/clients/{key}", httppkg.MakeHTTPHandlerFunc(apiController.APIClientDetail)).Methods("GET")
|
||||
subRouter.HandleFunc("/api/proxies", httppkg.MakeHTTPHandlerFunc(apiController.DeleteProxies)).Methods("DELETE")
|
||||
|
||||
// view
|
||||
subRouter.Handle("/favicon.ico", http.FileServer(helper.AssetsFS)).Methods("GET")
|
||||
subRouter.PathPrefix("/static/").Handler(
|
||||
netpkg.MakeHTTPGzipHandler(http.StripPrefix("/static/", http.FileServer(helper.AssetsFS))),
|
||||
).Methods("GET")
|
||||
|
||||
subRouter.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, "/static/", http.StatusMovedPermanently)
|
||||
})
|
||||
}
|
||||
|
||||
func healthz(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
}
|
||||
@@ -95,20 +95,33 @@ func (cm *ControlManager) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type Control struct {
|
||||
// SessionContext encapsulates the input parameters for creating a new Control.
|
||||
type SessionContext struct {
|
||||
// all resource managers and controllers
|
||||
rc *controller.ResourceController
|
||||
|
||||
RC *controller.ResourceController
|
||||
// proxy manager
|
||||
pxyManager *proxy.Manager
|
||||
|
||||
PxyManager *proxy.Manager
|
||||
// plugin manager
|
||||
pluginManager *plugin.Manager
|
||||
|
||||
PluginManager *plugin.Manager
|
||||
// verifies authentication based on selected method
|
||||
authVerifier auth.Verifier
|
||||
AuthVerifier auth.Verifier
|
||||
// key used for connection encryption
|
||||
encryptionKey []byte
|
||||
EncryptionKey []byte
|
||||
// control connection
|
||||
Conn net.Conn
|
||||
// indicates whether the connection is encrypted
|
||||
ConnEncrypted bool
|
||||
// login message
|
||||
LoginMsg *msg.Login
|
||||
// server configuration
|
||||
ServerCfg *v1.ServerConfig
|
||||
// client registry
|
||||
ClientRegistry *registry.ClientRegistry
|
||||
}
|
||||
|
||||
type Control struct {
|
||||
// session context
|
||||
sessionCtx *SessionContext
|
||||
|
||||
// other components can use this to communicate with client
|
||||
msgTransporter transport.MessageTransporter
|
||||
@@ -117,12 +130,6 @@ type Control struct {
|
||||
// It provides a channel for sending messages, and you can register handlers to process messages based on their respective types.
|
||||
msgDispatcher *msg.Dispatcher
|
||||
|
||||
// login message
|
||||
loginMsg *msg.Login
|
||||
|
||||
// control connection
|
||||
conn net.Conn
|
||||
|
||||
// work connections
|
||||
workConnCh chan net.Conn
|
||||
|
||||
@@ -145,61 +152,34 @@ type Control struct {
|
||||
|
||||
mu sync.RWMutex
|
||||
|
||||
// Server configuration information
|
||||
serverCfg *v1.ServerConfig
|
||||
|
||||
clientRegistry *registry.ClientRegistry
|
||||
|
||||
xl *xlog.Logger
|
||||
ctx context.Context
|
||||
doneCh chan struct{}
|
||||
}
|
||||
|
||||
// TODO(fatedier): Referencing the implementation of frpc, encapsulate the input parameters as SessionContext.
|
||||
func NewControl(
|
||||
ctx context.Context,
|
||||
rc *controller.ResourceController,
|
||||
pxyManager *proxy.Manager,
|
||||
pluginManager *plugin.Manager,
|
||||
authVerifier auth.Verifier,
|
||||
encryptionKey []byte,
|
||||
ctlConn net.Conn,
|
||||
ctlConnEncrypted bool,
|
||||
loginMsg *msg.Login,
|
||||
serverCfg *v1.ServerConfig,
|
||||
) (*Control, error) {
|
||||
poolCount := loginMsg.PoolCount
|
||||
if poolCount > int(serverCfg.Transport.MaxPoolCount) {
|
||||
poolCount = int(serverCfg.Transport.MaxPoolCount)
|
||||
}
|
||||
func NewControl(ctx context.Context, sessionCtx *SessionContext) (*Control, error) {
|
||||
poolCount := min(sessionCtx.LoginMsg.PoolCount, int(sessionCtx.ServerCfg.Transport.MaxPoolCount))
|
||||
ctl := &Control{
|
||||
rc: rc,
|
||||
pxyManager: pxyManager,
|
||||
pluginManager: pluginManager,
|
||||
authVerifier: authVerifier,
|
||||
encryptionKey: encryptionKey,
|
||||
conn: ctlConn,
|
||||
loginMsg: loginMsg,
|
||||
workConnCh: make(chan net.Conn, poolCount+10),
|
||||
proxies: make(map[string]proxy.Proxy),
|
||||
poolCount: poolCount,
|
||||
portsUsedNum: 0,
|
||||
runID: loginMsg.RunID,
|
||||
serverCfg: serverCfg,
|
||||
xl: xlog.FromContextSafe(ctx),
|
||||
ctx: ctx,
|
||||
doneCh: make(chan struct{}),
|
||||
sessionCtx: sessionCtx,
|
||||
workConnCh: make(chan net.Conn, poolCount+10),
|
||||
proxies: make(map[string]proxy.Proxy),
|
||||
poolCount: poolCount,
|
||||
portsUsedNum: 0,
|
||||
runID: sessionCtx.LoginMsg.RunID,
|
||||
xl: xlog.FromContextSafe(ctx),
|
||||
ctx: ctx,
|
||||
doneCh: make(chan struct{}),
|
||||
}
|
||||
ctl.lastPing.Store(time.Now())
|
||||
|
||||
if ctlConnEncrypted {
|
||||
cryptoRW, err := netpkg.NewCryptoReadWriter(ctl.conn, ctl.encryptionKey)
|
||||
if sessionCtx.ConnEncrypted {
|
||||
cryptoRW, err := netpkg.NewCryptoReadWriter(sessionCtx.Conn, sessionCtx.EncryptionKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctl.msgDispatcher = msg.NewDispatcher(cryptoRW)
|
||||
} else {
|
||||
ctl.msgDispatcher = msg.NewDispatcher(ctl.conn)
|
||||
ctl.msgDispatcher = msg.NewDispatcher(sessionCtx.Conn)
|
||||
}
|
||||
ctl.registerMsgHandlers()
|
||||
ctl.msgTransporter = transport.NewMessageTransporter(ctl.msgDispatcher)
|
||||
@@ -213,7 +193,7 @@ func (ctl *Control) Start() {
|
||||
RunID: ctl.runID,
|
||||
Error: "",
|
||||
}
|
||||
_ = msg.WriteMsg(ctl.conn, loginRespMsg)
|
||||
_ = msg.WriteMsg(ctl.sessionCtx.Conn, loginRespMsg)
|
||||
|
||||
go func() {
|
||||
for i := 0; i < ctl.poolCount; i++ {
|
||||
@@ -225,7 +205,7 @@ func (ctl *Control) Start() {
|
||||
}
|
||||
|
||||
func (ctl *Control) Close() error {
|
||||
ctl.conn.Close()
|
||||
ctl.sessionCtx.Conn.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -233,7 +213,7 @@ func (ctl *Control) Replaced(newCtl *Control) {
|
||||
xl := ctl.xl
|
||||
xl.Infof("replaced by client [%s]", newCtl.runID)
|
||||
ctl.runID = ""
|
||||
ctl.conn.Close()
|
||||
ctl.sessionCtx.Conn.Close()
|
||||
}
|
||||
|
||||
func (ctl *Control) RegisterWorkConn(conn net.Conn) error {
|
||||
@@ -291,7 +271,7 @@ func (ctl *Control) GetWorkConn() (workConn net.Conn, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
case <-time.After(time.Duration(ctl.serverCfg.UserConnTimeout) * time.Second):
|
||||
case <-time.After(time.Duration(ctl.sessionCtx.ServerCfg.UserConnTimeout) * time.Second):
|
||||
err = fmt.Errorf("timeout trying to get work connection")
|
||||
xl.Warnf("%v", err)
|
||||
return
|
||||
@@ -304,15 +284,15 @@ func (ctl *Control) GetWorkConn() (workConn net.Conn, err error) {
|
||||
}
|
||||
|
||||
func (ctl *Control) heartbeatWorker() {
|
||||
if ctl.serverCfg.Transport.HeartbeatTimeout <= 0 {
|
||||
if ctl.sessionCtx.ServerCfg.Transport.HeartbeatTimeout <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
xl := ctl.xl
|
||||
go wait.Until(func() {
|
||||
if time.Since(ctl.lastPing.Load().(time.Time)) > time.Duration(ctl.serverCfg.Transport.HeartbeatTimeout)*time.Second {
|
||||
if time.Since(ctl.lastPing.Load().(time.Time)) > time.Duration(ctl.sessionCtx.ServerCfg.Transport.HeartbeatTimeout)*time.Second {
|
||||
xl.Warnf("heartbeat timeout")
|
||||
ctl.conn.Close()
|
||||
ctl.sessionCtx.Conn.Close()
|
||||
return
|
||||
}
|
||||
}, time.Second, ctl.doneCh)
|
||||
@@ -323,6 +303,30 @@ func (ctl *Control) WaitClosed() {
|
||||
<-ctl.doneCh
|
||||
}
|
||||
|
||||
func (ctl *Control) loginUserInfo() plugin.UserInfo {
|
||||
return plugin.UserInfo{
|
||||
User: ctl.sessionCtx.LoginMsg.User,
|
||||
Metas: ctl.sessionCtx.LoginMsg.Metas,
|
||||
RunID: ctl.sessionCtx.LoginMsg.RunID,
|
||||
}
|
||||
}
|
||||
|
||||
func (ctl *Control) closeProxy(pxy proxy.Proxy) {
|
||||
pxy.Close()
|
||||
ctl.sessionCtx.PxyManager.Del(pxy.GetName())
|
||||
metrics.Server.CloseProxy(pxy.GetName(), pxy.GetConfigurer().GetBaseConfig().Type)
|
||||
|
||||
notifyContent := &plugin.CloseProxyContent{
|
||||
User: ctl.loginUserInfo(),
|
||||
CloseProxy: msg.CloseProxy{
|
||||
ProxyName: pxy.GetName(),
|
||||
},
|
||||
}
|
||||
go func() {
|
||||
_ = ctl.sessionCtx.PluginManager.CloseProxy(notifyContent)
|
||||
}()
|
||||
}
|
||||
|
||||
func (ctl *Control) worker() {
|
||||
xl := ctl.xl
|
||||
|
||||
@@ -330,38 +334,23 @@ func (ctl *Control) worker() {
|
||||
go ctl.msgDispatcher.Run()
|
||||
|
||||
<-ctl.msgDispatcher.Done()
|
||||
ctl.conn.Close()
|
||||
ctl.sessionCtx.Conn.Close()
|
||||
|
||||
ctl.mu.Lock()
|
||||
defer ctl.mu.Unlock()
|
||||
|
||||
close(ctl.workConnCh)
|
||||
for workConn := range ctl.workConnCh {
|
||||
workConn.Close()
|
||||
}
|
||||
proxies := ctl.proxies
|
||||
ctl.proxies = make(map[string]proxy.Proxy)
|
||||
ctl.mu.Unlock()
|
||||
|
||||
for _, pxy := range ctl.proxies {
|
||||
pxy.Close()
|
||||
ctl.pxyManager.Del(pxy.GetName())
|
||||
metrics.Server.CloseProxy(pxy.GetName(), pxy.GetConfigurer().GetBaseConfig().Type)
|
||||
|
||||
notifyContent := &plugin.CloseProxyContent{
|
||||
User: plugin.UserInfo{
|
||||
User: ctl.loginMsg.User,
|
||||
Metas: ctl.loginMsg.Metas,
|
||||
RunID: ctl.loginMsg.RunID,
|
||||
},
|
||||
CloseProxy: msg.CloseProxy{
|
||||
ProxyName: pxy.GetName(),
|
||||
},
|
||||
}
|
||||
go func() {
|
||||
_ = ctl.pluginManager.CloseProxy(notifyContent)
|
||||
}()
|
||||
for _, pxy := range proxies {
|
||||
ctl.closeProxy(pxy)
|
||||
}
|
||||
|
||||
metrics.Server.CloseClient()
|
||||
ctl.clientRegistry.MarkOfflineByRunID(ctl.runID)
|
||||
ctl.sessionCtx.ClientRegistry.MarkOfflineByRunID(ctl.runID)
|
||||
xl.Infof("client exit success")
|
||||
close(ctl.doneCh)
|
||||
}
|
||||
@@ -380,15 +369,11 @@ func (ctl *Control) handleNewProxy(m msg.Message) {
|
||||
inMsg := m.(*msg.NewProxy)
|
||||
|
||||
content := &plugin.NewProxyContent{
|
||||
User: plugin.UserInfo{
|
||||
User: ctl.loginMsg.User,
|
||||
Metas: ctl.loginMsg.Metas,
|
||||
RunID: ctl.loginMsg.RunID,
|
||||
},
|
||||
User: ctl.loginUserInfo(),
|
||||
NewProxy: *inMsg,
|
||||
}
|
||||
var remoteAddr string
|
||||
retContent, err := ctl.pluginManager.NewProxy(content)
|
||||
retContent, err := ctl.sessionCtx.PluginManager.NewProxy(content)
|
||||
if err == nil {
|
||||
inMsg = &retContent.NewProxy
|
||||
remoteAddr, err = ctl.RegisterProxy(inMsg)
|
||||
@@ -401,15 +386,15 @@ func (ctl *Control) handleNewProxy(m msg.Message) {
|
||||
if err != nil {
|
||||
xl.Warnf("new proxy [%s] type [%s] error: %v", inMsg.ProxyName, inMsg.ProxyType, err)
|
||||
resp.Error = util.GenerateResponseErrorString(fmt.Sprintf("new proxy [%s] error", inMsg.ProxyName),
|
||||
err, lo.FromPtr(ctl.serverCfg.DetailedErrorsToClient))
|
||||
err, lo.FromPtr(ctl.sessionCtx.ServerCfg.DetailedErrorsToClient))
|
||||
} else {
|
||||
resp.RemoteAddr = remoteAddr
|
||||
xl.Infof("new proxy [%s] type [%s] success", inMsg.ProxyName, inMsg.ProxyType)
|
||||
clientID := ctl.loginMsg.ClientID
|
||||
clientID := ctl.sessionCtx.LoginMsg.ClientID
|
||||
if clientID == "" {
|
||||
clientID = ctl.loginMsg.RunID
|
||||
clientID = ctl.sessionCtx.LoginMsg.RunID
|
||||
}
|
||||
metrics.Server.NewProxy(inMsg.ProxyName, inMsg.ProxyType, ctl.loginMsg.User, clientID)
|
||||
metrics.Server.NewProxy(inMsg.ProxyName, inMsg.ProxyType, ctl.sessionCtx.LoginMsg.User, clientID)
|
||||
}
|
||||
_ = ctl.msgDispatcher.Send(resp)
|
||||
}
|
||||
@@ -419,22 +404,18 @@ func (ctl *Control) handlePing(m msg.Message) {
|
||||
inMsg := m.(*msg.Ping)
|
||||
|
||||
content := &plugin.PingContent{
|
||||
User: plugin.UserInfo{
|
||||
User: ctl.loginMsg.User,
|
||||
Metas: ctl.loginMsg.Metas,
|
||||
RunID: ctl.loginMsg.RunID,
|
||||
},
|
||||
User: ctl.loginUserInfo(),
|
||||
Ping: *inMsg,
|
||||
}
|
||||
retContent, err := ctl.pluginManager.Ping(content)
|
||||
retContent, err := ctl.sessionCtx.PluginManager.Ping(content)
|
||||
if err == nil {
|
||||
inMsg = &retContent.Ping
|
||||
err = ctl.authVerifier.VerifyPing(inMsg)
|
||||
err = ctl.sessionCtx.AuthVerifier.VerifyPing(inMsg)
|
||||
}
|
||||
if err != nil {
|
||||
xl.Warnf("received invalid ping: %v", err)
|
||||
_ = ctl.msgDispatcher.Send(&msg.Pong{
|
||||
Error: util.GenerateResponseErrorString("invalid ping", err, lo.FromPtr(ctl.serverCfg.DetailedErrorsToClient)),
|
||||
Error: util.GenerateResponseErrorString("invalid ping", err, lo.FromPtr(ctl.sessionCtx.ServerCfg.DetailedErrorsToClient)),
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -445,17 +426,17 @@ func (ctl *Control) handlePing(m msg.Message) {
|
||||
|
||||
func (ctl *Control) handleNatHoleVisitor(m msg.Message) {
|
||||
inMsg := m.(*msg.NatHoleVisitor)
|
||||
ctl.rc.NatHoleController.HandleVisitor(inMsg, ctl.msgTransporter, ctl.loginMsg.User)
|
||||
ctl.sessionCtx.RC.NatHoleController.HandleVisitor(inMsg, ctl.msgTransporter, ctl.sessionCtx.LoginMsg.User)
|
||||
}
|
||||
|
||||
func (ctl *Control) handleNatHoleClient(m msg.Message) {
|
||||
inMsg := m.(*msg.NatHoleClient)
|
||||
ctl.rc.NatHoleController.HandleClient(inMsg, ctl.msgTransporter)
|
||||
ctl.sessionCtx.RC.NatHoleController.HandleClient(inMsg, ctl.msgTransporter)
|
||||
}
|
||||
|
||||
func (ctl *Control) handleNatHoleReport(m msg.Message) {
|
||||
inMsg := m.(*msg.NatHoleReport)
|
||||
ctl.rc.NatHoleController.HandleReport(inMsg)
|
||||
ctl.sessionCtx.RC.NatHoleController.HandleReport(inMsg)
|
||||
}
|
||||
|
||||
func (ctl *Control) handleCloseProxy(m msg.Message) {
|
||||
@@ -468,15 +449,15 @@ func (ctl *Control) handleCloseProxy(m msg.Message) {
|
||||
func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err error) {
|
||||
var pxyConf v1.ProxyConfigurer
|
||||
// Load configures from NewProxy message and validate.
|
||||
pxyConf, err = config.NewProxyConfigurerFromMsg(pxyMsg, ctl.serverCfg)
|
||||
pxyConf, err = config.NewProxyConfigurerFromMsg(pxyMsg, ctl.sessionCtx.ServerCfg)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// User info
|
||||
userInfo := plugin.UserInfo{
|
||||
User: ctl.loginMsg.User,
|
||||
Metas: ctl.loginMsg.Metas,
|
||||
User: ctl.sessionCtx.LoginMsg.User,
|
||||
Metas: ctl.sessionCtx.LoginMsg.Metas,
|
||||
RunID: ctl.runID,
|
||||
}
|
||||
|
||||
@@ -484,22 +465,22 @@ func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err
|
||||
// In fact, it creates different proxies based on the proxy type. We just call run() here.
|
||||
pxy, err := proxy.NewProxy(ctl.ctx, &proxy.Options{
|
||||
UserInfo: userInfo,
|
||||
LoginMsg: ctl.loginMsg,
|
||||
LoginMsg: ctl.sessionCtx.LoginMsg,
|
||||
PoolCount: ctl.poolCount,
|
||||
ResourceController: ctl.rc,
|
||||
ResourceController: ctl.sessionCtx.RC,
|
||||
GetWorkConnFn: ctl.GetWorkConn,
|
||||
Configurer: pxyConf,
|
||||
ServerCfg: ctl.serverCfg,
|
||||
EncryptionKey: ctl.encryptionKey,
|
||||
ServerCfg: ctl.sessionCtx.ServerCfg,
|
||||
EncryptionKey: ctl.sessionCtx.EncryptionKey,
|
||||
})
|
||||
if err != nil {
|
||||
return remoteAddr, err
|
||||
}
|
||||
|
||||
// Check ports used number in each client
|
||||
if ctl.serverCfg.MaxPortsPerClient > 0 {
|
||||
if ctl.sessionCtx.ServerCfg.MaxPortsPerClient > 0 {
|
||||
ctl.mu.Lock()
|
||||
if ctl.portsUsedNum+pxy.GetUsedPortsNum() > int(ctl.serverCfg.MaxPortsPerClient) {
|
||||
if ctl.portsUsedNum+pxy.GetUsedPortsNum() > int(ctl.sessionCtx.ServerCfg.MaxPortsPerClient) {
|
||||
ctl.mu.Unlock()
|
||||
err = fmt.Errorf("exceed the max_ports_per_client")
|
||||
return
|
||||
@@ -516,7 +497,7 @@ func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err
|
||||
}()
|
||||
}
|
||||
|
||||
if ctl.pxyManager.Exist(pxyMsg.ProxyName) {
|
||||
if ctl.sessionCtx.PxyManager.Exist(pxyMsg.ProxyName) {
|
||||
err = fmt.Errorf("proxy [%s] already exists", pxyMsg.ProxyName)
|
||||
return
|
||||
}
|
||||
@@ -531,7 +512,7 @@ func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err
|
||||
}
|
||||
}()
|
||||
|
||||
err = ctl.pxyManager.Add(pxyMsg.ProxyName, pxy)
|
||||
err = ctl.sessionCtx.PxyManager.Add(pxyMsg.ProxyName, pxy)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -550,28 +531,12 @@ func (ctl *Control) CloseProxy(closeMsg *msg.CloseProxy) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
if ctl.serverCfg.MaxPortsPerClient > 0 {
|
||||
if ctl.sessionCtx.ServerCfg.MaxPortsPerClient > 0 {
|
||||
ctl.portsUsedNum -= pxy.GetUsedPortsNum()
|
||||
}
|
||||
pxy.Close()
|
||||
ctl.pxyManager.Del(pxy.GetName())
|
||||
delete(ctl.proxies, closeMsg.ProxyName)
|
||||
ctl.mu.Unlock()
|
||||
|
||||
metrics.Server.CloseProxy(pxy.GetName(), pxy.GetConfigurer().GetBaseConfig().Type)
|
||||
|
||||
notifyContent := &plugin.CloseProxyContent{
|
||||
User: plugin.UserInfo{
|
||||
User: ctl.loginMsg.User,
|
||||
Metas: ctl.loginMsg.Metas,
|
||||
RunID: ctl.loginMsg.RunID,
|
||||
},
|
||||
CloseProxy: msg.CloseProxy{
|
||||
ProxyName: pxy.GetName(),
|
||||
},
|
||||
}
|
||||
go func() {
|
||||
_ = ctl.pluginManager.CloseProxy(notifyContent)
|
||||
}()
|
||||
ctl.closeProxy(pxy)
|
||||
return
|
||||
}
|
||||
|
||||
77
server/group/base.go
Normal file
77
server/group/base.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package group
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
gerr "github.com/fatedier/golib/errors"
|
||||
)
|
||||
|
||||
// baseGroup contains the shared plumbing for listener-based groups
|
||||
// (TCP, HTTPS, TCPMux). Each concrete group embeds this and provides
|
||||
// its own Listen method with protocol-specific validation.
|
||||
type baseGroup struct {
|
||||
group string
|
||||
groupKey string
|
||||
|
||||
acceptCh chan net.Conn
|
||||
realLn net.Listener
|
||||
lns []*Listener
|
||||
mu sync.Mutex
|
||||
cleanupFn func()
|
||||
}
|
||||
|
||||
// initBase resets the baseGroup for a fresh listen cycle.
|
||||
// Must be called under mu when len(lns) == 0.
|
||||
func (bg *baseGroup) initBase(group, groupKey string, realLn net.Listener, cleanupFn func()) {
|
||||
bg.group = group
|
||||
bg.groupKey = groupKey
|
||||
bg.realLn = realLn
|
||||
bg.acceptCh = make(chan net.Conn)
|
||||
bg.cleanupFn = cleanupFn
|
||||
}
|
||||
|
||||
// worker reads from the real listener and fans out to acceptCh.
|
||||
// The parameters are captured at creation time so that the worker is
|
||||
// bound to a specific listen cycle and cannot observe a later initBase.
|
||||
func (bg *baseGroup) worker(realLn net.Listener, acceptCh chan<- net.Conn) {
|
||||
for {
|
||||
c, err := realLn.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = gerr.PanicToError(func() {
|
||||
acceptCh <- c
|
||||
})
|
||||
if err != nil {
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// newListener creates a new Listener wired to this baseGroup.
|
||||
// Must be called under mu.
|
||||
func (bg *baseGroup) newListener(addr net.Addr) *Listener {
|
||||
ln := newListener(bg.acceptCh, addr, bg.closeListener)
|
||||
bg.lns = append(bg.lns, ln)
|
||||
return ln
|
||||
}
|
||||
|
||||
// closeListener removes ln from the list. When the last listener is removed,
|
||||
// it closes acceptCh, closes the real listener, and calls cleanupFn.
|
||||
func (bg *baseGroup) closeListener(ln *Listener) {
|
||||
bg.mu.Lock()
|
||||
defer bg.mu.Unlock()
|
||||
for i, l := range bg.lns {
|
||||
if l == ln {
|
||||
bg.lns = append(bg.lns[:i], bg.lns[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(bg.lns) == 0 {
|
||||
close(bg.acceptCh)
|
||||
bg.realLn.Close()
|
||||
bg.cleanupFn()
|
||||
}
|
||||
}
|
||||
169
server/group/base_test.go
Normal file
169
server/group/base_test.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package group
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// fakeLn is a controllable net.Listener for tests.
|
||||
type fakeLn struct {
|
||||
connCh chan net.Conn
|
||||
closed chan struct{}
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func newFakeLn() *fakeLn {
|
||||
return &fakeLn{
|
||||
connCh: make(chan net.Conn, 8),
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeLn) Accept() (net.Conn, error) {
|
||||
select {
|
||||
case c := <-f.connCh:
|
||||
return c, nil
|
||||
case <-f.closed:
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeLn) Close() error {
|
||||
f.once.Do(func() { close(f.closed) })
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeLn) Addr() net.Addr { return fakeAddr("127.0.0.1:9999") }
|
||||
|
||||
func (f *fakeLn) inject(c net.Conn) {
|
||||
select {
|
||||
case f.connCh <- c:
|
||||
case <-f.closed:
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseGroup_WorkerFanOut(t *testing.T) {
|
||||
fl := newFakeLn()
|
||||
var bg baseGroup
|
||||
bg.initBase("g", "key", fl, func() {})
|
||||
|
||||
go bg.worker(fl, bg.acceptCh)
|
||||
|
||||
c1, c2 := net.Pipe()
|
||||
defer c2.Close()
|
||||
fl.inject(c1)
|
||||
|
||||
select {
|
||||
case got := <-bg.acceptCh:
|
||||
assert.Equal(t, c1, got)
|
||||
got.Close()
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for connection on acceptCh")
|
||||
}
|
||||
|
||||
fl.Close()
|
||||
}
|
||||
|
||||
func TestBaseGroup_WorkerStopsOnListenerClose(t *testing.T) {
|
||||
fl := newFakeLn()
|
||||
var bg baseGroup
|
||||
bg.initBase("g", "key", fl, func() {})
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
bg.worker(fl, bg.acceptCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
fl.Close()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("worker did not stop after listener close")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseGroup_WorkerClosesConnOnClosedChannel(t *testing.T) {
|
||||
fl := newFakeLn()
|
||||
var bg baseGroup
|
||||
bg.initBase("g", "key", fl, func() {})
|
||||
|
||||
// Close acceptCh before worker sends.
|
||||
close(bg.acceptCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
bg.worker(fl, bg.acceptCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
c1, c2 := net.Pipe()
|
||||
defer c2.Close()
|
||||
fl.inject(c1)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("worker did not stop after panic recovery")
|
||||
}
|
||||
|
||||
// c1 should have been closed by worker's panic recovery path.
|
||||
buf := make([]byte, 1)
|
||||
_, err := c1.Read(buf)
|
||||
assert.Error(t, err, "connection should be closed by worker")
|
||||
}
|
||||
|
||||
func TestBaseGroup_CloseLastListenerTriggersCleanup(t *testing.T) {
|
||||
fl := newFakeLn()
|
||||
var bg baseGroup
|
||||
cleanupCalled := 0
|
||||
bg.initBase("g", "key", fl, func() { cleanupCalled++ })
|
||||
|
||||
bg.mu.Lock()
|
||||
ln1 := bg.newListener(fl.Addr())
|
||||
ln2 := bg.newListener(fl.Addr())
|
||||
bg.mu.Unlock()
|
||||
|
||||
go bg.worker(fl, bg.acceptCh)
|
||||
|
||||
ln1.Close()
|
||||
assert.Equal(t, 0, cleanupCalled, "cleanup should not run while listeners remain")
|
||||
|
||||
ln2.Close()
|
||||
assert.Equal(t, 1, cleanupCalled, "cleanup should run after last listener closed")
|
||||
}
|
||||
|
||||
func TestBaseGroup_CloseOneOfTwoListeners(t *testing.T) {
|
||||
fl := newFakeLn()
|
||||
var bg baseGroup
|
||||
cleanupCalled := 0
|
||||
bg.initBase("g", "key", fl, func() { cleanupCalled++ })
|
||||
|
||||
bg.mu.Lock()
|
||||
ln1 := bg.newListener(fl.Addr())
|
||||
ln2 := bg.newListener(fl.Addr())
|
||||
bg.mu.Unlock()
|
||||
|
||||
go bg.worker(fl, bg.acceptCh)
|
||||
|
||||
ln1.Close()
|
||||
assert.Equal(t, 0, cleanupCalled)
|
||||
|
||||
// ln2 should still receive connections.
|
||||
c1, c2 := net.Pipe()
|
||||
defer c2.Close()
|
||||
fl.inject(c1)
|
||||
|
||||
got, err := ln2.Accept()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, c1, got)
|
||||
got.Close()
|
||||
|
||||
ln2.Close()
|
||||
assert.Equal(t, 1, cleanupCalled)
|
||||
}
|
||||
@@ -24,4 +24,6 @@ var (
|
||||
ErrListenerClosed = errors.New("group listener closed")
|
||||
ErrGroupDifferentPort = errors.New("group should have same remote port")
|
||||
ErrProxyRepeated = errors.New("group proxy repeated")
|
||||
|
||||
errGroupStale = errors.New("stale group reference")
|
||||
)
|
||||
|
||||
@@ -9,53 +9,42 @@ import (
|
||||
"github.com/fatedier/frp/pkg/util/vhost"
|
||||
)
|
||||
|
||||
// HTTPGroupController manages HTTP groups that use round-robin
|
||||
// callback routing (fundamentally different from listener-based groups).
|
||||
type HTTPGroupController struct {
|
||||
// groups indexed by group name
|
||||
groups map[string]*HTTPGroup
|
||||
|
||||
// register createConn for each group to vhostRouter.
|
||||
// createConn will get a connection from one proxy of the group
|
||||
groupRegistry[*HTTPGroup]
|
||||
vhostRouter *vhost.Routers
|
||||
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewHTTPGroupController(vhostRouter *vhost.Routers) *HTTPGroupController {
|
||||
return &HTTPGroupController{
|
||||
groups: make(map[string]*HTTPGroup),
|
||||
vhostRouter: vhostRouter,
|
||||
groupRegistry: newGroupRegistry[*HTTPGroup](),
|
||||
vhostRouter: vhostRouter,
|
||||
}
|
||||
}
|
||||
|
||||
func (ctl *HTTPGroupController) Register(
|
||||
proxyName, group, groupKey string,
|
||||
routeConfig vhost.RouteConfig,
|
||||
) (err error) {
|
||||
indexKey := group
|
||||
ctl.mu.Lock()
|
||||
g, ok := ctl.groups[indexKey]
|
||||
if !ok {
|
||||
g = NewHTTPGroup(ctl)
|
||||
ctl.groups[indexKey] = g
|
||||
) error {
|
||||
for {
|
||||
g := ctl.getOrCreate(group, func() *HTTPGroup {
|
||||
return NewHTTPGroup(ctl)
|
||||
})
|
||||
err := g.Register(proxyName, group, groupKey, routeConfig)
|
||||
if err == errGroupStale {
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
ctl.mu.Unlock()
|
||||
|
||||
return g.Register(proxyName, group, groupKey, routeConfig)
|
||||
}
|
||||
|
||||
func (ctl *HTTPGroupController) UnRegister(proxyName, group string, _ vhost.RouteConfig) {
|
||||
indexKey := group
|
||||
ctl.mu.Lock()
|
||||
defer ctl.mu.Unlock()
|
||||
g, ok := ctl.groups[indexKey]
|
||||
g, ok := ctl.get(group)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
isEmpty := g.UnRegister(proxyName)
|
||||
if isEmpty {
|
||||
delete(ctl.groups, indexKey)
|
||||
}
|
||||
g.UnRegister(proxyName)
|
||||
}
|
||||
|
||||
type HTTPGroup struct {
|
||||
@@ -87,6 +76,9 @@ func (g *HTTPGroup) Register(
|
||||
) (err error) {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
if !g.ctl.isCurrent(group, func(cur *HTTPGroup) bool { return cur == g }) {
|
||||
return errGroupStale
|
||||
}
|
||||
if len(g.createFuncs) == 0 {
|
||||
// the first proxy in this group
|
||||
tmp := routeConfig // copy object
|
||||
@@ -123,7 +115,7 @@ func (g *HTTPGroup) Register(
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *HTTPGroup) UnRegister(proxyName string) (isEmpty bool) {
|
||||
func (g *HTTPGroup) UnRegister(proxyName string) {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
delete(g.createFuncs, proxyName)
|
||||
@@ -135,10 +127,11 @@ func (g *HTTPGroup) UnRegister(proxyName string) (isEmpty bool) {
|
||||
}
|
||||
|
||||
if len(g.createFuncs) == 0 {
|
||||
isEmpty = true
|
||||
g.ctl.vhostRouter.Del(g.domain, g.location, g.routeByHTTPUser)
|
||||
g.ctl.removeIf(g.group, func(cur *HTTPGroup) bool {
|
||||
return cur == g
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (g *HTTPGroup) createConn(remoteAddr string) (net.Conn, error) {
|
||||
@@ -151,7 +144,7 @@ func (g *HTTPGroup) createConn(remoteAddr string) (net.Conn, error) {
|
||||
location := g.location
|
||||
routeByHTTPUser := g.routeByHTTPUser
|
||||
if len(g.pxyNames) > 0 {
|
||||
name := g.pxyNames[int(newIndex)%len(g.pxyNames)]
|
||||
name := g.pxyNames[newIndex%uint64(len(g.pxyNames))]
|
||||
f = g.createFuncs[name]
|
||||
}
|
||||
g.mu.RUnlock()
|
||||
@@ -174,7 +167,7 @@ func (g *HTTPGroup) chooseEndpoint() (string, error) {
|
||||
location := g.location
|
||||
routeByHTTPUser := g.routeByHTTPUser
|
||||
if len(g.pxyNames) > 0 {
|
||||
name = g.pxyNames[int(newIndex)%len(g.pxyNames)]
|
||||
name = g.pxyNames[newIndex%uint64(len(g.pxyNames))]
|
||||
}
|
||||
g.mu.RUnlock()
|
||||
|
||||
|
||||
@@ -17,25 +17,19 @@ package group
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
gerr "github.com/fatedier/golib/errors"
|
||||
|
||||
"github.com/fatedier/frp/pkg/util/vhost"
|
||||
)
|
||||
|
||||
type HTTPSGroupController struct {
|
||||
groups map[string]*HTTPSGroup
|
||||
|
||||
groupRegistry[*HTTPSGroup]
|
||||
httpsMuxer *vhost.HTTPSMuxer
|
||||
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewHTTPSGroupController(httpsMuxer *vhost.HTTPSMuxer) *HTTPSGroupController {
|
||||
return &HTTPSGroupController{
|
||||
groups: make(map[string]*HTTPSGroup),
|
||||
httpsMuxer: httpsMuxer,
|
||||
groupRegistry: newGroupRegistry[*HTTPSGroup](),
|
||||
httpsMuxer: httpsMuxer,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,41 +38,28 @@ func (ctl *HTTPSGroupController) Listen(
|
||||
group, groupKey string,
|
||||
routeConfig vhost.RouteConfig,
|
||||
) (l net.Listener, err error) {
|
||||
indexKey := group
|
||||
ctl.mu.Lock()
|
||||
g, ok := ctl.groups[indexKey]
|
||||
if !ok {
|
||||
g = NewHTTPSGroup(ctl)
|
||||
ctl.groups[indexKey] = g
|
||||
for {
|
||||
g := ctl.getOrCreate(group, func() *HTTPSGroup {
|
||||
return NewHTTPSGroup(ctl)
|
||||
})
|
||||
l, err = g.Listen(ctx, group, groupKey, routeConfig)
|
||||
if err == errGroupStale {
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
ctl.mu.Unlock()
|
||||
|
||||
return g.Listen(ctx, group, groupKey, routeConfig)
|
||||
}
|
||||
|
||||
func (ctl *HTTPSGroupController) RemoveGroup(group string) {
|
||||
ctl.mu.Lock()
|
||||
defer ctl.mu.Unlock()
|
||||
delete(ctl.groups, group)
|
||||
}
|
||||
|
||||
type HTTPSGroup struct {
|
||||
group string
|
||||
groupKey string
|
||||
domain string
|
||||
baseGroup
|
||||
|
||||
acceptCh chan net.Conn
|
||||
httpsLn *vhost.Listener
|
||||
lns []*HTTPSGroupListener
|
||||
ctl *HTTPSGroupController
|
||||
mu sync.Mutex
|
||||
domain string
|
||||
ctl *HTTPSGroupController
|
||||
}
|
||||
|
||||
func NewHTTPSGroup(ctl *HTTPSGroupController) *HTTPSGroup {
|
||||
return &HTTPSGroup{
|
||||
lns: make([]*HTTPSGroupListener, 0),
|
||||
ctl: ctl,
|
||||
acceptCh: make(chan net.Conn),
|
||||
ctl: ctl,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -86,23 +67,27 @@ func (g *HTTPSGroup) Listen(
|
||||
ctx context.Context,
|
||||
group, groupKey string,
|
||||
routeConfig vhost.RouteConfig,
|
||||
) (ln *HTTPSGroupListener, err error) {
|
||||
) (ln *Listener, err error) {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
if !g.ctl.isCurrent(group, func(cur *HTTPSGroup) bool { return cur == g }) {
|
||||
return nil, errGroupStale
|
||||
}
|
||||
if len(g.lns) == 0 {
|
||||
// the first listener, listen on the real address
|
||||
httpsLn, errRet := g.ctl.httpsMuxer.Listen(ctx, &routeConfig)
|
||||
if errRet != nil {
|
||||
return nil, errRet
|
||||
}
|
||||
ln = newHTTPSGroupListener(group, g, httpsLn.Addr())
|
||||
|
||||
g.group = group
|
||||
g.groupKey = groupKey
|
||||
g.domain = routeConfig.Domain
|
||||
g.httpsLn = httpsLn
|
||||
g.lns = append(g.lns, ln)
|
||||
go g.worker()
|
||||
g.initBase(group, groupKey, httpsLn, func() {
|
||||
g.ctl.removeIf(g.group, func(cur *HTTPSGroup) bool {
|
||||
return cur == g
|
||||
})
|
||||
})
|
||||
ln = g.newListener(httpsLn.Addr())
|
||||
go g.worker(httpsLn, g.acceptCh)
|
||||
} else {
|
||||
// route config in the same group must be equal
|
||||
if g.group != group || g.domain != routeConfig.Domain {
|
||||
@@ -111,87 +96,7 @@ func (g *HTTPSGroup) Listen(
|
||||
if g.groupKey != groupKey {
|
||||
return nil, ErrGroupAuthFailed
|
||||
}
|
||||
ln = newHTTPSGroupListener(group, g, g.lns[0].Addr())
|
||||
g.lns = append(g.lns, ln)
|
||||
ln = g.newListener(g.lns[0].Addr())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (g *HTTPSGroup) worker() {
|
||||
for {
|
||||
c, err := g.httpsLn.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = gerr.PanicToError(func() {
|
||||
g.acceptCh <- c
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *HTTPSGroup) Accept() <-chan net.Conn {
|
||||
return g.acceptCh
|
||||
}
|
||||
|
||||
func (g *HTTPSGroup) CloseListener(ln *HTTPSGroupListener) {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
for i, tmpLn := range g.lns {
|
||||
if tmpLn == ln {
|
||||
g.lns = append(g.lns[:i], g.lns[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(g.lns) == 0 {
|
||||
close(g.acceptCh)
|
||||
if g.httpsLn != nil {
|
||||
g.httpsLn.Close()
|
||||
}
|
||||
g.ctl.RemoveGroup(g.group)
|
||||
}
|
||||
}
|
||||
|
||||
type HTTPSGroupListener struct {
|
||||
groupName string
|
||||
group *HTTPSGroup
|
||||
|
||||
addr net.Addr
|
||||
closeCh chan struct{}
|
||||
}
|
||||
|
||||
func newHTTPSGroupListener(name string, group *HTTPSGroup, addr net.Addr) *HTTPSGroupListener {
|
||||
return &HTTPSGroupListener{
|
||||
groupName: name,
|
||||
group: group,
|
||||
addr: addr,
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (ln *HTTPSGroupListener) Accept() (c net.Conn, err error) {
|
||||
var ok bool
|
||||
select {
|
||||
case <-ln.closeCh:
|
||||
return nil, ErrListenerClosed
|
||||
case c, ok = <-ln.group.Accept():
|
||||
if !ok {
|
||||
return nil, ErrListenerClosed
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (ln *HTTPSGroupListener) Addr() net.Addr {
|
||||
return ln.addr
|
||||
}
|
||||
|
||||
func (ln *HTTPSGroupListener) Close() (err error) {
|
||||
close(ln.closeCh)
|
||||
|
||||
// remove self from HTTPSGroup
|
||||
ln.group.CloseListener(ln)
|
||||
return
|
||||
}
|
||||
|
||||
49
server/group/listener.go
Normal file
49
server/group/listener.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package group
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Listener is a per-proxy virtual listener that receives connections
|
||||
// from a shared group. It implements net.Listener.
|
||||
type Listener struct {
|
||||
acceptCh <-chan net.Conn
|
||||
addr net.Addr
|
||||
closeCh chan struct{}
|
||||
onClose func(*Listener)
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func newListener(acceptCh <-chan net.Conn, addr net.Addr, onClose func(*Listener)) *Listener {
|
||||
return &Listener{
|
||||
acceptCh: acceptCh,
|
||||
addr: addr,
|
||||
closeCh: make(chan struct{}),
|
||||
onClose: onClose,
|
||||
}
|
||||
}
|
||||
|
||||
func (ln *Listener) Accept() (net.Conn, error) {
|
||||
select {
|
||||
case <-ln.closeCh:
|
||||
return nil, ErrListenerClosed
|
||||
case c, ok := <-ln.acceptCh:
|
||||
if !ok {
|
||||
return nil, ErrListenerClosed
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (ln *Listener) Addr() net.Addr {
|
||||
return ln.addr
|
||||
}
|
||||
|
||||
func (ln *Listener) Close() error {
|
||||
ln.once.Do(func() {
|
||||
close(ln.closeCh)
|
||||
ln.onClose(ln)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
68
server/group/listener_test.go
Normal file
68
server/group/listener_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package group
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestListener_Accept(t *testing.T) {
|
||||
acceptCh := make(chan net.Conn, 1)
|
||||
ln := newListener(acceptCh, fakeAddr("127.0.0.1:1234"), func(*Listener) {})
|
||||
|
||||
c1, c2 := net.Pipe()
|
||||
defer c1.Close()
|
||||
defer c2.Close()
|
||||
|
||||
acceptCh <- c1
|
||||
got, err := ln.Accept()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, c1, got)
|
||||
}
|
||||
|
||||
func TestListener_AcceptAfterChannelClose(t *testing.T) {
|
||||
acceptCh := make(chan net.Conn)
|
||||
ln := newListener(acceptCh, fakeAddr("127.0.0.1:1234"), func(*Listener) {})
|
||||
|
||||
close(acceptCh)
|
||||
_, err := ln.Accept()
|
||||
assert.ErrorIs(t, err, ErrListenerClosed)
|
||||
}
|
||||
|
||||
func TestListener_AcceptAfterListenerClose(t *testing.T) {
|
||||
acceptCh := make(chan net.Conn) // open, not closed
|
||||
ln := newListener(acceptCh, fakeAddr("127.0.0.1:1234"), func(*Listener) {})
|
||||
|
||||
ln.Close()
|
||||
_, err := ln.Accept()
|
||||
assert.ErrorIs(t, err, ErrListenerClosed)
|
||||
}
|
||||
|
||||
func TestListener_DoubleClose(t *testing.T) {
|
||||
closeCalls := 0
|
||||
ln := newListener(
|
||||
make(chan net.Conn),
|
||||
fakeAddr("127.0.0.1:1234"),
|
||||
func(*Listener) { closeCalls++ },
|
||||
)
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
ln.Close()
|
||||
ln.Close()
|
||||
})
|
||||
assert.Equal(t, 1, closeCalls, "onClose should be called exactly once")
|
||||
}
|
||||
|
||||
func TestListener_Addr(t *testing.T) {
|
||||
addr := fakeAddr("10.0.0.1:5555")
|
||||
ln := newListener(make(chan net.Conn), addr, func(*Listener) {})
|
||||
assert.Equal(t, addr, ln.Addr())
|
||||
}
|
||||
|
||||
// fakeAddr implements net.Addr for testing.
|
||||
type fakeAddr string
|
||||
|
||||
func (a fakeAddr) Network() string { return "tcp" }
|
||||
func (a fakeAddr) String() string { return string(a) }
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user