diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index 3a9365fbe..9e4b77cc1 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -9,6 +9,7 @@ import ( "encoding/hex" "errors" "fmt" + "net/url" "strings" "sync" "time" @@ -407,6 +408,29 @@ func protoStatusToInternal(protoStatus proto.ProxyStatus) reverseproxy.ProxyStat } func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCURLRequest) (*proto.GetOIDCURLResponse, error) { + redirectURL, err := url.Parse(req.GetRedirectUrl()) + if err != nil { + // TODO: log + return nil, status.Errorf(codes.InvalidArgument, "failed to parse redirect url: %v", err) + } + // Validate redirectURL against known proxy endpoints to avoid abuse of OIDC redirection. + proxies, err := s.reverseProxyStore.GetAccountReverseProxies(ctx, store.LockingStrengthNone, req.GetAccountId()) + if err != nil { + // TODO: log + return nil, status.Errorf(codes.FailedPrecondition, "failed to get reverse proxy from store: %v", err) + } + var found bool + for _, proxy := range proxies { + if proxy.Domain == redirectURL.Hostname() { + found = true + break + } + } + if !found { + // TODO: log + return nil, status.Errorf(codes.FailedPrecondition, "reverse proxy not found in store") + } + provider, err := oidc.NewProvider(ctx, s.oidcConfig.Issuer) if err != nil { // TODO: log @@ -420,9 +444,8 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU // Using an HMAC here to avoid redirection state being modified. // State format: base64(redirectURL)|hmac - redirectURL := req.GetRedirectUrl() - hmacSum := s.generateHMAC(redirectURL) - state := fmt.Sprintf("%s|%s", base64.URLEncoding.EncodeToString([]byte(redirectURL)), hmacSum) + hmacSum := s.generateHMAC(redirectURL.String()) + state := fmt.Sprintf("%s|%s", base64.URLEncoding.EncodeToString([]byte(redirectURL.String())), hmacSum) codeVerifier := oauth2.GenerateVerifier() s.pkceVerifiers.Store(state, codeVerifier) diff --git a/management/server/http/handlers/proxy/auth.go b/management/server/http/handlers/proxy/auth.go index 8d7c628ee..d525ca5fe 100644 --- a/management/server/http/handlers/proxy/auth.go +++ b/management/server/http/handlers/proxy/auth.go @@ -72,6 +72,9 @@ func (h *AuthCallbackHandler) handleCallback(w http.ResponseWriter, r *http.Requ } redirectURL.RawQuery = redirectQuery.Encode() - log.WithField("redirect", redirectURL).Debug("OAuth callback: redirecting user with token") + // Redirect must be HTTPS, regardless of what was originally intended (which should always be HTTPS but better to double-check here). + redirectURL.Scheme = "https" + + log.WithField("redirect", redirectURL.String()).Debug("OAuth callback: redirecting user with token") http.Redirect(w, r, redirectURL.String(), http.StatusFound) } diff --git a/proxy/internal/auth/oidc.go b/proxy/internal/auth/oidc.go index 55fd80367..4c5bfe8cb 100644 --- a/proxy/internal/auth/oidc.go +++ b/proxy/internal/auth/oidc.go @@ -4,8 +4,6 @@ import ( "context" "net/http" "net/url" - "strings" - "sync" "time" gojwt "github.com/golang-jwt/jwt/v5" @@ -15,8 +13,6 @@ import ( "github.com/netbirdio/netbird/shared/auth/jwt" ) -const stateExpiration = 10 * time.Minute - type urlGenerator interface { GetOIDCURL(context.Context, *proto.GetOIDCURLRequest, ...grpc.CallOption) (*proto.GetOIDCURLResponse, error) } @@ -41,13 +37,11 @@ type OIDC struct { validator *jwt.Validator maxTokenAgeSeconds int64 client urlGenerator - states map[string]*oidcState - statesMux sync.RWMutex } // NewOIDC creates a new OIDC authentication scheme func NewOIDC(client urlGenerator, id, accountId string, cfg OIDCConfig) *OIDC { - o := &OIDC{ + return &OIDC{ id: id, accountId: accountId, validator: jwt.NewValidator( @@ -58,12 +52,7 @@ func NewOIDC(client urlGenerator, id, accountId string, cfg OIDCConfig) *OIDC { ), maxTokenAgeSeconds: cfg.MaxTokenAgeSeconds, client: client, - states: make(map[string]*oidcState), } - - go o.cleanupStates() - - return o } func (*OIDC) Type() Method { @@ -71,9 +60,8 @@ func (*OIDC) Type() Method { } func (o *OIDC) Authenticate(r *http.Request) (string, string) { - // Try Authorization: Bearer header - if auth := r.Header.Get("Authorization"); strings.HasPrefix(auth, "Bearer ") { - if userID := o.validateToken(r.Context(), strings.TrimPrefix(auth, "Bearer ")); userID != "" { + if token := r.URL.Query().Get("access_token"); token != "" { + if userID := o.validateToken(r.Context(), token); userID != "" { return userID, "" } } @@ -116,7 +104,8 @@ func (o *OIDC) validateToken(ctx context.Context, token string) string { return "" } - if time.Since(iat.Time).Seconds() > float64(o.maxTokenAgeSeconds) { + // If max token age is 0 skip this check. + if o.maxTokenAgeSeconds > 0 && time.Since(iat.Time).Seconds() > float64(o.maxTokenAgeSeconds) { // TODO: log or return? return "" } @@ -147,17 +136,3 @@ func getUserIDFromClaims(claims gojwt.MapClaims) string { } return "unknown" } - -// cleanupStates periodically removes expired states -func (o *OIDC) cleanupStates() { - for range time.Tick(time.Minute) { - cutoff := time.Now().Add(-stateExpiration) - o.statesMux.Lock() - for k, v := range o.states { - if v.CreatedAt.Before(cutoff) { - delete(o.states, k) - } - } - o.statesMux.Unlock() - } -}