mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
Refuse to service a service if auth setup failed
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
@@ -134,15 +135,25 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
|
||||
})
|
||||
}
|
||||
|
||||
func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration) {
|
||||
// AddDomain registers authentication schemes for the given domain.
|
||||
// If schemes are provided, a valid session public key is required to sign/verify
|
||||
// session JWTs. Returns an error if the key is missing or invalid.
|
||||
// Callers must not serve the domain if this returns an error, to avoid
|
||||
// exposing an unauthenticated service.
|
||||
func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration) error {
|
||||
if len(schemes) == 0 {
|
||||
mw.domainsMux.Lock()
|
||||
defer mw.domainsMux.Unlock()
|
||||
mw.domains[domain] = DomainConfig{}
|
||||
return nil
|
||||
}
|
||||
|
||||
pubKeyBytes, err := base64.StdEncoding.DecodeString(publicKeyB64)
|
||||
if err != nil {
|
||||
// TODO: log
|
||||
return
|
||||
return fmt.Errorf("decode session public key for domain %s: %w", domain, err)
|
||||
}
|
||||
if len(pubKeyBytes) != ed25519.PublicKeySize {
|
||||
// TODO: log
|
||||
return
|
||||
return fmt.Errorf("invalid session public key size for domain %s: got %d, want %d", domain, len(pubKeyBytes), ed25519.PublicKeySize)
|
||||
}
|
||||
|
||||
mw.domainsMux.Lock()
|
||||
@@ -152,6 +163,7 @@ func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 st
|
||||
SessionPublicKey: pubKeyBytes,
|
||||
SessionExpiration: expiration,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mw *Middleware) RemoveDomain(domain string) {
|
||||
|
||||
502
proxy/internal/auth/middleware_test.go
Normal file
502
proxy/internal/auth/middleware_test.go
Normal file
@@ -0,0 +1,502 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||
"github.com/netbirdio/netbird/proxy/auth"
|
||||
)
|
||||
|
||||
func generateTestKeyPair(t *testing.T) *sessionkey.KeyPair {
|
||||
t.Helper()
|
||||
kp, err := sessionkey.GenerateKeyPair()
|
||||
require.NoError(t, err)
|
||||
return kp
|
||||
}
|
||||
|
||||
// stubScheme is a minimal Scheme implementation for testing.
|
||||
type stubScheme struct {
|
||||
method auth.Method
|
||||
token string
|
||||
promptID string
|
||||
authFn func(*http.Request) (string, string)
|
||||
}
|
||||
|
||||
func (s *stubScheme) Type() auth.Method { return s.method }
|
||||
|
||||
func (s *stubScheme) Authenticate(r *http.Request) (string, string) {
|
||||
if s.authFn != nil {
|
||||
return s.authFn(r)
|
||||
}
|
||||
return s.token, s.promptID
|
||||
}
|
||||
|
||||
func newPassthroughHandler() http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("backend"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestAddDomain_ValidKey(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
mw.domainsMux.RLock()
|
||||
config, exists := mw.domains["example.com"]
|
||||
mw.domainsMux.RUnlock()
|
||||
|
||||
assert.True(t, exists, "domain should be registered")
|
||||
assert.Len(t, config.Schemes, 1)
|
||||
assert.Equal(t, ed25519.PublicKeySize, len(config.SessionPublicKey))
|
||||
assert.Equal(t, time.Hour, config.SessionExpiration)
|
||||
}
|
||||
|
||||
func TestAddDomain_EmptyKey(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid session public key size")
|
||||
|
||||
mw.domainsMux.RLock()
|
||||
_, exists := mw.domains["example.com"]
|
||||
mw.domainsMux.RUnlock()
|
||||
assert.False(t, exists, "domain must not be registered with an empty session key")
|
||||
}
|
||||
|
||||
func TestAddDomain_InvalidBase64(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "decode session public key")
|
||||
|
||||
mw.domainsMux.RLock()
|
||||
_, exists := mw.domains["example.com"]
|
||||
mw.domainsMux.RUnlock()
|
||||
assert.False(t, exists, "domain must not be registered with invalid base64 key")
|
||||
}
|
||||
|
||||
func TestAddDomain_WrongKeySize(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
|
||||
shortKey := base64.StdEncoding.EncodeToString([]byte("tooshort"))
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, shortKey, time.Hour)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid session public key size")
|
||||
|
||||
mw.domainsMux.RLock()
|
||||
_, exists := mw.domains["example.com"]
|
||||
mw.domainsMux.RUnlock()
|
||||
assert.False(t, exists, "domain must not be registered with a wrong-size key")
|
||||
}
|
||||
|
||||
func TestAddDomain_NoSchemes_NoKeyRequired(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
|
||||
err := mw.AddDomain("example.com", nil, "", time.Hour)
|
||||
require.NoError(t, err, "domains with no auth schemes should not require a key")
|
||||
|
||||
mw.domainsMux.RLock()
|
||||
_, exists := mw.domains["example.com"]
|
||||
mw.domainsMux.RUnlock()
|
||||
assert.True(t, exists)
|
||||
}
|
||||
|
||||
func TestAddDomain_OverwritesPreviousConfig(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
kp1 := generateTestKeyPair(t)
|
||||
kp2 := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp2.PublicKey, 2*time.Hour))
|
||||
|
||||
mw.domainsMux.RLock()
|
||||
config := mw.domains["example.com"]
|
||||
mw.domainsMux.RUnlock()
|
||||
|
||||
pubKeyBytes, _ := base64.StdEncoding.DecodeString(kp2.PublicKey)
|
||||
assert.Equal(t, ed25519.PublicKey(pubKeyBytes), config.SessionPublicKey, "should use the latest key")
|
||||
assert.Equal(t, 2*time.Hour, config.SessionExpiration)
|
||||
}
|
||||
|
||||
func TestRemoveDomain(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour))
|
||||
|
||||
mw.RemoveDomain("example.com")
|
||||
|
||||
mw.domainsMux.RLock()
|
||||
_, exists := mw.domains["example.com"]
|
||||
mw.domainsMux.RUnlock()
|
||||
assert.False(t, exists)
|
||||
}
|
||||
|
||||
func TestProtect_UnknownDomainPassesThrough(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://unknown.com/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "backend", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestProtect_DomainWithNoSchemesPassesThrough(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "backend", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestProtect_UnauthenticatedRequestIsBlocked(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour))
|
||||
|
||||
var backendCalled bool
|
||||
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
backendCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
handler := mw.Protect(backend)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.False(t, backendCalled, "unauthenticated request should not reach backend")
|
||||
}
|
||||
|
||||
func TestProtect_HostWithPortIsMatched(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour))
|
||||
|
||||
var backendCalled bool
|
||||
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
backendCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
handler := mw.Protect(backend)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com:8443/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.False(t, backendCalled, "host with port should still match the protected domain")
|
||||
}
|
||||
|
||||
func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour))
|
||||
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
user := UserFromContext(r.Context())
|
||||
method := MethodFromContext(r.Context())
|
||||
assert.Equal(t, "test-user", user)
|
||||
assert.Equal(t, auth.MethodPIN, method)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("authenticated"))
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: token})
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "authenticated", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour))
|
||||
|
||||
// Sign a token that expired 1 second ago.
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, -time.Second)
|
||||
require.NoError(t, err)
|
||||
|
||||
var backendCalled bool
|
||||
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
backendCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
handler := mw.Protect(backend)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: token})
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.False(t, backendCalled, "expired session should not reach the backend")
|
||||
}
|
||||
|
||||
func TestProtect_WrongDomainCookieIsRejected(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour))
|
||||
|
||||
// Token signed for a different domain audience.
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "other.com", auth.MethodPIN, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
var backendCalled bool
|
||||
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
backendCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
handler := mw.Protect(backend)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: token})
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.False(t, backendCalled, "cookie for wrong domain should be rejected")
|
||||
}
|
||||
|
||||
func TestProtect_WrongKeyCookieIsRejected(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
kp1 := generateTestKeyPair(t)
|
||||
kp2 := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour))
|
||||
|
||||
// Token signed with a different private key.
|
||||
token, err := sessionkey.SignToken(kp2.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
var backendCalled bool
|
||||
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
backendCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
handler := mw.Protect(backend)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: token})
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.False(t, backendCalled, "cookie signed by wrong key should be rejected")
|
||||
}
|
||||
|
||||
func TestProtect_SchemeAuthSetsSessionCookie(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "pin-user", "example.com", auth.MethodPIN, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
scheme := &stubScheme{
|
||||
method: auth.MethodPIN,
|
||||
authFn: func(r *http.Request) (string, string) {
|
||||
if r.FormValue("pin") == "111111" {
|
||||
return token, ""
|
||||
}
|
||||
return "", "pin"
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour))
|
||||
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "pin-user", UserFromContext(r.Context()))
|
||||
assert.Equal(t, auth.MethodPIN, MethodFromContext(r.Context()))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Submit the PIN via form POST.
|
||||
form := url.Values{"pin": {"111111"}}
|
||||
req := httptest.NewRequest(http.MethodPost, "http://example.com/", strings.NewReader(form.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
cookies := rec.Result().Cookies()
|
||||
var sessionCookie *http.Cookie
|
||||
for _, c := range cookies {
|
||||
if c.Name == auth.SessionCookieName {
|
||||
sessionCookie = c
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, sessionCookie, "session cookie should be set after successful auth")
|
||||
assert.True(t, sessionCookie.HttpOnly)
|
||||
assert.True(t, sessionCookie.Secure)
|
||||
assert.Equal(t, http.SameSiteLaxMode, sessionCookie.SameSite)
|
||||
}
|
||||
|
||||
func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{
|
||||
method: auth.MethodPIN,
|
||||
authFn: func(_ *http.Request) (string, string) {
|
||||
return "", "pin"
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
for _, c := range rec.Result().Cookies() {
|
||||
assert.NotEqual(t, auth.SessionCookieName, c.Name, "no session cookie should be set on failed auth")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProtect_MultipleSchemes(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "password-user", "example.com", auth.MethodPassword, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
// First scheme (PIN) always fails, second scheme (password) succeeds.
|
||||
pinScheme := &stubScheme{
|
||||
method: auth.MethodPIN,
|
||||
authFn: func(_ *http.Request) (string, string) {
|
||||
return "", "pin"
|
||||
},
|
||||
}
|
||||
passwordScheme := &stubScheme{
|
||||
method: auth.MethodPassword,
|
||||
authFn: func(r *http.Request) (string, string) {
|
||||
if r.FormValue("password") == "secret" {
|
||||
return token, ""
|
||||
}
|
||||
return "", "password"
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme, passwordScheme}, kp.PublicKey, time.Hour))
|
||||
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, auth.MethodPassword, MethodFromContext(r.Context()))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
form := url.Values{"password": {"secret"}}
|
||||
req := httptest.NewRequest(http.MethodPost, "http://example.com/", strings.NewReader(form.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
// Return a garbage token that won't validate.
|
||||
scheme := &stubScheme{
|
||||
method: auth.MethodPIN,
|
||||
authFn: func(_ *http.Request) (string, string) {
|
||||
return "invalid-jwt-token", ""
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestAddDomain_RandomBytes32NotEd25519(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
|
||||
// 32 random bytes that happen to be valid base64 and correct size
|
||||
// but are actually a valid ed25519 public key length-wise.
|
||||
// This should succeed because ed25519 public keys are just 32 bytes.
|
||||
randomBytes := make([]byte, ed25519.PublicKeySize)
|
||||
_, err := rand.Read(randomBytes)
|
||||
require.NoError(t, err)
|
||||
|
||||
key := base64.StdEncoding.EncodeToString(randomBytes)
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
|
||||
err = mw.AddDomain("example.com", []Scheme{scheme}, key, time.Hour)
|
||||
require.NoError(t, err, "any 32-byte key should be accepted at registration time")
|
||||
}
|
||||
|
||||
func TestAddDomain_InvalidKeyDoesNotCorruptExistingConfig(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour))
|
||||
|
||||
// Attempt to overwrite with an invalid key.
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "bad", time.Hour)
|
||||
require.Error(t, err)
|
||||
|
||||
// The original valid config should still be intact.
|
||||
mw.domainsMux.RLock()
|
||||
config, exists := mw.domains["example.com"]
|
||||
mw.domainsMux.RUnlock()
|
||||
|
||||
assert.True(t, exists, "original config should still exist")
|
||||
assert.Len(t, config.Schemes, 1)
|
||||
assert.Equal(t, time.Hour, config.SessionExpiration)
|
||||
}
|
||||
@@ -449,7 +449,10 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping)
|
||||
}
|
||||
|
||||
maxSessionAge := time.Duration(mapping.GetAuth().GetMaxSessionAgeSeconds()) * time.Second
|
||||
s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge)
|
||||
if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge); err != nil {
|
||||
s.Logger.WithField("domain", mapping.GetDomain()).WithError(err).Error("Auth setup failed, refusing to serve domain without authentication")
|
||||
return
|
||||
}
|
||||
s.proxy.AddMapping(s.protoToMapping(mapping))
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user