[management, reverse proxy] Add reverse proxy feature (#5291)

* implement reverse proxy


---------

Co-authored-by: Alisdair MacLeod <git@alisdairmacleod.co.uk>
Co-authored-by: mlsmaycon <mlsmaycon@gmail.com>
Co-authored-by: Eduard Gert <kontakt@eduardgert.de>
Co-authored-by: Viktor Liu <viktor@netbird.io>
Co-authored-by: Diego Noguês <diego.sure@gmail.com>
Co-authored-by: Diego Noguês <49420+diegocn@users.noreply.github.com>
Co-authored-by: Bethuel Mmbaga <bethuelmbaga12@gmail.com>
Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com>
Co-authored-by: Ashley Mensah <ashleyamo982@gmail.com>
This commit is contained in:
Pascal Fischer
2026-02-13 19:37:43 +01:00
committed by GitHub
parent edce11b34d
commit f53155562f
225 changed files with 35513 additions and 235 deletions

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net/http"
"net/netip"
"os"
"strconv"
"time"
@@ -12,9 +13,19 @@ import (
"github.com/rs/cors"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
idpmanager "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/modules/zones"
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
@@ -26,6 +37,8 @@ import (
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/http/handlers/proxy"
nbpeers "github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/server/auth"
"github.com/netbirdio/netbird/management/server/geolocation"
@@ -60,7 +73,7 @@ const (
)
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager) (http.Handler, error) {
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, reverseProxyManager reverseproxy.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) {
// Register bypass paths for unauthenticated endpoints
if err := bypass.AddBypassPath("/api/instance"); err != nil {
@@ -76,6 +89,10 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
if err := bypass.AddBypassPath("/api/users/invites/nbi_*/accept"); err != nil {
return nil, fmt.Errorf("failed to add bypass path: %w", err)
}
// OAuth callback for proxy authentication
if err := bypass.AddBypassPath(types.ProxyCallbackEndpointFull); err != nil {
return nil, fmt.Errorf("failed to add bypass path: %w", err)
}
var rateLimitingConfig *middleware.RateLimiterConfig
if os.Getenv(rateLimitingEnabledKey) == "true" {
@@ -156,6 +173,15 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
idp.AddEndpoints(accountManager, router)
instance.AddEndpoints(instanceManager, router)
instance.AddVersionEndpoint(instanceManager, router)
if reverseProxyManager != nil && reverseProxyDomainManager != nil {
reverseproxymanager.RegisterEndpoints(reverseProxyManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, router)
}
// Register OAuth callback handler for proxy authentication
if proxyGRPCServer != nil {
oauthHandler := proxy.NewAuthCallbackHandler(proxyGRPCServer, trustedHTTPProxies)
oauthHandler.RegisterEndpoints(router)
}
// Mount embedded IdP handler at /oauth2 path if configured
if embeddedIdpEnabled {

View File

@@ -154,6 +154,11 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string,
return
}
if peer.ProxyMeta.Embedded {
util.WriteError(ctx, status.Errorf(status.InvalidArgument, "not allowed to read peer"), w)
return
}
settings, err := h.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator)
if err != nil {
util.WriteError(ctx, err, w)
@@ -321,6 +326,9 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
grpsInfoMap := groups.ToGroupsInfoMap(grps, len(peers))
respBody := make([]*api.PeerBatch, 0, len(peers))
for _, peer := range peers {
if peer.ProxyMeta.Embedded {
continue
}
respBody = append(respBody, toPeerListItemResponse(peer, grpsInfoMap[peer.ID], dnsDomain, 0))
}

View File

@@ -0,0 +1,208 @@
package proxy
import (
"context"
"net"
"net/http"
"net/netip"
"net/url"
"strings"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/proxy/auth"
)
// AuthCallbackHandler handles OAuth callbacks for proxy authentication.
type AuthCallbackHandler struct {
proxyService *nbgrpc.ProxyServiceServer
rateLimiter *middleware.APIRateLimiter
trustedProxies []netip.Prefix
}
// NewAuthCallbackHandler creates a new OAuth callback handler.
func NewAuthCallbackHandler(proxyService *nbgrpc.ProxyServiceServer, trustedProxies []netip.Prefix) *AuthCallbackHandler {
rateLimiterConfig := &middleware.RateLimiterConfig{
RequestsPerMinute: 10,
Burst: 15,
CleanupInterval: 5 * time.Minute,
LimiterTTL: 10 * time.Minute,
}
return &AuthCallbackHandler{
proxyService: proxyService,
rateLimiter: middleware.NewAPIRateLimiter(rateLimiterConfig),
trustedProxies: trustedProxies,
}
}
// RegisterEndpoints registers the OAuth callback endpoint.
func (h *AuthCallbackHandler) RegisterEndpoints(router *mux.Router) {
router.HandleFunc(types.ProxyCallbackEndpoint, h.handleCallback).Methods(http.MethodGet)
}
func (h *AuthCallbackHandler) handleCallback(w http.ResponseWriter, r *http.Request) {
clientIP := h.resolveClientIP(r)
if !h.rateLimiter.Allow(clientIP) {
log.WithField("client_ip", clientIP).Warn("OAuth callback rate limit exceeded")
http.Error(w, "Too many requests. Please try again later.", http.StatusTooManyRequests)
return
}
state := r.URL.Query().Get("state")
codeVerifier, originalURL, err := h.proxyService.ValidateState(state)
if err != nil {
log.WithError(err).Error("OAuth callback state validation failed")
http.Error(w, "Invalid state parameter", http.StatusBadRequest)
return
}
redirectURL, err := url.Parse(originalURL)
if err != nil {
log.WithError(err).Error("Failed to parse redirect URL")
http.Error(w, "Invalid redirect URL", http.StatusBadRequest)
return
}
oidcConfig := h.proxyService.GetOIDCConfig()
provider, err := oidc.NewProvider(r.Context(), oidcConfig.Issuer)
if err != nil {
log.WithError(err).Error("Failed to create OIDC provider")
http.Error(w, "Failed to create OIDC provider", http.StatusInternalServerError)
return
}
token, err := (&oauth2.Config{
ClientID: oidcConfig.ClientID,
Endpoint: provider.Endpoint(),
RedirectURL: oidcConfig.CallbackURL,
}).Exchange(r.Context(), r.URL.Query().Get("code"), oauth2.VerifierOption(codeVerifier))
if err != nil {
log.WithError(err).Error("Failed to exchange code for token")
http.Error(w, "Failed to exchange code for token", http.StatusInternalServerError)
return
}
userID := extractUserIDFromToken(r.Context(), provider, oidcConfig, token)
if userID == "" {
log.Error("Failed to extract user ID from OIDC token")
http.Error(w, "Failed to validate token", http.StatusUnauthorized)
return
}
// Group validation is performed by the proxy via ValidateSession gRPC call.
// This allows the proxy to show 403 pages directly without redirect dance.
sessionToken, err := h.proxyService.GenerateSessionToken(r.Context(), redirectURL.Hostname(), userID, auth.MethodOIDC)
if err != nil {
log.WithError(err).Error("Failed to create session token")
redirectURL.Scheme = "https"
query := redirectURL.Query()
query.Set("error", "access_denied")
query.Set("error_description", "Service configuration error")
redirectURL.RawQuery = query.Encode()
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
return
}
redirectURL.Scheme = "https"
query := redirectURL.Query()
query.Set("session_token", sessionToken)
redirectURL.RawQuery = query.Encode()
log.WithField("redirect", redirectURL.Host).Debug("OAuth callback: redirecting user with session token")
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
}
func extractUserIDFromToken(ctx context.Context, provider *oidc.Provider, config nbgrpc.ProxyOIDCConfig, token *oauth2.Token) string {
rawIDToken, ok := token.Extra("id_token").(string)
if !ok {
log.Warn("No id_token in OIDC response")
return ""
}
verifier := provider.Verifier(&oidc.Config{
ClientID: config.ClientID,
})
idToken, err := verifier.Verify(ctx, rawIDToken)
if err != nil {
log.WithError(err).Warn("Failed to verify ID token")
return ""
}
var claims struct {
Subject string `json:"sub"`
}
if err := idToken.Claims(&claims); err != nil {
log.WithError(err).Warn("Failed to extract claims from ID token")
return ""
}
return claims.Subject
}
// resolveClientIP extracts the real client IP from the request.
// When trustedProxies is non-empty and the direct peer is trusted,
// it walks X-Forwarded-For right-to-left skipping trusted IPs.
// Otherwise it returns RemoteAddr directly.
func (h *AuthCallbackHandler) resolveClientIP(r *http.Request) string {
remoteIP := extractHost(r.RemoteAddr)
if len(h.trustedProxies) == 0 || !isTrustedProxy(remoteIP, h.trustedProxies) {
return remoteIP
}
xff := r.Header.Get("X-Forwarded-For")
if xff == "" {
return remoteIP
}
parts := strings.Split(xff, ",")
for i := len(parts) - 1; i >= 0; i-- {
ip := strings.TrimSpace(parts[i])
if ip == "" {
continue
}
if !isTrustedProxy(ip, h.trustedProxies) {
return ip
}
}
// All IPs in XFF are trusted; return the leftmost as best guess.
if first := strings.TrimSpace(parts[0]); first != "" {
return first
}
return remoteIP
}
func extractHost(remoteAddr string) string {
host, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
return remoteAddr
}
return host
}
func isTrustedProxy(ipStr string, trusted []netip.Prefix) bool {
addr, err := netip.ParseAddr(ipStr)
if err != nil {
return false
}
for _, prefix := range trusted {
if prefix.Contains(addr) {
return true
}
}
return false
}

View File

@@ -0,0 +1,523 @@
//go:build integration
package proxy
import (
"context"
"crypto/ed25519"
"crypto/rand"
"encoding/base64"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/gorilla/mux"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/shared/management/proto"
)
// fakeOIDCServer creates a minimal OIDC provider for testing.
type fakeOIDCServer struct {
server *httptest.Server
issuer string
signingKey ed25519.PrivateKey
publicKey ed25519.PublicKey
keyID string
tokenSubject string
tokenExpiry time.Duration
failExchange bool
}
func newFakeOIDCServer() *fakeOIDCServer {
pub, priv, _ := ed25519.GenerateKey(rand.Reader)
f := &fakeOIDCServer{
signingKey: priv,
publicKey: pub,
keyID: "test-key-1",
tokenExpiry: time.Hour,
}
f.server = httptest.NewServer(f)
f.issuer = f.server.URL
return f
}
func (f *fakeOIDCServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/.well-known/openid-configuration":
f.handleDiscovery(w, r)
case "/token":
f.handleToken(w, r)
case "/keys":
f.handleJWKS(w, r)
default:
http.NotFound(w, r)
}
}
func (f *fakeOIDCServer) handleDiscovery(w http.ResponseWriter, _ *http.Request) {
discovery := map[string]interface{}{
"issuer": f.issuer,
"authorization_endpoint": f.issuer + "/auth",
"token_endpoint": f.issuer + "/token",
"jwks_uri": f.issuer + "/keys",
"response_types_supported": []string{
"code",
"id_token",
"token id_token",
},
"subject_types_supported": []string{"public"},
"id_token_signing_alg_values_supported": []string{"EdDSA"},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(discovery)
}
func (f *fakeOIDCServer) handleToken(w http.ResponseWriter, r *http.Request) {
if f.failExchange {
http.Error(w, "invalid_grant", http.StatusBadRequest)
return
}
if err := r.ParseForm(); err != nil {
http.Error(w, "bad request", http.StatusBadRequest)
return
}
idToken := f.createIDToken()
response := map[string]interface{}{
"access_token": "test-access-token",
"token_type": "Bearer",
"expires_in": 3600,
"id_token": idToken,
"refresh_token": "test-refresh-token",
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}
func (f *fakeOIDCServer) createIDToken() string {
now := time.Now()
claims := jwt.MapClaims{
"iss": f.issuer,
"sub": f.tokenSubject,
"aud": "test-client-id",
"exp": now.Add(f.tokenExpiry).Unix(),
"iat": now.Unix(),
"nbf": now.Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims)
token.Header["kid"] = f.keyID
signed, _ := token.SignedString(f.signingKey)
return signed
}
func (f *fakeOIDCServer) handleJWKS(w http.ResponseWriter, _ *http.Request) {
jwks := map[string]interface{}{
"keys": []map[string]interface{}{
{
"kty": "OKP",
"crv": "Ed25519",
"kid": f.keyID,
"x": base64.RawURLEncoding.EncodeToString(f.publicKey),
"use": "sig",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(jwks)
}
func (f *fakeOIDCServer) Close() {
f.server.Close()
}
// testSetup contains all test dependencies.
type testSetup struct {
store store.Store
oidcServer *fakeOIDCServer
proxyService *nbgrpc.ProxyServiceServer
handler *AuthCallbackHandler
router *mux.Router
cleanup func()
}
// testAccessLogManager is a minimal mock for accesslogs.Manager.
type testAccessLogManager struct{}
func (m *testAccessLogManager) SaveAccessLog(_ context.Context, _ *accesslogs.AccessLogEntry) error {
return nil
}
func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string, _ *accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) {
return nil, 0, nil
}
func setupAuthCallbackTest(t *testing.T) *testSetup {
t.Helper()
ctx := context.Background()
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
require.NoError(t, err)
createTestAccountsAndUsers(t, ctx, testStore)
createTestReverseProxies(t, ctx, testStore)
oidcServer := newFakeOIDCServer()
tokenStore := nbgrpc.NewOneTimeTokenStore(time.Minute)
usersManager := users.NewManager(testStore)
oidcConfig := nbgrpc.ProxyOIDCConfig{
Issuer: oidcServer.issuer,
ClientID: "test-client-id",
Scopes: []string{"openid", "profile", "email"},
CallbackURL: "https://management.example.com/reverse-proxy/callback",
HMACKey: []byte("test-hmac-key-for-state-signing"),
}
proxyService := nbgrpc.NewProxyServiceServer(
&testAccessLogManager{},
tokenStore,
oidcConfig,
nil,
usersManager,
)
proxyService.SetProxyManager(&testServiceManager{store: testStore})
handler := NewAuthCallbackHandler(proxyService, nil)
router := mux.NewRouter()
handler.RegisterEndpoints(router)
return &testSetup{
store: testStore,
oidcServer: oidcServer,
proxyService: proxyService,
handler: handler,
router: router,
cleanup: func() {
cleanup()
oidcServer.Close()
},
}
}
func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store.Store) {
t.Helper()
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pubKey := base64.StdEncoding.EncodeToString(pub)
privKey := base64.StdEncoding.EncodeToString(priv)
testProxy := &reverseproxy.Service{
ID: "testProxyId",
AccountID: "testAccountId",
Name: "Test Proxy",
Domain: "test-proxy.example.com",
Targets: []*reverseproxy.Target{{
Path: strPtr("/"),
Host: "localhost",
Port: 8080,
Protocol: "http",
TargetId: "peer1",
TargetType: "peer",
Enabled: true,
}},
Enabled: true,
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"allowedGroupId"},
},
},
SessionPrivateKey: privKey,
SessionPublicKey: pubKey,
}
require.NoError(t, testStore.CreateService(ctx, testProxy))
restrictedProxy := &reverseproxy.Service{
ID: "restrictedProxyId",
AccountID: "testAccountId",
Name: "Restricted Proxy",
Domain: "restricted-proxy.example.com",
Targets: []*reverseproxy.Target{{
Path: strPtr("/"),
Host: "localhost",
Port: 8080,
Protocol: "http",
TargetId: "peer1",
TargetType: "peer",
Enabled: true,
}},
Enabled: true,
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"restrictedGroupId"},
},
},
SessionPrivateKey: privKey,
SessionPublicKey: pubKey,
}
require.NoError(t, testStore.CreateService(ctx, restrictedProxy))
noAuthProxy := &reverseproxy.Service{
ID: "noAuthProxyId",
AccountID: "testAccountId",
Name: "No Auth Proxy",
Domain: "no-auth-proxy.example.com",
Targets: []*reverseproxy.Target{{
Path: strPtr("/"),
Host: "localhost",
Port: 8080,
Protocol: "http",
TargetId: "peer1",
TargetType: "peer",
Enabled: true,
}},
Enabled: true,
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Enabled: false,
},
},
SessionPrivateKey: privKey,
SessionPublicKey: pubKey,
}
require.NoError(t, testStore.CreateService(ctx, noAuthProxy))
}
func strPtr(s string) *string {
return &s
}
func createTestAccountsAndUsers(t *testing.T, ctx context.Context, testStore store.Store) {
t.Helper()
testAccount := &types.Account{
Id: "testAccountId",
Domain: "test.com",
DomainCategory: "private",
IsDomainPrimaryAccount: true,
CreatedAt: time.Now(),
}
require.NoError(t, testStore.SaveAccount(ctx, testAccount))
allowedGroup := &types.Group{
ID: "allowedGroupId",
AccountID: "testAccountId",
Name: "Allowed Group",
Issued: "api",
}
require.NoError(t, testStore.CreateGroup(ctx, allowedGroup))
allowedUser := &types.User{
Id: "allowedUserId",
AccountID: "testAccountId",
Role: types.UserRoleUser,
AutoGroups: []string{"allowedGroupId"},
CreatedAt: time.Now(),
Issued: "api",
}
require.NoError(t, testStore.SaveUser(ctx, allowedUser))
}
// testServiceManager is a minimal implementation for testing.
type testServiceManager struct {
store store.Store
}
func (m *testServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) {
return nil, nil
}
func (m *testServiceManager) GetService(_ context.Context, _, _, _ string) (*reverseproxy.Service, error) {
return nil, nil
}
func (m *testServiceManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
return nil, nil
}
func (m *testServiceManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
return nil, nil
}
func (m *testServiceManager) DeleteService(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testServiceManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error {
return nil
}
func (m *testServiceManager) SetStatus(_ context.Context, _, _ string, _ reverseproxy.ProxyStatus) error {
return nil
}
func (m *testServiceManager) ReloadAllServicesForAccount(_ context.Context, _ string) error {
return nil
}
func (m *testServiceManager) ReloadService(_ context.Context, _, _ string) error {
return nil
}
func (m *testServiceManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
return m.store.GetServices(ctx, store.LockingStrengthNone)
}
func (m *testServiceManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.Service, error) {
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID)
}
func (m *testServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
}
func (m *testServiceManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
return "", nil
}
func createTestState(t *testing.T, ps *nbgrpc.ProxyServiceServer, redirectURL string) string {
t.Helper()
resp, err := ps.GetOIDCURL(context.Background(), &proto.GetOIDCURLRequest{
RedirectUrl: redirectURL,
AccountId: "testAccountId",
})
require.NoError(t, err)
parsedURL, err := url.Parse(resp.Url)
require.NoError(t, err)
return parsedURL.Query().Get("state")
}
func TestAuthCallback_UserAllowedToLogin(t *testing.T) {
setup := setupAuthCallbackTest(t)
defer setup.cleanup()
setup.oidcServer.tokenSubject = "allowedUserId"
state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/dashboard")
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
rec := httptest.NewRecorder()
setup.router.ServeHTTP(rec, req)
require.Equal(t, http.StatusFound, rec.Code)
location := rec.Header().Get("Location")
require.NotEmpty(t, location)
parsedLocation, err := url.Parse(location)
require.NoError(t, err)
require.Equal(t, "test-proxy.example.com", parsedLocation.Host)
require.NotEmpty(t, parsedLocation.Query().Get("session_token"), "Should include session token")
require.Empty(t, parsedLocation.Query().Get("error"), "Should not have error parameter")
}
func TestAuthCallback_ProxyNotFound(t *testing.T) {
setup := setupAuthCallbackTest(t)
defer setup.cleanup()
setup.oidcServer.tokenSubject = "allowedUserId"
state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/")
require.NoError(t, setup.store.DeleteService(context.Background(), "testAccountId", "testProxyId"))
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
rec := httptest.NewRecorder()
setup.router.ServeHTTP(rec, req)
require.Equal(t, http.StatusFound, rec.Code)
location := rec.Header().Get("Location")
parsedLocation, err := url.Parse(location)
require.NoError(t, err)
require.Equal(t, "access_denied", parsedLocation.Query().Get("error"))
}
func TestAuthCallback_InvalidToken(t *testing.T) {
setup := setupAuthCallbackTest(t)
defer setup.cleanup()
setup.oidcServer.failExchange = true
state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/")
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=invalid-code&state="+url.QueryEscape(state), nil)
rec := httptest.NewRecorder()
setup.router.ServeHTTP(rec, req)
require.Equal(t, http.StatusInternalServerError, rec.Code)
require.Contains(t, rec.Body.String(), "Failed to exchange code")
}
func TestAuthCallback_ExpiredToken(t *testing.T) {
setup := setupAuthCallbackTest(t)
defer setup.cleanup()
setup.oidcServer.tokenSubject = "allowedUserId"
setup.oidcServer.tokenExpiry = -time.Hour
state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/")
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
rec := httptest.NewRecorder()
setup.router.ServeHTTP(rec, req)
require.Equal(t, http.StatusUnauthorized, rec.Code)
require.Contains(t, rec.Body.String(), "Failed to validate token")
}
func TestAuthCallback_InvalidState(t *testing.T) {
setup := setupAuthCallbackTest(t)
defer setup.cleanup()
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state=invalid-state", nil)
rec := httptest.NewRecorder()
setup.router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
require.Contains(t, rec.Body.String(), "Invalid state")
}
func TestAuthCallback_MissingState(t *testing.T) {
setup := setupAuthCallbackTest(t)
defer setup.cleanup()
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code", nil)
rec := httptest.NewRecorder()
setup.router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}

