mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-19 08:46:38 +00:00
refactor: add ValidateSession gRPC and streamline test setup
- Add ValidateSession gRPC method for proxy-side user validation - Move group access validation from REST callback to gRPC layer - Capture user info in access logs via CapturedData mutable pointer - Create validate_session_test.go for gRPC validation tests - Simplify auth_callback_integration_test.go to create accounts programmatically instead of using SQL file - SQL test data file now only used by validate_session_test.go
This commit is contained in:
@@ -17,6 +17,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||
"github.com/netbirdio/netbird/proxy/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
)
|
||||
|
||||
func generateTestKeyPair(t *testing.T) *sessionkey.KeyPair {
|
||||
@@ -28,10 +29,10 @@ func generateTestKeyPair(t *testing.T) *sessionkey.KeyPair {
|
||||
|
||||
// stubScheme is a minimal Scheme implementation for testing.
|
||||
type stubScheme struct {
|
||||
method auth.Method
|
||||
token string
|
||||
promptID string
|
||||
authFn func(*http.Request) (string, string)
|
||||
method auth.Method
|
||||
token string
|
||||
promptID string
|
||||
authFn func(*http.Request) (string, string)
|
||||
}
|
||||
|
||||
func (s *stubScheme) Type() auth.Method { return s.method }
|
||||
@@ -51,7 +52,7 @@ func newPassthroughHandler() http.Handler {
|
||||
}
|
||||
|
||||
func TestAddDomain_ValidKey(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
@@ -69,7 +70,7 @@ func TestAddDomain_ValidKey(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAddDomain_EmptyKey(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour)
|
||||
@@ -83,7 +84,7 @@ func TestAddDomain_EmptyKey(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAddDomain_InvalidBase64(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour)
|
||||
@@ -97,7 +98,7 @@ func TestAddDomain_InvalidBase64(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAddDomain_WrongKeySize(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
|
||||
shortKey := base64.StdEncoding.EncodeToString([]byte("tooshort"))
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
@@ -112,7 +113,7 @@ func TestAddDomain_WrongKeySize(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAddDomain_NoSchemes_NoKeyRequired(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
|
||||
err := mw.AddDomain("example.com", nil, "", time.Hour)
|
||||
require.NoError(t, err, "domains with no auth schemes should not require a key")
|
||||
@@ -124,7 +125,7 @@ func TestAddDomain_NoSchemes_NoKeyRequired(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAddDomain_OverwritesPreviousConfig(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp1 := generateTestKeyPair(t)
|
||||
kp2 := generateTestKeyPair(t)
|
||||
|
||||
@@ -143,7 +144,7 @@ func TestAddDomain_OverwritesPreviousConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRemoveDomain(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
@@ -158,7 +159,7 @@ func TestRemoveDomain(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_UnknownDomainPassesThrough(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://unknown.com/", nil)
|
||||
@@ -170,7 +171,7 @@ func TestProtect_UnknownDomainPassesThrough(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_DomainWithNoSchemesPassesThrough(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
@@ -184,7 +185,7 @@ func TestProtect_DomainWithNoSchemesPassesThrough(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_UnauthenticatedRequestIsBlocked(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
@@ -205,7 +206,7 @@ func TestProtect_UnauthenticatedRequestIsBlocked(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_HostWithPortIsMatched(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
@@ -226,7 +227,7 @@ func TestProtect_HostWithPortIsMatched(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
@@ -235,16 +236,18 @@ func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
capturedData := &proxy.CapturedData{}
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
user := UserFromContext(r.Context())
|
||||
method := MethodFromContext(r.Context())
|
||||
assert.Equal(t, "test-user", user)
|
||||
assert.Equal(t, auth.MethodPIN, method)
|
||||
cd := proxy.CapturedDataFromContext(r.Context())
|
||||
require.NotNil(t, cd)
|
||||
assert.Equal(t, "test-user", cd.GetUserID())
|
||||
assert.Equal(t, "pin", cd.GetAuthMethod())
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("authenticated"))
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req = req.WithContext(proxy.WithCapturedData(req.Context(), capturedData))
|
||||
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: token})
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
@@ -254,7 +257,7 @@ func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
@@ -280,7 +283,7 @@ func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_WrongDomainCookieIsRejected(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
@@ -306,7 +309,7 @@ func TestProtect_WrongDomainCookieIsRejected(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_WrongKeyCookieIsRejected(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp1 := generateTestKeyPair(t)
|
||||
kp2 := generateTestKeyPair(t)
|
||||
|
||||
@@ -333,7 +336,7 @@ func TestProtect_WrongKeyCookieIsRejected(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "pin-user", "example.com", auth.MethodPIN, time.Hour)
|
||||
@@ -383,7 +386,7 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{
|
||||
@@ -406,7 +409,7 @@ func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_MultipleSchemes(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "password-user", "example.com", auth.MethodPassword, time.Hour)
|
||||
@@ -448,7 +451,7 @@ func TestProtect_MultipleSchemes(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
// Return a garbage token that won't validate.
|
||||
@@ -470,7 +473,7 @@ func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAddDomain_RandomBytes32NotEd25519(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
|
||||
// 32 random bytes that happen to be valid base64 and correct size
|
||||
// but are actually a valid ed25519 public key length-wise.
|
||||
@@ -487,7 +490,7 @@ func TestAddDomain_RandomBytes32NotEd25519(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAddDomain_InvalidKeyDoesNotCorruptExistingConfig(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger())
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
|
||||
Reference in New Issue
Block a user