mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-19 00:36:38 +00:00
refactor: add ValidateSession gRPC and streamline test setup
- Add ValidateSession gRPC method for proxy-side user validation - Move group access validation from REST callback to gRPC layer - Capture user info in access logs via CapturedData mutable pointer - Create validate_session_test.go for gRPC validation tests - Simplify auth_callback_integration_test.go to create accounts programmatically instead of using SQL file - SQL test data file now only used by validate_session_test.go
This commit is contained in:
@@ -22,6 +22,7 @@ import (
|
||||
accesslogs "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
@@ -164,14 +165,15 @@ func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func setupAuthCallbackTest(t *testing.T, sqlFile string) *testSetup {
|
||||
func setupAuthCallbackTest(t *testing.T) *testSetup {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, sqlFile, t.TempDir())
|
||||
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
|
||||
createTestAccountsAndUsers(t, ctx, testStore)
|
||||
createTestReverseProxies(t, ctx, testStore)
|
||||
|
||||
oidcServer := newFakeOIDCServer()
|
||||
@@ -307,6 +309,37 @@ func strPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
func createTestAccountsAndUsers(t *testing.T, ctx context.Context, testStore store.Store) {
|
||||
t.Helper()
|
||||
|
||||
testAccount := &types.Account{
|
||||
Id: "testAccountId",
|
||||
Domain: "test.com",
|
||||
DomainCategory: "private",
|
||||
IsDomainPrimaryAccount: true,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
require.NoError(t, testStore.SaveAccount(ctx, testAccount))
|
||||
|
||||
allowedGroup := &types.Group{
|
||||
ID: "allowedGroupId",
|
||||
AccountID: "testAccountId",
|
||||
Name: "Allowed Group",
|
||||
Issued: "api",
|
||||
}
|
||||
require.NoError(t, testStore.CreateGroup(ctx, allowedGroup))
|
||||
|
||||
allowedUser := &types.User{
|
||||
Id: "allowedUserId",
|
||||
AccountID: "testAccountId",
|
||||
Role: types.UserRoleUser,
|
||||
AutoGroups: []string{"allowedGroupId"},
|
||||
CreatedAt: time.Now(),
|
||||
Issued: "api",
|
||||
}
|
||||
require.NoError(t, testStore.SaveUser(ctx, allowedUser))
|
||||
}
|
||||
|
||||
// testReverseProxyManager is a minimal implementation for testing.
|
||||
type testReverseProxyManager struct {
|
||||
store store.Store
|
||||
@@ -360,6 +393,10 @@ func (m *testReverseProxyManager) GetAccountReverseProxies(ctx context.Context,
|
||||
return m.store.GetAccountReverseProxies(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
func (m *testReverseProxyManager) GetProxyIDByTargetID(_ context.Context, _, _ string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func createTestState(t *testing.T, ps *nbgrpc.ProxyServiceServer, redirectURL string) string {
|
||||
t.Helper()
|
||||
|
||||
@@ -376,7 +413,7 @@ func createTestState(t *testing.T, ps *nbgrpc.ProxyServiceServer, redirectURL st
|
||||
}
|
||||
|
||||
func TestAuthCallback_UserAllowedToLogin(t *testing.T) {
|
||||
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
|
||||
setup := setupAuthCallbackTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
setup.oidcServer.tokenSubject = "allowedUserId"
|
||||
@@ -401,81 +438,8 @@ func TestAuthCallback_UserAllowedToLogin(t *testing.T) {
|
||||
require.Empty(t, parsedLocation.Query().Get("error"), "Should not have error parameter")
|
||||
}
|
||||
|
||||
func TestAuthCallback_UserNotInAllowedGroup(t *testing.T) {
|
||||
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
|
||||
defer setup.cleanup()
|
||||
|
||||
setup.oidcServer.tokenSubject = "nonGroupUserId"
|
||||
|
||||
state := createTestState(t, setup.proxyService, "https://restricted-proxy.example.com/")
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
setup.router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusFound, rec.Code)
|
||||
|
||||
location := rec.Header().Get("Location")
|
||||
require.NotEmpty(t, location)
|
||||
|
||||
parsedLocation, err := url.Parse(location)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "restricted-proxy.example.com", parsedLocation.Host)
|
||||
require.Equal(t, "access_denied", parsedLocation.Query().Get("error"))
|
||||
require.Contains(t, parsedLocation.Query().Get("error_description"), "not authorized")
|
||||
}
|
||||
|
||||
func TestAuthCallback_ProxyInDifferentAccount(t *testing.T) {
|
||||
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
|
||||
defer setup.cleanup()
|
||||
|
||||
setup.oidcServer.tokenSubject = "otherAccountUserId"
|
||||
|
||||
state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/")
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
setup.router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusFound, rec.Code)
|
||||
|
||||
location := rec.Header().Get("Location")
|
||||
require.NotEmpty(t, location)
|
||||
|
||||
parsedLocation, err := url.Parse(location)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "access_denied", parsedLocation.Query().Get("error"))
|
||||
require.Contains(t, parsedLocation.Query().Get("error_description"), "not authorized")
|
||||
}
|
||||
|
||||
func TestAuthCallback_UserNotFound(t *testing.T) {
|
||||
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
|
||||
defer setup.cleanup()
|
||||
|
||||
setup.oidcServer.tokenSubject = "nonExistentUserId"
|
||||
|
||||
state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/")
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
setup.router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusFound, rec.Code)
|
||||
|
||||
location := rec.Header().Get("Location")
|
||||
parsedLocation, err := url.Parse(location)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "access_denied", parsedLocation.Query().Get("error"))
|
||||
}
|
||||
|
||||
func TestAuthCallback_ProxyNotFound(t *testing.T) {
|
||||
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
|
||||
setup := setupAuthCallbackTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
setup.oidcServer.tokenSubject = "allowedUserId"
|
||||
@@ -499,7 +463,7 @@ func TestAuthCallback_ProxyNotFound(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAuthCallback_InvalidToken(t *testing.T) {
|
||||
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
|
||||
setup := setupAuthCallbackTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
setup.oidcServer.failExchange = true
|
||||
@@ -516,7 +480,7 @@ func TestAuthCallback_InvalidToken(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAuthCallback_ExpiredToken(t *testing.T) {
|
||||
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
|
||||
setup := setupAuthCallbackTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
setup.oidcServer.tokenSubject = "allowedUserId"
|
||||
@@ -534,7 +498,7 @@ func TestAuthCallback_ExpiredToken(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAuthCallback_InvalidState(t *testing.T) {
|
||||
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
|
||||
setup := setupAuthCallbackTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state=invalid-state", nil)
|
||||
@@ -547,7 +511,7 @@ func TestAuthCallback_InvalidState(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAuthCallback_MissingState(t *testing.T) {
|
||||
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
|
||||
setup := setupAuthCallbackTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code", nil)
|
||||
@@ -557,26 +521,3 @@ func TestAuthCallback_MissingState(t *testing.T) {
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestAuthCallback_BearerAuthDisabled(t *testing.T) {
|
||||
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
|
||||
defer setup.cleanup()
|
||||
|
||||
setup.oidcServer.tokenSubject = "allowedUserId"
|
||||
|
||||
state := createTestState(t, setup.proxyService, "https://no-auth-proxy.example.com/")
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
setup.router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusFound, rec.Code)
|
||||
|
||||
location := rec.Header().Get("Location")
|
||||
parsedLocation, err := url.Parse(location)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotEmpty(t, parsedLocation.Query().Get("session_token"))
|
||||
require.Empty(t, parsedLocation.Query().Get("error"))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user