diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index c01d7b316..36e28504d 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -163,7 +163,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server { func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer { return Create(s, func() *nbgrpc.ProxyServiceServer { - proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager()) + proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager()) s.AfterInit(func(s *BaseServer) { proxyService.SetProxyManager(s.ReverseProxyManager()) }) diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index 7e49ece55..c99403684 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -26,6 +26,7 @@ import ( "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey" + "github.com/netbirdio/netbird/management/server/users" proxyauth "github.com/netbirdio/netbird/proxy/auth" "github.com/netbirdio/netbird/shared/management/proto" ) @@ -69,6 +70,9 @@ type ProxyServiceServer struct { // Manager for peers peersManager peers.Manager + // Manager for users + usersManager users.Manager + // Store for one-time authentication tokens tokenStore *OneTimeTokenStore @@ -90,14 +94,15 @@ type proxyConnection struct { mu sync.RWMutex } -// NewProxyServiceServer creates a new proxy service server -func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager) *ProxyServiceServer { +// NewProxyServiceServer creates a new proxy service server. +func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager) *ProxyServiceServer { return &ProxyServiceServer{ updatesChan: make(chan *proto.ProxyMapping, 100), accessLogManager: accessLogMgr, oidcConfig: oidcConfig, tokenStore: tokenStore, peersManager: peersManager, + usersManager: usersManager, } } @@ -733,3 +738,60 @@ func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, u proxyauth.DefaultSessionExpiry, ) } + +// ValidateUserGroupAccess checks if a user has access to a reverse proxy. +// It looks up the proxy within the user's account only, then optionally checks +// group membership if BearerAuth with DistributionGroups is configured. +func (s *ProxyServiceServer) ValidateUserGroupAccess(ctx context.Context, domain, userID string) error { + user, err := s.usersManager.GetUser(ctx, userID) + if err != nil { + return fmt.Errorf("user not found: %s", userID) + } + + proxy, err := s.getAccountProxyByDomain(ctx, user.AccountID, domain) + if err != nil { + return err + } + + if proxy.Auth.BearerAuth == nil || !proxy.Auth.BearerAuth.Enabled { + return nil + } + + allowedGroups := proxy.Auth.BearerAuth.DistributionGroups + if len(allowedGroups) == 0 { + return nil + } + + allowedSet := make(map[string]bool, len(allowedGroups)) + for _, groupID := range allowedGroups { + allowedSet[groupID] = true + } + + for _, groupID := range user.AutoGroups { + if allowedSet[groupID] { + log.WithFields(log.Fields{ + "user_id": user.Id, + "group_id": groupID, + "domain": domain, + }).Debug("User granted access via group membership") + return nil + } + } + + return fmt.Errorf("user %s not in allowed groups for domain %s", user.Id, domain) +} + +func (s *ProxyServiceServer) getAccountProxyByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.ReverseProxy, error) { + proxies, err := s.reverseProxyManager.GetAccountReverseProxies(ctx, accountID) + if err != nil { + return nil, fmt.Errorf("get account reverse proxies: %w", err) + } + + for _, proxy := range proxies { + if proxy.Domain == domain { + return proxy, nil + } + } + + return nil, fmt.Errorf("reverse proxy not found for domain %s in account %s", domain, accountID) +} diff --git a/management/internals/shared/grpc/proxy_group_access_test.go b/management/internals/shared/grpc/proxy_group_access_test.go new file mode 100644 index 000000000..ba208dd77 --- /dev/null +++ b/management/internals/shared/grpc/proxy_group_access_test.go @@ -0,0 +1,377 @@ +package grpc + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" + "github.com/netbirdio/netbird/management/server/types" +) + +type mockReverseProxyManager struct { + proxiesByAccount map[string][]*reverseproxy.ReverseProxy + err error +} + +func (m *mockReverseProxyManager) GetAccountReverseProxies(ctx context.Context, accountID string) ([]*reverseproxy.ReverseProxy, error) { + if m.err != nil { + return nil, m.err + } + return m.proxiesByAccount[accountID], nil +} + +func (m *mockReverseProxyManager) GetGlobalReverseProxies(ctx context.Context) ([]*reverseproxy.ReverseProxy, error) { + return nil, nil +} + +func (m *mockReverseProxyManager) GetAllReverseProxies(ctx context.Context, accountID, userID string) ([]*reverseproxy.ReverseProxy, error) { + return nil, nil +} + +func (m *mockReverseProxyManager) GetReverseProxy(ctx context.Context, accountID, userID, reverseProxyID string) (*reverseproxy.ReverseProxy, error) { + return nil, nil +} + +func (m *mockReverseProxyManager) CreateReverseProxy(ctx context.Context, accountID, userID string, rp *reverseproxy.ReverseProxy) (*reverseproxy.ReverseProxy, error) { + return nil, nil +} + +func (m *mockReverseProxyManager) UpdateReverseProxy(ctx context.Context, accountID, userID string, rp *reverseproxy.ReverseProxy) (*reverseproxy.ReverseProxy, error) { + return nil, nil +} + +func (m *mockReverseProxyManager) DeleteReverseProxy(ctx context.Context, accountID, userID, reverseProxyID string) error { + return nil +} + +func (m *mockReverseProxyManager) SetCertificateIssuedAt(ctx context.Context, accountID, reverseProxyID string) error { + return nil +} + +func (m *mockReverseProxyManager) SetStatus(ctx context.Context, accountID, reverseProxyID string, status reverseproxy.ProxyStatus) error { + return nil +} + +func (m *mockReverseProxyManager) ReloadAllReverseProxiesForAccount(ctx context.Context, accountID string) error { + return nil +} + +func (m *mockReverseProxyManager) ReloadReverseProxy(ctx context.Context, accountID, reverseProxyID string) error { + return nil +} + +func (m *mockReverseProxyManager) GetProxyByID(ctx context.Context, accountID, reverseProxyID string) (*reverseproxy.ReverseProxy, error) { + return nil, nil +} + +type mockUsersManager struct { + users map[string]*types.User + err error +} + +func (m *mockUsersManager) GetUser(ctx context.Context, userID string) (*types.User, error) { + if m.err != nil { + return nil, m.err + } + user, ok := m.users[userID] + if !ok { + return nil, errors.New("user not found") + } + return user, nil +} + +func TestValidateUserGroupAccess(t *testing.T) { + tests := []struct { + name string + domain string + userID string + proxiesByAccount map[string][]*reverseproxy.ReverseProxy + users map[string]*types.User + proxyErr error + userErr error + expectErr bool + expectErrMsg string + }{ + { + name: "user not found", + domain: "app.example.com", + userID: "unknown-user", + proxiesByAccount: map[string][]*reverseproxy.ReverseProxy{ + "account1": {{Domain: "app.example.com", AccountID: "account1"}}, + }, + users: map[string]*types.User{}, + expectErr: true, + expectErrMsg: "user not found", + }, + { + name: "proxy not found in user's account", + domain: "app.example.com", + userID: "user1", + proxiesByAccount: map[string][]*reverseproxy.ReverseProxy{}, + users: map[string]*types.User{ + "user1": {Id: "user1", AccountID: "account1"}, + }, + expectErr: true, + expectErrMsg: "reverse proxy not found", + }, + { + name: "proxy exists in different account - not accessible", + domain: "app.example.com", + userID: "user1", + proxiesByAccount: map[string][]*reverseproxy.ReverseProxy{ + "account2": {{Domain: "app.example.com", AccountID: "account2"}}, + }, + users: map[string]*types.User{ + "user1": {Id: "user1", AccountID: "account1"}, + }, + expectErr: true, + expectErrMsg: "reverse proxy not found", + }, + { + name: "no bearer auth configured - same account allows access", + domain: "app.example.com", + userID: "user1", + proxiesByAccount: map[string][]*reverseproxy.ReverseProxy{ + "account1": {{Domain: "app.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}}}, + }, + users: map[string]*types.User{ + "user1": {Id: "user1", AccountID: "account1"}, + }, + expectErr: false, + }, + { + name: "bearer auth disabled - same account allows access", + domain: "app.example.com", + userID: "user1", + proxiesByAccount: map[string][]*reverseproxy.ReverseProxy{ + "account1": {{ + Domain: "app.example.com", + AccountID: "account1", + Auth: reverseproxy.AuthConfig{ + BearerAuth: &reverseproxy.BearerAuthConfig{Enabled: false}, + }, + }}, + }, + users: map[string]*types.User{ + "user1": {Id: "user1", AccountID: "account1"}, + }, + expectErr: false, + }, + { + name: "bearer auth enabled but no groups configured - same account allows access", + domain: "app.example.com", + userID: "user1", + proxiesByAccount: map[string][]*reverseproxy.ReverseProxy{ + "account1": {{ + Domain: "app.example.com", + AccountID: "account1", + Auth: reverseproxy.AuthConfig{ + BearerAuth: &reverseproxy.BearerAuthConfig{ + Enabled: true, + DistributionGroups: []string{}, + }, + }, + }}, + }, + users: map[string]*types.User{ + "user1": {Id: "user1", AccountID: "account1"}, + }, + expectErr: false, + }, + { + name: "user not in allowed groups", + domain: "app.example.com", + userID: "user1", + proxiesByAccount: map[string][]*reverseproxy.ReverseProxy{ + "account1": {{ + Domain: "app.example.com", + AccountID: "account1", + Auth: reverseproxy.AuthConfig{ + BearerAuth: &reverseproxy.BearerAuthConfig{ + Enabled: true, + DistributionGroups: []string{"group1", "group2"}, + }, + }, + }}, + }, + users: map[string]*types.User{ + "user1": {Id: "user1", AccountID: "account1", AutoGroups: []string{"group3", "group4"}}, + }, + expectErr: true, + expectErrMsg: "not in allowed groups", + }, + { + name: "user in one of the allowed groups - allow access", + domain: "app.example.com", + userID: "user1", + proxiesByAccount: map[string][]*reverseproxy.ReverseProxy{ + "account1": {{ + Domain: "app.example.com", + AccountID: "account1", + Auth: reverseproxy.AuthConfig{ + BearerAuth: &reverseproxy.BearerAuthConfig{ + Enabled: true, + DistributionGroups: []string{"group1", "group2"}, + }, + }, + }}, + }, + users: map[string]*types.User{ + "user1": {Id: "user1", AccountID: "account1", AutoGroups: []string{"group2", "group3"}}, + }, + expectErr: false, + }, + { + name: "user in all allowed groups - allow access", + domain: "app.example.com", + userID: "user1", + proxiesByAccount: map[string][]*reverseproxy.ReverseProxy{ + "account1": {{ + Domain: "app.example.com", + AccountID: "account1", + Auth: reverseproxy.AuthConfig{ + BearerAuth: &reverseproxy.BearerAuthConfig{ + Enabled: true, + DistributionGroups: []string{"group1", "group2"}, + }, + }, + }}, + }, + users: map[string]*types.User{ + "user1": {Id: "user1", AccountID: "account1", AutoGroups: []string{"group1", "group2", "group3"}}, + }, + expectErr: false, + }, + { + name: "proxy manager error", + domain: "app.example.com", + userID: "user1", + proxiesByAccount: nil, + proxyErr: errors.New("database error"), + users: map[string]*types.User{ + "user1": {Id: "user1", AccountID: "account1"}, + }, + expectErr: true, + expectErrMsg: "get account reverse proxies", + }, + { + name: "multiple proxies in account - finds correct one", + domain: "app2.example.com", + userID: "user1", + proxiesByAccount: map[string][]*reverseproxy.ReverseProxy{ + "account1": { + {Domain: "app1.example.com", AccountID: "account1"}, + {Domain: "app2.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}}, + {Domain: "app3.example.com", AccountID: "account1"}, + }, + }, + users: map[string]*types.User{ + "user1": {Id: "user1", AccountID: "account1"}, + }, + expectErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := &ProxyServiceServer{ + reverseProxyManager: &mockReverseProxyManager{ + proxiesByAccount: tt.proxiesByAccount, + err: tt.proxyErr, + }, + usersManager: &mockUsersManager{ + users: tt.users, + err: tt.userErr, + }, + } + + err := server.ValidateUserGroupAccess(context.Background(), tt.domain, tt.userID) + + if tt.expectErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.expectErrMsg) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestGetAccountProxyByDomain(t *testing.T) { + tests := []struct { + name string + accountID string + domain string + proxiesByAccount map[string][]*reverseproxy.ReverseProxy + err error + expectProxy bool + expectErr bool + }{ + { + name: "proxy found", + accountID: "account1", + domain: "app.example.com", + proxiesByAccount: map[string][]*reverseproxy.ReverseProxy{ + "account1": { + {Domain: "other.example.com", AccountID: "account1"}, + {Domain: "app.example.com", AccountID: "account1"}, + }, + }, + expectProxy: true, + expectErr: false, + }, + { + name: "proxy not found in account", + accountID: "account1", + domain: "unknown.example.com", + proxiesByAccount: map[string][]*reverseproxy.ReverseProxy{ + "account1": {{Domain: "app.example.com", AccountID: "account1"}}, + }, + expectProxy: false, + expectErr: true, + }, + { + name: "empty proxy list for account", + accountID: "account1", + domain: "app.example.com", + proxiesByAccount: map[string][]*reverseproxy.ReverseProxy{}, + expectProxy: false, + expectErr: true, + }, + { + name: "manager error", + accountID: "account1", + domain: "app.example.com", + proxiesByAccount: nil, + err: errors.New("database error"), + expectProxy: false, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := &ProxyServiceServer{ + reverseProxyManager: &mockReverseProxyManager{ + proxiesByAccount: tt.proxiesByAccount, + err: tt.err, + }, + } + + proxy, err := server.getAccountProxyByDomain(context.Background(), tt.accountID, tt.domain) + + if tt.expectErr { + require.Error(t, err) + assert.Nil(t, proxy) + } else { + require.NoError(t, err) + require.NotNil(t, proxy) + assert.Equal(t, tt.domain, proxy.Domain) + } + }) + } +} diff --git a/management/server/http/handlers/proxy/auth.go b/management/server/http/handlers/proxy/auth.go index 29ed3ea52..82075bda3 100644 --- a/management/server/http/handlers/proxy/auth.go +++ b/management/server/http/handlers/proxy/auth.go @@ -7,24 +7,27 @@ import ( "github.com/coreos/go-oidc/v3/oidc" "github.com/gorilla/mux" - "github.com/netbirdio/netbird/management/server/types" log "github.com/sirupsen/logrus" "golang.org/x/oauth2" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/proxy/auth" ) +// AuthCallbackHandler handles OAuth callbacks for proxy authentication. type AuthCallbackHandler struct { proxyService *nbgrpc.ProxyServiceServer } +// NewAuthCallbackHandler creates a new OAuth callback handler. func NewAuthCallbackHandler(proxyService *nbgrpc.ProxyServiceServer) *AuthCallbackHandler { return &AuthCallbackHandler{ proxyService: proxyService, } } +// RegisterEndpoints registers the OAuth callback endpoint. func (h *AuthCallbackHandler) RegisterEndpoints(router *mux.Router) { router.HandleFunc(types.ProxyCallbackEndpoint, h.handleCallback).Methods(http.MethodGet) } @@ -46,10 +49,8 @@ func (h *AuthCallbackHandler) handleCallback(w http.ResponseWriter, r *http.Requ return } - // Get OIDC configuration oidcConfig := h.proxyService.GetOIDCConfig() - // Create OIDC provider to discover endpoints provider, err := oidc.NewProvider(r.Context(), oidcConfig.Issuer) if err != nil { log.WithError(err).Error("Failed to create OIDC provider") @@ -68,7 +69,6 @@ func (h *AuthCallbackHandler) handleCallback(w http.ResponseWriter, r *http.Requ return } - // Extract user ID from the OIDC token userID := extractUserIDFromToken(r.Context(), provider, oidcConfig, token) if userID == "" { log.Error("Failed to extract user ID from OIDC token") @@ -76,7 +76,22 @@ func (h *AuthCallbackHandler) handleCallback(w http.ResponseWriter, r *http.Requ return } - // Generate session JWT instead of passing OIDC access_token + if err := h.proxyService.ValidateUserGroupAccess(r.Context(), redirectURL.Hostname(), userID); err != nil { + log.WithFields(log.Fields{ + "user_id": userID, + "domain": redirectURL.Hostname(), + "error": err.Error(), + }).Warn("User denied access to reverse proxy") + + redirectURL.Scheme = "https" + query := redirectURL.Query() + query.Set("error", "access_denied") + query.Set("error_description", "You are not authorized to access this service") + redirectURL.RawQuery = query.Encode() + http.Redirect(w, r, redirectURL.String(), http.StatusFound) + return + } + sessionToken, err := h.proxyService.GenerateSessionToken(r.Context(), redirectURL.Hostname(), userID, auth.MethodOIDC) if err != nil { log.WithError(err).Error("Failed to create session token") @@ -84,13 +99,8 @@ func (h *AuthCallbackHandler) handleCallback(w http.ResponseWriter, r *http.Requ return } - // Redirect must be HTTPS, regardless of what was originally intended (which should always be HTTPS but better to double-check here). redirectURL.Scheme = "https" - // Pass the session token in the URL query parameter. The proxy middleware will - // extract it, validate it, set its own cookie, and redirect to remove the token from the URL. - // We cannot set the cookie here because cookies are domain-scoped (RFC 6265) and the - // management server cannot set cookies for the proxy's domain. query := redirectURL.Query() query.Set("session_token", sessionToken) redirectURL.RawQuery = query.Encode() @@ -99,9 +109,7 @@ func (h *AuthCallbackHandler) handleCallback(w http.ResponseWriter, r *http.Requ http.Redirect(w, r, redirectURL.String(), http.StatusFound) } -// extractUserIDFromToken extracts the user ID from an OIDC token. func extractUserIDFromToken(ctx context.Context, provider *oidc.Provider, config nbgrpc.ProxyOIDCConfig, token *oauth2.Token) string { - // Try to get ID token from the oauth2 token extras rawIDToken, ok := token.Extra("id_token").(string) if !ok { log.Warn("No id_token in OIDC response") @@ -118,27 +126,13 @@ func extractUserIDFromToken(ctx context.Context, provider *oidc.Provider, config return "" } - // Extract claims var claims struct { Subject string `json:"sub"` - Email string `json:"email"` - UserID string `json:"user_id"` } if err := idToken.Claims(&claims); err != nil { log.WithError(err).Warn("Failed to extract claims from ID token") return "" } - // Prefer subject, fall back to user_id or email - if claims.Subject != "" { - return claims.Subject - } - if claims.UserID != "" { - return claims.UserID - } - if claims.Email != "" { - return claims.Email - } - - return "" + return claims.Subject } diff --git a/management/server/http/handlers/proxy/auth_callback_integration_test.go b/management/server/http/handlers/proxy/auth_callback_integration_test.go new file mode 100644 index 000000000..58e337392 --- /dev/null +++ b/management/server/http/handlers/proxy/auth_callback_integration_test.go @@ -0,0 +1,582 @@ +//go:build integration + +package proxy + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "encoding/base64" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/gorilla/mux" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" + accesslogs "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/users" + "github.com/netbirdio/netbird/shared/management/proto" +) + +// fakeOIDCServer creates a minimal OIDC provider for testing. +type fakeOIDCServer struct { + server *httptest.Server + issuer string + signingKey ed25519.PrivateKey + publicKey ed25519.PublicKey + keyID string + tokenSubject string + tokenExpiry time.Duration + failExchange bool +} + +func newFakeOIDCServer() *fakeOIDCServer { + pub, priv, _ := ed25519.GenerateKey(rand.Reader) + f := &fakeOIDCServer{ + signingKey: priv, + publicKey: pub, + keyID: "test-key-1", + tokenExpiry: time.Hour, + } + f.server = httptest.NewServer(f) + f.issuer = f.server.URL + return f +} + +func (f *fakeOIDCServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/.well-known/openid-configuration": + f.handleDiscovery(w, r) + case "/token": + f.handleToken(w, r) + case "/keys": + f.handleJWKS(w, r) + default: + http.NotFound(w, r) + } +} + +func (f *fakeOIDCServer) handleDiscovery(w http.ResponseWriter, _ *http.Request) { + discovery := map[string]interface{}{ + "issuer": f.issuer, + "authorization_endpoint": f.issuer + "/auth", + "token_endpoint": f.issuer + "/token", + "jwks_uri": f.issuer + "/keys", + "response_types_supported": []string{ + "code", + "id_token", + "token id_token", + }, + "subject_types_supported": []string{"public"}, + "id_token_signing_alg_values_supported": []string{"EdDSA"}, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(discovery) +} + +func (f *fakeOIDCServer) handleToken(w http.ResponseWriter, r *http.Request) { + if f.failExchange { + http.Error(w, "invalid_grant", http.StatusBadRequest) + return + } + + if err := r.ParseForm(); err != nil { + http.Error(w, "bad request", http.StatusBadRequest) + return + } + + idToken := f.createIDToken() + + response := map[string]interface{}{ + "access_token": "test-access-token", + "token_type": "Bearer", + "expires_in": 3600, + "id_token": idToken, + "refresh_token": "test-refresh-token", + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) +} + +func (f *fakeOIDCServer) createIDToken() string { + now := time.Now() + claims := jwt.MapClaims{ + "iss": f.issuer, + "sub": f.tokenSubject, + "aud": "test-client-id", + "exp": now.Add(f.tokenExpiry).Unix(), + "iat": now.Unix(), + "nbf": now.Unix(), + } + + token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims) + token.Header["kid"] = f.keyID + signed, _ := token.SignedString(f.signingKey) + return signed +} + +func (f *fakeOIDCServer) handleJWKS(w http.ResponseWriter, _ *http.Request) { + jwks := map[string]interface{}{ + "keys": []map[string]interface{}{ + { + "kty": "OKP", + "crv": "Ed25519", + "kid": f.keyID, + "x": base64.RawURLEncoding.EncodeToString(f.publicKey), + "use": "sig", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(jwks) +} + +func (f *fakeOIDCServer) Close() { + f.server.Close() +} + +// testSetup contains all test dependencies. +type testSetup struct { + store store.Store + oidcServer *fakeOIDCServer + proxyService *nbgrpc.ProxyServiceServer + handler *AuthCallbackHandler + router *mux.Router + cleanup func() +} + +// testAccessLogManager is a minimal mock for accesslogs.Manager. +type testAccessLogManager struct{} + +func (m *testAccessLogManager) SaveAccessLog(_ context.Context, _ *accesslogs.AccessLogEntry) error { + return nil +} + +func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string) ([]*accesslogs.AccessLogEntry, error) { + return nil, nil +} + +func setupAuthCallbackTest(t *testing.T, sqlFile string) *testSetup { + t.Helper() + + ctx := context.Background() + + testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, sqlFile, t.TempDir()) + require.NoError(t, err) + + createTestReverseProxies(t, ctx, testStore) + + oidcServer := newFakeOIDCServer() + + tokenStore := nbgrpc.NewOneTimeTokenStore(time.Minute) + + usersManager := users.NewManager(testStore) + + oidcConfig := nbgrpc.ProxyOIDCConfig{ + Issuer: oidcServer.issuer, + ClientID: "test-client-id", + Scopes: []string{"openid", "profile", "email"}, + CallbackURL: "https://management.example.com/reverse-proxy/callback", + HMACKey: []byte("test-hmac-key-for-state-signing"), + } + + proxyService := nbgrpc.NewProxyServiceServer( + &testAccessLogManager{}, + tokenStore, + oidcConfig, + nil, + usersManager, + ) + + proxyService.SetProxyManager(&testReverseProxyManager{store: testStore}) + + handler := NewAuthCallbackHandler(proxyService) + + router := mux.NewRouter() + handler.RegisterEndpoints(router) + + return &testSetup{ + store: testStore, + oidcServer: oidcServer, + proxyService: proxyService, + handler: handler, + router: router, + cleanup: func() { + cleanup() + oidcServer.Close() + }, + } +} + +func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store.Store) { + t.Helper() + + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + pubKey := base64.StdEncoding.EncodeToString(pub) + privKey := base64.StdEncoding.EncodeToString(priv) + + testProxy := &reverseproxy.ReverseProxy{ + ID: "testProxyId", + AccountID: "testAccountId", + Name: "Test Proxy", + Domain: "test-proxy.example.com", + Targets: []*reverseproxy.Target{{ + Path: strPtr("/"), + Host: "localhost", + Port: 8080, + Protocol: "http", + TargetId: "peer1", + TargetType: "peer", + Enabled: true, + }}, + Enabled: true, + Auth: reverseproxy.AuthConfig{ + BearerAuth: &reverseproxy.BearerAuthConfig{ + Enabled: true, + DistributionGroups: []string{"allowedGroupId"}, + }, + }, + SessionPrivateKey: privKey, + SessionPublicKey: pubKey, + } + require.NoError(t, testStore.CreateReverseProxy(ctx, testProxy)) + + restrictedProxy := &reverseproxy.ReverseProxy{ + ID: "restrictedProxyId", + AccountID: "testAccountId", + Name: "Restricted Proxy", + Domain: "restricted-proxy.example.com", + Targets: []*reverseproxy.Target{{ + Path: strPtr("/"), + Host: "localhost", + Port: 8080, + Protocol: "http", + TargetId: "peer1", + TargetType: "peer", + Enabled: true, + }}, + Enabled: true, + Auth: reverseproxy.AuthConfig{ + BearerAuth: &reverseproxy.BearerAuthConfig{ + Enabled: true, + DistributionGroups: []string{"restrictedGroupId"}, + }, + }, + SessionPrivateKey: privKey, + SessionPublicKey: pubKey, + } + require.NoError(t, testStore.CreateReverseProxy(ctx, restrictedProxy)) + + noAuthProxy := &reverseproxy.ReverseProxy{ + ID: "noAuthProxyId", + AccountID: "testAccountId", + Name: "No Auth Proxy", + Domain: "no-auth-proxy.example.com", + Targets: []*reverseproxy.Target{{ + Path: strPtr("/"), + Host: "localhost", + Port: 8080, + Protocol: "http", + TargetId: "peer1", + TargetType: "peer", + Enabled: true, + }}, + Enabled: true, + Auth: reverseproxy.AuthConfig{ + BearerAuth: &reverseproxy.BearerAuthConfig{ + Enabled: false, + }, + }, + SessionPrivateKey: privKey, + SessionPublicKey: pubKey, + } + require.NoError(t, testStore.CreateReverseProxy(ctx, noAuthProxy)) +} + +func strPtr(s string) *string { + return &s +} + +// testReverseProxyManager is a minimal implementation for testing. +type testReverseProxyManager struct { + store store.Store +} + +func (m *testReverseProxyManager) GetAllReverseProxies(_ context.Context, _, _ string) ([]*reverseproxy.ReverseProxy, error) { + return nil, nil +} + +func (m *testReverseProxyManager) GetReverseProxy(_ context.Context, _, _, _ string) (*reverseproxy.ReverseProxy, error) { + return nil, nil +} + +func (m *testReverseProxyManager) CreateReverseProxy(_ context.Context, _, _ string, _ *reverseproxy.ReverseProxy) (*reverseproxy.ReverseProxy, error) { + return nil, nil +} + +func (m *testReverseProxyManager) UpdateReverseProxy(_ context.Context, _, _ string, _ *reverseproxy.ReverseProxy) (*reverseproxy.ReverseProxy, error) { + return nil, nil +} + +func (m *testReverseProxyManager) DeleteReverseProxy(_ context.Context, _, _, _ string) error { + return nil +} + +func (m *testReverseProxyManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error { + return nil +} + +func (m *testReverseProxyManager) SetStatus(_ context.Context, _, _ string, _ reverseproxy.ProxyStatus) error { + return nil +} + +func (m *testReverseProxyManager) ReloadAllReverseProxiesForAccount(_ context.Context, _ string) error { + return nil +} + +func (m *testReverseProxyManager) ReloadReverseProxy(_ context.Context, _, _ string) error { + return nil +} + +func (m *testReverseProxyManager) GetGlobalReverseProxies(ctx context.Context) ([]*reverseproxy.ReverseProxy, error) { + return m.store.GetReverseProxies(ctx, store.LockingStrengthNone) +} + +func (m *testReverseProxyManager) GetProxyByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.ReverseProxy, error) { + return m.store.GetReverseProxyByID(ctx, store.LockingStrengthNone, accountID, proxyID) +} + +func (m *testReverseProxyManager) GetAccountReverseProxies(ctx context.Context, accountID string) ([]*reverseproxy.ReverseProxy, error) { + return m.store.GetAccountReverseProxies(ctx, store.LockingStrengthNone, accountID) +} + +func createTestState(t *testing.T, ps *nbgrpc.ProxyServiceServer, redirectURL string) string { + t.Helper() + + resp, err := ps.GetOIDCURL(context.Background(), &proto.GetOIDCURLRequest{ + RedirectUrl: redirectURL, + AccountId: "testAccountId", + }) + require.NoError(t, err) + + parsedURL, err := url.Parse(resp.Url) + require.NoError(t, err) + + return parsedURL.Query().Get("state") +} + +func TestAuthCallback_UserAllowedToLogin(t *testing.T) { + setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql") + defer setup.cleanup() + + setup.oidcServer.tokenSubject = "allowedUserId" + + state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/dashboard") + + req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil) + rec := httptest.NewRecorder() + + setup.router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusFound, rec.Code) + + location := rec.Header().Get("Location") + require.NotEmpty(t, location) + + parsedLocation, err := url.Parse(location) + require.NoError(t, err) + + require.Equal(t, "test-proxy.example.com", parsedLocation.Host) + require.NotEmpty(t, parsedLocation.Query().Get("session_token"), "Should include session token") + require.Empty(t, parsedLocation.Query().Get("error"), "Should not have error parameter") +} + +func TestAuthCallback_UserNotInAllowedGroup(t *testing.T) { + setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql") + defer setup.cleanup() + + setup.oidcServer.tokenSubject = "nonGroupUserId" + + state := createTestState(t, setup.proxyService, "https://restricted-proxy.example.com/") + + req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil) + rec := httptest.NewRecorder() + + setup.router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusFound, rec.Code) + + location := rec.Header().Get("Location") + require.NotEmpty(t, location) + + parsedLocation, err := url.Parse(location) + require.NoError(t, err) + + require.Equal(t, "restricted-proxy.example.com", parsedLocation.Host) + require.Equal(t, "access_denied", parsedLocation.Query().Get("error")) + require.Contains(t, parsedLocation.Query().Get("error_description"), "not authorized") +} + +func TestAuthCallback_ProxyInDifferentAccount(t *testing.T) { + setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql") + defer setup.cleanup() + + setup.oidcServer.tokenSubject = "otherAccountUserId" + + state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/") + + req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil) + rec := httptest.NewRecorder() + + setup.router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusFound, rec.Code) + + location := rec.Header().Get("Location") + require.NotEmpty(t, location) + + parsedLocation, err := url.Parse(location) + require.NoError(t, err) + + require.Equal(t, "access_denied", parsedLocation.Query().Get("error")) + require.Contains(t, parsedLocation.Query().Get("error_description"), "not authorized") +} + +func TestAuthCallback_UserNotFound(t *testing.T) { + setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql") + defer setup.cleanup() + + setup.oidcServer.tokenSubject = "nonExistentUserId" + + state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/") + + req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil) + rec := httptest.NewRecorder() + + setup.router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusFound, rec.Code) + + location := rec.Header().Get("Location") + parsedLocation, err := url.Parse(location) + require.NoError(t, err) + + require.Equal(t, "access_denied", parsedLocation.Query().Get("error")) +} + +func TestAuthCallback_ProxyNotFound(t *testing.T) { + setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql") + defer setup.cleanup() + + setup.oidcServer.tokenSubject = "allowedUserId" + + state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/") + + require.NoError(t, setup.store.DeleteReverseProxy(context.Background(), "testAccountId", "testProxyId")) + + req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil) + rec := httptest.NewRecorder() + + setup.router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusFound, rec.Code) + + location := rec.Header().Get("Location") + parsedLocation, err := url.Parse(location) + require.NoError(t, err) + + require.Equal(t, "access_denied", parsedLocation.Query().Get("error")) +} + +func TestAuthCallback_InvalidToken(t *testing.T) { + setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql") + defer setup.cleanup() + + setup.oidcServer.failExchange = true + + state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/") + + req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=invalid-code&state="+url.QueryEscape(state), nil) + rec := httptest.NewRecorder() + + setup.router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusInternalServerError, rec.Code) + require.Contains(t, rec.Body.String(), "Failed to exchange code") +} + +func TestAuthCallback_ExpiredToken(t *testing.T) { + setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql") + defer setup.cleanup() + + setup.oidcServer.tokenSubject = "allowedUserId" + setup.oidcServer.tokenExpiry = -time.Hour + + state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/") + + req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil) + rec := httptest.NewRecorder() + + setup.router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusUnauthorized, rec.Code) + require.Contains(t, rec.Body.String(), "Failed to validate token") +} + +func TestAuthCallback_InvalidState(t *testing.T) { + setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql") + defer setup.cleanup() + + req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state=invalid-state", nil) + rec := httptest.NewRecorder() + + setup.router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) + require.Contains(t, rec.Body.String(), "Invalid state") +} + +func TestAuthCallback_MissingState(t *testing.T) { + setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql") + defer setup.cleanup() + + req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code", nil) + rec := httptest.NewRecorder() + + setup.router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestAuthCallback_BearerAuthDisabled(t *testing.T) { + setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql") + defer setup.cleanup() + + setup.oidcServer.tokenSubject = "allowedUserId" + + state := createTestState(t, setup.proxyService, "https://no-auth-proxy.example.com/") + + req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil) + rec := httptest.NewRecorder() + + setup.router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusFound, rec.Code) + + location := rec.Header().Get("Location") + parsedLocation, err := url.Parse(location) + require.NoError(t, err) + + require.NotEmpty(t, parsedLocation.Query().Get("session_token")) + require.Empty(t, parsedLocation.Query().Get("error")) +} diff --git a/proxy/internal/auth/middleware.go b/proxy/internal/auth/middleware.go index 92b4a74a8..a8871cd84 100644 --- a/proxy/internal/auth/middleware.go +++ b/proxy/internal/auth/middleware.go @@ -84,6 +84,21 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler { return } + // Check for error from OAuth callback (e.g., access denied) + if errCode := r.URL.Query().Get("error"); errCode != "" { + var requestID string + if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { + cd.SetOrigin(proxy.OriginAuth) + requestID = cd.GetRequestID() + } + errDesc := r.URL.Query().Get("error_description") + if errDesc == "" { + errDesc = "An error occurred during authentication" + } + web.ServeErrorPage(w, r, http.StatusForbidden, "Access Denied", errDesc, requestID, web.ErrorStatus{Proxy: true, Destination: true}) + return + } + // Check for an existing session cookie (contains JWT) if cookie, err := r.Cookie(auth.SessionCookieName); err == nil { if userID, method, err := auth.ValidateSessionJWT(cookie.Value, host, config.SessionPublicKey); err == nil {