mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-19 08:46:38 +00:00
[client,management] Rewrite the SSH feature (#4015)
This commit is contained in:
@@ -10,22 +10,23 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
serverauth "github.com/netbirdio/netbird/management/server/auth"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type EnsureAccountFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
|
||||
type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth nbcontext.UserAuth) error
|
||||
type EnsureAccountFunc func(ctx context.Context, userAuth auth.UserAuth) (string, string, error)
|
||||
type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth auth.UserAuth) error
|
||||
|
||||
type GetUserFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error)
|
||||
type GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||
|
||||
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
|
||||
type AuthMiddleware struct {
|
||||
authManager auth.Manager
|
||||
authManager serverauth.Manager
|
||||
ensureAccount EnsureAccountFunc
|
||||
getUserFromUserAuth GetUserFromUserAuthFunc
|
||||
syncUserJWTGroups SyncUserJWTGroupsFunc
|
||||
@@ -34,7 +35,7 @@ type AuthMiddleware struct {
|
||||
|
||||
// NewAuthMiddleware instance constructor
|
||||
func NewAuthMiddleware(
|
||||
authManager auth.Manager,
|
||||
authManager serverauth.Manager,
|
||||
ensureAccount EnsureAccountFunc,
|
||||
syncUserJWTGroups SyncUserJWTGroupsFunc,
|
||||
getUserFromUserAuth GetUserFromUserAuthFunc,
|
||||
@@ -61,18 +62,18 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
auth := strings.Split(r.Header.Get("Authorization"), " ")
|
||||
authType := strings.ToLower(auth[0])
|
||||
authHeader := strings.Split(r.Header.Get("Authorization"), " ")
|
||||
authType := strings.ToLower(authHeader[0])
|
||||
|
||||
// fallback to token when receive pat as bearer
|
||||
if len(auth) >= 2 && authType == "bearer" && strings.HasPrefix(auth[1], "nbp_") {
|
||||
if len(authHeader) >= 2 && authType == "bearer" && strings.HasPrefix(authHeader[1], "nbp_") {
|
||||
authType = "token"
|
||||
auth[0] = authType
|
||||
authHeader[0] = authType
|
||||
}
|
||||
|
||||
switch authType {
|
||||
case "bearer":
|
||||
request, err := m.checkJWTFromRequest(r, auth)
|
||||
request, err := m.checkJWTFromRequest(r, authHeader)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error())
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||
@@ -81,7 +82,7 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
|
||||
h.ServeHTTP(w, request)
|
||||
case "token":
|
||||
request, err := m.checkPATFromRequest(r, auth)
|
||||
request, err := m.checkPATFromRequest(r, authHeader)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
|
||||
// Check if it's a status error, otherwise default to Unauthorized
|
||||
@@ -100,8 +101,8 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
}
|
||||
|
||||
// CheckJWTFromRequest checks if the JWT is valid
|
||||
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, auth []string) (*http.Request, error) {
|
||||
token, err := getTokenFromJWTRequest(auth)
|
||||
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) {
|
||||
token, err := getTokenFromJWTRequest(authHeaderParts)
|
||||
|
||||
// If an error occurs, call the error handler and return an error
|
||||
if err != nil {
|
||||
@@ -151,8 +152,8 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, auth []string) (*h
|
||||
}
|
||||
|
||||
// CheckPATFromRequest checks if the PAT is valid
|
||||
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*http.Request, error) {
|
||||
token, err := getTokenFromPATRequest(auth)
|
||||
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) {
|
||||
token, err := getTokenFromPATRequest(authHeaderParts)
|
||||
if err != nil {
|
||||
return r, fmt.Errorf("error extracting token: %w", err)
|
||||
}
|
||||
@@ -177,7 +178,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*h
|
||||
return r, err
|
||||
}
|
||||
|
||||
userAuth := nbcontext.UserAuth{
|
||||
userAuth := auth.UserAuth{
|
||||
UserId: user.Id,
|
||||
AccountId: user.AccountID,
|
||||
Domain: accDomain,
|
||||
|
||||
Reference in New Issue
Block a user