[management, reverse proxy] Add reverse proxy feature (#5291)

* implement reverse proxy


---------

Co-authored-by: Alisdair MacLeod <git@alisdairmacleod.co.uk>
Co-authored-by: mlsmaycon <mlsmaycon@gmail.com>
Co-authored-by: Eduard Gert <kontakt@eduardgert.de>
Co-authored-by: Viktor Liu <viktor@netbird.io>
Co-authored-by: Diego Noguês <diego.sure@gmail.com>
Co-authored-by: Diego Noguês <49420+diegocn@users.noreply.github.com>
Co-authored-by: Bethuel Mmbaga <bethuelmbaga12@gmail.com>
Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com>
Co-authored-by: Ashley Mensah <ashleyamo982@gmail.com>
This commit is contained in:
Pascal Fischer
2026-02-13 19:37:43 +01:00
committed by GitHub
parent edce11b34d
commit f53155562f
225 changed files with 35513 additions and 235 deletions

View File

@@ -0,0 +1,18 @@
<!doctype html>
{{ range $method, $value := .Methods }}
{{ if eq $method "pin" }}
<form>
<label for={{ $value }}>PIN:</label>
<input name={{ $value }} id={{ $value }} />
<button type=submit>Submit</button>
</form>
{{ else if eq $method "password" }}
<form>
<label for={{ $value }}>Password:</label>
<input name={{ $value }} id={{ $value }}/>
<button type=submit>Submit</button>
</form>
{{ else if eq $method "oidc" }}
<a href={{ $value }}>Click here to log in with SSO</a>
{{ end }}
{{ end }}

View File

@@ -0,0 +1,364 @@
package auth
import (
"context"
"crypto/ed25519"
"encoding/base64"
"fmt"
"net"
"net/http"
"net/url"
"sync"
"time"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/proxy/web"
"github.com/netbirdio/netbird/shared/management/proto"
)
type authenticator interface {
Authenticate(ctx context.Context, in *proto.AuthenticateRequest, opts ...grpc.CallOption) (*proto.AuthenticateResponse, error)
}
// SessionValidator validates session tokens and checks user access permissions.
type SessionValidator interface {
ValidateSession(ctx context.Context, in *proto.ValidateSessionRequest, opts ...grpc.CallOption) (*proto.ValidateSessionResponse, error)
}
// Scheme defines an authentication mechanism for a domain.
type Scheme interface {
Type() auth.Method
// Authenticate checks the request and determines whether it represents
// an authenticated user. An empty token indicates an unauthenticated
// request; optionally, promptData may be returned for the login UI.
// An error indicates an infrastructure failure (e.g. gRPC unavailable).
Authenticate(*http.Request) (token string, promptData string, err error)
}
type DomainConfig struct {
Schemes []Scheme
SessionPublicKey ed25519.PublicKey
SessionExpiration time.Duration
AccountID string
ServiceID string
}
type validationResult struct {
UserID string
Valid bool
DeniedReason string
}
type Middleware struct {
domainsMux sync.RWMutex
domains map[string]DomainConfig
logger *log.Logger
sessionValidator SessionValidator
}
// NewMiddleware creates a new authentication middleware.
// The sessionValidator is optional; if nil, OIDC session tokens will be validated
// locally without group access checks.
func NewMiddleware(logger *log.Logger, sessionValidator SessionValidator) *Middleware {
if logger == nil {
logger = log.StandardLogger()
}
return &Middleware{
domains: make(map[string]DomainConfig),
logger: logger,
sessionValidator: sessionValidator,
}
}
// Protect applies authentication middleware to the passed handler.
// For each incoming request it will be checked against the middleware's
// internal list of protected domains.
// If the Host domain in the inbound request is not present, then it will
// simply be passed through.
// However, if the Host domain is present, then the specified authentication
// schemes for that domain will be applied to the request.
// In the event that no authentication schemes are defined for the domain,
// then the request will also be simply passed through.
func (mw *Middleware) Protect(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
host, _, err := net.SplitHostPort(r.Host)
if err != nil {
host = r.Host
}
config, exists := mw.getDomainConfig(host)
mw.logger.Debugf("checking authentication for host: %s, exists: %t", host, exists)
// Domains that are not configured here or have no authentication schemes applied should simply pass through.
if !exists || len(config.Schemes) == 0 {
next.ServeHTTP(w, r)
return
}
// Set account and service IDs in captured data for access logging.
setCapturedIDs(r, config)
if mw.handleOAuthCallbackError(w, r) {
return
}
if mw.forwardWithSessionCookie(w, r, host, config, next) {
return
}
mw.authenticateWithSchemes(w, r, host, config)
})
}
func (mw *Middleware) getDomainConfig(host string) (DomainConfig, bool) {
mw.domainsMux.RLock()
defer mw.domainsMux.RUnlock()
config, exists := mw.domains[host]
return config, exists
}
func setCapturedIDs(r *http.Request, config DomainConfig) {
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetAccountId(types.AccountID(config.AccountID))
cd.SetServiceId(config.ServiceID)
}
}
// handleOAuthCallbackError checks for error query parameters from an OAuth
// callback and renders the access denied page if present.
func (mw *Middleware) handleOAuthCallbackError(w http.ResponseWriter, r *http.Request) bool {
errCode := r.URL.Query().Get("error")
if errCode == "" {
return false
}
var requestID string
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetAuthMethod(auth.MethodOIDC.String())
requestID = cd.GetRequestID()
}
errDesc := r.URL.Query().Get("error_description")
if errDesc == "" {
errDesc = "An error occurred during authentication"
}
web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", errDesc, requestID)
return true
}
// forwardWithSessionCookie checks for a valid session cookie and, if found,
// sets the user identity on the request context and forwards to the next handler.
func (mw *Middleware) forwardWithSessionCookie(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, next http.Handler) bool {
cookie, err := r.Cookie(auth.SessionCookieName)
if err != nil {
return false
}
userID, method, err := auth.ValidateSessionJWT(cookie.Value, host, config.SessionPublicKey)
if err != nil {
return false
}
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetUserID(userID)
cd.SetAuthMethod(method)
}
next.ServeHTTP(w, r)
return true
}
// authenticateWithSchemes tries each configured auth scheme in order.
// On success it sets a session cookie and redirects; on failure it renders the login page.
func (mw *Middleware) authenticateWithSchemes(w http.ResponseWriter, r *http.Request, host string, config DomainConfig) {
methods := make(map[string]string)
var attemptedMethod string
for _, scheme := range config.Schemes {
token, promptData, err := scheme.Authenticate(r)
if err != nil {
mw.logger.WithField("scheme", scheme.Type().String()).Warnf("authentication infrastructure error: %v", err)
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
}
http.Error(w, "authentication service unavailable", http.StatusBadGateway)
return
}
// Track if credentials were submitted but auth failed
if token == "" && wasCredentialSubmitted(r, scheme.Type()) {
attemptedMethod = scheme.Type().String()
}
if token != "" {
mw.handleAuthenticatedToken(w, r, host, token, config, scheme)
return
}
methods[scheme.Type().String()] = promptData
}
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
if attemptedMethod != "" {
cd.SetAuthMethod(attemptedMethod)
}
}
web.ServeHTTP(w, r, map[string]any{"methods": methods}, http.StatusUnauthorized)
}
// handleAuthenticatedToken validates the token, handles denied access, and on
// success sets a session cookie and redirects to the original URL.
func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Request, host, token string, config DomainConfig, scheme Scheme) {
result, err := mw.validateSessionToken(r.Context(), host, token, config.SessionPublicKey, scheme.Type())
if err != nil {
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetAuthMethod(scheme.Type().String())
}
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if !result.Valid {
var requestID string
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetUserID(result.UserID)
cd.SetAuthMethod(scheme.Type().String())
requestID = cd.GetRequestID()
}
web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", "You are not authorized to access this service", requestID)
return
}
expiration := config.SessionExpiration
if expiration == 0 {
expiration = auth.DefaultSessionExpiry
}
http.SetCookie(w, &http.Cookie{
Name: auth.SessionCookieName,
Value: token,
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteLaxMode,
MaxAge: int(expiration.Seconds()),
})
// Redirect instead of forwarding the auth POST to the backend.
// The browser will follow with a GET carrying the new session cookie.
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetUserID(result.UserID)
cd.SetAuthMethod(scheme.Type().String())
}
redirectURL := stripSessionTokenParam(r.URL)
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
}
// wasCredentialSubmitted checks if credentials were submitted for the given auth method.
func wasCredentialSubmitted(r *http.Request, method auth.Method) bool {
switch method {
case auth.MethodPIN:
return r.FormValue("pin") != ""
case auth.MethodPassword:
return r.FormValue("password") != ""
case auth.MethodOIDC:
return r.URL.Query().Get("session_token") != ""
}
return false
}
// AddDomain registers authentication schemes for the given domain.
// If schemes are provided, a valid session public key is required to sign/verify
// session JWTs. Returns an error if the key is missing or invalid.
// Callers must not serve the domain if this returns an error, to avoid
// exposing an unauthenticated service.
func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID, serviceID string) error {
if len(schemes) == 0 {
mw.domainsMux.Lock()
defer mw.domainsMux.Unlock()
mw.domains[domain] = DomainConfig{
AccountID: accountID,
ServiceID: serviceID,
}
return nil
}
pubKeyBytes, err := base64.StdEncoding.DecodeString(publicKeyB64)
if err != nil {
return fmt.Errorf("decode session public key for domain %s: %w", domain, err)
}
if len(pubKeyBytes) != ed25519.PublicKeySize {
return fmt.Errorf("invalid session public key size for domain %s: got %d, want %d", domain, len(pubKeyBytes), ed25519.PublicKeySize)
}
mw.domainsMux.Lock()
defer mw.domainsMux.Unlock()
mw.domains[domain] = DomainConfig{
Schemes: schemes,
SessionPublicKey: pubKeyBytes,
SessionExpiration: expiration,
AccountID: accountID,
ServiceID: serviceID,
}
return nil
}
func (mw *Middleware) RemoveDomain(domain string) {
mw.domainsMux.Lock()
defer mw.domainsMux.Unlock()
delete(mw.domains, domain)
}
// validateSessionToken validates a session token, optionally checking group access via gRPC.
// For OIDC tokens with a configured validator, it calls ValidateSession to check group access.
// For other auth methods (PIN, password), it validates the JWT locally.
// Returns a validationResult with user ID and validity status, or error for invalid tokens.
func (mw *Middleware) validateSessionToken(ctx context.Context, host, token string, publicKey ed25519.PublicKey, method auth.Method) (*validationResult, error) {
// For OIDC with a session validator, call the gRPC service to check group access
if method == auth.MethodOIDC && mw.sessionValidator != nil {
resp, err := mw.sessionValidator.ValidateSession(ctx, &proto.ValidateSessionRequest{
Domain: host,
SessionToken: token,
})
if err != nil {
mw.logger.WithError(err).Error("ValidateSession gRPC call failed")
return nil, fmt.Errorf("session validation failed")
}
if !resp.Valid {
mw.logger.WithFields(log.Fields{
"domain": host,
"denied_reason": resp.DeniedReason,
"user_id": resp.UserId,
}).Debug("Session validation denied")
return &validationResult{
UserID: resp.UserId,
Valid: false,
DeniedReason: resp.DeniedReason,
}, nil
}
return &validationResult{UserID: resp.UserId, Valid: true}, nil
}
// For non-OIDC methods or when no validator is configured, validate JWT locally
userID, _, err := auth.ValidateSessionJWT(token, host, publicKey)
if err != nil {
return nil, err
}
return &validationResult{UserID: userID, Valid: true}, nil
}
// stripSessionTokenParam returns the request URI with the session_token query
// parameter removed so it doesn't linger in the browser's address bar or history.
func stripSessionTokenParam(u *url.URL) string {
q := u.Query()
if !q.Has("session_token") {
return u.RequestURI()
}
q.Del("session_token")
clean := *u
clean.RawQuery = q.Encode()
return clean.RequestURI()
}

