refactor: add ValidateSession gRPC and streamline test setup

- Add ValidateSession gRPC method for proxy-side user validation
- Move group access validation from REST callback to gRPC layer
- Capture user info in access logs via CapturedData mutable pointer
- Create validate_session_test.go for gRPC validation tests
- Simplify auth_callback_integration_test.go to create accounts
  programmatically instead of using SQL file
- SQL test data file now only used by validate_session_test.go
This commit is contained in:
mlsmaycon
2026-02-10 20:31:03 +01:00
parent 0cb02bd906
commit eea6120cd0
15 changed files with 955 additions and 238 deletions

View File

@@ -24,6 +24,11 @@ type authenticator interface {
Authenticate(ctx context.Context, in *proto.AuthenticateRequest, opts ...grpc.CallOption) (*proto.AuthenticateResponse, error)
}
// SessionValidator validates session tokens and checks user access permissions.
type SessionValidator interface {
ValidateSession(ctx context.Context, in *proto.ValidateSessionRequest, opts ...grpc.CallOption) (*proto.ValidateSessionResponse, error)
}
type Scheme interface {
Type() auth.Method
// Authenticate should check the passed request and determine whether
@@ -42,18 +47,23 @@ type DomainConfig struct {
}
type Middleware struct {
domainsMux sync.RWMutex
domains map[string]DomainConfig
logger *log.Logger
domainsMux sync.RWMutex
domains map[string]DomainConfig
logger *log.Logger
sessionValidator SessionValidator
}
func NewMiddleware(logger *log.Logger) *Middleware {
// NewMiddleware creates a new authentication middleware.
// The sessionValidator is optional; if nil, OIDC session tokens will be validated
// locally without group access checks.
func NewMiddleware(logger *log.Logger, sessionValidator SessionValidator) *Middleware {
if logger == nil {
logger = log.StandardLogger()
}
return &Middleware{
domains: make(map[string]DomainConfig),
logger: logger,
domains: make(map[string]DomainConfig),
logger: logger,
sessionValidator: sessionValidator,
}
}
@@ -102,9 +112,11 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
// Check for an existing session cookie (contains JWT)
if cookie, err := r.Cookie(auth.SessionCookieName); err == nil {
if userID, method, err := auth.ValidateSessionJWT(cookie.Value, host, config.SessionPublicKey); err == nil {
ctx := withAuthMethod(r.Context(), auth.Method(method))
ctx = withAuthUser(ctx, userID)
next.ServeHTTP(w, r.WithContext(ctx))
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetUserID(userID)
cd.SetAuthMethod(method)
}
next.ServeHTTP(w, r)
return
}
}
@@ -114,13 +126,23 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
for _, scheme := range config.Schemes {
token, promptData := scheme.Authenticate(r)
if token != "" {
if _, _, err := auth.ValidateSessionJWT(token, host, config.SessionPublicKey); err != nil {
userID, err := mw.validateSessionToken(r.Context(), host, token, config.SessionPublicKey, scheme.Type())
if err != nil {
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
}
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if userID == "" {
var requestID string
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
requestID = cd.GetRequestID()
}
web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", "You are not authorized to access this service", requestID)
return
}
expiration := config.SessionExpiration
if expiration == 0 {
@@ -191,6 +213,40 @@ func (mw *Middleware) RemoveDomain(domain string) {
delete(mw.domains, domain)
}
// validateSessionToken validates a session token, optionally checking group access via gRPC.
// For OIDC tokens with a configured validator, it calls ValidateSession to check group access.
// For other auth methods (PIN, password), it validates the JWT locally.
// Returns the user ID if valid, empty string if access denied, or error for invalid tokens.
func (mw *Middleware) validateSessionToken(ctx context.Context, host, token string, publicKey ed25519.PublicKey, method auth.Method) (string, error) {
// For OIDC with a session validator, call the gRPC service to check group access
if method == auth.MethodOIDC && mw.sessionValidator != nil {
resp, err := mw.sessionValidator.ValidateSession(ctx, &proto.ValidateSessionRequest{
Domain: host,
SessionToken: token,
})
if err != nil {
mw.logger.WithError(err).Error("ValidateSession gRPC call failed")
return "", fmt.Errorf("session validation failed")
}
if !resp.Valid {
mw.logger.WithFields(log.Fields{
"domain": host,
"denied_reason": resp.DeniedReason,
"user_id": resp.UserId,
}).Debug("Session validation denied")
return "", nil
}
return resp.UserId, nil
}
// For non-OIDC methods or when no validator is configured, validate JWT locally
userID, _, err := auth.ValidateSessionJWT(token, host, publicKey)
if err != nil {
return "", err
}
return userID, nil
}
// stripSessionTokenParam returns the request URI with the session_token query
// parameter removed so it doesn't linger in the browser's address bar or history.
func stripSessionTokenParam(u *url.URL) string {

View File

@@ -17,6 +17,7 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/proxy/internal/proxy"
)
func generateTestKeyPair(t *testing.T) *sessionkey.KeyPair {
@@ -28,10 +29,10 @@ func generateTestKeyPair(t *testing.T) *sessionkey.KeyPair {
// stubScheme is a minimal Scheme implementation for testing.
type stubScheme struct {
method auth.Method
token string
promptID string
authFn func(*http.Request) (string, string)
method auth.Method
token string
promptID string
authFn func(*http.Request) (string, string)
}
func (s *stubScheme) Type() auth.Method { return s.method }
@@ -51,7 +52,7 @@ func newPassthroughHandler() http.Handler {
}
func TestAddDomain_ValidKey(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
@@ -69,7 +70,7 @@ func TestAddDomain_ValidKey(t *testing.T) {
}
func TestAddDomain_EmptyKey(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour)
@@ -83,7 +84,7 @@ func TestAddDomain_EmptyKey(t *testing.T) {
}
func TestAddDomain_InvalidBase64(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour)
@@ -97,7 +98,7 @@ func TestAddDomain_InvalidBase64(t *testing.T) {
}
func TestAddDomain_WrongKeySize(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
shortKey := base64.StdEncoding.EncodeToString([]byte("tooshort"))
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
@@ -112,7 +113,7 @@ func TestAddDomain_WrongKeySize(t *testing.T) {
}
func TestAddDomain_NoSchemes_NoKeyRequired(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
err := mw.AddDomain("example.com", nil, "", time.Hour)
require.NoError(t, err, "domains with no auth schemes should not require a key")
@@ -124,7 +125,7 @@ func TestAddDomain_NoSchemes_NoKeyRequired(t *testing.T) {
}
func TestAddDomain_OverwritesPreviousConfig(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
kp1 := generateTestKeyPair(t)
kp2 := generateTestKeyPair(t)
@@ -143,7 +144,7 @@ func TestAddDomain_OverwritesPreviousConfig(t *testing.T) {
}
func TestRemoveDomain(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
@@ -158,7 +159,7 @@ func TestRemoveDomain(t *testing.T) {
}
func TestProtect_UnknownDomainPassesThrough(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "http://unknown.com/", nil)
@@ -170,7 +171,7 @@ func TestProtect_UnknownDomainPassesThrough(t *testing.T) {
}
func TestProtect_DomainWithNoSchemesPassesThrough(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour))
handler := mw.Protect(newPassthroughHandler())
@@ -184,7 +185,7 @@ func TestProtect_DomainWithNoSchemesPassesThrough(t *testing.T) {
}
func TestProtect_UnauthenticatedRequestIsBlocked(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
@@ -205,7 +206,7 @@ func TestProtect_UnauthenticatedRequestIsBlocked(t *testing.T) {
}
func TestProtect_HostWithPortIsMatched(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
@@ -226,7 +227,7 @@ func TestProtect_HostWithPortIsMatched(t *testing.T) {
}
func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
@@ -235,16 +236,18 @@ func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour)
require.NoError(t, err)
capturedData := &proxy.CapturedData{}
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := UserFromContext(r.Context())
method := MethodFromContext(r.Context())
assert.Equal(t, "test-user", user)
assert.Equal(t, auth.MethodPIN, method)
cd := proxy.CapturedDataFromContext(r.Context())
require.NotNil(t, cd)
assert.Equal(t, "test-user", cd.GetUserID())
assert.Equal(t, "pin", cd.GetAuthMethod())
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("authenticated"))
}))
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req = req.WithContext(proxy.WithCapturedData(req.Context(), capturedData))
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: token})
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
@@ -254,7 +257,7 @@ func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
}
func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
@@ -280,7 +283,7 @@ func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) {
}
func TestProtect_WrongDomainCookieIsRejected(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
@@ -306,7 +309,7 @@ func TestProtect_WrongDomainCookieIsRejected(t *testing.T) {
}
func TestProtect_WrongKeyCookieIsRejected(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
kp1 := generateTestKeyPair(t)
kp2 := generateTestKeyPair(t)
@@ -333,7 +336,7 @@ func TestProtect_WrongKeyCookieIsRejected(t *testing.T) {
}
func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
token, err := sessionkey.SignToken(kp.PrivateKey, "pin-user", "example.com", auth.MethodPIN, time.Hour)
@@ -383,7 +386,7 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
}
func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{
@@ -406,7 +409,7 @@ func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
}
func TestProtect_MultipleSchemes(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
token, err := sessionkey.SignToken(kp.PrivateKey, "password-user", "example.com", auth.MethodPassword, time.Hour)
@@ -448,7 +451,7 @@ func TestProtect_MultipleSchemes(t *testing.T) {
}
func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
// Return a garbage token that won't validate.
@@ -470,7 +473,7 @@ func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) {
}
func TestAddDomain_RandomBytes32NotEd25519(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
// 32 random bytes that happen to be valid base64 and correct size
// but are actually a valid ed25519 public key length-wise.
@@ -487,7 +490,7 @@ func TestAddDomain_RandomBytes32NotEd25519(t *testing.T) {
}
func TestAddDomain_InvalidKeyDoesNotCorruptExistingConfig(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}