package proxy import ( "context" "net/http" "net/url" "github.com/coreos/go-oidc/v3/oidc" "github.com/gorilla/mux" log "github.com/sirupsen/logrus" "golang.org/x/oauth2" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/proxy/auth" ) type AuthCallbackHandler struct { proxyService *nbgrpc.ProxyServiceServer } func NewAuthCallbackHandler(proxyService *nbgrpc.ProxyServiceServer) *AuthCallbackHandler { return &AuthCallbackHandler{ proxyService: proxyService, } } func (h *AuthCallbackHandler) RegisterEndpoints(router *mux.Router) { router.HandleFunc("/oauth/callback", h.handleCallback).Methods(http.MethodGet) } func (h *AuthCallbackHandler) handleCallback(w http.ResponseWriter, r *http.Request) { state := r.URL.Query().Get("state") codeVerifier, originalURL, err := h.proxyService.ValidateState(state) if err != nil { log.WithError(err).Error("OAuth callback state validation failed") http.Error(w, "Invalid state parameter", http.StatusBadRequest) return } redirectURL, err := url.Parse(originalURL) if err != nil { log.WithError(err).Error("Failed to parse redirect URL") http.Error(w, "Invalid redirect URL", http.StatusBadRequest) return } // Get OIDC configuration oidcConfig := h.proxyService.GetOIDCConfig() // Create OIDC provider to discover endpoints provider, err := oidc.NewProvider(r.Context(), oidcConfig.Issuer) if err != nil { log.WithError(err).Error("Failed to create OIDC provider") http.Error(w, "Failed to create OIDC provider", http.StatusInternalServerError) return } token, err := (&oauth2.Config{ ClientID: oidcConfig.ClientID, Endpoint: provider.Endpoint(), RedirectURL: oidcConfig.CallbackURL, }).Exchange(r.Context(), r.URL.Query().Get("code"), oauth2.VerifierOption(codeVerifier)) if err != nil { log.WithError(err).Error("Failed to exchange code for token") http.Error(w, "Failed to exchange code for token", http.StatusInternalServerError) return } // Extract user ID from the OIDC token userID := extractUserIDFromToken(r.Context(), provider, oidcConfig, token) if userID == "" { log.Error("Failed to extract user ID from OIDC token") http.Error(w, "Failed to validate token", http.StatusUnauthorized) return } // Generate session JWT instead of passing OIDC access_token sessionToken, err := h.proxyService.GenerateSessionToken(r.Context(), redirectURL.Hostname(), userID, auth.MethodOIDC) if err != nil { log.WithError(err).Error("Failed to create session token") http.Error(w, "Failed to create session", http.StatusInternalServerError) return } // Redirect must be HTTPS, regardless of what was originally intended (which should always be HTTPS but better to double-check here). redirectURL.Scheme = "https" // Pass the session token in the URL query parameter. The proxy middleware will // extract it, validate it, set its own cookie, and redirect to remove the token from the URL. // We cannot set the cookie here because cookies are domain-scoped (RFC 6265) and the // management server cannot set cookies for the proxy's domain. query := redirectURL.Query() query.Set("session_token", sessionToken) redirectURL.RawQuery = query.Encode() log.WithField("redirect", redirectURL.Host).Debug("OAuth callback: redirecting user with session token") http.Redirect(w, r, redirectURL.String(), http.StatusFound) } // extractUserIDFromToken extracts the user ID from an OIDC token. func extractUserIDFromToken(ctx context.Context, provider *oidc.Provider, config nbgrpc.ProxyOIDCConfig, token *oauth2.Token) string { // Try to get ID token from the oauth2 token extras rawIDToken, ok := token.Extra("id_token").(string) if !ok { log.Warn("No id_token in OIDC response") return "" } verifier := provider.Verifier(&oidc.Config{ ClientID: config.ClientID, }) idToken, err := verifier.Verify(ctx, rawIDToken) if err != nil { log.WithError(err).Warn("Failed to verify ID token") return "" } // Extract claims var claims struct { Subject string `json:"sub"` Email string `json:"email"` UserID string `json:"user_id"` } if err := idToken.Claims(&claims); err != nil { log.WithError(err).Warn("Failed to extract claims from ID token") return "" } // Prefer subject, fall back to user_id or email if claims.Subject != "" { return claims.Subject } if claims.UserID != "" { return claims.UserID } if claims.Email != "" { return claims.Email } return "" }