diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index c99403684..7d64eee2f 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -26,6 +26,7 @@ import ( "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/users" proxyauth "github.com/netbirdio/netbird/proxy/auth" "github.com/netbirdio/netbird/shared/management/proto" @@ -795,3 +796,143 @@ func (s *ProxyServiceServer) getAccountProxyByDomain(ctx context.Context, accoun return nil, fmt.Errorf("reverse proxy not found for domain %s in account %s", domain, accountID) } + +// ValidateSession validates a session token and checks if the user has access to the domain. +func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.ValidateSessionRequest) (*proto.ValidateSessionResponse, error) { + domain := req.GetDomain() + sessionToken := req.GetSessionToken() + + if domain == "" || sessionToken == "" { + return &proto.ValidateSessionResponse{ + Valid: false, + DeniedReason: "missing domain or session_token", + }, nil + } + + proxy, err := s.getProxyByDomain(ctx, domain) + if err != nil { + log.WithFields(log.Fields{ + "domain": domain, + "error": err.Error(), + }).Debug("ValidateSession: proxy not found") + return &proto.ValidateSessionResponse{ + Valid: false, + DeniedReason: "proxy_not_found", + }, nil + } + + pubKeyBytes, err := base64.StdEncoding.DecodeString(proxy.SessionPublicKey) + if err != nil { + log.WithFields(log.Fields{ + "domain": domain, + "error": err.Error(), + }).Error("ValidateSession: decode public key") + return &proto.ValidateSessionResponse{ + Valid: false, + DeniedReason: "invalid_proxy_config", + }, nil + } + + userID, _, err := proxyauth.ValidateSessionJWT(sessionToken, domain, pubKeyBytes) + if err != nil { + log.WithFields(log.Fields{ + "domain": domain, + "error": err.Error(), + }).Debug("ValidateSession: invalid session token") + return &proto.ValidateSessionResponse{ + Valid: false, + DeniedReason: "invalid_token", + }, nil + } + + user, err := s.usersManager.GetUser(ctx, userID) + if err != nil { + log.WithFields(log.Fields{ + "domain": domain, + "user_id": userID, + "error": err.Error(), + }).Debug("ValidateSession: user not found") + return &proto.ValidateSessionResponse{ + Valid: false, + DeniedReason: "user_not_found", + }, nil + } + + if user.AccountID != proxy.AccountID { + log.WithFields(log.Fields{ + "domain": domain, + "user_id": userID, + "user_account": user.AccountID, + "proxy_account": proxy.AccountID, + }).Debug("ValidateSession: user account mismatch") + return &proto.ValidateSessionResponse{ + Valid: false, + DeniedReason: "account_mismatch", + }, nil + } + + if err := s.checkGroupAccess(proxy, user); err != nil { + log.WithFields(log.Fields{ + "domain": domain, + "user_id": userID, + "error": err.Error(), + }).Debug("ValidateSession: access denied") + return &proto.ValidateSessionResponse{ + Valid: false, + UserId: user.Id, + UserEmail: user.Email, + DeniedReason: "not_in_group", + }, nil + } + + log.WithFields(log.Fields{ + "domain": domain, + "user_id": userID, + "email": user.Email, + }).Debug("ValidateSession: access granted") + + return &proto.ValidateSessionResponse{ + Valid: true, + UserId: user.Id, + UserEmail: user.Email, + }, nil +} + +func (s *ProxyServiceServer) getProxyByDomain(ctx context.Context, domain string) (*reverseproxy.ReverseProxy, error) { + proxies, err := s.reverseProxyManager.GetGlobalReverseProxies(ctx) + if err != nil { + return nil, fmt.Errorf("get reverse proxies: %w", err) + } + + for _, proxy := range proxies { + if proxy.Domain == domain { + return proxy, nil + } + } + + return nil, fmt.Errorf("reverse proxy not found for domain: %s", domain) +} + +func (s *ProxyServiceServer) checkGroupAccess(proxy *reverseproxy.ReverseProxy, user *types.User) error { + if proxy.Auth.BearerAuth == nil || !proxy.Auth.BearerAuth.Enabled { + return nil + } + + allowedGroups := proxy.Auth.BearerAuth.DistributionGroups + if len(allowedGroups) == 0 { + return nil + } + + allowedSet := make(map[string]bool, len(allowedGroups)) + for _, groupID := range allowedGroups { + allowedSet[groupID] = true + } + + for _, groupID := range user.AutoGroups { + if allowedSet[groupID] { + return nil + } + } + + return fmt.Errorf("user not in allowed groups") +} diff --git a/management/internals/shared/grpc/proxy_group_access_test.go b/management/internals/shared/grpc/proxy_group_access_test.go index ba208dd77..060be00f4 100644 --- a/management/internals/shared/grpc/proxy_group_access_test.go +++ b/management/internals/shared/grpc/proxy_group_access_test.go @@ -68,6 +68,10 @@ func (m *mockReverseProxyManager) GetProxyByID(ctx context.Context, accountID, r return nil, nil } +func (m *mockReverseProxyManager) GetProxyIDByTargetID(_ context.Context, _, _ string) (string, error) { + return "", nil +} + type mockUsersManager struct { users map[string]*types.User err error diff --git a/management/internals/shared/grpc/validate_session_test.go b/management/internals/shared/grpc/validate_session_test.go new file mode 100644 index 000000000..2047600a1 --- /dev/null +++ b/management/internals/shared/grpc/validate_session_test.go @@ -0,0 +1,304 @@ +//go:build integration + +package grpc + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "encoding/base64" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/proxy/auth" + "github.com/netbirdio/netbird/shared/management/proto" +) + +type validateSessionTestSetup struct { + proxyService *ProxyServiceServer + store store.Store + cleanup func() +} + +func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup { + t.Helper() + + ctx := context.Background() + testStore, storeCleanup, err := store.NewTestStoreFromSQL(ctx, "../../../server/testdata/auth_callback.sql", t.TempDir()) + require.NoError(t, err) + + proxyManager := &testValidateSessionProxyManager{store: testStore} + usersManager := &testValidateSessionUsersManager{store: testStore} + + proxyService := NewProxyServiceServer(nil, NewOneTimeTokenStore(time.Minute), ProxyOIDCConfig{}, nil, usersManager) + proxyService.SetProxyManager(proxyManager) + + createTestProxies(t, ctx, testStore) + + return &validateSessionTestSetup{ + proxyService: proxyService, + store: testStore, + cleanup: storeCleanup, + } +} + +func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store) { + t.Helper() + + pubKey, privKey := generateSessionKeyPair(t) + + testProxy := &reverseproxy.ReverseProxy{ + ID: "testProxyId", + AccountID: "testAccountId", + Name: "Test Proxy", + Domain: "test-proxy.example.com", + Enabled: true, + SessionPrivateKey: privKey, + SessionPublicKey: pubKey, + Auth: reverseproxy.AuthConfig{ + BearerAuth: &reverseproxy.BearerAuthConfig{ + Enabled: true, + }, + }, + } + require.NoError(t, testStore.CreateReverseProxy(ctx, testProxy)) + + restrictedProxy := &reverseproxy.ReverseProxy{ + ID: "restrictedProxyId", + AccountID: "testAccountId", + Name: "Restricted Proxy", + Domain: "restricted-proxy.example.com", + Enabled: true, + SessionPrivateKey: privKey, + SessionPublicKey: pubKey, + Auth: reverseproxy.AuthConfig{ + BearerAuth: &reverseproxy.BearerAuthConfig{ + Enabled: true, + DistributionGroups: []string{"allowedGroupId"}, + }, + }, + } + require.NoError(t, testStore.CreateReverseProxy(ctx, restrictedProxy)) +} + +func generateSessionKeyPair(t *testing.T) (string, string) { + t.Helper() + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + return base64.StdEncoding.EncodeToString(pub), base64.StdEncoding.EncodeToString(priv) +} + +func createSessionToken(t *testing.T, privKeyB64, userID, domain string) string { + t.Helper() + token, err := sessionkey.SignToken(privKeyB64, userID, domain, auth.MethodOIDC, time.Hour) + require.NoError(t, err) + return token +} + +func TestValidateSession_UserAllowed(t *testing.T) { + setup := setupValidateSessionTest(t) + defer setup.cleanup() + + proxy, err := setup.store.GetReverseProxyByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId") + require.NoError(t, err) + + token := createSessionToken(t, proxy.SessionPrivateKey, "allowedUserId", "test-proxy.example.com") + + resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{ + Domain: "test-proxy.example.com", + SessionToken: token, + }) + + require.NoError(t, err) + assert.True(t, resp.Valid, "User should be allowed access") + assert.Equal(t, "allowedUserId", resp.UserId) + assert.Empty(t, resp.DeniedReason) +} + +func TestValidateSession_UserNotInAllowedGroup(t *testing.T) { + setup := setupValidateSessionTest(t) + defer setup.cleanup() + + proxy, err := setup.store.GetReverseProxyByID(context.Background(), store.LockingStrengthNone, "testAccountId", "restrictedProxyId") + require.NoError(t, err) + + token := createSessionToken(t, proxy.SessionPrivateKey, "nonGroupUserId", "restricted-proxy.example.com") + + resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{ + Domain: "restricted-proxy.example.com", + SessionToken: token, + }) + + require.NoError(t, err) + assert.False(t, resp.Valid, "User not in group should be denied") + assert.Equal(t, "not_in_group", resp.DeniedReason) + assert.Equal(t, "nonGroupUserId", resp.UserId) +} + +func TestValidateSession_UserInDifferentAccount(t *testing.T) { + setup := setupValidateSessionTest(t) + defer setup.cleanup() + + proxy, err := setup.store.GetReverseProxyByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId") + require.NoError(t, err) + + token := createSessionToken(t, proxy.SessionPrivateKey, "otherAccountUserId", "test-proxy.example.com") + + resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{ + Domain: "test-proxy.example.com", + SessionToken: token, + }) + + require.NoError(t, err) + assert.False(t, resp.Valid, "User in different account should be denied") + assert.Equal(t, "account_mismatch", resp.DeniedReason) +} + +func TestValidateSession_UserNotFound(t *testing.T) { + setup := setupValidateSessionTest(t) + defer setup.cleanup() + + proxy, err := setup.store.GetReverseProxyByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId") + require.NoError(t, err) + + token := createSessionToken(t, proxy.SessionPrivateKey, "nonExistentUserId", "test-proxy.example.com") + + resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{ + Domain: "test-proxy.example.com", + SessionToken: token, + }) + + require.NoError(t, err) + assert.False(t, resp.Valid, "Non-existent user should be denied") + assert.Equal(t, "user_not_found", resp.DeniedReason) +} + +func TestValidateSession_ProxyNotFound(t *testing.T) { + setup := setupValidateSessionTest(t) + defer setup.cleanup() + + proxy, err := setup.store.GetReverseProxyByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId") + require.NoError(t, err) + + token := createSessionToken(t, proxy.SessionPrivateKey, "allowedUserId", "unknown-proxy.example.com") + + resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{ + Domain: "unknown-proxy.example.com", + SessionToken: token, + }) + + require.NoError(t, err) + assert.False(t, resp.Valid, "Unknown proxy should be denied") + assert.Equal(t, "proxy_not_found", resp.DeniedReason) +} + +func TestValidateSession_InvalidToken(t *testing.T) { + setup := setupValidateSessionTest(t) + defer setup.cleanup() + + resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{ + Domain: "test-proxy.example.com", + SessionToken: "invalid-token", + }) + + require.NoError(t, err) + assert.False(t, resp.Valid, "Invalid token should be denied") + assert.Equal(t, "invalid_token", resp.DeniedReason) +} + +func TestValidateSession_MissingDomain(t *testing.T) { + setup := setupValidateSessionTest(t) + defer setup.cleanup() + + resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{ + SessionToken: "some-token", + }) + + require.NoError(t, err) + assert.False(t, resp.Valid) + assert.Contains(t, resp.DeniedReason, "missing") +} + +func TestValidateSession_MissingToken(t *testing.T) { + setup := setupValidateSessionTest(t) + defer setup.cleanup() + + resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{ + Domain: "test-proxy.example.com", + }) + + require.NoError(t, err) + assert.False(t, resp.Valid) + assert.Contains(t, resp.DeniedReason, "missing") +} + +type testValidateSessionProxyManager struct { + store store.Store +} + +func (m *testValidateSessionProxyManager) GetAllReverseProxies(_ context.Context, _, _ string) ([]*reverseproxy.ReverseProxy, error) { + return nil, nil +} + +func (m *testValidateSessionProxyManager) GetReverseProxy(_ context.Context, _, _, _ string) (*reverseproxy.ReverseProxy, error) { + return nil, nil +} + +func (m *testValidateSessionProxyManager) CreateReverseProxy(_ context.Context, _, _ string, _ *reverseproxy.ReverseProxy) (*reverseproxy.ReverseProxy, error) { + return nil, nil +} + +func (m *testValidateSessionProxyManager) UpdateReverseProxy(_ context.Context, _, _ string, _ *reverseproxy.ReverseProxy) (*reverseproxy.ReverseProxy, error) { + return nil, nil +} + +func (m *testValidateSessionProxyManager) DeleteReverseProxy(_ context.Context, _, _, _ string) error { + return nil +} + +func (m *testValidateSessionProxyManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error { + return nil +} + +func (m *testValidateSessionProxyManager) SetStatus(_ context.Context, _, _ string, _ reverseproxy.ProxyStatus) error { + return nil +} + +func (m *testValidateSessionProxyManager) ReloadAllReverseProxiesForAccount(_ context.Context, _ string) error { + return nil +} + +func (m *testValidateSessionProxyManager) ReloadReverseProxy(_ context.Context, _, _ string) error { + return nil +} + +func (m *testValidateSessionProxyManager) GetGlobalReverseProxies(ctx context.Context) ([]*reverseproxy.ReverseProxy, error) { + return m.store.GetReverseProxies(ctx, store.LockingStrengthNone) +} + +func (m *testValidateSessionProxyManager) GetProxyByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.ReverseProxy, error) { + return m.store.GetReverseProxyByID(ctx, store.LockingStrengthNone, accountID, proxyID) +} + +func (m *testValidateSessionProxyManager) GetAccountReverseProxies(ctx context.Context, accountID string) ([]*reverseproxy.ReverseProxy, error) { + return m.store.GetAccountReverseProxies(ctx, store.LockingStrengthNone, accountID) +} + +func (m *testValidateSessionProxyManager) GetProxyIDByTargetID(_ context.Context, _, _ string) (string, error) { + return "", nil +} + +type testValidateSessionUsersManager struct { + store store.Store +} + +func (m *testValidateSessionUsersManager) GetUser(ctx context.Context, userID string) (*types.User, error) { + return m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID) +} diff --git a/management/server/http/handlers/proxy/auth.go b/management/server/http/handlers/proxy/auth.go index 82075bda3..27c64ecae 100644 --- a/management/server/http/handlers/proxy/auth.go +++ b/management/server/http/handlers/proxy/auth.go @@ -76,26 +76,18 @@ func (h *AuthCallbackHandler) handleCallback(w http.ResponseWriter, r *http.Requ return } - if err := h.proxyService.ValidateUserGroupAccess(r.Context(), redirectURL.Hostname(), userID); err != nil { - log.WithFields(log.Fields{ - "user_id": userID, - "domain": redirectURL.Hostname(), - "error": err.Error(), - }).Warn("User denied access to reverse proxy") - - redirectURL.Scheme = "https" - query := redirectURL.Query() - query.Set("error", "access_denied") - query.Set("error_description", "You are not authorized to access this service") - redirectURL.RawQuery = query.Encode() - http.Redirect(w, r, redirectURL.String(), http.StatusFound) - return - } + // Group validation is performed by the proxy via ValidateSession gRPC call. + // This allows the proxy to show 403 pages directly without redirect dance. sessionToken, err := h.proxyService.GenerateSessionToken(r.Context(), redirectURL.Hostname(), userID, auth.MethodOIDC) if err != nil { log.WithError(err).Error("Failed to create session token") - http.Error(w, "Failed to create session", http.StatusInternalServerError) + redirectURL.Scheme = "https" + query := redirectURL.Query() + query.Set("error", "access_denied") + query.Set("error_description", "Service configuration error") + redirectURL.RawQuery = query.Encode() + http.Redirect(w, r, redirectURL.String(), http.StatusFound) return } diff --git a/management/server/http/handlers/proxy/auth_callback_integration_test.go b/management/server/http/handlers/proxy/auth_callback_integration_test.go index 58e337392..59f9f8d14 100644 --- a/management/server/http/handlers/proxy/auth_callback_integration_test.go +++ b/management/server/http/handlers/proxy/auth_callback_integration_test.go @@ -22,6 +22,7 @@ import ( accesslogs "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/shared/management/proto" ) @@ -164,14 +165,15 @@ func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string) return nil, nil } -func setupAuthCallbackTest(t *testing.T, sqlFile string) *testSetup { +func setupAuthCallbackTest(t *testing.T) *testSetup { t.Helper() ctx := context.Background() - testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, sqlFile, t.TempDir()) + testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir()) require.NoError(t, err) + createTestAccountsAndUsers(t, ctx, testStore) createTestReverseProxies(t, ctx, testStore) oidcServer := newFakeOIDCServer() @@ -307,6 +309,37 @@ func strPtr(s string) *string { return &s } +func createTestAccountsAndUsers(t *testing.T, ctx context.Context, testStore store.Store) { + t.Helper() + + testAccount := &types.Account{ + Id: "testAccountId", + Domain: "test.com", + DomainCategory: "private", + IsDomainPrimaryAccount: true, + CreatedAt: time.Now(), + } + require.NoError(t, testStore.SaveAccount(ctx, testAccount)) + + allowedGroup := &types.Group{ + ID: "allowedGroupId", + AccountID: "testAccountId", + Name: "Allowed Group", + Issued: "api", + } + require.NoError(t, testStore.CreateGroup(ctx, allowedGroup)) + + allowedUser := &types.User{ + Id: "allowedUserId", + AccountID: "testAccountId", + Role: types.UserRoleUser, + AutoGroups: []string{"allowedGroupId"}, + CreatedAt: time.Now(), + Issued: "api", + } + require.NoError(t, testStore.SaveUser(ctx, allowedUser)) +} + // testReverseProxyManager is a minimal implementation for testing. type testReverseProxyManager struct { store store.Store @@ -360,6 +393,10 @@ func (m *testReverseProxyManager) GetAccountReverseProxies(ctx context.Context, return m.store.GetAccountReverseProxies(ctx, store.LockingStrengthNone, accountID) } +func (m *testReverseProxyManager) GetProxyIDByTargetID(_ context.Context, _, _ string) (string, error) { + return "", nil +} + func createTestState(t *testing.T, ps *nbgrpc.ProxyServiceServer, redirectURL string) string { t.Helper() @@ -376,7 +413,7 @@ func createTestState(t *testing.T, ps *nbgrpc.ProxyServiceServer, redirectURL st } func TestAuthCallback_UserAllowedToLogin(t *testing.T) { - setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql") + setup := setupAuthCallbackTest(t) defer setup.cleanup() setup.oidcServer.tokenSubject = "allowedUserId" @@ -401,81 +438,8 @@ func TestAuthCallback_UserAllowedToLogin(t *testing.T) { require.Empty(t, parsedLocation.Query().Get("error"), "Should not have error parameter") } -func TestAuthCallback_UserNotInAllowedGroup(t *testing.T) { - setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql") - defer setup.cleanup() - - setup.oidcServer.tokenSubject = "nonGroupUserId" - - state := createTestState(t, setup.proxyService, "https://restricted-proxy.example.com/") - - req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil) - rec := httptest.NewRecorder() - - setup.router.ServeHTTP(rec, req) - - require.Equal(t, http.StatusFound, rec.Code) - - location := rec.Header().Get("Location") - require.NotEmpty(t, location) - - parsedLocation, err := url.Parse(location) - require.NoError(t, err) - - require.Equal(t, "restricted-proxy.example.com", parsedLocation.Host) - require.Equal(t, "access_denied", parsedLocation.Query().Get("error")) - require.Contains(t, parsedLocation.Query().Get("error_description"), "not authorized") -} - -func TestAuthCallback_ProxyInDifferentAccount(t *testing.T) { - setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql") - defer setup.cleanup() - - setup.oidcServer.tokenSubject = "otherAccountUserId" - - state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/") - - req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil) - rec := httptest.NewRecorder() - - setup.router.ServeHTTP(rec, req) - - require.Equal(t, http.StatusFound, rec.Code) - - location := rec.Header().Get("Location") - require.NotEmpty(t, location) - - parsedLocation, err := url.Parse(location) - require.NoError(t, err) - - require.Equal(t, "access_denied", parsedLocation.Query().Get("error")) - require.Contains(t, parsedLocation.Query().Get("error_description"), "not authorized") -} - -func TestAuthCallback_UserNotFound(t *testing.T) { - setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql") - defer setup.cleanup() - - setup.oidcServer.tokenSubject = "nonExistentUserId" - - state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/") - - req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil) - rec := httptest.NewRecorder() - - setup.router.ServeHTTP(rec, req) - - require.Equal(t, http.StatusFound, rec.Code) - - location := rec.Header().Get("Location") - parsedLocation, err := url.Parse(location) - require.NoError(t, err) - - require.Equal(t, "access_denied", parsedLocation.Query().Get("error")) -} - func TestAuthCallback_ProxyNotFound(t *testing.T) { - setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql") + setup := setupAuthCallbackTest(t) defer setup.cleanup() setup.oidcServer.tokenSubject = "allowedUserId" @@ -499,7 +463,7 @@ func TestAuthCallback_ProxyNotFound(t *testing.T) { } func TestAuthCallback_InvalidToken(t *testing.T) { - setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql") + setup := setupAuthCallbackTest(t) defer setup.cleanup() setup.oidcServer.failExchange = true @@ -516,7 +480,7 @@ func TestAuthCallback_InvalidToken(t *testing.T) { } func TestAuthCallback_ExpiredToken(t *testing.T) { - setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql") + setup := setupAuthCallbackTest(t) defer setup.cleanup() setup.oidcServer.tokenSubject = "allowedUserId" @@ -534,7 +498,7 @@ func TestAuthCallback_ExpiredToken(t *testing.T) { } func TestAuthCallback_InvalidState(t *testing.T) { - setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql") + setup := setupAuthCallbackTest(t) defer setup.cleanup() req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state=invalid-state", nil) @@ -547,7 +511,7 @@ func TestAuthCallback_InvalidState(t *testing.T) { } func TestAuthCallback_MissingState(t *testing.T) { - setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql") + setup := setupAuthCallbackTest(t) defer setup.cleanup() req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code", nil) @@ -557,26 +521,3 @@ func TestAuthCallback_MissingState(t *testing.T) { require.Equal(t, http.StatusBadRequest, rec.Code) } - -func TestAuthCallback_BearerAuthDisabled(t *testing.T) { - setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql") - defer setup.cleanup() - - setup.oidcServer.tokenSubject = "allowedUserId" - - state := createTestState(t, setup.proxyService, "https://no-auth-proxy.example.com/") - - req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil) - rec := httptest.NewRecorder() - - setup.router.ServeHTTP(rec, req) - - require.Equal(t, http.StatusFound, rec.Code) - - location := rec.Header().Get("Location") - parsedLocation, err := url.Parse(location) - require.NoError(t, err) - - require.NotEmpty(t, parsedLocation.Query().Get("session_token")) - require.Empty(t, parsedLocation.Query().Get("error")) -} diff --git a/management/server/testdata/auth_callback.sql b/management/server/testdata/auth_callback.sql new file mode 100644 index 000000000..fdd91a6d5 --- /dev/null +++ b/management/server/testdata/auth_callback.sql @@ -0,0 +1,17 @@ +-- Schema definitions (must match GORM auto-migrate order) +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); + +-- Test accounts +INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO accounts VALUES('otherAccountId','','2024-10-02 16:01:38.000000000+00:00','other.com','private',1,'otherNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); + +-- Test groups +INSERT INTO "groups" VALUES('allowedGroupId','testAccountId','Allowed Group','api','[]',0,''); +INSERT INTO "groups" VALUES('restrictedGroupId','testAccountId','Restricted Group','api','[]',0,''); + +-- Test users +INSERT INTO users VALUES('allowedUserId','testAccountId','user',0,0,'','["allowedGroupId"]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('nonGroupUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('otherAccountUserId','otherAccountId','user',0,0,'','["allowedGroupId"]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); diff --git a/proxy/internal/accesslog/middleware.go b/proxy/internal/accesslog/middleware.go index 48d6b61b3..c48e853bb 100644 --- a/proxy/internal/accesslog/middleware.go +++ b/proxy/internal/accesslog/middleware.go @@ -7,7 +7,6 @@ import ( "github.com/rs/xid" - "github.com/netbirdio/netbird/proxy/internal/auth" "github.com/netbirdio/netbird/proxy/internal/proxy" ) @@ -55,8 +54,8 @@ func (l *Logger) Middleware(next http.Handler) http.Handler { Method: r.Method, ResponseCode: int32(sw.status), SourceIp: sourceIp, - AuthMechanism: auth.MethodFromContext(r.Context()).String(), - UserId: auth.UserFromContext(r.Context()), + AuthMechanism: capturedData.GetAuthMethod(), + UserId: capturedData.GetUserID(), AuthSuccess: sw.status != http.StatusUnauthorized && sw.status != http.StatusForbidden, } l.logger.Debugf("response: request_id=%s method=%s host=%s path=%s status=%d duration=%dms source=%s origin=%s service=%s account=%s", diff --git a/proxy/internal/auth/middleware.go b/proxy/internal/auth/middleware.go index 092030cb8..9621a58c7 100644 --- a/proxy/internal/auth/middleware.go +++ b/proxy/internal/auth/middleware.go @@ -24,6 +24,11 @@ type authenticator interface { Authenticate(ctx context.Context, in *proto.AuthenticateRequest, opts ...grpc.CallOption) (*proto.AuthenticateResponse, error) } +// SessionValidator validates session tokens and checks user access permissions. +type SessionValidator interface { + ValidateSession(ctx context.Context, in *proto.ValidateSessionRequest, opts ...grpc.CallOption) (*proto.ValidateSessionResponse, error) +} + type Scheme interface { Type() auth.Method // Authenticate should check the passed request and determine whether @@ -42,18 +47,23 @@ type DomainConfig struct { } type Middleware struct { - domainsMux sync.RWMutex - domains map[string]DomainConfig - logger *log.Logger + domainsMux sync.RWMutex + domains map[string]DomainConfig + logger *log.Logger + sessionValidator SessionValidator } -func NewMiddleware(logger *log.Logger) *Middleware { +// NewMiddleware creates a new authentication middleware. +// The sessionValidator is optional; if nil, OIDC session tokens will be validated +// locally without group access checks. +func NewMiddleware(logger *log.Logger, sessionValidator SessionValidator) *Middleware { if logger == nil { logger = log.StandardLogger() } return &Middleware{ - domains: make(map[string]DomainConfig), - logger: logger, + domains: make(map[string]DomainConfig), + logger: logger, + sessionValidator: sessionValidator, } } @@ -102,9 +112,11 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler { // Check for an existing session cookie (contains JWT) if cookie, err := r.Cookie(auth.SessionCookieName); err == nil { if userID, method, err := auth.ValidateSessionJWT(cookie.Value, host, config.SessionPublicKey); err == nil { - ctx := withAuthMethod(r.Context(), auth.Method(method)) - ctx = withAuthUser(ctx, userID) - next.ServeHTTP(w, r.WithContext(ctx)) + if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { + cd.SetUserID(userID) + cd.SetAuthMethod(method) + } + next.ServeHTTP(w, r) return } } @@ -114,13 +126,23 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler { for _, scheme := range config.Schemes { token, promptData := scheme.Authenticate(r) if token != "" { - if _, _, err := auth.ValidateSessionJWT(token, host, config.SessionPublicKey); err != nil { + userID, err := mw.validateSessionToken(r.Context(), host, token, config.SessionPublicKey, scheme.Type()) + if err != nil { if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { cd.SetOrigin(proxy.OriginAuth) } http.Error(w, err.Error(), http.StatusBadRequest) return } + if userID == "" { + var requestID string + if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { + cd.SetOrigin(proxy.OriginAuth) + requestID = cd.GetRequestID() + } + web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", "You are not authorized to access this service", requestID) + return + } expiration := config.SessionExpiration if expiration == 0 { @@ -191,6 +213,40 @@ func (mw *Middleware) RemoveDomain(domain string) { delete(mw.domains, domain) } +// validateSessionToken validates a session token, optionally checking group access via gRPC. +// For OIDC tokens with a configured validator, it calls ValidateSession to check group access. +// For other auth methods (PIN, password), it validates the JWT locally. +// Returns the user ID if valid, empty string if access denied, or error for invalid tokens. +func (mw *Middleware) validateSessionToken(ctx context.Context, host, token string, publicKey ed25519.PublicKey, method auth.Method) (string, error) { + // For OIDC with a session validator, call the gRPC service to check group access + if method == auth.MethodOIDC && mw.sessionValidator != nil { + resp, err := mw.sessionValidator.ValidateSession(ctx, &proto.ValidateSessionRequest{ + Domain: host, + SessionToken: token, + }) + if err != nil { + mw.logger.WithError(err).Error("ValidateSession gRPC call failed") + return "", fmt.Errorf("session validation failed") + } + if !resp.Valid { + mw.logger.WithFields(log.Fields{ + "domain": host, + "denied_reason": resp.DeniedReason, + "user_id": resp.UserId, + }).Debug("Session validation denied") + return "", nil + } + return resp.UserId, nil + } + + // For non-OIDC methods or when no validator is configured, validate JWT locally + userID, _, err := auth.ValidateSessionJWT(token, host, publicKey) + if err != nil { + return "", err + } + return userID, nil +} + // stripSessionTokenParam returns the request URI with the session_token query // parameter removed so it doesn't linger in the browser's address bar or history. func stripSessionTokenParam(u *url.URL) string { diff --git a/proxy/internal/auth/middleware_test.go b/proxy/internal/auth/middleware_test.go index 503d6faf9..eac4749d5 100644 --- a/proxy/internal/auth/middleware_test.go +++ b/proxy/internal/auth/middleware_test.go @@ -17,6 +17,7 @@ import ( "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey" "github.com/netbirdio/netbird/proxy/auth" + "github.com/netbirdio/netbird/proxy/internal/proxy" ) func generateTestKeyPair(t *testing.T) *sessionkey.KeyPair { @@ -28,10 +29,10 @@ func generateTestKeyPair(t *testing.T) *sessionkey.KeyPair { // stubScheme is a minimal Scheme implementation for testing. type stubScheme struct { - method auth.Method - token string - promptID string - authFn func(*http.Request) (string, string) + method auth.Method + token string + promptID string + authFn func(*http.Request) (string, string) } func (s *stubScheme) Type() auth.Method { return s.method } @@ -51,7 +52,7 @@ func newPassthroughHandler() http.Handler { } func TestAddDomain_ValidKey(t *testing.T) { - mw := NewMiddleware(log.StandardLogger()) + mw := NewMiddleware(log.StandardLogger(), nil) kp := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} @@ -69,7 +70,7 @@ func TestAddDomain_ValidKey(t *testing.T) { } func TestAddDomain_EmptyKey(t *testing.T) { - mw := NewMiddleware(log.StandardLogger()) + mw := NewMiddleware(log.StandardLogger(), nil) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour) @@ -83,7 +84,7 @@ func TestAddDomain_EmptyKey(t *testing.T) { } func TestAddDomain_InvalidBase64(t *testing.T) { - mw := NewMiddleware(log.StandardLogger()) + mw := NewMiddleware(log.StandardLogger(), nil) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour) @@ -97,7 +98,7 @@ func TestAddDomain_InvalidBase64(t *testing.T) { } func TestAddDomain_WrongKeySize(t *testing.T) { - mw := NewMiddleware(log.StandardLogger()) + mw := NewMiddleware(log.StandardLogger(), nil) shortKey := base64.StdEncoding.EncodeToString([]byte("tooshort")) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} @@ -112,7 +113,7 @@ func TestAddDomain_WrongKeySize(t *testing.T) { } func TestAddDomain_NoSchemes_NoKeyRequired(t *testing.T) { - mw := NewMiddleware(log.StandardLogger()) + mw := NewMiddleware(log.StandardLogger(), nil) err := mw.AddDomain("example.com", nil, "", time.Hour) require.NoError(t, err, "domains with no auth schemes should not require a key") @@ -124,7 +125,7 @@ func TestAddDomain_NoSchemes_NoKeyRequired(t *testing.T) { } func TestAddDomain_OverwritesPreviousConfig(t *testing.T) { - mw := NewMiddleware(log.StandardLogger()) + mw := NewMiddleware(log.StandardLogger(), nil) kp1 := generateTestKeyPair(t) kp2 := generateTestKeyPair(t) @@ -143,7 +144,7 @@ func TestAddDomain_OverwritesPreviousConfig(t *testing.T) { } func TestRemoveDomain(t *testing.T) { - mw := NewMiddleware(log.StandardLogger()) + mw := NewMiddleware(log.StandardLogger(), nil) kp := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} @@ -158,7 +159,7 @@ func TestRemoveDomain(t *testing.T) { } func TestProtect_UnknownDomainPassesThrough(t *testing.T) { - mw := NewMiddleware(log.StandardLogger()) + mw := NewMiddleware(log.StandardLogger(), nil) handler := mw.Protect(newPassthroughHandler()) req := httptest.NewRequest(http.MethodGet, "http://unknown.com/", nil) @@ -170,7 +171,7 @@ func TestProtect_UnknownDomainPassesThrough(t *testing.T) { } func TestProtect_DomainWithNoSchemesPassesThrough(t *testing.T) { - mw := NewMiddleware(log.StandardLogger()) + mw := NewMiddleware(log.StandardLogger(), nil) require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour)) handler := mw.Protect(newPassthroughHandler()) @@ -184,7 +185,7 @@ func TestProtect_DomainWithNoSchemesPassesThrough(t *testing.T) { } func TestProtect_UnauthenticatedRequestIsBlocked(t *testing.T) { - mw := NewMiddleware(log.StandardLogger()) + mw := NewMiddleware(log.StandardLogger(), nil) kp := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} @@ -205,7 +206,7 @@ func TestProtect_UnauthenticatedRequestIsBlocked(t *testing.T) { } func TestProtect_HostWithPortIsMatched(t *testing.T) { - mw := NewMiddleware(log.StandardLogger()) + mw := NewMiddleware(log.StandardLogger(), nil) kp := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} @@ -226,7 +227,7 @@ func TestProtect_HostWithPortIsMatched(t *testing.T) { } func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) { - mw := NewMiddleware(log.StandardLogger()) + mw := NewMiddleware(log.StandardLogger(), nil) kp := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} @@ -235,16 +236,18 @@ func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) { token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour) require.NoError(t, err) + capturedData := &proxy.CapturedData{} handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - user := UserFromContext(r.Context()) - method := MethodFromContext(r.Context()) - assert.Equal(t, "test-user", user) - assert.Equal(t, auth.MethodPIN, method) + cd := proxy.CapturedDataFromContext(r.Context()) + require.NotNil(t, cd) + assert.Equal(t, "test-user", cd.GetUserID()) + assert.Equal(t, "pin", cd.GetAuthMethod()) w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("authenticated")) })) req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + req = req.WithContext(proxy.WithCapturedData(req.Context(), capturedData)) req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: token}) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) @@ -254,7 +257,7 @@ func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) { } func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) { - mw := NewMiddleware(log.StandardLogger()) + mw := NewMiddleware(log.StandardLogger(), nil) kp := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} @@ -280,7 +283,7 @@ func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) { } func TestProtect_WrongDomainCookieIsRejected(t *testing.T) { - mw := NewMiddleware(log.StandardLogger()) + mw := NewMiddleware(log.StandardLogger(), nil) kp := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} @@ -306,7 +309,7 @@ func TestProtect_WrongDomainCookieIsRejected(t *testing.T) { } func TestProtect_WrongKeyCookieIsRejected(t *testing.T) { - mw := NewMiddleware(log.StandardLogger()) + mw := NewMiddleware(log.StandardLogger(), nil) kp1 := generateTestKeyPair(t) kp2 := generateTestKeyPair(t) @@ -333,7 +336,7 @@ func TestProtect_WrongKeyCookieIsRejected(t *testing.T) { } func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) { - mw := NewMiddleware(log.StandardLogger()) + mw := NewMiddleware(log.StandardLogger(), nil) kp := generateTestKeyPair(t) token, err := sessionkey.SignToken(kp.PrivateKey, "pin-user", "example.com", auth.MethodPIN, time.Hour) @@ -383,7 +386,7 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) { } func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) { - mw := NewMiddleware(log.StandardLogger()) + mw := NewMiddleware(log.StandardLogger(), nil) kp := generateTestKeyPair(t) scheme := &stubScheme{ @@ -406,7 +409,7 @@ func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) { } func TestProtect_MultipleSchemes(t *testing.T) { - mw := NewMiddleware(log.StandardLogger()) + mw := NewMiddleware(log.StandardLogger(), nil) kp := generateTestKeyPair(t) token, err := sessionkey.SignToken(kp.PrivateKey, "password-user", "example.com", auth.MethodPassword, time.Hour) @@ -448,7 +451,7 @@ func TestProtect_MultipleSchemes(t *testing.T) { } func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) { - mw := NewMiddleware(log.StandardLogger()) + mw := NewMiddleware(log.StandardLogger(), nil) kp := generateTestKeyPair(t) // Return a garbage token that won't validate. @@ -470,7 +473,7 @@ func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) { } func TestAddDomain_RandomBytes32NotEd25519(t *testing.T) { - mw := NewMiddleware(log.StandardLogger()) + mw := NewMiddleware(log.StandardLogger(), nil) // 32 random bytes that happen to be valid base64 and correct size // but are actually a valid ed25519 public key length-wise. @@ -487,7 +490,7 @@ func TestAddDomain_RandomBytes32NotEd25519(t *testing.T) { } func TestAddDomain_InvalidKeyDoesNotCorruptExistingConfig(t *testing.T) { - mw := NewMiddleware(log.StandardLogger()) + mw := NewMiddleware(log.StandardLogger(), nil) kp := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} diff --git a/proxy/internal/proxy/context.go b/proxy/internal/proxy/context.go index 460f04ed0..22ebbf371 100644 --- a/proxy/internal/proxy/context.go +++ b/proxy/internal/proxy/context.go @@ -45,12 +45,14 @@ func (o ResponseOrigin) String() string { // CapturedData is a mutable struct that allows downstream handlers // to pass data back up the middleware chain. type CapturedData struct { - mu sync.RWMutex - RequestID string - ServiceId string - AccountId types.AccountID - Origin ResponseOrigin - ClientIP string + mu sync.RWMutex + RequestID string + ServiceId string + AccountId types.AccountID + Origin ResponseOrigin + ClientIP string + UserID string + AuthMethod string } // GetRequestID safely gets the request ID @@ -116,6 +118,34 @@ func (c *CapturedData) GetClientIP() string { return c.ClientIP } +// SetUserID safely sets the authenticated user ID. +func (c *CapturedData) SetUserID(userID string) { + c.mu.Lock() + defer c.mu.Unlock() + c.UserID = userID +} + +// GetUserID safely gets the authenticated user ID. +func (c *CapturedData) GetUserID() string { + c.mu.RLock() + defer c.mu.RUnlock() + return c.UserID +} + +// SetAuthMethod safely sets the authentication method used. +func (c *CapturedData) SetAuthMethod(method string) { + c.mu.Lock() + defer c.mu.Unlock() + c.AuthMethod = method +} + +// GetAuthMethod safely gets the authentication method used. +func (c *CapturedData) GetAuthMethod() string { + c.mu.RLock() + defer c.mu.RUnlock() + return c.AuthMethod +} + // WithCapturedData adds a CapturedData struct to the context func WithCapturedData(ctx context.Context, data *CapturedData) context.Context { return context.WithValue(ctx, capturedDataKey, data) diff --git a/proxy/server.go b/proxy/server.go index f24a19160..b5ec50906 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -240,8 +240,8 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { // Configure the reverse proxy using NetBird's HTTP Client Transport for proxying. s.proxy = proxy.NewReverseProxy(s.netbird, s.ForwardedProto, s.TrustedProxies, s.Logger) - // Configure the authentication middleware. - s.auth = auth.NewMiddleware(s.Logger) + // Configure the authentication middleware with session validator for OIDC group checks. + s.auth = auth.NewMiddleware(s.Logger, s.mgmtClient) // Configure Access logs to management server. accessLog := accesslog.NewLogger(s.mgmtClient, s.Logger, s.TrustedProxies) diff --git a/shared/management/proto/management.pb.go b/shared/management/proto/management.pb.go index 44838fc16..5f45d3fdd 100644 --- a/shared/management/proto/management.pb.go +++ b/shared/management/proto/management.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v6.33.0 +// protoc v6.33.3 // source: management.proto package proto diff --git a/shared/management/proto/proxy_service.pb.go b/shared/management/proto/proxy_service.pb.go index 703c65a73..95de71d1a 100644 --- a/shared/management/proto/proxy_service.pb.go +++ b/shared/management/proto/proxy_service.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v6.33.0 +// protoc v6.33.3 // source: proxy_service.proto package proto @@ -1341,6 +1341,132 @@ func (x *GetOIDCURLResponse) GetUrl() string { return "" } +type ValidateSessionRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Domain string `protobuf:"bytes,1,opt,name=domain,proto3" json:"domain,omitempty"` + SessionToken string `protobuf:"bytes,2,opt,name=session_token,json=sessionToken,proto3" json:"session_token,omitempty"` +} + +func (x *ValidateSessionRequest) Reset() { + *x = ValidateSessionRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[18] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ValidateSessionRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ValidateSessionRequest) ProtoMessage() {} + +func (x *ValidateSessionRequest) ProtoReflect() protoreflect.Message { + mi := &file_proxy_service_proto_msgTypes[18] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ValidateSessionRequest.ProtoReflect.Descriptor instead. +func (*ValidateSessionRequest) Descriptor() ([]byte, []int) { + return file_proxy_service_proto_rawDescGZIP(), []int{18} +} + +func (x *ValidateSessionRequest) GetDomain() string { + if x != nil { + return x.Domain + } + return "" +} + +func (x *ValidateSessionRequest) GetSessionToken() string { + if x != nil { + return x.SessionToken + } + return "" +} + +type ValidateSessionResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Valid bool `protobuf:"varint,1,opt,name=valid,proto3" json:"valid,omitempty"` + UserId string `protobuf:"bytes,2,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` + UserEmail string `protobuf:"bytes,3,opt,name=user_email,json=userEmail,proto3" json:"user_email,omitempty"` + DeniedReason string `protobuf:"bytes,4,opt,name=denied_reason,json=deniedReason,proto3" json:"denied_reason,omitempty"` +} + +func (x *ValidateSessionResponse) Reset() { + *x = ValidateSessionResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[19] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ValidateSessionResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ValidateSessionResponse) ProtoMessage() {} + +func (x *ValidateSessionResponse) ProtoReflect() protoreflect.Message { + mi := &file_proxy_service_proto_msgTypes[19] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ValidateSessionResponse.ProtoReflect.Descriptor instead. +func (*ValidateSessionResponse) Descriptor() ([]byte, []int) { + return file_proxy_service_proto_rawDescGZIP(), []int{19} +} + +func (x *ValidateSessionResponse) GetValid() bool { + if x != nil { + return x.Valid + } + return false +} + +func (x *ValidateSessionResponse) GetUserId() string { + if x != nil { + return x.UserId + } + return "" +} + +func (x *ValidateSessionResponse) GetUserEmail() string { + if x != nil { + return x.UserEmail + } + return "" +} + +func (x *ValidateSessionResponse) GetDeniedReason() string { + if x != nil { + return x.DeniedReason + } + return "" +} + var File_proxy_service_proto protoreflect.FileDescriptor var file_proxy_service_proto_rawDesc = []byte{ @@ -1500,61 +1626,81 @@ var file_proxy_service_proto_rawDesc = []byte{ 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x72, 0x6c, 0x22, 0x26, 0x0a, 0x12, 0x47, 0x65, 0x74, 0x4f, 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, - 0x72, 0x6c, 0x2a, 0x64, 0x0a, 0x16, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, - 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x17, 0x0a, 0x13, - 0x55, 0x50, 0x44, 0x41, 0x54, 0x45, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x43, 0x52, 0x45, 0x41, - 0x54, 0x45, 0x44, 0x10, 0x00, 0x12, 0x18, 0x0a, 0x14, 0x55, 0x50, 0x44, 0x41, 0x54, 0x45, 0x5f, - 0x54, 0x59, 0x50, 0x45, 0x5f, 0x4d, 0x4f, 0x44, 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x01, 0x12, - 0x17, 0x0a, 0x13, 0x55, 0x50, 0x44, 0x41, 0x54, 0x45, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x52, - 0x45, 0x4d, 0x4f, 0x56, 0x45, 0x44, 0x10, 0x02, 0x2a, 0xc8, 0x01, 0x0a, 0x0b, 0x50, 0x72, 0x6f, - 0x78, 0x79, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x18, 0x0a, 0x14, 0x50, 0x52, 0x4f, 0x58, - 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x50, 0x45, 0x4e, 0x44, 0x49, 0x4e, 0x47, - 0x10, 0x00, 0x12, 0x17, 0x0a, 0x13, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, - 0x55, 0x53, 0x5f, 0x41, 0x43, 0x54, 0x49, 0x56, 0x45, 0x10, 0x01, 0x12, 0x23, 0x0a, 0x1f, 0x50, - 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x54, 0x55, 0x4e, 0x4e, - 0x45, 0x4c, 0x5f, 0x4e, 0x4f, 0x54, 0x5f, 0x43, 0x52, 0x45, 0x41, 0x54, 0x45, 0x44, 0x10, 0x02, - 0x12, 0x24, 0x0a, 0x20, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, - 0x5f, 0x43, 0x45, 0x52, 0x54, 0x49, 0x46, 0x49, 0x43, 0x41, 0x54, 0x45, 0x5f, 0x50, 0x45, 0x4e, - 0x44, 0x49, 0x4e, 0x47, 0x10, 0x03, 0x12, 0x23, 0x0a, 0x1f, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, + 0x72, 0x6c, 0x22, 0x55, 0x0a, 0x16, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, + 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, + 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, + 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x23, 0x0a, 0x0d, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, + 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x65, 0x73, + 0x73, 0x69, 0x6f, 0x6e, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x22, 0x8c, 0x01, 0x0a, 0x17, 0x56, 0x61, + 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x12, 0x17, 0x0a, 0x07, 0x75, + 0x73, 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, + 0x65, 0x72, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x65, 0x6d, 0x61, + 0x69, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x75, 0x73, 0x65, 0x72, 0x45, 0x6d, + 0x61, 0x69, 0x6c, 0x12, 0x23, 0x0a, 0x0d, 0x64, 0x65, 0x6e, 0x69, 0x65, 0x64, 0x5f, 0x72, 0x65, + 0x61, 0x73, 0x6f, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x64, 0x65, 0x6e, 0x69, + 0x65, 0x64, 0x52, 0x65, 0x61, 0x73, 0x6f, 0x6e, 0x2a, 0x64, 0x0a, 0x16, 0x50, 0x72, 0x6f, 0x78, + 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, + 0x70, 0x65, 0x12, 0x17, 0x0a, 0x13, 0x55, 0x50, 0x44, 0x41, 0x54, 0x45, 0x5f, 0x54, 0x59, 0x50, + 0x45, 0x5f, 0x43, 0x52, 0x45, 0x41, 0x54, 0x45, 0x44, 0x10, 0x00, 0x12, 0x18, 0x0a, 0x14, 0x55, + 0x50, 0x44, 0x41, 0x54, 0x45, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x4d, 0x4f, 0x44, 0x49, 0x46, + 0x49, 0x45, 0x44, 0x10, 0x01, 0x12, 0x17, 0x0a, 0x13, 0x55, 0x50, 0x44, 0x41, 0x54, 0x45, 0x5f, + 0x54, 0x59, 0x50, 0x45, 0x5f, 0x52, 0x45, 0x4d, 0x4f, 0x56, 0x45, 0x44, 0x10, 0x02, 0x2a, 0xc8, + 0x01, 0x0a, 0x0b, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x18, + 0x0a, 0x14, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x50, + 0x45, 0x4e, 0x44, 0x49, 0x4e, 0x47, 0x10, 0x00, 0x12, 0x17, 0x0a, 0x13, 0x50, 0x52, 0x4f, 0x58, + 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x41, 0x43, 0x54, 0x49, 0x56, 0x45, 0x10, + 0x01, 0x12, 0x23, 0x0a, 0x1f, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, + 0x53, 0x5f, 0x54, 0x55, 0x4e, 0x4e, 0x45, 0x4c, 0x5f, 0x4e, 0x4f, 0x54, 0x5f, 0x43, 0x52, 0x45, + 0x41, 0x54, 0x45, 0x44, 0x10, 0x02, 0x12, 0x24, 0x0a, 0x20, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x43, 0x45, 0x52, 0x54, 0x49, 0x46, 0x49, 0x43, 0x41, - 0x54, 0x45, 0x5f, 0x46, 0x41, 0x49, 0x4c, 0x45, 0x44, 0x10, 0x04, 0x12, 0x16, 0x0a, 0x12, 0x50, - 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x45, 0x52, 0x52, 0x4f, - 0x52, 0x10, 0x05, 0x32, 0xa0, 0x04, 0x0a, 0x0c, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x53, 0x65, 0x72, - 0x76, 0x69, 0x63, 0x65, 0x12, 0x5f, 0x0a, 0x10, 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, - 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, 0x23, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, + 0x54, 0x45, 0x5f, 0x50, 0x45, 0x4e, 0x44, 0x49, 0x4e, 0x47, 0x10, 0x03, 0x12, 0x23, 0x0a, 0x1f, + 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x43, 0x45, 0x52, + 0x54, 0x49, 0x46, 0x49, 0x43, 0x41, 0x54, 0x45, 0x5f, 0x46, 0x41, 0x49, 0x4c, 0x45, 0x44, 0x10, + 0x04, 0x12, 0x16, 0x0a, 0x12, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, + 0x53, 0x5f, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x05, 0x32, 0xfc, 0x04, 0x0a, 0x0c, 0x50, 0x72, + 0x6f, 0x78, 0x79, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x5f, 0x0a, 0x10, 0x47, 0x65, + 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, 0x23, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4d, + 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x1a, 0x24, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, + 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x30, 0x01, 0x12, 0x54, 0x0a, 0x0d, 0x53, + 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x12, 0x20, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, + 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x21, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x6e, 0x64, + 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x12, 0x51, 0x0a, 0x0c, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, + 0x65, 0x12, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, + 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x1a, 0x20, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x5d, 0x0a, 0x10, 0x53, 0x65, 0x6e, 0x64, 0x53, 0x74, 0x61, 0x74, + 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, 0x23, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x6e, 0x64, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x24, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x61, - 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x30, 0x01, 0x12, 0x54, 0x0a, 0x0d, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, - 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x12, 0x20, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, - 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x21, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, - 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x51, 0x0a, 0x0c, 0x41, - 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x12, 0x1f, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, - 0x69, 0x63, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x20, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, - 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x5d, - 0x0a, 0x10, 0x53, 0x65, 0x6e, 0x64, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, - 0x74, 0x65, 0x12, 0x23, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x53, 0x65, 0x6e, 0x64, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x24, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x6e, 0x64, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, - 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x5a, 0x0a, - 0x0f, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, 0x65, 0x65, 0x72, - 0x12, 0x22, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x72, - 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x1a, 0x23, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, 0x65, 0x65, - 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x4b, 0x0a, 0x0a, 0x47, 0x65, 0x74, - 0x4f, 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, 0x12, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4f, 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4f, 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x6e, 0x64, 0x53, + 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x12, 0x5a, 0x0a, 0x0f, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, 0x6f, + 0x78, 0x79, 0x50, 0x65, 0x65, 0x72, 0x12, 0x22, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, + 0x65, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x23, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, + 0x6f, 0x78, 0x79, 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, + 0x4b, 0x0a, 0x0a, 0x47, 0x65, 0x74, 0x4f, 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, 0x12, 0x1d, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4f, 0x49, + 0x44, 0x43, 0x55, 0x52, 0x4c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4f, 0x49, 0x44, + 0x43, 0x55, 0x52, 0x4c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x5a, 0x0a, 0x0f, + 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x12, + 0x22, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x56, 0x61, 0x6c, + 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x1a, 0x23, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -1570,7 +1716,7 @@ func file_proxy_service_proto_rawDescGZIP() []byte { } var file_proxy_service_proto_enumTypes = make([]protoimpl.EnumInfo, 2) -var file_proxy_service_proto_msgTypes = make([]protoimpl.MessageInfo, 18) +var file_proxy_service_proto_msgTypes = make([]protoimpl.MessageInfo, 20) var file_proxy_service_proto_goTypes = []interface{}{ (ProxyMappingUpdateType)(0), // 0: management.ProxyMappingUpdateType (ProxyStatus)(0), // 1: management.ProxyStatus @@ -1592,16 +1738,18 @@ var file_proxy_service_proto_goTypes = []interface{}{ (*CreateProxyPeerResponse)(nil), // 17: management.CreateProxyPeerResponse (*GetOIDCURLRequest)(nil), // 18: management.GetOIDCURLRequest (*GetOIDCURLResponse)(nil), // 19: management.GetOIDCURLResponse - (*timestamppb.Timestamp)(nil), // 20: google.protobuf.Timestamp + (*ValidateSessionRequest)(nil), // 20: management.ValidateSessionRequest + (*ValidateSessionResponse)(nil), // 21: management.ValidateSessionResponse + (*timestamppb.Timestamp)(nil), // 22: google.protobuf.Timestamp } var file_proxy_service_proto_depIdxs = []int32{ - 20, // 0: management.GetMappingUpdateRequest.started_at:type_name -> google.protobuf.Timestamp + 22, // 0: management.GetMappingUpdateRequest.started_at:type_name -> google.protobuf.Timestamp 6, // 1: management.GetMappingUpdateResponse.mapping:type_name -> management.ProxyMapping 0, // 2: management.ProxyMapping.type:type_name -> management.ProxyMappingUpdateType 4, // 3: management.ProxyMapping.path:type_name -> management.PathMapping 5, // 4: management.ProxyMapping.auth:type_name -> management.Authentication 9, // 5: management.SendAccessLogRequest.log:type_name -> management.AccessLog - 20, // 6: management.AccessLog.timestamp:type_name -> google.protobuf.Timestamp + 22, // 6: management.AccessLog.timestamp:type_name -> google.protobuf.Timestamp 11, // 7: management.AuthenticateRequest.password:type_name -> management.PasswordRequest 12, // 8: management.AuthenticateRequest.pin:type_name -> management.PinRequest 1, // 9: management.SendStatusUpdateRequest.status:type_name -> management.ProxyStatus @@ -1611,14 +1759,16 @@ var file_proxy_service_proto_depIdxs = []int32{ 14, // 13: management.ProxyService.SendStatusUpdate:input_type -> management.SendStatusUpdateRequest 16, // 14: management.ProxyService.CreateProxyPeer:input_type -> management.CreateProxyPeerRequest 18, // 15: management.ProxyService.GetOIDCURL:input_type -> management.GetOIDCURLRequest - 3, // 16: management.ProxyService.GetMappingUpdate:output_type -> management.GetMappingUpdateResponse - 8, // 17: management.ProxyService.SendAccessLog:output_type -> management.SendAccessLogResponse - 13, // 18: management.ProxyService.Authenticate:output_type -> management.AuthenticateResponse - 15, // 19: management.ProxyService.SendStatusUpdate:output_type -> management.SendStatusUpdateResponse - 17, // 20: management.ProxyService.CreateProxyPeer:output_type -> management.CreateProxyPeerResponse - 19, // 21: management.ProxyService.GetOIDCURL:output_type -> management.GetOIDCURLResponse - 16, // [16:22] is the sub-list for method output_type - 10, // [10:16] is the sub-list for method input_type + 20, // 16: management.ProxyService.ValidateSession:input_type -> management.ValidateSessionRequest + 3, // 17: management.ProxyService.GetMappingUpdate:output_type -> management.GetMappingUpdateResponse + 8, // 18: management.ProxyService.SendAccessLog:output_type -> management.SendAccessLogResponse + 13, // 19: management.ProxyService.Authenticate:output_type -> management.AuthenticateResponse + 15, // 20: management.ProxyService.SendStatusUpdate:output_type -> management.SendStatusUpdateResponse + 17, // 21: management.ProxyService.CreateProxyPeer:output_type -> management.CreateProxyPeerResponse + 19, // 22: management.ProxyService.GetOIDCURL:output_type -> management.GetOIDCURLResponse + 21, // 23: management.ProxyService.ValidateSession:output_type -> management.ValidateSessionResponse + 17, // [17:24] is the sub-list for method output_type + 10, // [10:17] is the sub-list for method input_type 10, // [10:10] is the sub-list for extension type_name 10, // [10:10] is the sub-list for extension extendee 0, // [0:10] is the sub-list for field type_name @@ -1846,6 +1996,30 @@ func file_proxy_service_proto_init() { return nil } } + file_proxy_service_proto_msgTypes[18].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ValidateSessionRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[19].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ValidateSessionResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } } file_proxy_service_proto_msgTypes[8].OneofWrappers = []interface{}{ (*AuthenticateRequest_Password)(nil), @@ -1859,7 +2033,7 @@ func file_proxy_service_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_proxy_service_proto_rawDesc, NumEnums: 2, - NumMessages: 18, + NumMessages: 20, NumExtensions: 0, NumServices: 1, }, diff --git a/shared/management/proto/proxy_service.proto b/shared/management/proto/proxy_service.proto index 14a8ebc76..617f2b2e4 100644 --- a/shared/management/proto/proxy_service.proto +++ b/shared/management/proto/proxy_service.proto @@ -20,6 +20,10 @@ service ProxyService { rpc CreateProxyPeer(CreateProxyPeerRequest) returns (CreateProxyPeerResponse); rpc GetOIDCURL(GetOIDCURLRequest) returns (GetOIDCURLResponse); + + // ValidateSession validates a session token and checks user access permissions. + // Called by the proxy after receiving a session token from OIDC callback. + rpc ValidateSession(ValidateSessionRequest) returns (ValidateSessionResponse); } // GetMappingUpdateRequest is sent to initialise a mapping stream. @@ -164,3 +168,15 @@ message GetOIDCURLRequest { message GetOIDCURLResponse { string url = 1; } + +message ValidateSessionRequest { + string domain = 1; + string session_token = 2; +} + +message ValidateSessionResponse { + bool valid = 1; + string user_id = 2; + string user_email = 3; + string denied_reason = 4; +} diff --git a/shared/management/proto/proxy_service_grpc.pb.go b/shared/management/proto/proxy_service_grpc.pb.go index 9abeaf219..627b217d8 100644 --- a/shared/management/proto/proxy_service_grpc.pb.go +++ b/shared/management/proto/proxy_service_grpc.pb.go @@ -24,6 +24,9 @@ type ProxyServiceClient interface { SendStatusUpdate(ctx context.Context, in *SendStatusUpdateRequest, opts ...grpc.CallOption) (*SendStatusUpdateResponse, error) CreateProxyPeer(ctx context.Context, in *CreateProxyPeerRequest, opts ...grpc.CallOption) (*CreateProxyPeerResponse, error) GetOIDCURL(ctx context.Context, in *GetOIDCURLRequest, opts ...grpc.CallOption) (*GetOIDCURLResponse, error) + // ValidateSession validates a session token and checks user access permissions. + // Called by the proxy after receiving a session token from OIDC callback. + ValidateSession(ctx context.Context, in *ValidateSessionRequest, opts ...grpc.CallOption) (*ValidateSessionResponse, error) } type proxyServiceClient struct { @@ -111,6 +114,15 @@ func (c *proxyServiceClient) GetOIDCURL(ctx context.Context, in *GetOIDCURLReque return out, nil } +func (c *proxyServiceClient) ValidateSession(ctx context.Context, in *ValidateSessionRequest, opts ...grpc.CallOption) (*ValidateSessionResponse, error) { + out := new(ValidateSessionResponse) + err := c.cc.Invoke(ctx, "/management.ProxyService/ValidateSession", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + // ProxyServiceServer is the server API for ProxyService service. // All implementations must embed UnimplementedProxyServiceServer // for forward compatibility @@ -121,6 +133,9 @@ type ProxyServiceServer interface { SendStatusUpdate(context.Context, *SendStatusUpdateRequest) (*SendStatusUpdateResponse, error) CreateProxyPeer(context.Context, *CreateProxyPeerRequest) (*CreateProxyPeerResponse, error) GetOIDCURL(context.Context, *GetOIDCURLRequest) (*GetOIDCURLResponse, error) + // ValidateSession validates a session token and checks user access permissions. + // Called by the proxy after receiving a session token from OIDC callback. + ValidateSession(context.Context, *ValidateSessionRequest) (*ValidateSessionResponse, error) mustEmbedUnimplementedProxyServiceServer() } @@ -146,6 +161,9 @@ func (UnimplementedProxyServiceServer) CreateProxyPeer(context.Context, *CreateP func (UnimplementedProxyServiceServer) GetOIDCURL(context.Context, *GetOIDCURLRequest) (*GetOIDCURLResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method GetOIDCURL not implemented") } +func (UnimplementedProxyServiceServer) ValidateSession(context.Context, *ValidateSessionRequest) (*ValidateSessionResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method ValidateSession not implemented") +} func (UnimplementedProxyServiceServer) mustEmbedUnimplementedProxyServiceServer() {} // UnsafeProxyServiceServer may be embedded to opt out of forward compatibility for this service. @@ -270,6 +288,24 @@ func _ProxyService_GetOIDCURL_Handler(srv interface{}, ctx context.Context, dec return interceptor(ctx, in, info, handler) } +func _ProxyService_ValidateSession_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ValidateSessionRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ProxyServiceServer).ValidateSession(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/management.ProxyService/ValidateSession", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ProxyServiceServer).ValidateSession(ctx, req.(*ValidateSessionRequest)) + } + return interceptor(ctx, in, info, handler) +} + // ProxyService_ServiceDesc is the grpc.ServiceDesc for ProxyService service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -297,6 +333,10 @@ var ProxyService_ServiceDesc = grpc.ServiceDesc{ MethodName: "GetOIDCURL", Handler: _ProxyService_GetOIDCURL_Handler, }, + { + MethodName: "ValidateSession", + Handler: _ProxyService_ValidateSession_Handler, + }, }, Streams: []grpc.StreamDesc{ {