mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 00:06:38 +00:00
Add group-based access control for SSO reverse proxy authentication
Implement user group validation during OAuth callback to ensure users belong to allowed distribution groups before granting access to reverse proxies. This provides account isolation and fine-grained access control. Key changes: - Add ValidateUserGroupAccess to ProxyServiceServer for group membership checks - Redirect denied users to error page with access_denied parameter - Handle OAuth error responses in proxy middleware - Add comprehensive integration tests for auth callback flow
This commit is contained in:
@@ -7,24 +7,27 @@ import (
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/proxy/auth"
|
||||
)
|
||||
|
||||
// AuthCallbackHandler handles OAuth callbacks for proxy authentication.
|
||||
type AuthCallbackHandler struct {
|
||||
proxyService *nbgrpc.ProxyServiceServer
|
||||
}
|
||||
|
||||
// NewAuthCallbackHandler creates a new OAuth callback handler.
|
||||
func NewAuthCallbackHandler(proxyService *nbgrpc.ProxyServiceServer) *AuthCallbackHandler {
|
||||
return &AuthCallbackHandler{
|
||||
proxyService: proxyService,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterEndpoints registers the OAuth callback endpoint.
|
||||
func (h *AuthCallbackHandler) RegisterEndpoints(router *mux.Router) {
|
||||
router.HandleFunc(types.ProxyCallbackEndpoint, h.handleCallback).Methods(http.MethodGet)
|
||||
}
|
||||
@@ -46,10 +49,8 @@ func (h *AuthCallbackHandler) handleCallback(w http.ResponseWriter, r *http.Requ
|
||||
return
|
||||
}
|
||||
|
||||
// Get OIDC configuration
|
||||
oidcConfig := h.proxyService.GetOIDCConfig()
|
||||
|
||||
// Create OIDC provider to discover endpoints
|
||||
provider, err := oidc.NewProvider(r.Context(), oidcConfig.Issuer)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to create OIDC provider")
|
||||
@@ -68,7 +69,6 @@ func (h *AuthCallbackHandler) handleCallback(w http.ResponseWriter, r *http.Requ
|
||||
return
|
||||
}
|
||||
|
||||
// Extract user ID from the OIDC token
|
||||
userID := extractUserIDFromToken(r.Context(), provider, oidcConfig, token)
|
||||
if userID == "" {
|
||||
log.Error("Failed to extract user ID from OIDC token")
|
||||
@@ -76,7 +76,22 @@ func (h *AuthCallbackHandler) handleCallback(w http.ResponseWriter, r *http.Requ
|
||||
return
|
||||
}
|
||||
|
||||
// Generate session JWT instead of passing OIDC access_token
|
||||
if err := h.proxyService.ValidateUserGroupAccess(r.Context(), redirectURL.Hostname(), userID); err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"user_id": userID,
|
||||
"domain": redirectURL.Hostname(),
|
||||
"error": err.Error(),
|
||||
}).Warn("User denied access to reverse proxy")
|
||||
|
||||
redirectURL.Scheme = "https"
|
||||
query := redirectURL.Query()
|
||||
query.Set("error", "access_denied")
|
||||
query.Set("error_description", "You are not authorized to access this service")
|
||||
redirectURL.RawQuery = query.Encode()
|
||||
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
sessionToken, err := h.proxyService.GenerateSessionToken(r.Context(), redirectURL.Hostname(), userID, auth.MethodOIDC)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to create session token")
|
||||
@@ -84,13 +99,8 @@ func (h *AuthCallbackHandler) handleCallback(w http.ResponseWriter, r *http.Requ
|
||||
return
|
||||
}
|
||||
|
||||
// Redirect must be HTTPS, regardless of what was originally intended (which should always be HTTPS but better to double-check here).
|
||||
redirectURL.Scheme = "https"
|
||||
|
||||
// Pass the session token in the URL query parameter. The proxy middleware will
|
||||
// extract it, validate it, set its own cookie, and redirect to remove the token from the URL.
|
||||
// We cannot set the cookie here because cookies are domain-scoped (RFC 6265) and the
|
||||
// management server cannot set cookies for the proxy's domain.
|
||||
query := redirectURL.Query()
|
||||
query.Set("session_token", sessionToken)
|
||||
redirectURL.RawQuery = query.Encode()
|
||||
@@ -99,9 +109,7 @@ func (h *AuthCallbackHandler) handleCallback(w http.ResponseWriter, r *http.Requ
|
||||
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
// extractUserIDFromToken extracts the user ID from an OIDC token.
|
||||
func extractUserIDFromToken(ctx context.Context, provider *oidc.Provider, config nbgrpc.ProxyOIDCConfig, token *oauth2.Token) string {
|
||||
// Try to get ID token from the oauth2 token extras
|
||||
rawIDToken, ok := token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
log.Warn("No id_token in OIDC response")
|
||||
@@ -118,27 +126,13 @@ func extractUserIDFromToken(ctx context.Context, provider *oidc.Provider, config
|
||||
return ""
|
||||
}
|
||||
|
||||
// Extract claims
|
||||
var claims struct {
|
||||
Subject string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
log.WithError(err).Warn("Failed to extract claims from ID token")
|
||||
return ""
|
||||
}
|
||||
|
||||
// Prefer subject, fall back to user_id or email
|
||||
if claims.Subject != "" {
|
||||
return claims.Subject
|
||||
}
|
||||
if claims.UserID != "" {
|
||||
return claims.UserID
|
||||
}
|
||||
if claims.Email != "" {
|
||||
return claims.Email
|
||||
}
|
||||
|
||||
return ""
|
||||
return claims.Subject
|
||||
}
|
||||
|
||||
@@ -0,0 +1,582 @@
|
||||
//go:build integration
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
accesslogs "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// fakeOIDCServer creates a minimal OIDC provider for testing.
|
||||
type fakeOIDCServer struct {
|
||||
server *httptest.Server
|
||||
issuer string
|
||||
signingKey ed25519.PrivateKey
|
||||
publicKey ed25519.PublicKey
|
||||
keyID string
|
||||
tokenSubject string
|
||||
tokenExpiry time.Duration
|
||||
failExchange bool
|
||||
}
|
||||
|
||||
func newFakeOIDCServer() *fakeOIDCServer {
|
||||
pub, priv, _ := ed25519.GenerateKey(rand.Reader)
|
||||
f := &fakeOIDCServer{
|
||||
signingKey: priv,
|
||||
publicKey: pub,
|
||||
keyID: "test-key-1",
|
||||
tokenExpiry: time.Hour,
|
||||
}
|
||||
f.server = httptest.NewServer(f)
|
||||
f.issuer = f.server.URL
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *fakeOIDCServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/.well-known/openid-configuration":
|
||||
f.handleDiscovery(w, r)
|
||||
case "/token":
|
||||
f.handleToken(w, r)
|
||||
case "/keys":
|
||||
f.handleJWKS(w, r)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeOIDCServer) handleDiscovery(w http.ResponseWriter, _ *http.Request) {
|
||||
discovery := map[string]interface{}{
|
||||
"issuer": f.issuer,
|
||||
"authorization_endpoint": f.issuer + "/auth",
|
||||
"token_endpoint": f.issuer + "/token",
|
||||
"jwks_uri": f.issuer + "/keys",
|
||||
"response_types_supported": []string{
|
||||
"code",
|
||||
"id_token",
|
||||
"token id_token",
|
||||
},
|
||||
"subject_types_supported": []string{"public"},
|
||||
"id_token_signing_alg_values_supported": []string{"EdDSA"},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(discovery)
|
||||
}
|
||||
|
||||
func (f *fakeOIDCServer) handleToken(w http.ResponseWriter, r *http.Request) {
|
||||
if f.failExchange {
|
||||
http.Error(w, "invalid_grant", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := r.ParseForm(); err != nil {
|
||||
http.Error(w, "bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
idToken := f.createIDToken()
|
||||
|
||||
response := map[string]interface{}{
|
||||
"access_token": "test-access-token",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
"id_token": idToken,
|
||||
"refresh_token": "test-refresh-token",
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
func (f *fakeOIDCServer) createIDToken() string {
|
||||
now := time.Now()
|
||||
claims := jwt.MapClaims{
|
||||
"iss": f.issuer,
|
||||
"sub": f.tokenSubject,
|
||||
"aud": "test-client-id",
|
||||
"exp": now.Add(f.tokenExpiry).Unix(),
|
||||
"iat": now.Unix(),
|
||||
"nbf": now.Unix(),
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims)
|
||||
token.Header["kid"] = f.keyID
|
||||
signed, _ := token.SignedString(f.signingKey)
|
||||
return signed
|
||||
}
|
||||
|
||||
func (f *fakeOIDCServer) handleJWKS(w http.ResponseWriter, _ *http.Request) {
|
||||
jwks := map[string]interface{}{
|
||||
"keys": []map[string]interface{}{
|
||||
{
|
||||
"kty": "OKP",
|
||||
"crv": "Ed25519",
|
||||
"kid": f.keyID,
|
||||
"x": base64.RawURLEncoding.EncodeToString(f.publicKey),
|
||||
"use": "sig",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(jwks)
|
||||
}
|
||||
|
||||
func (f *fakeOIDCServer) Close() {
|
||||
f.server.Close()
|
||||
}
|
||||
|
||||
// testSetup contains all test dependencies.
|
||||
type testSetup struct {
|
||||
store store.Store
|
||||
oidcServer *fakeOIDCServer
|
||||
proxyService *nbgrpc.ProxyServiceServer
|
||||
handler *AuthCallbackHandler
|
||||
router *mux.Router
|
||||
cleanup func()
|
||||
}
|
||||
|
||||
// testAccessLogManager is a minimal mock for accesslogs.Manager.
|
||||
type testAccessLogManager struct{}
|
||||
|
||||
func (m *testAccessLogManager) SaveAccessLog(_ context.Context, _ *accesslogs.AccessLogEntry) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string) ([]*accesslogs.AccessLogEntry, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func setupAuthCallbackTest(t *testing.T, sqlFile string) *testSetup {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, sqlFile, t.TempDir())
|
||||
require.NoError(t, err)
|
||||
|
||||
createTestReverseProxies(t, ctx, testStore)
|
||||
|
||||
oidcServer := newFakeOIDCServer()
|
||||
|
||||
tokenStore := nbgrpc.NewOneTimeTokenStore(time.Minute)
|
||||
|
||||
usersManager := users.NewManager(testStore)
|
||||
|
||||
oidcConfig := nbgrpc.ProxyOIDCConfig{
|
||||
Issuer: oidcServer.issuer,
|
||||
ClientID: "test-client-id",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
CallbackURL: "https://management.example.com/reverse-proxy/callback",
|
||||
HMACKey: []byte("test-hmac-key-for-state-signing"),
|
||||
}
|
||||
|
||||
proxyService := nbgrpc.NewProxyServiceServer(
|
||||
&testAccessLogManager{},
|
||||
tokenStore,
|
||||
oidcConfig,
|
||||
nil,
|
||||
usersManager,
|
||||
)
|
||||
|
||||
proxyService.SetProxyManager(&testReverseProxyManager{store: testStore})
|
||||
|
||||
handler := NewAuthCallbackHandler(proxyService)
|
||||
|
||||
router := mux.NewRouter()
|
||||
handler.RegisterEndpoints(router)
|
||||
|
||||
return &testSetup{
|
||||
store: testStore,
|
||||
oidcServer: oidcServer,
|
||||
proxyService: proxyService,
|
||||
handler: handler,
|
||||
router: router,
|
||||
cleanup: func() {
|
||||
cleanup()
|
||||
oidcServer.Close()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store.Store) {
|
||||
t.Helper()
|
||||
|
||||
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
pubKey := base64.StdEncoding.EncodeToString(pub)
|
||||
privKey := base64.StdEncoding.EncodeToString(priv)
|
||||
|
||||
testProxy := &reverseproxy.ReverseProxy{
|
||||
ID: "testProxyId",
|
||||
AccountID: "testAccountId",
|
||||
Name: "Test Proxy",
|
||||
Domain: "test-proxy.example.com",
|
||||
Targets: []*reverseproxy.Target{{
|
||||
Path: strPtr("/"),
|
||||
Host: "localhost",
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
TargetId: "peer1",
|
||||
TargetType: "peer",
|
||||
Enabled: true,
|
||||
}},
|
||||
Enabled: true,
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||
Enabled: true,
|
||||
DistributionGroups: []string{"allowedGroupId"},
|
||||
},
|
||||
},
|
||||
SessionPrivateKey: privKey,
|
||||
SessionPublicKey: pubKey,
|
||||
}
|
||||
require.NoError(t, testStore.CreateReverseProxy(ctx, testProxy))
|
||||
|
||||
restrictedProxy := &reverseproxy.ReverseProxy{
|
||||
ID: "restrictedProxyId",
|
||||
AccountID: "testAccountId",
|
||||
Name: "Restricted Proxy",
|
||||
Domain: "restricted-proxy.example.com",
|
||||
Targets: []*reverseproxy.Target{{
|
||||
Path: strPtr("/"),
|
||||
Host: "localhost",
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
TargetId: "peer1",
|
||||
TargetType: "peer",
|
||||
Enabled: true,
|
||||
}},
|
||||
Enabled: true,
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||
Enabled: true,
|
||||
DistributionGroups: []string{"restrictedGroupId"},
|
||||
},
|
||||
},
|
||||
SessionPrivateKey: privKey,
|
||||
SessionPublicKey: pubKey,
|
||||
}
|
||||
require.NoError(t, testStore.CreateReverseProxy(ctx, restrictedProxy))
|
||||
|
||||
noAuthProxy := &reverseproxy.ReverseProxy{
|
||||
ID: "noAuthProxyId",
|
||||
AccountID: "testAccountId",
|
||||
Name: "No Auth Proxy",
|
||||
Domain: "no-auth-proxy.example.com",
|
||||
Targets: []*reverseproxy.Target{{
|
||||
Path: strPtr("/"),
|
||||
Host: "localhost",
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
TargetId: "peer1",
|
||||
TargetType: "peer",
|
||||
Enabled: true,
|
||||
}},
|
||||
Enabled: true,
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
SessionPrivateKey: privKey,
|
||||
SessionPublicKey: pubKey,
|
||||
}
|
||||
require.NoError(t, testStore.CreateReverseProxy(ctx, noAuthProxy))
|
||||
}
|
||||
|
||||
func strPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
// testReverseProxyManager is a minimal implementation for testing.
|
||||
type testReverseProxyManager struct {
|
||||
store store.Store
|
||||
}
|
||||
|
||||
func (m *testReverseProxyManager) GetAllReverseProxies(_ context.Context, _, _ string) ([]*reverseproxy.ReverseProxy, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testReverseProxyManager) GetReverseProxy(_ context.Context, _, _, _ string) (*reverseproxy.ReverseProxy, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testReverseProxyManager) CreateReverseProxy(_ context.Context, _, _ string, _ *reverseproxy.ReverseProxy) (*reverseproxy.ReverseProxy, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testReverseProxyManager) UpdateReverseProxy(_ context.Context, _, _ string, _ *reverseproxy.ReverseProxy) (*reverseproxy.ReverseProxy, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testReverseProxyManager) DeleteReverseProxy(_ context.Context, _, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testReverseProxyManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testReverseProxyManager) SetStatus(_ context.Context, _, _ string, _ reverseproxy.ProxyStatus) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testReverseProxyManager) ReloadAllReverseProxiesForAccount(_ context.Context, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testReverseProxyManager) ReloadReverseProxy(_ context.Context, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testReverseProxyManager) GetGlobalReverseProxies(ctx context.Context) ([]*reverseproxy.ReverseProxy, error) {
|
||||
return m.store.GetReverseProxies(ctx, store.LockingStrengthNone)
|
||||
}
|
||||
|
||||
func (m *testReverseProxyManager) GetProxyByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.ReverseProxy, error) {
|
||||
return m.store.GetReverseProxyByID(ctx, store.LockingStrengthNone, accountID, proxyID)
|
||||
}
|
||||
|
||||
func (m *testReverseProxyManager) GetAccountReverseProxies(ctx context.Context, accountID string) ([]*reverseproxy.ReverseProxy, error) {
|
||||
return m.store.GetAccountReverseProxies(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
func createTestState(t *testing.T, ps *nbgrpc.ProxyServiceServer, redirectURL string) string {
|
||||
t.Helper()
|
||||
|
||||
resp, err := ps.GetOIDCURL(context.Background(), &proto.GetOIDCURLRequest{
|
||||
RedirectUrl: redirectURL,
|
||||
AccountId: "testAccountId",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
parsedURL, err := url.Parse(resp.Url)
|
||||
require.NoError(t, err)
|
||||
|
||||
return parsedURL.Query().Get("state")
|
||||
}
|
||||
|
||||
func TestAuthCallback_UserAllowedToLogin(t *testing.T) {
|
||||
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
|
||||
defer setup.cleanup()
|
||||
|
||||
setup.oidcServer.tokenSubject = "allowedUserId"
|
||||
|
||||
state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/dashboard")
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
setup.router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusFound, rec.Code)
|
||||
|
||||
location := rec.Header().Get("Location")
|
||||
require.NotEmpty(t, location)
|
||||
|
||||
parsedLocation, err := url.Parse(location)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "test-proxy.example.com", parsedLocation.Host)
|
||||
require.NotEmpty(t, parsedLocation.Query().Get("session_token"), "Should include session token")
|
||||
require.Empty(t, parsedLocation.Query().Get("error"), "Should not have error parameter")
|
||||
}
|
||||
|
||||
func TestAuthCallback_UserNotInAllowedGroup(t *testing.T) {
|
||||
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
|
||||
defer setup.cleanup()
|
||||
|
||||
setup.oidcServer.tokenSubject = "nonGroupUserId"
|
||||
|
||||
state := createTestState(t, setup.proxyService, "https://restricted-proxy.example.com/")
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
setup.router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusFound, rec.Code)
|
||||
|
||||
location := rec.Header().Get("Location")
|
||||
require.NotEmpty(t, location)
|
||||
|
||||
parsedLocation, err := url.Parse(location)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "restricted-proxy.example.com", parsedLocation.Host)
|
||||
require.Equal(t, "access_denied", parsedLocation.Query().Get("error"))
|
||||
require.Contains(t, parsedLocation.Query().Get("error_description"), "not authorized")
|
||||
}
|
||||
|
||||
func TestAuthCallback_ProxyInDifferentAccount(t *testing.T) {
|
||||
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
|
||||
defer setup.cleanup()
|
||||
|
||||
setup.oidcServer.tokenSubject = "otherAccountUserId"
|
||||
|
||||
state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/")
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
setup.router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusFound, rec.Code)
|
||||
|
||||
location := rec.Header().Get("Location")
|
||||
require.NotEmpty(t, location)
|
||||
|
||||
parsedLocation, err := url.Parse(location)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "access_denied", parsedLocation.Query().Get("error"))
|
||||
require.Contains(t, parsedLocation.Query().Get("error_description"), "not authorized")
|
||||
}
|
||||
|
||||
func TestAuthCallback_UserNotFound(t *testing.T) {
|
||||
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
|
||||
defer setup.cleanup()
|
||||
|
||||
setup.oidcServer.tokenSubject = "nonExistentUserId"
|
||||
|
||||
state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/")
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
setup.router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusFound, rec.Code)
|
||||
|
||||
location := rec.Header().Get("Location")
|
||||
parsedLocation, err := url.Parse(location)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "access_denied", parsedLocation.Query().Get("error"))
|
||||
}
|
||||
|
||||
func TestAuthCallback_ProxyNotFound(t *testing.T) {
|
||||
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
|
||||
defer setup.cleanup()
|
||||
|
||||
setup.oidcServer.tokenSubject = "allowedUserId"
|
||||
|
||||
state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/")
|
||||
|
||||
require.NoError(t, setup.store.DeleteReverseProxy(context.Background(), "testAccountId", "testProxyId"))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
setup.router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusFound, rec.Code)
|
||||
|
||||
location := rec.Header().Get("Location")
|
||||
parsedLocation, err := url.Parse(location)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "access_denied", parsedLocation.Query().Get("error"))
|
||||
}
|
||||
|
||||
func TestAuthCallback_InvalidToken(t *testing.T) {
|
||||
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
|
||||
defer setup.cleanup()
|
||||
|
||||
setup.oidcServer.failExchange = true
|
||||
|
||||
state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/")
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=invalid-code&state="+url.QueryEscape(state), nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
setup.router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
require.Contains(t, rec.Body.String(), "Failed to exchange code")
|
||||
}
|
||||
|
||||
func TestAuthCallback_ExpiredToken(t *testing.T) {
|
||||
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
|
||||
defer setup.cleanup()
|
||||
|
||||
setup.oidcServer.tokenSubject = "allowedUserId"
|
||||
setup.oidcServer.tokenExpiry = -time.Hour
|
||||
|
||||
state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/")
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
setup.router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
require.Contains(t, rec.Body.String(), "Failed to validate token")
|
||||
}
|
||||
|
||||
func TestAuthCallback_InvalidState(t *testing.T) {
|
||||
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
|
||||
defer setup.cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state=invalid-state", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
setup.router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
require.Contains(t, rec.Body.String(), "Invalid state")
|
||||
}
|
||||
|
||||
func TestAuthCallback_MissingState(t *testing.T) {
|
||||
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
|
||||
defer setup.cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
setup.router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestAuthCallback_BearerAuthDisabled(t *testing.T) {
|
||||
setup := setupAuthCallbackTest(t, "../../../http/testing/testdata/auth_callback.sql")
|
||||
defer setup.cleanup()
|
||||
|
||||
setup.oidcServer.tokenSubject = "allowedUserId"
|
||||
|
||||
state := createTestState(t, setup.proxyService, "https://no-auth-proxy.example.com/")
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
setup.router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusFound, rec.Code)
|
||||
|
||||
location := rec.Header().Get("Location")
|
||||
parsedLocation, err := url.Parse(location)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotEmpty(t, parsedLocation.Query().Get("session_token"))
|
||||
require.Empty(t, parsedLocation.Query().Get("error"))
|
||||
}
|
||||
Reference in New Issue
Block a user