|
|
|
@@ -12,6 +12,7 @@ import (
|
|
|
|
"go.opentelemetry.io/otel/metric"
|
|
|
|
"go.opentelemetry.io/otel/metric"
|
|
|
|
|
|
|
|
|
|
|
|
"github.com/netbirdio/management-integrations/integrations"
|
|
|
|
"github.com/netbirdio/management-integrations/integrations"
|
|
|
|
|
|
|
|
|
|
|
|
serverauth "github.com/netbirdio/netbird/management/server/auth"
|
|
|
|
serverauth "github.com/netbirdio/netbird/management/server/auth"
|
|
|
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
|
|
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
|
|
|
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
|
|
|
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
|
|
|
@@ -87,17 +88,14 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
|
|
|
|
|
|
|
|
|
|
|
switch authType {
|
|
|
|
switch authType {
|
|
|
|
case "bearer":
|
|
|
|
case "bearer":
|
|
|
|
request, err := m.checkJWTFromRequest(r, authHeader)
|
|
|
|
if err := m.checkJWTFromRequest(r, authHeader); err != nil {
|
|
|
|
if err != nil {
|
|
|
|
|
|
|
|
log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error())
|
|
|
|
log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error())
|
|
|
|
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
|
|
|
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
|
|
|
return
|
|
|
|
return
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
h.ServeHTTP(w, r)
|
|
|
|
h.ServeHTTP(w, request)
|
|
|
|
|
|
|
|
case "token":
|
|
|
|
case "token":
|
|
|
|
request, err := m.checkPATFromRequest(r, authHeader)
|
|
|
|
if err := m.checkPATFromRequest(r, authHeader); err != nil {
|
|
|
|
if err != nil {
|
|
|
|
|
|
|
|
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
|
|
|
|
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
|
|
|
|
// Check if it's a status error, otherwise default to Unauthorized
|
|
|
|
// Check if it's a status error, otherwise default to Unauthorized
|
|
|
|
if _, ok := status.FromError(err); !ok {
|
|
|
|
if _, ok := status.FromError(err); !ok {
|
|
|
|
@@ -106,7 +104,7 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
|
|
|
util.WriteError(r.Context(), err, w)
|
|
|
|
util.WriteError(r.Context(), err, w)
|
|
|
|
return
|
|
|
|
return
|
|
|
|
}
|
|
|
|
}
|
|
|
|
h.ServeHTTP(w, request)
|
|
|
|
h.ServeHTTP(w, r)
|
|
|
|
default:
|
|
|
|
default:
|
|
|
|
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
|
|
|
|
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
|
|
|
|
return
|
|
|
|
return
|
|
|
|
@@ -115,19 +113,19 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// CheckJWTFromRequest checks if the JWT is valid
|
|
|
|
// CheckJWTFromRequest checks if the JWT is valid
|
|
|
|
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) {
|
|
|
|
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) error {
|
|
|
|
token, err := getTokenFromJWTRequest(authHeaderParts)
|
|
|
|
token, err := getTokenFromJWTRequest(authHeaderParts)
|
|
|
|
|
|
|
|
|
|
|
|
// If an error occurs, call the error handler and return an error
|
|
|
|
// If an error occurs, call the error handler and return an error
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return r, fmt.Errorf("error extracting token: %w", err)
|
|
|
|
return fmt.Errorf("error extracting token: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
ctx := r.Context()
|
|
|
|
ctx := r.Context()
|
|
|
|
|
|
|
|
|
|
|
|
userAuth, validatedToken, err := m.authManager.ValidateAndParseToken(ctx, token)
|
|
|
|
userAuth, validatedToken, err := m.authManager.ValidateAndParseToken(ctx, token)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return r, err
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
|
|
|
|
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
|
|
|
|
@@ -143,7 +141,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
|
|
|
|
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
|
|
|
|
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
|
|
|
|
accountId, _, err := m.ensureAccount(ctx, userAuth)
|
|
|
|
accountId, _, err := m.ensureAccount(ctx, userAuth)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return r, err
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if userAuth.AccountId != accountId {
|
|
|
|
if userAuth.AccountId != accountId {
|
|
|
|
@@ -153,7 +151,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
|
|
|
|
|
|
|
|
|
|
|
|
userAuth, err = m.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, validatedToken)
|
|
|
|
userAuth, err = m.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, validatedToken)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return r, err
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
err = m.syncUserJWTGroups(ctx, userAuth)
|
|
|
|
err = m.syncUserJWTGroups(ctx, userAuth)
|
|
|
|
@@ -164,17 +162,19 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
|
|
|
|
_, err = m.getUserFromUserAuth(ctx, userAuth)
|
|
|
|
_, err = m.getUserFromUserAuth(ctx, userAuth)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
log.WithContext(ctx).Errorf("HTTP server failed to update user from user auth: %s", err)
|
|
|
|
log.WithContext(ctx).Errorf("HTTP server failed to update user from user auth: %s", err)
|
|
|
|
return r, err
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return nbcontext.SetUserAuthInRequest(r, userAuth), nil
|
|
|
|
// propagates ctx change to upstream middleware
|
|
|
|
|
|
|
|
*r = *nbcontext.SetUserAuthInRequest(r, userAuth)
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// CheckPATFromRequest checks if the PAT is valid
|
|
|
|
// CheckPATFromRequest checks if the PAT is valid
|
|
|
|
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) {
|
|
|
|
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) error {
|
|
|
|
token, err := getTokenFromPATRequest(authHeaderParts)
|
|
|
|
token, err := getTokenFromPATRequest(authHeaderParts)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return r, fmt.Errorf("error extracting token: %w", err)
|
|
|
|
return fmt.Errorf("error extracting token: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if m.patUsageTracker != nil {
|
|
|
|
if m.patUsageTracker != nil {
|
|
|
|
@@ -183,22 +183,22 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
|
|
|
|
|
|
|
|
|
|
|
if m.rateLimiter != nil && !isTerraformRequest(r) {
|
|
|
|
if m.rateLimiter != nil && !isTerraformRequest(r) {
|
|
|
|
if !m.rateLimiter.Allow(token) {
|
|
|
|
if !m.rateLimiter.Allow(token) {
|
|
|
|
return r, status.Errorf(status.TooManyRequests, "too many requests")
|
|
|
|
return status.Errorf(status.TooManyRequests, "too many requests")
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
ctx := r.Context()
|
|
|
|
ctx := r.Context()
|
|
|
|
user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token)
|
|
|
|
user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return r, fmt.Errorf("invalid Token: %w", err)
|
|
|
|
return fmt.Errorf("invalid Token: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if time.Now().After(pat.GetExpirationDate()) {
|
|
|
|
if time.Now().After(pat.GetExpirationDate()) {
|
|
|
|
return r, fmt.Errorf("token expired")
|
|
|
|
return fmt.Errorf("token expired")
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
err = m.authManager.MarkPATUsed(ctx, pat.ID)
|
|
|
|
err = m.authManager.MarkPATUsed(ctx, pat.ID)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return r, err
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
userAuth := auth.UserAuth{
|
|
|
|
userAuth := auth.UserAuth{
|
|
|
|
@@ -216,7 +216,9 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return nbcontext.SetUserAuthInRequest(r, userAuth), nil
|
|
|
|
// propagates ctx change to upstream middleware
|
|
|
|
|
|
|
|
*r = *nbcontext.SetUserAuthInRequest(r, userAuth)
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func isTerraformRequest(r *http.Request) bool {
|
|
|
|
func isTerraformRequest(r *http.Request) bool {
|
|
|
|
|