mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
refactor: add ValidateSession gRPC and streamline test setup
- Add ValidateSession gRPC method for proxy-side user validation - Move group access validation from REST callback to gRPC layer - Capture user info in access logs via CapturedData mutable pointer - Create validate_session_test.go for gRPC validation tests - Simplify auth_callback_integration_test.go to create accounts programmatically instead of using SQL file - SQL test data file now only used by validate_session_test.go
This commit is contained in:
@@ -76,26 +76,18 @@ func (h *AuthCallbackHandler) handleCallback(w http.ResponseWriter, r *http.Requ
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.proxyService.ValidateUserGroupAccess(r.Context(), redirectURL.Hostname(), userID); err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"user_id": userID,
|
||||
"domain": redirectURL.Hostname(),
|
||||
"error": err.Error(),
|
||||
}).Warn("User denied access to reverse proxy")
|
||||
|
||||
redirectURL.Scheme = "https"
|
||||
query := redirectURL.Query()
|
||||
query.Set("error", "access_denied")
|
||||
query.Set("error_description", "You are not authorized to access this service")
|
||||
redirectURL.RawQuery = query.Encode()
|
||||
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
||||
return
|
||||
}
|
||||
// Group validation is performed by the proxy via ValidateSession gRPC call.
|
||||
// This allows the proxy to show 403 pages directly without redirect dance.
|
||||
|
||||
sessionToken, err := h.proxyService.GenerateSessionToken(r.Context(), redirectURL.Hostname(), userID, auth.MethodOIDC)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to create session token")
|
||||
http.Error(w, "Failed to create session", http.StatusInternalServerError)
|
||||
redirectURL.Scheme = "https"
|
||||
query := redirectURL.Query()
|
||||
query.Set("error", "access_denied")
|
||||
query.Set("error_description", "Service configuration error")
|
||||
redirectURL.RawQuery = query.Encode()
|
||||
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
accesslogs "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
@@ -164,14 +165,15 @@ func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func setupAuthCallbackTest(t *testing.T, sqlFile string) *testSetup {
|
||||
func setupAuthCallbackTest(t *testing.T) *testSetup {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, sqlFile, t.TempDir())
|
||||
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
|
||||
createTestAccountsAndUsers(t, ctx, testStore)
|
||||
createTestReverseProxies(t, ctx, testStore)
|
||||
|
||||
oidcServer := newFakeOIDCServer()
|
||||
@@ -307,6 +309,37 @@ func strPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
func createTestAccountsAndUsers(t *testing.T, ctx context.Context, testStore store.Store) {
|
||||
t.Helper()
|
||||
|
||||
testAccount := &types.Account{
|
||||
Id: "testAccountId",
|
||||
Domain: "test.com",
|
||||
DomainCategory: "private",
|
||||
IsDomainPrimaryAccount: true,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
require.NoError(t, testStore.SaveAccount(ctx, testAccount))
|
||||
|
||||
allowedGroup := &types.Group{
|
||||
ID: "allowedGroupId",
|
||||
AccountID: "testAccountId",
|
||||
Name: "Allowed Group",
|
||||
Issued: "api",
|
||||
}
|
||||
require.NoError(t, testStore.CreateGroup(ctx, allowedGroup))
|
||||
|
||||
allowedUser := &types.User{
|
||||
Id: "allowedUserId",
|
||||
AccountID: "testAccountId",
|
||||
Role: types.UserRoleUser,
|
||||
AutoGroups: []string{"allowedGroupId"},
|
||||
CreatedAt: time.Now(),
|
||||
Issued: "api",
|
||||
}
|
||||
require.NoError(t, testStore.SaveUser(ctx, allowedUser))
|
||||
}
|
||||
|
||||
// testReverseProxyManager is a minimal implementation for testing.
|
||||
type testReverseProxyManager struct {
|
||||
store store.Store
|
||||
@@ -360,6 +393,10 @@ func (m *testReverseProxyManager) GetAccountReverseProxies(ctx context.Context,
|
||||
return m.store.GetAccountReverseProxies(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
func (m *testReverseProxyManager) GetProxyIDByTargetID(_ context.Context, _, _ string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func createTestState(t *testing.T, ps *nbgrpc.ProxyServiceServer, redirectURL string) string {
|
||||
t.Helper()
|
||||
|
||||
@@ -376,7 +413,7 @@ func createTestState(t *testing.T, ps *nbgrpc.ProxyServiceServer, redirectURL st
|
||||
}
|
||||
|
||||
func TestAuthCallback_UserAllowedToLogin(t *testing.T) {
|
||||
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
|
||||
setup := setupAuthCallbackTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
setup.oidcServer.tokenSubject = "allowedUserId"
|
||||
@@ -401,81 +438,8 @@ func TestAuthCallback_UserAllowedToLogin(t *testing.T) {
|
||||
require.Empty(t, parsedLocation.Query().Get("error"), "Should not have error parameter")
|
||||
}
|
||||
|
||||
func TestAuthCallback_UserNotInAllowedGroup(t *testing.T) {
|
||||
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
|
||||
defer setup.cleanup()
|
||||
|
||||
setup.oidcServer.tokenSubject = "nonGroupUserId"
|
||||
|
||||
state := createTestState(t, setup.proxyService, "https://restricted-proxy.example.com/")
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
setup.router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusFound, rec.Code)
|
||||
|
||||
location := rec.Header().Get("Location")
|
||||
require.NotEmpty(t, location)
|
||||
|
||||
parsedLocation, err := url.Parse(location)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "restricted-proxy.example.com", parsedLocation.Host)
|
||||
require.Equal(t, "access_denied", parsedLocation.Query().Get("error"))
|
||||
require.Contains(t, parsedLocation.Query().Get("error_description"), "not authorized")
|
||||
}
|
||||
|
||||
func TestAuthCallback_ProxyInDifferentAccount(t *testing.T) {
|
||||
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
|
||||
defer setup.cleanup()
|
||||
|
||||
setup.oidcServer.tokenSubject = "otherAccountUserId"
|
||||
|
||||
state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/")
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
setup.router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusFound, rec.Code)
|
||||
|
||||
location := rec.Header().Get("Location")
|
||||
require.NotEmpty(t, location)
|
||||
|
||||
parsedLocation, err := url.Parse(location)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "access_denied", parsedLocation.Query().Get("error"))
|
||||
require.Contains(t, parsedLocation.Query().Get("error_description"), "not authorized")
|
||||
}
|
||||
|
||||
func TestAuthCallback_UserNotFound(t *testing.T) {
|
||||
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
|
||||
defer setup.cleanup()
|
||||
|
||||
setup.oidcServer.tokenSubject = "nonExistentUserId"
|
||||
|
||||
state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/")
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
setup.router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusFound, rec.Code)
|
||||
|
||||
location := rec.Header().Get("Location")
|
||||
parsedLocation, err := url.Parse(location)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "access_denied", parsedLocation.Query().Get("error"))
|
||||
}
|
||||
|
||||
func TestAuthCallback_ProxyNotFound(t *testing.T) {
|
||||
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
|
||||
setup := setupAuthCallbackTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
setup.oidcServer.tokenSubject = "allowedUserId"
|
||||
@@ -499,7 +463,7 @@ func TestAuthCallback_ProxyNotFound(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAuthCallback_InvalidToken(t *testing.T) {
|
||||
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
|
||||
setup := setupAuthCallbackTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
setup.oidcServer.failExchange = true
|
||||
@@ -516,7 +480,7 @@ func TestAuthCallback_InvalidToken(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAuthCallback_ExpiredToken(t *testing.T) {
|
||||
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
|
||||
setup := setupAuthCallbackTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
setup.oidcServer.tokenSubject = "allowedUserId"
|
||||
@@ -534,7 +498,7 @@ func TestAuthCallback_ExpiredToken(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAuthCallback_InvalidState(t *testing.T) {
|
||||
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
|
||||
setup := setupAuthCallbackTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state=invalid-state", nil)
|
||||
@@ -547,7 +511,7 @@ func TestAuthCallback_InvalidState(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAuthCallback_MissingState(t *testing.T) {
|
||||
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
|
||||
setup := setupAuthCallbackTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code", nil)
|
||||
@@ -557,26 +521,3 @@ func TestAuthCallback_MissingState(t *testing.T) {
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestAuthCallback_BearerAuthDisabled(t *testing.T) {
|
||||
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
|
||||
defer setup.cleanup()
|
||||
|
||||
setup.oidcServer.tokenSubject = "allowedUserId"
|
||||
|
||||
state := createTestState(t, setup.proxyService, "https://no-auth-proxy.example.com/")
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
setup.router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusFound, rec.Code)
|
||||
|
||||
location := rec.Header().Get("Location")
|
||||
parsedLocation, err := url.Parse(location)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotEmpty(t, parsedLocation.Query().Get("session_token"))
|
||||
require.Empty(t, parsedLocation.Query().Get("error"))
|
||||
}
|
||||
|
||||
17
management/server/testdata/auth_callback.sql
vendored
Normal file
17
management/server/testdata/auth_callback.sql
vendored
Normal file
@@ -0,0 +1,17 @@
|
||||
-- Schema definitions (must match GORM auto-migrate order)
|
||||
CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
|
||||
CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
|
||||
-- Test accounts
|
||||
INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
|
||||
INSERT INTO accounts VALUES('otherAccountId','','2024-10-02 16:01:38.000000000+00:00','other.com','private',1,'otherNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
|
||||
|
||||
-- Test groups
|
||||
INSERT INTO "groups" VALUES('allowedGroupId','testAccountId','Allowed Group','api','[]',0,'');
|
||||
INSERT INTO "groups" VALUES('restrictedGroupId','testAccountId','Restricted Group','api','[]',0,'');
|
||||
|
||||
-- Test users
|
||||
INSERT INTO users VALUES('allowedUserId','testAccountId','user',0,0,'','["allowedGroupId"]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
|
||||
INSERT INTO users VALUES('nonGroupUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
|
||||
INSERT INTO users VALUES('otherAccountUserId','otherAccountId','user',0,0,'','["allowedGroupId"]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
|
||||
Reference in New Issue
Block a user