mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
refactor: add ValidateSession gRPC and streamline test setup
- Add ValidateSession gRPC method for proxy-side user validation - Move group access validation from REST callback to gRPC layer - Capture user info in access logs via CapturedData mutable pointer - Create validate_session_test.go for gRPC validation tests - Simplify auth_callback_integration_test.go to create accounts programmatically instead of using SQL file - SQL test data file now only used by validate_session_test.go
This commit is contained in:
@@ -24,6 +24,11 @@ type authenticator interface {
|
||||
Authenticate(ctx context.Context, in *proto.AuthenticateRequest, opts ...grpc.CallOption) (*proto.AuthenticateResponse, error)
|
||||
}
|
||||
|
||||
// SessionValidator validates session tokens and checks user access permissions.
|
||||
type SessionValidator interface {
|
||||
ValidateSession(ctx context.Context, in *proto.ValidateSessionRequest, opts ...grpc.CallOption) (*proto.ValidateSessionResponse, error)
|
||||
}
|
||||
|
||||
type Scheme interface {
|
||||
Type() auth.Method
|
||||
// Authenticate should check the passed request and determine whether
|
||||
@@ -42,18 +47,23 @@ type DomainConfig struct {
|
||||
}
|
||||
|
||||
type Middleware struct {
|
||||
domainsMux sync.RWMutex
|
||||
domains map[string]DomainConfig
|
||||
logger *log.Logger
|
||||
domainsMux sync.RWMutex
|
||||
domains map[string]DomainConfig
|
||||
logger *log.Logger
|
||||
sessionValidator SessionValidator
|
||||
}
|
||||
|
||||
func NewMiddleware(logger *log.Logger) *Middleware {
|
||||
// NewMiddleware creates a new authentication middleware.
|
||||
// The sessionValidator is optional; if nil, OIDC session tokens will be validated
|
||||
// locally without group access checks.
|
||||
func NewMiddleware(logger *log.Logger, sessionValidator SessionValidator) *Middleware {
|
||||
if logger == nil {
|
||||
logger = log.StandardLogger()
|
||||
}
|
||||
return &Middleware{
|
||||
domains: make(map[string]DomainConfig),
|
||||
logger: logger,
|
||||
domains: make(map[string]DomainConfig),
|
||||
logger: logger,
|
||||
sessionValidator: sessionValidator,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,9 +112,11 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
|
||||
// Check for an existing session cookie (contains JWT)
|
||||
if cookie, err := r.Cookie(auth.SessionCookieName); err == nil {
|
||||
if userID, method, err := auth.ValidateSessionJWT(cookie.Value, host, config.SessionPublicKey); err == nil {
|
||||
ctx := withAuthMethod(r.Context(), auth.Method(method))
|
||||
ctx = withAuthUser(ctx, userID)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetUserID(userID)
|
||||
cd.SetAuthMethod(method)
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -114,13 +126,23 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
|
||||
for _, scheme := range config.Schemes {
|
||||
token, promptData := scheme.Authenticate(r)
|
||||
if token != "" {
|
||||
if _, _, err := auth.ValidateSessionJWT(token, host, config.SessionPublicKey); err != nil {
|
||||
userID, err := mw.validateSessionToken(r.Context(), host, token, config.SessionPublicKey, scheme.Type())
|
||||
if err != nil {
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
}
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if userID == "" {
|
||||
var requestID string
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
requestID = cd.GetRequestID()
|
||||
}
|
||||
web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", "You are not authorized to access this service", requestID)
|
||||
return
|
||||
}
|
||||
|
||||
expiration := config.SessionExpiration
|
||||
if expiration == 0 {
|
||||
@@ -191,6 +213,40 @@ func (mw *Middleware) RemoveDomain(domain string) {
|
||||
delete(mw.domains, domain)
|
||||
}
|
||||
|
||||
// validateSessionToken validates a session token, optionally checking group access via gRPC.
|
||||
// For OIDC tokens with a configured validator, it calls ValidateSession to check group access.
|
||||
// For other auth methods (PIN, password), it validates the JWT locally.
|
||||
// Returns the user ID if valid, empty string if access denied, or error for invalid tokens.
|
||||
func (mw *Middleware) validateSessionToken(ctx context.Context, host, token string, publicKey ed25519.PublicKey, method auth.Method) (string, error) {
|
||||
// For OIDC with a session validator, call the gRPC service to check group access
|
||||
if method == auth.MethodOIDC && mw.sessionValidator != nil {
|
||||
resp, err := mw.sessionValidator.ValidateSession(ctx, &proto.ValidateSessionRequest{
|
||||
Domain: host,
|
||||
SessionToken: token,
|
||||
})
|
||||
if err != nil {
|
||||
mw.logger.WithError(err).Error("ValidateSession gRPC call failed")
|
||||
return "", fmt.Errorf("session validation failed")
|
||||
}
|
||||
if !resp.Valid {
|
||||
mw.logger.WithFields(log.Fields{
|
||||
"domain": host,
|
||||
"denied_reason": resp.DeniedReason,
|
||||
"user_id": resp.UserId,
|
||||
}).Debug("Session validation denied")
|
||||
return "", nil
|
||||
}
|
||||
return resp.UserId, nil
|
||||
}
|
||||
|
||||
// For non-OIDC methods or when no validator is configured, validate JWT locally
|
||||
userID, _, err := auth.ValidateSessionJWT(token, host, publicKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
// stripSessionTokenParam returns the request URI with the session_token query
|
||||
// parameter removed so it doesn't linger in the browser's address bar or history.
|
||||
func stripSessionTokenParam(u *url.URL) string {
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||
"github.com/netbirdio/netbird/proxy/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
)
|
||||
|
||||
func generateTestKeyPair(t *testing.T) *sessionkey.KeyPair {
|
||||
@@ -28,10 +29,10 @@ func generateTestKeyPair(t *testing.T) *sessionkey.KeyPair {
|
||||
|
||||
// stubScheme is a minimal Scheme implementation for testing.
|
||||
type stubScheme struct {
|
||||
method auth.Method
|
||||
token string
|
||||
promptID string
|
||||
authFn func(*http.Request) (string, string)
|
||||
method auth.Method
|
||||
token string
|
||||
promptID string
|
||||
authFn func(*http.Request) (string, string)
|
||||
}
|
||||
|
||||
func (s *stubScheme) Type() auth.Method { return s.method }
|
||||
@@ -51,7 +52,7 @@ func newPassthroughHandler() http.Handler {
|
||||
}
|
||||
|
||||
func TestAddDomain_ValidKey(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
@@ -69,7 +70,7 @@ func TestAddDomain_ValidKey(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAddDomain_EmptyKey(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour)
|
||||
@@ -83,7 +84,7 @@ func TestAddDomain_EmptyKey(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAddDomain_InvalidBase64(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour)
|
||||
@@ -97,7 +98,7 @@ func TestAddDomain_InvalidBase64(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAddDomain_WrongKeySize(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
|
||||
shortKey := base64.StdEncoding.EncodeToString([]byte("tooshort"))
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
@@ -112,7 +113,7 @@ func TestAddDomain_WrongKeySize(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAddDomain_NoSchemes_NoKeyRequired(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
|
||||
err := mw.AddDomain("example.com", nil, "", time.Hour)
|
||||
require.NoError(t, err, "domains with no auth schemes should not require a key")
|
||||
@@ -124,7 +125,7 @@ func TestAddDomain_NoSchemes_NoKeyRequired(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAddDomain_OverwritesPreviousConfig(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp1 := generateTestKeyPair(t)
|
||||
kp2 := generateTestKeyPair(t)
|
||||
|
||||
@@ -143,7 +144,7 @@ func TestAddDomain_OverwritesPreviousConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRemoveDomain(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
@@ -158,7 +159,7 @@ func TestRemoveDomain(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_UnknownDomainPassesThrough(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://unknown.com/", nil)
|
||||
@@ -170,7 +171,7 @@ func TestProtect_UnknownDomainPassesThrough(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_DomainWithNoSchemesPassesThrough(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
@@ -184,7 +185,7 @@ func TestProtect_DomainWithNoSchemesPassesThrough(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_UnauthenticatedRequestIsBlocked(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
@@ -205,7 +206,7 @@ func TestProtect_UnauthenticatedRequestIsBlocked(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_HostWithPortIsMatched(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
@@ -226,7 +227,7 @@ func TestProtect_HostWithPortIsMatched(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
@@ -235,16 +236,18 @@ func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
capturedData := &proxy.CapturedData{}
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
user := UserFromContext(r.Context())
|
||||
method := MethodFromContext(r.Context())
|
||||
assert.Equal(t, "test-user", user)
|
||||
assert.Equal(t, auth.MethodPIN, method)
|
||||
cd := proxy.CapturedDataFromContext(r.Context())
|
||||
require.NotNil(t, cd)
|
||||
assert.Equal(t, "test-user", cd.GetUserID())
|
||||
assert.Equal(t, "pin", cd.GetAuthMethod())
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("authenticated"))
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req = req.WithContext(proxy.WithCapturedData(req.Context(), capturedData))
|
||||
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: token})
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
@@ -254,7 +257,7 @@ func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
@@ -280,7 +283,7 @@ func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_WrongDomainCookieIsRejected(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
@@ -306,7 +309,7 @@ func TestProtect_WrongDomainCookieIsRejected(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_WrongKeyCookieIsRejected(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp1 := generateTestKeyPair(t)
|
||||
kp2 := generateTestKeyPair(t)
|
||||
|
||||
@@ -333,7 +336,7 @@ func TestProtect_WrongKeyCookieIsRejected(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "pin-user", "example.com", auth.MethodPIN, time.Hour)
|
||||
@@ -383,7 +386,7 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{
|
||||
@@ -406,7 +409,7 @@ func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_MultipleSchemes(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "password-user", "example.com", auth.MethodPassword, time.Hour)
|
||||
@@ -448,7 +451,7 @@ func TestProtect_MultipleSchemes(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
// Return a garbage token that won't validate.
|
||||
@@ -470,7 +473,7 @@ func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAddDomain_RandomBytes32NotEd25519(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
|
||||
// 32 random bytes that happen to be valid base64 and correct size
|
||||
// but are actually a valid ed25519 public key length-wise.
|
||||
@@ -487,7 +490,7 @@ func TestAddDomain_RandomBytes32NotEd25519(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAddDomain_InvalidKeyDoesNotCorruptExistingConfig(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
|
||||
Reference in New Issue
Block a user