From 7b6294b62419f9e9ffe6aecd8c2827722cb0deb1 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Sun, 8 Feb 2026 23:24:43 +0800 Subject: [PATCH] Refuse to service a service if auth setup failed --- proxy/internal/auth/middleware.go | 22 +- proxy/internal/auth/middleware_test.go | 502 +++++++++++++++++++++++++ proxy/server.go | 5 +- 3 files changed, 523 insertions(+), 6 deletions(-) create mode 100644 proxy/internal/auth/middleware_test.go diff --git a/proxy/internal/auth/middleware.go b/proxy/internal/auth/middleware.go index edffc22cf..95c5955e7 100644 --- a/proxy/internal/auth/middleware.go +++ b/proxy/internal/auth/middleware.go @@ -4,6 +4,7 @@ import ( "context" "crypto/ed25519" "encoding/base64" + "fmt" "net" "net/http" "sync" @@ -134,15 +135,25 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler { }) } -func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration) { +// AddDomain registers authentication schemes for the given domain. +// If schemes are provided, a valid session public key is required to sign/verify +// session JWTs. Returns an error if the key is missing or invalid. +// Callers must not serve the domain if this returns an error, to avoid +// exposing an unauthenticated service. +func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration) error { + if len(schemes) == 0 { + mw.domainsMux.Lock() + defer mw.domainsMux.Unlock() + mw.domains[domain] = DomainConfig{} + return nil + } + pubKeyBytes, err := base64.StdEncoding.DecodeString(publicKeyB64) if err != nil { - // TODO: log - return + return fmt.Errorf("decode session public key for domain %s: %w", domain, err) } if len(pubKeyBytes) != ed25519.PublicKeySize { - // TODO: log - return + return fmt.Errorf("invalid session public key size for domain %s: got %d, want %d", domain, len(pubKeyBytes), ed25519.PublicKeySize) } mw.domainsMux.Lock() @@ -152,6 +163,7 @@ func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 st SessionPublicKey: pubKeyBytes, SessionExpiration: expiration, } + return nil } func (mw *Middleware) RemoveDomain(domain string) { diff --git a/proxy/internal/auth/middleware_test.go b/proxy/internal/auth/middleware_test.go new file mode 100644 index 000000000..f095b9e79 --- /dev/null +++ b/proxy/internal/auth/middleware_test.go @@ -0,0 +1,502 @@ +package auth + +import ( + "crypto/ed25519" + "crypto/rand" + "encoding/base64" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey" + "github.com/netbirdio/netbird/proxy/auth" +) + +func generateTestKeyPair(t *testing.T) *sessionkey.KeyPair { + t.Helper() + kp, err := sessionkey.GenerateKeyPair() + require.NoError(t, err) + return kp +} + +// stubScheme is a minimal Scheme implementation for testing. +type stubScheme struct { + method auth.Method + token string + promptID string + authFn func(*http.Request) (string, string) +} + +func (s *stubScheme) Type() auth.Method { return s.method } + +func (s *stubScheme) Authenticate(r *http.Request) (string, string) { + if s.authFn != nil { + return s.authFn(r) + } + return s.token, s.promptID +} + +func newPassthroughHandler() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("backend")) + }) +} + +func TestAddDomain_ValidKey(t *testing.T) { + mw := NewMiddleware(log.StandardLogger()) + kp := generateTestKeyPair(t) + + scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} + err := mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour) + require.NoError(t, err) + + mw.domainsMux.RLock() + config, exists := mw.domains["example.com"] + mw.domainsMux.RUnlock() + + assert.True(t, exists, "domain should be registered") + assert.Len(t, config.Schemes, 1) + assert.Equal(t, ed25519.PublicKeySize, len(config.SessionPublicKey)) + assert.Equal(t, time.Hour, config.SessionExpiration) +} + +func TestAddDomain_EmptyKey(t *testing.T) { + mw := NewMiddleware(log.StandardLogger()) + + scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} + err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid session public key size") + + mw.domainsMux.RLock() + _, exists := mw.domains["example.com"] + mw.domainsMux.RUnlock() + assert.False(t, exists, "domain must not be registered with an empty session key") +} + +func TestAddDomain_InvalidBase64(t *testing.T) { + mw := NewMiddleware(log.StandardLogger()) + + scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} + err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour) + require.Error(t, err) + assert.Contains(t, err.Error(), "decode session public key") + + mw.domainsMux.RLock() + _, exists := mw.domains["example.com"] + mw.domainsMux.RUnlock() + assert.False(t, exists, "domain must not be registered with invalid base64 key") +} + +func TestAddDomain_WrongKeySize(t *testing.T) { + mw := NewMiddleware(log.StandardLogger()) + + shortKey := base64.StdEncoding.EncodeToString([]byte("tooshort")) + scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} + err := mw.AddDomain("example.com", []Scheme{scheme}, shortKey, time.Hour) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid session public key size") + + mw.domainsMux.RLock() + _, exists := mw.domains["example.com"] + mw.domainsMux.RUnlock() + assert.False(t, exists, "domain must not be registered with a wrong-size key") +} + +func TestAddDomain_NoSchemes_NoKeyRequired(t *testing.T) { + mw := NewMiddleware(log.StandardLogger()) + + err := mw.AddDomain("example.com", nil, "", time.Hour) + require.NoError(t, err, "domains with no auth schemes should not require a key") + + mw.domainsMux.RLock() + _, exists := mw.domains["example.com"] + mw.domainsMux.RUnlock() + assert.True(t, exists) +} + +func TestAddDomain_OverwritesPreviousConfig(t *testing.T) { + mw := NewMiddleware(log.StandardLogger()) + kp1 := generateTestKeyPair(t) + kp2 := generateTestKeyPair(t) + + scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} + + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp2.PublicKey, 2*time.Hour)) + + mw.domainsMux.RLock() + config := mw.domains["example.com"] + mw.domainsMux.RUnlock() + + pubKeyBytes, _ := base64.StdEncoding.DecodeString(kp2.PublicKey) + assert.Equal(t, ed25519.PublicKey(pubKeyBytes), config.SessionPublicKey, "should use the latest key") + assert.Equal(t, 2*time.Hour, config.SessionExpiration) +} + +func TestRemoveDomain(t *testing.T) { + mw := NewMiddleware(log.StandardLogger()) + kp := generateTestKeyPair(t) + + scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour)) + + mw.RemoveDomain("example.com") + + mw.domainsMux.RLock() + _, exists := mw.domains["example.com"] + mw.domainsMux.RUnlock() + assert.False(t, exists) +} + +func TestProtect_UnknownDomainPassesThrough(t *testing.T) { + mw := NewMiddleware(log.StandardLogger()) + handler := mw.Protect(newPassthroughHandler()) + + req := httptest.NewRequest(http.MethodGet, "http://unknown.com/", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "backend", rec.Body.String()) +} + +func TestProtect_DomainWithNoSchemesPassesThrough(t *testing.T) { + mw := NewMiddleware(log.StandardLogger()) + require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour)) + + handler := mw.Protect(newPassthroughHandler()) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "backend", rec.Body.String()) +} + +func TestProtect_UnauthenticatedRequestIsBlocked(t *testing.T) { + mw := NewMiddleware(log.StandardLogger()) + kp := generateTestKeyPair(t) + + scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour)) + + var backendCalled bool + backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + backendCalled = true + w.WriteHeader(http.StatusOK) + }) + handler := mw.Protect(backend) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.False(t, backendCalled, "unauthenticated request should not reach backend") +} + +func TestProtect_HostWithPortIsMatched(t *testing.T) { + mw := NewMiddleware(log.StandardLogger()) + kp := generateTestKeyPair(t) + + scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour)) + + var backendCalled bool + backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + backendCalled = true + w.WriteHeader(http.StatusOK) + }) + handler := mw.Protect(backend) + + req := httptest.NewRequest(http.MethodGet, "http://example.com:8443/", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.False(t, backendCalled, "host with port should still match the protected domain") +} + +func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) { + mw := NewMiddleware(log.StandardLogger()) + kp := generateTestKeyPair(t) + + scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour)) + + token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour) + require.NoError(t, err) + + handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user := UserFromContext(r.Context()) + method := MethodFromContext(r.Context()) + assert.Equal(t, "test-user", user) + assert.Equal(t, auth.MethodPIN, method) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("authenticated")) + })) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: token}) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "authenticated", rec.Body.String()) +} + +func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) { + mw := NewMiddleware(log.StandardLogger()) + kp := generateTestKeyPair(t) + + scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour)) + + // Sign a token that expired 1 second ago. + token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, -time.Second) + require.NoError(t, err) + + var backendCalled bool + backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + backendCalled = true + w.WriteHeader(http.StatusOK) + }) + handler := mw.Protect(backend) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: token}) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.False(t, backendCalled, "expired session should not reach the backend") +} + +func TestProtect_WrongDomainCookieIsRejected(t *testing.T) { + mw := NewMiddleware(log.StandardLogger()) + kp := generateTestKeyPair(t) + + scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour)) + + // Token signed for a different domain audience. + token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "other.com", auth.MethodPIN, time.Hour) + require.NoError(t, err) + + var backendCalled bool + backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + backendCalled = true + w.WriteHeader(http.StatusOK) + }) + handler := mw.Protect(backend) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: token}) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.False(t, backendCalled, "cookie for wrong domain should be rejected") +} + +func TestProtect_WrongKeyCookieIsRejected(t *testing.T) { + mw := NewMiddleware(log.StandardLogger()) + kp1 := generateTestKeyPair(t) + kp2 := generateTestKeyPair(t) + + scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour)) + + // Token signed with a different private key. + token, err := sessionkey.SignToken(kp2.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour) + require.NoError(t, err) + + var backendCalled bool + backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + backendCalled = true + w.WriteHeader(http.StatusOK) + }) + handler := mw.Protect(backend) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: token}) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.False(t, backendCalled, "cookie signed by wrong key should be rejected") +} + +func TestProtect_SchemeAuthSetsSessionCookie(t *testing.T) { + mw := NewMiddleware(log.StandardLogger()) + kp := generateTestKeyPair(t) + + token, err := sessionkey.SignToken(kp.PrivateKey, "pin-user", "example.com", auth.MethodPIN, time.Hour) + require.NoError(t, err) + + scheme := &stubScheme{ + method: auth.MethodPIN, + authFn: func(r *http.Request) (string, string) { + if r.FormValue("pin") == "111111" { + return token, "" + } + return "", "pin" + }, + } + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour)) + + handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "pin-user", UserFromContext(r.Context())) + assert.Equal(t, auth.MethodPIN, MethodFromContext(r.Context())) + w.WriteHeader(http.StatusOK) + })) + + // Submit the PIN via form POST. + form := url.Values{"pin": {"111111"}} + req := httptest.NewRequest(http.MethodPost, "http://example.com/", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + cookies := rec.Result().Cookies() + var sessionCookie *http.Cookie + for _, c := range cookies { + if c.Name == auth.SessionCookieName { + sessionCookie = c + break + } + } + require.NotNil(t, sessionCookie, "session cookie should be set after successful auth") + assert.True(t, sessionCookie.HttpOnly) + assert.True(t, sessionCookie.Secure) + assert.Equal(t, http.SameSiteLaxMode, sessionCookie.SameSite) +} + +func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) { + mw := NewMiddleware(log.StandardLogger()) + kp := generateTestKeyPair(t) + + scheme := &stubScheme{ + method: auth.MethodPIN, + authFn: func(_ *http.Request) (string, string) { + return "", "pin" + }, + } + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour)) + + handler := mw.Protect(newPassthroughHandler()) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + for _, c := range rec.Result().Cookies() { + assert.NotEqual(t, auth.SessionCookieName, c.Name, "no session cookie should be set on failed auth") + } +} + +func TestProtect_MultipleSchemes(t *testing.T) { + mw := NewMiddleware(log.StandardLogger()) + kp := generateTestKeyPair(t) + + token, err := sessionkey.SignToken(kp.PrivateKey, "password-user", "example.com", auth.MethodPassword, time.Hour) + require.NoError(t, err) + + // First scheme (PIN) always fails, second scheme (password) succeeds. + pinScheme := &stubScheme{ + method: auth.MethodPIN, + authFn: func(_ *http.Request) (string, string) { + return "", "pin" + }, + } + passwordScheme := &stubScheme{ + method: auth.MethodPassword, + authFn: func(r *http.Request) (string, string) { + if r.FormValue("password") == "secret" { + return token, "" + } + return "", "password" + }, + } + require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme, passwordScheme}, kp.PublicKey, time.Hour)) + + handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, auth.MethodPassword, MethodFromContext(r.Context())) + w.WriteHeader(http.StatusOK) + })) + + form := url.Values{"password": {"secret"}} + req := httptest.NewRequest(http.MethodPost, "http://example.com/", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) +} + +func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) { + mw := NewMiddleware(log.StandardLogger()) + kp := generateTestKeyPair(t) + + // Return a garbage token that won't validate. + scheme := &stubScheme{ + method: auth.MethodPIN, + authFn: func(_ *http.Request) (string, string) { + return "invalid-jwt-token", "" + }, + } + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour)) + + handler := mw.Protect(newPassthroughHandler()) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestAddDomain_RandomBytes32NotEd25519(t *testing.T) { + mw := NewMiddleware(log.StandardLogger()) + + // 32 random bytes that happen to be valid base64 and correct size + // but are actually a valid ed25519 public key length-wise. + // This should succeed because ed25519 public keys are just 32 bytes. + randomBytes := make([]byte, ed25519.PublicKeySize) + _, err := rand.Read(randomBytes) + require.NoError(t, err) + + key := base64.StdEncoding.EncodeToString(randomBytes) + scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} + + err = mw.AddDomain("example.com", []Scheme{scheme}, key, time.Hour) + require.NoError(t, err, "any 32-byte key should be accepted at registration time") +} + +func TestAddDomain_InvalidKeyDoesNotCorruptExistingConfig(t *testing.T) { + mw := NewMiddleware(log.StandardLogger()) + kp := generateTestKeyPair(t) + + scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour)) + + // Attempt to overwrite with an invalid key. + err := mw.AddDomain("example.com", []Scheme{scheme}, "bad", time.Hour) + require.Error(t, err) + + // The original valid config should still be intact. + mw.domainsMux.RLock() + config, exists := mw.domains["example.com"] + mw.domainsMux.RUnlock() + + assert.True(t, exists, "original config should still exist") + assert.Len(t, config.Schemes, 1) + assert.Equal(t, time.Hour, config.SessionExpiration) +} diff --git a/proxy/server.go b/proxy/server.go index 780f9a8d4..06012938a 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -449,7 +449,10 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping) } maxSessionAge := time.Duration(mapping.GetAuth().GetMaxSessionAgeSeconds()) * time.Second - s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge) + if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge); err != nil { + s.Logger.WithField("domain", mapping.GetDomain()).WithError(err).Error("Auth setup failed, refusing to serve domain without authentication") + return + } s.proxy.AddMapping(s.protoToMapping(mapping)) }