refactor: add ValidateSession gRPC and streamline test setup

- Add ValidateSession gRPC method for proxy-side user validation
- Move group access validation from REST callback to gRPC layer
- Capture user info in access logs via CapturedData mutable pointer
- Create validate_session_test.go for gRPC validation tests
- Simplify auth_callback_integration_test.go to create accounts
  programmatically instead of using SQL file
- SQL test data file now only used by validate_session_test.go
This commit is contained in:
mlsmaycon
2026-02-10 20:31:03 +01:00
parent 0cb02bd906
commit eea6120cd0
15 changed files with 955 additions and 238 deletions

View File

@@ -76,26 +76,18 @@ func (h *AuthCallbackHandler) handleCallback(w http.ResponseWriter, r *http.Requ
return
}
if err := h.proxyService.ValidateUserGroupAccess(r.Context(), redirectURL.Hostname(), userID); err != nil {
log.WithFields(log.Fields{
"user_id": userID,
"domain": redirectURL.Hostname(),
"error": err.Error(),
}).Warn("User denied access to reverse proxy")
redirectURL.Scheme = "https"
query := redirectURL.Query()
query.Set("error", "access_denied")
query.Set("error_description", "You are not authorized to access this service")
redirectURL.RawQuery = query.Encode()
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
return
}
// Group validation is performed by the proxy via ValidateSession gRPC call.
// This allows the proxy to show 403 pages directly without redirect dance.
sessionToken, err := h.proxyService.GenerateSessionToken(r.Context(), redirectURL.Hostname(), userID, auth.MethodOIDC)
if err != nil {
log.WithError(err).Error("Failed to create session token")
http.Error(w, "Failed to create session", http.StatusInternalServerError)
redirectURL.Scheme = "https"
query := redirectURL.Query()
query.Set("error", "access_denied")
query.Set("error_description", "Service configuration error")
redirectURL.RawQuery = query.Encode()
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
return
}

View File

