mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-22 02:06:39 +00:00
[client,management] Rewrite the SSH feature (#4015)
This commit is contained in:
@@ -10,22 +10,23 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
serverauth "github.com/netbirdio/netbird/management/server/auth"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type EnsureAccountFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
|
||||
type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth nbcontext.UserAuth) error
|
||||
type EnsureAccountFunc func(ctx context.Context, userAuth auth.UserAuth) (string, string, error)
|
||||
type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth auth.UserAuth) error
|
||||
|
||||
type GetUserFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error)
|
||||
type GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||
|
||||
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
|
||||
type AuthMiddleware struct {
|
||||
authManager auth.Manager
|
||||
authManager serverauth.Manager
|
||||
ensureAccount EnsureAccountFunc
|
||||
getUserFromUserAuth GetUserFromUserAuthFunc
|
||||
syncUserJWTGroups SyncUserJWTGroupsFunc
|
||||
@@ -34,7 +35,7 @@ type AuthMiddleware struct {
|
||||
|
||||
// NewAuthMiddleware instance constructor
|
||||
func NewAuthMiddleware(
|
||||
authManager auth.Manager,
|
||||
authManager serverauth.Manager,
|
||||
ensureAccount EnsureAccountFunc,
|
||||
syncUserJWTGroups SyncUserJWTGroupsFunc,
|
||||
getUserFromUserAuth GetUserFromUserAuthFunc,
|
||||
@@ -61,18 +62,18 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
auth := strings.Split(r.Header.Get("Authorization"), " ")
|
||||
authType := strings.ToLower(auth[0])
|
||||
authHeader := strings.Split(r.Header.Get("Authorization"), " ")
|
||||
authType := strings.ToLower(authHeader[0])
|
||||
|
||||
// fallback to token when receive pat as bearer
|
||||
if len(auth) >= 2 && authType == "bearer" && strings.HasPrefix(auth[1], "nbp_") {
|
||||
if len(authHeader) >= 2 && authType == "bearer" && strings.HasPrefix(authHeader[1], "nbp_") {
|
||||
authType = "token"
|
||||
auth[0] = authType
|
||||
authHeader[0] = authType
|
||||
}
|
||||
|
||||
switch authType {
|
||||
case "bearer":
|
||||
request, err := m.checkJWTFromRequest(r, auth)
|
||||
request, err := m.checkJWTFromRequest(r, authHeader)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error())
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||
@@ -81,7 +82,7 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
|
||||
h.ServeHTTP(w, request)
|
||||
case "token":
|
||||
request, err := m.checkPATFromRequest(r, auth)
|
||||
request, err := m.checkPATFromRequest(r, authHeader)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
|
||||
// Check if it's a status error, otherwise default to Unauthorized
|
||||
@@ -100,8 +101,8 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
}
|
||||
|
||||
// CheckJWTFromRequest checks if the JWT is valid
|
||||
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, auth []string) (*http.Request, error) {
|
||||
token, err := getTokenFromJWTRequest(auth)
|
||||
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) {
|
||||
token, err := getTokenFromJWTRequest(authHeaderParts)
|
||||
|
||||
// If an error occurs, call the error handler and return an error
|
||||
if err != nil {
|
||||
@@ -151,8 +152,8 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, auth []string) (*h
|
||||
}
|
||||
|
||||
// CheckPATFromRequest checks if the PAT is valid
|
||||
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*http.Request, error) {
|
||||
token, err := getTokenFromPATRequest(auth)
|
||||
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) {
|
||||
token, err := getTokenFromPATRequest(authHeaderParts)
|
||||
if err != nil {
|
||||
return r, fmt.Errorf("error extracting token: %w", err)
|
||||
}
|
||||
@@ -177,7 +178,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*h
|
||||
return r, err
|
||||
}
|
||||
|
||||
userAuth := nbcontext.UserAuth{
|
||||
userAuth := auth.UserAuth{
|
||||
UserId: user.Id,
|
||||
AccountId: user.AccountID,
|
||||
Domain: accDomain,
|
||||
|
||||
@@ -12,11 +12,12 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
nbauth "github.com/netbirdio/netbird/shared/auth"
|
||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -75,9 +76,9 @@ func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *types.Use
|
||||
return nil, nil, "", "", fmt.Errorf("PAT invalid")
|
||||
}
|
||||
|
||||
func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) {
|
||||
func mockValidateAndParseToken(_ context.Context, token string) (nbauth.UserAuth, *jwt.Token, error) {
|
||||
if token == JWT {
|
||||
return nbcontext.UserAuth{
|
||||
return nbauth.UserAuth{
|
||||
UserId: userID,
|
||||
AccountId: accountID,
|
||||
Domain: testAccount.Domain,
|
||||
@@ -91,7 +92,7 @@ func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserA
|
||||
Valid: true,
|
||||
}, nil
|
||||
}
|
||||
return nbcontext.UserAuth{}, nil, fmt.Errorf("JWT invalid")
|
||||
return nbauth.UserAuth{}, nil, fmt.Errorf("JWT invalid")
|
||||
}
|
||||
|
||||
func mockMarkPATUsed(_ context.Context, token string) error {
|
||||
@@ -101,7 +102,7 @@ func mockMarkPATUsed(_ context.Context, token string) error {
|
||||
return fmt.Errorf("Should never get reached")
|
||||
}
|
||||
|
||||
func mockEnsureUserAccessByJWTGroups(_ context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) {
|
||||
func mockEnsureUserAccessByJWTGroups(_ context.Context, userAuth nbauth.UserAuth, token *jwt.Token) (nbauth.UserAuth, error) {
|
||||
if userAuth.IsChild || userAuth.IsPAT {
|
||||
return userAuth, nil
|
||||
}
|
||||
@@ -197,13 +198,13 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
nil,
|
||||
@@ -255,13 +256,13 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
@@ -306,13 +307,13 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
@@ -348,13 +349,13 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
@@ -391,13 +392,13 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
@@ -454,13 +455,13 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
@@ -508,13 +509,13 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
name string
|
||||
path string
|
||||
authHeader string
|
||||
expectedUserAuth *nbcontext.UserAuth // nil expects 401 response status
|
||||
expectedUserAuth *nbauth.UserAuth // nil expects 401 response status
|
||||
}{
|
||||
{
|
||||
name: "Valid PAT Token",
|
||||
path: "/test",
|
||||
authHeader: "Token " + PAT,
|
||||
expectedUserAuth: &nbcontext.UserAuth{
|
||||
expectedUserAuth: &nbauth.UserAuth{
|
||||
AccountId: accountID,
|
||||
UserId: userID,
|
||||
Domain: testAccount.Domain,
|
||||
@@ -526,7 +527,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
name: "Valid PAT Token accesses child",
|
||||
path: "/test?account=xyz",
|
||||
authHeader: "Token " + PAT,
|
||||
expectedUserAuth: &nbcontext.UserAuth{
|
||||
expectedUserAuth: &nbauth.UserAuth{
|
||||
AccountId: "xyz",
|
||||
UserId: userID,
|
||||
Domain: testAccount.Domain,
|
||||
@@ -539,7 +540,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
name: "Valid JWT Token",
|
||||
path: "/test",
|
||||
authHeader: "Bearer " + JWT,
|
||||
expectedUserAuth: &nbcontext.UserAuth{
|
||||
expectedUserAuth: &nbauth.UserAuth{
|
||||
AccountId: accountID,
|
||||
UserId: userID,
|
||||
Domain: testAccount.Domain,
|
||||
@@ -551,7 +552,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
name: "Valid JWT Token with child",
|
||||
path: "/test?account=xyz",
|
||||
authHeader: "Bearer " + JWT,
|
||||
expectedUserAuth: &nbcontext.UserAuth{
|
||||
expectedUserAuth: &nbauth.UserAuth{
|
||||
AccountId: "xyz",
|
||||
UserId: userID,
|
||||
Domain: testAccount.Domain,
|
||||
@@ -570,13 +571,13 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
nil,
|
||||
|
||||
Reference in New Issue
Block a user