From db3a9f0aa2d294c4a010f70b60dc9226d7aaf7ee Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Thu, 30 Mar 2023 10:54:09 +0200 Subject: [PATCH 01/15] refactor jwt token validation and add PAT to middleware auth --- management/cmd/management.go | 26 +- management/server/account.go | 28 ++ management/server/grpcserver.go | 25 +- management/server/http/handler.go | 20 +- .../server/http/middleware/auth_middleware.go | 170 ++++++++++++ .../http/middleware/auth_midleware_test.go | 1 + management/server/http/middleware/jwt.go | 249 ------------------ management/server/http/util/util.go | 8 +- .../handler.go => jwtclaims/jwtValidator.go} | 90 ++++++- management/server/mock_server/account_mock.go | 9 + 10 files changed, 341 insertions(+), 285 deletions(-) create mode 100644 management/server/http/middleware/auth_middleware.go create mode 100644 management/server/http/middleware/auth_midleware_test.go delete mode 100644 management/server/http/middleware/jwt.go rename management/server/{http/middleware/handler.go => jwtclaims/jwtValidator.go} (55%) diff --git a/management/cmd/management.go b/management/cmd/management.go index f3210d88e..620a89f16 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -19,25 +19,28 @@ import ( "github.com/google/uuid" "github.com/miekg/dns" - "github.com/netbirdio/netbird/management/server/activity/sqlite" - httpapi "github.com/netbirdio/netbird/management/server/http" - "github.com/netbirdio/netbird/management/server/metrics" - "github.com/netbirdio/netbird/management/server/telemetry" "golang.org/x/crypto/acme/autocert" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" + "github.com/netbirdio/netbird/management/server/activity/sqlite" + httpapi "github.com/netbirdio/netbird/management/server/http" + "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/metrics" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/util" - "github.com/netbirdio/netbird/encryption" - mgmtProto "github.com/netbirdio/netbird/management/proto" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/keepalive" + + "github.com/netbirdio/netbird/encryption" + mgmtProto "github.com/netbirdio/netbird/management/proto" ) // ManagementLegacyPort is the port that was used before by the Management gRPC server. @@ -179,13 +182,22 @@ var ( tlsEnabled = true } + jwtValidator, err := jwtclaims.NewJWTValidator( + config.HttpConfig.AuthIssuer, + config.HttpConfig.AuthAudience, + config.HttpConfig.AuthKeysLocation, + ) + if err != nil { + return fmt.Errorf("failed creating JWT validator: %v", err) + } + httpAPIAuthCfg := httpapi.AuthCfg{ Issuer: config.HttpConfig.AuthIssuer, Audience: config.HttpConfig.AuthAudience, UserIDClaim: config.HttpConfig.AuthUserIDClaim, KeysLocation: config.HttpConfig.AuthKeysLocation, } - httpAPIHandler, err := httpapi.APIHandler(accountManager, appMetrics, httpAPIAuthCfg) + httpAPIHandler, err := httpapi.APIHandler(accountManager, *jwtValidator, appMetrics, httpAPIAuthCfg) if err != nil { return fmt.Errorf("failed creating HTTP API handler: %v", err) } diff --git a/management/server/account.go b/management/server/account.go index 3db6a2fe0..e2cdf40b2 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -56,6 +56,7 @@ type AccountManager interface { GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error) GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error) + MarkPATUsed(tokenID string) error IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) AccountExists(accountId string) (*bool, error) GetPeerByKey(peerKey string) (*Peer, error) @@ -1120,6 +1121,33 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e return nil } +func (am *DefaultAccountManager) MarkPATUsed(tokenID string) error { + unlock := am.Store.AcquireGlobalLock() + defer unlock() + + user, err := am.Store.GetUserByTokenID(tokenID) + log.Debugf("User: %v", user) + if err != nil { + return err + } + + account, err := am.Store.GetAccountByUser(user.Id) + if err != nil { + return err + } + + pat, ok := account.Users[user.Id].PATs[tokenID] + if !ok { + return fmt.Errorf("token not found") + } + + pat.LastUsed = time.Now() + + am.Store.SaveAccount(account) + + return nil +} + // GetAccountFromPAT returns Account and User associated with a personal access token func (am *DefaultAccountManager) GetAccountFromPAT(token string) (*Account, *User, *PersonalAccessToken, error) { if len(token) != PATLength { diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index fa0e49ed3..34f74cff9 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -3,24 +3,25 @@ package server import ( "context" "fmt" - pb "github.com/golang/protobuf/proto" //nolint "strings" "time" + pb "github.com/golang/protobuf/proto" // nolint + "github.com/netbirdio/netbird/management/server/telemetry" - "github.com/netbirdio/netbird/management/server/http/middleware" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/golang/protobuf/ptypes/timestamp" - "github.com/netbirdio/netbird/encryption" - "github.com/netbirdio/netbird/management/proto" - internalStatus "github.com/netbirdio/netbird/management/server/status" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc/codes" gRPCPeer "google.golang.org/grpc/peer" "google.golang.org/grpc/status" + + "github.com/netbirdio/netbird/encryption" + "github.com/netbirdio/netbird/management/proto" + internalStatus "github.com/netbirdio/netbird/management/server/status" ) // GRPCServer an instance of a Management gRPC API server @@ -31,7 +32,7 @@ type GRPCServer struct { peersUpdateManager *PeersUpdateManager config *Config turnCredentialsManager TURNCredentialsManager - jwtMiddleware *middleware.JWTMiddleware + jwtValidator *jwtclaims.JWTValidator jwtClaimsExtractor *jwtclaims.ClaimsExtractor appMetrics telemetry.AppMetrics } @@ -45,10 +46,10 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager return nil, err } - var jwtMiddleware *middleware.JWTMiddleware + var jwtValidator *jwtclaims.JWTValidator if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) { - jwtMiddleware, err = middleware.NewJwtMiddleware( + jwtValidator, err = jwtclaims.NewJWTValidator( config.HttpConfig.AuthIssuer, config.HttpConfig.AuthAudience, config.HttpConfig.AuthKeysLocation) @@ -86,7 +87,7 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager accountManager: accountManager, config: config, turnCredentialsManager: turnCredentialsManager, - jwtMiddleware: jwtMiddleware, + jwtValidator: jwtValidator, jwtClaimsExtractor: jwtClaimsExtractor, appMetrics: appMetrics, }, nil @@ -187,11 +188,11 @@ func (s *GRPCServer) cancelPeerRoutines(peer *Peer) { } func (s *GRPCServer) validateToken(jwtToken string) (string, error) { - if s.jwtMiddleware == nil { - return "", status.Error(codes.Internal, "no jwt middleware set") + if s.jwtValidator == nil { + return "", status.Error(codes.Internal, "no jwt validator set") } - token, err := s.jwtMiddleware.ValidateAndParse(jwtToken) + token, err := s.jwtValidator.ValidateAndParse(jwtToken) if err != nil { return "", status.Errorf(codes.InvalidArgument, "invalid jwt token, err: %v", err) } diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 79028e6d2..d8117a436 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -8,6 +8,7 @@ import ( s "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/middleware" + "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/telemetry" ) @@ -25,18 +26,17 @@ type apiHandler struct { AuthCfg AuthCfg } +// EmptyObject is an empty struct used to return empty JSON object type emptyObject struct { } // APIHandler creates the Management service HTTP API handler registering all the available endpoints. -func APIHandler(accountManager s.AccountManager, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) { - jwtMiddleware, err := middleware.NewJwtMiddleware( - authCfg.Issuer, - authCfg.Audience, - authCfg.KeysLocation) - if err != nil { - return nil, err - } +func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) { + authMiddleware := middleware.NewAuthMiddleware( + accountManager.GetAccountFromPAT, + jwtValidator.ValidateAndParse, + accountManager.MarkPATUsed, + authCfg.Audience) corsMiddleware := cors.AllowAll() @@ -49,7 +49,7 @@ func APIHandler(accountManager s.AccountManager, appMetrics telemetry.AppMetrics metricsMiddleware := appMetrics.HTTPMiddleware() router := rootRouter.PathPrefix("/api").Subrouter() - router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, jwtMiddleware.Handler, acMiddleware.Handler) + router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler, acMiddleware.Handler) api := apiHandler{ Router: router, @@ -70,7 +70,7 @@ func APIHandler(accountManager s.AccountManager, appMetrics telemetry.AppMetrics api.addDNSSettingEndpoint() api.addEventsEndpoint() - err = api.Router.Walk(func(route *mux.Route, _ *mux.Router, _ []*mux.Route) error { + err := api.Router.Walk(func(route *mux.Route, _ *mux.Router, _ []*mux.Route) error { methods, err := route.GetMethods() if err != nil { return err diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go new file mode 100644 index 000000000..0b1478ef3 --- /dev/null +++ b/management/server/http/middleware/auth_middleware.go @@ -0,0 +1,170 @@ +package middleware + +import ( + "context" + "errors" + "fmt" + "net/http" + "strings" + "time" + + "github.com/golang-jwt/jwt" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/http/util" + "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/status" +) + +type GetAccountFromPATFunc func(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) +type ValidateAndParseTokenFunc func(token string) (*jwt.Token, error) +type MarkPATUsedFunc func(token string) error + +// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens +type AuthMiddleware struct { + getAccountFromPAT GetAccountFromPATFunc + validateAndParseToken ValidateAndParseTokenFunc + markPATUsed MarkPATUsedFunc + audience string +} + +const ( + userProperty = "user" +) + +// NewAuthMiddleware instance constructor +func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc, markPATUsed MarkPATUsedFunc, audience string) *AuthMiddleware { + return &AuthMiddleware{ + getAccountFromPAT: getAccountFromPAT, + validateAndParseToken: validateAndParseToken, + markPATUsed: markPATUsed, + audience: audience, + } +} + +// Handler method of the middleware which authenticates a user either by JWT claims or by PAT +func (a *AuthMiddleware) Handler(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + auth := strings.Split(r.Header.Get("Authorization"), " ") + authType := auth[0] + switch strings.ToLower(authType) { + case "bearer": + err := a.CheckJWTFromRequest(w, r) + if err != nil { + log.Debugf("Error when validating JWT claims: %s", err.Error()) + util.WriteError(status.Errorf(status.Unauthorized, "Token invalid"), w) + return + } + h.ServeHTTP(w, r) + case "token": + err := a.CheckPATFromRequest(w, r) + if err != nil { + log.Debugf("Error when validating PAT claims: %s", err.Error()) + util.WriteError(status.Errorf(status.Unauthorized, "Token invalid"), w) + return + } + h.ServeHTTP(w, r) + default: + util.WriteError(status.Errorf(status.Unauthorized, "No valid authentication provided"), w) + return + } + }) +} + +// CheckJWTFromRequest checks if the JWT is valid +func (m *AuthMiddleware) CheckJWTFromRequest(w http.ResponseWriter, r *http.Request) error { + + token, err := getTokenFromJWTRequest(r) + + // If an error occurs, call the error handler and return an error + if err != nil { + return fmt.Errorf("Error extracting token: %w", err) + } + + validatedToken, err := m.validateAndParseToken(token) + if err != nil { + return err + } + + if validatedToken == nil { + return nil + } + + // If we get here, everything worked and we can set the + // user property in context. + newRequest := r.WithContext(context.WithValue(r.Context(), userProperty, validatedToken)) // nolint + // Update the current request with the new context information. + *r = *newRequest + return nil +} + +// CheckPATFromRequest checks if the PAT is valid +func (m *AuthMiddleware) CheckPATFromRequest(w http.ResponseWriter, r *http.Request) error { + token, err := getTokenFromPATRequest(r) + + // If an error occurs, call the error handler and return an error + if err != nil { + return fmt.Errorf("Error extracting token: %w", err) + } + + account, user, pat, err := m.getAccountFromPAT(token) + if err != nil { + util.WriteError(status.Errorf(status.Unauthorized, "Token invalid"), w) + return fmt.Errorf("invalid Token: %w", err) + } + if time.Now().After(pat.ExpirationDate) { + util.WriteError(status.Errorf(status.Unauthorized, "Token expired"), w) + return fmt.Errorf("token expired") + } + + err = m.markPATUsed(pat.ID) + if err != nil { + return err + } + + claimMaps := jwt.MapClaims{} + claimMaps[jwtclaims.UserIDClaim] = user.Id + claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id + claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain + claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory + jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) + newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) + // Update the current request with the new context information. + *r = *newRequest + return nil +} + +// getTokenFromJWTRequest is a "TokenExtractor" that takes a give request and extracts +// the JWT token from the Authorization header. +func getTokenFromJWTRequest(r *http.Request) (string, error) { + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + return "", nil // No error, just no token + } + + // TODO: Make this a bit more robust, parsing-wise + authHeaderParts := strings.Fields(authHeader) + if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { + return "", errors.New("Authorization header format must be Bearer {token}") + } + + return authHeaderParts[1], nil +} + +// getTokenFromPATRequest is a "TokenExtractor" that takes a give request and extracts +// the PAT token from the Authorization header. +func getTokenFromPATRequest(r *http.Request) (string, error) { + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + return "", nil // No error, just no token + } + + // TODO: Make this a bit more robust, parsing-wise + authHeaderParts := strings.Fields(authHeader) + if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "token" { + return "", errors.New("Authorization header format must be Token {token}") + } + + return authHeaderParts[1], nil +} diff --git a/management/server/http/middleware/auth_midleware_test.go b/management/server/http/middleware/auth_midleware_test.go new file mode 100644 index 000000000..c870d7c16 --- /dev/null +++ b/management/server/http/middleware/auth_midleware_test.go @@ -0,0 +1 @@ +package middleware diff --git a/management/server/http/middleware/jwt.go b/management/server/http/middleware/jwt.go deleted file mode 100644 index feb00ec86..000000000 --- a/management/server/http/middleware/jwt.go +++ /dev/null @@ -1,249 +0,0 @@ -package middleware - -import ( - "context" - "errors" - "fmt" - "github.com/golang-jwt/jwt" - "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/status" - "log" - "net/http" - "strings" -) - -// A function called whenever an error is encountered -type errorHandler func(w http.ResponseWriter, r *http.Request, err string) - -// TokenExtractor is a function that takes a request as input and returns -// either a token or an error. An error should only be returned if an attempt -// to specify a token was found, but the information was somehow incorrectly -// formed. In the case where a token is simply not present, this should not -// be treated as an error. An empty string should be returned in that case. -type TokenExtractor func(r *http.Request) (string, error) - -// Options is a struct for specifying configuration options for the middleware. -type Options struct { - // The function that will return the Key to validate the JWT. - // It can be either a shared secret or a public key. - // Default value: nil - ValidationKeyGetter jwt.Keyfunc - // The name of the property in the request where the user information - // from the JWT will be stored. - // Default value: "user" - UserProperty string - // The function that will be called when there's an error validating the token - // Default value: - ErrorHandler errorHandler - // A boolean indicating if the credentials are required or not - // Default value: false - CredentialsOptional bool - // A function that extracts the token from the request - // Default: FromAuthHeader (i.e., from Authorization header as bearer token) - Extractor TokenExtractor - // Debug flag turns on debugging output - // Default: false - Debug bool - // When set, all requests with the OPTIONS method will use authentication - // Default: false - EnableAuthOnOptions bool - // When set, the middelware verifies that tokens are signed with the specific signing algorithm - // If the signing method is not constant the ValidationKeyGetter callback can be used to implement additional checks - // Important to avoid security issues described here: https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/ - // Default: nil - SigningMethod jwt.SigningMethod -} - -type JWTMiddleware struct { - Options Options -} - -func OnError(w http.ResponseWriter, r *http.Request, err string) { - util.WriteError(status.Errorf(status.Unauthorized, ""), w) -} - -// New constructs a new Secure instance with supplied options. -func New(options ...Options) *JWTMiddleware { - - var opts Options - if len(options) == 0 { - opts = Options{} - } else { - opts = options[0] - } - - if opts.UserProperty == "" { - opts.UserProperty = "user" - } - - if opts.ErrorHandler == nil { - opts.ErrorHandler = OnError - } - - if opts.Extractor == nil { - opts.Extractor = FromAuthHeader - } - - return &JWTMiddleware{ - Options: opts, - } -} - -func (m *JWTMiddleware) logf(format string, args ...interface{}) { - if m.Options.Debug { - log.Printf(format, args...) - } -} - -// HandlerWithNext is a special implementation for Negroni, but could be used elsewhere. -func (m *JWTMiddleware) HandlerWithNext(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { - err := m.CheckJWTFromRequest(w, r) - - // If there was an error, do not call next. - if err == nil && next != nil { - next(w, r) - } -} - -func (m *JWTMiddleware) Handler(h http.Handler) http.Handler { - - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Let secure process the request. If it returns an error, - // that indicates the request should not continue. - err := m.CheckJWTFromRequest(w, r) - - // If there was an error, do not continue. - if err != nil { - return - } - - h.ServeHTTP(w, r) - }) -} - -// FromAuthHeader is a "TokenExtractor" that takes a give request and extracts -// the JWT token from the Authorization header. -func FromAuthHeader(r *http.Request) (string, error) { - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - return "", nil // No error, just no token - } - - // TODO: Make this a bit more robust, parsing-wise - authHeaderParts := strings.Fields(authHeader) - if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { - return "", errors.New("Authorization header format must be Bearer {token}") - } - - return authHeaderParts[1], nil -} - -// FromParameter returns a function that extracts the token from the specified -// query string parameter -func FromParameter(param string) TokenExtractor { - return func(r *http.Request) (string, error) { - return r.URL.Query().Get(param), nil - } -} - -// FromFirst returns a function that runs multiple token extractors and takes the -// first token it finds -func FromFirst(extractors ...TokenExtractor) TokenExtractor { - return func(r *http.Request) (string, error) { - for _, ex := range extractors { - token, err := ex(r) - if err != nil { - return "", err - } - if token != "" { - return token, nil - } - } - return "", nil - } -} - -func (m *JWTMiddleware) CheckJWTFromRequest(w http.ResponseWriter, r *http.Request) error { - if !m.Options.EnableAuthOnOptions { - if r.Method == "OPTIONS" { - return nil - } - } - - // Use the specified token extractor to extract a token from the request - token, err := m.Options.Extractor(r) - - // If debugging is turned on, log the outcome - if err != nil { - m.logf("Error extracting JWT: %v", err) - } else { - m.logf("Token extracted: %s", token) - } - - // If an error occurs, call the error handler and return an error - if err != nil { - m.Options.ErrorHandler(w, r, err.Error()) - return fmt.Errorf("Error extracting token: %w", err) - } - - validatedToken, err := m.ValidateAndParse(token) - if err != nil { - m.Options.ErrorHandler(w, r, err.Error()) - return err - } - - if validatedToken == nil { - return nil - } - - // If we get here, everything worked and we can set the - // user property in context. - newRequest := r.WithContext(context.WithValue(r.Context(), m.Options.UserProperty, validatedToken)) //nolint - // Update the current request with the new context information. - *r = *newRequest - return nil -} - -// ValidateAndParse validates and parses a given access token against jwt standards and signing methods -func (m *JWTMiddleware) ValidateAndParse(token string) (*jwt.Token, error) { - // If the token is empty... - if token == "" { - // Check if it was required - if m.Options.CredentialsOptional { - m.logf("no credentials found (CredentialsOptional=true)") - // No error, just no token (and that is ok given that CredentialsOptional is true) - return nil, nil - } - - // If we get here, the required token is missing - errorMsg := "required authorization token not found" - m.logf(" Error: No credentials found (CredentialsOptional=false)") - return nil, fmt.Errorf(errorMsg) - } - - // Now parse the token - parsedToken, err := jwt.Parse(token, m.Options.ValidationKeyGetter) - - // Check if there was an error in parsing... - if err != nil { - m.logf("error parsing token: %v", err) - return nil, fmt.Errorf("Error parsing token: %w", err) - } - - if m.Options.SigningMethod != nil && m.Options.SigningMethod.Alg() != parsedToken.Header["alg"] { - errorMsg := fmt.Sprintf("Expected %s signing method but token specified %s", - m.Options.SigningMethod.Alg(), - parsedToken.Header["alg"]) - m.logf("error validating token algorithm: %s", errorMsg) - return nil, fmt.Errorf("error validating token algorithm: %s", errorMsg) - } - - // Check if the parsed token is valid... - if !parsedToken.Valid { - errorMsg := "token is invalid" - m.logf(errorMsg) - return nil, errors.New(errorMsg) - } - - return parsedToken, nil -} diff --git a/management/server/http/util/util.go b/management/server/http/util/util.go index 0055511a2..c40daa1a3 100644 --- a/management/server/http/util/util.go +++ b/management/server/http/util/util.go @@ -4,10 +4,12 @@ import ( "encoding/json" "errors" "fmt" - "github.com/netbirdio/netbird/management/server/status" - log "github.com/sirupsen/logrus" "net/http" "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/server/status" ) // WriteJSONObject simply writes object to the HTTP reponse in JSON format @@ -93,6 +95,8 @@ func WriteError(err error, w http.ResponseWriter) { httpStatus = http.StatusInternalServerError case status.InvalidArgument: httpStatus = http.StatusUnprocessableEntity + case status.Unauthorized: + httpStatus = http.StatusUnauthorized default: } msg = err.Error() diff --git a/management/server/http/middleware/handler.go b/management/server/jwtclaims/jwtValidator.go similarity index 55% rename from management/server/http/middleware/handler.go rename to management/server/jwtclaims/jwtValidator.go index c647506bc..d324c5ab3 100644 --- a/management/server/http/middleware/handler.go +++ b/management/server/jwtclaims/jwtValidator.go @@ -1,4 +1,4 @@ -package middleware +package jwtclaims import ( "bytes" @@ -17,6 +17,32 @@ import ( log "github.com/sirupsen/logrus" ) +// Options is a struct for specifying configuration options for the middleware. +type Options struct { + // The function that will return the Key to validate the JWT. + // It can be either a shared secret or a public key. + // Default value: nil + ValidationKeyGetter jwt.Keyfunc + // The name of the property in the request where the user information + // from the JWT will be stored. + // Default value: "user" + UserProperty string + // The function that will be called when there's an error validating the token + // Default value: + CredentialsOptional bool + // A function that extracts the token from the request + // Default: FromAuthHeader (i.e., from Authorization header as bearer token) + Debug bool + // When set, all requests with the OPTIONS method will use authentication + // Default: false + EnableAuthOnOptions bool + // When set, the middelware verifies that tokens are signed with the specific signing algorithm + // If the signing method is not constant the ValidationKeyGetter callback can be used to implement additional checks + // Important to avoid security issues described here: https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/ + // Default: nil + SigningMethod jwt.SigningMethod +} + // Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation type Jwks struct { Keys []JSONWebKey `json:"keys"` @@ -32,14 +58,17 @@ type JSONWebKey struct { X5c []string `json:"x5c"` } -// NewJwtMiddleware creates new middleware to verify the JWT token sent via Authorization header -func NewJwtMiddleware(issuer string, audience string, keysLocation string) (*JWTMiddleware, error) { +type JWTValidator struct { + options Options +} + +func NewJWTValidator(issuer string, audience string, keysLocation string) (*JWTValidator, error) { keys, err := getPemKeys(keysLocation) if err != nil { return nil, err } - return New(Options{ + options := Options{ ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) { // Verify 'aud' claim checkAud := token.Claims.(jwt.MapClaims).VerifyAudience(audience, false) @@ -62,7 +91,58 @@ func NewJwtMiddleware(issuer string, audience string, keysLocation string) (*JWT }, SigningMethod: jwt.SigningMethodRS256, EnableAuthOnOptions: false, - }), nil + } + + if options.UserProperty == "" { + options.UserProperty = "user" + } + + return &JWTValidator{ + options: options, + }, nil +} + +func (m *JWTValidator) ValidateAndParse(token string) (*jwt.Token, error) { + // If the token is empty... + if token == "" { + // Check if it was required + if m.options.CredentialsOptional { + log.Debugf("no credentials found (CredentialsOptional=true)") + // No error, just no token (and that is ok given that CredentialsOptional is true) + return nil, nil + } + + // If we get here, the required token is missing + errorMsg := "required authorization token not found" + log.Debugf(" Error: No credentials found (CredentialsOptional=false)") + return nil, fmt.Errorf(errorMsg) + } + + // Now parse the token + parsedToken, err := jwt.Parse(token, m.options.ValidationKeyGetter) + + // Check if there was an error in parsing... + if err != nil { + log.Debugf("error parsing token: %v", err) + return nil, fmt.Errorf("Error parsing token: %w", err) + } + + if m.options.SigningMethod != nil && m.options.SigningMethod.Alg() != parsedToken.Header["alg"] { + errorMsg := fmt.Sprintf("Expected %s signing method but token specified %s", + m.options.SigningMethod.Alg(), + parsedToken.Header["alg"]) + log.Debugf("error validating token algorithm: %s", errorMsg) + return nil, fmt.Errorf("error validating token algorithm: %s", errorMsg) + } + + // Check if the parsed token is valid... + if !parsedToken.Valid { + errorMsg := "token is invalid" + log.Debugf(errorMsg) + return nil, errors.New(errorMsg) + } + + return parsedToken, nil } func getPemKeys(keysLocation string) (*Jwks, error) { diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 71870fd84..c91201344 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -48,6 +48,7 @@ type MockAccountManager struct { ListPoliciesFunc func(accountID, userID string) ([]*server.Policy, error) GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error) GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) + MarkPATUsedFunc func(pat string) error UpdatePeerMetaFunc func(peerID string, meta server.PeerSystemMeta) error UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error UpdatePeerFunc func(accountID, userID string, peer *server.Peer) (*server.Peer, error) @@ -186,6 +187,14 @@ func (am *MockAccountManager) GetAccountFromPAT(pat string) (*server.Account, *s return nil, nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromPAT is not implemented") } +// MarkPATUsed mock implementation of MarkPATUsed from server.AccountManager interface +func (am *MockAccountManager) MarkPATUsed(pat string) error { + if am.MarkPATUsedFunc != nil { + return am.MarkPATUsedFunc(pat) + } + return status.Errorf(codes.Unimplemented, "method MarkPATUsed is not implemented") +} + // AddPATToUser mock implementation of AddPATToUser from server.AccountManager interface func (am *MockAccountManager) AddPATToUser(accountID string, userID string, pat *server.PersonalAccessToken) error { if am.AddPATToUserFunc != nil { From 6c8bb6063278a21b78915aa1e0ee9c3407861b91 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Thu, 30 Mar 2023 16:06:46 +0200 Subject: [PATCH 02/15] fix merge --- management/server/http/middleware/jwt.go | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 management/server/http/middleware/jwt.go diff --git a/management/server/http/middleware/jwt.go b/management/server/http/middleware/jwt.go deleted file mode 100644 index e69de29bb..000000000 From e869882da11bb9e77627219264ad323f1c7f43f6 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Thu, 30 Mar 2023 16:14:51 +0200 Subject: [PATCH 03/15] fix merge --- management/server/grpcserver.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 958b6729e..0c8dad246 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -46,10 +46,10 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager return nil, err } - var jwtMiddleware *middleware.JWTMiddleware + var jwtValidator *jwtclaims.JWTValidator if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) { - jwtMiddleware, err = middleware.NewJwtMiddleware( + jwtValidator, err = jwtclaims.NewJWTValidator( config.HttpConfig.AuthIssuer, config.HttpConfig.AuthAudience, config.HttpConfig.AuthKeysLocation) @@ -87,7 +87,7 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager accountManager: accountManager, config: config, turnCredentialsManager: turnCredentialsManager, - jwtMiddleware: jwtMiddleware, + jwtValidator: jwtValidator, jwtClaimsExtractor: jwtClaimsExtractor, appMetrics: appMetrics, }, nil @@ -188,11 +188,11 @@ func (s *GRPCServer) cancelPeerRoutines(peer *Peer) { } func (s *GRPCServer) validateToken(jwtToken string) (string, error) { - if s.jwtMiddleware == nil { - return "", status.Error(codes.Internal, "no jwt middleware set") + if s.jwtValidator == nil { + return "", status.Error(codes.Internal, "no jwt validator set") } - token, err := s.jwtMiddleware.ValidateAndParse(jwtToken) + token, err := s.jwtValidator.ValidateAndParse(jwtToken) if err != nil { return "", status.Errorf(codes.InvalidArgument, "invalid jwt token, err: %v", err) } From 2a799957064de411d4d962d94be418071ea601d1 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Thu, 30 Mar 2023 16:22:15 +0200 Subject: [PATCH 04/15] fix linter --- management/server/account.go | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index ae61d34ec..be51e745d 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1128,7 +1128,6 @@ func (am *DefaultAccountManager) MarkPATUsed(tokenID string) error { defer unlock() user, err := am.Store.GetUserByTokenID(tokenID) - log.Debugf("User: %v", user) if err != nil { return err } @@ -1145,9 +1144,7 @@ func (am *DefaultAccountManager) MarkPATUsed(tokenID string) error { pat.LastUsed = time.Now() - am.Store.SaveAccount(account) - - return nil + return am.Store.SaveAccount(account) } // GetAccountFromPAT returns Account and User associated with a personal access token From 1343a3f00ebbdac6e4eb91fa66076986400daf40 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Thu, 30 Mar 2023 16:43:39 +0200 Subject: [PATCH 05/15] add test + codacy --- management/server/account_test.go | 38 +++++++++++++++++++ .../server/http/middleware/auth_middleware.go | 8 ++-- .../http/middleware/auth_midleware_test.go | 1 - management/server/jwtclaims/extractor.go | 20 +++++----- management/server/jwtclaims/extractor_test.go | 12 +++--- 5 files changed, 59 insertions(+), 20 deletions(-) delete mode 100644 management/server/http/middleware/auth_midleware_test.go diff --git a/management/server/account_test.go b/management/server/account_test.go index 25d501c8b..f21c93f0e 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -495,6 +495,44 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) { assert.Equal(t, account.Users["someUser"].PATs["tokenId"], pat) } +func TestDefaultAccountManager_MarkPATUsed(t *testing.T) { + store := newStore(t) + account := newAccountWithId("account_id", "testuser", "") + + token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" + hashedToken := sha256.Sum256([]byte(token)) + encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) + account.Users["someUser"] = &User{ + Id: "someUser", + PATs: map[string]*PersonalAccessToken{ + "tokenId": { + ID: "tokenId", + HashedToken: encodedHashedToken, + LastUsed: time.Time{}, + }, + }, + } + err := store.SaveAccount(account) + if err != nil { + t.Fatalf("Error when saving account: %s", err) + } + + am := DefaultAccountManager{ + Store: store, + } + + err = am.MarkPATUsed("tokenId") + if err != nil { + t.Fatalf("Error when marking PAT used: %s", err) + } + + account, err = am.Store.GetAccount("account_id") + if err != nil { + t.Fatalf("Error when getting account: %s", err) + } + assert.True(t, !account.Users["someUser"].PATs["tokenId"].LastUsed.IsZero()) +} + func TestAccountManager_PrivateAccount(t *testing.T) { manager, err := createManager(t) if err != nil { diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 0b1478ef3..54889466e 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -124,10 +124,10 @@ func (m *AuthMiddleware) CheckPATFromRequest(w http.ResponseWriter, r *http.Requ } claimMaps := jwt.MapClaims{} - claimMaps[jwtclaims.UserIDClaim] = user.Id - claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id - claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain - claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory + claimMaps[string(jwtclaims.UserIDClaim)] = user.Id + claimMaps[m.audience+string(jwtclaims.AccountIDSuffix)] = account.Id + claimMaps[m.audience+string(jwtclaims.DomainIDSuffix)] = account.Domain + claimMaps[m.audience+string(jwtclaims.DomainCategorySuffix)] = account.DomainCategory jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) // Update the current request with the new context information. diff --git a/management/server/http/middleware/auth_midleware_test.go b/management/server/http/middleware/auth_midleware_test.go deleted file mode 100644 index c870d7c16..000000000 --- a/management/server/http/middleware/auth_midleware_test.go +++ /dev/null @@ -1 +0,0 @@ -package middleware diff --git a/management/server/jwtclaims/extractor.go b/management/server/jwtclaims/extractor.go index 9d60da335..5063d7b91 100644 --- a/management/server/jwtclaims/extractor.go +++ b/management/server/jwtclaims/extractor.go @@ -6,12 +6,14 @@ import ( "github.com/golang-jwt/jwt" ) +type key string + const ( - TokenUserProperty = "user" - AccountIDSuffix = "wt_account_id" - DomainIDSuffix = "wt_account_domain" - DomainCategorySuffix = "wt_account_domain_category" - UserIDClaim = "sub" + TokenUserProperty key = "user" + AccountIDSuffix key = "wt_account_id" + DomainIDSuffix key = "wt_account_domain" + DomainCategorySuffix key = "wt_account_domain_category" + UserIDClaim key = "sub" ) // Extract function type @@ -60,7 +62,7 @@ func NewClaimsExtractor(options ...ClaimsExtractorOption) *ClaimsExtractor { ce.FromRequestContext = ce.fromRequestContext } if ce.userIDClaim == "" { - ce.userIDClaim = UserIDClaim + ce.userIDClaim = string(UserIDClaim) } return ce } @@ -74,15 +76,15 @@ func (c *ClaimsExtractor) FromToken(token *jwt.Token) AuthorizationClaims { return jwtClaims } jwtClaims.UserId = userID - accountIDClaim, ok := claims[c.authAudience+AccountIDSuffix] + accountIDClaim, ok := claims[c.authAudience+string(AccountIDSuffix)] if ok { jwtClaims.AccountId = accountIDClaim.(string) } - domainClaim, ok := claims[c.authAudience+DomainIDSuffix] + domainClaim, ok := claims[c.authAudience+string(DomainIDSuffix)] if ok { jwtClaims.Domain = domainClaim.(string) } - domainCategoryClaim, ok := claims[c.authAudience+DomainCategorySuffix] + domainCategoryClaim, ok := claims[c.authAudience+string(DomainCategorySuffix)] if ok { jwtClaims.DomainCategory = domainCategoryClaim.(string) } diff --git a/management/server/jwtclaims/extractor_test.go b/management/server/jwtclaims/extractor_test.go index d8acd79b6..d8476e039 100644 --- a/management/server/jwtclaims/extractor_test.go +++ b/management/server/jwtclaims/extractor_test.go @@ -12,21 +12,21 @@ import ( func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audiance string) *http.Request { claimMaps := jwt.MapClaims{} if claims.UserId != "" { - claimMaps[UserIDClaim] = claims.UserId + claimMaps[string(UserIDClaim)] = claims.UserId } if claims.AccountId != "" { - claimMaps[audiance+AccountIDSuffix] = claims.AccountId + claimMaps[audiance+string(AccountIDSuffix)] = claims.AccountId } if claims.Domain != "" { - claimMaps[audiance+DomainIDSuffix] = claims.Domain + claimMaps[audiance+string(DomainIDSuffix)] = claims.Domain } if claims.DomainCategory != "" { - claimMaps[audiance+DomainCategorySuffix] = claims.DomainCategory + claimMaps[audiance+string(DomainCategorySuffix)] = claims.DomainCategory } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) r, err := http.NewRequest(http.MethodGet, "http://localhost", nil) require.NoError(t, err, "creating testing request failed") - testRequest := r.WithContext(context.WithValue(r.Context(), TokenUserProperty, token)) //nolint + testRequest := r.WithContext(context.WithValue(r.Context(), TokenUserProperty, token)) // nolint return testRequest } @@ -124,7 +124,7 @@ func TestExtractClaimsSetOptions(t *testing.T) { t.Error("audience should be empty") return } - if c.extractor.userIDClaim != UserIDClaim { + if c.extractor.userIDClaim != string(UserIDClaim) { t.Errorf("user id claim should be default, expected %s, got %s", UserIDClaim, c.extractor.userIDClaim) return } From 454240ca05f0e3ff106e4c596c79ff937aaca42c Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Thu, 30 Mar 2023 17:32:44 +0200 Subject: [PATCH 06/15] comments for codacy --- management/server/account.go | 1 + .../server/http/middleware/auth_middleware.go | 17 ++++++++++++----- management/server/jwtclaims/extractor.go | 13 +++++++++---- management/server/jwtclaims/jwtValidator.go | 3 +++ 4 files changed, 25 insertions(+), 9 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index be51e745d..27bf5606e 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1123,6 +1123,7 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e return nil } +// MarkPATUsed marks a personal access token as used func (am *DefaultAccountManager) MarkPATUsed(tokenID string) error { unlock := am.Store.AcquireGlobalLock() defer unlock() diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 54889466e..d211a6ef2 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -17,8 +17,13 @@ import ( "github.com/netbirdio/netbird/management/server/status" ) +// GetAccountFromPATFunc function type GetAccountFromPATFunc func(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) + +// ValidateAndParseTokenFunc function type ValidateAndParseTokenFunc func(token string) (*jwt.Token, error) + +// MarkPATUsedFunc function type MarkPATUsedFunc func(token string) error // AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens @@ -29,8 +34,10 @@ type AuthMiddleware struct { audience string } +type key string + const ( - userProperty = "user" + userProperty key = "user" ) // NewAuthMiddleware instance constructor @@ -44,13 +51,13 @@ func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParse } // Handler method of the middleware which authenticates a user either by JWT claims or by PAT -func (a *AuthMiddleware) Handler(h http.Handler) http.Handler { +func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { auth := strings.Split(r.Header.Get("Authorization"), " ") authType := auth[0] switch strings.ToLower(authType) { case "bearer": - err := a.CheckJWTFromRequest(w, r) + err := m.CheckJWTFromRequest(w, r) if err != nil { log.Debugf("Error when validating JWT claims: %s", err.Error()) util.WriteError(status.Errorf(status.Unauthorized, "Token invalid"), w) @@ -58,7 +65,7 @@ func (a *AuthMiddleware) Handler(h http.Handler) http.Handler { } h.ServeHTTP(w, r) case "token": - err := a.CheckPATFromRequest(w, r) + err := m.CheckPATFromRequest(w, r) if err != nil { log.Debugf("Error when validating PAT claims: %s", err.Error()) util.WriteError(status.Errorf(status.Unauthorized, "Token invalid"), w) @@ -93,7 +100,7 @@ func (m *AuthMiddleware) CheckJWTFromRequest(w http.ResponseWriter, r *http.Requ // If we get here, everything worked and we can set the // user property in context. - newRequest := r.WithContext(context.WithValue(r.Context(), userProperty, validatedToken)) // nolint + newRequest := r.WithContext(context.WithValue(r.Context(), string(userProperty), validatedToken)) // nolint // Update the current request with the new context information. *r = *newRequest return nil diff --git a/management/server/jwtclaims/extractor.go b/management/server/jwtclaims/extractor.go index 5063d7b91..b5f30b8d6 100644 --- a/management/server/jwtclaims/extractor.go +++ b/management/server/jwtclaims/extractor.go @@ -9,11 +9,16 @@ import ( type key string const ( - TokenUserProperty key = "user" - AccountIDSuffix key = "wt_account_id" - DomainIDSuffix key = "wt_account_domain" + // TokenUserProperty key for the user property in the request context + TokenUserProperty key = "user" + // AccountIDSuffix suffix for the account id claim + AccountIDSuffix key = "wt_account_id" + // DomainIDSuffix suffix for the domain id claim + DomainIDSuffix key = "wt_account_domain" + // DomainCategorySuffix suffix for the domain category claim DomainCategorySuffix key = "wt_account_domain_category" - UserIDClaim key = "sub" + // UserIDClaim claim for the user id + UserIDClaim key = "sub" ) // Extract function type diff --git a/management/server/jwtclaims/jwtValidator.go b/management/server/jwtclaims/jwtValidator.go index d324c5ab3..ee9513c57 100644 --- a/management/server/jwtclaims/jwtValidator.go +++ b/management/server/jwtclaims/jwtValidator.go @@ -58,10 +58,12 @@ type JSONWebKey struct { X5c []string `json:"x5c"` } +// JWTValidator struct to handle token validation and parsing type JWTValidator struct { options Options } +// NewJWTValidator constructor func NewJWTValidator(issuer string, audience string, keysLocation string) (*JWTValidator, error) { keys, err := getPemKeys(keysLocation) if err != nil { @@ -102,6 +104,7 @@ func NewJWTValidator(issuer string, audience string, keysLocation string) (*JWTV }, nil } +// ValidateAndParse validates the token and returns the parsed token func (m *JWTValidator) ValidateAndParse(token string) (*jwt.Token, error) { // If the token is empty... if token == "" { From e08af7fcdff28219419e566c90045f7aee3827dc Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Thu, 30 Mar 2023 17:46:21 +0200 Subject: [PATCH 07/15] codacy --- management/server/http/middleware/auth_middleware.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index d211a6ef2..2901ccfbd 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -100,7 +100,7 @@ func (m *AuthMiddleware) CheckJWTFromRequest(w http.ResponseWriter, r *http.Requ // If we get here, everything worked and we can set the // user property in context. - newRequest := r.WithContext(context.WithValue(r.Context(), string(userProperty), validatedToken)) // nolint + newRequest := r.WithContext(context.WithValue(r.Context(), userProperty, validatedToken)) // nolint // Update the current request with the new context information. *r = *newRequest return nil From f273fe9f519621140420e5bb9f8a200c5de899b3 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Thu, 30 Mar 2023 18:54:55 +0200 Subject: [PATCH 08/15] revert codacy --- .../server/http/middleware/auth_middleware.go | 12 +++++------- management/server/jwtclaims/extractor.go | 18 ++++++++---------- 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 2901ccfbd..aeebf2593 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -34,10 +34,8 @@ type AuthMiddleware struct { audience string } -type key string - const ( - userProperty key = "user" + userProperty = "user" ) // NewAuthMiddleware instance constructor @@ -131,10 +129,10 @@ func (m *AuthMiddleware) CheckPATFromRequest(w http.ResponseWriter, r *http.Requ } claimMaps := jwt.MapClaims{} - claimMaps[string(jwtclaims.UserIDClaim)] = user.Id - claimMaps[m.audience+string(jwtclaims.AccountIDSuffix)] = account.Id - claimMaps[m.audience+string(jwtclaims.DomainIDSuffix)] = account.Domain - claimMaps[m.audience+string(jwtclaims.DomainCategorySuffix)] = account.DomainCategory + claimMaps[jwtclaims.UserIDClaim] = user.Id + claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id + claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain + claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) // Update the current request with the new context information. diff --git a/management/server/jwtclaims/extractor.go b/management/server/jwtclaims/extractor.go index b5f30b8d6..3bd518d00 100644 --- a/management/server/jwtclaims/extractor.go +++ b/management/server/jwtclaims/extractor.go @@ -6,19 +6,17 @@ import ( "github.com/golang-jwt/jwt" ) -type key string - const ( // TokenUserProperty key for the user property in the request context - TokenUserProperty key = "user" + TokenUserProperty = "user" // AccountIDSuffix suffix for the account id claim - AccountIDSuffix key = "wt_account_id" + AccountIDSuffix = "wt_account_id" // DomainIDSuffix suffix for the domain id claim - DomainIDSuffix key = "wt_account_domain" + DomainIDSuffix = "wt_account_domain" // DomainCategorySuffix suffix for the domain category claim - DomainCategorySuffix key = "wt_account_domain_category" + DomainCategorySuffix = "wt_account_domain_category" // UserIDClaim claim for the user id - UserIDClaim key = "sub" + UserIDClaim = "sub" ) // Extract function type @@ -81,15 +79,15 @@ func (c *ClaimsExtractor) FromToken(token *jwt.Token) AuthorizationClaims { return jwtClaims } jwtClaims.UserId = userID - accountIDClaim, ok := claims[c.authAudience+string(AccountIDSuffix)] + accountIDClaim, ok := claims[c.authAudience+AccountIDSuffix] if ok { jwtClaims.AccountId = accountIDClaim.(string) } - domainClaim, ok := claims[c.authAudience+string(DomainIDSuffix)] + domainClaim, ok := claims[c.authAudience+DomainIDSuffix] if ok { jwtClaims.Domain = domainClaim.(string) } - domainCategoryClaim, ok := claims[c.authAudience+string(DomainCategorySuffix)] + domainCategoryClaim, ok := claims[c.authAudience+DomainCategorySuffix] if ok { jwtClaims.DomainCategory = domainCategoryClaim.(string) } From ce775d59aeb7f14bccbe9de45a7189d77cf628f6 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Thu, 30 Mar 2023 18:59:35 +0200 Subject: [PATCH 09/15] revert codacy --- management/server/jwtclaims/extractor.go | 4 ++-- management/server/jwtclaims/extractor_test.go | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/management/server/jwtclaims/extractor.go b/management/server/jwtclaims/extractor.go index 3bd518d00..9aa00a004 100644 --- a/management/server/jwtclaims/extractor.go +++ b/management/server/jwtclaims/extractor.go @@ -19,7 +19,7 @@ const ( UserIDClaim = "sub" ) -// Extract function type +// ExtractClaims Extract function type type ExtractClaims func(r *http.Request) AuthorizationClaims // ClaimsExtractor struct that holds the extract function @@ -65,7 +65,7 @@ func NewClaimsExtractor(options ...ClaimsExtractorOption) *ClaimsExtractor { ce.FromRequestContext = ce.fromRequestContext } if ce.userIDClaim == "" { - ce.userIDClaim = string(UserIDClaim) + ce.userIDClaim = UserIDClaim } return ce } diff --git a/management/server/jwtclaims/extractor_test.go b/management/server/jwtclaims/extractor_test.go index d8476e039..53f8818b1 100644 --- a/management/server/jwtclaims/extractor_test.go +++ b/management/server/jwtclaims/extractor_test.go @@ -12,16 +12,16 @@ import ( func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audiance string) *http.Request { claimMaps := jwt.MapClaims{} if claims.UserId != "" { - claimMaps[string(UserIDClaim)] = claims.UserId + claimMaps[UserIDClaim] = claims.UserId } if claims.AccountId != "" { - claimMaps[audiance+string(AccountIDSuffix)] = claims.AccountId + claimMaps[audiance+AccountIDSuffix] = claims.AccountId } if claims.Domain != "" { - claimMaps[audiance+string(DomainIDSuffix)] = claims.Domain + claimMaps[audiance+DomainIDSuffix] = claims.Domain } if claims.DomainCategory != "" { - claimMaps[audiance+string(DomainCategorySuffix)] = claims.DomainCategory + claimMaps[audiance+DomainCategorySuffix] = claims.DomainCategory } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) r, err := http.NewRequest(http.MethodGet, "http://localhost", nil) @@ -124,7 +124,7 @@ func TestExtractClaimsSetOptions(t *testing.T) { t.Error("audience should be empty") return } - if c.extractor.userIDClaim != string(UserIDClaim) { + if c.extractor.userIDClaim != UserIDClaim { t.Errorf("user id claim should be default, expected %s, got %s", UserIDClaim, c.extractor.userIDClaim) return } From ca1dc5ac885bfc9b4804172b560335f8d58edb65 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Thu, 30 Mar 2023 19:03:44 +0200 Subject: [PATCH 10/15] disable access control for token endpoint --- .../server/http/middleware/access_control.go | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/management/server/http/middleware/access_control.go b/management/server/http/middleware/access_control.go index 5e56f75ab..f1ab898a8 100644 --- a/management/server/http/middleware/access_control.go +++ b/management/server/http/middleware/access_control.go @@ -2,6 +2,9 @@ package middleware import ( "net/http" + "regexp" + + log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/status" @@ -34,12 +37,23 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { claims := a.claimsExtract.FromRequestContext(r) - ok, err := a.isUserAdmin(claims) + ok, err := regexp.MatchString(`^.*/api/users/.*/tokens.*$`, r.URL.Path) + if err != nil { + log.Debugf("Regex failed") + util.WriteError(status.Errorf(status.Internal, ""), w) + return + } + if ok { + log.Debugf("Valid Path") + h.ServeHTTP(w, r) + return + } + + ok, err = a.isUserAdmin(claims) if err != nil { util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w) return } - if !ok { switch r.Method { case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut: From 32c96c15b85d6675c22695fb151febf725b99dcf Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Fri, 31 Mar 2023 10:30:05 +0200 Subject: [PATCH 11/15] disable linter errors by comment --- management/server/http/middleware/auth_middleware.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index aeebf2593..c3f9361dd 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -98,7 +98,7 @@ func (m *AuthMiddleware) CheckJWTFromRequest(w http.ResponseWriter, r *http.Requ // If we get here, everything worked and we can set the // user property in context. - newRequest := r.WithContext(context.WithValue(r.Context(), userProperty, validatedToken)) // nolint + newRequest := r.WithContext(context.WithValue(r.Context(), userProperty, validatedToken)) //nolint // Update the current request with the new context information. *r = *newRequest return nil @@ -134,7 +134,7 @@ func (m *AuthMiddleware) CheckPATFromRequest(w http.ResponseWriter, r *http.Requ claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) - newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) + newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint // Update the current request with the new context information. *r = *newRequest return nil From 110067c00fc4c2a81dce37911ec081d4fa54dcf8 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Fri, 31 Mar 2023 12:03:53 +0200 Subject: [PATCH 12/15] change order for access control checks and aquire account lock after global lock --- management/server/account.go | 10 ++++++- .../server/http/middleware/access_control.go | 27 ++++++++++--------- 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 27bf5606e..78c9237b8 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1126,7 +1126,6 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e // MarkPATUsed marks a personal access token as used func (am *DefaultAccountManager) MarkPATUsed(tokenID string) error { unlock := am.Store.AcquireGlobalLock() - defer unlock() user, err := am.Store.GetUserByTokenID(tokenID) if err != nil { @@ -1138,6 +1137,15 @@ func (am *DefaultAccountManager) MarkPATUsed(tokenID string) error { return err } + unlock() + unlock = am.Store.AcquireAccountLock(account.Id) + defer unlock() + + account, err = am.Store.GetAccountByUser(user.Id) + if err != nil { + return err + } + pat, ok := account.Users[user.Id].PATs[tokenID] if !ok { return fmt.Errorf("token not found") diff --git a/management/server/http/middleware/access_control.go b/management/server/http/middleware/access_control.go index f1ab898a8..5f8389dfa 100644 --- a/management/server/http/middleware/access_control.go +++ b/management/server/http/middleware/access_control.go @@ -37,19 +37,7 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { claims := a.claimsExtract.FromRequestContext(r) - ok, err := regexp.MatchString(`^.*/api/users/.*/tokens.*$`, r.URL.Path) - if err != nil { - log.Debugf("Regex failed") - util.WriteError(status.Errorf(status.Internal, ""), w) - return - } - if ok { - log.Debugf("Valid Path") - h.ServeHTTP(w, r) - return - } - - ok, err = a.isUserAdmin(claims) + ok, err := a.isUserAdmin(claims) if err != nil { util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w) return @@ -57,6 +45,19 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler { if !ok { switch r.Method { case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut: + + ok, err := regexp.MatchString(`^.*/api/users/.*/tokens.*$`, r.URL.Path) + if err != nil { + log.Debugf("Regex failed") + util.WriteError(status.Errorf(status.Internal, ""), w) + return + } + if ok { + log.Debugf("Valid Path") + h.ServeHTTP(w, r) + return + } + util.WriteError(status.Errorf(status.PermissionDenied, "only admin can perform this operation"), w) return } From 2eaf4aa8d7b8972b95772374551466445d43d49c Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Fri, 31 Mar 2023 12:44:22 +0200 Subject: [PATCH 13/15] add test for auth middleware --- .../http/middleware/auth_middleware_test.go | 123 ++++++++++++++++++ 1 file changed, 123 insertions(+) create mode 100644 management/server/http/middleware/auth_middleware_test.go diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go new file mode 100644 index 000000000..e77beb21f --- /dev/null +++ b/management/server/http/middleware/auth_middleware_test.go @@ -0,0 +1,123 @@ +package middleware + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/golang-jwt/jwt" + + "github.com/netbirdio/netbird/management/server" +) + +const ( + audience = "audience" + accountID = "accountID" + domain = "domain" + userID = "userID" + tokenID = "tokenID" + PAT = "PAT" + JWT = "JWT" + wrongToken = "wrongToken" +) + +var testAccount = &server.Account{ + Id: accountID, + Domain: domain, + Users: map[string]*server.User{ + userID: { + Id: userID, + PATs: map[string]*server.PersonalAccessToken{ + tokenID: { + ID: tokenID, + Name: "My first token", + HashedToken: "someHash", + ExpirationDate: time.Now().AddDate(0, 0, 7), + CreatedBy: userID, + CreatedAt: time.Now(), + LastUsed: time.Now(), + }, + }, + }, + }, +} + +func mockGetAccountFromPAT(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) { + if token == PAT { + return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil + } + return nil, nil, nil, fmt.Errorf("PAT invalid") +} + +func mockValidateAndParseToken(token string) (*jwt.Token, error) { + if token == JWT { + return &jwt.Token{}, nil + } + return nil, fmt.Errorf("JWT invalid") +} + +func mockMarkPATUsed(token string) error { + if token == tokenID { + return nil + } + return fmt.Errorf("Should never get reached") +} + +func TestAccounts_AccountsHandler(t *testing.T) { + tt := []struct { + name string + authHeader string + expectedStatusCode int + }{ + { + name: "Valid PAT Token", + authHeader: "Token " + PAT, + expectedStatusCode: 200, + }, + { + name: "Invalid PAT Token", + authHeader: "Token " + wrongToken, + expectedStatusCode: 401, + }, + { + name: "Valid JWT Token", + authHeader: "Bearer " + JWT, + expectedStatusCode: 200, + }, + { + name: "Invalid JWT Token", + authHeader: "Bearer " + wrongToken, + expectedStatusCode: 401, + }, + { + name: "Basic Auth", + authHeader: "Basic " + PAT, + expectedStatusCode: 401, + }, + } + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // do nothing + }) + + authMiddleware := NewAuthMiddleware(mockGetAccountFromPAT, mockValidateAndParseToken, mockMarkPATUsed, audience) + + handlerToTest := authMiddleware.Handler(nextHandler) + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "http://testing", nil) + req.Header.Set("Authorization", tc.authHeader) + rec := httptest.NewRecorder() + + handlerToTest.ServeHTTP(rec, req) + + if rec.Result().StatusCode != tc.expectedStatusCode { + t.Errorf("expected status code %d, got %d", tc.expectedStatusCode, rec.Result().StatusCode) + } + }) + } + +} From 931c20c8fe6a1c2c7aca4c80a58410e6bd3bdb85 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Fri, 31 Mar 2023 12:45:10 +0200 Subject: [PATCH 14/15] fix test name --- management/server/http/middleware/auth_middleware_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index e77beb21f..5a5558fa5 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -65,7 +65,7 @@ func mockMarkPATUsed(token string) error { return fmt.Errorf("Should never get reached") } -func TestAccounts_AccountsHandler(t *testing.T) { +func TestAuthMiddleware_Handler(t *testing.T) { tt := []struct { name string authHeader string From d3de03596113ea62df42b3a19c04dbc856692965 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Sat, 1 Apr 2023 11:04:21 +0200 Subject: [PATCH 15/15] error responses always lower case + duplicate error response fix --- management/server/http/middleware/auth_middleware.go | 8 +++----- management/server/http/util/util.go | 3 ++- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index c3f9361dd..a8c81012a 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -58,7 +58,7 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { err := m.CheckJWTFromRequest(w, r) if err != nil { log.Debugf("Error when validating JWT claims: %s", err.Error()) - util.WriteError(status.Errorf(status.Unauthorized, "Token invalid"), w) + util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w) return } h.ServeHTTP(w, r) @@ -66,12 +66,12 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { err := m.CheckPATFromRequest(w, r) if err != nil { log.Debugf("Error when validating PAT claims: %s", err.Error()) - util.WriteError(status.Errorf(status.Unauthorized, "Token invalid"), w) + util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w) return } h.ServeHTTP(w, r) default: - util.WriteError(status.Errorf(status.Unauthorized, "No valid authentication provided"), w) + util.WriteError(status.Errorf(status.Unauthorized, "no valid authentication provided"), w) return } }) @@ -115,11 +115,9 @@ func (m *AuthMiddleware) CheckPATFromRequest(w http.ResponseWriter, r *http.Requ account, user, pat, err := m.getAccountFromPAT(token) if err != nil { - util.WriteError(status.Errorf(status.Unauthorized, "Token invalid"), w) return fmt.Errorf("invalid Token: %w", err) } if time.Now().After(pat.ExpirationDate) { - util.WriteError(status.Errorf(status.Unauthorized, "Token expired"), w) return fmt.Errorf("token expired") } diff --git a/management/server/http/util/util.go b/management/server/http/util/util.go index c40daa1a3..407443251 100644 --- a/management/server/http/util/util.go +++ b/management/server/http/util/util.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net/http" + "strings" "time" log "github.com/sirupsen/logrus" @@ -99,7 +100,7 @@ func WriteError(err error, w http.ResponseWriter) { httpStatus = http.StatusUnauthorized default: } - msg = err.Error() + msg = strings.ToLower(err.Error()) } else { unhandledMSG := fmt.Sprintf("got unhandled error code, error: %s", err.Error()) log.Error(unhandledMSG)