Compare commits

..

32 Commits
tmp ... new

Author SHA1 Message Date
fatedier
01413c3853 test/e2e: use shared deadline for proxy readiness and fix doc comment
- Use a single deadline in waitForClientProxyReady so total wait across
  all proxies does not exceed the given timeout
- Fix WaitForOutput doc comment to accurately describe single pattern
  with count semantics
2026-03-09 18:45:12 +08:00
fatedier
adcd2e64b6 test/e2e: replace RunProcesses client sleep with log-based proxy readiness detection
Replace the fixed 1500ms sleep in RunProcesses with event-driven proxy
registration detection by monitoring frpc log output for "start proxy
success" messages.

Key changes:
- Add thread-safe SafeBuffer to replace bytes.Buffer in Process, enabling
  concurrent read/write of process output during execution
- Add Process.WaitForOutput() to poll process output for pattern matches
  with timeout and early exit on process termination
- Add waitForClientProxyReady() that uses config.LoadClientConfig() to
  extract proxy names, then waits for each proxy's success log
- For visitor-only clients (no deterministic readiness signal), fall back
  to the original sleep with elapsed time deducted
2026-03-09 13:46:13 +08:00
fatedier
48e8901466 test/e2e: optimize RunFrps/RunFrpc with process exit detection (#5225)
* test/e2e: optimize RunFrps/RunFrpc with process exit detection

Refactor Process to track subprocess lifecycle via a done channel,
replacing direct cmd.Wait() in Stop() to avoid double-Wait races.
RunFrps/RunFrpc now use select on the done channel instead of fixed
sleeps, allowing short-lived processes (verify, startup failures) to
return immediately while preserving existing timeout behavior for
long-running daemons.

* test/e2e: guard Process against double-Start and Stop-before-Start

Add started flag to prevent double-Start panics and allow Stop to
return immediately when the process was never started. Use sync.Once
for closing the done channel as defense-in-depth against double close.
2026-03-09 10:28:47 +08:00
fatedier
bcd2424c24 test/e2e: optimize e2e test time by replacing sleeps with TCP readiness checks (#5223)
Replace the fixed 500ms sleep after each frps startup in RunProcesses
with a TCP dial-based readiness check that polls the server bind port.
This reduces the e2e suite wall time from ~97s to ~43s.

Also simplify the RunProcesses API to accept a single server template
string instead of a slice, matching how every call site uses it.
2026-03-08 23:41:33 +08:00
fatedier
c7ac12ea0f server/group: refactor with shared abstractions and fix concurrency issues (#5222)
* server/group: refactor group package with shared abstractions and fix concurrency issues

Extract common patterns into reusable components:
- groupRegistry[G]: generic concurrent map for group lifecycle management
- baseGroup: shared plumbing for listener-based groups (TCP, HTTPS, TCPMux)
- Listener: unified virtual listener replacing 3 identical implementations

Fix concurrency issues:
- Stale-pointer race: isCurrent check + errGroupStale + controller retry loops
- Worker generation safety: pass realLn and acceptCh as params instead of reading mutable fields
- Connection leak: close conn on worker panic recovery path
- ABBA deadlock in HTTP UnRegister: consistent lock ordering (group.mu -> registry.mu)
- Round-robin overflow in HTTPGroup: use unsigned modulo

Add unit tests (17 tests) for registry, listener, and baseGroup.
Add TCPMux group load balancing e2e test.

* server/group: replace tautological assertion with require.NotPanics

* server/group: remove blank line between doc comment and type declaration
2026-03-08 18:57:21 +08:00
fatedier
eeb0dacfc1 pkg/metrics/mem: remove redundant map write-backs and optimize proxy lookup (#5221)
Remove 4 redundant pointer map write-backs in OpenConnection,
CloseConnection, AddTrafficIn, and AddTrafficOut since the map stores
pointers and mutations are already visible without reassignment.

Optimize GetProxiesByTypeAndName from O(n) full map scan to O(1) direct
map lookup by proxy name.
2026-03-08 10:40:39 +08:00
Oleksandr Redko
535eb3db35 refactor: use maps.Clone and slices.Concat (#5220) 2026-03-08 10:38:16 +08:00
fatedier
605f3bdece client/visitor: deduplicate visitor connection handshake and wrapping (#5219)
Extract two shared helpers to eliminate duplicated code across STCP,
SUDP, and XTCP visitors:

- dialRawVisitorConn: handles ConnectServer + NewVisitorConn handshake
  (auth, sign key, 10s read deadline, error check)
- wrapVisitorConn: handles encryption + pooled compression wrapping,
  returning a recycleFn for pool resource cleanup

SUDP is upgraded from WithCompression to WithCompressionFromPool,
aligning with the pooled compression used by STCP and XTCP.
2026-03-08 01:03:40 +08:00
fatedier
764a626b6e server/control: deduplicate close-proxy logic and UserInfo construction (#5218)
Extract closeProxy() helper to eliminate duplicated 4-step cleanup
sequence (Close, PxyManager.Del, metrics, plugin notify) between
worker() and CloseProxy().

Extract loginUserInfo() helper to eliminate 4 repeated plugin.UserInfo
constructions using LoginMsg fields.

Optimize worker() to snapshot and clear the proxies map under lock,
then perform cleanup outside the lock to reduce lock hold time.
2026-03-08 00:02:14 +08:00
Oleksandr Redko
c2454e7114 refactor: fix modernize lint issues (#5215) 2026-03-07 23:10:19 +08:00
fatedier
017d71717f server: introduce SessionContext to encapsulate NewControl parameters (#5217)
Replace 10 positional parameters in NewControl() with a single
SessionContext struct, matching the client-side pattern. This also
eliminates the post-construction mutation of clientRegistry and
removes two TODO comments.
2026-03-07 20:17:00 +08:00
Oleksandr Redko
bd200b1a3b fix: typos in comments, tests, functions (#5216) 2026-03-07 18:43:04 +08:00
fatedier
c70ceff370 fix: three high-severity bugs across nathole, proxy, and udp modules (#5214)
- pkg/nathole: add RLock when reading clientCfgs map in PreCheck path
  to prevent concurrent map read/write crash
- server/proxy: fix error variable shadowing in GetWorkConnFromPool
  that could return a closed connection with nil error
- pkg/util/net: check ListenUDP error before spawning goroutines
  and assign readConn to struct field so Close() works correctly
2026-03-07 13:36:02 +08:00
fatedier
bb3d0e7140 deduplicate common logic across proxy, visitor, and metrics modules (#5213)
- Replace duplicate parseBasicAuth with existing httppkg.ParseBasicAuth
- Extract buildDomains helper in BaseProxy for HTTP/HTTPS/TCPMux proxies
- Extract toProxyStats helper to deduplicate ProxyStats construction
- Extract startVisitorListener helper in BaseProxy for STCP/SUDP proxies
- Extract acceptLoop helper in BaseVisitor for STCP/XTCP visitors
2026-03-07 12:00:27 +08:00
fatedier
cf396563f8 client/proxy: unify work conn wrapping across all proxy types (#5212)
* client/proxy: extract wrapWorkConn to deduplicate UDP/SUDP connection wrapping

Move the repeated rate-limiting, encryption, and compression wrapping
logic from UDPProxy and SUDPProxy into a shared BaseProxy.wrapWorkConn
method, reducing ~18 lines of duplication in each proxy type.

* client/proxy: unify work conn wrapping with pooled compression for all proxy types

Refactor wrapWorkConn to accept encKey parameter and return
(io.ReadWriteCloser, recycleFn, error), enabling HandleTCPWorkConnection
to reuse the same limiter/encryption/compression pipeline.

Switch all proxy types from WithCompression to WithCompressionFromPool.
TCP non-plugin path calls recycleFn via defer after Join; plugin and
UDP/SUDP paths skip recycle (objects are GC'd safely, per golib contract).
2026-03-07 01:33:37 +08:00
fatedier
0b4f83cd04 pkg/config: use modern Go stdlib for sorting and string operations (#5210)
- slices.SortedFunc + maps.Values + cmp.Compare instead of manual
  map-to-slice collection + sort.Slice (source/aggregator.go)
- strings.CutSuffix instead of HasSuffix+TrimSuffix, and deduplicate
  error handling in BandwidthQuantity.UnmarshalString (types/types.go)
2026-03-06 23:13:29 +08:00
fatedier
e9f7a1a9f2 pkg: use modern Go stdlib functions to simplify code (#5209)
- strings.CutPrefix instead of HasPrefix+TrimPrefix (naming, legacy)
- slices.Contains instead of manual loop (plugin/server)
- min/max builtins instead of manual comparisons (nathole)
2026-03-06 22:14:46 +08:00
fatedier
d644593342 server/proxy: simplify HTTPS and TCPMux proxy domain registration (#5208)
Consolidate the separate custom-domain loop and subdomain block into a
single unified loop, matching the pattern already applied to HTTPProxy
in PR #5207. No behavioral change.
2026-03-06 21:31:29 +08:00
fatedier
427c4ca3ae server/proxy: simplify HTTP proxy domain registration by removing duplicate loop (#5207)
The Run() method had two nearly identical loop blocks for registering
custom domains and subdomain, with the same group/non-group registration
logic copy-pasted (~30 lines of duplication).

Consolidate by collecting all domains into a single slice first, then
iterating once with the shared registration logic. Also fixes a minor
inconsistency where the custom domain block used routeConfig.Domain in
CanonicalAddr but the subdomain block used tmpRouteConfig.Domain.
2026-03-06 21:17:30 +08:00
fatedier
f2d1f3739a pkg/util/xlog: fix AddPrefix not updating existing prefix due to range value copy (#5206)
In AddPrefix, the loop `for _, p := range l.prefixes` creates a copy
of each element. Assignments to p.Value and p.Priority only modify
the local copy, not the original slice element, causing updates to
existing prefixes to be silently lost.

This affects client/service.go where AddPrefix is called with
Name:"runID" on reconnection — the old runID value would persist
in log output instead of being updated to the new one.

Fix by using index-based access `l.prefixes[i]` to modify the
original slice element, and add break since prefix names are unique.
2026-03-06 20:44:40 +08:00
fatedier
c23894f156 fix: validate CA cert parsing and add missing ReadHeaderTimeout (#5205)
- pkg/transport/tls.go: check AppendCertsFromPEM return value and
  return clear error when CA file contains no valid PEM certificates
- pkg/plugin/client/http2http.go: set ReadHeaderTimeout to 60s to
  match other plugins and prevent slow header attacks
- pkg/plugin/client/http2https.go: same ReadHeaderTimeout fix
2026-03-06 17:59:41 +08:00
fatedier
cb459b02b6 fix: WebsocketListener nil panic and OIDC auth data race (#5204)
- pkg/util/net/websocket.go: store ln parameter in struct to prevent
  nil pointer panic when Addr() is called
- pkg/auth/oidc.go: replace unsynchronized []string with map + RWMutex
  for subjectsFromLogin to fix data race across concurrent connections
2026-03-06 16:51:52 +08:00
fatedier
8f633fe363 fix: return buffers to pool on error paths to reduce GC pressure (#5203)
- pkg/nathole/nathole.go: add pool.PutBuf(buf) on ReadFromUDP error
  and DecodeMessageInto error paths in waitDetectMessage
- pkg/proto/udp/udp.go: add defer pool.PutBuf(buf) in writerFn to
  ensure buffer is returned when the goroutine exits
2026-03-06 15:55:22 +08:00
fatedier
c62a1da161 fix: close connections on error paths to prevent resource leaks (#5202)
Fix connection leaks in multiple error paths across client and server:
- server/proxy/http: close tmpConn when WithEncryption fails
- client/proxy: close localConn when ProxyProtocol WriteTo fails
- client/visitor/sudp: close visitorConn on all error paths in getNewVisitorConn
- client/visitor/xtcp: close tunnelConn when WithEncryption fails
- client/visitor/xtcp: close lConn when NewKCPConnFromUDP fails
- pkg/plugin/client/unix_domain_socket: close localConn and connInfo.Conn when WriteTo fails, close connInfo.Conn when DialUnix fails
- pkg/plugin/client/tls2raw: close tlsConn when Handshake or Dial fails
2026-03-06 15:18:38 +08:00
fatedier
f22f7d539c server/group: fix port leak and incorrect Listen port in TCPGroup (#5200)
Fix two bugs in TCPGroup.Listen():
- Release acquired port when net.Listen fails to prevent port leak
- Use realPort instead of port for net.Listen to ensure consistency
  between port manager records and actual listening port
2026-03-06 02:25:47 +08:00
fatedier
462c987f6d pkg/msg: change UDPPacket.Content from string to []byte to avoid redundant base64 encode/decode (#5198) 2026-03-06 01:38:24 +08:00
fatedier
541878af4d docs: update release notes for config parsing error improvements (#5196) 2026-03-05 00:00:46 +08:00
fatedier
b7435967b0 pkg/config: fix line numbers shown incorrectly for TOML type mismatch errors (#5195) 2026-03-04 20:53:22 +08:00
fatedier
774478d071 pkg/config: improve error messages with line number details for config parsing (#5194)
When frpc verify encounters config errors, the error messages now include
line/column information to help users locate problems:

- TOML syntax errors: report line and column from the TOML parser
  (e.g., "toml: line 4, column 11: expected character ]")
- Type mismatch errors: report the field name and approximate line number
  (e.g., "line 3: field \"proxies\": cannot unmarshal string into ...")
- File format detection: use file extension to determine format, preventing
  silent fallthrough from TOML to YAML parser which produced confusing errors

Previously, a TOML file with syntax errors would silently fall through to the
YAML parser, which would misinterpret the content and produce unhelpful errors
like "json: cannot unmarshal string into Go value of type v1.ClientConfig".

https://claude.ai/code/session_017HWLfcXS3U2hLoy4dsg8Nv

Co-authored-by: Claude <noreply@anthropic.com>
2026-03-04 19:27:30 +08:00
fatedier
fbeb6ca43a refactor: restructure API packages into client/http and server/http with typed proxy/visitor models (#5193) 2026-03-04 17:38:43 +08:00
fatedier
381245a439 build: add noweb tag to allow building without frontend assets (#5189) 2026-03-02 01:32:19 +08:00
fatedier
01997deb98 add persistent proxy/visitor store with CRUD API and web UI (#5188) 2026-03-02 01:09:59 +08:00
183 changed files with 12370 additions and 6110 deletions

View File

@@ -2,7 +2,7 @@ version: 2
jobs:
go-version-latest:
docker:
- image: cimg/go:1.24-node
- image: cimg/go:1.25-node
resource_class: large
steps:
- checkout

View File

@@ -17,7 +17,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: '1.24'
go-version: '1.25'
cache: false
- uses: actions/setup-node@v4
with:
@@ -29,7 +29,7 @@ jobs:
run: make build
working-directory: web/frpc
- name: golangci-lint
uses: golangci/golangci-lint-action@v8
uses: golangci/golangci-lint-action@v9
with:
# Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
version: v2.3
version: v2.10

View File

@@ -15,7 +15,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: '1.24'
go-version: '1.25'
- uses: actions/setup-node@v4
with:
node-version: '22'

1
.gitignore vendored
View File

@@ -30,4 +30,5 @@ client.key
# AI
CLAUDE.md
AGENTS.md
.sisyphus/

View File

@@ -18,6 +18,7 @@ linters:
- lll
- makezero
- misspell
- modernize
- prealloc
- predeclared
- revive
@@ -33,13 +34,7 @@ linters:
disabled-checks:
- exitAfterDefer
gosec:
excludes:
- G401
- G402
- G404
- G501
- G115
- G204
excludes: ["G115", "G117", "G204", "G401", "G402", "G404", "G501", "G703", "G704", "G705"]
severity: low
confidence: low
govet:
@@ -53,6 +48,9 @@ linters:
ignore-rules:
- cancelled
- marshalled
modernize:
disable:
- omitzero
unparam:
check-exported: false
exclusions:
@@ -77,6 +75,9 @@ linters:
- linters:
- revive
text: "avoid meaningless package names"
- linters:
- revive
text: "Go standard library package names"
- linters:
- unparam
text: is always false

View File

@@ -1,6 +1,7 @@
export PATH := $(PATH):`go env GOPATH`/bin
export GO111MODULE=on
LDFLAGS := -s -w
NOWEB_TAG = $(shell [ ! -d web/frps/dist ] || [ ! -d web/frpc/dist ] && echo ',noweb')
.PHONY: web frps-web frpc-web frps frpc
@@ -28,23 +29,23 @@ fmt-more:
gci:
gci write -s standard -s default -s "prefix(github.com/fatedier/frp/)" ./
vet: web
go vet ./...
vet:
go vet -tags "$(NOWEB_TAG)" ./...
frps:
env CGO_ENABLED=0 go build -trimpath -ldflags "$(LDFLAGS)" -tags frps -o bin/frps ./cmd/frps
env CGO_ENABLED=0 go build -trimpath -ldflags "$(LDFLAGS)" -tags "frps$(NOWEB_TAG)" -o bin/frps ./cmd/frps
frpc:
env CGO_ENABLED=0 go build -trimpath -ldflags "$(LDFLAGS)" -tags frpc -o bin/frpc ./cmd/frpc
env CGO_ENABLED=0 go build -trimpath -ldflags "$(LDFLAGS)" -tags "frpc$(NOWEB_TAG)" -o bin/frpc ./cmd/frpc
test: gotest
gotest: web
go test -v --cover ./assets/...
go test -v --cover ./cmd/...
go test -v --cover ./client/...
go test -v --cover ./server/...
go test -v --cover ./pkg/...
gotest:
go test -tags "$(NOWEB_TAG)" -v --cover ./assets/...
go test -tags "$(NOWEB_TAG)" -v --cover ./cmd/...
go test -tags "$(NOWEB_TAG)" -v --cover ./client/...
go test -tags "$(NOWEB_TAG)" -v --cover ./server/...
go test -tags "$(NOWEB_TAG)" -v --cover ./pkg/...
e2e:
./hack/run-e2e.sh

View File

@@ -13,6 +13,16 @@ frp is an open source project with its ongoing development made possible entirel
<h3 align="center">Gold Sponsors</h3>
<!--gold sponsors start-->
<div align="center">
## Recall.ai - API for meeting recordings
If you're looking for a meeting recording API, consider checking out [Recall.ai](https://www.recall.ai/?utm_source=github&utm_medium=sponsorship&utm_campaign=fatedier-frp),
an API that records Zoom, Google Meet, Microsoft Teams, in-person meetings, and more.
</div>
<p align="center">
<a href="https://requestly.com/?utm_source=github&utm_medium=partnered&utm_campaign=frp" target="_blank">
<img width="480px" src="https://github.com/user-attachments/assets/24670320-997d-4d62-9bca-955c59fe883d">
@@ -40,15 +50,6 @@ frp is an open source project with its ongoing development made possible entirel
<sub>An open source, self-hosted alternative to public clouds, built for data ownership and privacy</sub>
</a>
</p>
<div align="center">
## Recall.ai - API for meeting recordings
If you're looking for a meeting recording API, consider checking out [Recall.ai](https://www.recall.ai/?utm_source=github&utm_medium=sponsorship&utm_campaign=fatedier-frp),
an API that records Zoom, Google Meet, Microsoft Teams, in-person meetings, and more.
</div>
<!--gold sponsors end-->
## What is frp?
@@ -800,6 +801,14 @@ Then run command `frpc reload -c ./frpc.toml` and wait for about 10 seconds to l
**Note that global client parameters won't be modified except 'start'.**
`start` is a global allowlist evaluated after all sources are merged (config file/include/store).
If `start` is non-empty, any proxy or visitor not listed there will not be started, including
entries created via Store API.
`start` is kept mainly for compatibility and is generally not recommended for new configurations.
Prefer per-proxy/per-visitor `enabled`, and keep `start` empty unless you explicitly want this
global allowlist behavior.
You can run command `frpc verify -c ./frpc.toml` before reloading to check if there are config errors.
### Get proxy status from client

View File

@@ -15,6 +15,16 @@ frp 是一个完全开源的项目,我们的开发工作完全依靠赞助者
<h3 align="center">Gold Sponsors</h3>
<!--gold sponsors start-->
<div align="center">
## Recall.ai - API for meeting recordings
If you're looking for a meeting recording API, consider checking out [Recall.ai](https://www.recall.ai/?utm_source=github&utm_medium=sponsorship&utm_campaign=fatedier-frp),
an API that records Zoom, Google Meet, Microsoft Teams, in-person meetings, and more.
</div>
<p align="center">
<a href="https://requestly.com/?utm_source=github&utm_medium=partnered&utm_campaign=frp" target="_blank">
<img width="480px" src="https://github.com/user-attachments/assets/24670320-997d-4d62-9bca-955c59fe883d">
@@ -42,15 +52,6 @@ frp 是一个完全开源的项目,我们的开发工作完全依靠赞助者
<sub>An open source, self-hosted alternative to public clouds, built for data ownership and privacy</sub>
</a>
</p>
<div align="center">
## Recall.ai - API for meeting recordings
If you're looking for a meeting recording API, consider checking out [Recall.ai](https://www.recall.ai/?utm_source=github&utm_medium=sponsorship&utm_campaign=fatedier-frp),
an API that records Zoom, Google Meet, Microsoft Teams, in-person meetings, and more.
</div>
<!--gold sponsors end-->
## 为什么使用 frp

View File

@@ -1,8 +1,9 @@
## Features
* frpc now supports a `clientID` option to uniquely identify client instances. The server dashboard displays all connected clients with their online/offline status, connection history, and metadata, making it easier to monitor and manage multiple frpc deployments.
* Redesigned the frp web dashboard with a modern UI, dark mode support, and improved navigation.
* Added a built-in `store` capability for frpc, including persisted store source (`[store] path = "..."`), Store CRUD admin APIs (`/api/store/proxies*`, `/api/store/visitors*`) with runtime reload, and Store management pages in the frpc web dashboard.
## Fixes
## Improvements
* Fixed UDP proxy protocol sending header on every packet instead of only the first packet of each session.
* Kept proxy/visitor names as raw config names during completion; moved user-prefix handling to explicit wire-level naming logic.
* Added `noweb` build tag to allow compiling without frontend assets. `make build` now auto-detects missing `web/*/dist` directories and skips embedding, so a fresh clone can build without running `make web` first. The dashboard gracefully returns 404 when assets are not embedded.
* Improved config parsing errors: for `.toml` files, syntax errors now return immediately with parser position details (line/column when available) instead of falling through to YAML/JSON parsing, and TOML type mismatches report field-level errors without misleading line numbers.

View File

@@ -29,14 +29,23 @@ var (
prefixPath string
)
type emptyFS struct{}
func (emptyFS) Open(name string) (http.File, error) {
return nil, &fs.PathError{Op: "open", Path: name, Err: fs.ErrNotExist}
}
// if path is empty, load assets in memory
// or set FileSystem using disk files
func Load(path string) {
prefixPath = path
if prefixPath != "" {
switch {
case prefixPath != "":
FileSystem = http.Dir(prefixPath)
} else {
case content != nil:
FileSystem = http.FS(content)
default:
FileSystem = emptyFS{}
}
}

View File

@@ -1,498 +0,0 @@
// Copyright 2025 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package api
import (
"cmp"
"encoding/json"
"fmt"
"net"
"net/http"
"os"
"slices"
"strconv"
"time"
"github.com/fatedier/frp/client/proxy"
"github.com/fatedier/frp/pkg/config"
"github.com/fatedier/frp/pkg/config/source"
v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/config/v1/validation"
"github.com/fatedier/frp/pkg/policy/security"
httppkg "github.com/fatedier/frp/pkg/util/http"
"github.com/fatedier/frp/pkg/util/log"
)
// Controller handles HTTP API requests for frpc.
type Controller struct {
getProxyStatus func() []*proxy.WorkingStatus
serverAddr string
configFilePath string
unsafeFeatures *security.UnsafeFeatures
updateConfig func(common *v1.ClientCommonConfig, proxyCfgs []v1.ProxyConfigurer, visitorCfgs []v1.VisitorConfigurer) error
reloadFromSources func() error
gracefulClose func(d time.Duration)
storeSource *source.StoreSource
}
// ControllerParams contains parameters for creating an APIController.
type ControllerParams struct {
GetProxyStatus func() []*proxy.WorkingStatus
ServerAddr string
ConfigFilePath string
UnsafeFeatures *security.UnsafeFeatures
UpdateConfig func(common *v1.ClientCommonConfig, proxyCfgs []v1.ProxyConfigurer, visitorCfgs []v1.VisitorConfigurer) error
ReloadFromSources func() error
GracefulClose func(d time.Duration)
StoreSource *source.StoreSource
}
func NewController(params ControllerParams) *Controller {
return &Controller{
getProxyStatus: params.GetProxyStatus,
serverAddr: params.ServerAddr,
configFilePath: params.ConfigFilePath,
unsafeFeatures: params.UnsafeFeatures,
updateConfig: params.UpdateConfig,
reloadFromSources: params.ReloadFromSources,
gracefulClose: params.GracefulClose,
storeSource: params.StoreSource,
}
}
func (c *Controller) reloadFromSourcesOrError() error {
if err := c.reloadFromSources(); err != nil {
return httppkg.NewError(http.StatusInternalServerError, fmt.Sprintf("failed to apply config: %v", err))
}
return nil
}
// Reload handles GET /api/reload
func (c *Controller) Reload(ctx *httppkg.Context) (any, error) {
strictConfigMode := false
strictStr := ctx.Query("strictConfig")
if strictStr != "" {
strictConfigMode, _ = strconv.ParseBool(strictStr)
}
result, err := config.LoadClientConfigResult(c.configFilePath, strictConfigMode)
if err != nil {
log.Warnf("reload frpc proxy config error: %s", err.Error())
return nil, httppkg.NewError(http.StatusBadRequest, err.Error())
}
proxyCfgs := result.Proxies
visitorCfgs := result.Visitors
proxyCfgsForValidation, visitorCfgsForValidation := config.FilterClientConfigurers(
result.Common,
proxyCfgs,
visitorCfgs,
)
proxyCfgsForValidation = config.CompleteProxyConfigurers(proxyCfgsForValidation)
visitorCfgsForValidation = config.CompleteVisitorConfigurers(visitorCfgsForValidation)
if _, err := validation.ValidateAllClientConfig(result.Common, proxyCfgsForValidation, visitorCfgsForValidation, c.unsafeFeatures); err != nil {
log.Warnf("reload frpc proxy config error: %s", err.Error())
return nil, httppkg.NewError(http.StatusBadRequest, err.Error())
}
if err := c.updateConfig(result.Common, proxyCfgs, visitorCfgs); err != nil {
log.Warnf("reload frpc proxy config error: %s", err.Error())
return nil, httppkg.NewError(http.StatusInternalServerError, err.Error())
}
log.Infof("success reload conf")
return nil, nil
}
// Stop handles POST /api/stop
func (c *Controller) Stop(ctx *httppkg.Context) (any, error) {
go c.gracefulClose(100 * time.Millisecond)
return nil, nil
}
// Status handles GET /api/status
func (c *Controller) Status(ctx *httppkg.Context) (any, error) {
res := make(StatusResp)
ps := c.getProxyStatus()
if ps == nil {
return res, nil
}
for _, status := range ps {
res[status.Type] = append(res[status.Type], c.buildProxyStatusResp(status))
}
for _, arrs := range res {
if len(arrs) <= 1 {
continue
}
slices.SortFunc(arrs, func(a, b ProxyStatusResp) int {
return cmp.Compare(a.Name, b.Name)
})
}
return res, nil
}
// GetConfig handles GET /api/config
func (c *Controller) GetConfig(ctx *httppkg.Context) (any, error) {
if c.configFilePath == "" {
return nil, httppkg.NewError(http.StatusBadRequest, "frpc has no config file path")
}
content, err := os.ReadFile(c.configFilePath)
if err != nil {
log.Warnf("load frpc config file error: %s", err.Error())
return nil, httppkg.NewError(http.StatusBadRequest, err.Error())
}
return string(content), nil
}
// PutConfig handles PUT /api/config
func (c *Controller) PutConfig(ctx *httppkg.Context) (any, error) {
body, err := ctx.Body()
if err != nil {
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("read request body error: %v", err))
}
if len(body) == 0 {
return nil, httppkg.NewError(http.StatusBadRequest, "body can't be empty")
}
if err := os.WriteFile(c.configFilePath, body, 0o600); err != nil {
return nil, httppkg.NewError(http.StatusInternalServerError, fmt.Sprintf("write content to frpc config file error: %v", err))
}
return nil, nil
}
func (c *Controller) buildProxyStatusResp(status *proxy.WorkingStatus) ProxyStatusResp {
psr := ProxyStatusResp{
Name: status.Name,
Type: status.Type,
Status: status.Phase,
Err: status.Err,
}
baseCfg := status.Cfg.GetBaseConfig()
if baseCfg.LocalPort != 0 {
psr.LocalAddr = net.JoinHostPort(baseCfg.LocalIP, strconv.Itoa(baseCfg.LocalPort))
}
psr.Plugin = baseCfg.Plugin.Type
if status.Err == "" {
psr.RemoteAddr = status.RemoteAddr
if slices.Contains([]string{"tcp", "udp"}, status.Type) {
psr.RemoteAddr = c.serverAddr + psr.RemoteAddr
}
}
// Check if proxy is from store
if c.storeSource != nil {
if c.storeSource.GetProxy(status.Name) != nil {
psr.Source = "store"
}
}
return psr
}
func (c *Controller) ListStoreProxies(ctx *httppkg.Context) (any, error) {
proxies, err := c.storeSource.GetAllProxies()
if err != nil {
return nil, httppkg.NewError(http.StatusInternalServerError, fmt.Sprintf("failed to list proxies: %v", err))
}
resp := ProxyListResp{Proxies: make([]ProxyConfig, 0, len(proxies))}
for _, p := range proxies {
cfg, err := proxyConfigurerToMap(p)
if err != nil {
continue
}
resp.Proxies = append(resp.Proxies, ProxyConfig{
Name: p.GetBaseConfig().Name,
Type: p.GetBaseConfig().Type,
Config: cfg,
})
}
return resp, nil
}
func (c *Controller) GetStoreProxy(ctx *httppkg.Context) (any, error) {
name := ctx.Param("name")
if name == "" {
return nil, httppkg.NewError(http.StatusBadRequest, "proxy name is required")
}
p := c.storeSource.GetProxy(name)
if p == nil {
return nil, httppkg.NewError(http.StatusNotFound, fmt.Sprintf("proxy %q not found", name))
}
cfg, err := proxyConfigurerToMap(p)
if err != nil {
return nil, httppkg.NewError(http.StatusInternalServerError, err.Error())
}
return ProxyConfig{
Name: p.GetBaseConfig().Name,
Type: p.GetBaseConfig().Type,
Config: cfg,
}, nil
}
func (c *Controller) CreateStoreProxy(ctx *httppkg.Context) (any, error) {
body, err := ctx.Body()
if err != nil {
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("read body error: %v", err))
}
var typed v1.TypedProxyConfig
if err := json.Unmarshal(body, &typed); err != nil {
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("parse JSON error: %v", err))
}
if typed.ProxyConfigurer == nil {
return nil, httppkg.NewError(http.StatusBadRequest, "invalid proxy config: type is required")
}
typed.Complete()
if err := validation.ValidateProxyConfigurerForClient(typed.ProxyConfigurer); err != nil {
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("validation error: %v", err))
}
if err := c.storeSource.AddProxy(typed.ProxyConfigurer); err != nil {
return nil, httppkg.NewError(http.StatusConflict, err.Error())
}
if err := c.reloadFromSourcesOrError(); err != nil {
return nil, err
}
log.Infof("store: created proxy %q", typed.ProxyConfigurer.GetBaseConfig().Name)
return nil, nil
}
func (c *Controller) UpdateStoreProxy(ctx *httppkg.Context) (any, error) {
name := ctx.Param("name")
if name == "" {
return nil, httppkg.NewError(http.StatusBadRequest, "proxy name is required")
}
body, err := ctx.Body()
if err != nil {
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("read body error: %v", err))
}
var typed v1.TypedProxyConfig
if err := json.Unmarshal(body, &typed); err != nil {
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("parse JSON error: %v", err))
}
if typed.ProxyConfigurer == nil {
return nil, httppkg.NewError(http.StatusBadRequest, "invalid proxy config: type is required")
}
bodyName := typed.ProxyConfigurer.GetBaseConfig().Name
if bodyName != name {
return nil, httppkg.NewError(http.StatusBadRequest, "proxy name in URL must match name in body")
}
typed.Complete()
if err := validation.ValidateProxyConfigurerForClient(typed.ProxyConfigurer); err != nil {
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("validation error: %v", err))
}
if err := c.storeSource.UpdateProxy(typed.ProxyConfigurer); err != nil {
return nil, httppkg.NewError(http.StatusNotFound, err.Error())
}
if err := c.reloadFromSourcesOrError(); err != nil {
return nil, err
}
log.Infof("store: updated proxy %q", name)
return nil, nil
}
func (c *Controller) DeleteStoreProxy(ctx *httppkg.Context) (any, error) {
name := ctx.Param("name")
if name == "" {
return nil, httppkg.NewError(http.StatusBadRequest, "proxy name is required")
}
if err := c.storeSource.RemoveProxy(name); err != nil {
return nil, httppkg.NewError(http.StatusNotFound, err.Error())
}
if err := c.reloadFromSourcesOrError(); err != nil {
return nil, err
}
log.Infof("store: deleted proxy %q", name)
return nil, nil
}
func (c *Controller) ListStoreVisitors(ctx *httppkg.Context) (any, error) {
visitors, err := c.storeSource.GetAllVisitors()
if err != nil {
return nil, httppkg.NewError(http.StatusInternalServerError, fmt.Sprintf("failed to list visitors: %v", err))
}
resp := VisitorListResp{Visitors: make([]VisitorConfig, 0, len(visitors))}
for _, v := range visitors {
cfg, err := visitorConfigurerToMap(v)
if err != nil {
continue
}
resp.Visitors = append(resp.Visitors, VisitorConfig{
Name: v.GetBaseConfig().Name,
Type: v.GetBaseConfig().Type,
Config: cfg,
})
}
return resp, nil
}
func (c *Controller) GetStoreVisitor(ctx *httppkg.Context) (any, error) {
name := ctx.Param("name")
if name == "" {
return nil, httppkg.NewError(http.StatusBadRequest, "visitor name is required")
}
v := c.storeSource.GetVisitor(name)
if v == nil {
return nil, httppkg.NewError(http.StatusNotFound, fmt.Sprintf("visitor %q not found", name))
}
cfg, err := visitorConfigurerToMap(v)
if err != nil {
return nil, httppkg.NewError(http.StatusInternalServerError, err.Error())
}
return VisitorConfig{
Name: v.GetBaseConfig().Name,
Type: v.GetBaseConfig().Type,
Config: cfg,
}, nil
}
func (c *Controller) CreateStoreVisitor(ctx *httppkg.Context) (any, error) {
body, err := ctx.Body()
if err != nil {
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("read body error: %v", err))
}
var typed v1.TypedVisitorConfig
if err := json.Unmarshal(body, &typed); err != nil {
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("parse JSON error: %v", err))
}
if typed.VisitorConfigurer == nil {
return nil, httppkg.NewError(http.StatusBadRequest, "invalid visitor config: type is required")
}
typed.Complete()
if err := validation.ValidateVisitorConfigurer(typed.VisitorConfigurer); err != nil {
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("validation error: %v", err))
}
if err := c.storeSource.AddVisitor(typed.VisitorConfigurer); err != nil {
return nil, httppkg.NewError(http.StatusConflict, err.Error())
}
if err := c.reloadFromSourcesOrError(); err != nil {
return nil, err
}
log.Infof("store: created visitor %q", typed.VisitorConfigurer.GetBaseConfig().Name)
return nil, nil
}
func (c *Controller) UpdateStoreVisitor(ctx *httppkg.Context) (any, error) {
name := ctx.Param("name")
if name == "" {
return nil, httppkg.NewError(http.StatusBadRequest, "visitor name is required")
}
body, err := ctx.Body()
if err != nil {
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("read body error: %v", err))
}
var typed v1.TypedVisitorConfig
if err := json.Unmarshal(body, &typed); err != nil {
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("parse JSON error: %v", err))
}
if typed.VisitorConfigurer == nil {
return nil, httppkg.NewError(http.StatusBadRequest, "invalid visitor config: type is required")
}
bodyName := typed.VisitorConfigurer.GetBaseConfig().Name
if bodyName != name {
return nil, httppkg.NewError(http.StatusBadRequest, "visitor name in URL must match name in body")
}
typed.Complete()
if err := validation.ValidateVisitorConfigurer(typed.VisitorConfigurer); err != nil {
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("validation error: %v", err))
}
if err := c.storeSource.UpdateVisitor(typed.VisitorConfigurer); err != nil {
return nil, httppkg.NewError(http.StatusNotFound, err.Error())
}
if err := c.reloadFromSourcesOrError(); err != nil {
return nil, err
}
log.Infof("store: updated visitor %q", name)
return nil, nil
}
func (c *Controller) DeleteStoreVisitor(ctx *httppkg.Context) (any, error) {
name := ctx.Param("name")
if name == "" {
return nil, httppkg.NewError(http.StatusBadRequest, "visitor name is required")
}
if err := c.storeSource.RemoveVisitor(name); err != nil {
return nil, httppkg.NewError(http.StatusNotFound, err.Error())
}
if err := c.reloadFromSourcesOrError(); err != nil {
return nil, err
}
log.Infof("store: deleted visitor %q", name)
return nil, nil
}
func proxyConfigurerToMap(p v1.ProxyConfigurer) (map[string]any, error) {
data, err := json.Marshal(p)
if err != nil {
return nil, err
}
var m map[string]any
if err := json.Unmarshal(data, &m); err != nil {
return nil, err
}
return m, nil
}
func visitorConfigurerToMap(v v1.VisitorConfigurer) (map[string]any, error) {
data, err := json.Marshal(v)
if err != nil {
return nil, err
}
var m map[string]any
if err := json.Unmarshal(data, &m); err != nil {
return nil, err
}
return m, nil
}

View File

@@ -17,7 +17,7 @@ package client
import (
"net/http"
"github.com/fatedier/frp/client/api"
adminapi "github.com/fatedier/frp/client/http"
"github.com/fatedier/frp/client/proxy"
httppkg "github.com/fatedier/frp/pkg/util/http"
netpkg "github.com/fatedier/frp/pkg/util/net"
@@ -65,16 +65,11 @@ func healthz(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}
func newAPIController(svr *Service) *api.Controller {
return api.NewController(api.ControllerParams{
GetProxyStatus: svr.getAllProxyStatus,
ServerAddr: svr.common.ServerAddr,
ConfigFilePath: svr.configFilePath,
UnsafeFeatures: svr.unsafeFeatures,
UpdateConfig: svr.UpdateConfigSource,
ReloadFromSources: svr.reloadConfigFromSources,
GracefulClose: svr.GracefulClose,
StoreSource: svr.storeSource,
func newAPIController(svr *Service) *adminapi.Controller {
manager := newServiceConfigManager(svr)
return adminapi.NewController(adminapi.ControllerParams{
ServerAddr: svr.common.ServerAddr,
Manager: manager,
})
}

422
client/config_manager.go Normal file
View File

@@ -0,0 +1,422 @@
package client
import (
"errors"
"fmt"
"os"
"time"
"github.com/fatedier/frp/client/configmgmt"
"github.com/fatedier/frp/client/proxy"
"github.com/fatedier/frp/pkg/config"
"github.com/fatedier/frp/pkg/config/source"
v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/config/v1/validation"
"github.com/fatedier/frp/pkg/util/log"
)
type serviceConfigManager struct {
svr *Service
}
func newServiceConfigManager(svr *Service) configmgmt.ConfigManager {
return &serviceConfigManager{svr: svr}
}
func (m *serviceConfigManager) ReloadFromFile(strict bool) error {
if m.svr.configFilePath == "" {
return fmt.Errorf("%w: frpc has no config file path", configmgmt.ErrInvalidArgument)
}
result, err := config.LoadClientConfigResult(m.svr.configFilePath, strict)
if err != nil {
return fmt.Errorf("%w: %v", configmgmt.ErrInvalidArgument, err)
}
proxyCfgsForValidation, visitorCfgsForValidation := config.FilterClientConfigurers(
result.Common,
result.Proxies,
result.Visitors,
)
proxyCfgsForValidation = config.CompleteProxyConfigurers(proxyCfgsForValidation)
visitorCfgsForValidation = config.CompleteVisitorConfigurers(visitorCfgsForValidation)
if _, err := validation.ValidateAllClientConfig(result.Common, proxyCfgsForValidation, visitorCfgsForValidation, m.svr.unsafeFeatures); err != nil {
return fmt.Errorf("%w: %v", configmgmt.ErrInvalidArgument, err)
}
if err := m.svr.UpdateConfigSource(result.Common, result.Proxies, result.Visitors); err != nil {
return fmt.Errorf("%w: %v", configmgmt.ErrApplyConfig, err)
}
log.Infof("success reload conf")
return nil
}
func (m *serviceConfigManager) ReadConfigFile() (string, error) {
if m.svr.configFilePath == "" {
return "", fmt.Errorf("%w: frpc has no config file path", configmgmt.ErrInvalidArgument)
}
content, err := os.ReadFile(m.svr.configFilePath)
if err != nil {
return "", fmt.Errorf("%w: %v", configmgmt.ErrInvalidArgument, err)
}
return string(content), nil
}
func (m *serviceConfigManager) WriteConfigFile(content []byte) error {
if len(content) == 0 {
return fmt.Errorf("%w: body can't be empty", configmgmt.ErrInvalidArgument)
}
if err := os.WriteFile(m.svr.configFilePath, content, 0o600); err != nil {
return err
}
return nil
}
func (m *serviceConfigManager) GetProxyStatus() []*proxy.WorkingStatus {
return m.svr.getAllProxyStatus()
}
func (m *serviceConfigManager) IsStoreProxyEnabled(name string) bool {
if name == "" {
return false
}
m.svr.reloadMu.Lock()
storeSource := m.svr.storeSource
m.svr.reloadMu.Unlock()
if storeSource == nil {
return false
}
cfg := storeSource.GetProxy(name)
if cfg == nil {
return false
}
enabled := cfg.GetBaseConfig().Enabled
return enabled == nil || *enabled
}
func (m *serviceConfigManager) StoreEnabled() bool {
m.svr.reloadMu.Lock()
storeSource := m.svr.storeSource
m.svr.reloadMu.Unlock()
return storeSource != nil
}
func (m *serviceConfigManager) ListStoreProxies() ([]v1.ProxyConfigurer, error) {
storeSource, err := m.storeSourceOrError()
if err != nil {
return nil, err
}
return storeSource.GetAllProxies()
}
func (m *serviceConfigManager) GetStoreProxy(name string) (v1.ProxyConfigurer, error) {
if name == "" {
return nil, fmt.Errorf("%w: proxy name is required", configmgmt.ErrInvalidArgument)
}
storeSource, err := m.storeSourceOrError()
if err != nil {
return nil, err
}
cfg := storeSource.GetProxy(name)
if cfg == nil {
return nil, fmt.Errorf("%w: proxy %q", configmgmt.ErrNotFound, name)
}
return cfg, nil
}
func (m *serviceConfigManager) CreateStoreProxy(cfg v1.ProxyConfigurer) (v1.ProxyConfigurer, error) {
if err := m.validateStoreProxyConfigurer(cfg); err != nil {
return nil, fmt.Errorf("%w: validation error: %v", configmgmt.ErrInvalidArgument, err)
}
name := cfg.GetBaseConfig().Name
persisted, err := m.withStoreProxyMutationAndReload(name, func(storeSource *source.StoreSource) error {
if err := storeSource.AddProxy(cfg); err != nil {
if errors.Is(err, source.ErrAlreadyExists) {
return fmt.Errorf("%w: %v", configmgmt.ErrConflict, err)
}
return err
}
return nil
})
if err != nil {
return nil, err
}
log.Infof("store: created proxy %q", name)
return persisted, nil
}
func (m *serviceConfigManager) UpdateStoreProxy(name string, cfg v1.ProxyConfigurer) (v1.ProxyConfigurer, error) {
if name == "" {
return nil, fmt.Errorf("%w: proxy name is required", configmgmt.ErrInvalidArgument)
}
if cfg == nil {
return nil, fmt.Errorf("%w: invalid proxy config: type is required", configmgmt.ErrInvalidArgument)
}
bodyName := cfg.GetBaseConfig().Name
if bodyName != name {
return nil, fmt.Errorf("%w: proxy name in URL must match name in body", configmgmt.ErrInvalidArgument)
}
if err := m.validateStoreProxyConfigurer(cfg); err != nil {
return nil, fmt.Errorf("%w: validation error: %v", configmgmt.ErrInvalidArgument, err)
}
persisted, err := m.withStoreProxyMutationAndReload(name, func(storeSource *source.StoreSource) error {
if err := storeSource.UpdateProxy(cfg); err != nil {
if errors.Is(err, source.ErrNotFound) {
return fmt.Errorf("%w: %v", configmgmt.ErrNotFound, err)
}
return err
}
return nil
})
if err != nil {
return nil, err
}
log.Infof("store: updated proxy %q", name)
return persisted, nil
}
func (m *serviceConfigManager) DeleteStoreProxy(name string) error {
if name == "" {
return fmt.Errorf("%w: proxy name is required", configmgmt.ErrInvalidArgument)
}
if err := m.withStoreMutationAndReload(func(storeSource *source.StoreSource) error {
if err := storeSource.RemoveProxy(name); err != nil {
if errors.Is(err, source.ErrNotFound) {
return fmt.Errorf("%w: %v", configmgmt.ErrNotFound, err)
}
return err
}
return nil
}); err != nil {
return err
}
log.Infof("store: deleted proxy %q", name)
return nil
}
func (m *serviceConfigManager) ListStoreVisitors() ([]v1.VisitorConfigurer, error) {
storeSource, err := m.storeSourceOrError()
if err != nil {
return nil, err
}
return storeSource.GetAllVisitors()
}
func (m *serviceConfigManager) GetStoreVisitor(name string) (v1.VisitorConfigurer, error) {
if name == "" {
return nil, fmt.Errorf("%w: visitor name is required", configmgmt.ErrInvalidArgument)
}
storeSource, err := m.storeSourceOrError()
if err != nil {
return nil, err
}
cfg := storeSource.GetVisitor(name)
if cfg == nil {
return nil, fmt.Errorf("%w: visitor %q", configmgmt.ErrNotFound, name)
}
return cfg, nil
}
func (m *serviceConfigManager) CreateStoreVisitor(cfg v1.VisitorConfigurer) (v1.VisitorConfigurer, error) {
if err := m.validateStoreVisitorConfigurer(cfg); err != nil {
return nil, fmt.Errorf("%w: validation error: %v", configmgmt.ErrInvalidArgument, err)
}
name := cfg.GetBaseConfig().Name
persisted, err := m.withStoreVisitorMutationAndReload(name, func(storeSource *source.StoreSource) error {
if err := storeSource.AddVisitor(cfg); err != nil {
if errors.Is(err, source.ErrAlreadyExists) {
return fmt.Errorf("%w: %v", configmgmt.ErrConflict, err)
}
return err
}
return nil
})
if err != nil {
return nil, err
}
log.Infof("store: created visitor %q", name)
return persisted, nil
}
func (m *serviceConfigManager) UpdateStoreVisitor(name string, cfg v1.VisitorConfigurer) (v1.VisitorConfigurer, error) {
if name == "" {
return nil, fmt.Errorf("%w: visitor name is required", configmgmt.ErrInvalidArgument)
}
if cfg == nil {
return nil, fmt.Errorf("%w: invalid visitor config: type is required", configmgmt.ErrInvalidArgument)
}
bodyName := cfg.GetBaseConfig().Name
if bodyName != name {
return nil, fmt.Errorf("%w: visitor name in URL must match name in body", configmgmt.ErrInvalidArgument)
}
if err := m.validateStoreVisitorConfigurer(cfg); err != nil {
return nil, fmt.Errorf("%w: validation error: %v", configmgmt.ErrInvalidArgument, err)
}
persisted, err := m.withStoreVisitorMutationAndReload(name, func(storeSource *source.StoreSource) error {
if err := storeSource.UpdateVisitor(cfg); err != nil {
if errors.Is(err, source.ErrNotFound) {
return fmt.Errorf("%w: %v", configmgmt.ErrNotFound, err)
}
return err
}
return nil
})
if err != nil {
return nil, err
}
log.Infof("store: updated visitor %q", name)
return persisted, nil
}
func (m *serviceConfigManager) DeleteStoreVisitor(name string) error {
if name == "" {
return fmt.Errorf("%w: visitor name is required", configmgmt.ErrInvalidArgument)
}
if err := m.withStoreMutationAndReload(func(storeSource *source.StoreSource) error {
if err := storeSource.RemoveVisitor(name); err != nil {
if errors.Is(err, source.ErrNotFound) {
return fmt.Errorf("%w: %v", configmgmt.ErrNotFound, err)
}
return err
}
return nil
}); err != nil {
return err
}
log.Infof("store: deleted visitor %q", name)
return nil
}
func (m *serviceConfigManager) GracefulClose(d time.Duration) {
m.svr.GracefulClose(d)
}
func (m *serviceConfigManager) storeSourceOrError() (*source.StoreSource, error) {
m.svr.reloadMu.Lock()
storeSource := m.svr.storeSource
m.svr.reloadMu.Unlock()
if storeSource == nil {
return nil, fmt.Errorf("%w: store API is disabled", configmgmt.ErrStoreDisabled)
}
return storeSource, nil
}
func (m *serviceConfigManager) withStoreMutationAndReload(
fn func(storeSource *source.StoreSource) error,
) error {
m.svr.reloadMu.Lock()
defer m.svr.reloadMu.Unlock()
storeSource := m.svr.storeSource
if storeSource == nil {
return fmt.Errorf("%w: store API is disabled", configmgmt.ErrStoreDisabled)
}
if err := fn(storeSource); err != nil {
return err
}
if err := m.svr.reloadConfigFromSourcesLocked(); err != nil {
return fmt.Errorf("%w: failed to apply config: %v", configmgmt.ErrApplyConfig, err)
}
return nil
}
func (m *serviceConfigManager) withStoreProxyMutationAndReload(
name string,
fn func(storeSource *source.StoreSource) error,
) (v1.ProxyConfigurer, error) {
m.svr.reloadMu.Lock()
defer m.svr.reloadMu.Unlock()
storeSource := m.svr.storeSource
if storeSource == nil {
return nil, fmt.Errorf("%w: store API is disabled", configmgmt.ErrStoreDisabled)
}
if err := fn(storeSource); err != nil {
return nil, err
}
if err := m.svr.reloadConfigFromSourcesLocked(); err != nil {
return nil, fmt.Errorf("%w: failed to apply config: %v", configmgmt.ErrApplyConfig, err)
}
persisted := storeSource.GetProxy(name)
if persisted == nil {
return nil, fmt.Errorf("%w: proxy %q not found in store after mutation", configmgmt.ErrApplyConfig, name)
}
return persisted.Clone(), nil
}
func (m *serviceConfigManager) withStoreVisitorMutationAndReload(
name string,
fn func(storeSource *source.StoreSource) error,
) (v1.VisitorConfigurer, error) {
m.svr.reloadMu.Lock()
defer m.svr.reloadMu.Unlock()
storeSource := m.svr.storeSource
if storeSource == nil {
return nil, fmt.Errorf("%w: store API is disabled", configmgmt.ErrStoreDisabled)
}
if err := fn(storeSource); err != nil {
return nil, err
}
if err := m.svr.reloadConfigFromSourcesLocked(); err != nil {
return nil, fmt.Errorf("%w: failed to apply config: %v", configmgmt.ErrApplyConfig, err)
}
persisted := storeSource.GetVisitor(name)
if persisted == nil {
return nil, fmt.Errorf("%w: visitor %q not found in store after mutation", configmgmt.ErrApplyConfig, name)
}
return persisted.Clone(), nil
}
func (m *serviceConfigManager) validateStoreProxyConfigurer(cfg v1.ProxyConfigurer) error {
if cfg == nil {
return fmt.Errorf("invalid proxy config")
}
runtimeCfg := cfg.Clone()
if runtimeCfg == nil {
return fmt.Errorf("invalid proxy config")
}
runtimeCfg.Complete()
return validation.ValidateProxyConfigurerForClient(runtimeCfg)
}
func (m *serviceConfigManager) validateStoreVisitorConfigurer(cfg v1.VisitorConfigurer) error {
if cfg == nil {
return fmt.Errorf("invalid visitor config")
}
runtimeCfg := cfg.Clone()
if runtimeCfg == nil {
return fmt.Errorf("invalid visitor config")
}
runtimeCfg.Complete()
return validation.ValidateVisitorConfigurer(runtimeCfg)
}

View File

@@ -0,0 +1,137 @@
package client
import (
"errors"
"path/filepath"
"testing"
"github.com/fatedier/frp/client/configmgmt"
"github.com/fatedier/frp/pkg/config/source"
v1 "github.com/fatedier/frp/pkg/config/v1"
)
func newTestRawTCPProxyConfig(name string) *v1.TCPProxyConfig {
return &v1.TCPProxyConfig{
ProxyBaseConfig: v1.ProxyBaseConfig{
Name: name,
Type: "tcp",
ProxyBackend: v1.ProxyBackend{
LocalPort: 10080,
},
},
}
}
func TestServiceConfigManagerCreateStoreProxyConflict(t *testing.T) {
storeSource, err := source.NewStoreSource(source.StoreSourceConfig{
Path: filepath.Join(t.TempDir(), "store.json"),
})
if err != nil {
t.Fatalf("new store source: %v", err)
}
if err := storeSource.AddProxy(newTestRawTCPProxyConfig("p1")); err != nil {
t.Fatalf("seed proxy: %v", err)
}
agg := source.NewAggregator(source.NewConfigSource())
agg.SetStoreSource(storeSource)
mgr := &serviceConfigManager{
svr: &Service{
aggregator: agg,
configSource: agg.ConfigSource(),
storeSource: storeSource,
reloadCommon: &v1.ClientCommonConfig{},
},
}
_, err = mgr.CreateStoreProxy(newTestRawTCPProxyConfig("p1"))
if err == nil {
t.Fatal("expected conflict error")
}
if !errors.Is(err, configmgmt.ErrConflict) {
t.Fatalf("unexpected error: %v", err)
}
}
func TestServiceConfigManagerCreateStoreProxyKeepsStoreOnReloadFailure(t *testing.T) {
storeSource, err := source.NewStoreSource(source.StoreSourceConfig{
Path: filepath.Join(t.TempDir(), "store.json"),
})
if err != nil {
t.Fatalf("new store source: %v", err)
}
mgr := &serviceConfigManager{
svr: &Service{
storeSource: storeSource,
reloadCommon: &v1.ClientCommonConfig{},
},
}
_, err = mgr.CreateStoreProxy(newTestRawTCPProxyConfig("p1"))
if err == nil {
t.Fatal("expected apply config error")
}
if !errors.Is(err, configmgmt.ErrApplyConfig) {
t.Fatalf("unexpected error: %v", err)
}
if storeSource.GetProxy("p1") == nil {
t.Fatal("proxy should remain in store after reload failure")
}
}
func TestServiceConfigManagerCreateStoreProxyStoreDisabled(t *testing.T) {
mgr := &serviceConfigManager{
svr: &Service{
reloadCommon: &v1.ClientCommonConfig{},
},
}
_, err := mgr.CreateStoreProxy(newTestRawTCPProxyConfig("p1"))
if err == nil {
t.Fatal("expected store disabled error")
}
if !errors.Is(err, configmgmt.ErrStoreDisabled) {
t.Fatalf("unexpected error: %v", err)
}
}
func TestServiceConfigManagerCreateStoreProxyDoesNotPersistRuntimeDefaults(t *testing.T) {
storeSource, err := source.NewStoreSource(source.StoreSourceConfig{
Path: filepath.Join(t.TempDir(), "store.json"),
})
if err != nil {
t.Fatalf("new store source: %v", err)
}
agg := source.NewAggregator(source.NewConfigSource())
agg.SetStoreSource(storeSource)
mgr := &serviceConfigManager{
svr: &Service{
aggregator: agg,
configSource: agg.ConfigSource(),
storeSource: storeSource,
reloadCommon: &v1.ClientCommonConfig{},
},
}
persisted, err := mgr.CreateStoreProxy(newTestRawTCPProxyConfig("raw-proxy"))
if err != nil {
t.Fatalf("create store proxy: %v", err)
}
if persisted == nil {
t.Fatal("expected persisted proxy to be returned")
}
got := storeSource.GetProxy("raw-proxy")
if got == nil {
t.Fatal("proxy not found in store")
}
if got.GetBaseConfig().LocalIP != "" {
t.Fatalf("localIP was persisted with runtime default: %q", got.GetBaseConfig().LocalIP)
}
if got.GetBaseConfig().Transport.BandwidthLimitMode != "" {
t.Fatalf("bandwidthLimitMode was persisted with runtime default: %q", got.GetBaseConfig().Transport.BandwidthLimitMode)
}
}

View File

@@ -0,0 +1,42 @@
package configmgmt
import (
"errors"
"time"
"github.com/fatedier/frp/client/proxy"
v1 "github.com/fatedier/frp/pkg/config/v1"
)
var (
ErrInvalidArgument = errors.New("invalid argument")
ErrNotFound = errors.New("not found")
ErrConflict = errors.New("conflict")
ErrStoreDisabled = errors.New("store disabled")
ErrApplyConfig = errors.New("apply config failed")
)
type ConfigManager interface {
ReloadFromFile(strict bool) error
ReadConfigFile() (string, error)
WriteConfigFile(content []byte) error
GetProxyStatus() []*proxy.WorkingStatus
IsStoreProxyEnabled(name string) bool
StoreEnabled() bool
ListStoreProxies() ([]v1.ProxyConfigurer, error)
GetStoreProxy(name string) (v1.ProxyConfigurer, error)
CreateStoreProxy(cfg v1.ProxyConfigurer) (v1.ProxyConfigurer, error)
UpdateStoreProxy(name string, cfg v1.ProxyConfigurer) (v1.ProxyConfigurer, error)
DeleteStoreProxy(name string) error
ListStoreVisitors() ([]v1.VisitorConfigurer, error)
GetStoreVisitor(name string) (v1.VisitorConfigurer, error)
CreateStoreVisitor(cfg v1.VisitorConfigurer) (v1.VisitorConfigurer, error)
UpdateStoreVisitor(name string, cfg v1.VisitorConfigurer) (v1.VisitorConfigurer, error)
DeleteStoreVisitor(name string) error
GracefulClose(d time.Duration)
}

View File

@@ -25,9 +25,9 @@ import (
"github.com/fatedier/frp/pkg/auth"
v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/naming"
"github.com/fatedier/frp/pkg/transport"
netpkg "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/util"
"github.com/fatedier/frp/pkg/util/wait"
"github.com/fatedier/frp/pkg/util/xlog"
"github.com/fatedier/frp/pkg/vnet"
@@ -157,7 +157,7 @@ func (ctl *Control) handleReqWorkConn(_ msg.Message) {
return
}
startMsg.ProxyName = util.StripUserPrefix(ctl.sessionCtx.Common.User, startMsg.ProxyName)
startMsg.ProxyName = naming.StripUserPrefix(ctl.sessionCtx.Common.User, startMsg.ProxyName)
// dispatch this work connection to related proxy
ctl.pm.HandleWorkConn(startMsg.ProxyName, workConn, &startMsg)
@@ -168,7 +168,7 @@ func (ctl *Control) handleNewProxyResp(m msg.Message) {
inMsg := m.(*msg.NewProxyResp)
// Server will return NewProxyResp message to each NewProxy message.
// Start a new proxy handler if no error got
proxyName := util.StripUserPrefix(ctl.sessionCtx.Common.User, inMsg.ProxyName)
proxyName := naming.StripUserPrefix(ctl.sessionCtx.Common.User, inMsg.ProxyName)
err := ctl.pm.StartProxy(proxyName, inMsg.RemoteAddr, inMsg.Error)
if err != nil {
xl.Warnf("[%s] start error: %v", proxyName, err)

395
client/http/controller.go Normal file
View File

@@ -0,0 +1,395 @@
// Copyright 2025 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package http
import (
"cmp"
"errors"
"fmt"
"net"
"net/http"
"slices"
"strconv"
"time"
"github.com/fatedier/frp/client/configmgmt"
"github.com/fatedier/frp/client/http/model"
"github.com/fatedier/frp/client/proxy"
httppkg "github.com/fatedier/frp/pkg/util/http"
"github.com/fatedier/frp/pkg/util/jsonx"
)
// Controller handles HTTP API requests for frpc.
type Controller struct {
serverAddr string
manager configmgmt.ConfigManager
}
// ControllerParams contains parameters for creating an APIController.
type ControllerParams struct {
ServerAddr string
Manager configmgmt.ConfigManager
}
func NewController(params ControllerParams) *Controller {
return &Controller{
serverAddr: params.ServerAddr,
manager: params.Manager,
}
}
func (c *Controller) toHTTPError(err error) error {
if err == nil {
return nil
}
code := http.StatusInternalServerError
switch {
case errors.Is(err, configmgmt.ErrInvalidArgument):
code = http.StatusBadRequest
case errors.Is(err, configmgmt.ErrNotFound), errors.Is(err, configmgmt.ErrStoreDisabled):
code = http.StatusNotFound
case errors.Is(err, configmgmt.ErrConflict):
code = http.StatusConflict
}
return httppkg.NewError(code, err.Error())
}
// Reload handles GET /api/reload
func (c *Controller) Reload(ctx *httppkg.Context) (any, error) {
strictConfigMode := false
strictStr := ctx.Query("strictConfig")
if strictStr != "" {
strictConfigMode, _ = strconv.ParseBool(strictStr)
}
if err := c.manager.ReloadFromFile(strictConfigMode); err != nil {
return nil, c.toHTTPError(err)
}
return nil, nil
}
// Stop handles POST /api/stop
func (c *Controller) Stop(ctx *httppkg.Context) (any, error) {
go c.manager.GracefulClose(100 * time.Millisecond)
return nil, nil
}
// Status handles GET /api/status
func (c *Controller) Status(ctx *httppkg.Context) (any, error) {
res := make(model.StatusResp)
ps := c.manager.GetProxyStatus()
if ps == nil {
return res, nil
}
for _, status := range ps {
res[status.Type] = append(res[status.Type], c.buildProxyStatusResp(status))
}
for _, arrs := range res {
if len(arrs) <= 1 {
continue
}
slices.SortFunc(arrs, func(a, b model.ProxyStatusResp) int {
return cmp.Compare(a.Name, b.Name)
})
}
return res, nil
}
// GetConfig handles GET /api/config
func (c *Controller) GetConfig(ctx *httppkg.Context) (any, error) {
content, err := c.manager.ReadConfigFile()
if err != nil {
return nil, c.toHTTPError(err)
}
return content, nil
}
// PutConfig handles PUT /api/config
func (c *Controller) PutConfig(ctx *httppkg.Context) (any, error) {
body, err := ctx.Body()
if err != nil {
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("read request body error: %v", err))
}
if len(body) == 0 {
return nil, httppkg.NewError(http.StatusBadRequest, "body can't be empty")
}
if err := c.manager.WriteConfigFile(body); err != nil {
return nil, c.toHTTPError(err)
}
return nil, nil
}
func (c *Controller) buildProxyStatusResp(status *proxy.WorkingStatus) model.ProxyStatusResp {
psr := model.ProxyStatusResp{
Name: status.Name,
Type: status.Type,
Status: status.Phase,
Err: status.Err,
}
baseCfg := status.Cfg.GetBaseConfig()
if baseCfg.LocalPort != 0 {
psr.LocalAddr = net.JoinHostPort(baseCfg.LocalIP, strconv.Itoa(baseCfg.LocalPort))
}
psr.Plugin = baseCfg.Plugin.Type
if status.Err == "" {
psr.RemoteAddr = status.RemoteAddr
if slices.Contains([]string{"tcp", "udp"}, status.Type) {
psr.RemoteAddr = c.serverAddr + psr.RemoteAddr
}
}
if c.manager.IsStoreProxyEnabled(status.Name) {
psr.Source = model.SourceStore
}
return psr
}
func (c *Controller) ListStoreProxies(ctx *httppkg.Context) (any, error) {
proxies, err := c.manager.ListStoreProxies()
if err != nil {
return nil, c.toHTTPError(err)
}
resp := model.ProxyListResp{Proxies: make([]model.ProxyDefinition, 0, len(proxies))}
for _, p := range proxies {
payload, err := model.ProxyDefinitionFromConfigurer(p)
if err != nil {
return nil, httppkg.NewError(http.StatusInternalServerError, err.Error())
}
resp.Proxies = append(resp.Proxies, payload)
}
slices.SortFunc(resp.Proxies, func(a, b model.ProxyDefinition) int {
return cmp.Compare(a.Name, b.Name)
})
return resp, nil
}
func (c *Controller) GetStoreProxy(ctx *httppkg.Context) (any, error) {
name := ctx.Param("name")
if name == "" {
return nil, httppkg.NewError(http.StatusBadRequest, "proxy name is required")
}
p, err := c.manager.GetStoreProxy(name)
if err != nil {
return nil, c.toHTTPError(err)
}
payload, err := model.ProxyDefinitionFromConfigurer(p)
if err != nil {
return nil, httppkg.NewError(http.StatusInternalServerError, err.Error())
}
return payload, nil
}
func (c *Controller) CreateStoreProxy(ctx *httppkg.Context) (any, error) {
body, err := ctx.Body()
if err != nil {
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("read body error: %v", err))
}
var payload model.ProxyDefinition
if err := jsonx.Unmarshal(body, &payload); err != nil {
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("parse JSON error: %v", err))
}
if err := payload.Validate("", false); err != nil {
return nil, httppkg.NewError(http.StatusBadRequest, err.Error())
}
cfg, err := payload.ToConfigurer()
if err != nil {
return nil, httppkg.NewError(http.StatusBadRequest, err.Error())
}
created, err := c.manager.CreateStoreProxy(cfg)
if err != nil {
return nil, c.toHTTPError(err)
}
resp, err := model.ProxyDefinitionFromConfigurer(created)
if err != nil {
return nil, httppkg.NewError(http.StatusInternalServerError, err.Error())
}
return resp, nil
}
func (c *Controller) UpdateStoreProxy(ctx *httppkg.Context) (any, error) {
name := ctx.Param("name")
if name == "" {
return nil, httppkg.NewError(http.StatusBadRequest, "proxy name is required")
}
body, err := ctx.Body()
if err != nil {
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("read body error: %v", err))
}
var payload model.ProxyDefinition
if err := jsonx.Unmarshal(body, &payload); err != nil {
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("parse JSON error: %v", err))
}
if err := payload.Validate(name, true); err != nil {
return nil, httppkg.NewError(http.StatusBadRequest, err.Error())
}
cfg, err := payload.ToConfigurer()
if err != nil {
return nil, httppkg.NewError(http.StatusBadRequest, err.Error())
}
updated, err := c.manager.UpdateStoreProxy(name, cfg)
if err != nil {
return nil, c.toHTTPError(err)
}
resp, err := model.ProxyDefinitionFromConfigurer(updated)
if err != nil {
return nil, httppkg.NewError(http.StatusInternalServerError, err.Error())
}
return resp, nil
}
func (c *Controller) DeleteStoreProxy(ctx *httppkg.Context) (any, error) {
name := ctx.Param("name")
if name == "" {
return nil, httppkg.NewError(http.StatusBadRequest, "proxy name is required")
}
if err := c.manager.DeleteStoreProxy(name); err != nil {
return nil, c.toHTTPError(err)
}
return nil, nil
}
func (c *Controller) ListStoreVisitors(ctx *httppkg.Context) (any, error) {
visitors, err := c.manager.ListStoreVisitors()
if err != nil {
return nil, c.toHTTPError(err)
}
resp := model.VisitorListResp{Visitors: make([]model.VisitorDefinition, 0, len(visitors))}
for _, v := range visitors {
payload, err := model.VisitorDefinitionFromConfigurer(v)
if err != nil {
return nil, httppkg.NewError(http.StatusInternalServerError, err.Error())
}
resp.Visitors = append(resp.Visitors, payload)
}
slices.SortFunc(resp.Visitors, func(a, b model.VisitorDefinition) int {
return cmp.Compare(a.Name, b.Name)
})
return resp, nil
}
func (c *Controller) GetStoreVisitor(ctx *httppkg.Context) (any, error) {
name := ctx.Param("name")
if name == "" {
return nil, httppkg.NewError(http.StatusBadRequest, "visitor name is required")
}
v, err := c.manager.GetStoreVisitor(name)
if err != nil {
return nil, c.toHTTPError(err)
}
payload, err := model.VisitorDefinitionFromConfigurer(v)
if err != nil {
return nil, httppkg.NewError(http.StatusInternalServerError, err.Error())
}
return payload, nil
}
func (c *Controller) CreateStoreVisitor(ctx *httppkg.Context) (any, error) {
body, err := ctx.Body()
if err != nil {
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("read body error: %v", err))
}
var payload model.VisitorDefinition
if err := jsonx.Unmarshal(body, &payload); err != nil {
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("parse JSON error: %v", err))
}
if err := payload.Validate("", false); err != nil {
return nil, httppkg.NewError(http.StatusBadRequest, err.Error())
}
cfg, err := payload.ToConfigurer()
if err != nil {
return nil, httppkg.NewError(http.StatusBadRequest, err.Error())
}
created, err := c.manager.CreateStoreVisitor(cfg)
if err != nil {
return nil, c.toHTTPError(err)
}
resp, err := model.VisitorDefinitionFromConfigurer(created)
if err != nil {
return nil, httppkg.NewError(http.StatusInternalServerError, err.Error())
}
return resp, nil
}
func (c *Controller) UpdateStoreVisitor(ctx *httppkg.Context) (any, error) {
name := ctx.Param("name")
if name == "" {
return nil, httppkg.NewError(http.StatusBadRequest, "visitor name is required")
}
body, err := ctx.Body()
if err != nil {
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("read body error: %v", err))
}
var payload model.VisitorDefinition
if err := jsonx.Unmarshal(body, &payload); err != nil {
return nil, httppkg.NewError(http.StatusBadRequest, fmt.Sprintf("parse JSON error: %v", err))
}
if err := payload.Validate(name, true); err != nil {
return nil, httppkg.NewError(http.StatusBadRequest, err.Error())
}
cfg, err := payload.ToConfigurer()
if err != nil {
return nil, httppkg.NewError(http.StatusBadRequest, err.Error())
}
updated, err := c.manager.UpdateStoreVisitor(name, cfg)
if err != nil {
return nil, c.toHTTPError(err)
}
resp, err := model.VisitorDefinitionFromConfigurer(updated)
if err != nil {
return nil, httppkg.NewError(http.StatusInternalServerError, err.Error())
}
return resp, nil
}
func (c *Controller) DeleteStoreVisitor(ctx *httppkg.Context) (any, error) {
name := ctx.Param("name")
if name == "" {
return nil, httppkg.NewError(http.StatusBadRequest, "visitor name is required")
}
if err := c.manager.DeleteStoreVisitor(name); err != nil {
return nil, c.toHTTPError(err)
}
return nil, nil
}

View File

@@ -0,0 +1,531 @@
package http
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gorilla/mux"
"github.com/fatedier/frp/client/configmgmt"
"github.com/fatedier/frp/client/http/model"
"github.com/fatedier/frp/client/proxy"
v1 "github.com/fatedier/frp/pkg/config/v1"
httppkg "github.com/fatedier/frp/pkg/util/http"
)
type fakeConfigManager struct {
reloadFromFileFn func(strict bool) error
readConfigFileFn func() (string, error)
writeConfigFileFn func(content []byte) error
getProxyStatusFn func() []*proxy.WorkingStatus
isStoreProxyEnabledFn func(name string) bool
storeEnabledFn func() bool
listStoreProxiesFn func() ([]v1.ProxyConfigurer, error)
getStoreProxyFn func(name string) (v1.ProxyConfigurer, error)
createStoreProxyFn func(cfg v1.ProxyConfigurer) (v1.ProxyConfigurer, error)
updateStoreProxyFn func(name string, cfg v1.ProxyConfigurer) (v1.ProxyConfigurer, error)
deleteStoreProxyFn func(name string) error
listStoreVisitorsFn func() ([]v1.VisitorConfigurer, error)
getStoreVisitorFn func(name string) (v1.VisitorConfigurer, error)
createStoreVisitFn func(cfg v1.VisitorConfigurer) (v1.VisitorConfigurer, error)
updateStoreVisitFn func(name string, cfg v1.VisitorConfigurer) (v1.VisitorConfigurer, error)
deleteStoreVisitFn func(name string) error
gracefulCloseFn func(d time.Duration)
}
func (m *fakeConfigManager) ReloadFromFile(strict bool) error {
if m.reloadFromFileFn != nil {
return m.reloadFromFileFn(strict)
}
return nil
}
func (m *fakeConfigManager) ReadConfigFile() (string, error) {
if m.readConfigFileFn != nil {
return m.readConfigFileFn()
}
return "", nil
}
func (m *fakeConfigManager) WriteConfigFile(content []byte) error {
if m.writeConfigFileFn != nil {
return m.writeConfigFileFn(content)
}
return nil
}
func (m *fakeConfigManager) GetProxyStatus() []*proxy.WorkingStatus {
if m.getProxyStatusFn != nil {
return m.getProxyStatusFn()
}
return nil
}
func (m *fakeConfigManager) IsStoreProxyEnabled(name string) bool {
if m.isStoreProxyEnabledFn != nil {
return m.isStoreProxyEnabledFn(name)
}
return false
}
func (m *fakeConfigManager) StoreEnabled() bool {
if m.storeEnabledFn != nil {
return m.storeEnabledFn()
}
return false
}
func (m *fakeConfigManager) ListStoreProxies() ([]v1.ProxyConfigurer, error) {
if m.listStoreProxiesFn != nil {
return m.listStoreProxiesFn()
}
return nil, nil
}
func (m *fakeConfigManager) GetStoreProxy(name string) (v1.ProxyConfigurer, error) {
if m.getStoreProxyFn != nil {
return m.getStoreProxyFn(name)
}
return nil, nil
}
func (m *fakeConfigManager) CreateStoreProxy(cfg v1.ProxyConfigurer) (v1.ProxyConfigurer, error) {
if m.createStoreProxyFn != nil {
return m.createStoreProxyFn(cfg)
}
return cfg, nil
}
func (m *fakeConfigManager) UpdateStoreProxy(name string, cfg v1.ProxyConfigurer) (v1.ProxyConfigurer, error) {
if m.updateStoreProxyFn != nil {
return m.updateStoreProxyFn(name, cfg)
}
return cfg, nil
}
func (m *fakeConfigManager) DeleteStoreProxy(name string) error {
if m.deleteStoreProxyFn != nil {
return m.deleteStoreProxyFn(name)
}
return nil
}
func (m *fakeConfigManager) ListStoreVisitors() ([]v1.VisitorConfigurer, error) {
if m.listStoreVisitorsFn != nil {
return m.listStoreVisitorsFn()
}
return nil, nil
}
func (m *fakeConfigManager) GetStoreVisitor(name string) (v1.VisitorConfigurer, error) {
if m.getStoreVisitorFn != nil {
return m.getStoreVisitorFn(name)
}
return nil, nil
}
func (m *fakeConfigManager) CreateStoreVisitor(cfg v1.VisitorConfigurer) (v1.VisitorConfigurer, error) {
if m.createStoreVisitFn != nil {
return m.createStoreVisitFn(cfg)
}
return cfg, nil
}
func (m *fakeConfigManager) UpdateStoreVisitor(name string, cfg v1.VisitorConfigurer) (v1.VisitorConfigurer, error) {
if m.updateStoreVisitFn != nil {
return m.updateStoreVisitFn(name, cfg)
}
return cfg, nil
}
func (m *fakeConfigManager) DeleteStoreVisitor(name string) error {
if m.deleteStoreVisitFn != nil {
return m.deleteStoreVisitFn(name)
}
return nil
}
func (m *fakeConfigManager) GracefulClose(d time.Duration) {
if m.gracefulCloseFn != nil {
m.gracefulCloseFn(d)
}
}
func newRawTCPProxyConfig(name string) *v1.TCPProxyConfig {
return &v1.TCPProxyConfig{
ProxyBaseConfig: v1.ProxyBaseConfig{
Name: name,
Type: "tcp",
ProxyBackend: v1.ProxyBackend{
LocalPort: 10080,
},
},
}
}
func TestBuildProxyStatusRespStoreSourceEnabled(t *testing.T) {
status := &proxy.WorkingStatus{
Name: "shared-proxy",
Type: "tcp",
Phase: proxy.ProxyPhaseRunning,
RemoteAddr: ":8080",
Cfg: newRawTCPProxyConfig("shared-proxy"),
}
controller := &Controller{
serverAddr: "127.0.0.1",
manager: &fakeConfigManager{
isStoreProxyEnabledFn: func(name string) bool {
return name == "shared-proxy"
},
},
}
resp := controller.buildProxyStatusResp(status)
if resp.Source != "store" {
t.Fatalf("unexpected source: %q", resp.Source)
}
if resp.RemoteAddr != "127.0.0.1:8080" {
t.Fatalf("unexpected remote addr: %q", resp.RemoteAddr)
}
}
func TestReloadErrorMapping(t *testing.T) {
tests := []struct {
name string
err error
expectedCode int
}{
{name: "invalid arg", err: fmtError(configmgmt.ErrInvalidArgument, "bad cfg"), expectedCode: http.StatusBadRequest},
{name: "apply fail", err: fmtError(configmgmt.ErrApplyConfig, "reload failed"), expectedCode: http.StatusInternalServerError},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
controller := &Controller{
manager: &fakeConfigManager{reloadFromFileFn: func(bool) error { return tc.err }},
}
ctx := httppkg.NewContext(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/api/reload", nil))
_, err := controller.Reload(ctx)
if err == nil {
t.Fatal("expected error")
}
assertHTTPCode(t, err, tc.expectedCode)
})
}
}
func TestStoreProxyErrorMapping(t *testing.T) {
tests := []struct {
name string
err error
expectedCode int
}{
{name: "not found", err: fmtError(configmgmt.ErrNotFound, "not found"), expectedCode: http.StatusNotFound},
{name: "conflict", err: fmtError(configmgmt.ErrConflict, "exists"), expectedCode: http.StatusConflict},
{name: "internal", err: errors.New("persist failed"), expectedCode: http.StatusInternalServerError},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
body := []byte(`{"name":"shared-proxy","type":"tcp","tcp":{"localPort":10080}}`)
req := httptest.NewRequest(http.MethodPut, "/api/store/proxies/shared-proxy", bytes.NewReader(body))
req = mux.SetURLVars(req, map[string]string{"name": "shared-proxy"})
ctx := httppkg.NewContext(httptest.NewRecorder(), req)
controller := &Controller{
manager: &fakeConfigManager{
updateStoreProxyFn: func(_ string, _ v1.ProxyConfigurer) (v1.ProxyConfigurer, error) {
return nil, tc.err
},
},
}
_, err := controller.UpdateStoreProxy(ctx)
if err == nil {
t.Fatal("expected error")
}
assertHTTPCode(t, err, tc.expectedCode)
})
}
}
func TestStoreVisitorErrorMapping(t *testing.T) {
body := []byte(`{"name":"shared-visitor","type":"xtcp","xtcp":{"serverName":"server","bindPort":10081,"secretKey":"secret"}}`)
req := httptest.NewRequest(http.MethodDelete, "/api/store/visitors/shared-visitor", bytes.NewReader(body))
req = mux.SetURLVars(req, map[string]string{"name": "shared-visitor"})
ctx := httppkg.NewContext(httptest.NewRecorder(), req)
controller := &Controller{
manager: &fakeConfigManager{
deleteStoreVisitFn: func(string) error {
return fmtError(configmgmt.ErrStoreDisabled, "disabled")
},
},
}
_, err := controller.DeleteStoreVisitor(ctx)
if err == nil {
t.Fatal("expected error")
}
assertHTTPCode(t, err, http.StatusNotFound)
}
func TestCreateStoreProxyIgnoresUnknownFields(t *testing.T) {
var gotName string
controller := &Controller{
manager: &fakeConfigManager{
createStoreProxyFn: func(cfg v1.ProxyConfigurer) (v1.ProxyConfigurer, error) {
gotName = cfg.GetBaseConfig().Name
return cfg, nil
},
},
}
body := []byte(`{"name":"raw-proxy","type":"tcp","unexpected":"value","tcp":{"localPort":10080,"unknownInBlock":"value"}}`)
req := httptest.NewRequest(http.MethodPost, "/api/store/proxies", bytes.NewReader(body))
ctx := httppkg.NewContext(httptest.NewRecorder(), req)
resp, err := controller.CreateStoreProxy(ctx)
if err != nil {
t.Fatalf("create store proxy: %v", err)
}
if gotName != "raw-proxy" {
t.Fatalf("unexpected proxy name: %q", gotName)
}
payload, ok := resp.(model.ProxyDefinition)
if !ok {
t.Fatalf("unexpected response type: %T", resp)
}
if payload.Type != "tcp" || payload.TCP == nil {
t.Fatalf("unexpected payload: %#v", payload)
}
}
func TestCreateStoreVisitorIgnoresUnknownFields(t *testing.T) {
var gotName string
controller := &Controller{
manager: &fakeConfigManager{
createStoreVisitFn: func(cfg v1.VisitorConfigurer) (v1.VisitorConfigurer, error) {
gotName = cfg.GetBaseConfig().Name
return cfg, nil
},
},
}
body := []byte(`{
"name":"raw-visitor","type":"xtcp","unexpected":"value",
"xtcp":{"serverName":"server","bindPort":10081,"secretKey":"secret","unknownInBlock":"value"}
}`)
req := httptest.NewRequest(http.MethodPost, "/api/store/visitors", bytes.NewReader(body))
ctx := httppkg.NewContext(httptest.NewRecorder(), req)
resp, err := controller.CreateStoreVisitor(ctx)
if err != nil {
t.Fatalf("create store visitor: %v", err)
}
if gotName != "raw-visitor" {
t.Fatalf("unexpected visitor name: %q", gotName)
}
payload, ok := resp.(model.VisitorDefinition)
if !ok {
t.Fatalf("unexpected response type: %T", resp)
}
if payload.Type != "xtcp" || payload.XTCP == nil {
t.Fatalf("unexpected payload: %#v", payload)
}
}
func TestCreateStoreProxyPluginUnknownFieldsAreIgnored(t *testing.T) {
var gotPluginType string
controller := &Controller{
manager: &fakeConfigManager{
createStoreProxyFn: func(cfg v1.ProxyConfigurer) (v1.ProxyConfigurer, error) {
gotPluginType = cfg.GetBaseConfig().Plugin.Type
return cfg, nil
},
},
}
body := []byte(`{"name":"plugin-proxy","type":"tcp","tcp":{"plugin":{"type":"http2https","localAddr":"127.0.0.1:8080","unknownInPlugin":"value"}}}`)
req := httptest.NewRequest(http.MethodPost, "/api/store/proxies", bytes.NewReader(body))
ctx := httppkg.NewContext(httptest.NewRecorder(), req)
resp, err := controller.CreateStoreProxy(ctx)
if err != nil {
t.Fatalf("create store proxy: %v", err)
}
if gotPluginType != "http2https" {
t.Fatalf("unexpected plugin type: %q", gotPluginType)
}
payload, ok := resp.(model.ProxyDefinition)
if !ok {
t.Fatalf("unexpected response type: %T", resp)
}
if payload.TCP == nil {
t.Fatalf("unexpected response payload: %#v", payload)
}
pluginType := payload.TCP.Plugin.Type
if pluginType != "http2https" {
t.Fatalf("unexpected plugin type in response payload: %q", pluginType)
}
}
func TestCreateStoreVisitorPluginUnknownFieldsAreIgnored(t *testing.T) {
var gotPluginType string
controller := &Controller{
manager: &fakeConfigManager{
createStoreVisitFn: func(cfg v1.VisitorConfigurer) (v1.VisitorConfigurer, error) {
gotPluginType = cfg.GetBaseConfig().Plugin.Type
return cfg, nil
},
},
}
body := []byte(`{
"name":"plugin-visitor","type":"stcp",
"stcp":{"serverName":"server","bindPort":10081,"plugin":{"type":"virtual_net","destinationIP":"10.0.0.1","unknownInPlugin":"value"}}
}`)
req := httptest.NewRequest(http.MethodPost, "/api/store/visitors", bytes.NewReader(body))
ctx := httppkg.NewContext(httptest.NewRecorder(), req)
resp, err := controller.CreateStoreVisitor(ctx)
if err != nil {
t.Fatalf("create store visitor: %v", err)
}
if gotPluginType != "virtual_net" {
t.Fatalf("unexpected plugin type: %q", gotPluginType)
}
payload, ok := resp.(model.VisitorDefinition)
if !ok {
t.Fatalf("unexpected response type: %T", resp)
}
if payload.STCP == nil {
t.Fatalf("unexpected response payload: %#v", payload)
}
pluginType := payload.STCP.Plugin.Type
if pluginType != "virtual_net" {
t.Fatalf("unexpected plugin type in response payload: %q", pluginType)
}
}
func TestUpdateStoreProxyRejectsMismatchedTypeBlock(t *testing.T) {
controller := &Controller{manager: &fakeConfigManager{}}
body := []byte(`{"name":"p1","type":"tcp","udp":{"localPort":10080}}`)
req := httptest.NewRequest(http.MethodPut, "/api/store/proxies/p1", bytes.NewReader(body))
req = mux.SetURLVars(req, map[string]string{"name": "p1"})
ctx := httppkg.NewContext(httptest.NewRecorder(), req)
_, err := controller.UpdateStoreProxy(ctx)
if err == nil {
t.Fatal("expected error")
}
assertHTTPCode(t, err, http.StatusBadRequest)
}
func TestUpdateStoreProxyRejectsNameMismatch(t *testing.T) {
controller := &Controller{manager: &fakeConfigManager{}}
body := []byte(`{"name":"p2","type":"tcp","tcp":{"localPort":10080}}`)
req := httptest.NewRequest(http.MethodPut, "/api/store/proxies/p1", bytes.NewReader(body))
req = mux.SetURLVars(req, map[string]string{"name": "p1"})
ctx := httppkg.NewContext(httptest.NewRecorder(), req)
_, err := controller.UpdateStoreProxy(ctx)
if err == nil {
t.Fatal("expected error")
}
assertHTTPCode(t, err, http.StatusBadRequest)
}
func TestListStoreProxiesReturnsSortedPayload(t *testing.T) {
controller := &Controller{
manager: &fakeConfigManager{
listStoreProxiesFn: func() ([]v1.ProxyConfigurer, error) {
b := newRawTCPProxyConfig("b")
a := newRawTCPProxyConfig("a")
return []v1.ProxyConfigurer{b, a}, nil
},
},
}
ctx := httppkg.NewContext(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/api/store/proxies", nil))
resp, err := controller.ListStoreProxies(ctx)
if err != nil {
t.Fatalf("list store proxies: %v", err)
}
out, ok := resp.(model.ProxyListResp)
if !ok {
t.Fatalf("unexpected response type: %T", resp)
}
if len(out.Proxies) != 2 {
t.Fatalf("unexpected proxy count: %d", len(out.Proxies))
}
if out.Proxies[0].Name != "a" || out.Proxies[1].Name != "b" {
t.Fatalf("proxies are not sorted by name: %#v", out.Proxies)
}
}
func fmtError(sentinel error, msg string) error {
return fmt.Errorf("%w: %s", sentinel, msg)
}
func assertHTTPCode(t *testing.T, err error, expected int) {
t.Helper()
var httpErr *httppkg.Error
if !errors.As(err, &httpErr) {
t.Fatalf("unexpected error type: %T", err)
}
if httpErr.Code != expected {
t.Fatalf("unexpected status code: got %d, want %d", httpErr.Code, expected)
}
}
func TestUpdateStoreProxyReturnsTypedPayload(t *testing.T) {
controller := &Controller{
manager: &fakeConfigManager{
updateStoreProxyFn: func(_ string, cfg v1.ProxyConfigurer) (v1.ProxyConfigurer, error) {
return cfg, nil
},
},
}
body := map[string]any{
"name": "shared-proxy",
"type": "tcp",
"tcp": map[string]any{
"localPort": 10080,
"remotePort": 7000,
},
}
data, err := json.Marshal(body)
if err != nil {
t.Fatalf("marshal request: %v", err)
}
req := httptest.NewRequest(http.MethodPut, "/api/store/proxies/shared-proxy", bytes.NewReader(data))
req = mux.SetURLVars(req, map[string]string{"name": "shared-proxy"})
ctx := httppkg.NewContext(httptest.NewRecorder(), req)
resp, err := controller.UpdateStoreProxy(ctx)
if err != nil {
t.Fatalf("update store proxy: %v", err)
}
payload, ok := resp.(model.ProxyDefinition)
if !ok {
t.Fatalf("unexpected response type: %T", resp)
}
if payload.TCP == nil || payload.TCP.RemotePort != 7000 {
t.Fatalf("unexpected response payload: %#v", payload)
}
}

View File

@@ -0,0 +1,148 @@
package model
import (
"fmt"
"strings"
v1 "github.com/fatedier/frp/pkg/config/v1"
)
type ProxyDefinition struct {
Name string `json:"name"`
Type string `json:"type"`
TCP *v1.TCPProxyConfig `json:"tcp,omitempty"`
UDP *v1.UDPProxyConfig `json:"udp,omitempty"`
HTTP *v1.HTTPProxyConfig `json:"http,omitempty"`
HTTPS *v1.HTTPSProxyConfig `json:"https,omitempty"`
TCPMux *v1.TCPMuxProxyConfig `json:"tcpmux,omitempty"`
STCP *v1.STCPProxyConfig `json:"stcp,omitempty"`
SUDP *v1.SUDPProxyConfig `json:"sudp,omitempty"`
XTCP *v1.XTCPProxyConfig `json:"xtcp,omitempty"`
}
func (p *ProxyDefinition) Validate(pathName string, isUpdate bool) error {
if strings.TrimSpace(p.Name) == "" {
return fmt.Errorf("proxy name is required")
}
if !IsProxyType(p.Type) {
return fmt.Errorf("invalid proxy type: %s", p.Type)
}
if isUpdate && pathName != "" && pathName != p.Name {
return fmt.Errorf("proxy name in URL must match name in body")
}
_, blockType, blockCount := p.activeBlock()
if blockCount != 1 {
return fmt.Errorf("exactly one proxy type block is required")
}
if blockType != p.Type {
return fmt.Errorf("proxy type block %q does not match type %q", blockType, p.Type)
}
return nil
}
func (p *ProxyDefinition) ToConfigurer() (v1.ProxyConfigurer, error) {
block, _, _ := p.activeBlock()
if block == nil {
return nil, fmt.Errorf("exactly one proxy type block is required")
}
cfg := block
cfg.GetBaseConfig().Name = p.Name
cfg.GetBaseConfig().Type = p.Type
return cfg, nil
}
func ProxyDefinitionFromConfigurer(cfg v1.ProxyConfigurer) (ProxyDefinition, error) {
if cfg == nil {
return ProxyDefinition{}, fmt.Errorf("proxy config is nil")
}
base := cfg.GetBaseConfig()
payload := ProxyDefinition{
Name: base.Name,
Type: base.Type,
}
switch c := cfg.(type) {
case *v1.TCPProxyConfig:
payload.TCP = c
case *v1.UDPProxyConfig:
payload.UDP = c
case *v1.HTTPProxyConfig:
payload.HTTP = c
case *v1.HTTPSProxyConfig:
payload.HTTPS = c
case *v1.TCPMuxProxyConfig:
payload.TCPMux = c
case *v1.STCPProxyConfig:
payload.STCP = c
case *v1.SUDPProxyConfig:
payload.SUDP = c
case *v1.XTCPProxyConfig:
payload.XTCP = c
default:
return ProxyDefinition{}, fmt.Errorf("unsupported proxy configurer type %T", cfg)
}
return payload, nil
}
func (p *ProxyDefinition) activeBlock() (v1.ProxyConfigurer, string, int) {
count := 0
var block v1.ProxyConfigurer
var blockType string
if p.TCP != nil {
count++
block = p.TCP
blockType = "tcp"
}
if p.UDP != nil {
count++
block = p.UDP
blockType = "udp"
}
if p.HTTP != nil {
count++
block = p.HTTP
blockType = "http"
}
if p.HTTPS != nil {
count++
block = p.HTTPS
blockType = "https"
}
if p.TCPMux != nil {
count++
block = p.TCPMux
blockType = "tcpmux"
}
if p.STCP != nil {
count++
block = p.STCP
blockType = "stcp"
}
if p.SUDP != nil {
count++
block = p.SUDP
blockType = "sudp"
}
if p.XTCP != nil {
count++
block = p.XTCP
blockType = "xtcp"
}
return block, blockType, count
}
func IsProxyType(typ string) bool {
switch typ {
case "tcp", "udp", "http", "https", "tcpmux", "stcp", "sudp", "xtcp":
return true
default:
return false
}
}

View File

@@ -12,7 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package api
package model
const SourceStore = "store"
// StatusResp is the response for GET /api/status
type StatusResp map[string][]ProxyStatusResp
@@ -29,31 +31,12 @@ type ProxyStatusResp struct {
Source string `json:"source,omitempty"` // "store" or "config"
}
// ProxyConfig wraps proxy configuration for API requests/responses.
type ProxyConfig struct {
Name string `json:"name"`
Type string `json:"type"`
Config map[string]any `json:"config"`
}
// VisitorConfig wraps visitor configuration for API requests/responses.
type VisitorConfig struct {
Name string `json:"name"`
Type string `json:"type"`
Config map[string]any `json:"config"`
}
// ProxyListResp is the response for GET /api/store/proxies
type ProxyListResp struct {
Proxies []ProxyConfig `json:"proxies"`
Proxies []ProxyDefinition `json:"proxies"`
}
// VisitorListResp is the response for GET /api/store/visitors
type VisitorListResp struct {
Visitors []VisitorConfig `json:"visitors"`
}
// ErrorResp represents an error response
type ErrorResp struct {
Error string `json:"error"`
Visitors []VisitorDefinition `json:"visitors"`
}

View File

@@ -0,0 +1,107 @@
package model
import (
"fmt"
"strings"
v1 "github.com/fatedier/frp/pkg/config/v1"
)
type VisitorDefinition struct {
Name string `json:"name"`
Type string `json:"type"`
STCP *v1.STCPVisitorConfig `json:"stcp,omitempty"`
SUDP *v1.SUDPVisitorConfig `json:"sudp,omitempty"`
XTCP *v1.XTCPVisitorConfig `json:"xtcp,omitempty"`
}
func (p *VisitorDefinition) Validate(pathName string, isUpdate bool) error {
if strings.TrimSpace(p.Name) == "" {
return fmt.Errorf("visitor name is required")
}
if !IsVisitorType(p.Type) {
return fmt.Errorf("invalid visitor type: %s", p.Type)
}
if isUpdate && pathName != "" && pathName != p.Name {
return fmt.Errorf("visitor name in URL must match name in body")
}
_, blockType, blockCount := p.activeBlock()
if blockCount != 1 {
return fmt.Errorf("exactly one visitor type block is required")
}
if blockType != p.Type {
return fmt.Errorf("visitor type block %q does not match type %q", blockType, p.Type)
}
return nil
}
func (p *VisitorDefinition) ToConfigurer() (v1.VisitorConfigurer, error) {
block, _, _ := p.activeBlock()
if block == nil {
return nil, fmt.Errorf("exactly one visitor type block is required")
}
cfg := block
cfg.GetBaseConfig().Name = p.Name
cfg.GetBaseConfig().Type = p.Type
return cfg, nil
}
func VisitorDefinitionFromConfigurer(cfg v1.VisitorConfigurer) (VisitorDefinition, error) {
if cfg == nil {
return VisitorDefinition{}, fmt.Errorf("visitor config is nil")
}
base := cfg.GetBaseConfig()
payload := VisitorDefinition{
Name: base.Name,
Type: base.Type,
}
switch c := cfg.(type) {
case *v1.STCPVisitorConfig:
payload.STCP = c
case *v1.SUDPVisitorConfig:
payload.SUDP = c
case *v1.XTCPVisitorConfig:
payload.XTCP = c
default:
return VisitorDefinition{}, fmt.Errorf("unsupported visitor configurer type %T", cfg)
}
return payload, nil
}
func (p *VisitorDefinition) activeBlock() (v1.VisitorConfigurer, string, int) {
count := 0
var block v1.VisitorConfigurer
var blockType string
if p.STCP != nil {
count++
block = p.STCP
blockType = "stcp"
}
if p.SUDP != nil {
count++
block = p.SUDP
blockType = "sudp"
}
if p.XTCP != nil {
count++
block = p.XTCP
blockType = "xtcp"
}
return block, blockType, count
}
func IsVisitorType(typ string) bool {
switch typ {
case "stcp", "sudp", "xtcp":
return true
default:
return false
}
}

View File

@@ -16,6 +16,7 @@ package proxy
import (
"context"
"fmt"
"io"
"net"
"reflect"
@@ -122,6 +123,33 @@ func (pxy *BaseProxy) Close() {
}
}
// wrapWorkConn applies rate limiting, encryption, and compression
// to a work connection based on the proxy's transport configuration.
// The returned recycle function should be called when the stream is no longer in use
// to return compression resources to the pool. It is safe to not call recycle,
// in which case resources will be garbage collected normally.
func (pxy *BaseProxy) wrapWorkConn(conn net.Conn, encKey []byte) (io.ReadWriteCloser, func(), error) {
var rwc io.ReadWriteCloser = conn
if pxy.limiter != nil {
rwc = libio.WrapReadWriteCloser(limit.NewReader(conn, pxy.limiter), limit.NewWriter(conn, pxy.limiter), func() error {
return conn.Close()
})
}
if pxy.baseCfg.Transport.UseEncryption {
var err error
rwc, err = libio.WithEncryption(rwc, encKey)
if err != nil {
conn.Close()
return nil, nil, fmt.Errorf("create encryption stream error: %w", err)
}
}
var recycleFn func()
if pxy.baseCfg.Transport.UseCompression {
rwc, recycleFn = libio.WithCompressionFromPool(rwc)
}
return rwc, recycleFn, nil
}
func (pxy *BaseProxy) SetInWorkConnCallback(cb func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool) {
pxy.inWorkConnCallback = cb
}
@@ -139,30 +167,14 @@ func (pxy *BaseProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) {
func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWorkConn, encKey []byte) {
xl := pxy.xl
baseCfg := pxy.baseCfg
var (
remote io.ReadWriteCloser
err error
)
remote = workConn
if pxy.limiter != nil {
remote = libio.WrapReadWriteCloser(limit.NewReader(workConn, pxy.limiter), limit.NewWriter(workConn, pxy.limiter), func() error {
return workConn.Close()
})
}
xl.Tracef("handle tcp work connection, useEncryption: %t, useCompression: %t",
baseCfg.Transport.UseEncryption, baseCfg.Transport.UseCompression)
if baseCfg.Transport.UseEncryption {
remote, err = libio.WithEncryption(remote, encKey)
if err != nil {
workConn.Close()
xl.Errorf("create encryption stream error: %v", err)
return
}
}
var compressionResourceRecycleFn func()
if baseCfg.Transport.UseCompression {
remote, compressionResourceRecycleFn = libio.WithCompressionFromPool(remote)
remote, recycleFn, err := pxy.wrapWorkConn(workConn, encKey)
if err != nil {
xl.Errorf("wrap work connection: %v", err)
return
}
// check if we need to send proxy protocol info
@@ -178,7 +190,6 @@ func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWor
}
if baseCfg.Transport.ProxyProtocolVersion != "" && m.SrcAddr != "" && m.SrcPort != 0 {
// Use the common proxy protocol builder function
header := netpkg.BuildProxyProtocolHeaderStruct(connInfo.SrcAddr, connInfo.DstAddr, baseCfg.Transport.ProxyProtocolVersion)
connInfo.ProxyProtocolHeader = header
}
@@ -187,12 +198,18 @@ func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWor
if pxy.proxyPlugin != nil {
// if plugin is set, let plugin handle connection first
// Don't recycle compression resources here because plugins may
// retain the connection after Handle returns.
xl.Debugf("handle by plugin: %s", pxy.proxyPlugin.Name())
pxy.proxyPlugin.Handle(pxy.ctx, &connInfo)
xl.Debugf("handle by plugin finished")
return
}
if recycleFn != nil {
defer recycleFn()
}
localConn, err := libnet.Dial(
net.JoinHostPort(baseCfg.LocalIP, strconv.Itoa(baseCfg.LocalPort)),
libnet.WithTimeout(10*time.Second),
@@ -209,6 +226,7 @@ func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWor
if connInfo.ProxyProtocolHeader != nil {
if _, err := connInfo.ProxyProtocolHeader.WriteTo(localConn); err != nil {
workConn.Close()
localConn.Close()
xl.Errorf("write proxy protocol header to local conn error: %v", err)
return
}
@@ -219,7 +237,4 @@ func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWor
if len(errs) > 0 {
xl.Tracef("join connections errors: %v", errs)
}
if compressionResourceRecycleFn != nil {
compressionResourceRecycleFn()
}
}

View File

@@ -118,9 +118,9 @@ func (pm *Manager) HandleEvent(payload any) error {
}
func (pm *Manager) GetAllProxyStatus() []*WorkingStatus {
ps := make([]*WorkingStatus, 0)
pm.mu.RLock()
defer pm.mu.RUnlock()
ps := make([]*WorkingStatus, 0, len(pm.proxies))
for _, pxy := range pm.proxies {
ps = append(ps, pxy.GetStatus())
}

View File

@@ -29,8 +29,8 @@ import (
"github.com/fatedier/frp/client/health"
v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/naming"
"github.com/fatedier/frp/pkg/transport"
"github.com/fatedier/frp/pkg/util/util"
"github.com/fatedier/frp/pkg/util/xlog"
"github.com/fatedier/frp/pkg/vnet"
)
@@ -116,7 +116,7 @@ func NewWrapper(
vnetController: vnetController,
xl: xl,
ctx: xlog.NewContext(ctx, xl),
wireName: util.AddUserPrefix(clientCfg.User, baseInfo.Name),
wireName: naming.AddUserPrefix(clientCfg.User, baseInfo.Name),
}
if baseInfo.HealthCheck.Type != "" && baseInfo.LocalPort > 0 {

View File

@@ -17,7 +17,6 @@
package proxy
import (
"io"
"net"
"reflect"
"strconv"
@@ -25,17 +24,15 @@ import (
"time"
"github.com/fatedier/golib/errors"
libio "github.com/fatedier/golib/io"
v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/proto/udp"
"github.com/fatedier/frp/pkg/util/limit"
netpkg "github.com/fatedier/frp/pkg/util/net"
)
func init() {
RegisterProxyFactory(reflect.TypeOf(&v1.SUDPProxyConfig{}), NewSUDPProxy)
RegisterProxyFactory(reflect.TypeFor[*v1.SUDPProxyConfig](), NewSUDPProxy)
}
type SUDPProxy struct {
@@ -83,27 +80,13 @@ func (pxy *SUDPProxy) InWorkConn(conn net.Conn, _ *msg.StartWorkConn) {
xl := pxy.xl
xl.Infof("incoming a new work connection for sudp proxy, %s", conn.RemoteAddr().String())
var rwc io.ReadWriteCloser = conn
var err error
if pxy.limiter != nil {
rwc = libio.WrapReadWriteCloser(limit.NewReader(conn, pxy.limiter), limit.NewWriter(conn, pxy.limiter), func() error {
return conn.Close()
})
remote, _, err := pxy.wrapWorkConn(conn, pxy.encryptionKey)
if err != nil {
xl.Errorf("wrap work connection: %v", err)
return
}
if pxy.cfg.Transport.UseEncryption {
rwc, err = libio.WithEncryption(rwc, pxy.encryptionKey)
if err != nil {
conn.Close()
xl.Errorf("create encryption stream error: %v", err)
return
}
}
if pxy.cfg.Transport.UseCompression {
rwc = libio.WithCompression(rwc)
}
conn = netpkg.WrapReadWriteCloserToConn(rwc, conn)
workConn := conn
workConn := netpkg.WrapReadWriteCloserToConn(remote, conn)
readCh := make(chan *msg.UDPPacket, 1024)
sendCh := make(chan msg.Message, 1024)
isClose := false

View File

@@ -17,24 +17,21 @@
package proxy
import (
"io"
"net"
"reflect"
"strconv"
"time"
"github.com/fatedier/golib/errors"
libio "github.com/fatedier/golib/io"
v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/proto/udp"
"github.com/fatedier/frp/pkg/util/limit"
netpkg "github.com/fatedier/frp/pkg/util/net"
)
func init() {
RegisterProxyFactory(reflect.TypeOf(&v1.UDPProxyConfig{}), NewUDPProxy)
RegisterProxyFactory(reflect.TypeFor[*v1.UDPProxyConfig](), NewUDPProxy)
}
type UDPProxy struct {
@@ -94,28 +91,14 @@ func (pxy *UDPProxy) InWorkConn(conn net.Conn, _ *msg.StartWorkConn) {
// close resources related with old workConn
pxy.Close()
var rwc io.ReadWriteCloser = conn
var err error
if pxy.limiter != nil {
rwc = libio.WrapReadWriteCloser(limit.NewReader(conn, pxy.limiter), limit.NewWriter(conn, pxy.limiter), func() error {
return conn.Close()
})
remote, _, err := pxy.wrapWorkConn(conn, pxy.encryptionKey)
if err != nil {
xl.Errorf("wrap work connection: %v", err)
return
}
if pxy.cfg.Transport.UseEncryption {
rwc, err = libio.WithEncryption(rwc, pxy.encryptionKey)
if err != nil {
conn.Close()
xl.Errorf("create encryption stream error: %v", err)
return
}
}
if pxy.cfg.Transport.UseCompression {
rwc = libio.WithCompression(rwc)
}
conn = netpkg.WrapReadWriteCloserToConn(rwc, conn)
pxy.mu.Lock()
pxy.workConn = conn
pxy.workConn = netpkg.WrapReadWriteCloserToConn(remote, conn)
pxy.readCh = make(chan *msg.UDPPacket, 1024)
pxy.sendCh = make(chan msg.Message, 1024)
pxy.closed = false
@@ -129,7 +112,7 @@ func (pxy *UDPProxy) InWorkConn(conn net.Conn, _ *msg.StartWorkConn) {
return
}
if errRet := errors.PanicToError(func() {
xl.Tracef("get udp package from workConn: %s", udpMsg.Content)
xl.Tracef("get udp package from workConn, len: %d", len(udpMsg.Content))
readCh <- &udpMsg
}); errRet != nil {
xl.Infof("reader goroutine for udp work connection closed: %v", errRet)
@@ -145,7 +128,7 @@ func (pxy *UDPProxy) InWorkConn(conn net.Conn, _ *msg.StartWorkConn) {
for rawMsg := range sendCh {
switch m := rawMsg.(type) {
case *msg.UDPPacket:
xl.Tracef("send udp package to workConn: %s", m.Content)
xl.Tracef("send udp package to workConn, len: %d", len(m.Content))
case *msg.Ping:
xl.Tracef("send ping message to udp workConn")
}

View File

@@ -27,14 +27,14 @@ import (
v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/naming"
"github.com/fatedier/frp/pkg/nathole"
"github.com/fatedier/frp/pkg/transport"
netpkg "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/util"
)
func init() {
RegisterProxyFactory(reflect.TypeOf(&v1.XTCPProxyConfig{}), NewXTCPProxy)
RegisterProxyFactory(reflect.TypeFor[*v1.XTCPProxyConfig](), NewXTCPProxy)
}
type XTCPProxy struct {
@@ -86,7 +86,7 @@ func (pxy *XTCPProxy) InWorkConn(conn net.Conn, startWorkConnMsg *msg.StartWorkC
transactionID := nathole.NewTransactionID()
natHoleClientMsg := &msg.NatHoleClient{
TransactionID: transactionID,
ProxyName: util.AddUserPrefix(pxy.clientCfg.User, pxy.cfg.Name),
ProxyName: naming.AddUserPrefix(pxy.clientCfg.User, pxy.cfg.Name),
Sid: natHoleSidMsg.Sid,
MappedAddrs: prepareResult.Addrs,
AssistedAddrs: prepareResult.AssistedAddrs,

View File

@@ -123,8 +123,11 @@ type Service struct {
vnetController *vnet.Controller
cfgMu sync.RWMutex
common *v1.ClientCommonConfig
cfgMu sync.RWMutex
// reloadMu serializes reload transactions to keep reloadCommon and applied
// config in sync across concurrent API operations.
reloadMu sync.Mutex
common *v1.ClientCommonConfig
// reloadCommon is used for filtering/defaulting during config-source reloads.
// It can be updated by /api/reload without mutating startup-only common behavior.
reloadCommon *v1.ClientCommonConfig
@@ -441,26 +444,28 @@ func (svr *Service) UpdateConfigSource(
proxyCfgs []v1.ProxyConfigurer,
visitorCfgs []v1.VisitorConfigurer,
) error {
svr.reloadMu.Lock()
defer svr.reloadMu.Unlock()
cfgSource := svr.configSource
if cfgSource == nil {
return fmt.Errorf("config source is not available")
}
// Update reloadCommon before ReplaceAll so the subsequent reload uses the
// same common config as /api/reload validation.
svr.cfgMu.Lock()
prevReloadCommon := svr.reloadCommon
svr.reloadCommon = common
svr.cfgMu.Unlock()
if err := cfgSource.ReplaceAll(proxyCfgs, visitorCfgs); err != nil {
svr.cfgMu.Lock()
svr.reloadCommon = prevReloadCommon
svr.cfgMu.Unlock()
return err
}
return svr.reloadConfigFromSources()
// Non-atomic update semantics: source has been updated at this point.
// Even if reload fails below, keep this common config for subsequent reloads.
svr.cfgMu.Lock()
svr.reloadCommon = common
svr.cfgMu.Unlock()
if err := svr.reloadConfigFromSourcesLocked(); err != nil {
return err
}
return nil
}
func (svr *Service) Close() {
@@ -473,6 +478,15 @@ func (svr *Service) GracefulClose(d time.Duration) {
}
func (svr *Service) stop() {
// Coordinate shutdown with reload/update paths that read source pointers.
svr.reloadMu.Lock()
if svr.aggregator != nil {
svr.aggregator = nil
}
svr.configSource = nil
svr.storeSource = nil
svr.reloadMu.Unlock()
svr.ctlMu.Lock()
defer svr.ctlMu.Unlock()
if svr.ctl != nil {
@@ -483,11 +497,6 @@ func (svr *Service) stop() {
svr.webServer.Close()
svr.webServer = nil
}
if svr.aggregator != nil {
svr.aggregator = nil
}
svr.configSource = nil
svr.storeSource = nil
}
func (svr *Service) getProxyStatus(name string) (*proxy.WorkingStatus, bool) {
@@ -520,7 +529,14 @@ func (s *statusExporterImpl) GetProxyStatus(name string) (*proxy.WorkingStatus,
}
func (svr *Service) reloadConfigFromSources() error {
if svr.aggregator == nil {
svr.reloadMu.Lock()
defer svr.reloadMu.Unlock()
return svr.reloadConfigFromSourcesLocked()
}
func (svr *Service) reloadConfigFromSourcesLocked() error {
aggregator := svr.aggregator
if aggregator == nil {
return errors.New("config aggregator is not initialized")
}
@@ -528,7 +544,7 @@ func (svr *Service) reloadConfigFromSources() error {
reloadCommon := svr.reloadCommon
svr.cfgMu.RUnlock()
proxies, visitors, err := svr.aggregator.Load()
proxies, visitors, err := aggregator.Load()
if err != nil {
return fmt.Errorf("reload config from sources failed: %w", err)
}

140
client/service_test.go Normal file
View File

@@ -0,0 +1,140 @@
package client
import (
"path/filepath"
"strings"
"testing"
"github.com/fatedier/frp/pkg/config/source"
v1 "github.com/fatedier/frp/pkg/config/v1"
)
func TestUpdateConfigSourceRollsBackReloadCommonOnReplaceAllFailure(t *testing.T) {
prevCommon := &v1.ClientCommonConfig{User: "old-user"}
newCommon := &v1.ClientCommonConfig{User: "new-user"}
svr := &Service{
configSource: source.NewConfigSource(),
reloadCommon: prevCommon,
}
invalidProxy := &v1.TCPProxyConfig{}
err := svr.UpdateConfigSource(newCommon, []v1.ProxyConfigurer{invalidProxy}, nil)
if err == nil {
t.Fatal("expected error, got nil")
}
if !strings.Contains(err.Error(), "proxy name cannot be empty") {
t.Fatalf("unexpected error: %v", err)
}
if svr.reloadCommon != prevCommon {
t.Fatalf("reloadCommon should roll back on ReplaceAll failure")
}
}
func TestUpdateConfigSourceKeepsReloadCommonOnReloadFailure(t *testing.T) {
prevCommon := &v1.ClientCommonConfig{User: "old-user"}
newCommon := &v1.ClientCommonConfig{User: "new-user"}
svr := &Service{
// Keep configSource valid so ReplaceAll succeeds first.
configSource: source.NewConfigSource(),
reloadCommon: prevCommon,
// Keep aggregator nil to force reload failure.
aggregator: nil,
}
validProxy := &v1.TCPProxyConfig{
ProxyBaseConfig: v1.ProxyBaseConfig{
Name: "p1",
Type: "tcp",
},
}
err := svr.UpdateConfigSource(newCommon, []v1.ProxyConfigurer{validProxy}, nil)
if err == nil {
t.Fatal("expected error, got nil")
}
if !strings.Contains(err.Error(), "config aggregator is not initialized") {
t.Fatalf("unexpected error: %v", err)
}
if svr.reloadCommon != newCommon {
t.Fatalf("reloadCommon should keep new value on reload failure")
}
}
func TestReloadConfigFromSourcesDoesNotMutateStoreConfigs(t *testing.T) {
storeSource, err := source.NewStoreSource(source.StoreSourceConfig{
Path: filepath.Join(t.TempDir(), "store.json"),
})
if err != nil {
t.Fatalf("new store source: %v", err)
}
proxyCfg := &v1.TCPProxyConfig{
ProxyBaseConfig: v1.ProxyBaseConfig{
Name: "store-proxy",
Type: "tcp",
},
}
visitorCfg := &v1.STCPVisitorConfig{
VisitorBaseConfig: v1.VisitorBaseConfig{
Name: "store-visitor",
Type: "stcp",
},
}
if err := storeSource.AddProxy(proxyCfg); err != nil {
t.Fatalf("add proxy to store: %v", err)
}
if err := storeSource.AddVisitor(visitorCfg); err != nil {
t.Fatalf("add visitor to store: %v", err)
}
agg := source.NewAggregator(source.NewConfigSource())
agg.SetStoreSource(storeSource)
svr := &Service{
aggregator: agg,
configSource: agg.ConfigSource(),
storeSource: storeSource,
reloadCommon: &v1.ClientCommonConfig{},
}
if err := svr.reloadConfigFromSources(); err != nil {
t.Fatalf("reload config from sources: %v", err)
}
gotProxy := storeSource.GetProxy("store-proxy")
if gotProxy == nil {
t.Fatalf("proxy not found in store")
}
if gotProxy.GetBaseConfig().LocalIP != "" {
t.Fatalf("store proxy localIP should stay empty, got %q", gotProxy.GetBaseConfig().LocalIP)
}
gotVisitor := storeSource.GetVisitor("store-visitor")
if gotVisitor == nil {
t.Fatalf("visitor not found in store")
}
if gotVisitor.GetBaseConfig().BindAddr != "" {
t.Fatalf("store visitor bindAddr should stay empty, got %q", gotVisitor.GetBaseConfig().BindAddr)
}
svr.cfgMu.RLock()
defer svr.cfgMu.RUnlock()
if len(svr.proxyCfgs) != 1 {
t.Fatalf("expected 1 runtime proxy, got %d", len(svr.proxyCfgs))
}
if svr.proxyCfgs[0].GetBaseConfig().LocalIP != "127.0.0.1" {
t.Fatalf("runtime proxy localIP should be defaulted, got %q", svr.proxyCfgs[0].GetBaseConfig().LocalIP)
}
if len(svr.visitorCfgs) != 1 {
t.Fatalf("expected 1 runtime visitor, got %d", len(svr.visitorCfgs))
}
if svr.visitorCfgs[0].GetBaseConfig().BindAddr != "127.0.0.1" {
t.Fatalf("runtime visitor bindAddr should be defaulted, got %q", svr.visitorCfgs[0].GetBaseConfig().BindAddr)
}
}

View File

@@ -15,17 +15,12 @@
package visitor
import (
"fmt"
"io"
"net"
"strconv"
"time"
libio "github.com/fatedier/golib/io"
v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/util/util"
"github.com/fatedier/frp/pkg/util/xlog"
)
@@ -41,10 +36,10 @@ func (sv *STCPVisitor) Run() (err error) {
if err != nil {
return
}
go sv.worker()
go sv.acceptLoop(sv.l, "stcp local", sv.handleConn)
}
go sv.internalConnWorker()
go sv.acceptLoop(sv.internalLn, "stcp internal", sv.handleConn)
if sv.plugin != nil {
sv.plugin.Start()
@@ -56,35 +51,10 @@ func (sv *STCPVisitor) Close() {
sv.BaseVisitor.Close()
}
func (sv *STCPVisitor) worker() {
xl := xlog.FromContextSafe(sv.ctx)
for {
conn, err := sv.l.Accept()
if err != nil {
xl.Warnf("stcp local listener closed")
return
}
go sv.handleConn(conn)
}
}
func (sv *STCPVisitor) internalConnWorker() {
xl := xlog.FromContextSafe(sv.ctx)
for {
conn, err := sv.internalLn.Accept()
if err != nil {
xl.Warnf("stcp internal listener closed")
return
}
go sv.handleConn(conn)
}
}
func (sv *STCPVisitor) handleConn(userConn net.Conn) {
xl := xlog.FromContextSafe(sv.ctx)
var tunnelErr error
defer func() {
// If there was an error and connection supports CloseWithError, use it
if tunnelErr != nil {
if eConn, ok := userConn.(interface{ CloseWithError(error) error }); ok {
_ = eConn.CloseWithError(tunnelErr)
@@ -95,62 +65,21 @@ func (sv *STCPVisitor) handleConn(userConn net.Conn) {
}()
xl.Debugf("get a new stcp user connection")
visitorConn, err := sv.helper.ConnectServer()
visitorConn, err := sv.dialRawVisitorConn(sv.cfg.GetBaseConfig())
if err != nil {
xl.Warnf("dialRawVisitorConn error: %v", err)
tunnelErr = err
return
}
defer visitorConn.Close()
now := time.Now().Unix()
targetProxyName := util.BuildTargetServerProxyName(sv.clientCfg.User, sv.cfg.ServerUser, sv.cfg.ServerName)
newVisitorConnMsg := &msg.NewVisitorConn{
RunID: sv.helper.RunID(),
ProxyName: targetProxyName,
SignKey: util.GetAuthKey(sv.cfg.SecretKey, now),
Timestamp: now,
UseEncryption: sv.cfg.Transport.UseEncryption,
UseCompression: sv.cfg.Transport.UseCompression,
}
err = msg.WriteMsg(visitorConn, newVisitorConnMsg)
remote, recycleFn, err := wrapVisitorConn(visitorConn, sv.cfg.GetBaseConfig())
if err != nil {
xl.Warnf("send newVisitorConnMsg to server error: %v", err)
xl.Warnf("wrapVisitorConn error: %v", err)
tunnelErr = err
return
}
var newVisitorConnRespMsg msg.NewVisitorConnResp
_ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second))
err = msg.ReadMsgInto(visitorConn, &newVisitorConnRespMsg)
if err != nil {
xl.Warnf("get newVisitorConnRespMsg error: %v", err)
tunnelErr = err
return
}
_ = visitorConn.SetReadDeadline(time.Time{})
if newVisitorConnRespMsg.Error != "" {
xl.Warnf("start new visitor connection error: %s", newVisitorConnRespMsg.Error)
tunnelErr = fmt.Errorf("%s", newVisitorConnRespMsg.Error)
return
}
var remote io.ReadWriteCloser
remote = visitorConn
if sv.cfg.Transport.UseEncryption {
remote, err = libio.WithEncryption(remote, []byte(sv.cfg.SecretKey))
if err != nil {
xl.Errorf("create encryption stream error: %v", err)
tunnelErr = err
return
}
}
if sv.cfg.Transport.UseCompression {
var recycleFn func()
remote, recycleFn = libio.WithCompressionFromPool(remote)
defer recycleFn()
}
defer recycleFn()
libio.Join(userConn, remote)
}

View File

@@ -16,20 +16,17 @@ package visitor
import (
"fmt"
"io"
"net"
"strconv"
"sync"
"time"
"github.com/fatedier/golib/errors"
libio "github.com/fatedier/golib/io"
v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/proto/udp"
netpkg "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/util"
"github.com/fatedier/frp/pkg/util/xlog"
)
@@ -75,6 +72,7 @@ func (sv *SUDPVisitor) dispatcher() {
var (
visitorConn net.Conn
recycleFn func()
err error
firstPacket *msg.UDPPacket
@@ -92,14 +90,17 @@ func (sv *SUDPVisitor) dispatcher() {
return
}
visitorConn, err = sv.getNewVisitorConn()
visitorConn, recycleFn, err = sv.getNewVisitorConn()
if err != nil {
xl.Warnf("newVisitorConn to frps error: %v, try to reconnect", err)
continue
}
// visitorConn always be closed when worker done.
sv.worker(visitorConn, firstPacket)
func() {
defer recycleFn()
sv.worker(visitorConn, firstPacket)
}()
select {
case <-sv.checkCloseCh:
@@ -146,7 +147,7 @@ func (sv *SUDPVisitor) worker(workConn net.Conn, firstPacket *msg.UDPPacket) {
case *msg.UDPPacket:
if errRet := errors.PanicToError(func() {
sv.readCh <- m
xl.Tracef("frpc visitor get udp packet from workConn: %s", m.Content)
xl.Tracef("frpc visitor get udp packet from workConn, len: %d", len(m.Content))
}); errRet != nil {
xl.Infof("reader goroutine for udp work connection closed")
return
@@ -168,7 +169,7 @@ func (sv *SUDPVisitor) worker(workConn net.Conn, firstPacket *msg.UDPPacket) {
xl.Warnf("sender goroutine for udp work connection closed: %v", errRet)
return
}
xl.Tracef("send udp package to workConn: %s", firstPacket.Content)
xl.Tracef("send udp package to workConn, len: %d", len(firstPacket.Content))
}
for {
@@ -183,7 +184,7 @@ func (sv *SUDPVisitor) worker(workConn net.Conn, firstPacket *msg.UDPPacket) {
xl.Warnf("sender goroutine for udp work connection closed: %v", errRet)
return
}
xl.Tracef("send udp package to workConn: %s", udpMsg.Content)
xl.Tracef("send udp package to workConn, len: %d", len(udpMsg.Content))
case <-closeCh:
return
}
@@ -197,53 +198,17 @@ func (sv *SUDPVisitor) worker(workConn net.Conn, firstPacket *msg.UDPPacket) {
xl.Infof("sudp worker is closed")
}
func (sv *SUDPVisitor) getNewVisitorConn() (net.Conn, error) {
xl := xlog.FromContextSafe(sv.ctx)
visitorConn, err := sv.helper.ConnectServer()
func (sv *SUDPVisitor) getNewVisitorConn() (net.Conn, func(), error) {
rawConn, err := sv.dialRawVisitorConn(sv.cfg.GetBaseConfig())
if err != nil {
return nil, fmt.Errorf("frpc connect frps error: %v", err)
return nil, func() {}, err
}
now := time.Now().Unix()
targetProxyName := util.BuildTargetServerProxyName(sv.clientCfg.User, sv.cfg.ServerUser, sv.cfg.ServerName)
newVisitorConnMsg := &msg.NewVisitorConn{
RunID: sv.helper.RunID(),
ProxyName: targetProxyName,
SignKey: util.GetAuthKey(sv.cfg.SecretKey, now),
Timestamp: now,
UseEncryption: sv.cfg.Transport.UseEncryption,
UseCompression: sv.cfg.Transport.UseCompression,
}
err = msg.WriteMsg(visitorConn, newVisitorConnMsg)
rwc, recycleFn, err := wrapVisitorConn(rawConn, sv.cfg.GetBaseConfig())
if err != nil {
return nil, fmt.Errorf("frpc send newVisitorConnMsg to frps error: %v", err)
rawConn.Close()
return nil, func() {}, err
}
var newVisitorConnRespMsg msg.NewVisitorConnResp
_ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second))
err = msg.ReadMsgInto(visitorConn, &newVisitorConnRespMsg)
if err != nil {
return nil, fmt.Errorf("frpc read newVisitorConnRespMsg error: %v", err)
}
_ = visitorConn.SetReadDeadline(time.Time{})
if newVisitorConnRespMsg.Error != "" {
return nil, fmt.Errorf("start new visitor connection error: %s", newVisitorConnRespMsg.Error)
}
var remote io.ReadWriteCloser
remote = visitorConn
if sv.cfg.Transport.UseEncryption {
remote, err = libio.WithEncryption(remote, []byte(sv.cfg.SecretKey))
if err != nil {
xl.Errorf("create encryption stream error: %v", err)
return nil, err
}
}
if sv.cfg.Transport.UseCompression {
remote = libio.WithCompression(remote)
}
return netpkg.WrapReadWriteCloserToConn(remote, visitorConn), nil
return netpkg.WrapReadWriteCloserToConn(rwc, rawConn), recycleFn, nil
}
func (sv *SUDPVisitor) Close() {

View File

@@ -16,13 +16,21 @@ package visitor
import (
"context"
"fmt"
"io"
"net"
"sync"
"time"
libio "github.com/fatedier/golib/io"
v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/naming"
plugin "github.com/fatedier/frp/pkg/plugin/visitor"
"github.com/fatedier/frp/pkg/transport"
netpkg "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/util"
"github.com/fatedier/frp/pkg/util/xlog"
"github.com/fatedier/frp/pkg/vnet"
)
@@ -119,6 +127,18 @@ func (v *BaseVisitor) AcceptConn(conn net.Conn) error {
return v.internalLn.PutConn(conn)
}
func (v *BaseVisitor) acceptLoop(l net.Listener, name string, handleConn func(net.Conn)) {
xl := xlog.FromContextSafe(v.ctx)
for {
conn, err := l.Accept()
if err != nil {
xl.Warnf("%s listener closed", name)
return
}
go handleConn(conn)
}
}
func (v *BaseVisitor) Close() {
if v.l != nil {
v.l.Close()
@@ -130,3 +150,57 @@ func (v *BaseVisitor) Close() {
v.plugin.Close()
}
}
func (v *BaseVisitor) dialRawVisitorConn(cfg *v1.VisitorBaseConfig) (net.Conn, error) {
visitorConn, err := v.helper.ConnectServer()
if err != nil {
return nil, fmt.Errorf("connect to server error: %v", err)
}
now := time.Now().Unix()
targetProxyName := naming.BuildTargetServerProxyName(v.clientCfg.User, cfg.ServerUser, cfg.ServerName)
newVisitorConnMsg := &msg.NewVisitorConn{
RunID: v.helper.RunID(),
ProxyName: targetProxyName,
SignKey: util.GetAuthKey(cfg.SecretKey, now),
Timestamp: now,
UseEncryption: cfg.Transport.UseEncryption,
UseCompression: cfg.Transport.UseCompression,
}
err = msg.WriteMsg(visitorConn, newVisitorConnMsg)
if err != nil {
visitorConn.Close()
return nil, fmt.Errorf("send newVisitorConnMsg to server error: %v", err)
}
var newVisitorConnRespMsg msg.NewVisitorConnResp
_ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second))
err = msg.ReadMsgInto(visitorConn, &newVisitorConnRespMsg)
if err != nil {
visitorConn.Close()
return nil, fmt.Errorf("read newVisitorConnRespMsg error: %v", err)
}
_ = visitorConn.SetReadDeadline(time.Time{})
if newVisitorConnRespMsg.Error != "" {
visitorConn.Close()
return nil, fmt.Errorf("start new visitor connection error: %s", newVisitorConnRespMsg.Error)
}
return visitorConn, nil
}
func wrapVisitorConn(conn io.ReadWriteCloser, cfg *v1.VisitorBaseConfig) (io.ReadWriteCloser, func(), error) {
rwc := conn
if cfg.Transport.UseEncryption {
var err error
rwc, err = libio.WithEncryption(rwc, []byte(cfg.SecretKey))
if err != nil {
return nil, func() {}, fmt.Errorf("create encryption stream error: %v", err)
}
}
recycleFn := func() {}
if cfg.Transport.UseCompression {
rwc, recycleFn = libio.WithCompressionFromPool(rwc)
}
return rwc, recycleFn, nil
}

View File

@@ -31,6 +31,7 @@ import (
v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/naming"
"github.com/fatedier/frp/pkg/nathole"
"github.com/fatedier/frp/pkg/transport"
netpkg "github.com/fatedier/frp/pkg/util/net"
@@ -64,10 +65,10 @@ func (sv *XTCPVisitor) Run() (err error) {
if err != nil {
return
}
go sv.worker()
go sv.acceptLoop(sv.l, "xtcp local", sv.handleConn)
}
go sv.internalConnWorker()
go sv.acceptLoop(sv.internalLn, "xtcp internal", sv.handleConn)
go sv.processTunnelStartEvents()
if sv.cfg.KeepTunnelOpen {
sv.retryLimiter = rate.NewLimiter(rate.Every(time.Hour/time.Duration(sv.cfg.MaxRetriesAnHour)), sv.cfg.MaxRetriesAnHour)
@@ -92,30 +93,6 @@ func (sv *XTCPVisitor) Close() {
}
}
func (sv *XTCPVisitor) worker() {
xl := xlog.FromContextSafe(sv.ctx)
for {
conn, err := sv.l.Accept()
if err != nil {
xl.Warnf("xtcp local listener closed")
return
}
go sv.handleConn(conn)
}
}
func (sv *XTCPVisitor) internalConnWorker() {
xl := xlog.FromContextSafe(sv.ctx)
for {
conn, err := sv.internalLn.Accept()
if err != nil {
xl.Warnf("xtcp internal listener closed")
return
}
go sv.handleConn(conn)
}
}
func (sv *XTCPVisitor) processTunnelStartEvents() {
for {
select {
@@ -205,20 +182,14 @@ func (sv *XTCPVisitor) handleConn(userConn net.Conn) {
return
}
var muxConnRWCloser io.ReadWriteCloser = tunnelConn
if sv.cfg.Transport.UseEncryption {
muxConnRWCloser, err = libio.WithEncryption(muxConnRWCloser, []byte(sv.cfg.SecretKey))
if err != nil {
xl.Errorf("create encryption stream error: %v", err)
tunnelErr = err
return
}
}
if sv.cfg.Transport.UseCompression {
var recycleFn func()
muxConnRWCloser, recycleFn = libio.WithCompressionFromPool(muxConnRWCloser)
defer recycleFn()
muxConnRWCloser, recycleFn, err := wrapVisitorConn(tunnelConn, sv.cfg.GetBaseConfig())
if err != nil {
xl.Errorf("%v", err)
tunnelConn.Close()
tunnelErr = err
return
}
defer recycleFn()
_, _, errs := libio.Join(userConn, muxConnRWCloser)
xl.Debugf("join connections closed")
@@ -280,7 +251,7 @@ func (sv *XTCPVisitor) getTunnelConn(ctx context.Context) (net.Conn, error) {
// 4. Create a tunnel session using an underlying UDP connection.
func (sv *XTCPVisitor) makeNatHole() {
xl := xlog.FromContextSafe(sv.ctx)
targetProxyName := util.BuildTargetServerProxyName(sv.clientCfg.User, sv.cfg.ServerUser, sv.cfg.ServerName)
targetProxyName := naming.BuildTargetServerProxyName(sv.clientCfg.User, sv.cfg.ServerUser, sv.cfg.ServerName)
xl.Tracef("makeNatHole start")
if err := nathole.PreCheck(sv.ctx, sv.helper.MsgTransporter(), targetProxyName, 5*time.Second); err != nil {
xl.Warnf("nathole precheck error: %v", err)
@@ -372,6 +343,7 @@ func (ks *KCPTunnelSession) Init(listenConn *net.UDPConn, raddr *net.UDPAddr) er
}
remote, err := netpkg.NewKCPConnFromUDP(lConn, true, raddr.String())
if err != nil {
lConn.Close()
return fmt.Errorf("create kcp connection from udp connection error: %v", err)
}

View File

@@ -47,7 +47,7 @@ var natholeDiscoveryCmd = &cobra.Command{
Use: "discover",
Short: "Discover nathole information from stun server",
RunE: func(cmd *cobra.Command, args []string) error {
// ignore error here, because we can use command line pameters
// ignore error here, because we can use command line parameters
cfg, _, _, _, err := config.LoadClientConfig(cfgFile, strictConfigMode)
if err != nil {
cfg = &v1.ClientCommonConfig{}

View File

@@ -143,6 +143,9 @@ transport.tls.enable = true
# Proxy names you want to start.
# Default is empty, means all proxies.
# This list is a global allowlist after config + store are merged, so entries
# created via Store API are also filtered by this list.
# If start is non-empty, any proxy/visitor not listed here will not be started.
# start = ["ssh", "dns"]
# Alternative to 'start': You can control each proxy individually using the 'enabled' field.

View File

@@ -5,7 +5,7 @@ COPY web/frpc/ ./
RUN npm install
RUN npm run build
FROM golang:1.24 AS building
FROM golang:1.25 AS building
COPY . /building
COPY --from=web-builder /web/frpc/dist /building/web/frpc/dist

View File

@@ -5,7 +5,7 @@ COPY web/frps/ ./
RUN npm install
RUN npm run build
FROM golang:1.24 AS building
FROM golang:1.25 AS building
COPY . /building
COPY --from=web-builder /web/frps/dist /building/web/frps/dist

2
go.mod
View File

@@ -1,6 +1,6 @@
module github.com/fatedier/frp
go 1.24.0
go 1.25.0
require (
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5

View File

@@ -23,6 +23,7 @@ import (
"net/url"
"os"
"slices"
"sync"
"github.com/coreos/go-oidc/v3/oidc"
"golang.org/x/oauth2"
@@ -205,7 +206,8 @@ type OidcAuthConsumer struct {
additionalAuthScopes []v1.AuthScope
verifier TokenVerifier
subjectsFromLogin []string
mu sync.RWMutex
subjectsFromLogin map[string]struct{}
}
func NewTokenVerifier(cfg v1.AuthOIDCServerConfig) TokenVerifier {
@@ -226,7 +228,7 @@ func NewOidcAuthVerifier(additionalAuthScopes []v1.AuthScope, verifier TokenVeri
return &OidcAuthConsumer{
additionalAuthScopes: additionalAuthScopes,
verifier: verifier,
subjectsFromLogin: []string{},
subjectsFromLogin: make(map[string]struct{}),
}
}
@@ -235,9 +237,9 @@ func (auth *OidcAuthConsumer) VerifyLogin(loginMsg *msg.Login) (err error) {
if err != nil {
return fmt.Errorf("invalid OIDC token in login: %v", err)
}
if !slices.Contains(auth.subjectsFromLogin, token.Subject) {
auth.subjectsFromLogin = append(auth.subjectsFromLogin, token.Subject)
}
auth.mu.Lock()
auth.subjectsFromLogin[token.Subject] = struct{}{}
auth.mu.Unlock()
return nil
}
@@ -246,11 +248,13 @@ func (auth *OidcAuthConsumer) verifyPostLoginToken(privilegeKey string) (err err
if err != nil {
return fmt.Errorf("invalid OIDC token in ping: %v", err)
}
if !slices.Contains(auth.subjectsFromLogin, token.Subject) {
auth.mu.RLock()
_, ok := auth.subjectsFromLogin[token.Subject]
auth.mu.RUnlock()
if !ok {
return fmt.Errorf("received different OIDC subject in login and ping. "+
"original subjects: %s, "+
"new subject: %s",
auth.subjectsFromLogin, token.Subject)
token.Subject)
}
return nil
}

View File

@@ -171,15 +171,14 @@ func Convert_ServerCommonConf_To_v1(conf *ServerCommonConf) *v1.ServerConfig {
func transformHeadersFromPluginParams(params map[string]string) v1.HeaderOperations {
out := v1.HeaderOperations{}
for k, v := range params {
if !strings.HasPrefix(k, "plugin_header_") {
k, ok := strings.CutPrefix(k, "plugin_header_")
if !ok || k == "" {
continue
}
if k = strings.TrimPrefix(k, "plugin_header_"); k != "" {
if out.Set == nil {
out.Set = make(map[string]string)
}
out.Set[k] = v
if out.Set == nil {
out.Set = make(map[string]string)
}
out.Set[k] = v
}
return out
}

View File

@@ -39,14 +39,14 @@ const (
// Proxy
var (
proxyConfTypeMap = map[ProxyType]reflect.Type{
ProxyTypeTCP: reflect.TypeOf(TCPProxyConf{}),
ProxyTypeUDP: reflect.TypeOf(UDPProxyConf{}),
ProxyTypeTCPMUX: reflect.TypeOf(TCPMuxProxyConf{}),
ProxyTypeHTTP: reflect.TypeOf(HTTPProxyConf{}),
ProxyTypeHTTPS: reflect.TypeOf(HTTPSProxyConf{}),
ProxyTypeSTCP: reflect.TypeOf(STCPProxyConf{}),
ProxyTypeXTCP: reflect.TypeOf(XTCPProxyConf{}),
ProxyTypeSUDP: reflect.TypeOf(SUDPProxyConf{}),
ProxyTypeTCP: reflect.TypeFor[TCPProxyConf](),
ProxyTypeUDP: reflect.TypeFor[UDPProxyConf](),
ProxyTypeTCPMUX: reflect.TypeFor[TCPMuxProxyConf](),
ProxyTypeHTTP: reflect.TypeFor[HTTPProxyConf](),
ProxyTypeHTTPS: reflect.TypeFor[HTTPSProxyConf](),
ProxyTypeSTCP: reflect.TypeFor[STCPProxyConf](),
ProxyTypeXTCP: reflect.TypeFor[XTCPProxyConf](),
ProxyTypeSUDP: reflect.TypeFor[SUDPProxyConf](),
}
)

View File

@@ -22,8 +22,8 @@ func GetMapWithoutPrefix(set map[string]string, prefix string) map[string]string
m := make(map[string]string)
for key, value := range set {
if strings.HasPrefix(key, prefix) {
m[strings.TrimPrefix(key, prefix)] = value
if trimmed, ok := strings.CutPrefix(key, prefix); ok {
m[trimmed] = value
}
}

View File

@@ -32,9 +32,9 @@ const (
// Visitor
var (
visitorConfTypeMap = map[VisitorType]reflect.Type{
VisitorTypeSTCP: reflect.TypeOf(STCPVisitorConf{}),
VisitorTypeXTCP: reflect.TypeOf(XTCPVisitorConf{}),
VisitorTypeSUDP: reflect.TypeOf(SUDPVisitorConf{}),
VisitorTypeSTCP: reflect.TypeFor[STCPVisitorConf](),
VisitorTypeXTCP: reflect.TypeFor[XTCPVisitorConf](),
VisitorTypeSUDP: reflect.TypeFor[SUDPVisitorConf](),
}
)

View File

@@ -17,6 +17,7 @@ package config
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
@@ -33,6 +34,7 @@ import (
v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/config/v1/validation"
"github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/util/jsonx"
"github.com/fatedier/frp/pkg/util/util"
)
@@ -108,7 +110,21 @@ func LoadConfigureFromFile(path string, c any, strict bool) error {
if err != nil {
return err
}
return LoadConfigure(content, c, strict)
return LoadConfigure(content, c, strict, detectFormatFromPath(path))
}
// detectFormatFromPath returns a format hint based on the file extension.
func detectFormatFromPath(path string) string {
switch strings.ToLower(filepath.Ext(path)) {
case ".toml":
return "toml"
case ".yaml", ".yml":
return "yaml"
case ".json":
return "json"
default:
return ""
}
}
// parseYAMLWithDotFieldsHandling parses YAML with dot-prefixed fields handling
@@ -129,48 +145,136 @@ func parseYAMLWithDotFieldsHandling(content []byte, target any) error {
}
// Convert to JSON and decode with strict validation
jsonBytes, err := json.Marshal(temp)
jsonBytes, err := jsonx.Marshal(temp)
if err != nil {
return err
}
decoder := json.NewDecoder(bytes.NewReader(jsonBytes))
decoder.DisallowUnknownFields()
return decoder.Decode(target)
return decodeJSONContent(jsonBytes, target, true)
}
func decodeJSONContent(content []byte, target any, strict bool) error {
if clientCfg, ok := target.(*v1.ClientConfig); ok {
decoded, err := v1.DecodeClientConfigJSON(content, v1.DecodeOptions{
DisallowUnknownFields: strict,
})
if err != nil {
return err
}
*clientCfg = decoded
return nil
}
return jsonx.UnmarshalWithOptions(content, target, jsonx.DecodeOptions{
RejectUnknownMembers: strict,
})
}
// LoadConfigure loads configuration from bytes and unmarshal into c.
// Now it supports json, yaml and toml format.
func LoadConfigure(b []byte, c any, strict bool) error {
v1.DisallowUnknownFieldsMu.Lock()
defer v1.DisallowUnknownFieldsMu.Unlock()
v1.DisallowUnknownFields = strict
// An optional format hint (e.g. "toml", "yaml", "json") can be provided
// to enable better error messages with line number information.
func LoadConfigure(b []byte, c any, strict bool, formats ...string) error {
format := ""
if len(formats) > 0 {
format = formats[0]
}
originalBytes := b
parsedFromTOML := false
var tomlObj any
// Try to unmarshal as TOML first; swallow errors from that (assume it's not valid TOML).
if err := toml.Unmarshal(b, &tomlObj); err == nil {
b, err = json.Marshal(&tomlObj)
tomlErr := toml.Unmarshal(b, &tomlObj)
if tomlErr == nil {
parsedFromTOML = true
var err error
b, err = jsonx.Marshal(&tomlObj)
if err != nil {
return err
}
} else if format == "toml" {
// File is known to be TOML but has syntax errors.
return formatTOMLError(tomlErr)
}
// If the buffer smells like JSON (first non-whitespace character is '{'), unmarshal as JSON directly.
if yaml.IsJSONBuffer(b) {
decoder := json.NewDecoder(bytes.NewBuffer(b))
if strict {
decoder.DisallowUnknownFields()
if err := decodeJSONContent(b, c, strict); err != nil {
return enhanceDecodeError(err, originalBytes, !parsedFromTOML)
}
return decoder.Decode(c)
return nil
}
// Handle YAML content
if strict {
// In strict mode, always use our custom handler to support YAML merge
return parseYAMLWithDotFieldsHandling(b, c)
if err := parseYAMLWithDotFieldsHandling(b, c); err != nil {
return enhanceDecodeError(err, originalBytes, !parsedFromTOML)
}
return nil
}
// Non-strict mode, parse normally
return yaml.Unmarshal(b, c)
}
// formatTOMLError extracts line/column information from TOML decode errors.
func formatTOMLError(err error) error {
var decErr *toml.DecodeError
if errors.As(err, &decErr) {
row, col := decErr.Position()
return fmt.Errorf("toml: line %d, column %d: %s", row, col, decErr.Error())
}
var strictErr *toml.StrictMissingError
if errors.As(err, &strictErr) {
return strictErr
}
return err
}
// enhanceDecodeError tries to add field path and line number information to JSON/YAML decode errors.
func enhanceDecodeError(err error, originalContent []byte, includeLine bool) error {
var typeErr *json.UnmarshalTypeError
if errors.As(err, &typeErr) && typeErr.Field != "" {
if includeLine {
line := findFieldLineInContent(originalContent, typeErr.Field)
if line > 0 {
return fmt.Errorf("line %d: field \"%s\": cannot unmarshal %s into %s", line, typeErr.Field, typeErr.Value, typeErr.Type)
}
}
return fmt.Errorf("field \"%s\": cannot unmarshal %s into %s", typeErr.Field, typeErr.Value, typeErr.Type)
}
return err
}
// findFieldLineInContent searches the original config content for a field name
// and returns the 1-indexed line number where it appears, or 0 if not found.
func findFieldLineInContent(content []byte, fieldPath string) int {
if fieldPath == "" {
return 0
}
// Use the last component of the field path (e.g. "proxies" from "proxies" or
// "protocol" from "transport.protocol").
parts := strings.Split(fieldPath, ".")
searchKey := parts[len(parts)-1]
lines := bytes.Split(content, []byte("\n"))
for i, line := range lines {
trimmed := bytes.TrimSpace(line)
// Match TOML key assignments like: key = ...
if bytes.HasPrefix(trimmed, []byte(searchKey)) {
rest := bytes.TrimSpace(trimmed[len(searchKey):])
if len(rest) > 0 && rest[0] == '=' {
return i + 1
}
}
// Match TOML table array headers like: [[proxies]]
if bytes.Contains(trimmed, []byte("[["+searchKey+"]]")) {
return i + 1
}
}
return 0
}
func NewProxyConfigurerFromMsg(m *msg.NewProxy, serverCfg *v1.ServerConfig) (v1.ProxyConfigurer, error) {
m.ProxyType = util.EmptyOr(m.ProxyType, string(v1.ProxyTypeTCP))
@@ -341,7 +445,8 @@ func FilterClientConfigurers(
proxyCfgs := proxies
visitorCfgs := visitors
// Filter by start
// Filter by start across merged configurers from all sources.
// For example, store entries are also filtered by this set.
if len(common.Start) > 0 {
startSet := sets.New(common.Start...)
proxyCfgs = lo.Filter(proxyCfgs, func(c v1.ProxyConfigurer, _ int) bool {

View File

@@ -189,6 +189,31 @@ unixPath = "/tmp/uds.sock"
require.Error(err)
}
func TestLoadClientConfigStrictMode_UnknownPluginField(t *testing.T) {
require := require.New(t)
content := `
serverPort = 7000
[[proxies]]
name = "test"
type = "tcp"
localPort = 6000
[proxies.plugin]
type = "http2https"
localAddr = "127.0.0.1:8080"
unknownInPlugin = "value"
`
clientCfg := v1.ClientConfig{}
err := LoadConfigure([]byte(content), &clientCfg, false)
require.NoError(err)
err = LoadConfigure([]byte(content), &clientCfg, true)
require.ErrorContains(err, "unknownInPlugin")
}
// TestYAMLMergeInStrictMode tests that YAML merge functionality works
// even in strict mode by properly handling dot-prefixed fields
func TestYAMLMergeInStrictMode(t *testing.T) {
@@ -470,3 +495,111 @@ serverPort: 7000
require.Equal("127.0.0.1", clientCfg.ServerAddr)
require.Equal(7000, clientCfg.ServerPort)
}
func TestTOMLSyntaxErrorWithPosition(t *testing.T) {
require := require.New(t)
// TOML with syntax error (unclosed table array header)
content := `serverAddr = "127.0.0.1"
serverPort = 7000
[[proxies]
name = "test"
`
clientCfg := v1.ClientConfig{}
err := LoadConfigure([]byte(content), &clientCfg, false, "toml")
require.Error(err)
require.Contains(err.Error(), "toml")
require.Contains(err.Error(), "line")
require.Contains(err.Error(), "column")
}
func TestTOMLTypeMismatchErrorWithFieldInfo(t *testing.T) {
require := require.New(t)
// TOML with wrong type: proxies should be a table array, not a string
content := `serverAddr = "127.0.0.1"
serverPort = 7000
proxies = "this should be a table array"
`
clientCfg := v1.ClientConfig{}
err := LoadConfigure([]byte(content), &clientCfg, false, "toml")
require.Error(err)
// The error should contain field info
errMsg := err.Error()
require.Contains(errMsg, "proxies")
require.NotContains(errMsg, "line")
}
func TestFindFieldLineInContent(t *testing.T) {
content := []byte(`serverAddr = "127.0.0.1"
serverPort = 7000
[[proxies]]
name = "test"
type = "tcp"
remotePort = 6000
`)
tests := []struct {
fieldPath string
wantLine int
}{
{"serverAddr", 1},
{"serverPort", 2},
{"name", 5},
{"type", 6},
{"remotePort", 7},
{"nonexistent", 0},
}
for _, tt := range tests {
t.Run(tt.fieldPath, func(t *testing.T) {
got := findFieldLineInContent(content, tt.fieldPath)
require.Equal(t, tt.wantLine, got)
})
}
}
func TestFormatDetection(t *testing.T) {
tests := []struct {
path string
format string
}{
{"config.toml", "toml"},
{"config.TOML", "toml"},
{"config.yaml", "yaml"},
{"config.yml", "yaml"},
{"config.json", "json"},
{"config.ini", ""},
{"config", ""},
}
for _, tt := range tests {
t.Run(tt.path, func(t *testing.T) {
require.Equal(t, tt.format, detectFormatFromPath(tt.path))
})
}
}
func TestValidTOMLStillWorks(t *testing.T) {
require := require.New(t)
// Valid TOML with format hint should work fine
content := `serverAddr = "127.0.0.1"
serverPort = 7000
[[proxies]]
name = "test"
type = "tcp"
remotePort = 6000
`
clientCfg := v1.ClientConfig{}
err := LoadConfigure([]byte(content), &clientCfg, false, "toml")
require.NoError(err)
require.Equal("127.0.0.1", clientCfg.ServerAddr)
require.Equal(7000, clientCfg.ServerPort)
require.Len(clientCfg.Proxies, 1)
}

View File

@@ -15,18 +15,16 @@
package source
import (
"cmp"
"errors"
"fmt"
"sort"
"maps"
"slices"
"sync"
v1 "github.com/fatedier/frp/pkg/config/v1"
)
type sourceEntry struct {
source Source
}
type Aggregator struct {
mu sync.RWMutex
@@ -58,17 +56,13 @@ func (a *Aggregator) StoreSource() *StoreSource {
return a.storeSource
}
func (a *Aggregator) getSourcesLocked() []sourceEntry {
sources := make([]sourceEntry, 0, 2)
func (a *Aggregator) getSourcesLocked() []Source {
sources := make([]Source, 0, 2)
if a.configSource != nil {
sources = append(sources, sourceEntry{
source: a.configSource,
})
sources = append(sources, a.configSource)
}
if a.storeSource != nil {
sources = append(sources, sourceEntry{
source: a.storeSource,
})
sources = append(sources, a.storeSource)
}
return sources
}
@@ -85,8 +79,8 @@ func (a *Aggregator) Load() ([]v1.ProxyConfigurer, []v1.VisitorConfigurer, error
proxyMap := make(map[string]v1.ProxyConfigurer)
visitorMap := make(map[string]v1.VisitorConfigurer)
for _, entry := range entries {
proxies, visitors, err := entry.source.Load()
for _, src := range entries {
proxies, visitors, err := src.Load()
if err != nil {
return nil, nil, fmt.Errorf("load source: %w", err)
}
@@ -105,21 +99,11 @@ func (a *Aggregator) mapsToSortedSlices(
proxyMap map[string]v1.ProxyConfigurer,
visitorMap map[string]v1.VisitorConfigurer,
) ([]v1.ProxyConfigurer, []v1.VisitorConfigurer) {
proxies := make([]v1.ProxyConfigurer, 0, len(proxyMap))
for _, p := range proxyMap {
proxies = append(proxies, p)
}
sort.Slice(proxies, func(i, j int) bool {
return proxies[i].GetBaseConfig().Name < proxies[j].GetBaseConfig().Name
proxies := slices.SortedFunc(maps.Values(proxyMap), func(x, y v1.ProxyConfigurer) int {
return cmp.Compare(x.GetBaseConfig().Name, y.GetBaseConfig().Name)
})
visitors := make([]v1.VisitorConfigurer, 0, len(visitorMap))
for _, v := range visitorMap {
visitors = append(visitors, v)
}
sort.Slice(visitors, func(i, j int) bool {
return visitors[i].GetBaseConfig().Name < visitors[j].GetBaseConfig().Name
visitors := slices.SortedFunc(maps.Values(visitorMap), func(x, y v1.VisitorConfigurer) int {
return cmp.Compare(x.GetBaseConfig().Name, y.GetBaseConfig().Name)
})
return proxies, visitors
}

View File

@@ -196,7 +196,28 @@ func TestAggregator_VisitorMerge(t *testing.T) {
require.Len(visitors, 2)
}
func TestAggregator_Load_ReturnsSharedReferences(t *testing.T) {
func TestAggregator_Load_ReturnsSortedByName(t *testing.T) {
require := require.New(t)
agg := newTestAggregator(t, nil)
err := agg.ConfigSource().ReplaceAll(
[]v1.ProxyConfigurer{mockProxy("charlie"), mockProxy("alice"), mockProxy("bob")},
[]v1.VisitorConfigurer{mockVisitor("zulu"), mockVisitor("alpha")},
)
require.NoError(err)
proxies, visitors, err := agg.Load()
require.NoError(err)
require.Len(proxies, 3)
require.Equal("alice", proxies[0].GetBaseConfig().Name)
require.Equal("bob", proxies[1].GetBaseConfig().Name)
require.Equal("charlie", proxies[2].GetBaseConfig().Name)
require.Len(visitors, 2)
require.Equal("alpha", visitors[0].GetBaseConfig().Name)
require.Equal("zulu", visitors[1].GetBaseConfig().Name)
}
func TestAggregator_Load_ReturnsDefensiveCopies(t *testing.T) {
require := require.New(t)
agg := newTestAggregator(t, nil)
@@ -213,5 +234,5 @@ func TestAggregator_Load_ReturnsSharedReferences(t *testing.T) {
proxies2, _, err := agg.Load()
require.NoError(err)
require.Len(proxies2, 1)
require.Equal("alice.ssh", proxies2[0].GetBaseConfig().Name)
require.Equal("ssh", proxies2[0].GetBaseConfig().Name)
}

View File

@@ -61,5 +61,5 @@ func (s *baseSource) Load() ([]v1.ProxyConfigurer, []v1.VisitorConfigurer, error
visitors = append(visitors, v)
}
return proxies, visitors, nil
return cloneConfigurers(proxies, visitors)
}

View File

@@ -0,0 +1,48 @@
package source
import (
"testing"
"github.com/stretchr/testify/require"
v1 "github.com/fatedier/frp/pkg/config/v1"
)
func TestBaseSourceLoadReturnsClonedConfigurers(t *testing.T) {
require := require.New(t)
src := NewConfigSource()
proxyCfg := &v1.TCPProxyConfig{
ProxyBaseConfig: v1.ProxyBaseConfig{
Name: "proxy1",
Type: "tcp",
},
}
visitorCfg := &v1.STCPVisitorConfig{
VisitorBaseConfig: v1.VisitorBaseConfig{
Name: "visitor1",
Type: "stcp",
},
}
err := src.ReplaceAll([]v1.ProxyConfigurer{proxyCfg}, []v1.VisitorConfigurer{visitorCfg})
require.NoError(err)
firstProxies, firstVisitors, err := src.Load()
require.NoError(err)
require.Len(firstProxies, 1)
require.Len(firstVisitors, 1)
// Mutate loaded objects as runtime completion would do.
firstProxies[0].Complete()
firstVisitors[0].Complete()
secondProxies, secondVisitors, err := src.Load()
require.NoError(err)
require.Len(secondProxies, 1)
require.Len(secondVisitors, 1)
require.Empty(secondProxies[0].GetBaseConfig().LocalIP)
require.Empty(secondVisitors[0].GetBaseConfig().BindAddr)
}

View File

@@ -0,0 +1,43 @@
// Copyright 2026 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package source
import (
"fmt"
v1 "github.com/fatedier/frp/pkg/config/v1"
)
func cloneConfigurers(
proxies []v1.ProxyConfigurer,
visitors []v1.VisitorConfigurer,
) ([]v1.ProxyConfigurer, []v1.VisitorConfigurer, error) {
clonedProxies := make([]v1.ProxyConfigurer, 0, len(proxies))
clonedVisitors := make([]v1.VisitorConfigurer, 0, len(visitors))
for _, cfg := range proxies {
if cfg == nil {
return nil, nil, fmt.Errorf("proxy cannot be nil")
}
clonedProxies = append(clonedProxies, cfg.Clone())
}
for _, cfg := range visitors {
if cfg == nil {
return nil, nil, fmt.Errorf("visitor cannot be nil")
}
clonedVisitors = append(clonedVisitors, cfg.Clone())
}
return clonedProxies, clonedVisitors, nil
}

View File

@@ -15,12 +15,13 @@
package source
import (
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/util/jsonx"
)
type StoreSourceConfig struct {
@@ -37,6 +38,11 @@ type StoreSource struct {
config StoreSourceConfig
}
var (
ErrAlreadyExists = errors.New("already exists")
ErrNotFound = errors.New("not found")
)
func NewStoreSource(cfg StoreSourceConfig) (*StoreSource, error) {
if cfg.Path == "" {
return nil, fmt.Errorf("path is required")
@@ -68,34 +74,44 @@ func (s *StoreSource) loadFromFileUnlocked() error {
return err
}
var stored storeData
if err := json.Unmarshal(data, &stored); err != nil {
type rawStoreData struct {
Proxies []jsonx.RawMessage `json:"proxies,omitempty"`
Visitors []jsonx.RawMessage `json:"visitors,omitempty"`
}
stored := rawStoreData{}
if err := jsonx.Unmarshal(data, &stored); err != nil {
return fmt.Errorf("failed to parse JSON: %w", err)
}
s.proxies = make(map[string]v1.ProxyConfigurer)
s.visitors = make(map[string]v1.VisitorConfigurer)
for _, tp := range stored.Proxies {
if tp.ProxyConfigurer != nil {
proxyCfg := tp.ProxyConfigurer
name := proxyCfg.GetBaseConfig().Name
if name == "" {
return fmt.Errorf("proxy name cannot be empty")
}
s.proxies[name] = proxyCfg
for i, proxyData := range stored.Proxies {
proxyCfg, err := v1.DecodeProxyConfigurerJSON(proxyData, v1.DecodeOptions{
DisallowUnknownFields: false,
})
if err != nil {
return fmt.Errorf("failed to decode proxy at index %d: %w", i, err)
}
name := proxyCfg.GetBaseConfig().Name
if name == "" {
return fmt.Errorf("proxy name cannot be empty")
}
s.proxies[name] = proxyCfg
}
for _, tv := range stored.Visitors {
if tv.VisitorConfigurer != nil {
visitorCfg := tv.VisitorConfigurer
name := visitorCfg.GetBaseConfig().Name
if name == "" {
return fmt.Errorf("visitor name cannot be empty")
}
s.visitors[name] = visitorCfg
for i, visitorData := range stored.Visitors {
visitorCfg, err := v1.DecodeVisitorConfigurerJSON(visitorData, v1.DecodeOptions{
DisallowUnknownFields: false,
})
if err != nil {
return fmt.Errorf("failed to decode visitor at index %d: %w", i, err)
}
name := visitorCfg.GetBaseConfig().Name
if name == "" {
return fmt.Errorf("visitor name cannot be empty")
}
s.visitors[name] = visitorCfg
}
return nil
@@ -114,7 +130,7 @@ func (s *StoreSource) saveToFileUnlocked() error {
stored.Visitors = append(stored.Visitors, v1.TypedVisitorConfig{VisitorConfigurer: v})
}
data, err := json.MarshalIndent(stored, "", " ")
data, err := jsonx.MarshalIndent(stored, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal JSON: %w", err)
}
@@ -170,7 +186,7 @@ func (s *StoreSource) AddProxy(proxy v1.ProxyConfigurer) error {
defer s.mu.Unlock()
if _, exists := s.proxies[name]; exists {
return fmt.Errorf("proxy %q already exists", name)
return fmt.Errorf("%w: proxy %q", ErrAlreadyExists, name)
}
s.proxies[name] = proxy
@@ -197,7 +213,7 @@ func (s *StoreSource) UpdateProxy(proxy v1.ProxyConfigurer) error {
oldProxy, exists := s.proxies[name]
if !exists {
return fmt.Errorf("proxy %q does not exist", name)
return fmt.Errorf("%w: proxy %q", ErrNotFound, name)
}
s.proxies[name] = proxy
@@ -219,7 +235,7 @@ func (s *StoreSource) RemoveProxy(name string) error {
oldProxy, exists := s.proxies[name]
if !exists {
return fmt.Errorf("proxy %q does not exist", name)
return fmt.Errorf("%w: proxy %q", ErrNotFound, name)
}
delete(s.proxies, name)
@@ -256,7 +272,7 @@ func (s *StoreSource) AddVisitor(visitor v1.VisitorConfigurer) error {
defer s.mu.Unlock()
if _, exists := s.visitors[name]; exists {
return fmt.Errorf("visitor %q already exists", name)
return fmt.Errorf("%w: visitor %q", ErrAlreadyExists, name)
}
s.visitors[name] = visitor
@@ -283,7 +299,7 @@ func (s *StoreSource) UpdateVisitor(visitor v1.VisitorConfigurer) error {
oldVisitor, exists := s.visitors[name]
if !exists {
return fmt.Errorf("visitor %q does not exist", name)
return fmt.Errorf("%w: visitor %q", ErrNotFound, name)
}
s.visitors[name] = visitor
@@ -305,7 +321,7 @@ func (s *StoreSource) RemoveVisitor(name string) error {
oldVisitor, exists := s.visitors[name]
if !exists {
return fmt.Errorf("visitor %q does not exist", name)
return fmt.Errorf("%w: visitor %q", ErrNotFound, name)
}
delete(s.visitors, name)

View File

@@ -15,7 +15,6 @@
package source
import (
"encoding/json"
"os"
"path/filepath"
"testing"
@@ -23,6 +22,7 @@ import (
"github.com/stretchr/testify/require"
v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/util/jsonx"
)
func TestStoreSource_AddProxyAndVisitor_DoesNotApplyRuntimeDefaults(t *testing.T) {
@@ -80,7 +80,7 @@ func TestStoreSource_LoadFromFile_DoesNotApplyRuntimeDefaults(t *testing.T) {
Proxies: []v1.TypedProxyConfig{{ProxyConfigurer: proxyCfg}},
Visitors: []v1.TypedVisitorConfig{{VisitorConfigurer: visitorCfg}},
}
data, err := json.Marshal(stored)
data, err := jsonx.Marshal(stored)
require.NoError(err)
err = os.WriteFile(path, data, 0o600)
require.NoError(err)
@@ -97,3 +97,25 @@ func TestStoreSource_LoadFromFile_DoesNotApplyRuntimeDefaults(t *testing.T) {
require.Empty(gotVisitor.GetBaseConfig().BindAddr)
require.Empty(gotVisitor.(*v1.XTCPVisitorConfig).Protocol)
}
func TestStoreSource_LoadFromFile_UnknownFieldsAreIgnored(t *testing.T) {
require := require.New(t)
path := filepath.Join(t.TempDir(), "store.json")
raw := []byte(`{
"proxies": [
{"name":"proxy1","type":"tcp","localPort":10080,"unexpected":"value"}
],
"visitors": [
{"name":"visitor1","type":"xtcp","serverName":"server1","secretKey":"secret","bindPort":10081,"unexpected":"value"}
]
}`)
err := os.WriteFile(path, raw, 0o600)
require.NoError(err)
storeSource, err := NewStoreSource(StoreSourceConfig{Path: path})
require.NoError(err)
require.NotNil(storeSource.GetProxy("proxy1"))
require.NotNil(storeSource.GetVisitor("visitor1"))
}

View File

@@ -38,7 +38,7 @@ func parseNumberRangePair(firstRangeStr, secondRangeStr string) ([]NumberPair, e
return nil, fmt.Errorf("first and second range numbers are not in pairs")
}
pairs := make([]NumberPair, 0, len(firstRangeNumbers))
for i := 0; i < len(firstRangeNumbers); i++ {
for i := range firstRangeNumbers {
pairs = append(pairs, NumberPair{
First: firstRangeNumbers[i],
Second: secondRangeNumbers[i],

View File

@@ -70,24 +70,18 @@ func (q *BandwidthQuantity) UnmarshalString(s string) error {
f float64
err error
)
switch {
case strings.HasSuffix(s, "MB"):
if fstr, ok := strings.CutSuffix(s, "MB"); ok {
base = MB
fstr := strings.TrimSuffix(s, "MB")
f, err = strconv.ParseFloat(fstr, 64)
if err != nil {
return err
}
case strings.HasSuffix(s, "KB"):
} else if fstr, ok := strings.CutSuffix(s, "KB"); ok {
base = KB
fstr := strings.TrimSuffix(s, "KB")
f, err = strconv.ParseFloat(fstr, 64)
if err != nil {
return err
}
default:
} else {
return errors.New("unit not support")
}
if err != nil {
return err
}
q.s = s
q.i = int64(f * float64(base))
@@ -143,8 +137,8 @@ func (p PortsRangeSlice) String() string {
func NewPortsRangeSliceFromString(str string) ([]PortsRange, error) {
str = strings.TrimSpace(str)
out := []PortsRange{}
numRanges := strings.Split(str, ",")
for _, numRangeStr := range numRanges {
numRanges := strings.SplitSeq(str, ",")
for numRangeStr := range numRanges {
// 1000-2000 or 2001
numArray := strings.Split(numRangeStr, "-")
// length: only 1 or 2 is correct

View File

@@ -39,6 +39,31 @@ func TestBandwidthQuantity(t *testing.T) {
require.Equal(`{"b":"1KB","int":5}`, string(buf))
}
func TestBandwidthQuantity_MB(t *testing.T) {
require := require.New(t)
var w Wrap
err := json.Unmarshal([]byte(`{"b":"2MB","int":1}`), &w)
require.NoError(err)
require.EqualValues(2*MB, w.B.Bytes())
buf, err := json.Marshal(&w)
require.NoError(err)
require.Equal(`{"b":"2MB","int":1}`, string(buf))
}
func TestBandwidthQuantity_InvalidUnit(t *testing.T) {
var w Wrap
err := json.Unmarshal([]byte(`{"b":"1GB","int":1}`), &w)
require.Error(t, err)
}
func TestBandwidthQuantity_InvalidNumber(t *testing.T) {
var w Wrap
err := json.Unmarshal([]byte(`{"b":"abcKB","int":1}`), &w)
require.Error(t, err)
}
func TestPortsRangeSlice2String(t *testing.T) {
require := require.New(t)

109
pkg/config/v1/clone_test.go Normal file
View File

@@ -0,0 +1,109 @@
package v1
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestProxyCloneDeepCopy(t *testing.T) {
require := require.New(t)
enabled := true
pluginHTTP2 := true
cfg := &HTTPProxyConfig{
ProxyBaseConfig: ProxyBaseConfig{
Name: "p1",
Type: "http",
Enabled: &enabled,
Annotations: map[string]string{"a": "1"},
Metadatas: map[string]string{"m": "1"},
HealthCheck: HealthCheckConfig{
Type: "http",
HTTPHeaders: []HTTPHeader{
{Name: "X-Test", Value: "v1"},
},
},
ProxyBackend: ProxyBackend{
Plugin: TypedClientPluginOptions{
Type: PluginHTTPS2HTTP,
ClientPluginOptions: &HTTPS2HTTPPluginOptions{
Type: PluginHTTPS2HTTP,
EnableHTTP2: &pluginHTTP2,
RequestHeaders: HeaderOperations{Set: map[string]string{"k": "v"}},
},
},
},
},
DomainConfig: DomainConfig{
CustomDomains: []string{"a.example.com"},
SubDomain: "a",
},
Locations: []string{"/api"},
RequestHeaders: HeaderOperations{Set: map[string]string{"h1": "v1"}},
ResponseHeaders: HeaderOperations{Set: map[string]string{"h2": "v2"}},
}
cloned := cfg.Clone().(*HTTPProxyConfig)
*cloned.Enabled = false
cloned.Annotations["a"] = "changed"
cloned.Metadatas["m"] = "changed"
cloned.HealthCheck.HTTPHeaders[0].Value = "changed"
cloned.CustomDomains[0] = "b.example.com"
cloned.Locations[0] = "/new"
cloned.RequestHeaders.Set["h1"] = "changed"
cloned.ResponseHeaders.Set["h2"] = "changed"
clientPlugin := cloned.Plugin.ClientPluginOptions.(*HTTPS2HTTPPluginOptions)
*clientPlugin.EnableHTTP2 = false
clientPlugin.RequestHeaders.Set["k"] = "changed"
require.True(*cfg.Enabled)
require.Equal("1", cfg.Annotations["a"])
require.Equal("1", cfg.Metadatas["m"])
require.Equal("v1", cfg.HealthCheck.HTTPHeaders[0].Value)
require.Equal("a.example.com", cfg.CustomDomains[0])
require.Equal("/api", cfg.Locations[0])
require.Equal("v1", cfg.RequestHeaders.Set["h1"])
require.Equal("v2", cfg.ResponseHeaders.Set["h2"])
origPlugin := cfg.Plugin.ClientPluginOptions.(*HTTPS2HTTPPluginOptions)
require.True(*origPlugin.EnableHTTP2)
require.Equal("v", origPlugin.RequestHeaders.Set["k"])
}
func TestVisitorCloneDeepCopy(t *testing.T) {
require := require.New(t)
enabled := true
cfg := &XTCPVisitorConfig{
VisitorBaseConfig: VisitorBaseConfig{
Name: "v1",
Type: "xtcp",
Enabled: &enabled,
ServerName: "server",
BindPort: 7000,
Plugin: TypedVisitorPluginOptions{
Type: VisitorPluginVirtualNet,
VisitorPluginOptions: &VirtualNetVisitorPluginOptions{
Type: VisitorPluginVirtualNet,
DestinationIP: "10.0.0.1",
},
},
},
NatTraversal: &NatTraversalConfig{
DisableAssistedAddrs: true,
},
}
cloned := cfg.Clone().(*XTCPVisitorConfig)
*cloned.Enabled = false
cloned.NatTraversal.DisableAssistedAddrs = false
visitorPlugin := cloned.Plugin.VisitorPluginOptions.(*VirtualNetVisitorPluginOptions)
visitorPlugin.DestinationIP = "10.0.0.2"
require.True(*cfg.Enabled)
require.True(cfg.NatTraversal.DisableAssistedAddrs)
origPlugin := cfg.Plugin.VisitorPluginOptions.(*VirtualNetVisitorPluginOptions)
require.Equal("10.0.0.1", origPlugin.DestinationIP)
}

View File

@@ -15,23 +15,11 @@
package v1
import (
"sync"
"maps"
"github.com/fatedier/frp/pkg/util/util"
)
// TODO(fatedier): Due to the current implementation issue of the go json library, the UnmarshalJSON method
// of a custom struct cannot access the DisallowUnknownFields parameter of the parent decoder.
// Here, a global variable is temporarily used to control whether unknown fields are allowed.
// Once the v2 version is implemented by the community, we can switch to a standardized approach.
//
// https://github.com/golang/go/issues/41144
// https://github.com/golang/go/discussions/63397
var (
DisallowUnknownFields = false
DisallowUnknownFieldsMu sync.Mutex
)
type AuthScope string
const (
@@ -104,6 +92,14 @@ type NatTraversalConfig struct {
DisableAssistedAddrs bool `json:"disableAssistedAddrs,omitempty"`
}
func (c *NatTraversalConfig) Clone() *NatTraversalConfig {
if c == nil {
return nil
}
out := *c
return &out
}
type LogConfig struct {
// This is destination where frp should write the logs.
// If "console" is used, logs will be printed to stdout, otherwise,
@@ -138,6 +134,12 @@ type HeaderOperations struct {
Set map[string]string `json:"set,omitempty"`
}
func (o HeaderOperations) Clone() HeaderOperations {
return HeaderOperations{
Set: maps.Clone(o.Set),
}
}
type HTTPHeader struct {
Name string `json:"name"`
Value string `json:"value"`

195
pkg/config/v1/decode.go Normal file
View File

@@ -0,0 +1,195 @@
// Copyright 2026 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package v1
import (
"errors"
"fmt"
"reflect"
"github.com/fatedier/frp/pkg/util/jsonx"
)
type DecodeOptions struct {
DisallowUnknownFields bool
}
func decodeJSONWithOptions(b []byte, out any, options DecodeOptions) error {
return jsonx.UnmarshalWithOptions(b, out, jsonx.DecodeOptions{
RejectUnknownMembers: options.DisallowUnknownFields,
})
}
func isJSONNull(b []byte) bool {
return len(b) == 0 || string(b) == "null"
}
type typedEnvelope struct {
Type string `json:"type"`
Plugin jsonx.RawMessage `json:"plugin,omitempty"`
}
func DecodeProxyConfigurerJSON(b []byte, options DecodeOptions) (ProxyConfigurer, error) {
if isJSONNull(b) {
return nil, errors.New("type is required")
}
var env typedEnvelope
if err := jsonx.Unmarshal(b, &env); err != nil {
return nil, err
}
configurer := NewProxyConfigurerByType(ProxyType(env.Type))
if configurer == nil {
return nil, fmt.Errorf("unknown proxy type: %s", env.Type)
}
if err := decodeJSONWithOptions(b, configurer, options); err != nil {
return nil, fmt.Errorf("unmarshal ProxyConfig error: %v", err)
}
if len(env.Plugin) > 0 && !isJSONNull(env.Plugin) {
plugin, err := DecodeClientPluginOptionsJSON(env.Plugin, options)
if err != nil {
return nil, fmt.Errorf("unmarshal proxy plugin error: %v", err)
}
configurer.GetBaseConfig().Plugin = plugin
}
return configurer, nil
}
func DecodeVisitorConfigurerJSON(b []byte, options DecodeOptions) (VisitorConfigurer, error) {
if isJSONNull(b) {
return nil, errors.New("type is required")
}
var env typedEnvelope
if err := jsonx.Unmarshal(b, &env); err != nil {
return nil, err
}
configurer := NewVisitorConfigurerByType(VisitorType(env.Type))
if configurer == nil {
return nil, fmt.Errorf("unknown visitor type: %s", env.Type)
}
if err := decodeJSONWithOptions(b, configurer, options); err != nil {
return nil, fmt.Errorf("unmarshal VisitorConfig error: %v", err)
}
if len(env.Plugin) > 0 && !isJSONNull(env.Plugin) {
plugin, err := DecodeVisitorPluginOptionsJSON(env.Plugin, options)
if err != nil {
return nil, fmt.Errorf("unmarshal visitor plugin error: %v", err)
}
configurer.GetBaseConfig().Plugin = plugin
}
return configurer, nil
}
func DecodeClientPluginOptionsJSON(b []byte, options DecodeOptions) (TypedClientPluginOptions, error) {
if isJSONNull(b) {
return TypedClientPluginOptions{}, nil
}
var env typedEnvelope
if err := jsonx.Unmarshal(b, &env); err != nil {
return TypedClientPluginOptions{}, err
}
if env.Type == "" {
return TypedClientPluginOptions{}, errors.New("plugin type is empty")
}
v, ok := clientPluginOptionsTypeMap[env.Type]
if !ok {
return TypedClientPluginOptions{}, fmt.Errorf("unknown plugin type: %s", env.Type)
}
optionsStruct := reflect.New(v).Interface().(ClientPluginOptions)
if err := decodeJSONWithOptions(b, optionsStruct, options); err != nil {
return TypedClientPluginOptions{}, fmt.Errorf("unmarshal ClientPluginOptions error: %v", err)
}
return TypedClientPluginOptions{
Type: env.Type,
ClientPluginOptions: optionsStruct,
}, nil
}
func DecodeVisitorPluginOptionsJSON(b []byte, options DecodeOptions) (TypedVisitorPluginOptions, error) {
if isJSONNull(b) {
return TypedVisitorPluginOptions{}, nil
}
var env typedEnvelope
if err := jsonx.Unmarshal(b, &env); err != nil {
return TypedVisitorPluginOptions{}, err
}
if env.Type == "" {
return TypedVisitorPluginOptions{}, errors.New("visitor plugin type is empty")
}
v, ok := visitorPluginOptionsTypeMap[env.Type]
if !ok {
return TypedVisitorPluginOptions{}, fmt.Errorf("unknown visitor plugin type: %s", env.Type)
}
optionsStruct := reflect.New(v).Interface().(VisitorPluginOptions)
if err := decodeJSONWithOptions(b, optionsStruct, options); err != nil {
return TypedVisitorPluginOptions{}, fmt.Errorf("unmarshal VisitorPluginOptions error: %v", err)
}
return TypedVisitorPluginOptions{
Type: env.Type,
VisitorPluginOptions: optionsStruct,
}, nil
}
func DecodeClientConfigJSON(b []byte, options DecodeOptions) (ClientConfig, error) {
type rawClientConfig struct {
ClientCommonConfig
Proxies []jsonx.RawMessage `json:"proxies,omitempty"`
Visitors []jsonx.RawMessage `json:"visitors,omitempty"`
}
raw := rawClientConfig{}
if err := decodeJSONWithOptions(b, &raw, options); err != nil {
return ClientConfig{}, err
}
cfg := ClientConfig{
ClientCommonConfig: raw.ClientCommonConfig,
Proxies: make([]TypedProxyConfig, 0, len(raw.Proxies)),
Visitors: make([]TypedVisitorConfig, 0, len(raw.Visitors)),
}
for i, proxyData := range raw.Proxies {
proxyCfg, err := DecodeProxyConfigurerJSON(proxyData, options)
if err != nil {
return ClientConfig{}, fmt.Errorf("decode proxy at index %d: %w", i, err)
}
cfg.Proxies = append(cfg.Proxies, TypedProxyConfig{
Type: proxyCfg.GetBaseConfig().Type,
ProxyConfigurer: proxyCfg,
})
}
for i, visitorData := range raw.Visitors {
visitorCfg, err := DecodeVisitorConfigurerJSON(visitorData, options)
if err != nil {
return ClientConfig{}, fmt.Errorf("decode visitor at index %d: %w", i, err)
}
cfg.Visitors = append(cfg.Visitors, TypedVisitorConfig{
Type: visitorCfg.GetBaseConfig().Type,
VisitorConfigurer: visitorCfg,
})
}
return cfg, nil
}

View File

@@ -0,0 +1,86 @@
// Copyright 2026 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package v1
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestDecodeProxyConfigurerJSON_StrictPluginUnknownFields(t *testing.T) {
require := require.New(t)
data := []byte(`{
"name":"p1",
"type":"tcp",
"localPort":10080,
"plugin":{
"type":"http2https",
"localAddr":"127.0.0.1:8080",
"unknownInPlugin":"value"
}
}`)
_, err := DecodeProxyConfigurerJSON(data, DecodeOptions{DisallowUnknownFields: false})
require.NoError(err)
_, err = DecodeProxyConfigurerJSON(data, DecodeOptions{DisallowUnknownFields: true})
require.ErrorContains(err, "unknownInPlugin")
}
func TestDecodeVisitorConfigurerJSON_StrictPluginUnknownFields(t *testing.T) {
require := require.New(t)
data := []byte(`{
"name":"v1",
"type":"stcp",
"serverName":"server",
"bindPort":10081,
"plugin":{
"type":"virtual_net",
"destinationIP":"10.0.0.1",
"unknownInPlugin":"value"
}
}`)
_, err := DecodeVisitorConfigurerJSON(data, DecodeOptions{DisallowUnknownFields: false})
require.NoError(err)
_, err = DecodeVisitorConfigurerJSON(data, DecodeOptions{DisallowUnknownFields: true})
require.ErrorContains(err, "unknownInPlugin")
}
func TestDecodeClientConfigJSON_StrictUnknownProxyField(t *testing.T) {
require := require.New(t)
data := []byte(`{
"serverPort":7000,
"proxies":[
{
"name":"p1",
"type":"tcp",
"localPort":10080,
"unknownField":"value"
}
]
}`)
_, err := DecodeClientConfigJSON(data, DecodeOptions{DisallowUnknownFields: false})
require.NoError(err)
_, err = DecodeClientConfigJSON(data, DecodeOptions{DisallowUnknownFields: true})
require.ErrorContains(err, "unknownField")
}

View File

@@ -15,14 +15,13 @@
package v1
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"maps"
"reflect"
"slices"
"github.com/fatedier/frp/pkg/config/types"
"github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/util/jsonx"
"github.com/fatedier/frp/pkg/util/util"
)
@@ -100,11 +99,23 @@ type HealthCheckConfig struct {
HTTPHeaders []HTTPHeader `json:"httpHeaders,omitempty"`
}
func (c HealthCheckConfig) Clone() HealthCheckConfig {
out := c
out.HTTPHeaders = slices.Clone(c.HTTPHeaders)
return out
}
type DomainConfig struct {
CustomDomains []string `json:"customDomains,omitempty"`
SubDomain string `json:"subdomain,omitempty"`
}
func (c DomainConfig) Clone() DomainConfig {
out := c
out.CustomDomains = slices.Clone(c.CustomDomains)
return out
}
type ProxyBaseConfig struct {
Name string `json:"name"`
Type string `json:"type"`
@@ -120,6 +131,22 @@ type ProxyBaseConfig struct {
ProxyBackend
}
func (c ProxyBaseConfig) Clone() ProxyBaseConfig {
out := c
out.Enabled = util.ClonePtr(c.Enabled)
out.Annotations = maps.Clone(c.Annotations)
out.Metadatas = maps.Clone(c.Metadatas)
out.HealthCheck = c.HealthCheck.Clone()
out.ProxyBackend = c.ProxyBackend.Clone()
return out
}
func (c ProxyBackend) Clone() ProxyBackend {
out := c
out.Plugin = c.Plugin.Clone()
return out
}
func (c *ProxyBaseConfig) GetBaseConfig() *ProxyBaseConfig {
return c
}
@@ -172,40 +199,24 @@ type TypedProxyConfig struct {
}
func (c *TypedProxyConfig) UnmarshalJSON(b []byte) error {
if len(b) == 4 && string(b) == "null" {
return errors.New("type is required")
}
typeStruct := struct {
Type string `json:"type"`
}{}
if err := json.Unmarshal(b, &typeStruct); err != nil {
configurer, err := DecodeProxyConfigurerJSON(b, DecodeOptions{})
if err != nil {
return err
}
c.Type = typeStruct.Type
configurer := NewProxyConfigurerByType(ProxyType(typeStruct.Type))
if configurer == nil {
return fmt.Errorf("unknown proxy type: %s", typeStruct.Type)
}
decoder := json.NewDecoder(bytes.NewBuffer(b))
if DisallowUnknownFields {
decoder.DisallowUnknownFields()
}
if err := decoder.Decode(configurer); err != nil {
return fmt.Errorf("unmarshal ProxyConfig error: %v", err)
}
c.Type = configurer.GetBaseConfig().Type
c.ProxyConfigurer = configurer
return nil
}
func (c *TypedProxyConfig) MarshalJSON() ([]byte, error) {
return json.Marshal(c.ProxyConfigurer)
return jsonx.Marshal(c.ProxyConfigurer)
}
type ProxyConfigurer interface {
Complete()
GetBaseConfig() *ProxyBaseConfig
Clone() ProxyConfigurer
// MarshalToMsg marshals this config into a msg.NewProxy message. This
// function will be called on the frpc side.
MarshalToMsg(*msg.NewProxy)
@@ -228,14 +239,14 @@ const (
)
var proxyConfigTypeMap = map[ProxyType]reflect.Type{
ProxyTypeTCP: reflect.TypeOf(TCPProxyConfig{}),
ProxyTypeUDP: reflect.TypeOf(UDPProxyConfig{}),
ProxyTypeHTTP: reflect.TypeOf(HTTPProxyConfig{}),
ProxyTypeHTTPS: reflect.TypeOf(HTTPSProxyConfig{}),
ProxyTypeTCPMUX: reflect.TypeOf(TCPMuxProxyConfig{}),
ProxyTypeSTCP: reflect.TypeOf(STCPProxyConfig{}),
ProxyTypeXTCP: reflect.TypeOf(XTCPProxyConfig{}),
ProxyTypeSUDP: reflect.TypeOf(SUDPProxyConfig{}),
ProxyTypeTCP: reflect.TypeFor[TCPProxyConfig](),
ProxyTypeUDP: reflect.TypeFor[UDPProxyConfig](),
ProxyTypeHTTP: reflect.TypeFor[HTTPProxyConfig](),
ProxyTypeHTTPS: reflect.TypeFor[HTTPSProxyConfig](),
ProxyTypeTCPMUX: reflect.TypeFor[TCPMuxProxyConfig](),
ProxyTypeSTCP: reflect.TypeFor[STCPProxyConfig](),
ProxyTypeXTCP: reflect.TypeFor[XTCPProxyConfig](),
ProxyTypeSUDP: reflect.TypeFor[SUDPProxyConfig](),
}
func NewProxyConfigurerByType(proxyType ProxyType) ProxyConfigurer {
@@ -268,6 +279,12 @@ func (c *TCPProxyConfig) UnmarshalFromMsg(m *msg.NewProxy) {
c.RemotePort = m.RemotePort
}
func (c *TCPProxyConfig) Clone() ProxyConfigurer {
out := *c
out.ProxyBaseConfig = c.ProxyBaseConfig.Clone()
return &out
}
var _ ProxyConfigurer = &UDPProxyConfig{}
type UDPProxyConfig struct {
@@ -288,6 +305,12 @@ func (c *UDPProxyConfig) UnmarshalFromMsg(m *msg.NewProxy) {
c.RemotePort = m.RemotePort
}
func (c *UDPProxyConfig) Clone() ProxyConfigurer {
out := *c
out.ProxyBaseConfig = c.ProxyBaseConfig.Clone()
return &out
}
var _ ProxyConfigurer = &HTTPProxyConfig{}
type HTTPProxyConfig struct {
@@ -331,6 +354,16 @@ func (c *HTTPProxyConfig) UnmarshalFromMsg(m *msg.NewProxy) {
c.RouteByHTTPUser = m.RouteByHTTPUser
}
func (c *HTTPProxyConfig) Clone() ProxyConfigurer {
out := *c
out.ProxyBaseConfig = c.ProxyBaseConfig.Clone()
out.DomainConfig = c.DomainConfig.Clone()
out.Locations = slices.Clone(c.Locations)
out.RequestHeaders = c.RequestHeaders.Clone()
out.ResponseHeaders = c.ResponseHeaders.Clone()
return &out
}
var _ ProxyConfigurer = &HTTPSProxyConfig{}
type HTTPSProxyConfig struct {
@@ -352,6 +385,13 @@ func (c *HTTPSProxyConfig) UnmarshalFromMsg(m *msg.NewProxy) {
c.SubDomain = m.SubDomain
}
func (c *HTTPSProxyConfig) Clone() ProxyConfigurer {
out := *c
out.ProxyBaseConfig = c.ProxyBaseConfig.Clone()
out.DomainConfig = c.DomainConfig.Clone()
return &out
}
type TCPMultiplexerType string
const (
@@ -392,6 +432,13 @@ func (c *TCPMuxProxyConfig) UnmarshalFromMsg(m *msg.NewProxy) {
c.RouteByHTTPUser = m.RouteByHTTPUser
}
func (c *TCPMuxProxyConfig) Clone() ProxyConfigurer {
out := *c
out.ProxyBaseConfig = c.ProxyBaseConfig.Clone()
out.DomainConfig = c.DomainConfig.Clone()
return &out
}
var _ ProxyConfigurer = &STCPProxyConfig{}
type STCPProxyConfig struct {
@@ -415,6 +462,13 @@ func (c *STCPProxyConfig) UnmarshalFromMsg(m *msg.NewProxy) {
c.AllowUsers = m.AllowUsers
}
func (c *STCPProxyConfig) Clone() ProxyConfigurer {
out := *c
out.ProxyBaseConfig = c.ProxyBaseConfig.Clone()
out.AllowUsers = slices.Clone(c.AllowUsers)
return &out
}
var _ ProxyConfigurer = &XTCPProxyConfig{}
type XTCPProxyConfig struct {
@@ -441,6 +495,14 @@ func (c *XTCPProxyConfig) UnmarshalFromMsg(m *msg.NewProxy) {
c.AllowUsers = m.AllowUsers
}
func (c *XTCPProxyConfig) Clone() ProxyConfigurer {
out := *c
out.ProxyBaseConfig = c.ProxyBaseConfig.Clone()
out.AllowUsers = slices.Clone(c.AllowUsers)
out.NatTraversal = c.NatTraversal.Clone()
return &out
}
var _ ProxyConfigurer = &SUDPProxyConfig{}
type SUDPProxyConfig struct {
@@ -463,3 +525,10 @@ func (c *SUDPProxyConfig) UnmarshalFromMsg(m *msg.NewProxy) {
c.Secretkey = m.Sk
c.AllowUsers = m.AllowUsers
}
func (c *SUDPProxyConfig) Clone() ProxyConfigurer {
out := *c
out.ProxyBaseConfig = c.ProxyBaseConfig.Clone()
out.AllowUsers = slices.Clone(c.AllowUsers)
return &out
}

View File

@@ -15,14 +15,11 @@
package v1
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"reflect"
"github.com/samber/lo"
"github.com/fatedier/frp/pkg/util/jsonx"
"github.com/fatedier/frp/pkg/util/util"
)
@@ -40,20 +37,21 @@ const (
)
var clientPluginOptionsTypeMap = map[string]reflect.Type{
PluginHTTP2HTTPS: reflect.TypeOf(HTTP2HTTPSPluginOptions{}),
PluginHTTPProxy: reflect.TypeOf(HTTPProxyPluginOptions{}),
PluginHTTPS2HTTP: reflect.TypeOf(HTTPS2HTTPPluginOptions{}),
PluginHTTPS2HTTPS: reflect.TypeOf(HTTPS2HTTPSPluginOptions{}),
PluginHTTP2HTTP: reflect.TypeOf(HTTP2HTTPPluginOptions{}),
PluginSocks5: reflect.TypeOf(Socks5PluginOptions{}),
PluginStaticFile: reflect.TypeOf(StaticFilePluginOptions{}),
PluginUnixDomainSocket: reflect.TypeOf(UnixDomainSocketPluginOptions{}),
PluginTLS2Raw: reflect.TypeOf(TLS2RawPluginOptions{}),
PluginVirtualNet: reflect.TypeOf(VirtualNetPluginOptions{}),
PluginHTTP2HTTPS: reflect.TypeFor[HTTP2HTTPSPluginOptions](),
PluginHTTPProxy: reflect.TypeFor[HTTPProxyPluginOptions](),
PluginHTTPS2HTTP: reflect.TypeFor[HTTPS2HTTPPluginOptions](),
PluginHTTPS2HTTPS: reflect.TypeFor[HTTPS2HTTPSPluginOptions](),
PluginHTTP2HTTP: reflect.TypeFor[HTTP2HTTPPluginOptions](),
PluginSocks5: reflect.TypeFor[Socks5PluginOptions](),
PluginStaticFile: reflect.TypeFor[StaticFilePluginOptions](),
PluginUnixDomainSocket: reflect.TypeFor[UnixDomainSocketPluginOptions](),
PluginTLS2Raw: reflect.TypeFor[TLS2RawPluginOptions](),
PluginVirtualNet: reflect.TypeFor[VirtualNetPluginOptions](),
}
type ClientPluginOptions interface {
Complete()
Clone() ClientPluginOptions
}
type TypedClientPluginOptions struct {
@@ -61,43 +59,25 @@ type TypedClientPluginOptions struct {
ClientPluginOptions
}
func (c *TypedClientPluginOptions) UnmarshalJSON(b []byte) error {
if len(b) == 4 && string(b) == "null" {
return nil
func (c TypedClientPluginOptions) Clone() TypedClientPluginOptions {
out := c
if c.ClientPluginOptions != nil {
out.ClientPluginOptions = c.ClientPluginOptions.Clone()
}
return out
}
typeStruct := struct {
Type string `json:"type"`
}{}
if err := json.Unmarshal(b, &typeStruct); err != nil {
func (c *TypedClientPluginOptions) UnmarshalJSON(b []byte) error {
decoded, err := DecodeClientPluginOptionsJSON(b, DecodeOptions{})
if err != nil {
return err
}
c.Type = typeStruct.Type
if c.Type == "" {
return errors.New("plugin type is empty")
}
v, ok := clientPluginOptionsTypeMap[typeStruct.Type]
if !ok {
return fmt.Errorf("unknown plugin type: %s", typeStruct.Type)
}
options := reflect.New(v).Interface().(ClientPluginOptions)
decoder := json.NewDecoder(bytes.NewBuffer(b))
if DisallowUnknownFields {
decoder.DisallowUnknownFields()
}
if err := decoder.Decode(options); err != nil {
return fmt.Errorf("unmarshal ClientPluginOptions error: %v", err)
}
c.ClientPluginOptions = options
*c = decoded
return nil
}
func (c *TypedClientPluginOptions) MarshalJSON() ([]byte, error) {
return json.Marshal(c.ClientPluginOptions)
return jsonx.Marshal(c.ClientPluginOptions)
}
type HTTP2HTTPSPluginOptions struct {
@@ -109,6 +89,15 @@ type HTTP2HTTPSPluginOptions struct {
func (o *HTTP2HTTPSPluginOptions) Complete() {}
func (o *HTTP2HTTPSPluginOptions) Clone() ClientPluginOptions {
if o == nil {
return nil
}
out := *o
out.RequestHeaders = o.RequestHeaders.Clone()
return &out
}
type HTTPProxyPluginOptions struct {
Type string `json:"type,omitempty"`
HTTPUser string `json:"httpUser,omitempty"`
@@ -117,6 +106,14 @@ type HTTPProxyPluginOptions struct {
func (o *HTTPProxyPluginOptions) Complete() {}
func (o *HTTPProxyPluginOptions) Clone() ClientPluginOptions {
if o == nil {
return nil
}
out := *o
return &out
}
type HTTPS2HTTPPluginOptions struct {
Type string `json:"type,omitempty"`
LocalAddr string `json:"localAddr,omitempty"`
@@ -131,6 +128,16 @@ func (o *HTTPS2HTTPPluginOptions) Complete() {
o.EnableHTTP2 = util.EmptyOr(o.EnableHTTP2, lo.ToPtr(true))
}
func (o *HTTPS2HTTPPluginOptions) Clone() ClientPluginOptions {
if o == nil {
return nil
}
out := *o
out.RequestHeaders = o.RequestHeaders.Clone()
out.EnableHTTP2 = util.ClonePtr(o.EnableHTTP2)
return &out
}
type HTTPS2HTTPSPluginOptions struct {
Type string `json:"type,omitempty"`
LocalAddr string `json:"localAddr,omitempty"`
@@ -145,6 +152,16 @@ func (o *HTTPS2HTTPSPluginOptions) Complete() {
o.EnableHTTP2 = util.EmptyOr(o.EnableHTTP2, lo.ToPtr(true))
}
func (o *HTTPS2HTTPSPluginOptions) Clone() ClientPluginOptions {
if o == nil {
return nil
}
out := *o
out.RequestHeaders = o.RequestHeaders.Clone()
out.EnableHTTP2 = util.ClonePtr(o.EnableHTTP2)
return &out
}
type HTTP2HTTPPluginOptions struct {
Type string `json:"type,omitempty"`
LocalAddr string `json:"localAddr,omitempty"`
@@ -154,6 +171,15 @@ type HTTP2HTTPPluginOptions struct {
func (o *HTTP2HTTPPluginOptions) Complete() {}
func (o *HTTP2HTTPPluginOptions) Clone() ClientPluginOptions {
if o == nil {
return nil
}
out := *o
out.RequestHeaders = o.RequestHeaders.Clone()
return &out
}
type Socks5PluginOptions struct {
Type string `json:"type,omitempty"`
Username string `json:"username,omitempty"`
@@ -162,6 +188,14 @@ type Socks5PluginOptions struct {
func (o *Socks5PluginOptions) Complete() {}
func (o *Socks5PluginOptions) Clone() ClientPluginOptions {
if o == nil {
return nil
}
out := *o
return &out
}
type StaticFilePluginOptions struct {
Type string `json:"type,omitempty"`
LocalPath string `json:"localPath,omitempty"`
@@ -172,6 +206,14 @@ type StaticFilePluginOptions struct {
func (o *StaticFilePluginOptions) Complete() {}
func (o *StaticFilePluginOptions) Clone() ClientPluginOptions {
if o == nil {
return nil
}
out := *o
return &out
}
type UnixDomainSocketPluginOptions struct {
Type string `json:"type,omitempty"`
UnixPath string `json:"unixPath,omitempty"`
@@ -179,6 +221,14 @@ type UnixDomainSocketPluginOptions struct {
func (o *UnixDomainSocketPluginOptions) Complete() {}
func (o *UnixDomainSocketPluginOptions) Clone() ClientPluginOptions {
if o == nil {
return nil
}
out := *o
return &out
}
type TLS2RawPluginOptions struct {
Type string `json:"type,omitempty"`
LocalAddr string `json:"localAddr,omitempty"`
@@ -188,8 +238,24 @@ type TLS2RawPluginOptions struct {
func (o *TLS2RawPluginOptions) Complete() {}
func (o *TLS2RawPluginOptions) Clone() ClientPluginOptions {
if o == nil {
return nil
}
out := *o
return &out
}
type VirtualNetPluginOptions struct {
Type string `json:"type,omitempty"`
}
func (o *VirtualNetPluginOptions) Complete() {}
func (o *VirtualNetPluginOptions) Clone() ClientPluginOptions {
if o == nil {
return nil
}
out := *o
return &out
}

View File

@@ -15,12 +15,9 @@
package v1
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"reflect"
"github.com/fatedier/frp/pkg/util/jsonx"
"github.com/fatedier/frp/pkg/util/util"
)
@@ -50,6 +47,13 @@ type VisitorBaseConfig struct {
Plugin TypedVisitorPluginOptions `json:"plugin,omitempty"`
}
func (c VisitorBaseConfig) Clone() VisitorBaseConfig {
out := c
out.Enabled = util.ClonePtr(c.Enabled)
out.Plugin = c.Plugin.Clone()
return out
}
func (c *VisitorBaseConfig) GetBaseConfig() *VisitorBaseConfig {
return c
}
@@ -63,6 +67,7 @@ func (c *VisitorBaseConfig) Complete() {
type VisitorConfigurer interface {
Complete()
GetBaseConfig() *VisitorBaseConfig
Clone() VisitorConfigurer
}
type VisitorType string
@@ -74,9 +79,9 @@ const (
)
var visitorConfigTypeMap = map[VisitorType]reflect.Type{
VisitorTypeSTCP: reflect.TypeOf(STCPVisitorConfig{}),
VisitorTypeXTCP: reflect.TypeOf(XTCPVisitorConfig{}),
VisitorTypeSUDP: reflect.TypeOf(SUDPVisitorConfig{}),
VisitorTypeSTCP: reflect.TypeFor[STCPVisitorConfig](),
VisitorTypeXTCP: reflect.TypeFor[XTCPVisitorConfig](),
VisitorTypeSUDP: reflect.TypeFor[SUDPVisitorConfig](),
}
type TypedVisitorConfig struct {
@@ -85,35 +90,18 @@ type TypedVisitorConfig struct {
}
func (c *TypedVisitorConfig) UnmarshalJSON(b []byte) error {
if len(b) == 4 && string(b) == "null" {
return errors.New("type is required")
}
typeStruct := struct {
Type string `json:"type"`
}{}
if err := json.Unmarshal(b, &typeStruct); err != nil {
configurer, err := DecodeVisitorConfigurerJSON(b, DecodeOptions{})
if err != nil {
return err
}
c.Type = typeStruct.Type
configurer := NewVisitorConfigurerByType(VisitorType(typeStruct.Type))
if configurer == nil {
return fmt.Errorf("unknown visitor type: %s", typeStruct.Type)
}
decoder := json.NewDecoder(bytes.NewBuffer(b))
if DisallowUnknownFields {
decoder.DisallowUnknownFields()
}
if err := decoder.Decode(configurer); err != nil {
return fmt.Errorf("unmarshal VisitorConfig error: %v", err)
}
c.Type = configurer.GetBaseConfig().Type
c.VisitorConfigurer = configurer
return nil
}
func (c *TypedVisitorConfig) MarshalJSON() ([]byte, error) {
return json.Marshal(c.VisitorConfigurer)
return jsonx.Marshal(c.VisitorConfigurer)
}
func NewVisitorConfigurerByType(t VisitorType) VisitorConfigurer {
@@ -132,12 +120,24 @@ type STCPVisitorConfig struct {
VisitorBaseConfig
}
func (c *STCPVisitorConfig) Clone() VisitorConfigurer {
out := *c
out.VisitorBaseConfig = c.VisitorBaseConfig.Clone()
return &out
}
var _ VisitorConfigurer = &SUDPVisitorConfig{}
type SUDPVisitorConfig struct {
VisitorBaseConfig
}
func (c *SUDPVisitorConfig) Clone() VisitorConfigurer {
out := *c
out.VisitorBaseConfig = c.VisitorBaseConfig.Clone()
return &out
}
var _ VisitorConfigurer = &XTCPVisitorConfig{}
type XTCPVisitorConfig struct {
@@ -162,3 +162,10 @@ func (c *XTCPVisitorConfig) Complete() {
c.MinRetryInterval = util.EmptyOr(c.MinRetryInterval, 90)
c.FallbackTimeoutMs = util.EmptyOr(c.FallbackTimeoutMs, 1000)
}
func (c *XTCPVisitorConfig) Clone() VisitorConfigurer {
out := *c
out.VisitorBaseConfig = c.VisitorBaseConfig.Clone()
out.NatTraversal = c.NatTraversal.Clone()
return &out
}

View File

@@ -15,11 +15,9 @@
package v1
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"reflect"
"github.com/fatedier/frp/pkg/util/jsonx"
)
const (
@@ -27,11 +25,12 @@ const (
)
var visitorPluginOptionsTypeMap = map[string]reflect.Type{
VisitorPluginVirtualNet: reflect.TypeOf(VirtualNetVisitorPluginOptions{}),
VisitorPluginVirtualNet: reflect.TypeFor[VirtualNetVisitorPluginOptions](),
}
type VisitorPluginOptions interface {
Complete()
Clone() VisitorPluginOptions
}
type TypedVisitorPluginOptions struct {
@@ -39,43 +38,25 @@ type TypedVisitorPluginOptions struct {
VisitorPluginOptions
}
func (c *TypedVisitorPluginOptions) UnmarshalJSON(b []byte) error {
if len(b) == 4 && string(b) == "null" {
return nil
func (c TypedVisitorPluginOptions) Clone() TypedVisitorPluginOptions {
out := c
if c.VisitorPluginOptions != nil {
out.VisitorPluginOptions = c.VisitorPluginOptions.Clone()
}
return out
}
typeStruct := struct {
Type string `json:"type"`
}{}
if err := json.Unmarshal(b, &typeStruct); err != nil {
func (c *TypedVisitorPluginOptions) UnmarshalJSON(b []byte) error {
decoded, err := DecodeVisitorPluginOptionsJSON(b, DecodeOptions{})
if err != nil {
return err
}
c.Type = typeStruct.Type
if c.Type == "" {
return errors.New("visitor plugin type is empty")
}
v, ok := visitorPluginOptionsTypeMap[typeStruct.Type]
if !ok {
return fmt.Errorf("unknown visitor plugin type: %s", typeStruct.Type)
}
options := reflect.New(v).Interface().(VisitorPluginOptions)
decoder := json.NewDecoder(bytes.NewBuffer(b))
if DisallowUnknownFields {
decoder.DisallowUnknownFields()
}
if err := decoder.Decode(options); err != nil {
return fmt.Errorf("unmarshal VisitorPluginOptions error: %v", err)
}
c.VisitorPluginOptions = options
*c = decoded
return nil
}
func (c *TypedVisitorPluginOptions) MarshalJSON() ([]byte, error) {
return json.Marshal(c.VisitorPluginOptions)
return jsonx.Marshal(c.VisitorPluginOptions)
}
type VirtualNetVisitorPluginOptions struct {
@@ -84,3 +65,11 @@ type VirtualNetVisitorPluginOptions struct {
}
func (o *VirtualNetVisitorPluginOptions) Complete() {}
func (o *VirtualNetVisitorPluginOptions) Clone() VisitorPluginOptions {
if o == nil {
return nil
}
out := *o
return &out
}

View File

@@ -143,7 +143,6 @@ func (m *serverMetrics) OpenConnection(name string, _ string) {
proxyStats, ok := m.info.ProxyStatistics[name]
if ok {
proxyStats.CurConns.Inc(1)
m.info.ProxyStatistics[name] = proxyStats
}
}
@@ -155,7 +154,6 @@ func (m *serverMetrics) CloseConnection(name string, _ string) {
proxyStats, ok := m.info.ProxyStatistics[name]
if ok {
proxyStats.CurConns.Dec(1)
m.info.ProxyStatistics[name] = proxyStats
}
}
@@ -168,7 +166,6 @@ func (m *serverMetrics) AddTrafficIn(name string, _ string, trafficBytes int64)
proxyStats, ok := m.info.ProxyStatistics[name]
if ok {
proxyStats.TrafficIn.Inc(trafficBytes)
m.info.ProxyStatistics[name] = proxyStats
}
}
@@ -181,7 +178,6 @@ func (m *serverMetrics) AddTrafficOut(name string, _ string, trafficBytes int64)
proxyStats, ok := m.info.ProxyStatistics[name]
if ok {
proxyStats.TrafficOut.Inc(trafficBytes)
m.info.ProxyStatistics[name] = proxyStats
}
}
@@ -203,6 +199,25 @@ func (m *serverMetrics) GetServer() *ServerStats {
return s
}
func toProxyStats(name string, proxyStats *ProxyStatistics) *ProxyStats {
ps := &ProxyStats{
Name: name,
Type: proxyStats.ProxyType,
User: proxyStats.User,
ClientID: proxyStats.ClientID,
TodayTrafficIn: proxyStats.TrafficIn.TodayCount(),
TodayTrafficOut: proxyStats.TrafficOut.TodayCount(),
CurConns: int64(proxyStats.CurConns.Count()),
}
if !proxyStats.LastStartTime.IsZero() {
ps.LastStartTime = proxyStats.LastStartTime.Format("01-02 15:04:05")
}
if !proxyStats.LastCloseTime.IsZero() {
ps.LastCloseTime = proxyStats.LastCloseTime.Format("01-02 15:04:05")
}
return ps
}
func (m *serverMetrics) GetProxiesByType(proxyType string) []*ProxyStats {
res := make([]*ProxyStats, 0)
m.mu.Lock()
@@ -212,23 +227,7 @@ func (m *serverMetrics) GetProxiesByType(proxyType string) []*ProxyStats {
if proxyStats.ProxyType != proxyType {
continue
}
ps := &ProxyStats{
Name: name,
Type: proxyStats.ProxyType,
User: proxyStats.User,
ClientID: proxyStats.ClientID,
TodayTrafficIn: proxyStats.TrafficIn.TodayCount(),
TodayTrafficOut: proxyStats.TrafficOut.TodayCount(),
CurConns: int64(proxyStats.CurConns.Count()),
}
if !proxyStats.LastStartTime.IsZero() {
ps.LastStartTime = proxyStats.LastStartTime.Format("01-02 15:04:05")
}
if !proxyStats.LastCloseTime.IsZero() {
ps.LastCloseTime = proxyStats.LastCloseTime.Format("01-02 15:04:05")
}
res = append(res, ps)
res = append(res, toProxyStats(name, proxyStats))
}
return res
}
@@ -237,31 +236,9 @@ func (m *serverMetrics) GetProxiesByTypeAndName(proxyType string, proxyName stri
m.mu.Lock()
defer m.mu.Unlock()
for name, proxyStats := range m.info.ProxyStatistics {
if proxyStats.ProxyType != proxyType {
continue
}
if name != proxyName {
continue
}
res = &ProxyStats{
Name: name,
Type: proxyStats.ProxyType,
User: proxyStats.User,
ClientID: proxyStats.ClientID,
TodayTrafficIn: proxyStats.TrafficIn.TodayCount(),
TodayTrafficOut: proxyStats.TrafficOut.TodayCount(),
CurConns: int64(proxyStats.CurConns.Count()),
}
if !proxyStats.LastStartTime.IsZero() {
res.LastStartTime = proxyStats.LastStartTime.Format("01-02 15:04:05")
}
if !proxyStats.LastCloseTime.IsZero() {
res.LastCloseTime = proxyStats.LastCloseTime.Format("01-02 15:04:05")
}
break
proxyStats, ok := m.info.ProxyStatistics[proxyName]
if ok && proxyStats.ProxyType == proxyType {
res = toProxyStats(proxyName, proxyStats)
}
return
}
@@ -272,21 +249,7 @@ func (m *serverMetrics) GetProxyByName(proxyName string) (res *ProxyStats) {
proxyStats, ok := m.info.ProxyStatistics[proxyName]
if ok {
res = &ProxyStats{
Name: proxyName,
Type: proxyStats.ProxyType,
User: proxyStats.User,
ClientID: proxyStats.ClientID,
TodayTrafficIn: proxyStats.TrafficIn.TodayCount(),
TodayTrafficOut: proxyStats.TrafficOut.TodayCount(),
CurConns: int64(proxyStats.CurConns.Count()),
}
if !proxyStats.LastStartTime.IsZero() {
res.LastStartTime = proxyStats.LastStartTime.Format("01-02 15:04:05")
}
if !proxyStats.LastCloseTime.IsZero() {
res.LastCloseTime = proxyStats.LastCloseTime.Format("01-02 15:04:05")
}
res = toProxyStats(proxyName, proxyStats)
}
return
}

View File

@@ -61,7 +61,7 @@ var msgTypeMap = map[byte]any{
TypeNatHoleReport: NatHoleReport{},
}
var TypeNameNatHoleResp = reflect.TypeOf(&NatHoleResp{}).Elem().Name()
var TypeNameNatHoleResp = reflect.TypeFor[NatHoleResp]().Name()
type ClientSpec struct {
// Due to the support of VirtualClient, frps needs to know the client type in order to
@@ -184,7 +184,7 @@ type Pong struct {
}
type UDPPacket struct {
Content string `json:"c,omitempty"`
Content []byte `json:"c,omitempty"`
LocalAddr *net.UDPAddr `json:"l,omitempty"`
RemoteAddr *net.UDPAddr `json:"r,omitempty"`
}

View File

@@ -1,4 +1,4 @@
package util
package naming
import "strings"
@@ -16,9 +16,8 @@ func StripUserPrefix(user, name string) string {
if user == "" {
return name
}
prefix := user + "."
if strings.HasPrefix(name, prefix) {
return strings.TrimPrefix(name, prefix)
if trimmed, ok := strings.CutPrefix(name, user+"."); ok {
return trimmed
}
return name
}

27
pkg/naming/names_test.go Normal file
View File

@@ -0,0 +1,27 @@
package naming
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestAddUserPrefix(t *testing.T) {
require := require.New(t)
require.Equal("test", AddUserPrefix("", "test"))
require.Equal("alice.test", AddUserPrefix("alice", "test"))
}
func TestStripUserPrefix(t *testing.T) {
require := require.New(t)
require.Equal("test", StripUserPrefix("", "test"))
require.Equal("test", StripUserPrefix("alice", "alice.test"))
require.Equal("alice.test", StripUserPrefix("alice", "alice.alice.test"))
require.Equal("bob.test", StripUserPrefix("alice", "bob.test"))
}
func TestBuildTargetServerProxyName(t *testing.T) {
require := require.New(t)
require.Equal("alice.test", BuildTargetServerProxyName("alice", "", "test"))
require.Equal("bob.test", BuildTargetServerProxyName("alice", "bob", "test"))
}

View File

@@ -151,7 +151,7 @@ func getBehaviorScoresByMode(mode int, defaultScore int) []*BehaviorScore {
func getBehaviorScoresByMode2(mode int, senderScore, receiverScore int) []*BehaviorScore {
behaviors := getBehaviorByMode(mode)
scores := make([]*BehaviorScore, 0, len(behaviors))
for i := 0; i < len(behaviors); i++ {
for i := range behaviors {
score := receiverScore
if behaviors[i].A.Role == DetectRoleSender {
score = senderScore

View File

@@ -70,12 +70,8 @@ func ClassifyNATFeature(addresses []string, localIPs []string) (*NatFeature, err
continue
}
if portNum > portMax {
portMax = portNum
}
if portNum < portMin {
portMin = portNum
}
portMax = max(portMax, portNum)
portMin = min(portMin, portNum)
if baseIP != ip {
ipChanged = true
}

View File

@@ -152,7 +152,9 @@ func (c *Controller) GenSid() string {
func (c *Controller) HandleVisitor(m *msg.NatHoleVisitor, transporter transport.MessageTransporter, visitorUser string) {
if m.PreCheck {
c.mu.RLock()
cfg, ok := c.clientCfgs[m.ProxyName]
c.mu.RUnlock()
if !ok {
_ = transporter.Send(c.GenNatHoleResponse(m.TransactionID, nil, fmt.Sprintf("xtcp server for [%s] doesn't exist", m.ProxyName)))
return
@@ -375,7 +377,7 @@ func getRangePorts(addrs []string, difference, maxNumber int) []msg.PortsRange {
if !isLast {
return nil
}
var ports []msg.PortsRange
ports := make([]msg.PortsRange, 0, 1)
_, portStr, err := net.SplitHostPort(addr)
if err != nil {
return nil

View File

@@ -298,11 +298,13 @@ func waitDetectMessage(
n, raddr, err := conn.ReadFromUDP(buf)
_ = conn.SetReadDeadline(time.Time{})
if err != nil {
pool.PutBuf(buf)
return nil, err
}
xl.Debugf("get udp message local %s, from %s", conn.LocalAddr(), raddr)
var m msg.NatHoleSid
if err := DecodeMessageInto(buf[:n], key, &m); err != nil {
pool.PutBuf(buf)
xl.Warnf("decode sid message error: %v", err)
continue
}
@@ -408,7 +410,7 @@ func sendSidMessageToRandomPorts(
xl := xlog.FromContextSafe(ctx)
used := sets.New[int]()
getUnusedPort := func() int {
for i := 0; i < 10; i++ {
for range 10 {
port := rand.IntN(65535-1024) + 1024
if !used.Has(port) {
used.Insert(port)
@@ -418,7 +420,7 @@ func sendSidMessageToRandomPorts(
return 0
}
for i := 0; i < count; i++ {
for range count {
select {
case <-ctx.Done():
return

View File

@@ -21,6 +21,7 @@ import (
stdlog "log"
"net/http"
"net/http/httputil"
"time"
"github.com/fatedier/golib/pool"
@@ -68,7 +69,7 @@ func NewHTTP2HTTPPlugin(_ PluginContext, options v1.ClientPluginOptions) (Plugin
p.s = &http.Server{
Handler: rp,
ReadHeaderTimeout: 0,
ReadHeaderTimeout: 60 * time.Second,
}
go func() {

View File

@@ -22,6 +22,7 @@ import (
stdlog "log"
"net/http"
"net/http/httputil"
"time"
"github.com/fatedier/golib/pool"
@@ -77,7 +78,7 @@ func NewHTTP2HTTPSPlugin(_ PluginContext, options v1.ClientPluginOptions) (Plugi
p.s = &http.Server{
Handler: rp,
ReadHeaderTimeout: 0,
ReadHeaderTimeout: 60 * time.Second,
}
go func() {

View File

@@ -62,11 +62,13 @@ func (p *TLS2RawPlugin) Handle(ctx context.Context, connInfo *ConnectionInfo) {
if err := tlsConn.Handshake(); err != nil {
xl.Warnf("tls handshake error: %v", err)
tlsConn.Close()
return
}
rawConn, err := net.Dial("tcp", p.opts.LocalAddr)
if err != nil {
xl.Warnf("dial to local addr error: %v", err)
tlsConn.Close()
return
}

View File

@@ -54,10 +54,13 @@ func (uds *UnixDomainSocketPlugin) Handle(ctx context.Context, connInfo *Connect
localConn, err := net.DialUnix("unix", nil, uds.UnixAddr)
if err != nil {
xl.Warnf("dial to uds %s error: %v", uds.UnixAddr, err)
connInfo.Conn.Close()
return
}
if connInfo.ProxyProtocolHeader != nil {
if _, err := connInfo.ProxyProtocolHeader.WriteTo(localConn); err != nil {
localConn.Close()
connInfo.Conn.Close()
return
}
}

View File

@@ -24,6 +24,7 @@ import (
"net/http"
"net/url"
"reflect"
"slices"
"strings"
v1 "github.com/fatedier/frp/pkg/config/v1"
@@ -64,12 +65,7 @@ func (p *httpPlugin) Name() string {
}
func (p *httpPlugin) IsSupport(op string) bool {
for _, v := range p.options.Ops {
if v == op {
return true
}
}
return false
return slices.Contains(p.options.Ops, op)
}
func (p *httpPlugin) Handle(ctx context.Context, op string, content any) (*Response, any, error) {

View File

@@ -153,10 +153,7 @@ func (p *VirtualNetPlugin) run() {
// Exponential backoff: 60s, 120s, 240s, 300s (capped)
baseDelay := 60 * time.Second
reconnectDelay = baseDelay * time.Duration(1<<uint(p.consecutiveErrors-1))
if reconnectDelay > 300*time.Second {
reconnectDelay = 300 * time.Second
}
reconnectDelay = min(baseDelay*time.Duration(1<<uint(p.consecutiveErrors-1)), 300*time.Second)
} else {
// Reset consecutive errors on successful connection
if p.consecutiveErrors > 0 {

View File

@@ -16,6 +16,7 @@ package featuregate
import (
"fmt"
"maps"
"sort"
"strings"
"sync"
@@ -92,10 +93,7 @@ type featureGate struct {
// NewFeatureGate creates a new feature gate with the default features
func NewFeatureGate() MutableFeatureGate {
known := map[Feature]FeatureSpec{}
for k, v := range defaultFeatures {
known[k] = v
}
known := maps.Clone(defaultFeatures)
f := &featureGate{}
f.known.Store(known)
@@ -109,14 +107,8 @@ func (f *featureGate) SetFromMap(m map[string]bool) error {
defer f.lock.Unlock()
// Copy existing state
known := map[Feature]FeatureSpec{}
for k, v := range f.known.Load().(map[Feature]FeatureSpec) {
known[k] = v
}
enabled := map[Feature]bool{}
for k, v := range f.enabled.Load().(map[Feature]bool) {
enabled[k] = v
}
known := maps.Clone(f.known.Load().(map[Feature]FeatureSpec))
enabled := maps.Clone(f.enabled.Load().(map[Feature]bool))
// Apply the new settings
for k, v := range m {
@@ -147,10 +139,7 @@ func (f *featureGate) Add(features map[Feature]FeatureSpec) error {
}
// Copy existing state
known := map[Feature]FeatureSpec{}
for k, v := range f.known.Load().(map[Feature]FeatureSpec) {
known[k] = v
}
known := maps.Clone(f.known.Load().(map[Feature]FeatureSpec))
// Add new features
for name, spec := range features {
@@ -171,8 +160,9 @@ func (f *featureGate) Add(features map[Feature]FeatureSpec) error {
// String returns a string containing all enabled feature gates, formatted as "key1=value1,key2=value2,..."
func (f *featureGate) String() string {
pairs := []string{}
for k, v := range f.enabled.Load().(map[Feature]bool) {
enabled := f.enabled.Load().(map[Feature]bool)
pairs := make([]string, 0, len(enabled))
for k, v := range enabled {
pairs = append(pairs, fmt.Sprintf("%s=%t", k, v))
}
sort.Strings(pairs)

View File

@@ -15,7 +15,6 @@
package udp
import (
"encoding/base64"
"net"
"sync"
"time"
@@ -28,16 +27,17 @@ import (
)
func NewUDPPacket(buf []byte, laddr, raddr *net.UDPAddr) *msg.UDPPacket {
content := make([]byte, len(buf))
copy(content, buf)
return &msg.UDPPacket{
Content: base64.StdEncoding.EncodeToString(buf),
Content: content,
LocalAddr: laddr,
RemoteAddr: raddr,
}
}
func GetContent(m *msg.UDPPacket) (buf []byte, err error) {
buf, err = base64.StdEncoding.DecodeString(m.Content)
return
return m.Content, nil
}
func ForwardUserConn(udpConn *net.UDPConn, readCh <-chan *msg.UDPPacket, sendCh chan<- *msg.UDPPacket, bufSize int) {
@@ -60,7 +60,7 @@ func ForwardUserConn(udpConn *net.UDPConn, readCh <-chan *msg.UDPPacket, sendCh
if err != nil {
return
}
// buf[:n] will be encoded to string, so the bytes can be reused
// NewUDPPacket copies buf[:n], so the read buffer can be reused
udpMsg := NewUDPPacket(buf[:n], nil, remoteAddr)
select {
@@ -85,6 +85,7 @@ func Forwarder(dstAddr *net.UDPAddr, readCh <-chan *msg.UDPPacket, sendCh chan<-
}()
buf := pool.GetBuf(bufSize)
defer pool.PutBuf(buf)
for {
_ = udpConn.SetReadDeadline(time.Now().Add(30 * time.Second))
n, _, err := udpConn.ReadFromUDP(buf)

View File

@@ -11,7 +11,7 @@ import (
"strconv"
"strings"
"github.com/fatedier/frp/client/api"
"github.com/fatedier/frp/client/http/model"
httppkg "github.com/fatedier/frp/pkg/util/http"
)
@@ -32,7 +32,7 @@ func (c *Client) SetAuth(user, pwd string) {
c.authPwd = pwd
}
func (c *Client) GetProxyStatus(ctx context.Context, name string) (*api.ProxyStatusResp, error) {
func (c *Client) GetProxyStatus(ctx context.Context, name string) (*model.ProxyStatusResp, error) {
req, err := http.NewRequestWithContext(ctx, "GET", "http://"+c.address+"/api/status", nil)
if err != nil {
return nil, err
@@ -41,7 +41,7 @@ func (c *Client) GetProxyStatus(ctx context.Context, name string) (*api.ProxySta
if err != nil {
return nil, err
}
allStatus := make(api.StatusResp)
allStatus := make(model.StatusResp)
if err = json.Unmarshal([]byte(content), &allStatus); err != nil {
return nil, fmt.Errorf("unmarshal http response error: %s", strings.TrimSpace(content))
}
@@ -55,7 +55,7 @@ func (c *Client) GetProxyStatus(ctx context.Context, name string) (*api.ProxySta
return nil, fmt.Errorf("no proxy status found")
}
func (c *Client) GetAllProxyStatus(ctx context.Context) (api.StatusResp, error) {
func (c *Client) GetAllProxyStatus(ctx context.Context) (model.StatusResp, error) {
req, err := http.NewRequestWithContext(ctx, "GET", "http://"+c.address+"/api/status", nil)
if err != nil {
return nil, err
@@ -64,7 +64,7 @@ func (c *Client) GetAllProxyStatus(ctx context.Context) (api.StatusResp, error)
if err != nil {
return nil, err
}
allStatus := make(api.StatusResp)
allStatus := make(model.StatusResp)
if err = json.Unmarshal([]byte(content), &allStatus); err != nil {
return nil, fmt.Errorf("unmarshal http response error: %s", strings.TrimSpace(content))
}

View File

@@ -20,6 +20,7 @@ import (
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"math/big"
"os"
"time"
@@ -85,7 +86,9 @@ func newCertPool(caPath string) (*x509.CertPool, error) {
return nil, err
}
pool.AppendCertsFromPEM(caCrt)
if !pool.AppendCertsFromPEM(caCrt) {
return nil, fmt.Errorf("failed to parse CA certificate from file %q: no valid PEM certificates found", caPath)
}
return pool, nil
}

View File

@@ -89,11 +89,11 @@ func ParseBasicAuth(auth string) (username, password string, ok bool) {
return
}
cs := string(c)
s := strings.IndexByte(cs, ':')
if s < 0 {
before, after, found := strings.Cut(cs, ":")
if !found {
return
}
return cs[:s], cs[s+1:], true
return before, after, true
}
func BasicAuth(username, passwd string) string {

45
pkg/util/jsonx/json_v1.go Normal file
View File

@@ -0,0 +1,45 @@
// Copyright 2026 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package jsonx
import (
"bytes"
"encoding/json"
)
type DecodeOptions struct {
RejectUnknownMembers bool
}
func Marshal(v any) ([]byte, error) {
return json.Marshal(v)
}
func MarshalIndent(v any, prefix, indent string) ([]byte, error) {
return json.MarshalIndent(v, prefix, indent)
}
func Unmarshal(data []byte, out any) error {
return json.Unmarshal(data, out)
}
func UnmarshalWithOptions(data []byte, out any, options DecodeOptions) error {
if !options.RejectUnknownMembers {
return json.Unmarshal(data, out)
}
decoder := json.NewDecoder(bytes.NewReader(data))
decoder.DisallowUnknownFields()
return decoder.Decode(out)
}

View File

@@ -0,0 +1,36 @@
// Copyright 2026 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package jsonx
import "fmt"
// RawMessage stores a raw encoded JSON value.
// It is equivalent to encoding/json.RawMessage behavior.
type RawMessage []byte
func (m RawMessage) MarshalJSON() ([]byte, error) {
if m == nil {
return []byte("null"), nil
}
return m, nil
}
func (m *RawMessage) UnmarshalJSON(data []byte) error {
if m == nil {
return fmt.Errorf("jsonx.RawMessage: UnmarshalJSON on nil pointer")
}
*m = append((*m)[:0], data...)
return nil
}

View File

@@ -86,11 +86,7 @@ func (c *FakeUDPConn) Read(b []byte) (n int, err error) {
c.lastActive = time.Now()
c.mu.Unlock()
if len(b) < len(content) {
n = len(b)
} else {
n = len(content)
}
n = min(len(b), len(content))
copy(b, content)
return n, nil
}
@@ -168,11 +164,15 @@ func ListenUDP(bindAddr string, bindPort int) (l *UDPListener, err error) {
return l, err
}
readConn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
return l, err
}
l = &UDPListener{
addr: udpAddr,
acceptCh: make(chan net.Conn),
writeCh: make(chan *UDPPacket, 1000),
readConn: readConn,
fakeConns: make(map[string]*FakeUDPConn),
}

View File

@@ -26,6 +26,7 @@ type WebsocketListener struct {
// ln: tcp listener for websocket connections
func NewWebsocketListener(ln net.Listener) (wl *WebsocketListener) {
wl = &WebsocketListener{
ln: ln,
acceptCh: make(chan net.Conn),
}

View File

@@ -68,8 +68,8 @@ func ParseRangeNumbers(rangeStr string) (numbers []int64, err error) {
rangeStr = strings.TrimSpace(rangeStr)
numbers = make([]int64, 0)
// e.g. 1000-2000,2001,2002,3000-4000
numRanges := strings.Split(rangeStr, ",")
for _, numRangeStr := range numRanges {
numRanges := strings.SplitSeq(rangeStr, ",")
for numRangeStr := range numRanges {
// 1000-2000 or 2001
numArray := strings.Split(numRangeStr, "-")
// length: only 1 or 2 is correct
@@ -134,3 +134,12 @@ func RandomSleep(duration time.Duration, minRatio, maxRatio float64) time.Durati
func ConstantTimeEqString(a, b string) bool {
return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1
}
// ClonePtr returns a pointer to a copied value. If v is nil, it returns nil.
func ClonePtr[T any](v *T) *T {
if v == nil {
return nil
}
out := *v
return &out
}

View File

@@ -42,22 +42,15 @@ func TestParseRangeNumbers(t *testing.T) {
require.Error(err)
}
func TestAddUserPrefix(t *testing.T) {
func TestClonePtr(t *testing.T) {
require := require.New(t)
require.Equal("test", AddUserPrefix("", "test"))
require.Equal("alice.test", AddUserPrefix("alice", "test"))
}
func TestStripUserPrefix(t *testing.T) {
require := require.New(t)
require.Equal("test", StripUserPrefix("", "test"))
require.Equal("test", StripUserPrefix("alice", "alice.test"))
require.Equal("alice.test", StripUserPrefix("alice", "alice.alice.test"))
require.Equal("bob.test", StripUserPrefix("alice", "bob.test"))
}
var nilInt *int
require.Nil(ClonePtr(nilInt))
func TestBuildTargetServerProxyName(t *testing.T) {
require := require.New(t)
require.Equal("alice.test", BuildTargetServerProxyName("alice", "", "test"))
require.Equal("bob.test", BuildTargetServerProxyName("alice", "bob", "test"))
v := 42
cloned := ClonePtr(&v)
require.NotNil(cloned)
require.Equal(v, *cloned)
require.NotSame(&v, cloned)
}

View File

@@ -14,7 +14,7 @@
package version
var version = "0.67.0"
var version = "0.68.0"
func Full() string {
return version

View File

@@ -266,31 +266,13 @@ func (rp *HTTPReverseProxy) connectHandler(rw http.ResponseWriter, req *http.Req
go libio.Join(remote, client)
}
func parseBasicAuth(auth string) (username, password string, ok bool) {
const prefix = "Basic "
// Case insensitive prefix match. See Issue 22736.
if len(auth) < len(prefix) || !strings.EqualFold(auth[:len(prefix)], prefix) {
return
}
c, err := base64.StdEncoding.DecodeString(auth[len(prefix):])
if err != nil {
return
}
cs := string(c)
s := strings.IndexByte(cs, ':')
if s < 0 {
return
}
return cs[:s], cs[s+1:], true
}
func (rp *HTTPReverseProxy) injectRequestInfoToCtx(req *http.Request) *http.Request {
user := ""
// If url host isn't empty, it's a proxy request. Get http user from Proxy-Authorization header.
if req.URL.Host != "" {
proxyAuth := req.Header.Get("Proxy-Authorization")
if proxyAuth != "" {
user, _, _ = parseBasicAuth(proxyAuth)
user, _, _ = httppkg.ParseBasicAuth(proxyAuth)
}
}
if user == "" {

View File

@@ -63,11 +63,12 @@ func (l *Logger) AddPrefix(prefix LogPrefix) *Logger {
if prefix.Priority <= 0 {
prefix.Priority = 10
}
for _, p := range l.prefixes {
for i, p := range l.prefixes {
if p.Name == prefix.Name {
found = true
p.Value = prefix.Value
p.Priority = prefix.Priority
l.prefixes[i].Value = prefix.Value
l.prefixes[i].Priority = prefix.Priority
break
}
}
if !found {

64
server/api_router.go Normal file
View File

@@ -0,0 +1,64 @@
// Copyright 2017 fatedier, fatedier@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package server
import (
"net/http"
"github.com/prometheus/client_golang/prometheus/promhttp"
httppkg "github.com/fatedier/frp/pkg/util/http"
netpkg "github.com/fatedier/frp/pkg/util/net"
adminapi "github.com/fatedier/frp/server/http"
)
func (svr *Service) registerRouteHandlers(helper *httppkg.RouterRegisterHelper) {
helper.Router.HandleFunc("/healthz", healthz)
subRouter := helper.Router.NewRoute().Subrouter()
subRouter.Use(helper.AuthMiddleware)
subRouter.Use(httppkg.NewRequestLogger)
// metrics
if svr.cfg.EnablePrometheus {
subRouter.Handle("/metrics", promhttp.Handler())
}
apiController := adminapi.NewController(svr.cfg, svr.clientRegistry, svr.pxyManager)
// apis
subRouter.HandleFunc("/api/serverinfo", httppkg.MakeHTTPHandlerFunc(apiController.APIServerInfo)).Methods("GET")
subRouter.HandleFunc("/api/proxy/{type}", httppkg.MakeHTTPHandlerFunc(apiController.APIProxyByType)).Methods("GET")
subRouter.HandleFunc("/api/proxy/{type}/{name}", httppkg.MakeHTTPHandlerFunc(apiController.APIProxyByTypeAndName)).Methods("GET")
subRouter.HandleFunc("/api/proxies/{name}", httppkg.MakeHTTPHandlerFunc(apiController.APIProxyByName)).Methods("GET")
subRouter.HandleFunc("/api/traffic/{name}", httppkg.MakeHTTPHandlerFunc(apiController.APIProxyTraffic)).Methods("GET")
subRouter.HandleFunc("/api/clients", httppkg.MakeHTTPHandlerFunc(apiController.APIClientList)).Methods("GET")
subRouter.HandleFunc("/api/clients/{key}", httppkg.MakeHTTPHandlerFunc(apiController.APIClientDetail)).Methods("GET")
subRouter.HandleFunc("/api/proxies", httppkg.MakeHTTPHandlerFunc(apiController.DeleteProxies)).Methods("DELETE")
// view
subRouter.Handle("/favicon.ico", http.FileServer(helper.AssetsFS)).Methods("GET")
subRouter.PathPrefix("/static/").Handler(
netpkg.MakeHTTPGzipHandler(http.StripPrefix("/static/", http.FileServer(helper.AssetsFS))),
).Methods("GET")
subRouter.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/static/", http.StatusMovedPermanently)
})
}
func healthz(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(200)
}

View File

@@ -95,20 +95,33 @@ func (cm *ControlManager) Close() error {
return nil
}
type Control struct {
// SessionContext encapsulates the input parameters for creating a new Control.
type SessionContext struct {
// all resource managers and controllers
rc *controller.ResourceController
RC *controller.ResourceController
// proxy manager
pxyManager *proxy.Manager
PxyManager *proxy.Manager
// plugin manager
pluginManager *plugin.Manager
PluginManager *plugin.Manager
// verifies authentication based on selected method
authVerifier auth.Verifier
AuthVerifier auth.Verifier
// key used for connection encryption
encryptionKey []byte
EncryptionKey []byte
// control connection
Conn net.Conn
// indicates whether the connection is encrypted
ConnEncrypted bool
// login message
LoginMsg *msg.Login
// server configuration
ServerCfg *v1.ServerConfig
// client registry
ClientRegistry *registry.ClientRegistry
}
type Control struct {
// session context
sessionCtx *SessionContext
// other components can use this to communicate with client
msgTransporter transport.MessageTransporter
@@ -117,12 +130,6 @@ type Control struct {
// It provides a channel for sending messages, and you can register handlers to process messages based on their respective types.
msgDispatcher *msg.Dispatcher
// login message
loginMsg *msg.Login
// control connection
conn net.Conn
// work connections
workConnCh chan net.Conn
@@ -145,61 +152,34 @@ type Control struct {
mu sync.RWMutex
// Server configuration information
serverCfg *v1.ServerConfig
clientRegistry *registry.ClientRegistry
xl *xlog.Logger
ctx context.Context
doneCh chan struct{}
}
// TODO(fatedier): Referencing the implementation of frpc, encapsulate the input parameters as SessionContext.
func NewControl(
ctx context.Context,
rc *controller.ResourceController,
pxyManager *proxy.Manager,
pluginManager *plugin.Manager,
authVerifier auth.Verifier,
encryptionKey []byte,
ctlConn net.Conn,
ctlConnEncrypted bool,
loginMsg *msg.Login,
serverCfg *v1.ServerConfig,
) (*Control, error) {
poolCount := loginMsg.PoolCount
if poolCount > int(serverCfg.Transport.MaxPoolCount) {
poolCount = int(serverCfg.Transport.MaxPoolCount)
}
func NewControl(ctx context.Context, sessionCtx *SessionContext) (*Control, error) {
poolCount := min(sessionCtx.LoginMsg.PoolCount, int(sessionCtx.ServerCfg.Transport.MaxPoolCount))
ctl := &Control{
rc: rc,
pxyManager: pxyManager,
pluginManager: pluginManager,
authVerifier: authVerifier,
encryptionKey: encryptionKey,
conn: ctlConn,
loginMsg: loginMsg,
workConnCh: make(chan net.Conn, poolCount+10),
proxies: make(map[string]proxy.Proxy),
poolCount: poolCount,
portsUsedNum: 0,
runID: loginMsg.RunID,
serverCfg: serverCfg,
xl: xlog.FromContextSafe(ctx),
ctx: ctx,
doneCh: make(chan struct{}),
sessionCtx: sessionCtx,
workConnCh: make(chan net.Conn, poolCount+10),
proxies: make(map[string]proxy.Proxy),
poolCount: poolCount,
portsUsedNum: 0,
runID: sessionCtx.LoginMsg.RunID,
xl: xlog.FromContextSafe(ctx),
ctx: ctx,
doneCh: make(chan struct{}),
}
ctl.lastPing.Store(time.Now())
if ctlConnEncrypted {
cryptoRW, err := netpkg.NewCryptoReadWriter(ctl.conn, ctl.encryptionKey)
if sessionCtx.ConnEncrypted {
cryptoRW, err := netpkg.NewCryptoReadWriter(sessionCtx.Conn, sessionCtx.EncryptionKey)
if err != nil {
return nil, err
}
ctl.msgDispatcher = msg.NewDispatcher(cryptoRW)
} else {
ctl.msgDispatcher = msg.NewDispatcher(ctl.conn)
ctl.msgDispatcher = msg.NewDispatcher(sessionCtx.Conn)
}
ctl.registerMsgHandlers()
ctl.msgTransporter = transport.NewMessageTransporter(ctl.msgDispatcher)
@@ -213,7 +193,7 @@ func (ctl *Control) Start() {
RunID: ctl.runID,
Error: "",
}
_ = msg.WriteMsg(ctl.conn, loginRespMsg)
_ = msg.WriteMsg(ctl.sessionCtx.Conn, loginRespMsg)
go func() {
for i := 0; i < ctl.poolCount; i++ {
@@ -225,7 +205,7 @@ func (ctl *Control) Start() {
}
func (ctl *Control) Close() error {
ctl.conn.Close()
ctl.sessionCtx.Conn.Close()
return nil
}
@@ -233,7 +213,7 @@ func (ctl *Control) Replaced(newCtl *Control) {
xl := ctl.xl
xl.Infof("replaced by client [%s]", newCtl.runID)
ctl.runID = ""
ctl.conn.Close()
ctl.sessionCtx.Conn.Close()
}
func (ctl *Control) RegisterWorkConn(conn net.Conn) error {
@@ -291,7 +271,7 @@ func (ctl *Control) GetWorkConn() (workConn net.Conn, err error) {
return
}
case <-time.After(time.Duration(ctl.serverCfg.UserConnTimeout) * time.Second):
case <-time.After(time.Duration(ctl.sessionCtx.ServerCfg.UserConnTimeout) * time.Second):
err = fmt.Errorf("timeout trying to get work connection")
xl.Warnf("%v", err)
return
@@ -304,15 +284,15 @@ func (ctl *Control) GetWorkConn() (workConn net.Conn, err error) {
}
func (ctl *Control) heartbeatWorker() {
if ctl.serverCfg.Transport.HeartbeatTimeout <= 0 {
if ctl.sessionCtx.ServerCfg.Transport.HeartbeatTimeout <= 0 {
return
}
xl := ctl.xl
go wait.Until(func() {
if time.Since(ctl.lastPing.Load().(time.Time)) > time.Duration(ctl.serverCfg.Transport.HeartbeatTimeout)*time.Second {
if time.Since(ctl.lastPing.Load().(time.Time)) > time.Duration(ctl.sessionCtx.ServerCfg.Transport.HeartbeatTimeout)*time.Second {
xl.Warnf("heartbeat timeout")
ctl.conn.Close()
ctl.sessionCtx.Conn.Close()
return
}
}, time.Second, ctl.doneCh)
@@ -323,6 +303,30 @@ func (ctl *Control) WaitClosed() {
<-ctl.doneCh
}
func (ctl *Control) loginUserInfo() plugin.UserInfo {
return plugin.UserInfo{
User: ctl.sessionCtx.LoginMsg.User,
Metas: ctl.sessionCtx.LoginMsg.Metas,
RunID: ctl.sessionCtx.LoginMsg.RunID,
}
}
func (ctl *Control) closeProxy(pxy proxy.Proxy) {
pxy.Close()
ctl.sessionCtx.PxyManager.Del(pxy.GetName())
metrics.Server.CloseProxy(pxy.GetName(), pxy.GetConfigurer().GetBaseConfig().Type)
notifyContent := &plugin.CloseProxyContent{
User: ctl.loginUserInfo(),
CloseProxy: msg.CloseProxy{
ProxyName: pxy.GetName(),
},
}
go func() {
_ = ctl.sessionCtx.PluginManager.CloseProxy(notifyContent)
}()
}
func (ctl *Control) worker() {
xl := ctl.xl
@@ -330,38 +334,23 @@ func (ctl *Control) worker() {
go ctl.msgDispatcher.Run()
<-ctl.msgDispatcher.Done()
ctl.conn.Close()
ctl.sessionCtx.Conn.Close()
ctl.mu.Lock()
defer ctl.mu.Unlock()
close(ctl.workConnCh)
for workConn := range ctl.workConnCh {
workConn.Close()
}
proxies := ctl.proxies
ctl.proxies = make(map[string]proxy.Proxy)
ctl.mu.Unlock()
for _, pxy := range ctl.proxies {
pxy.Close()
ctl.pxyManager.Del(pxy.GetName())
metrics.Server.CloseProxy(pxy.GetName(), pxy.GetConfigurer().GetBaseConfig().Type)
notifyContent := &plugin.CloseProxyContent{
User: plugin.UserInfo{
User: ctl.loginMsg.User,
Metas: ctl.loginMsg.Metas,
RunID: ctl.loginMsg.RunID,
},
CloseProxy: msg.CloseProxy{
ProxyName: pxy.GetName(),
},
}
go func() {
_ = ctl.pluginManager.CloseProxy(notifyContent)
}()
for _, pxy := range proxies {
ctl.closeProxy(pxy)
}
metrics.Server.CloseClient()
ctl.clientRegistry.MarkOfflineByRunID(ctl.runID)
ctl.sessionCtx.ClientRegistry.MarkOfflineByRunID(ctl.runID)
xl.Infof("client exit success")
close(ctl.doneCh)
}
@@ -380,15 +369,11 @@ func (ctl *Control) handleNewProxy(m msg.Message) {
inMsg := m.(*msg.NewProxy)
content := &plugin.NewProxyContent{
User: plugin.UserInfo{
User: ctl.loginMsg.User,
Metas: ctl.loginMsg.Metas,
RunID: ctl.loginMsg.RunID,
},
User: ctl.loginUserInfo(),
NewProxy: *inMsg,
}
var remoteAddr string
retContent, err := ctl.pluginManager.NewProxy(content)
retContent, err := ctl.sessionCtx.PluginManager.NewProxy(content)
if err == nil {
inMsg = &retContent.NewProxy
remoteAddr, err = ctl.RegisterProxy(inMsg)
@@ -401,15 +386,15 @@ func (ctl *Control) handleNewProxy(m msg.Message) {
if err != nil {
xl.Warnf("new proxy [%s] type [%s] error: %v", inMsg.ProxyName, inMsg.ProxyType, err)
resp.Error = util.GenerateResponseErrorString(fmt.Sprintf("new proxy [%s] error", inMsg.ProxyName),
err, lo.FromPtr(ctl.serverCfg.DetailedErrorsToClient))
err, lo.FromPtr(ctl.sessionCtx.ServerCfg.DetailedErrorsToClient))
} else {
resp.RemoteAddr = remoteAddr
xl.Infof("new proxy [%s] type [%s] success", inMsg.ProxyName, inMsg.ProxyType)
clientID := ctl.loginMsg.ClientID
clientID := ctl.sessionCtx.LoginMsg.ClientID
if clientID == "" {
clientID = ctl.loginMsg.RunID
clientID = ctl.sessionCtx.LoginMsg.RunID
}
metrics.Server.NewProxy(inMsg.ProxyName, inMsg.ProxyType, ctl.loginMsg.User, clientID)
metrics.Server.NewProxy(inMsg.ProxyName, inMsg.ProxyType, ctl.sessionCtx.LoginMsg.User, clientID)
}
_ = ctl.msgDispatcher.Send(resp)
}
@@ -419,22 +404,18 @@ func (ctl *Control) handlePing(m msg.Message) {
inMsg := m.(*msg.Ping)
content := &plugin.PingContent{
User: plugin.UserInfo{
User: ctl.loginMsg.User,
Metas: ctl.loginMsg.Metas,
RunID: ctl.loginMsg.RunID,
},
User: ctl.loginUserInfo(),
Ping: *inMsg,
}
retContent, err := ctl.pluginManager.Ping(content)
retContent, err := ctl.sessionCtx.PluginManager.Ping(content)
if err == nil {
inMsg = &retContent.Ping
err = ctl.authVerifier.VerifyPing(inMsg)
err = ctl.sessionCtx.AuthVerifier.VerifyPing(inMsg)
}
if err != nil {
xl.Warnf("received invalid ping: %v", err)
_ = ctl.msgDispatcher.Send(&msg.Pong{
Error: util.GenerateResponseErrorString("invalid ping", err, lo.FromPtr(ctl.serverCfg.DetailedErrorsToClient)),
Error: util.GenerateResponseErrorString("invalid ping", err, lo.FromPtr(ctl.sessionCtx.ServerCfg.DetailedErrorsToClient)),
})
return
}
@@ -445,17 +426,17 @@ func (ctl *Control) handlePing(m msg.Message) {
func (ctl *Control) handleNatHoleVisitor(m msg.Message) {
inMsg := m.(*msg.NatHoleVisitor)
ctl.rc.NatHoleController.HandleVisitor(inMsg, ctl.msgTransporter, ctl.loginMsg.User)
ctl.sessionCtx.RC.NatHoleController.HandleVisitor(inMsg, ctl.msgTransporter, ctl.sessionCtx.LoginMsg.User)
}
func (ctl *Control) handleNatHoleClient(m msg.Message) {
inMsg := m.(*msg.NatHoleClient)
ctl.rc.NatHoleController.HandleClient(inMsg, ctl.msgTransporter)
ctl.sessionCtx.RC.NatHoleController.HandleClient(inMsg, ctl.msgTransporter)
}
func (ctl *Control) handleNatHoleReport(m msg.Message) {
inMsg := m.(*msg.NatHoleReport)
ctl.rc.NatHoleController.HandleReport(inMsg)
ctl.sessionCtx.RC.NatHoleController.HandleReport(inMsg)
}
func (ctl *Control) handleCloseProxy(m msg.Message) {
@@ -468,15 +449,15 @@ func (ctl *Control) handleCloseProxy(m msg.Message) {
func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err error) {
var pxyConf v1.ProxyConfigurer
// Load configures from NewProxy message and validate.
pxyConf, err = config.NewProxyConfigurerFromMsg(pxyMsg, ctl.serverCfg)
pxyConf, err = config.NewProxyConfigurerFromMsg(pxyMsg, ctl.sessionCtx.ServerCfg)
if err != nil {
return
}
// User info
userInfo := plugin.UserInfo{
User: ctl.loginMsg.User,
Metas: ctl.loginMsg.Metas,
User: ctl.sessionCtx.LoginMsg.User,
Metas: ctl.sessionCtx.LoginMsg.Metas,
RunID: ctl.runID,
}
@@ -484,22 +465,22 @@ func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err
// In fact, it creates different proxies based on the proxy type. We just call run() here.
pxy, err := proxy.NewProxy(ctl.ctx, &proxy.Options{
UserInfo: userInfo,
LoginMsg: ctl.loginMsg,
LoginMsg: ctl.sessionCtx.LoginMsg,
PoolCount: ctl.poolCount,
ResourceController: ctl.rc,
ResourceController: ctl.sessionCtx.RC,
GetWorkConnFn: ctl.GetWorkConn,
Configurer: pxyConf,
ServerCfg: ctl.serverCfg,
EncryptionKey: ctl.encryptionKey,
ServerCfg: ctl.sessionCtx.ServerCfg,
EncryptionKey: ctl.sessionCtx.EncryptionKey,
})
if err != nil {
return remoteAddr, err
}
// Check ports used number in each client
if ctl.serverCfg.MaxPortsPerClient > 0 {
if ctl.sessionCtx.ServerCfg.MaxPortsPerClient > 0 {
ctl.mu.Lock()
if ctl.portsUsedNum+pxy.GetUsedPortsNum() > int(ctl.serverCfg.MaxPortsPerClient) {
if ctl.portsUsedNum+pxy.GetUsedPortsNum() > int(ctl.sessionCtx.ServerCfg.MaxPortsPerClient) {
ctl.mu.Unlock()
err = fmt.Errorf("exceed the max_ports_per_client")
return
@@ -516,7 +497,7 @@ func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err
}()
}
if ctl.pxyManager.Exist(pxyMsg.ProxyName) {
if ctl.sessionCtx.PxyManager.Exist(pxyMsg.ProxyName) {
err = fmt.Errorf("proxy [%s] already exists", pxyMsg.ProxyName)
return
}
@@ -531,7 +512,7 @@ func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err
}
}()
err = ctl.pxyManager.Add(pxyMsg.ProxyName, pxy)
err = ctl.sessionCtx.PxyManager.Add(pxyMsg.ProxyName, pxy)
if err != nil {
return
}
@@ -550,28 +531,12 @@ func (ctl *Control) CloseProxy(closeMsg *msg.CloseProxy) (err error) {
return
}
if ctl.serverCfg.MaxPortsPerClient > 0 {
if ctl.sessionCtx.ServerCfg.MaxPortsPerClient > 0 {
ctl.portsUsedNum -= pxy.GetUsedPortsNum()
}
pxy.Close()
ctl.pxyManager.Del(pxy.GetName())
delete(ctl.proxies, closeMsg.ProxyName)
ctl.mu.Unlock()
metrics.Server.CloseProxy(pxy.GetName(), pxy.GetConfigurer().GetBaseConfig().Type)
notifyContent := &plugin.CloseProxyContent{
User: plugin.UserInfo{
User: ctl.loginMsg.User,
Metas: ctl.loginMsg.Metas,
RunID: ctl.loginMsg.RunID,
},
CloseProxy: msg.CloseProxy{
ProxyName: pxy.GetName(),
},
}
go func() {
_ = ctl.pluginManager.CloseProxy(notifyContent)
}()
ctl.closeProxy(pxy)
return
}

77
server/group/base.go Normal file
View File

@@ -0,0 +1,77 @@
package group
import (
"net"
"sync"
gerr "github.com/fatedier/golib/errors"
)
// baseGroup contains the shared plumbing for listener-based groups
// (TCP, HTTPS, TCPMux). Each concrete group embeds this and provides
// its own Listen method with protocol-specific validation.
type baseGroup struct {
group string
groupKey string
acceptCh chan net.Conn
realLn net.Listener
lns []*Listener
mu sync.Mutex
cleanupFn func()
}
// initBase resets the baseGroup for a fresh listen cycle.
// Must be called under mu when len(lns) == 0.
func (bg *baseGroup) initBase(group, groupKey string, realLn net.Listener, cleanupFn func()) {
bg.group = group
bg.groupKey = groupKey
bg.realLn = realLn
bg.acceptCh = make(chan net.Conn)
bg.cleanupFn = cleanupFn
}
// worker reads from the real listener and fans out to acceptCh.
// The parameters are captured at creation time so that the worker is
// bound to a specific listen cycle and cannot observe a later initBase.
func (bg *baseGroup) worker(realLn net.Listener, acceptCh chan<- net.Conn) {
for {
c, err := realLn.Accept()
if err != nil {
return
}
err = gerr.PanicToError(func() {
acceptCh <- c
})
if err != nil {
c.Close()
return
}
}
}
// newListener creates a new Listener wired to this baseGroup.
// Must be called under mu.
func (bg *baseGroup) newListener(addr net.Addr) *Listener {
ln := newListener(bg.acceptCh, addr, bg.closeListener)
bg.lns = append(bg.lns, ln)
return ln
}
// closeListener removes ln from the list. When the last listener is removed,
// it closes acceptCh, closes the real listener, and calls cleanupFn.
func (bg *baseGroup) closeListener(ln *Listener) {
bg.mu.Lock()
defer bg.mu.Unlock()
for i, l := range bg.lns {
if l == ln {
bg.lns = append(bg.lns[:i], bg.lns[i+1:]...)
break
}
}
if len(bg.lns) == 0 {
close(bg.acceptCh)
bg.realLn.Close()
bg.cleanupFn()
}
}

169
server/group/base_test.go Normal file
View File

@@ -0,0 +1,169 @@
package group
import (
"net"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// fakeLn is a controllable net.Listener for tests.
type fakeLn struct {
connCh chan net.Conn
closed chan struct{}
once sync.Once
}
func newFakeLn() *fakeLn {
return &fakeLn{
connCh: make(chan net.Conn, 8),
closed: make(chan struct{}),
}
}
func (f *fakeLn) Accept() (net.Conn, error) {
select {
case c := <-f.connCh:
return c, nil
case <-f.closed:
return nil, net.ErrClosed
}
}
func (f *fakeLn) Close() error {
f.once.Do(func() { close(f.closed) })
return nil
}
func (f *fakeLn) Addr() net.Addr { return fakeAddr("127.0.0.1:9999") }
func (f *fakeLn) inject(c net.Conn) {
select {
case f.connCh <- c:
case <-f.closed:
}
}
func TestBaseGroup_WorkerFanOut(t *testing.T) {
fl := newFakeLn()
var bg baseGroup
bg.initBase("g", "key", fl, func() {})
go bg.worker(fl, bg.acceptCh)
c1, c2 := net.Pipe()
defer c2.Close()
fl.inject(c1)
select {
case got := <-bg.acceptCh:
assert.Equal(t, c1, got)
got.Close()
case <-time.After(time.Second):
t.Fatal("timed out waiting for connection on acceptCh")
}
fl.Close()
}
func TestBaseGroup_WorkerStopsOnListenerClose(t *testing.T) {
fl := newFakeLn()
var bg baseGroup
bg.initBase("g", "key", fl, func() {})
done := make(chan struct{})
go func() {
bg.worker(fl, bg.acceptCh)
close(done)
}()
fl.Close()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("worker did not stop after listener close")
}
}
func TestBaseGroup_WorkerClosesConnOnClosedChannel(t *testing.T) {
fl := newFakeLn()
var bg baseGroup
bg.initBase("g", "key", fl, func() {})
// Close acceptCh before worker sends.
close(bg.acceptCh)
done := make(chan struct{})
go func() {
bg.worker(fl, bg.acceptCh)
close(done)
}()
c1, c2 := net.Pipe()
defer c2.Close()
fl.inject(c1)
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("worker did not stop after panic recovery")
}
// c1 should have been closed by worker's panic recovery path.
buf := make([]byte, 1)
_, err := c1.Read(buf)
assert.Error(t, err, "connection should be closed by worker")
}
func TestBaseGroup_CloseLastListenerTriggersCleanup(t *testing.T) {
fl := newFakeLn()
var bg baseGroup
cleanupCalled := 0
bg.initBase("g", "key", fl, func() { cleanupCalled++ })
bg.mu.Lock()
ln1 := bg.newListener(fl.Addr())
ln2 := bg.newListener(fl.Addr())
bg.mu.Unlock()
go bg.worker(fl, bg.acceptCh)
ln1.Close()
assert.Equal(t, 0, cleanupCalled, "cleanup should not run while listeners remain")
ln2.Close()
assert.Equal(t, 1, cleanupCalled, "cleanup should run after last listener closed")
}
func TestBaseGroup_CloseOneOfTwoListeners(t *testing.T) {
fl := newFakeLn()
var bg baseGroup
cleanupCalled := 0
bg.initBase("g", "key", fl, func() { cleanupCalled++ })
bg.mu.Lock()
ln1 := bg.newListener(fl.Addr())
ln2 := bg.newListener(fl.Addr())
bg.mu.Unlock()
go bg.worker(fl, bg.acceptCh)
ln1.Close()
assert.Equal(t, 0, cleanupCalled)
// ln2 should still receive connections.
c1, c2 := net.Pipe()
defer c2.Close()
fl.inject(c1)
got, err := ln2.Accept()
require.NoError(t, err)
assert.Equal(t, c1, got)
got.Close()
ln2.Close()
assert.Equal(t, 1, cleanupCalled)
}

View File

@@ -24,4 +24,6 @@ var (
ErrListenerClosed = errors.New("group listener closed")
ErrGroupDifferentPort = errors.New("group should have same remote port")
ErrProxyRepeated = errors.New("group proxy repeated")
errGroupStale = errors.New("stale group reference")
)

View File

@@ -9,53 +9,42 @@ import (
"github.com/fatedier/frp/pkg/util/vhost"
)
// HTTPGroupController manages HTTP groups that use round-robin
// callback routing (fundamentally different from listener-based groups).
type HTTPGroupController struct {
// groups indexed by group name
groups map[string]*HTTPGroup
// register createConn for each group to vhostRouter.
// createConn will get a connection from one proxy of the group
groupRegistry[*HTTPGroup]
vhostRouter *vhost.Routers
mu sync.Mutex
}
func NewHTTPGroupController(vhostRouter *vhost.Routers) *HTTPGroupController {
return &HTTPGroupController{
groups: make(map[string]*HTTPGroup),
vhostRouter: vhostRouter,
groupRegistry: newGroupRegistry[*HTTPGroup](),
vhostRouter: vhostRouter,
}
}
func (ctl *HTTPGroupController) Register(
proxyName, group, groupKey string,
routeConfig vhost.RouteConfig,
) (err error) {
indexKey := group
ctl.mu.Lock()
g, ok := ctl.groups[indexKey]
if !ok {
g = NewHTTPGroup(ctl)
ctl.groups[indexKey] = g
) error {
for {
g := ctl.getOrCreate(group, func() *HTTPGroup {
return NewHTTPGroup(ctl)
})
err := g.Register(proxyName, group, groupKey, routeConfig)
if err == errGroupStale {
continue
}
return err
}
ctl.mu.Unlock()
return g.Register(proxyName, group, groupKey, routeConfig)
}
func (ctl *HTTPGroupController) UnRegister(proxyName, group string, _ vhost.RouteConfig) {
indexKey := group
ctl.mu.Lock()
defer ctl.mu.Unlock()
g, ok := ctl.groups[indexKey]
g, ok := ctl.get(group)
if !ok {
return
}
isEmpty := g.UnRegister(proxyName)
if isEmpty {
delete(ctl.groups, indexKey)
}
g.UnRegister(proxyName)
}
type HTTPGroup struct {
@@ -87,6 +76,9 @@ func (g *HTTPGroup) Register(
) (err error) {
g.mu.Lock()
defer g.mu.Unlock()
if !g.ctl.isCurrent(group, func(cur *HTTPGroup) bool { return cur == g }) {
return errGroupStale
}
if len(g.createFuncs) == 0 {
// the first proxy in this group
tmp := routeConfig // copy object
@@ -123,7 +115,7 @@ func (g *HTTPGroup) Register(
return nil
}
func (g *HTTPGroup) UnRegister(proxyName string) (isEmpty bool) {
func (g *HTTPGroup) UnRegister(proxyName string) {
g.mu.Lock()
defer g.mu.Unlock()
delete(g.createFuncs, proxyName)
@@ -135,10 +127,11 @@ func (g *HTTPGroup) UnRegister(proxyName string) (isEmpty bool) {
}
if len(g.createFuncs) == 0 {
isEmpty = true
g.ctl.vhostRouter.Del(g.domain, g.location, g.routeByHTTPUser)
g.ctl.removeIf(g.group, func(cur *HTTPGroup) bool {
return cur == g
})
}
return
}
func (g *HTTPGroup) createConn(remoteAddr string) (net.Conn, error) {
@@ -151,7 +144,7 @@ func (g *HTTPGroup) createConn(remoteAddr string) (net.Conn, error) {
location := g.location
routeByHTTPUser := g.routeByHTTPUser
if len(g.pxyNames) > 0 {
name := g.pxyNames[int(newIndex)%len(g.pxyNames)]
name := g.pxyNames[newIndex%uint64(len(g.pxyNames))]
f = g.createFuncs[name]
}
g.mu.RUnlock()
@@ -174,7 +167,7 @@ func (g *HTTPGroup) chooseEndpoint() (string, error) {
location := g.location
routeByHTTPUser := g.routeByHTTPUser
if len(g.pxyNames) > 0 {
name = g.pxyNames[int(newIndex)%len(g.pxyNames)]
name = g.pxyNames[newIndex%uint64(len(g.pxyNames))]
}
g.mu.RUnlock()

View File

@@ -17,25 +17,19 @@ package group
import (
"context"
"net"
"sync"
gerr "github.com/fatedier/golib/errors"
"github.com/fatedier/frp/pkg/util/vhost"
)
type HTTPSGroupController struct {
groups map[string]*HTTPSGroup
groupRegistry[*HTTPSGroup]
httpsMuxer *vhost.HTTPSMuxer
mu sync.Mutex
}
func NewHTTPSGroupController(httpsMuxer *vhost.HTTPSMuxer) *HTTPSGroupController {
return &HTTPSGroupController{
groups: make(map[string]*HTTPSGroup),
httpsMuxer: httpsMuxer,
groupRegistry: newGroupRegistry[*HTTPSGroup](),
httpsMuxer: httpsMuxer,
}
}
@@ -44,41 +38,28 @@ func (ctl *HTTPSGroupController) Listen(
group, groupKey string,
routeConfig vhost.RouteConfig,
) (l net.Listener, err error) {
indexKey := group
ctl.mu.Lock()
g, ok := ctl.groups[indexKey]
if !ok {
g = NewHTTPSGroup(ctl)
ctl.groups[indexKey] = g
for {
g := ctl.getOrCreate(group, func() *HTTPSGroup {
return NewHTTPSGroup(ctl)
})
l, err = g.Listen(ctx, group, groupKey, routeConfig)
if err == errGroupStale {
continue
}
return
}
ctl.mu.Unlock()
return g.Listen(ctx, group, groupKey, routeConfig)
}
func (ctl *HTTPSGroupController) RemoveGroup(group string) {
ctl.mu.Lock()
defer ctl.mu.Unlock()
delete(ctl.groups, group)
}
type HTTPSGroup struct {
group string
groupKey string
domain string
baseGroup
acceptCh chan net.Conn
httpsLn *vhost.Listener
lns []*HTTPSGroupListener
ctl *HTTPSGroupController
mu sync.Mutex
domain string
ctl *HTTPSGroupController
}
func NewHTTPSGroup(ctl *HTTPSGroupController) *HTTPSGroup {
return &HTTPSGroup{
lns: make([]*HTTPSGroupListener, 0),
ctl: ctl,
acceptCh: make(chan net.Conn),
ctl: ctl,
}
}
@@ -86,23 +67,27 @@ func (g *HTTPSGroup) Listen(
ctx context.Context,
group, groupKey string,
routeConfig vhost.RouteConfig,
) (ln *HTTPSGroupListener, err error) {
) (ln *Listener, err error) {
g.mu.Lock()
defer g.mu.Unlock()
if !g.ctl.isCurrent(group, func(cur *HTTPSGroup) bool { return cur == g }) {
return nil, errGroupStale
}
if len(g.lns) == 0 {
// the first listener, listen on the real address
httpsLn, errRet := g.ctl.httpsMuxer.Listen(ctx, &routeConfig)
if errRet != nil {
return nil, errRet
}
ln = newHTTPSGroupListener(group, g, httpsLn.Addr())
g.group = group
g.groupKey = groupKey
g.domain = routeConfig.Domain
g.httpsLn = httpsLn
g.lns = append(g.lns, ln)
go g.worker()
g.initBase(group, groupKey, httpsLn, func() {
g.ctl.removeIf(g.group, func(cur *HTTPSGroup) bool {
return cur == g
})
})
ln = g.newListener(httpsLn.Addr())
go g.worker(httpsLn, g.acceptCh)
} else {
// route config in the same group must be equal
if g.group != group || g.domain != routeConfig.Domain {
@@ -111,87 +96,7 @@ func (g *HTTPSGroup) Listen(
if g.groupKey != groupKey {
return nil, ErrGroupAuthFailed
}
ln = newHTTPSGroupListener(group, g, g.lns[0].Addr())
g.lns = append(g.lns, ln)
ln = g.newListener(g.lns[0].Addr())
}
return
}
func (g *HTTPSGroup) worker() {
for {
c, err := g.httpsLn.Accept()
if err != nil {
return
}
err = gerr.PanicToError(func() {
g.acceptCh <- c
})
if err != nil {
return
}
}
}
func (g *HTTPSGroup) Accept() <-chan net.Conn {
return g.acceptCh
}
func (g *HTTPSGroup) CloseListener(ln *HTTPSGroupListener) {
g.mu.Lock()
defer g.mu.Unlock()
for i, tmpLn := range g.lns {
if tmpLn == ln {
g.lns = append(g.lns[:i], g.lns[i+1:]...)
break
}
}
if len(g.lns) == 0 {
close(g.acceptCh)
if g.httpsLn != nil {
g.httpsLn.Close()
}
g.ctl.RemoveGroup(g.group)
}
}
type HTTPSGroupListener struct {
groupName string
group *HTTPSGroup
addr net.Addr
closeCh chan struct{}
}
func newHTTPSGroupListener(name string, group *HTTPSGroup, addr net.Addr) *HTTPSGroupListener {
return &HTTPSGroupListener{
groupName: name,
group: group,
addr: addr,
closeCh: make(chan struct{}),
}
}
func (ln *HTTPSGroupListener) Accept() (c net.Conn, err error) {
var ok bool
select {
case <-ln.closeCh:
return nil, ErrListenerClosed
case c, ok = <-ln.group.Accept():
if !ok {
return nil, ErrListenerClosed
}
return c, nil
}
}
func (ln *HTTPSGroupListener) Addr() net.Addr {
return ln.addr
}
func (ln *HTTPSGroupListener) Close() (err error) {
close(ln.closeCh)
// remove self from HTTPSGroup
ln.group.CloseListener(ln)
return
}

49
server/group/listener.go Normal file
View File

@@ -0,0 +1,49 @@
package group
import (
"net"
"sync"
)
// Listener is a per-proxy virtual listener that receives connections
// from a shared group. It implements net.Listener.
type Listener struct {
acceptCh <-chan net.Conn
addr net.Addr
closeCh chan struct{}
onClose func(*Listener)
once sync.Once
}
func newListener(acceptCh <-chan net.Conn, addr net.Addr, onClose func(*Listener)) *Listener {
return &Listener{
acceptCh: acceptCh,
addr: addr,
closeCh: make(chan struct{}),
onClose: onClose,
}
}
func (ln *Listener) Accept() (net.Conn, error) {
select {
case <-ln.closeCh:
return nil, ErrListenerClosed
case c, ok := <-ln.acceptCh:
if !ok {
return nil, ErrListenerClosed
}
return c, nil
}
}
func (ln *Listener) Addr() net.Addr {
return ln.addr
}
func (ln *Listener) Close() error {
ln.once.Do(func() {
close(ln.closeCh)
ln.onClose(ln)
})
return nil
}

View File

@@ -0,0 +1,68 @@
package group
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestListener_Accept(t *testing.T) {
acceptCh := make(chan net.Conn, 1)
ln := newListener(acceptCh, fakeAddr("127.0.0.1:1234"), func(*Listener) {})
c1, c2 := net.Pipe()
defer c1.Close()
defer c2.Close()
acceptCh <- c1
got, err := ln.Accept()
require.NoError(t, err)
assert.Equal(t, c1, got)
}
func TestListener_AcceptAfterChannelClose(t *testing.T) {
acceptCh := make(chan net.Conn)
ln := newListener(acceptCh, fakeAddr("127.0.0.1:1234"), func(*Listener) {})
close(acceptCh)
_, err := ln.Accept()
assert.ErrorIs(t, err, ErrListenerClosed)
}
func TestListener_AcceptAfterListenerClose(t *testing.T) {
acceptCh := make(chan net.Conn) // open, not closed
ln := newListener(acceptCh, fakeAddr("127.0.0.1:1234"), func(*Listener) {})
ln.Close()
_, err := ln.Accept()
assert.ErrorIs(t, err, ErrListenerClosed)
}
func TestListener_DoubleClose(t *testing.T) {
closeCalls := 0
ln := newListener(
make(chan net.Conn),
fakeAddr("127.0.0.1:1234"),
func(*Listener) { closeCalls++ },
)
assert.NotPanics(t, func() {
ln.Close()
ln.Close()
})
assert.Equal(t, 1, closeCalls, "onClose should be called exactly once")
}
func TestListener_Addr(t *testing.T) {
addr := fakeAddr("10.0.0.1:5555")
ln := newListener(make(chan net.Conn), addr, func(*Listener) {})
assert.Equal(t, addr, ln.Addr())
}
// fakeAddr implements net.Addr for testing.
type fakeAddr string
func (a fakeAddr) Network() string { return "tcp" }
func (a fakeAddr) String() string { return string(a) }

Some files were not shown because too many files have changed in this diff Show More