Files
netbird/management/server/http/handlers/proxy/auth_callback_integration_test.go
mlsmaycon b16d63643c Add group-based access control for SSO reverse proxy authentication
Implement user group validation during OAuth callback to ensure users
belong to allowed distribution groups before granting access to reverse
proxies. This provides account isolation and fine-grained access control.

Key changes:
- Add ValidateUserGroupAccess to ProxyServiceServer for group membership checks
- Redirect denied users to error page with access_denied parameter
- Handle OAuth error responses in proxy middleware
- Add comprehensive integration tests for auth callback flow
2026-02-10 16:25:00 +01:00

583 lines
17 KiB
Go

//go:build integration
package proxy
import (
"context"
"crypto/ed25519"
"crypto/rand"
"encoding/base64"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/gorilla/mux"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
accesslogs "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/shared/management/proto"
)
// fakeOIDCServer creates a minimal OIDC provider for testing.
type fakeOIDCServer struct {
server *httptest.Server
issuer string
signingKey ed25519.PrivateKey
publicKey ed25519.PublicKey
keyID string
tokenSubject string
tokenExpiry time.Duration
failExchange bool
}
func newFakeOIDCServer() *fakeOIDCServer {
pub, priv, _ := ed25519.GenerateKey(rand.Reader)
f := &fakeOIDCServer{
signingKey: priv,
publicKey: pub,
keyID: "test-key-1",
tokenExpiry: time.Hour,
}
f.server = httptest.NewServer(f)
f.issuer = f.server.URL
return f
}
func (f *fakeOIDCServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/.well-known/openid-configuration":
f.handleDiscovery(w, r)
case "/token":
f.handleToken(w, r)
case "/keys":
f.handleJWKS(w, r)
default:
http.NotFound(w, r)
}
}
func (f *fakeOIDCServer) handleDiscovery(w http.ResponseWriter, _ *http.Request) {
discovery := map[string]interface{}{
"issuer": f.issuer,
"authorization_endpoint": f.issuer + "/auth",
"token_endpoint": f.issuer + "/token",
"jwks_uri": f.issuer + "/keys",
"response_types_supported": []string{
"code",
"id_token",
"token id_token",
},
"subject_types_supported": []string{"public"},
"id_token_signing_alg_values_supported": []string{"EdDSA"},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(discovery)
}
func (f *fakeOIDCServer) handleToken(w http.ResponseWriter, r *http.Request) {
if f.failExchange {
http.Error(w, "invalid_grant", http.StatusBadRequest)
return
}
if err := r.ParseForm(); err != nil {
http.Error(w, "bad request", http.StatusBadRequest)
return
}
idToken := f.createIDToken()
response := map[string]interface{}{
"access_token": "test-access-token",
"token_type": "Bearer",
"expires_in": 3600,
"id_token": idToken,
"refresh_token": "test-refresh-token",
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}
func (f *fakeOIDCServer) createIDToken() string {
now := time.Now()
claims := jwt.MapClaims{
"iss": f.issuer,
"sub": f.tokenSubject,
"aud": "test-client-id",
"exp": now.Add(f.tokenExpiry).Unix(),
"iat": now.Unix(),
"nbf": now.Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims)
token.Header["kid"] = f.keyID
signed, _ := token.SignedString(f.signingKey)
return signed
}
func (f *fakeOIDCServer) handleJWKS(w http.ResponseWriter, _ *http.Request) {
jwks := map[string]interface{}{
"keys": []map[string]interface{}{
{
"kty": "OKP",
"crv": "Ed25519",
"kid": f.keyID,
"x": base64.RawURLEncoding.EncodeToString(f.publicKey),
"use": "sig",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(jwks)
}
func (f *fakeOIDCServer) Close() {
f.server.Close()
}
// testSetup contains all test dependencies.
type testSetup struct {
store store.Store
oidcServer *fakeOIDCServer
proxyService *nbgrpc.ProxyServiceServer
handler *AuthCallbackHandler
router *mux.Router
cleanup func()
}
// testAccessLogManager is a minimal mock for accesslogs.Manager.
type testAccessLogManager struct{}
func (m *testAccessLogManager) SaveAccessLog(_ context.Context, _ *accesslogs.AccessLogEntry) error {
return nil
}
func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string) ([]*accesslogs.AccessLogEntry, error) {
return nil, nil
}
func setupAuthCallbackTest(t *testing.T, sqlFile string) *testSetup {
t.Helper()
ctx := context.Background()
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, sqlFile, t.TempDir())
require.NoError(t, err)
createTestReverseProxies(t, ctx, testStore)
oidcServer := newFakeOIDCServer()
tokenStore := nbgrpc.NewOneTimeTokenStore(time.Minute)
usersManager := users.NewManager(testStore)
oidcConfig := nbgrpc.ProxyOIDCConfig{
Issuer: oidcServer.issuer,
ClientID: "test-client-id",
Scopes: []string{"openid", "profile", "email"},
CallbackURL: "https://management.example.com/reverse-proxy/callback",
HMACKey: []byte("test-hmac-key-for-state-signing"),
}
proxyService := nbgrpc.NewProxyServiceServer(
&testAccessLogManager{},
tokenStore,
oidcConfig,
nil,
usersManager,
)
proxyService.SetProxyManager(&testReverseProxyManager{store: testStore})
handler := NewAuthCallbackHandler(proxyService)
router := mux.NewRouter()
handler.RegisterEndpoints(router)
return &testSetup{
store: testStore,
oidcServer: oidcServer,
proxyService: proxyService,
handler: handler,
router: router,
cleanup: func() {
cleanup()
oidcServer.Close()
},
}
}
func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store.Store) {
t.Helper()
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pubKey := base64.StdEncoding.EncodeToString(pub)
privKey := base64.StdEncoding.EncodeToString(priv)
testProxy := &reverseproxy.ReverseProxy{
ID: "testProxyId",
AccountID: "testAccountId",
Name: "Test Proxy",
Domain: "test-proxy.example.com",
Targets: []*reverseproxy.Target{{
Path: strPtr("/"),
Host: "localhost",
Port: 8080,
Protocol: "http",
TargetId: "peer1",
TargetType: "peer",
Enabled: true,
}},
Enabled: true,
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"allowedGroupId"},
},
},
SessionPrivateKey: privKey,
SessionPublicKey: pubKey,
}
require.NoError(t, testStore.CreateReverseProxy(ctx, testProxy))
restrictedProxy := &reverseproxy.ReverseProxy{
ID: "restrictedProxyId",
AccountID: "testAccountId",
Name: "Restricted Proxy",
Domain: "restricted-proxy.example.com",
Targets: []*reverseproxy.Target{{
Path: strPtr("/"),
Host: "localhost",
Port: 8080,
Protocol: "http",
TargetId: "peer1",
TargetType: "peer",
Enabled: true,
}},
Enabled: true,
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"restrictedGroupId"},
},
},
SessionPrivateKey: privKey,
SessionPublicKey: pubKey,
}
require.NoError(t, testStore.CreateReverseProxy(ctx, restrictedProxy))
noAuthProxy := &reverseproxy.ReverseProxy{
ID: "noAuthProxyId",
AccountID: "testAccountId",
Name: "No Auth Proxy",
Domain: "no-auth-proxy.example.com",
Targets: []*reverseproxy.Target{{
Path: strPtr("/"),
Host: "localhost",
Port: 8080,
Protocol: "http",
TargetId: "peer1",
TargetType: "peer",
Enabled: true,
}},
Enabled: true,
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Enabled: false,
},
},
SessionPrivateKey: privKey,
SessionPublicKey: pubKey,
}
require.NoError(t, testStore.CreateReverseProxy(ctx, noAuthProxy))
}
func strPtr(s string) *string {
return &s
}
// testReverseProxyManager is a minimal implementation for testing.
type testReverseProxyManager struct {
store store.Store
}
func (m *testReverseProxyManager) GetAllReverseProxies(_ context.Context, _, _ string) ([]*reverseproxy.ReverseProxy, error) {
return nil, nil
}
func (m *testReverseProxyManager) GetReverseProxy(_ context.Context, _, _, _ string) (*reverseproxy.ReverseProxy, error) {
return nil, nil
}
func (m *testReverseProxyManager) CreateReverseProxy(_ context.Context, _, _ string, _ *reverseproxy.ReverseProxy) (*reverseproxy.ReverseProxy, error) {
return nil, nil
}
func (m *testReverseProxyManager) UpdateReverseProxy(_ context.Context, _, _ string, _ *reverseproxy.ReverseProxy) (*reverseproxy.ReverseProxy, error) {
return nil, nil
}
func (m *testReverseProxyManager) DeleteReverseProxy(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testReverseProxyManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error {
return nil
}
func (m *testReverseProxyManager) SetStatus(_ context.Context, _, _ string, _ reverseproxy.ProxyStatus) error {
return nil
}
func (m *testReverseProxyManager) ReloadAllReverseProxiesForAccount(_ context.Context, _ string) error {
return nil
}
func (m *testReverseProxyManager) ReloadReverseProxy(_ context.Context, _, _ string) error {
return nil
}
func (m *testReverseProxyManager) GetGlobalReverseProxies(ctx context.Context) ([]*reverseproxy.ReverseProxy, error) {
return m.store.GetReverseProxies(ctx, store.LockingStrengthNone)
}
func (m *testReverseProxyManager) GetProxyByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.ReverseProxy, error) {
return m.store.GetReverseProxyByID(ctx, store.LockingStrengthNone, accountID, proxyID)
}
func (m *testReverseProxyManager) GetAccountReverseProxies(ctx context.Context, accountID string) ([]*reverseproxy.ReverseProxy, error) {
return m.store.GetAccountReverseProxies(ctx, store.LockingStrengthNone, accountID)
}
func createTestState(t *testing.T, ps *nbgrpc.ProxyServiceServer, redirectURL string) string {
t.Helper()
resp, err := ps.GetOIDCURL(context.Background(), &proto.GetOIDCURLRequest{
RedirectUrl: redirectURL,
AccountId: "testAccountId",
})
require.NoError(t, err)
parsedURL, err := url.Parse(resp.Url)
require.NoError(t, err)
return parsedURL.Query().Get("state")
}
func TestAuthCallback_UserAllowedToLogin(t *testing.T) {
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
defer setup.cleanup()
setup.oidcServer.tokenSubject = "allowedUserId"
state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/dashboard")
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
rec := httptest.NewRecorder()
setup.router.ServeHTTP(rec, req)
require.Equal(t, http.StatusFound, rec.Code)
location := rec.Header().Get("Location")
require.NotEmpty(t, location)
parsedLocation, err := url.Parse(location)
require.NoError(t, err)
require.Equal(t, "test-proxy.example.com", parsedLocation.Host)
require.NotEmpty(t, parsedLocation.Query().Get("session_token"), "Should include session token")
require.Empty(t, parsedLocation.Query().Get("error"), "Should not have error parameter")
}
func TestAuthCallback_UserNotInAllowedGroup(t *testing.T) {
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
defer setup.cleanup()
setup.oidcServer.tokenSubject = "nonGroupUserId"
state := createTestState(t, setup.proxyService, "https://restricted-proxy.example.com/")
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
rec := httptest.NewRecorder()
setup.router.ServeHTTP(rec, req)
require.Equal(t, http.StatusFound, rec.Code)
location := rec.Header().Get("Location")
require.NotEmpty(t, location)
parsedLocation, err := url.Parse(location)
require.NoError(t, err)
require.Equal(t, "restricted-proxy.example.com", parsedLocation.Host)
require.Equal(t, "access_denied", parsedLocation.Query().Get("error"))
require.Contains(t, parsedLocation.Query().Get("error_description"), "not authorized")
}
func TestAuthCallback_ProxyInDifferentAccount(t *testing.T) {
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
defer setup.cleanup()
setup.oidcServer.tokenSubject = "otherAccountUserId"
state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/")
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
rec := httptest.NewRecorder()
setup.router.ServeHTTP(rec, req)
require.Equal(t, http.StatusFound, rec.Code)
location := rec.Header().Get("Location")
require.NotEmpty(t, location)
parsedLocation, err := url.Parse(location)
require.NoError(t, err)
require.Equal(t, "access_denied", parsedLocation.Query().Get("error"))
require.Contains(t, parsedLocation.Query().Get("error_description"), "not authorized")
}
func TestAuthCallback_UserNotFound(t *testing.T) {
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
defer setup.cleanup()
setup.oidcServer.tokenSubject = "nonExistentUserId"
state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/")
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
rec := httptest.NewRecorder()
setup.router.ServeHTTP(rec, req)
require.Equal(t, http.StatusFound, rec.Code)
location := rec.Header().Get("Location")
parsedLocation, err := url.Parse(location)
require.NoError(t, err)
require.Equal(t, "access_denied", parsedLocation.Query().Get("error"))
}
func TestAuthCallback_ProxyNotFound(t *testing.T) {
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
defer setup.cleanup()
setup.oidcServer.tokenSubject = "allowedUserId"
state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/")
require.NoError(t, setup.store.DeleteReverseProxy(context.Background(), "testAccountId", "testProxyId"))
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
rec := httptest.NewRecorder()
setup.router.ServeHTTP(rec, req)
require.Equal(t, http.StatusFound, rec.Code)
location := rec.Header().Get("Location")
parsedLocation, err := url.Parse(location)
require.NoError(t, err)
require.Equal(t, "access_denied", parsedLocation.Query().Get("error"))
}
func TestAuthCallback_InvalidToken(t *testing.T) {
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
defer setup.cleanup()
setup.oidcServer.failExchange = true
state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/")
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=invalid-code&state="+url.QueryEscape(state), nil)
rec := httptest.NewRecorder()
setup.router.ServeHTTP(rec, req)
require.Equal(t, http.StatusInternalServerError, rec.Code)
require.Contains(t, rec.Body.String(), "Failed to exchange code")
}
func TestAuthCallback_ExpiredToken(t *testing.T) {
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
defer setup.cleanup()
setup.oidcServer.tokenSubject = "allowedUserId"
setup.oidcServer.tokenExpiry = -time.Hour
state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/")
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
rec := httptest.NewRecorder()
setup.router.ServeHTTP(rec, req)
require.Equal(t, http.StatusUnauthorized, rec.Code)
require.Contains(t, rec.Body.String(), "Failed to validate token")
}
func TestAuthCallback_InvalidState(t *testing.T) {
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
defer setup.cleanup()
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state=invalid-state", nil)
rec := httptest.NewRecorder()
setup.router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
require.Contains(t, rec.Body.String(), "Invalid state")
}
func TestAuthCallback_MissingState(t *testing.T) {
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
defer setup.cleanup()
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code", nil)
rec := httptest.NewRecorder()
setup.router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestAuthCallback_BearerAuthDisabled(t *testing.T) {
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
defer setup.cleanup()
setup.oidcServer.tokenSubject = "allowedUserId"
state := createTestState(t, setup.proxyService, "https://no-auth-proxy.example.com/")
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
rec := httptest.NewRecorder()
setup.router.ServeHTTP(rec, req)
require.Equal(t, http.StatusFound, rec.Code)
location := rec.Header().Get("Location")
parsedLocation, err := url.Parse(location)
require.NoError(t, err)
require.NotEmpty(t, parsedLocation.Query().Get("session_token"))
require.Empty(t, parsedLocation.Query().Get("error"))
}