refactor: add ValidateSession gRPC and streamline test setup

- Add ValidateSession gRPC method for proxy-side user validation
- Move group access validation from REST callback to gRPC layer
- Capture user info in access logs via CapturedData mutable pointer
- Create validate_session_test.go for gRPC validation tests
- Simplify auth_callback_integration_test.go to create accounts
  programmatically instead of using SQL file
- SQL test data file now only used by validate_session_test.go
This commit is contained in:
mlsmaycon
2026-02-10 20:31:03 +01:00
parent 0cb02bd906
commit eea6120cd0
15 changed files with 955 additions and 238 deletions

View File

@@ -17,6 +17,7 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/proxy/internal/proxy"
)
func generateTestKeyPair(t *testing.T) *sessionkey.KeyPair {
@@ -28,10 +29,10 @@ func generateTestKeyPair(t *testing.T) *sessionkey.KeyPair {
// stubScheme is a minimal Scheme implementation for testing.
type stubScheme struct {
method auth.Method
token string
promptID string
authFn func(*http.Request) (string, string)
method auth.Method
token string
promptID string
authFn func(*http.Request) (string, string)
}
func (s *stubScheme) Type() auth.Method { return s.method }
@@ -51,7 +52,7 @@ func newPassthroughHandler() http.Handler {
}
func TestAddDomain_ValidKey(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
@@ -69,7 +70,7 @@ func TestAddDomain_ValidKey(t *testing.T) {
}
func TestAddDomain_EmptyKey(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour)
@@ -83,7 +84,7 @@ func TestAddDomain_EmptyKey(t *testing.T) {
}
func TestAddDomain_InvalidBase64(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour)
@@ -97,7 +98,7 @@ func TestAddDomain_InvalidBase64(t *testing.T) {
}
func TestAddDomain_WrongKeySize(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
shortKey := base64.StdEncoding.EncodeToString([]byte("tooshort"))
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
@@ -112,7 +113,7 @@ func TestAddDomain_WrongKeySize(t *testing.T) {
}
func TestAddDomain_NoSchemes_NoKeyRequired(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
err := mw.AddDomain("example.com", nil, "", time.Hour)
require.NoError(t, err, "domains with no auth schemes should not require a key")
@@ -124,7 +125,7 @@ func TestAddDomain_NoSchemes_NoKeyRequired(t *testing.T) {
}
func TestAddDomain_OverwritesPreviousConfig(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
kp1 := generateTestKeyPair(t)
kp2 := generateTestKeyPair(t)
@@ -143,7 +144,7 @@ func TestAddDomain_OverwritesPreviousConfig(t *testing.T) {
}
func TestRemoveDomain(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
@@ -158,7 +159,7 @@ func TestRemoveDomain(t *testing.T) {
}
func TestProtect_UnknownDomainPassesThrough(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "http://unknown.com/", nil)
@@ -170,7 +171,7 @@ func TestProtect_UnknownDomainPassesThrough(t *testing.T) {
}
func TestProtect_DomainWithNoSchemesPassesThrough(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour))
handler := mw.Protect(newPassthroughHandler())
@@ -184,7 +185,7 @@ func TestProtect_DomainWithNoSchemesPassesThrough(t *testing.T) {
}
func TestProtect_UnauthenticatedRequestIsBlocked(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
@@ -205,7 +206,7 @@ func TestProtect_UnauthenticatedRequestIsBlocked(t *testing.T) {
}
func TestProtect_HostWithPortIsMatched(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
@@ -226,7 +227,7 @@ func TestProtect_HostWithPortIsMatched(t *testing.T) {
}
func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
@@ -235,16 +236,18 @@ func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour)
require.NoError(t, err)
capturedData := &proxy.CapturedData{}
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := UserFromContext(r.Context())
method := MethodFromContext(r.Context())
assert.Equal(t, "test-user", user)
assert.Equal(t, auth.MethodPIN, method)
cd := proxy.CapturedDataFromContext(r.Context())
require.NotNil(t, cd)
assert.Equal(t, "test-user", cd.GetUserID())
assert.Equal(t, "pin", cd.GetAuthMethod())
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("authenticated"))
}))
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req = req.WithContext(proxy.WithCapturedData(req.Context(), capturedData))
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: token})
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
@@ -254,7 +257,7 @@ func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
}
func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
@@ -280,7 +283,7 @@ func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) {
}
func TestProtect_WrongDomainCookieIsRejected(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
@@ -306,7 +309,7 @@ func TestProtect_WrongDomainCookieIsRejected(t *testing.T) {
}
func TestProtect_WrongKeyCookieIsRejected(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
kp1 := generateTestKeyPair(t)
kp2 := generateTestKeyPair(t)
@@ -333,7 +336,7 @@ func TestProtect_WrongKeyCookieIsRejected(t *testing.T) {
}
func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
token, err := sessionkey.SignToken(kp.PrivateKey, "pin-user", "example.com", auth.MethodPIN, time.Hour)
@@ -383,7 +386,7 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
}
func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{
@@ -406,7 +409,7 @@ func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
}
func TestProtect_MultipleSchemes(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
token, err := sessionkey.SignToken(kp.PrivateKey, "password-user", "example.com", auth.MethodPassword, time.Hour)
@@ -448,7 +451,7 @@ func TestProtect_MultipleSchemes(t *testing.T) {
}
func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
// Return a garbage token that won't validate.
@@ -470,7 +473,7 @@ func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) {
}
func TestAddDomain_RandomBytes32NotEd25519(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
// 32 random bytes that happen to be valid base64 and correct size
// but are actually a valid ed25519 public key length-wise.
@@ -487,7 +490,7 @@ func TestAddDomain_RandomBytes32NotEd25519(t *testing.T) {
}
func TestAddDomain_InvalidKeyDoesNotCorruptExistingConfig(t *testing.T) {
mw := NewMiddleware(log.StandardLogger())
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}