diff --git a/go.mod b/go.mod index 1f8eec24e..76e592f73 100644 --- a/go.mod +++ b/go.mod @@ -51,6 +51,7 @@ require ( github.com/miekg/dns v1.1.43 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 + github.com/netbirdio/management-integrations/integrations v0.0.0-20231017101406-322cbabed3da github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pion/logging v0.2.2 diff --git a/go.sum b/go.sum index 15e69283c..561d3e17e 100644 --- a/go.sum +++ b/go.sum @@ -495,6 +495,8 @@ github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRW github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= github.com/nadoo/ipset v0.5.0 h1:5GJUAuZ7ITQQQGne5J96AmFjRtI8Avlbk6CabzYWVUc= github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ= +github.com/netbirdio/management-integrations/integrations v0.0.0-20231017101406-322cbabed3da h1:S1RoPhLTw3+IhHGnyfcQlj4aqIIaQdVd3SqaiK+MYFY= +github.com/netbirdio/management-integrations/integrations v0.0.0-20231017101406-322cbabed3da/go.mod h1:KSqjzHcqlodTWiuap5lRXxt5KT3vtYRoksL0KIrTK40= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/systray v0.0.0-20221012095658-dc8eda872c0c h1:wK/s4nyZj/GF/kFJQjX6nqNfE0G3gcqd6hhnPCyp4sw= diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 6e9b029c7..0d415a087 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -6,6 +6,7 @@ import ( "github.com/gorilla/mux" "github.com/rs/cors" + "github.com/netbirdio/management-integrations/integrations" s "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/middleware" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -58,6 +59,7 @@ func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValid AuthCfg: authCfg, } + integrations.RegisterHandlers(api.Router, accountManager) api.addAccountsEndpoint() api.addPeersEndpoint() api.addUsersEndpoint() @@ -73,8 +75,8 @@ func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValid err := api.Router.Walk(func(route *mux.Route, _ *mux.Router, _ []*mux.Route) error { methods, err := route.GetMethods() - if err != nil { - return err + if err != nil { // we may have wildcard routes from integrations without methods, skip them for now + methods = []string{} } for _, method := range methods { template, err := route.GetPathTemplate() diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 710723124..99482bfb7 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -57,10 +57,17 @@ func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParse func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { auth := strings.Split(r.Header.Get("Authorization"), " ") - authType := auth[0] - switch strings.ToLower(authType) { + authType := strings.ToLower(auth[0]) + + // fallback to token when receive pat as bearer + if len(auth) >= 2 && authType == "bearer" && strings.HasPrefix(auth[1], "nbp_") { + authType = "token" + auth[0] = authType + } + + switch authType { case "bearer": - err := m.CheckJWTFromRequest(w, r) + err := m.checkJWTFromRequest(w, r, auth) if err != nil { log.Errorf("Error when validating JWT claims: %s", err.Error()) util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w) @@ -68,7 +75,7 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { } h.ServeHTTP(w, r) case "token": - err := m.CheckPATFromRequest(w, r) + err := m.checkPATFromRequest(w, r, auth) if err != nil { log.Debugf("Error when validating PAT claims: %s", err.Error()) util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w) @@ -83,9 +90,8 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { } // CheckJWTFromRequest checks if the JWT is valid -func (m *AuthMiddleware) CheckJWTFromRequest(w http.ResponseWriter, r *http.Request) error { - - token, err := getTokenFromJWTRequest(r) +func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error { + token, err := getTokenFromJWTRequest(auth) // If an error occurs, call the error handler and return an error if err != nil { @@ -110,8 +116,8 @@ func (m *AuthMiddleware) CheckJWTFromRequest(w http.ResponseWriter, r *http.Requ } // CheckPATFromRequest checks if the PAT is valid -func (m *AuthMiddleware) CheckPATFromRequest(w http.ResponseWriter, r *http.Request) error { - token, err := getTokenFromPATRequest(r) +func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error { + token, err := getTokenFromPATRequest(auth) // If an error occurs, call the error handler and return an error if err != nil { @@ -143,16 +149,9 @@ func (m *AuthMiddleware) CheckPATFromRequest(w http.ResponseWriter, r *http.Requ return nil } -// getTokenFromJWTRequest is a "TokenExtractor" that takes a give request and extracts +// getTokenFromJWTRequest is a "TokenExtractor" that takes auth header parts and extracts // the JWT token from the Authorization header. -func getTokenFromJWTRequest(r *http.Request) (string, error) { - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - return "", nil // No error, just no token - } - - // TODO: Make this a bit more robust, parsing-wise - authHeaderParts := strings.Fields(authHeader) +func getTokenFromJWTRequest(authHeaderParts []string) (string, error) { if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { return "", errors.New("Authorization header format must be Bearer {token}") } @@ -160,16 +159,9 @@ func getTokenFromJWTRequest(r *http.Request) (string, error) { return authHeaderParts[1], nil } -// getTokenFromPATRequest is a "TokenExtractor" that takes a give request and extracts +// getTokenFromPATRequest is a "TokenExtractor" that takes auth header parts and extracts // the PAT token from the Authorization header. -func getTokenFromPATRequest(r *http.Request) (string, error) { - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - return "", nil // No error, just no token - } - - // TODO: Make this a bit more robust, parsing-wise - authHeaderParts := strings.Fields(authHeader) +func getTokenFromPATRequest(authHeaderParts []string) (string, error) { if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "token" { return "", errors.New("Authorization header format must be Token {token}") } diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index 608bf42fa..55e5de260 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -19,7 +19,7 @@ const ( domain = "domain" userID = "userID" tokenID = "tokenID" - PAT = "PAT" + PAT = "nbp_PAT" JWT = "JWT" wrongToken = "wrongToken" ) @@ -82,6 +82,11 @@ func TestAuthMiddleware_Handler(t *testing.T) { authHeader: "Token " + wrongToken, expectedStatusCode: 401, }, + { + name: "Fallback to PAT Token", + authHeader: "Bearer " + PAT, + expectedStatusCode: 200, + }, { name: "Valid JWT Token", authHeader: "Bearer " + JWT,