diff --git a/pkg/config/source/aggregator.go b/pkg/config/source/aggregator.go index f3be67bd..58496932 100644 --- a/pkg/config/source/aggregator.go +++ b/pkg/config/source/aggregator.go @@ -15,9 +15,11 @@ package source import ( + "cmp" "errors" "fmt" - "sort" + "maps" + "slices" "sync" v1 "github.com/fatedier/frp/pkg/config/v1" @@ -97,21 +99,11 @@ func (a *Aggregator) mapsToSortedSlices( proxyMap map[string]v1.ProxyConfigurer, visitorMap map[string]v1.VisitorConfigurer, ) ([]v1.ProxyConfigurer, []v1.VisitorConfigurer) { - proxies := make([]v1.ProxyConfigurer, 0, len(proxyMap)) - for _, p := range proxyMap { - proxies = append(proxies, p) - } - sort.Slice(proxies, func(i, j int) bool { - return proxies[i].GetBaseConfig().Name < proxies[j].GetBaseConfig().Name + proxies := slices.SortedFunc(maps.Values(proxyMap), func(x, y v1.ProxyConfigurer) int { + return cmp.Compare(x.GetBaseConfig().Name, y.GetBaseConfig().Name) }) - - visitors := make([]v1.VisitorConfigurer, 0, len(visitorMap)) - for _, v := range visitorMap { - visitors = append(visitors, v) - } - sort.Slice(visitors, func(i, j int) bool { - return visitors[i].GetBaseConfig().Name < visitors[j].GetBaseConfig().Name + visitors := slices.SortedFunc(maps.Values(visitorMap), func(x, y v1.VisitorConfigurer) int { + return cmp.Compare(x.GetBaseConfig().Name, y.GetBaseConfig().Name) }) - return proxies, visitors } diff --git a/pkg/config/source/aggregator_test.go b/pkg/config/source/aggregator_test.go index 5fc9636a..380c05cf 100644 --- a/pkg/config/source/aggregator_test.go +++ b/pkg/config/source/aggregator_test.go @@ -196,6 +196,27 @@ func TestAggregator_VisitorMerge(t *testing.T) { require.Len(visitors, 2) } +func TestAggregator_Load_ReturnsSortedByName(t *testing.T) { + require := require.New(t) + + agg := newTestAggregator(t, nil) + err := agg.ConfigSource().ReplaceAll( + []v1.ProxyConfigurer{mockProxy("charlie"), mockProxy("alice"), mockProxy("bob")}, + []v1.VisitorConfigurer{mockVisitor("zulu"), mockVisitor("alpha")}, + ) + require.NoError(err) + + proxies, visitors, err := agg.Load() + require.NoError(err) + require.Len(proxies, 3) + require.Equal("alice", proxies[0].GetBaseConfig().Name) + require.Equal("bob", proxies[1].GetBaseConfig().Name) + require.Equal("charlie", proxies[2].GetBaseConfig().Name) + require.Len(visitors, 2) + require.Equal("alpha", visitors[0].GetBaseConfig().Name) + require.Equal("zulu", visitors[1].GetBaseConfig().Name) +} + func TestAggregator_Load_ReturnsDefensiveCopies(t *testing.T) { require := require.New(t) diff --git a/pkg/config/types/types.go b/pkg/config/types/types.go index 8fa3105a..5b2b6930 100644 --- a/pkg/config/types/types.go +++ b/pkg/config/types/types.go @@ -70,24 +70,18 @@ func (q *BandwidthQuantity) UnmarshalString(s string) error { f float64 err error ) - switch { - case strings.HasSuffix(s, "MB"): + if fstr, ok := strings.CutSuffix(s, "MB"); ok { base = MB - fstr := strings.TrimSuffix(s, "MB") f, err = strconv.ParseFloat(fstr, 64) - if err != nil { - return err - } - case strings.HasSuffix(s, "KB"): + } else if fstr, ok := strings.CutSuffix(s, "KB"); ok { base = KB - fstr := strings.TrimSuffix(s, "KB") f, err = strconv.ParseFloat(fstr, 64) - if err != nil { - return err - } - default: + } else { return errors.New("unit not support") } + if err != nil { + return err + } q.s = s q.i = int64(f * float64(base)) diff --git a/pkg/config/types/types_test.go b/pkg/config/types/types_test.go index 8843de5a..c05ac9ee 100644 --- a/pkg/config/types/types_test.go +++ b/pkg/config/types/types_test.go @@ -39,6 +39,31 @@ func TestBandwidthQuantity(t *testing.T) { require.Equal(`{"b":"1KB","int":5}`, string(buf)) } +func TestBandwidthQuantity_MB(t *testing.T) { + require := require.New(t) + + var w Wrap + err := json.Unmarshal([]byte(`{"b":"2MB","int":1}`), &w) + require.NoError(err) + require.EqualValues(2*MB, w.B.Bytes()) + + buf, err := json.Marshal(&w) + require.NoError(err) + require.Equal(`{"b":"2MB","int":1}`, string(buf)) +} + +func TestBandwidthQuantity_InvalidUnit(t *testing.T) { + var w Wrap + err := json.Unmarshal([]byte(`{"b":"1GB","int":1}`), &w) + require.Error(t, err) +} + +func TestBandwidthQuantity_InvalidNumber(t *testing.T) { + var w Wrap + err := json.Unmarshal([]byte(`{"b":"abcKB","int":1}`), &w) + require.Error(t, err) +} + func TestPortsRangeSlice2String(t *testing.T) { require := require.New(t)