mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
Add management-integrations (#1227)
This commit is contained in:
@@ -57,10 +57,17 @@ func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParse
|
||||
func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
auth := strings.Split(r.Header.Get("Authorization"), " ")
|
||||
authType := auth[0]
|
||||
switch strings.ToLower(authType) {
|
||||
authType := strings.ToLower(auth[0])
|
||||
|
||||
// fallback to token when receive pat as bearer
|
||||
if len(auth) >= 2 && authType == "bearer" && strings.HasPrefix(auth[1], "nbp_") {
|
||||
authType = "token"
|
||||
auth[0] = authType
|
||||
}
|
||||
|
||||
switch authType {
|
||||
case "bearer":
|
||||
err := m.CheckJWTFromRequest(w, r)
|
||||
err := m.checkJWTFromRequest(w, r, auth)
|
||||
if err != nil {
|
||||
log.Errorf("Error when validating JWT claims: %s", err.Error())
|
||||
util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||
@@ -68,7 +75,7 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
}
|
||||
h.ServeHTTP(w, r)
|
||||
case "token":
|
||||
err := m.CheckPATFromRequest(w, r)
|
||||
err := m.checkPATFromRequest(w, r, auth)
|
||||
if err != nil {
|
||||
log.Debugf("Error when validating PAT claims: %s", err.Error())
|
||||
util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||
@@ -83,9 +90,8 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
}
|
||||
|
||||
// CheckJWTFromRequest checks if the JWT is valid
|
||||
func (m *AuthMiddleware) CheckJWTFromRequest(w http.ResponseWriter, r *http.Request) error {
|
||||
|
||||
token, err := getTokenFromJWTRequest(r)
|
||||
func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error {
|
||||
token, err := getTokenFromJWTRequest(auth)
|
||||
|
||||
// If an error occurs, call the error handler and return an error
|
||||
if err != nil {
|
||||
@@ -110,8 +116,8 @@ func (m *AuthMiddleware) CheckJWTFromRequest(w http.ResponseWriter, r *http.Requ
|
||||
}
|
||||
|
||||
// CheckPATFromRequest checks if the PAT is valid
|
||||
func (m *AuthMiddleware) CheckPATFromRequest(w http.ResponseWriter, r *http.Request) error {
|
||||
token, err := getTokenFromPATRequest(r)
|
||||
func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error {
|
||||
token, err := getTokenFromPATRequest(auth)
|
||||
|
||||
// If an error occurs, call the error handler and return an error
|
||||
if err != nil {
|
||||
@@ -143,16 +149,9 @@ func (m *AuthMiddleware) CheckPATFromRequest(w http.ResponseWriter, r *http.Requ
|
||||
return nil
|
||||
}
|
||||
|
||||
// getTokenFromJWTRequest is a "TokenExtractor" that takes a give request and extracts
|
||||
// getTokenFromJWTRequest is a "TokenExtractor" that takes auth header parts and extracts
|
||||
// the JWT token from the Authorization header.
|
||||
func getTokenFromJWTRequest(r *http.Request) (string, error) {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
return "", nil // No error, just no token
|
||||
}
|
||||
|
||||
// TODO: Make this a bit more robust, parsing-wise
|
||||
authHeaderParts := strings.Fields(authHeader)
|
||||
func getTokenFromJWTRequest(authHeaderParts []string) (string, error) {
|
||||
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
|
||||
return "", errors.New("Authorization header format must be Bearer {token}")
|
||||
}
|
||||
@@ -160,16 +159,9 @@ func getTokenFromJWTRequest(r *http.Request) (string, error) {
|
||||
return authHeaderParts[1], nil
|
||||
}
|
||||
|
||||
// getTokenFromPATRequest is a "TokenExtractor" that takes a give request and extracts
|
||||
// getTokenFromPATRequest is a "TokenExtractor" that takes auth header parts and extracts
|
||||
// the PAT token from the Authorization header.
|
||||
func getTokenFromPATRequest(r *http.Request) (string, error) {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
return "", nil // No error, just no token
|
||||
}
|
||||
|
||||
// TODO: Make this a bit more robust, parsing-wise
|
||||
authHeaderParts := strings.Fields(authHeader)
|
||||
func getTokenFromPATRequest(authHeaderParts []string) (string, error) {
|
||||
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "token" {
|
||||
return "", errors.New("Authorization header format must be Token {token}")
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ const (
|
||||
domain = "domain"
|
||||
userID = "userID"
|
||||
tokenID = "tokenID"
|
||||
PAT = "PAT"
|
||||
PAT = "nbp_PAT"
|
||||
JWT = "JWT"
|
||||
wrongToken = "wrongToken"
|
||||
)
|
||||
@@ -82,6 +82,11 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
authHeader: "Token " + wrongToken,
|
||||
expectedStatusCode: 401,
|
||||
},
|
||||
{
|
||||
name: "Fallback to PAT Token",
|
||||
authHeader: "Bearer " + PAT,
|
||||
expectedStatusCode: 200,
|
||||
},
|
||||
{
|
||||
name: "Valid JWT Token",
|
||||
authHeader: "Bearer " + JWT,
|
||||
|
||||
Reference in New Issue
Block a user