[client,management] Rewrite the SSH feature (#4015)

This commit is contained in:
Viktor Liu
2025-11-17 17:10:41 +01:00
committed by GitHub
parent 0d79301141
commit d71a82769c
170 changed files with 18744 additions and 2853 deletions

View File

@@ -10,22 +10,23 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/auth"
serverauth "github.com/netbirdio/netbird/management/server/auth"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
type EnsureAccountFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth nbcontext.UserAuth) error
type EnsureAccountFunc func(ctx context.Context, userAuth auth.UserAuth) (string, string, error)
type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth auth.UserAuth) error
type GetUserFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error)
type GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
type AuthMiddleware struct {
authManager auth.Manager
authManager serverauth.Manager
ensureAccount EnsureAccountFunc
getUserFromUserAuth GetUserFromUserAuthFunc
syncUserJWTGroups SyncUserJWTGroupsFunc
@@ -34,7 +35,7 @@ type AuthMiddleware struct {
// NewAuthMiddleware instance constructor
func NewAuthMiddleware(
authManager auth.Manager,
authManager serverauth.Manager,
ensureAccount EnsureAccountFunc,
syncUserJWTGroups SyncUserJWTGroupsFunc,
getUserFromUserAuth GetUserFromUserAuthFunc,
@@ -61,18 +62,18 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
return
}
auth := strings.Split(r.Header.Get("Authorization"), " ")
authType := strings.ToLower(auth[0])
authHeader := strings.Split(r.Header.Get("Authorization"), " ")
authType := strings.ToLower(authHeader[0])
// fallback to token when receive pat as bearer
if len(auth) >= 2 && authType == "bearer" && strings.HasPrefix(auth[1], "nbp_") {
if len(authHeader) >= 2 && authType == "bearer" && strings.HasPrefix(authHeader[1], "nbp_") {
authType = "token"
auth[0] = authType
authHeader[0] = authType
}
switch authType {
case "bearer":
request, err := m.checkJWTFromRequest(r, auth)
request, err := m.checkJWTFromRequest(r, authHeader)
if err != nil {
log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error())
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
@@ -81,7 +82,7 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
h.ServeHTTP(w, request)
case "token":
request, err := m.checkPATFromRequest(r, auth)
request, err := m.checkPATFromRequest(r, authHeader)
if err != nil {
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
// Check if it's a status error, otherwise default to Unauthorized
@@ -100,8 +101,8 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
}
// CheckJWTFromRequest checks if the JWT is valid
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, auth []string) (*http.Request, error) {
token, err := getTokenFromJWTRequest(auth)
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) {
token, err := getTokenFromJWTRequest(authHeaderParts)
// If an error occurs, call the error handler and return an error
if err != nil {
@@ -151,8 +152,8 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, auth []string) (*h
}
// CheckPATFromRequest checks if the PAT is valid
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*http.Request, error) {
token, err := getTokenFromPATRequest(auth)
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) {
token, err := getTokenFromPATRequest(authHeaderParts)
if err != nil {
return r, fmt.Errorf("error extracting token: %w", err)
}
@@ -177,7 +178,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*h
return r, err
}
userAuth := nbcontext.UserAuth{
userAuth := auth.UserAuth{
UserId: user.Id,
AccountId: user.AccountID,
Domain: accDomain,

View File

@@ -12,11 +12,12 @@ import (
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/auth"
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
nbauth "github.com/netbirdio/netbird/shared/auth"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
)
const (
@@ -75,9 +76,9 @@ func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *types.Use
return nil, nil, "", "", fmt.Errorf("PAT invalid")
}
func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) {
func mockValidateAndParseToken(_ context.Context, token string) (nbauth.UserAuth, *jwt.Token, error) {
if token == JWT {
return nbcontext.UserAuth{
return nbauth.UserAuth{
UserId: userID,
AccountId: accountID,
Domain: testAccount.Domain,
@@ -91,7 +92,7 @@ func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserA
Valid: true,
}, nil
}
return nbcontext.UserAuth{}, nil, fmt.Errorf("JWT invalid")
return nbauth.UserAuth{}, nil, fmt.Errorf("JWT invalid")
}
func mockMarkPATUsed(_ context.Context, token string) error {
@@ -101,7 +102,7 @@ func mockMarkPATUsed(_ context.Context, token string) error {
return fmt.Errorf("Should never get reached")
}
func mockEnsureUserAccessByJWTGroups(_ context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) {
func mockEnsureUserAccessByJWTGroups(_ context.Context, userAuth nbauth.UserAuth, token *jwt.Token) (nbauth.UserAuth, error) {
if userAuth.IsChild || userAuth.IsPAT {
return userAuth, nil
}
@@ -197,13 +198,13 @@ func TestAuthMiddleware_Handler(t *testing.T) {
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
func(ctx context.Context, userAuth nbauth.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
nil,
@@ -255,13 +256,13 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
func(ctx context.Context, userAuth nbauth.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
@@ -306,13 +307,13 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
func(ctx context.Context, userAuth nbauth.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
@@ -348,13 +349,13 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
func(ctx context.Context, userAuth nbauth.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
@@ -391,13 +392,13 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
func(ctx context.Context, userAuth nbauth.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
@@ -454,13 +455,13 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
func(ctx context.Context, userAuth nbauth.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
@@ -508,13 +509,13 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
name string
path string
authHeader string
expectedUserAuth *nbcontext.UserAuth // nil expects 401 response status
expectedUserAuth *nbauth.UserAuth // nil expects 401 response status
}{
{
name: "Valid PAT Token",
path: "/test",
authHeader: "Token " + PAT,
expectedUserAuth: &nbcontext.UserAuth{
expectedUserAuth: &nbauth.UserAuth{
AccountId: accountID,
UserId: userID,
Domain: testAccount.Domain,
@@ -526,7 +527,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
name: "Valid PAT Token accesses child",
path: "/test?account=xyz",
authHeader: "Token " + PAT,
expectedUserAuth: &nbcontext.UserAuth{
expectedUserAuth: &nbauth.UserAuth{
AccountId: "xyz",
UserId: userID,
Domain: testAccount.Domain,
@@ -539,7 +540,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
name: "Valid JWT Token",
path: "/test",
authHeader: "Bearer " + JWT,
expectedUserAuth: &nbcontext.UserAuth{
expectedUserAuth: &nbauth.UserAuth{
AccountId: accountID,
UserId: userID,
Domain: testAccount.Domain,
@@ -551,7 +552,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
name: "Valid JWT Token with child",
path: "/test?account=xyz",
authHeader: "Bearer " + JWT,
expectedUserAuth: &nbcontext.UserAuth{
expectedUserAuth: &nbauth.UserAuth{
AccountId: "xyz",
UserId: userID,
Domain: testAccount.Domain,
@@ -570,13 +571,13 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
func(ctx context.Context, userAuth nbauth.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
nil,