@@ -22,6 +22,7 @@ import (
accesslogs "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/shared/management/proto"
)
@@ -164,14 +165,15 @@ func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string)
return nil, nil
}
func setupAuthCallbackTest(t *testing.T, sqlFile string) *testSetup {
func setupAuthCallbackTest(t *testing.T) *testSetup {
t.Helper()
ctx := context.Background()
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, sqlFile, t.TempDir())
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
require.NoError(t, err)
createTestAccountsAndUsers(t, ctx, testStore)
createTestReverseProxies(t, ctx, testStore)
oidcServer := newFakeOIDCServer()
@@ -307,6 +309,37 @@ func strPtr(s string) *string {
return &s
}
func createTestAccountsAndUsers(t *testing.T, ctx context.Context, testStore store.Store) {
t.Helper()
testAccount := &types.Account{
Id: "testAccountId",
Domain: "test.com",
DomainCategory: "private",
IsDomainPrimaryAccount: true,
CreatedAt: time.Now(),
}
require.NoError(t, testStore.SaveAccount(ctx, testAccount))
allowedGroup := &types.Group{
ID: "allowedGroupId",
AccountID: "testAccountId",
Name: "Allowed Group",
Issued: "api",
}
require.NoError(t, testStore.CreateGroup(ctx, allowedGroup))
allowedUser := &types.User{
Id: "allowedUserId",
AccountID: "testAccountId",
Role: types.UserRoleUser,
AutoGroups: []string{"allowedGroupId"},
CreatedAt: time.Now(),
Issued: "api",
}
require.NoError(t, testStore.SaveUser(ctx, allowedUser))
}
// testReverseProxyManager is a minimal implementation for testing.
type testReverseProxyManager struct {
store store.Store
@@ -360,6 +393,10 @@ func (m *testReverseProxyManager) GetAccountReverseProxies(ctx context.Context,
return m.store.GetAccountReverseProxies(ctx, store.LockingStrengthNone, accountID)
}
func (m *testReverseProxyManager) GetProxyIDByTargetID(_ context.Context, _, _ string) (string, error) {
return "", nil
}
func createTestState(t *testing.T, ps *nbgrpc.ProxyServiceServer, redirectURL string) string {
t.Helper()
@@ -376,7 +413,7 @@ func createTestState(t *testing.T, ps *nbgrpc.ProxyServiceServer, redirectURL st
}
func TestAuthCallback_UserAllowedToLogin(t *testing.T) {
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
setup := setupAuthCallbackTest(t)
defer setup.cleanup()
setup.oidcServer.tokenSubject = "allowedUserId"
@@ -401,81 +438,8 @@ func TestAuthCallback_UserAllowedToLogin(t *testing.T) {
require.Empty(t, parsedLocation.Query().Get("error"), "Should not have error parameter")
}
func TestAuthCallback_UserNotInAllowedGroup(t *testing.T) {
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
defer setup.cleanup()
setup.oidcServer.tokenSubject = "nonGroupUserId"
state := createTestState(t, setup.proxyService, "https://restricted-proxy.example.com/")
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
rec := httptest.NewRecorder()
setup.router.ServeHTTP(rec, req)
require.Equal(t, http.StatusFound, rec.Code)
location := rec.Header().Get("Location")
require.NotEmpty(t, location)
parsedLocation, err := url.Parse(location)
require.NoError(t, err)
require.Equal(t, "restricted-proxy.example.com", parsedLocation.Host)
require.Equal(t, "access_denied", parsedLocation.Query().Get("error"))
require.Contains(t, parsedLocation.Query().Get("error_description"), "not authorized")
}
func TestAuthCallback_ProxyInDifferentAccount(t *testing.T) {
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
defer setup.cleanup()
setup.oidcServer.tokenSubject = "otherAccountUserId"
state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/")
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
rec := httptest.NewRecorder()
setup.router.ServeHTTP(rec, req)
require.Equal(t, http.StatusFound, rec.Code)
location := rec.Header().Get("Location")
require.NotEmpty(t, location)
parsedLocation, err := url.Parse(location)
require.NoError(t, err)
require.Equal(t, "access_denied", parsedLocation.Query().Get("error"))
require.Contains(t, parsedLocation.Query().Get("error_description"), "not authorized")
}
func TestAuthCallback_UserNotFound(t *testing.T) {
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
defer setup.cleanup()
setup.oidcServer.tokenSubject = "nonExistentUserId"
state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/")
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
rec := httptest.NewRecorder()
setup.router.ServeHTTP(rec, req)
require.Equal(t, http.StatusFound, rec.Code)
location := rec.Header().Get("Location")
parsedLocation, err := url.Parse(location)
require.NoError(t, err)
require.Equal(t, "access_denied", parsedLocation.Query().Get("error"))
}
func TestAuthCallback_ProxyNotFound(t *testing.T) {
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
setup := setupAuthCallbackTest(t)
defer setup.cleanup()
setup.oidcServer.tokenSubject = "allowedUserId"
@@ -499,7 +463,7 @@ func TestAuthCallback_ProxyNotFound(t *testing.T) {
}
func TestAuthCallback_InvalidToken(t *testing.T) {
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
setup := setupAuthCallbackTest(t)
defer setup.cleanup()
setup.oidcServer.failExchange = true
@@ -516,7 +480,7 @@ func TestAuthCallback_InvalidToken(t *testing.T) {
}
func TestAuthCallback_ExpiredToken(t *testing.T) {
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
setup := setupAuthCallbackTest(t)
defer setup.cleanup()
setup.oidcServer.tokenSubject = "allowedUserId"
@@ -534,7 +498,7 @@ func TestAuthCallback_ExpiredToken(t *testing.T) {
}
func TestAuthCallback_InvalidState(t *testing.T) {
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
setup := setupAuthCallbackTest(t)
defer setup.cleanup()
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state=invalid-state", nil)
@@ -547,7 +511,7 @@ func TestAuthCallback_InvalidState(t *testing.T) {
}
func TestAuthCallback_MissingState(t *testing.T) {
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
setup := setupAuthCallbackTest(t)
defer setup.cleanup()
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code", nil)
@@ -557,26 +521,3 @@ func TestAuthCallback_MissingState(t *testing.T) {
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestAuthCallback_BearerAuthDisabled(t *testing.T) {
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
defer setup.cleanup()
setup.oidcServer.tokenSubject = "allowedUserId"
state := createTestState(t, setup.proxyService, "https://no-auth-proxy.example.com/")
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
rec := httptest.NewRecorder()
setup.router.ServeHTTP(rec, req)
require.Equal(t, http.StatusFound, rec.Code)
location := rec.Header().Get("Location")
parsedLocation, err := url.Parse(location)
require.NoError(t, err)
require.NotEmpty(t, parsedLocation.Query().Get("session_token"))
require.Empty(t, parsedLocation.Query().Get("error"))
}

View File

@@ -0,0 +1,17 @@
-- Schema definitions (must match GORM auto-migrate order)
CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
-- Test accounts
INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO accounts VALUES('otherAccountId','','2024-10-02 16:01:38.000000000+00:00','other.com','private',1,'otherNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
-- Test groups
INSERT INTO "groups" VALUES('allowedGroupId','testAccountId','Allowed Group','api','[]',0,'');
INSERT INTO "groups" VALUES('restrictedGroupId','testAccountId','Restricted Group','api','[]',0,'');
-- Test users
INSERT INTO users VALUES('allowedUserId','testAccountId','user',0,0,'','["allowedGroupId"]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('nonGroupUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('otherAccountUserId','otherAccountId','user',0,0,'','["allowedGroupId"]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');