mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[proxy, management] Add header auth, access restrictions, and session idle timeout (#5587)
This commit is contained in:
@@ -1,11 +1,14 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -14,10 +17,13 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||
"github.com/netbirdio/netbird/proxy/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
"github.com/netbirdio/netbird/proxy/internal/restrict"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
func generateTestKeyPair(t *testing.T) *sessionkey.KeyPair {
|
||||
@@ -52,11 +58,11 @@ func newPassthroughHandler() http.Handler {
|
||||
}
|
||||
|
||||
func TestAddDomain_ValidKey(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
mw.domainsMux.RLock()
|
||||
@@ -70,10 +76,10 @@ func TestAddDomain_ValidKey(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAddDomain_EmptyKey(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour, "", "")
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour, "", "", nil)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid session public key size")
|
||||
|
||||
@@ -84,10 +90,10 @@ func TestAddDomain_EmptyKey(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAddDomain_InvalidBase64(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour, "", "")
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour, "", "", nil)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "decode session public key")
|
||||
|
||||
@@ -98,11 +104,11 @@ func TestAddDomain_InvalidBase64(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAddDomain_WrongKeySize(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
shortKey := base64.StdEncoding.EncodeToString([]byte("tooshort"))
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, shortKey, time.Hour, "", "")
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, shortKey, time.Hour, "", "", nil)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid session public key size")
|
||||
|
||||
@@ -113,9 +119,9 @@ func TestAddDomain_WrongKeySize(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAddDomain_NoSchemes_NoKeyRequired(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
err := mw.AddDomain("example.com", nil, "", time.Hour, "", "")
|
||||
err := mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil)
|
||||
require.NoError(t, err, "domains with no auth schemes should not require a key")
|
||||
|
||||
mw.domainsMux.RLock()
|
||||
@@ -125,14 +131,14 @@ func TestAddDomain_NoSchemes_NoKeyRequired(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAddDomain_OverwritesPreviousConfig(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp1 := generateTestKeyPair(t)
|
||||
kp2 := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", ""))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp2.PublicKey, 2*time.Hour, "", ""))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp2.PublicKey, 2*time.Hour, "", "", nil))
|
||||
|
||||
mw.domainsMux.RLock()
|
||||
config := mw.domains["example.com"]
|
||||
@@ -144,11 +150,11 @@ func TestAddDomain_OverwritesPreviousConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRemoveDomain(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
mw.RemoveDomain("example.com")
|
||||
|
||||
@@ -159,7 +165,7 @@ func TestRemoveDomain(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_UnknownDomainPassesThrough(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://unknown.com/", nil)
|
||||
@@ -171,8 +177,8 @@ func TestProtect_UnknownDomainPassesThrough(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_DomainWithNoSchemesPassesThrough(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour, "", ""))
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
@@ -185,11 +191,11 @@ func TestProtect_DomainWithNoSchemesPassesThrough(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_UnauthenticatedRequestIsBlocked(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
var backendCalled bool
|
||||
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
@@ -206,11 +212,11 @@ func TestProtect_UnauthenticatedRequestIsBlocked(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_HostWithPortIsMatched(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
var backendCalled bool
|
||||
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
@@ -227,16 +233,16 @@ func TestProtect_HostWithPortIsMatched(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
capturedData := &proxy.CapturedData{}
|
||||
capturedData := proxy.NewCapturedData("")
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
cd := proxy.CapturedDataFromContext(r.Context())
|
||||
require.NotNil(t, cd)
|
||||
@@ -257,11 +263,11 @@ func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
// Sign a token that expired 1 second ago.
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, -time.Second)
|
||||
@@ -283,11 +289,11 @@ func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_WrongDomainCookieIsRejected(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
// Token signed for a different domain audience.
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "other.com", auth.MethodPIN, time.Hour)
|
||||
@@ -309,12 +315,12 @@ func TestProtect_WrongDomainCookieIsRejected(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_WrongKeyCookieIsRejected(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp1 := generateTestKeyPair(t)
|
||||
kp2 := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", ""))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
// Token signed with a different private key.
|
||||
token, err := sessionkey.SignToken(kp2.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour)
|
||||
@@ -336,7 +342,7 @@ func TestProtect_WrongKeyCookieIsRejected(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "pin-user", "example.com", auth.MethodPIN, time.Hour)
|
||||
@@ -351,7 +357,7 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
|
||||
return "", "pin", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
var backendCalled bool
|
||||
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
@@ -386,7 +392,7 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{
|
||||
@@ -395,7 +401,7 @@ func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
|
||||
return "", "pin", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
@@ -409,7 +415,7 @@ func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_MultipleSchemes(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "password-user", "example.com", auth.MethodPassword, time.Hour)
|
||||
@@ -431,7 +437,7 @@ func TestProtect_MultipleSchemes(t *testing.T) {
|
||||
return "", "password", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme, passwordScheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme, passwordScheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
var backendCalled bool
|
||||
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
@@ -451,7 +457,7 @@ func TestProtect_MultipleSchemes(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
// Return a garbage token that won't validate.
|
||||
@@ -461,7 +467,7 @@ func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) {
|
||||
return "invalid-jwt-token", "", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
@@ -473,7 +479,7 @@ func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAddDomain_RandomBytes32NotEd25519(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
// 32 random bytes that happen to be valid base64 and correct size
|
||||
// but are actually a valid ed25519 public key length-wise.
|
||||
@@ -485,19 +491,19 @@ func TestAddDomain_RandomBytes32NotEd25519(t *testing.T) {
|
||||
key := base64.StdEncoding.EncodeToString(randomBytes)
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
|
||||
err = mw.AddDomain("example.com", []Scheme{scheme}, key, time.Hour, "", "")
|
||||
err = mw.AddDomain("example.com", []Scheme{scheme}, key, time.Hour, "", "", nil)
|
||||
require.NoError(t, err, "any 32-byte key should be accepted at registration time")
|
||||
}
|
||||
|
||||
func TestAddDomain_InvalidKeyDoesNotCorruptExistingConfig(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
// Attempt to overwrite with an invalid key.
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "bad", time.Hour, "", "")
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "bad", time.Hour, "", "", nil)
|
||||
require.Error(t, err)
|
||||
|
||||
// The original valid config should still be intact.
|
||||
@@ -511,7 +517,7 @@ func TestAddDomain_InvalidKeyDoesNotCorruptExistingConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_FailedPinAuthCapturesAuthMethod(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
// Scheme that always fails authentication (returns empty token)
|
||||
@@ -521,9 +527,9 @@ func TestProtect_FailedPinAuthCapturesAuthMethod(t *testing.T) {
|
||||
return "", "pin", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
capturedData := &proxy.CapturedData{}
|
||||
capturedData := proxy.NewCapturedData("")
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
// Submit wrong PIN - should capture auth method
|
||||
@@ -539,7 +545,7 @@ func TestProtect_FailedPinAuthCapturesAuthMethod(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_FailedPasswordAuthCapturesAuthMethod(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{
|
||||
@@ -548,9 +554,9 @@ func TestProtect_FailedPasswordAuthCapturesAuthMethod(t *testing.T) {
|
||||
return "", "password", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
capturedData := &proxy.CapturedData{}
|
||||
capturedData := proxy.NewCapturedData("")
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
// Submit wrong password - should capture auth method
|
||||
@@ -566,7 +572,7 @@ func TestProtect_FailedPasswordAuthCapturesAuthMethod(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_NoCredentialsDoesNotCaptureAuthMethod(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{
|
||||
@@ -575,9 +581,9 @@ func TestProtect_NoCredentialsDoesNotCaptureAuthMethod(t *testing.T) {
|
||||
return "", "pin", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
capturedData := &proxy.CapturedData{}
|
||||
capturedData := proxy.NewCapturedData("")
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
// No credentials submitted - should not capture auth method
|
||||
@@ -658,3 +664,271 @@ func TestWasCredentialSubmitted(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckIPRestrictions_UnparseableAddress(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
|
||||
restrict.ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil))
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
remoteAddr string
|
||||
wantCode int
|
||||
}{
|
||||
{"unparsable address denies", "not-an-ip:1234", http.StatusForbidden},
|
||||
{"empty address denies", "", http.StatusForbidden},
|
||||
{"allowed address passes", "10.1.2.3:5678", http.StatusOK},
|
||||
{"denied address blocked", "192.168.1.1:5678", http.StatusForbidden},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req.RemoteAddr = tt.remoteAddr
|
||||
req.Host = "example.com"
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
assert.Equal(t, tt.wantCode, rr.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckIPRestrictions_UsesCapturedDataClientIP(t *testing.T) {
|
||||
// When CapturedData is set (by the access log middleware, which resolves
|
||||
// trusted proxies), checkIPRestrictions should use that IP, not RemoteAddr.
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
|
||||
restrict.ParseFilter([]string{"203.0.113.0/24"}, nil, nil, nil))
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// RemoteAddr is a trusted proxy, but CapturedData has the real client IP.
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req.RemoteAddr = "10.0.0.1:5000"
|
||||
req.Host = "example.com"
|
||||
|
||||
cd := proxy.NewCapturedData("")
|
||||
cd.SetClientIP(netip.MustParseAddr("203.0.113.50"))
|
||||
ctx := proxy.WithCapturedData(req.Context(), cd)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
assert.Equal(t, http.StatusOK, rr.Code, "should use CapturedData IP (203.0.113.50), not RemoteAddr (10.0.0.1)")
|
||||
|
||||
// Same request but CapturedData has a blocked IP.
|
||||
req2 := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req2.RemoteAddr = "203.0.113.50:5000"
|
||||
req2.Host = "example.com"
|
||||
|
||||
cd2 := proxy.NewCapturedData("")
|
||||
cd2.SetClientIP(netip.MustParseAddr("10.0.0.1"))
|
||||
ctx2 := proxy.WithCapturedData(req2.Context(), cd2)
|
||||
req2 = req2.WithContext(ctx2)
|
||||
|
||||
rr2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr2, req2)
|
||||
assert.Equal(t, http.StatusForbidden, rr2.Code, "should use CapturedData IP (10.0.0.1), not RemoteAddr (203.0.113.50)")
|
||||
}
|
||||
|
||||
func TestCheckIPRestrictions_NilGeoWithCountryRules(t *testing.T) {
|
||||
// Geo is nil, country restrictions are configured: must deny (fail-close).
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
|
||||
restrict.ParseFilter(nil, nil, []string{"US"}, nil))
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req.RemoteAddr = "1.2.3.4:5678"
|
||||
req.Host = "example.com"
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
assert.Equal(t, http.StatusForbidden, rr.Code, "country restrictions with nil geo must deny")
|
||||
}
|
||||
|
||||
// mockAuthenticator is a minimal mock for the authenticator gRPC interface
|
||||
// used by the Header scheme.
|
||||
type mockAuthenticator struct {
|
||||
fn func(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error)
|
||||
}
|
||||
|
||||
func (m *mockAuthenticator) Authenticate(ctx context.Context, in *proto.AuthenticateRequest, _ ...grpc.CallOption) (*proto.AuthenticateResponse, error) {
|
||||
return m.fn(ctx, in)
|
||||
}
|
||||
|
||||
// newHeaderSchemeWithToken creates a Header scheme backed by a mock that
|
||||
// returns a signed session token when the expected header value is provided.
|
||||
func newHeaderSchemeWithToken(t *testing.T, kp *sessionkey.KeyPair, headerName, expectedValue string) Header {
|
||||
t.Helper()
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "example.com", auth.MethodHeader, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
mock := &mockAuthenticator{fn: func(_ context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
|
||||
ha := req.GetHeaderAuth()
|
||||
if ha != nil && ha.GetHeaderValue() == expectedValue {
|
||||
return &proto.AuthenticateResponse{Success: true, SessionToken: token}, nil
|
||||
}
|
||||
return &proto.AuthenticateResponse{Success: false}, nil
|
||||
}}
|
||||
return NewHeader(mock, "svc1", "acc1", headerName)
|
||||
}
|
||||
|
||||
func TestProtect_HeaderAuth_ForwardsOnSuccess(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key")
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
|
||||
|
||||
var backendCalled bool
|
||||
capturedData := proxy.NewCapturedData("")
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
backendCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/path", nil)
|
||||
req.Header.Set("X-API-Key", "secret-key")
|
||||
req = req.WithContext(proxy.WithCapturedData(req.Context(), capturedData))
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.True(t, backendCalled, "backend should be called directly for header auth (no redirect)")
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "ok", rec.Body.String())
|
||||
|
||||
// Session cookie should be set.
|
||||
var sessionCookie *http.Cookie
|
||||
for _, c := range rec.Result().Cookies() {
|
||||
if c.Name == auth.SessionCookieName {
|
||||
sessionCookie = c
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, sessionCookie, "session cookie should be set after successful header auth")
|
||||
assert.True(t, sessionCookie.HttpOnly)
|
||||
assert.True(t, sessionCookie.Secure)
|
||||
|
||||
assert.Equal(t, "header-user", capturedData.GetUserID())
|
||||
assert.Equal(t, "header", capturedData.GetAuthMethod())
|
||||
}
|
||||
|
||||
func TestProtect_HeaderAuth_MissingHeaderFallsThrough(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key")
|
||||
// Also add a PIN scheme so we can verify fallthrough behavior.
|
||||
pinScheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr, pinScheme}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
// No X-API-Key header: should fall through to PIN login page (401).
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, rec.Code, "missing header should fall through to login page")
|
||||
}
|
||||
|
||||
func TestProtect_HeaderAuth_WrongValueReturns401(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
mock := &mockAuthenticator{fn: func(_ context.Context, _ *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
|
||||
return &proto.AuthenticateResponse{Success: false}, nil
|
||||
}}
|
||||
hdr := NewHeader(mock, "svc1", "acc1", "X-API-Key")
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
|
||||
|
||||
capturedData := proxy.NewCapturedData("")
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req.Header.Set("X-API-Key", "wrong-key")
|
||||
req = req.WithContext(proxy.WithCapturedData(req.Context(), capturedData))
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
assert.Equal(t, "header", capturedData.GetAuthMethod())
|
||||
}
|
||||
|
||||
func TestProtect_HeaderAuth_InfraErrorReturns502(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
mock := &mockAuthenticator{fn: func(_ context.Context, _ *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
|
||||
return nil, errors.New("gRPC unavailable")
|
||||
}}
|
||||
hdr := NewHeader(mock, "svc1", "acc1", "X-API-Key")
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req.Header.Set("X-API-Key", "some-key")
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadGateway, rec.Code)
|
||||
}
|
||||
|
||||
func TestProtect_HeaderAuth_SubsequentRequestUsesSessionCookie(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key")
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
|
||||
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// First request with header auth.
|
||||
req1 := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req1.Header.Set("X-API-Key", "secret-key")
|
||||
req1 = req1.WithContext(proxy.WithCapturedData(req1.Context(), proxy.NewCapturedData("")))
|
||||
rec1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec1, req1)
|
||||
require.Equal(t, http.StatusOK, rec1.Code)
|
||||
|
||||
// Extract session cookie.
|
||||
var sessionCookie *http.Cookie
|
||||
for _, c := range rec1.Result().Cookies() {
|
||||
if c.Name == auth.SessionCookieName {
|
||||
sessionCookie = c
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, sessionCookie)
|
||||
|
||||
// Second request with only the session cookie (no header).
|
||||
capturedData2 := proxy.NewCapturedData("")
|
||||
req2 := httptest.NewRequest(http.MethodGet, "http://example.com/other", nil)
|
||||
req2.AddCookie(sessionCookie)
|
||||
req2 = req2.WithContext(proxy.WithCapturedData(req2.Context(), capturedData2))
|
||||
rec2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec2, req2)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec2.Code)
|
||||
assert.Equal(t, "header-user", capturedData2.GetUserID())
|
||||
assert.Equal(t, "header", capturedData2.GetAuthMethod())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user