refactor: add ValidateSession gRPC and streamline test setup

- Add ValidateSession gRPC method for proxy-side user validation
- Move group access validation from REST callback to gRPC layer
- Capture user info in access logs via CapturedData mutable pointer
- Create validate_session_test.go for gRPC validation tests
- Simplify auth_callback_integration_test.go to create accounts
  programmatically instead of using SQL file
- SQL test data file now only used by validate_session_test.go
This commit is contained in:
mlsmaycon
2026-02-10 20:31:03 +01:00
parent 0cb02bd906
commit eea6120cd0
15 changed files with 955 additions and 238 deletions

View File

@@ -22,6 +22,7 @@ import (
accesslogs "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/shared/management/proto"
)
@@ -164,14 +165,15 @@ func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string)
return nil, nil
}
func setupAuthCallbackTest(t *testing.T, sqlFile string) *testSetup {
func setupAuthCallbackTest(t *testing.T) *testSetup {
t.Helper()
ctx := context.Background()
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, sqlFile, t.TempDir())
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
require.NoError(t, err)
createTestAccountsAndUsers(t, ctx, testStore)
createTestReverseProxies(t, ctx, testStore)
oidcServer := newFakeOIDCServer()
@@ -307,6 +309,37 @@ func strPtr(s string) *string {
return &s
}
func createTestAccountsAndUsers(t *testing.T, ctx context.Context, testStore store.Store) {
t.Helper()
testAccount := &types.Account{
Id: "testAccountId",
Domain: "test.com",
DomainCategory: "private",
IsDomainPrimaryAccount: true,
CreatedAt: time.Now(),
}
require.NoError(t, testStore.SaveAccount(ctx, testAccount))
allowedGroup := &types.Group{
ID: "allowedGroupId",
AccountID: "testAccountId",
Name: "Allowed Group",
Issued: "api",
}
require.NoError(t, testStore.CreateGroup(ctx, allowedGroup))
allowedUser := &types.User{
Id: "allowedUserId",
AccountID: "testAccountId",
Role: types.UserRoleUser,
AutoGroups: []string{"allowedGroupId"},
CreatedAt: time.Now(),
Issued: "api",
}
require.NoError(t, testStore.SaveUser(ctx, allowedUser))
}
// testReverseProxyManager is a minimal implementation for testing.
type testReverseProxyManager struct {
store store.Store
@@ -360,6 +393,10 @@ func (m *testReverseProxyManager) GetAccountReverseProxies(ctx context.Context,
return m.store.GetAccountReverseProxies(ctx, store.LockingStrengthNone, accountID)
}
func (m *testReverseProxyManager) GetProxyIDByTargetID(_ context.Context, _, _ string) (string, error) {
return "", nil
}
func createTestState(t *testing.T, ps *nbgrpc.ProxyServiceServer, redirectURL string) string {
t.Helper()
@@ -376,7 +413,7 @@ func createTestState(t *testing.T, ps *nbgrpc.ProxyServiceServer, redirectURL st
}
func TestAuthCallback_UserAllowedToLogin(t *testing.T) {
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
setup := setupAuthCallbackTest(t)
defer setup.cleanup()
setup.oidcServer.tokenSubject = "allowedUserId"
@@ -401,81 +438,8 @@ func TestAuthCallback_UserAllowedToLogin(t *testing.T) {
require.Empty(t, parsedLocation.Query().Get("error"), "Should not have error parameter")
}
func TestAuthCallback_UserNotInAllowedGroup(t *testing.T) {
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
defer setup.cleanup()
setup.oidcServer.tokenSubject = "nonGroupUserId"
state := createTestState(t, setup.proxyService, "https://restricted-proxy.example.com/")
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
rec := httptest.NewRecorder()
setup.router.ServeHTTP(rec, req)
require.Equal(t, http.StatusFound, rec.Code)
location := rec.Header().Get("Location")
require.NotEmpty(t, location)
parsedLocation, err := url.Parse(location)
require.NoError(t, err)
require.Equal(t, "restricted-proxy.example.com", parsedLocation.Host)
require.Equal(t, "access_denied", parsedLocation.Query().Get("error"))
require.Contains(t, parsedLocation.Query().Get("error_description"), "not authorized")
}
func TestAuthCallback_ProxyInDifferentAccount(t *testing.T) {
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
defer setup.cleanup()
setup.oidcServer.tokenSubject = "otherAccountUserId"
state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/")
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
rec := httptest.NewRecorder()
setup.router.ServeHTTP(rec, req)
require.Equal(t, http.StatusFound, rec.Code)
location := rec.Header().Get("Location")
require.NotEmpty(t, location)
parsedLocation, err := url.Parse(location)
require.NoError(t, err)
require.Equal(t, "access_denied", parsedLocation.Query().Get("error"))
require.Contains(t, parsedLocation.Query().Get("error_description"), "not authorized")
}
func TestAuthCallback_UserNotFound(t *testing.T) {
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
defer setup.cleanup()
setup.oidcServer.tokenSubject = "nonExistentUserId"
state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/")
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
rec := httptest.NewRecorder()
setup.router.ServeHTTP(rec, req)
require.Equal(t, http.StatusFound, rec.Code)
location := rec.Header().Get("Location")
parsedLocation, err := url.Parse(location)
require.NoError(t, err)
require.Equal(t, "access_denied", parsedLocation.Query().Get("error"))
}
func TestAuthCallback_ProxyNotFound(t *testing.T) {
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
setup := setupAuthCallbackTest(t)
defer setup.cleanup()
setup.oidcServer.tokenSubject = "allowedUserId"
@@ -499,7 +463,7 @@ func TestAuthCallback_ProxyNotFound(t *testing.T) {
}
func TestAuthCallback_InvalidToken(t *testing.T) {
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
setup := setupAuthCallbackTest(t)
defer setup.cleanup()
setup.oidcServer.failExchange = true
@@ -516,7 +480,7 @@ func TestAuthCallback_InvalidToken(t *testing.T) {
}
func TestAuthCallback_ExpiredToken(t *testing.T) {
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
setup := setupAuthCallbackTest(t)
defer setup.cleanup()
setup.oidcServer.tokenSubject = "allowedUserId"
@@ -534,7 +498,7 @@ func TestAuthCallback_ExpiredToken(t *testing.T) {
}
func TestAuthCallback_InvalidState(t *testing.T) {
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
setup := setupAuthCallbackTest(t)
defer setup.cleanup()
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state=invalid-state", nil)
@@ -547,7 +511,7 @@ func TestAuthCallback_InvalidState(t *testing.T) {
}
func TestAuthCallback_MissingState(t *testing.T) {
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
setup := setupAuthCallbackTest(t)
defer setup.cleanup()
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code", nil)
@@ -557,26 +521,3 @@ func TestAuthCallback_MissingState(t *testing.T) {
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestAuthCallback_BearerAuthDisabled(t *testing.T) {
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
defer setup.cleanup()
setup.oidcServer.tokenSubject = "allowedUserId"
state := createTestState(t, setup.proxyService, "https://no-auth-proxy.example.com/")
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
rec := httptest.NewRecorder()
setup.router.ServeHTTP(rec, req)
require.Equal(t, http.StatusFound, rec.Code)
location := rec.Header().Get("Location")
parsedLocation, err := url.Parse(location)
require.NoError(t, err)
require.NotEmpty(t, parsedLocation.Query().Get("session_token"))
require.Empty(t, parsedLocation.Query().Get("error"))
}