diff --git a/pkg/auth/oidc.go b/pkg/auth/oidc.go index d9377f32..9aeaf3c5 100644 --- a/pkg/auth/oidc.go +++ b/pkg/auth/oidc.go @@ -23,6 +23,7 @@ import ( "net/url" "os" "slices" + "sync" "github.com/coreos/go-oidc/v3/oidc" "golang.org/x/oauth2" @@ -205,7 +206,8 @@ type OidcAuthConsumer struct { additionalAuthScopes []v1.AuthScope verifier TokenVerifier - subjectsFromLogin []string + mu sync.RWMutex + subjectsFromLogin map[string]struct{} } func NewTokenVerifier(cfg v1.AuthOIDCServerConfig) TokenVerifier { @@ -226,7 +228,7 @@ func NewOidcAuthVerifier(additionalAuthScopes []v1.AuthScope, verifier TokenVeri return &OidcAuthConsumer{ additionalAuthScopes: additionalAuthScopes, verifier: verifier, - subjectsFromLogin: []string{}, + subjectsFromLogin: make(map[string]struct{}), } } @@ -235,9 +237,9 @@ func (auth *OidcAuthConsumer) VerifyLogin(loginMsg *msg.Login) (err error) { if err != nil { return fmt.Errorf("invalid OIDC token in login: %v", err) } - if !slices.Contains(auth.subjectsFromLogin, token.Subject) { - auth.subjectsFromLogin = append(auth.subjectsFromLogin, token.Subject) - } + auth.mu.Lock() + auth.subjectsFromLogin[token.Subject] = struct{}{} + auth.mu.Unlock() return nil } @@ -246,11 +248,13 @@ func (auth *OidcAuthConsumer) verifyPostLoginToken(privilegeKey string) (err err if err != nil { return fmt.Errorf("invalid OIDC token in ping: %v", err) } - if !slices.Contains(auth.subjectsFromLogin, token.Subject) { + auth.mu.RLock() + _, ok := auth.subjectsFromLogin[token.Subject] + auth.mu.RUnlock() + if !ok { return fmt.Errorf("received different OIDC subject in login and ping. "+ - "original subjects: %s, "+ "new subject: %s", - auth.subjectsFromLogin, token.Subject) + token.Subject) } return nil } diff --git a/pkg/util/net/websocket.go b/pkg/util/net/websocket.go index 3ca8b332..6c2f39c4 100644 --- a/pkg/util/net/websocket.go +++ b/pkg/util/net/websocket.go @@ -26,6 +26,7 @@ type WebsocketListener struct { // ln: tcp listener for websocket connections func NewWebsocketListener(ln net.Listener) (wl *WebsocketListener) { wl = &WebsocketListener{ + ln: ln, acceptCh: make(chan net.Conn), }