View File

@@ -0,0 +1,660 @@
package auth
import (
"crypto/ed25519"
"crypto/rand"
"encoding/base64"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/proxy/internal/proxy"
)
func generateTestKeyPair(t *testing.T) *sessionkey.KeyPair {
t.Helper()
kp, err := sessionkey.GenerateKeyPair()
require.NoError(t, err)
return kp
}
// stubScheme is a minimal Scheme implementation for testing.
type stubScheme struct {
method auth.Method
token string
promptID string
authFn func(*http.Request) (string, string, error)
}
func (s *stubScheme) Type() auth.Method { return s.method }
func (s *stubScheme) Authenticate(r *http.Request) (string, string, error) {
if s.authFn != nil {
return s.authFn(r)
}
return s.token, s.promptID, nil
}
func newPassthroughHandler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("backend"))
})
}
func TestAddDomain_ValidKey(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err := mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")
require.NoError(t, err)
mw.domainsMux.RLock()
config, exists := mw.domains["example.com"]
mw.domainsMux.RUnlock()
assert.True(t, exists, "domain should be registered")
assert.Len(t, config.Schemes, 1)
assert.Equal(t, ed25519.PublicKeySize, len(config.SessionPublicKey))
assert.Equal(t, time.Hour, config.SessionExpiration)
}
func TestAddDomain_EmptyKey(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour, "", "")
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid session public key size")
mw.domainsMux.RLock()
_, exists := mw.domains["example.com"]
mw.domainsMux.RUnlock()
assert.False(t, exists, "domain must not be registered with an empty session key")
}
func TestAddDomain_InvalidBase64(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour, "", "")
require.Error(t, err)
assert.Contains(t, err.Error(), "decode session public key")
mw.domainsMux.RLock()
_, exists := mw.domains["example.com"]
mw.domainsMux.RUnlock()
assert.False(t, exists, "domain must not be registered with invalid base64 key")
}
func TestAddDomain_WrongKeySize(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
shortKey := base64.StdEncoding.EncodeToString([]byte("tooshort"))
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err := mw.AddDomain("example.com", []Scheme{scheme}, shortKey, time.Hour, "", "")
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid session public key size")
mw.domainsMux.RLock()
_, exists := mw.domains["example.com"]
mw.domainsMux.RUnlock()
assert.False(t, exists, "domain must not be registered with a wrong-size key")
}
func TestAddDomain_NoSchemes_NoKeyRequired(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
err := mw.AddDomain("example.com", nil, "", time.Hour, "", "")
require.NoError(t, err, "domains with no auth schemes should not require a key")
mw.domainsMux.RLock()
_, exists := mw.domains["example.com"]
mw.domainsMux.RUnlock()
assert.True(t, exists)
}
func TestAddDomain_OverwritesPreviousConfig(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
kp1 := generateTestKeyPair(t)
kp2 := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", ""))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp2.PublicKey, 2*time.Hour, "", ""))
mw.domainsMux.RLock()
config := mw.domains["example.com"]
mw.domainsMux.RUnlock()
pubKeyBytes, _ := base64.StdEncoding.DecodeString(kp2.PublicKey)
assert.Equal(t, ed25519.PublicKey(pubKeyBytes), config.SessionPublicKey, "should use the latest key")
assert.Equal(t, 2*time.Hour, config.SessionExpiration)
}
func TestRemoveDomain(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
mw.RemoveDomain("example.com")
mw.domainsMux.RLock()
_, exists := mw.domains["example.com"]
mw.domainsMux.RUnlock()
assert.False(t, exists)
}
func TestProtect_UnknownDomainPassesThrough(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "http://unknown.com/", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "backend", rec.Body.String())
}
func TestProtect_DomainWithNoSchemesPassesThrough(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour, "", ""))
handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "backend", rec.Body.String())
}
func TestProtect_UnauthenticatedRequestIsBlocked(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
backendCalled = true
w.WriteHeader(http.StatusOK)
})
handler := mw.Protect(backend)
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.False(t, backendCalled, "unauthenticated request should not reach backend")
}
func TestProtect_HostWithPortIsMatched(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
backendCalled = true
w.WriteHeader(http.StatusOK)
})
handler := mw.Protect(backend)
req := httptest.NewRequest(http.MethodGet, "http://example.com:8443/", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.False(t, backendCalled, "host with port should still match the protected domain")
}
func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour)
require.NoError(t, err)
capturedData := &proxy.CapturedData{}
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cd := proxy.CapturedDataFromContext(r.Context())
require.NotNil(t, cd)
assert.Equal(t, "test-user", cd.GetUserID())
assert.Equal(t, "pin", cd.GetAuthMethod())
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("authenticated"))
}))
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req = req.WithContext(proxy.WithCapturedData(req.Context(), capturedData))
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: token})
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "authenticated", rec.Body.String())
}
func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
// Sign a token that expired 1 second ago.
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, -time.Second)
require.NoError(t, err)
var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
backendCalled = true
w.WriteHeader(http.StatusOK)
})
handler := mw.Protect(backend)
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: token})
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.False(t, backendCalled, "expired session should not reach the backend")
}
func TestProtect_WrongDomainCookieIsRejected(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
// Token signed for a different domain audience.
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "other.com", auth.MethodPIN, time.Hour)
require.NoError(t, err)
var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
backendCalled = true
w.WriteHeader(http.StatusOK)
})
handler := mw.Protect(backend)
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: token})
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.False(t, backendCalled, "cookie for wrong domain should be rejected")
}
func TestProtect_WrongKeyCookieIsRejected(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
kp1 := generateTestKeyPair(t)
kp2 := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", ""))
// Token signed with a different private key.
token, err := sessionkey.SignToken(kp2.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour)
require.NoError(t, err)
var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
backendCalled = true
w.WriteHeader(http.StatusOK)
})
handler := mw.Protect(backend)
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: token})
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.False(t, backendCalled, "cookie signed by wrong key should be rejected")
}
func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
token, err := sessionkey.SignToken(kp.PrivateKey, "pin-user", "example.com", auth.MethodPIN, time.Hour)
require.NoError(t, err)
scheme := &stubScheme{
method: auth.MethodPIN,
authFn: func(r *http.Request) (string, string, error) {
if r.FormValue("pin") == "111111" {
return token, "", nil
}
return "", "pin", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
backendCalled = true
w.WriteHeader(http.StatusOK)
})
handler := mw.Protect(backend)
// Submit the PIN via form POST.
form := url.Values{"pin": {"111111"}}
req := httptest.NewRequest(http.MethodPost, "http://example.com/somepath", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.False(t, backendCalled, "backend should not be called during auth, only a redirect should be returned")
assert.Equal(t, http.StatusSeeOther, rec.Code)
assert.Equal(t, "/somepath", rec.Header().Get("Location"), "redirect should point to the original request URI")
cookies := rec.Result().Cookies()
var sessionCookie *http.Cookie
for _, c := range cookies {
if c.Name == auth.SessionCookieName {
sessionCookie = c
break
}
}
require.NotNil(t, sessionCookie, "session cookie should be set after successful auth")
assert.True(t, sessionCookie.HttpOnly)
assert.True(t, sessionCookie.Secure)
assert.Equal(t, http.SameSiteLaxMode, sessionCookie.SameSite)
}
func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{
method: auth.MethodPIN,
authFn: func(_ *http.Request) (string, string, error) {
return "", "pin", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
for _, c := range rec.Result().Cookies() {
assert.NotEqual(t, auth.SessionCookieName, c.Name, "no session cookie should be set on failed auth")
}
}
func TestProtect_MultipleSchemes(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
token, err := sessionkey.SignToken(kp.PrivateKey, "password-user", "example.com", auth.MethodPassword, time.Hour)
require.NoError(t, err)
// First scheme (PIN) always fails, second scheme (password) succeeds.
pinScheme := &stubScheme{
method: auth.MethodPIN,
authFn: func(_ *http.Request) (string, string, error) {
return "", "pin", nil
},
}
passwordScheme := &stubScheme{
method: auth.MethodPassword,
authFn: func(r *http.Request) (string, string, error) {
if r.FormValue("password") == "secret" {
return token, "", nil
}
return "", "password", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme, passwordScheme}, kp.PublicKey, time.Hour, "", ""))
var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
backendCalled = true
w.WriteHeader(http.StatusOK)
})
handler := mw.Protect(backend)
form := url.Values{"password": {"secret"}}
req := httptest.NewRequest(http.MethodPost, "http://example.com/", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.False(t, backendCalled, "backend should not be called during auth")
assert.Equal(t, http.StatusSeeOther, rec.Code)
}
func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
// Return a garbage token that won't validate.
scheme := &stubScheme{
method: auth.MethodPIN,
authFn: func(_ *http.Request) (string, string, error) {
return "invalid-jwt-token", "", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestAddDomain_RandomBytes32NotEd25519(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
// 32 random bytes that happen to be valid base64 and correct size
// but are actually a valid ed25519 public key length-wise.
// This should succeed because ed25519 public keys are just 32 bytes.
randomBytes := make([]byte, ed25519.PublicKeySize)
_, err := rand.Read(randomBytes)
require.NoError(t, err)
key := base64.StdEncoding.EncodeToString(randomBytes)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err = mw.AddDomain("example.com", []Scheme{scheme}, key, time.Hour, "", "")
require.NoError(t, err, "any 32-byte key should be accepted at registration time")
}
func TestAddDomain_InvalidKeyDoesNotCorruptExistingConfig(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
// Attempt to overwrite with an invalid key.
err := mw.AddDomain("example.com", []Scheme{scheme}, "bad", time.Hour, "", "")
require.Error(t, err)
// The original valid config should still be intact.
mw.domainsMux.RLock()
config, exists := mw.domains["example.com"]
mw.domainsMux.RUnlock()
assert.True(t, exists, "original config should still exist")
assert.Len(t, config.Schemes, 1)
assert.Equal(t, time.Hour, config.SessionExpiration)
}
func TestProtect_FailedPinAuthCapturesAuthMethod(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
// Scheme that always fails authentication (returns empty token)
scheme := &stubScheme{
method: auth.MethodPIN,
authFn: func(_ *http.Request) (string, string, error) {
return "", "pin", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
capturedData := &proxy.CapturedData{}
handler := mw.Protect(newPassthroughHandler())
// Submit wrong PIN - should capture auth method
form := url.Values{"pin": {"wrong-pin"}}
req := httptest.NewRequest(http.MethodPost, "http://example.com/", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req = req.WithContext(proxy.WithCapturedData(req.Context(), capturedData))
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusUnauthorized, rec.Code)
assert.Equal(t, "pin", capturedData.GetAuthMethod(), "Auth method should be captured for failed PIN auth")
}
func TestProtect_FailedPasswordAuthCapturesAuthMethod(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{
method: auth.MethodPassword,
authFn: func(_ *http.Request) (string, string, error) {
return "", "password", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
capturedData := &proxy.CapturedData{}
handler := mw.Protect(newPassthroughHandler())
// Submit wrong password - should capture auth method
form := url.Values{"password": {"wrong-password"}}
req := httptest.NewRequest(http.MethodPost, "http://example.com/", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req = req.WithContext(proxy.WithCapturedData(req.Context(), capturedData))
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusUnauthorized, rec.Code)
assert.Equal(t, "password", capturedData.GetAuthMethod(), "Auth method should be captured for failed password auth")
}
func TestProtect_NoCredentialsDoesNotCaptureAuthMethod(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{
method: auth.MethodPIN,
authFn: func(_ *http.Request) (string, string, error) {
return "", "pin", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
capturedData := &proxy.CapturedData{}
handler := mw.Protect(newPassthroughHandler())
// No credentials submitted - should not capture auth method
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req = req.WithContext(proxy.WithCapturedData(req.Context(), capturedData))
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusUnauthorized, rec.Code)
assert.Empty(t, capturedData.GetAuthMethod(), "Auth method should not be captured when no credentials submitted")
}
func TestWasCredentialSubmitted(t *testing.T) {
tests := []struct {
name string
method auth.Method
formData url.Values
query url.Values
expected bool
}{
{
name: "PIN submitted",
method: auth.MethodPIN,
formData: url.Values{"pin": {"123456"}},
expected: true,
},
{
name: "PIN not submitted",
method: auth.MethodPIN,
formData: url.Values{},
expected: false,
},
{
name: "Password submitted",
method: auth.MethodPassword,
formData: url.Values{"password": {"secret"}},
expected: true,
},
{
name: "Password not submitted",
method: auth.MethodPassword,
formData: url.Values{},
expected: false,
},
{
name: "OIDC token in query",
method: auth.MethodOIDC,
query: url.Values{"session_token": {"abc123"}},
expected: true,
},
{
name: "OIDC token not in query",
method: auth.MethodOIDC,
query: url.Values{},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reqURL := "http://example.com/"
if len(tt.query) > 0 {
reqURL += "?" + tt.query.Encode()
}
var body *strings.Reader
if len(tt.formData) > 0 {
body = strings.NewReader(tt.formData.Encode())
} else {
body = strings.NewReader("")
}
req := httptest.NewRequest(http.MethodPost, reqURL, body)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
result := wasCredentialSubmitted(req, tt.method)
assert.Equal(t, tt.expected, result)
})
}
}

View File

@@ -0,0 +1,65 @@
package auth
import (
"context"
"fmt"
"net/http"
"net/url"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/shared/management/proto"
)
type urlGenerator interface {
GetOIDCURL(context.Context, *proto.GetOIDCURLRequest, ...grpc.CallOption) (*proto.GetOIDCURLResponse, error)
}
type OIDC struct {
id string
accountId string
forwardedProto string
client urlGenerator
}
// NewOIDC creates a new OIDC authentication scheme
func NewOIDC(client urlGenerator, id, accountId, forwardedProto string) OIDC {
return OIDC{
id: id,
accountId: accountId,
forwardedProto: forwardedProto,
client: client,
}
}
func (OIDC) Type() auth.Method {
return auth.MethodOIDC
}
// Authenticate checks for an OIDC session token or obtains the OIDC redirect URL.
func (o OIDC) Authenticate(r *http.Request) (string, string, error) {
// Check for the session_token query param (from OIDC redirects).
// The management server passes the token in the URL because it cannot set
// cookies for the proxy's domain (cookies are domain-scoped per RFC 6265).
if token := r.URL.Query().Get("session_token"); token != "" {
return token, "", nil
}
redirectURL := &url.URL{
Scheme: auth.ResolveProto(o.forwardedProto, r.TLS),
Host: r.Host,
Path: r.URL.Path,
}
res, err := o.client.GetOIDCURL(r.Context(), &proto.GetOIDCURLRequest{
Id: o.id,
AccountId: o.accountId,
RedirectUrl: redirectURL.String(),
})
if err != nil {
return "", "", fmt.Errorf("get OIDC URL: %w", err)
}
return "", res.GetUrl(), nil
}

View File

@@ -0,0 +1,61 @@
package auth
import (
"fmt"
"net/http"
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/shared/management/proto"
)
const passwordFormId = "password"
type Password struct {
id, accountId string
client authenticator
}
func NewPassword(client authenticator, id, accountId string) Password {
return Password{
id: id,
accountId: accountId,
client: client,
}
}
func (Password) Type() auth.Method {
return auth.MethodPassword
}
// Authenticate attempts to authenticate the request using a form
// value passed in the request.
// If authentication fails, the required HTTP form ID is returned
// so that it can be injected into a request from the UI so that
// authentication may be successful.
func (p Password) Authenticate(r *http.Request) (string, string, error) {
password := r.FormValue(passwordFormId)
if password == "" {
// No password submitted; return the form ID so the UI can prompt the user.
return "", passwordFormId, nil
}
res, err := p.client.Authenticate(r.Context(), &proto.AuthenticateRequest{
Id: p.id,
AccountId: p.accountId,
Request: &proto.AuthenticateRequest_Password{
Password: &proto.PasswordRequest{
Password: password,
},
},
})
if err != nil {
return "", "", fmt.Errorf("authenticate password: %w", err)
}
if res.GetSuccess() {
return res.GetSessionToken(), "", nil
}
return "", passwordFormId, nil
}

View File

@@ -0,0 +1,61 @@
package auth
import (
"fmt"
"net/http"
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/shared/management/proto"
)
const pinFormId = "pin"
type Pin struct {
id, accountId string
client authenticator
}
func NewPin(client authenticator, id, accountId string) Pin {
return Pin{
id: id,
accountId: accountId,
client: client,
}
}
func (Pin) Type() auth.Method {
return auth.MethodPIN
}
// Authenticate attempts to authenticate the request using a form
// value passed in the request.
// If authentication fails, the required HTTP form ID is returned
// so that it can be injected into a request from the UI so that
// authentication may be successful.
func (p Pin) Authenticate(r *http.Request) (string, string, error) {
pin := r.FormValue(pinFormId)
if pin == "" {
// No PIN submitted; return the form ID so the UI can prompt the user.
return "", pinFormId, nil
}
res, err := p.client.Authenticate(r.Context(), &proto.AuthenticateRequest{
Id: p.id,
AccountId: p.accountId,
Request: &proto.AuthenticateRequest_Pin{
Pin: &proto.PinRequest{
Pin: pin,
},
},
})
if err != nil {
return "", "", fmt.Errorf("authenticate pin: %w", err)
}
if res.GetSuccess() {
return res.GetSessionToken(), "", nil
}
return "", pinFormId, nil
}