View File

@@ -0,0 +1,185 @@
package proxy
import (
"net/http"
"net/http/httptest"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
)
func TestAuthCallbackHandler_RateLimiting(t *testing.T) {
handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{}, nil)
require.NotNil(t, handler.rateLimiter, "Rate limiter should be initialized")
req := httptest.NewRequest(http.MethodGet, "/callback?state=test&code=test", nil)
req.RemoteAddr = "192.168.1.100:12345"
t.Run("allows requests under limit", func(t *testing.T) {
for i := 0; i < 15; i++ {
allowed := handler.rateLimiter.Allow("192.168.1.100")
assert.True(t, allowed, "Request %d should be allowed", i+1)
}
})
t.Run("blocks requests over limit", func(t *testing.T) {
handler.rateLimiter.Reset("192.168.1.200")
for i := 0; i < 15; i++ {
handler.rateLimiter.Allow("192.168.1.200")
}
allowed := handler.rateLimiter.Allow("192.168.1.200")
assert.False(t, allowed, "Request over limit should be blocked")
})
t.Run("different IPs have separate limits", func(t *testing.T) {
ip1 := "192.168.1.201"
ip2 := "192.168.1.202"
handler.rateLimiter.Reset(ip1)
handler.rateLimiter.Reset(ip2)
for i := 0; i < 15; i++ {
handler.rateLimiter.Allow(ip1)
}
assert.False(t, handler.rateLimiter.Allow(ip1), "IP1 should be blocked")
assert.True(t, handler.rateLimiter.Allow(ip2), "IP2 should be allowed")
})
}
func TestAuthCallbackHandler_RateLimitInHandleCallback(t *testing.T) {
handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{}, nil)
testIP := "10.0.0.50"
handler.rateLimiter.Reset(testIP)
t.Run("returns 429 when rate limited", func(t *testing.T) {
for i := 0; i < 15; i++ {
handler.rateLimiter.Allow(testIP)
}
req := httptest.NewRequest(http.MethodGet, "/callback?state=test&code=test", nil)
req.RemoteAddr = testIP + ":12345"
rr := httptest.NewRecorder()
handler.handleCallback(rr, req)
assert.Equal(t, http.StatusTooManyRequests, rr.Code, "Should return 429 status code")
assert.Contains(t, rr.Body.String(), "Too many requests", "Should contain rate limit message")
})
}
func TestResolveClientIP(t *testing.T) {
trusted := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("172.16.0.0/12"),
}
tests := []struct {
name string
remoteAddr string
xForwardedFor string
trustedProxy []netip.Prefix
expectedIP string
}{
{
name: "no trusted proxies returns RemoteAddr",
remoteAddr: "203.0.113.50:9999",
xForwardedFor: "1.2.3.4",
trustedProxy: nil,
expectedIP: "203.0.113.50",
},
{
name: "untrusted RemoteAddr ignores XFF",
remoteAddr: "203.0.113.50:9999",
xForwardedFor: "1.2.3.4, 10.0.0.1",
trustedProxy: trusted,
expectedIP: "203.0.113.50",
},
{
name: "trusted RemoteAddr with single client in XFF",
remoteAddr: "10.0.0.1:5000",
xForwardedFor: "203.0.113.50",
trustedProxy: trusted,
expectedIP: "203.0.113.50",
},
{
name: "trusted RemoteAddr walks past trusted entries in XFF",
remoteAddr: "10.0.0.1:5000",
xForwardedFor: "203.0.113.50, 10.0.0.2, 172.16.0.5",
trustedProxy: trusted,
expectedIP: "203.0.113.50",
},
{
name: "trusted RemoteAddr with empty XFF falls back to RemoteAddr",
remoteAddr: "10.0.0.1:5000",
trustedProxy: trusted,
expectedIP: "10.0.0.1",
},
{
name: "all XFF IPs trusted returns leftmost",
remoteAddr: "10.0.0.1:5000",
xForwardedFor: "10.0.0.2, 172.16.0.1, 10.0.0.3",
trustedProxy: trusted,
expectedIP: "10.0.0.2",
},
{
name: "XFF with whitespace",
remoteAddr: "10.0.0.1:5000",
xForwardedFor: " 203.0.113.50 , 10.0.0.2 ",
trustedProxy: trusted,
expectedIP: "203.0.113.50",
},
{
name: "multi-hop with mixed trust",
remoteAddr: "10.0.0.1:5000",
xForwardedFor: "8.8.8.8, 203.0.113.50, 172.16.0.1",
trustedProxy: trusted,
expectedIP: "203.0.113.50",
},
{
name: "RemoteAddr without port",
remoteAddr: "192.168.1.100",
expectedIP: "192.168.1.100",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{}, tt.trustedProxy)
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = tt.remoteAddr
if tt.xForwardedFor != "" {
req.Header.Set("X-Forwarded-For", tt.xForwardedFor)
}
ip := handler.resolveClientIP(req)
assert.Equal(t, tt.expectedIP, ip)
})
}
}
func TestAuthCallbackHandler_RateLimiterConfiguration(t *testing.T) {
handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{}, nil)
require.NotNil(t, handler.rateLimiter, "Rate limiter should be initialized")
testIP := "192.168.1.250"
handler.rateLimiter.Reset(testIP)
for i := 0; i < 15; i++ {
allowed := handler.rateLimiter.Allow(testIP)
assert.True(t, allowed, "Should allow request %d within burst limit", i+1)
}
allowed := handler.rateLimiter.Allow(testIP)
assert.False(t, allowed, "Should block request that exceeds burst limit")
}

View File

@@ -10,6 +10,10 @@ import (
"github.com/stretchr/testify/assert"
"github.com/netbirdio/management-integrations/integrations"
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
recordsManager "github.com/netbirdio/netbird/management/internals/modules/zones/records/manager"
@@ -86,6 +90,14 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
t.Fatalf("Failed to create manager: %v", err)
}
accessLogsManager := accesslogsmanager.NewManager(store, permissionsManager, nil)
proxyTokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Minute)
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager)
domainManager := manager.NewManager(store, proxyServiceServer, permissionsManager)
reverseProxyManager := reverseproxymanager.NewManager(store, am, permissionsManager, proxyServiceServer, domainManager)
proxyServiceServer.SetProxyManager(reverseProxyManager)
am.SetServiceManager(reverseProxyManager)
// @note this is required so that PAT's validate from store, but JWT's are mocked
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false)
authManagerMock := &serverauth.MockManager{
@@ -102,7 +114,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, reverseProxyManager, nil, nil, nil, nil)
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}