diff --git a/infrastructure_files/docker-compose.yml.tmpl b/infrastructure_files/docker-compose.yml.tmpl index 0d218a314..296201710 100644 --- a/infrastructure_files/docker-compose.yml.tmpl +++ b/infrastructure_files/docker-compose.yml.tmpl @@ -30,7 +30,7 @@ services: - $SIGNAL_VOLUMENAME:/var/lib/netbird ports: - 10000:80 - # # port and command for Let's Encrypt validation + # # port and command for Let's Encrypt validation # - 443:443 # command: ["--letsencrypt-domain", "$NETBIRD_DOMAIN", "--log-file", "console"] # Management @@ -45,7 +45,7 @@ services: - ./management.json:/etc/netbird/management.json ports: - $NETBIRD_MGMT_API_PORT:443 #API port - # # command for Let's Encrypt validation without dashboard container + # # command for Let's Encrypt validation without dashboard container # command: ["--letsencrypt-domain", "$NETBIRD_DOMAIN", "--log-file", "console"] command: ["--port", "443", "--log-file", "console", "--disable-anonymous-metrics=$NETBIRD_DISABLE_ANONYMOUS_METRICS", "--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN", "--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN"] # Coturn @@ -60,7 +60,6 @@ services: network_mode: host command: - -c /etc/turnserver.conf - volumes: $MGMT_VOLUMENAME: $SIGNAL_VOLUMENAME: diff --git a/infrastructure_files/management.json.tmpl b/infrastructure_files/management.json.tmpl index 1f48bd11c..f3b08101c 100644 --- a/infrastructure_files/management.json.tmpl +++ b/infrastructure_files/management.json.tmpl @@ -32,6 +32,7 @@ "AuthIssuer": "$NETBIRD_AUTH_AUTHORITY", "AuthAudience": "$NETBIRD_AUTH_AUDIENCE", "AuthKeysLocation": "$NETBIRD_AUTH_JWT_CERTS", + "AuthUserIDClaim": "$NETBIRD_AUTH_USER_ID_CLAIM", "CertFile":"$NETBIRD_MGMT_API_CERT_FILE", "CertKey":"$NETBIRD_MGMT_API_CERT_KEY_FILE", "OIDCConfigEndpoint":"$NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT" @@ -49,4 +50,4 @@ "DeviceAuthEndpoint": "$NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT" } } -} \ No newline at end of file +} diff --git a/infrastructure_files/setup.env.example b/infrastructure_files/setup.env.example index 4afb322f9..09f407225 100644 --- a/infrastructure_files/setup.env.example +++ b/infrastructure_files/setup.env.example @@ -7,6 +7,8 @@ NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT="" NETBIRD_AUTH_AUDIENCE="" # e.g. netbird-client NETBIRD_AUTH_CLIENT_ID="" +# if you want to use a custom claim for the user ID instead of 'sub', set it here +# NETBIRD_AUTH_USER_ID_CLAIM="" # indicates whether to use Auth0 or not: true or false NETBIRD_USE_AUTH0="false" NETBIRD_AUTH_DEVICE_AUTH_PROVIDER="none" diff --git a/management/cmd/management.go b/management/cmd/management.go index da8bfe5f9..f3210d88e 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -7,15 +7,6 @@ import ( "errors" "flag" "fmt" - "github.com/google/uuid" - "github.com/miekg/dns" - "github.com/netbirdio/netbird/management/server/activity/sqlite" - httpapi "github.com/netbirdio/netbird/management/server/http" - "github.com/netbirdio/netbird/management/server/metrics" - "github.com/netbirdio/netbird/management/server/telemetry" - "golang.org/x/crypto/acme/autocert" - "golang.org/x/net/http2" - "golang.org/x/net/http2/h2c" "io" "io/fs" "net" @@ -26,6 +17,16 @@ import ( "strings" "time" + "github.com/google/uuid" + "github.com/miekg/dns" + "github.com/netbirdio/netbird/management/server/activity/sqlite" + httpapi "github.com/netbirdio/netbird/management/server/http" + "github.com/netbirdio/netbird/management/server/metrics" + "github.com/netbirdio/netbird/management/server/telemetry" + "golang.org/x/crypto/acme/autocert" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" + "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/util" @@ -178,8 +179,13 @@ var ( tlsEnabled = true } - httpAPIHandler, err := httpapi.APIHandler(accountManager, config.HttpConfig.AuthIssuer, - config.HttpConfig.AuthAudience, config.HttpConfig.AuthKeysLocation, appMetrics) + httpAPIAuthCfg := httpapi.AuthCfg{ + Issuer: config.HttpConfig.AuthIssuer, + Audience: config.HttpConfig.AuthAudience, + UserIDClaim: config.HttpConfig.AuthUserIDClaim, + KeysLocation: config.HttpConfig.AuthKeysLocation, + } + httpAPIHandler, err := httpapi.APIHandler(accountManager, appMetrics, httpAPIAuthCfg) if err != nil { return fmt.Errorf("failed creating HTTP API handler: %v", err) } @@ -415,7 +421,6 @@ type OIDCConfigResponse struct { // fetchOIDCConfig fetches OIDC configuration from the IDP func fetchOIDCConfig(oidcEndpoint string) (OIDCConfigResponse, error) { - res, err := http.Get(oidcEndpoint) if err != nil { return OIDCConfigResponse{}, fmt.Errorf("failed fetching OIDC configuration fro mendpoint %s %v", oidcEndpoint, err) @@ -445,7 +450,6 @@ func fetchOIDCConfig(oidcEndpoint string) (OIDCConfigResponse, error) { } return config, nil - } func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) { diff --git a/management/server/account.go b/management/server/account.go index 9c7ad0709..e584a79c4 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -3,6 +3,15 @@ package server import ( "context" "fmt" + "math/rand" + "net" + "net/netip" + "reflect" + "regexp" + "strings" + "sync" + "time" + "github.com/eko/gocache/v3/cache" cacheStore "github.com/eko/gocache/v3/store" nbdns "github.com/netbirdio/netbird/dns" @@ -14,14 +23,6 @@ import ( gocache "github.com/patrickmn/go-cache" "github.com/rs/xid" log "github.com/sirupsen/logrus" - "math/rand" - "net" - "net/netip" - "reflect" - "regexp" - "strings" - "sync" - "time" ) const ( @@ -219,7 +220,6 @@ func (a *Account) getEnabledAndDisabledRoutesByPeer(peerID string) ([]*route.Rou // GetRoutesByPrefix return list of routes by account and route prefix func (a *Account) GetRoutesByPrefix(prefix netip.Prefix) []*route.Route { - var routes []*route.Route for _, r := range a.Routes { if r.Network.String() == prefix.String() { @@ -243,7 +243,6 @@ func (a *Account) GetPeerByIP(peerIP string) *Peer { // GetPeerRules returns a list of source or destination rules of a given peer. func (a *Account) GetPeerRules(peerID string) (srcRules []*Rule, dstRules []*Rule) { - // Rules are group based so there is no direct access to peers. // First, find all groups that the given peer belongs to peerGroups := make(map[string]struct{}) @@ -490,7 +489,8 @@ func (a *Account) GetPeer(peerID string) *Peer { // BuildManager creates a new DefaultAccountManager with a provided Store func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager, - singleAccountModeDomain string, dnsDomain string, eventStore activity.Store) (*DefaultAccountManager, error) { + singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, +) (*DefaultAccountManager, error) { am := &DefaultAccountManager{ Store: store, peersUpdateManager: peersUpdateManager, @@ -544,14 +544,13 @@ func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManage err := am.warmupIDPCache() if err != nil { log.Warnf("failed warming up cache due to error: %v", err) - //todo retry? + // todo retry? return } }() } return am, nil - } // newAccount creates a new Account with a generated ID and generated default setup keys. @@ -669,7 +668,6 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(email string, accountI } return nil, nil - } // lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil @@ -768,7 +766,8 @@ func (am *DefaultAccountManager) lookupCache(accountUsers map[string]struct{}, a // updateAccountDomainAttributes updates the account domain attributes and then, saves the account func (am *DefaultAccountManager) updateAccountDomainAttributes(account *Account, claims jwtclaims.AuthorizationClaims, - primaryDomain bool) error { + primaryDomain bool, +) error { account.IsDomainPrimaryAccount = primaryDomain lowerDomain := strings.ToLower(claims.Domain) @@ -826,6 +825,9 @@ func (am *DefaultAccountManager) handleExistingUserAccount( // handleNewUserAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account, // otherwise it will create a new account and make it primary account for the domain. func (am *DefaultAccountManager) handleNewUserAccount(domainAcc *Account, claims jwtclaims.AuthorizationClaims) (*Account, error) { + if claims.UserId == "" { + return nil, fmt.Errorf("user ID is empty") + } var ( account *Account err error @@ -897,7 +899,9 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e // GetAccountFromToken returns an account associated with this token func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error) { - + if claims.UserId == "" { + return nil, nil, fmt.Errorf("user ID is empty") + } if am.singleAccountMode && am.singleAccountModeDomain != "" { // This section is mostly related to self-hosted installations. // We override incoming domain claims to group users under a single account. @@ -943,6 +947,9 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat // // Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain) func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtclaims.AuthorizationClaims) (*Account, error) { + if claims.UserId == "" { + return nil, fmt.Errorf("user ID is empty") + } // if Account ID is part of the claims // it means that we've already classified the domain and user has an account if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) { @@ -995,7 +1002,6 @@ func isDomainValid(domain string) bool { // AccountExists checks whether account exists (returns true) or not (returns false) func (am *DefaultAccountManager) AccountExists(accountID string) (*bool, error) { - unlock := am.Store.AcquireAccountLock(accountID) defer unlock() diff --git a/management/server/config.go b/management/server/config.go index bcdbad7e0..6a428c83b 100644 --- a/management/server/config.go +++ b/management/server/config.go @@ -7,8 +7,13 @@ import ( "github.com/netbirdio/netbird/util" ) -type Protocol string -type Provider string +type ( + // Protocol type + Protocol string + + // Provider authorization flow type + Provider string +) const ( UDP Protocol = "udp" @@ -45,14 +50,16 @@ type TURNConfig struct { // HttpServerConfig is a config of the HTTP Management service server type HttpServerConfig struct { LetsEncryptDomain string - //CertFile is the location of the certificate + // CertFile is the location of the certificate CertFile string - //CertKey is the location of the certificate private key + // CertKey is the location of the certificate private key CertKey string // AuthAudience identifies the recipients that the JWT is intended for (aud in JWT) AuthAudience string - // AuthIssuer identifies principal that issued the JWT. + // AuthIssuer identifies principal that issued the JWT AuthIssuer string + // AuthUserIDClaim is the name of the claim that used as user ID + AuthUserIDClaim string // AuthKeysLocation is a location of JWT key set containing the public keys used to verify JWT AuthKeysLocation string // OIDCConfigEndpoint is the endpoint of an IDP manager to get OIDC configuration diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 9a7ade111..785a2fdf0 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -3,11 +3,11 @@ package server import ( "context" "fmt" - "github.com/netbirdio/netbird/management/server/telemetry" - gPeer "google.golang.org/grpc/peer" "strings" "time" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/http/middleware" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -31,12 +31,14 @@ type GRPCServer struct { config *Config turnCredentialsManager TURNCredentialsManager jwtMiddleware *middleware.JWTMiddleware + jwtClaimsExtractor *jwtclaims.ClaimsExtractor appMetrics telemetry.AppMetrics } // NewServer creates a new Management server func NewServer(config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager, - turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics) (*GRPCServer, error) { + turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics, +) (*GRPCServer, error) { key, err := wgtypes.GeneratePrivateKey() if err != nil { return nil, err @@ -66,6 +68,16 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager } } + var audience, userIDClaim string + if config.HttpConfig != nil { + audience = config.HttpConfig.AuthAudience + userIDClaim = config.HttpConfig.AuthUserIDClaim + } + jwtClaimsExtractor := jwtclaims.NewClaimsExtractor( + jwtclaims.WithAudience(audience), + jwtclaims.WithUserIDClaim(userIDClaim), + ) + return &GRPCServer{ wgKey: key, // peerKey -> event channel @@ -74,6 +86,7 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager config: config, turnCredentialsManager: turnCredentialsManager, jwtMiddleware: jwtMiddleware, + jwtClaimsExtractor: jwtClaimsExtractor, appMetrics: appMetrics, }, nil } @@ -113,7 +126,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi peer, err := s.accountManager.GetPeerByKey(peerKey.String()) if err != nil { - p, _ := gPeer.FromContext(srv.Context()) + p, _ := gRPCPeer.FromContext(srv.Context()) msg := status.Errorf(codes.PermissionDenied, "provided peer with the key wgPubKey %s is not registered, remote addr is %s", peerKey.String(), p.Addr.String()) log.Debug(msg) return msg @@ -122,7 +135,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi syncReq := &proto.SyncRequest{} err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, syncReq) if err != nil { - p, _ := gPeer.FromContext(srv.Context()) + p, _ := gRPCPeer.FromContext(srv.Context()) msg := status.Errorf(codes.InvalidArgument, "invalid request message from %s,remote addr is %s", peerKey.String(), p.Addr.String()) log.Debug(msg) return msg @@ -200,7 +213,7 @@ func (s *GRPCServer) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest) if err != nil { return nil, status.Errorf(codes.Internal, "invalid jwt token, err: %v", err) } - claims := jwtclaims.ExtractClaimsWithToken(token, s.config.HttpConfig.AuthAudience) + claims := s.jwtClaimsExtractor.FromToken(token) userID = claims.UserId // we need to call this method because if user is new, we will automatically add it to existing or create a new account _, _, err = s.accountManager.GetAccountFromToken(claims) @@ -305,7 +318,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p // peer doesn't exist -> check if setup key was provided if loginReq.GetJwtToken() == "" && loginReq.GetSetupKey() == "" { // absent setup key or jwt -> permission denied - p, _ := gPeer.FromContext(ctx) + p, _ := gRPCPeer.FromContext(ctx) msg := status.Errorf(codes.PermissionDenied, "provided peer with the key wgPubKey %s is not registered and no setup key or jwt was provided,"+ " remote addr is %s", peerKey.String(), p.Addr.String()) diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 17c71662f..7f4078c3a 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -46,6 +46,10 @@ components: type: array items: type: string + is_current: + description: Is true if authenticated user is the same as this user + type: boolean + readOnly: true required: - id - email @@ -1703,4 +1707,4 @@ paths: '403': "$ref": "#/components/responses/forbidden" '500': - "$ref": "#/components/responses/internal_error" \ No newline at end of file + "$ref": "#/components/responses/internal_error" diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index a7934c20f..2f12a2e87 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -566,6 +566,9 @@ type User struct { // Id User ID Id string `json:"id"` + // IsCurrent Is true if authenticated user is the same as this user + IsCurrent *bool `json:"is_current,omitempty"` + // Name User's name from idp provider Name string `json:"name"` diff --git a/management/server/http/dns_settings.go b/management/server/http/dns_settings.go index 92c5a8322..984faa988 100644 --- a/management/server/http/dns_settings.go +++ b/management/server/http/dns_settings.go @@ -2,33 +2,35 @@ package http import ( "encoding/json" + "net/http" + "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" log "github.com/sirupsen/logrus" - "net/http" ) // DNSSettings is a handler that returns the DNS settings of the account type DNSSettings struct { - jwtExtractor jwtclaims.ClaimsExtractor - accountManager server.AccountManager - authAudience string + accountManager server.AccountManager + claimsExtractor *jwtclaims.ClaimsExtractor } // NewDNSSettings returns a new instance of DNSSettings handler -func NewDNSSettings(accountManager server.AccountManager, authAudience string) *DNSSettings { +func NewDNSSettings(accountManager server.AccountManager, authCfg AuthCfg) *DNSSettings { return &DNSSettings{ accountManager: accountManager, - authAudience: authAudience, - jwtExtractor: *jwtclaims.NewClaimsExtractor(nil), + claimsExtractor: jwtclaims.NewClaimsExtractor( + jwtclaims.WithAudience(authCfg.Audience), + jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), + ), } } // GetDNSSettings returns the DNS settings for the account func (h *DNSSettings) GetDNSSettings(w http.ResponseWriter, r *http.Request) { - claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + claims := h.claimsExtractor.FromRequestContext(r) account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { log.Error(err) @@ -51,7 +53,7 @@ func (h *DNSSettings) GetDNSSettings(w http.ResponseWriter, r *http.Request) { // UpdateDNSSettings handles update to DNS settings of an account func (h *DNSSettings) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) { - claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + claims := h.claimsExtractor.FromRequestContext(r) account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { util.WriteError(err, w) diff --git a/management/server/http/dns_settings_test.go b/management/server/http/dns_settings_test.go index 58bec62f1..a82a9aae8 100644 --- a/management/server/http/dns_settings_test.go +++ b/management/server/http/dns_settings_test.go @@ -3,14 +3,15 @@ package http import ( "bytes" "encoding/json" - "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/status" - "github.com/stretchr/testify/assert" "io" "net/http" "net/http/httptest" "testing" + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/status" + "github.com/stretchr/testify/assert" + "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -52,16 +53,15 @@ func initDNSSettingsTestData() *DNSSettings { return testingDNSSettingsAccount, testingDNSSettingsAccount.Users[testDNSSettingsUserID], nil }, }, - authAudience: "", - jwtExtractor: jwtclaims.ClaimsExtractor{ - ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims { + claimsExtractor: jwtclaims.NewClaimsExtractor( + jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { return jwtclaims.AuthorizationClaims{ UserId: "test_user", Domain: "hotmail.com", AccountId: testDNSSettingsAccountID, } - }, - }, + }), + ), } } diff --git a/management/server/http/events.go b/management/server/http/events.go index a635f1c3d..b39da173c 100644 --- a/management/server/http/events.go +++ b/management/server/http/events.go @@ -2,34 +2,36 @@ package http import ( "fmt" + "net/http" + "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" log "github.com/sirupsen/logrus" - "net/http" ) // Events HTTP handler type Events struct { - accountManager server.AccountManager - authAudience string - jwtExtractor jwtclaims.ClaimsExtractor + accountManager server.AccountManager + claimsExtractor *jwtclaims.ClaimsExtractor } // NewEvents creates a new Events HTTP handler -func NewEvents(accountManager server.AccountManager, authAudience string) *Events { +func NewEvents(accountManager server.AccountManager, authCfg AuthCfg) *Events { return &Events{ accountManager: accountManager, - authAudience: authAudience, - jwtExtractor: *jwtclaims.NewClaimsExtractor(nil), + claimsExtractor: jwtclaims.NewClaimsExtractor( + jwtclaims.WithAudience(authCfg.Audience), + jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), + ), } } // GetEvents list of the given account func (h *Events) GetEvents(w http.ResponseWriter, r *http.Request) { - claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + claims := h.claimsExtractor.FromRequestContext(r) account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { log.Error(err) diff --git a/management/server/http/events_test.go b/management/server/http/events_test.go index c73d54c2f..707b7af45 100644 --- a/management/server/http/events_test.go +++ b/management/server/http/events_test.go @@ -2,6 +2,13 @@ package http import ( "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strconv" + "testing" + "time" + "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" @@ -9,12 +16,6 @@ import ( "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/stretchr/testify/assert" - "io" - "net/http" - "net/http/httptest" - "strconv" - "testing" - "time" ) func initEventsTestData(account string, user *server.User, events ...*activity.Event) *Events { @@ -36,16 +37,15 @@ func initEventsTestData(account string, user *server.User, events ...*activity.E }, user, nil }, }, - authAudience: "", - jwtExtractor: jwtclaims.ClaimsExtractor{ - ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims { + claimsExtractor: jwtclaims.NewClaimsExtractor( + jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { return jwtclaims.AuthorizationClaims{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_account", } - }, - }, + }), + ), } } @@ -244,7 +244,6 @@ func TestEvents_GetEvents(t *testing.T) { assert.Equal(t, expected.Meta["some"], event.Meta["some"]) assert.True(t, expected.Timestamp.Equal(event.Timestamp)) } - }) } } diff --git a/management/server/http/groups.go b/management/server/http/groups.go index c0bd21ded..098776160 100644 --- a/management/server/http/groups.go +++ b/management/server/http/groups.go @@ -2,10 +2,11 @@ package http import ( "encoding/json" + "net/http" + "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/status" - "net/http" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -17,22 +18,23 @@ import ( // Groups is a handler that returns groups of the account type Groups struct { - jwtExtractor jwtclaims.ClaimsExtractor - accountManager server.AccountManager - authAudience string + accountManager server.AccountManager + claimsExtractor *jwtclaims.ClaimsExtractor } -func NewGroups(accountManager server.AccountManager, authAudience string) *Groups { +func NewGroups(accountManager server.AccountManager, authCfg AuthCfg) *Groups { return &Groups{ accountManager: accountManager, - authAudience: authAudience, - jwtExtractor: *jwtclaims.NewClaimsExtractor(nil), + claimsExtractor: jwtclaims.NewClaimsExtractor( + jwtclaims.WithAudience(authCfg.Audience), + jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), + ), } } // GetAllGroupsHandler list for the account func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) { - claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + claims := h.claimsExtractor.FromRequestContext(r) account, _, err := h.accountManager.GetAccountFromToken(claims) if err != nil { log.Error(err) @@ -50,7 +52,7 @@ func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) { // UpdateGroupHandler handles update to a group identified by a given ID func (h *Groups) UpdateGroupHandler(w http.ResponseWriter, r *http.Request) { - claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + claims := h.claimsExtractor.FromRequestContext(r) account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { util.WriteError(err, w) @@ -119,7 +121,7 @@ func (h *Groups) UpdateGroupHandler(w http.ResponseWriter, r *http.Request) { // PatchGroupHandler handles patch updates to a group identified by a given ID func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) { - claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + claims := h.claimsExtractor.FromRequestContext(r) account, _, err := h.accountManager.GetAccountFromToken(claims) if err != nil { util.WriteError(err, w) @@ -223,7 +225,7 @@ func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) { // CreateGroupHandler handles group creation request func (h *Groups) CreateGroupHandler(w http.ResponseWriter, r *http.Request) { - claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + claims := h.claimsExtractor.FromRequestContext(r) account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { util.WriteError(err, w) @@ -265,7 +267,7 @@ func (h *Groups) CreateGroupHandler(w http.ResponseWriter, r *http.Request) { // DeleteGroupHandler handles group deletion request func (h *Groups) DeleteGroupHandler(w http.ResponseWriter, r *http.Request) { - claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + claims := h.claimsExtractor.FromRequestContext(r) account, _, err := h.accountManager.GetAccountFromToken(claims) if err != nil { util.WriteError(err, w) @@ -301,7 +303,7 @@ func (h *Groups) DeleteGroupHandler(w http.ResponseWriter, r *http.Request) { // GetGroupHandler returns a group func (h *Groups) GetGroupHandler(w http.ResponseWriter, r *http.Request) { - claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + claims := h.claimsExtractor.FromRequestContext(r) account, _, err := h.accountManager.GetAccountFromToken(claims) if err != nil { util.WriteError(err, w) diff --git a/management/server/http/groups_test.go b/management/server/http/groups_test.go index 380b52546..71e4be9b6 100644 --- a/management/server/http/groups_test.go +++ b/management/server/http/groups_test.go @@ -4,8 +4,6 @@ import ( "bytes" "encoding/json" "fmt" - "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/status" "io" "net" "net/http" @@ -13,6 +11,9 @@ import ( "strings" "testing" + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/status" + "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -78,20 +79,20 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *Groups { }, Groups: map[string]*server.Group{ "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}}, - "id-all": {ID: "id-all", Name: "All"}}, + "id-all": {ID: "id-all", Name: "All"}, + }, }, user, nil }, }, - authAudience: "", - jwtExtractor: jwtclaims.ClaimsExtractor{ - ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims { + claimsExtractor: jwtclaims.NewClaimsExtractor( + jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { return jwtclaims.AuthorizationClaims{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_id", } - }, - }, + }), + ), } } @@ -270,7 +271,8 @@ func TestWriteGroup(t *testing.T) { PeersCount: 2, Peers: []api.PeerMinimum{ {Id: "peer-A-ID"}, - {Id: "peer-B-ID"}}, + {Id: "peer-B-ID"}, + }, }, }, } diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 5015d5650..a56df91ee 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -1,22 +1,29 @@ package http import ( + "net/http" + "github.com/gorilla/mux" s "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/middleware" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/rs/cors" - "net/http" ) +// AuthCfg contains parameters for authentication middleware +type AuthCfg struct { + Issuer string + Audience string + UserIDClaim string + KeysLocation string +} + // APIHandler creates the Management service HTTP API handler registering all the available endpoints. -func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience string, authKeysLocation string, - appMetrics telemetry.AppMetrics) (http.Handler, error) { +func APIHandler(accountManager s.AccountManager, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) { jwtMiddleware, err := middleware.NewJwtMiddleware( - authIssuer, - authAudience, - authKeysLocation, - ) + authCfg.Issuer, + authCfg.Audience, + authCfg.KeysLocation) if err != nil { return nil, err } @@ -24,7 +31,8 @@ func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience corsMiddleware := cors.AllowAll() acMiddleware := middleware.NewAccessControl( - authAudience, + authCfg.Audience, + authCfg.UserIDClaim, accountManager.IsUserAdmin) rootRouter := mux.NewRouter() @@ -33,15 +41,15 @@ func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience apiHandler := rootRouter.PathPrefix("/api").Subrouter() apiHandler.Use(metricsMiddleware.Handler, corsMiddleware.Handler, jwtMiddleware.Handler, acMiddleware.Handler) - groupsHandler := NewGroups(accountManager, authAudience) - rulesHandler := NewRules(accountManager, authAudience) - peersHandler := NewPeers(accountManager, authAudience) - keysHandler := NewSetupKeysHandler(accountManager, authAudience) - userHandler := NewUserHandler(accountManager, authAudience) - routesHandler := NewRoutes(accountManager, authAudience) - nameserversHandler := NewNameservers(accountManager, authAudience) - eventsHandler := NewEvents(accountManager, authAudience) - dnsSettingsHandler := NewDNSSettings(accountManager, authAudience) + groupsHandler := NewGroups(accountManager, authCfg) + rulesHandler := NewRules(accountManager, authCfg) + peersHandler := NewPeers(accountManager, authCfg) + keysHandler := NewSetupKeysHandler(accountManager, authCfg) + userHandler := NewUserHandler(accountManager, authCfg) + routesHandler := NewRoutes(accountManager, authCfg) + nameserversHandler := NewNameservers(accountManager, authCfg) + eventsHandler := NewEvents(accountManager, authCfg) + dnsSettingsHandler := NewDNSSettings(accountManager, authCfg) apiHandler.HandleFunc("/peers", peersHandler.GetPeers).Methods("GET", "OPTIONS") apiHandler.HandleFunc("/peers/{id}", peersHandler.HandlePeer). @@ -88,7 +96,7 @@ func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience apiHandler.HandleFunc("/dns/settings", dnsSettingsHandler.GetDNSSettings).Methods("GET", "OPTIONS") apiHandler.HandleFunc("/dns/settings", dnsSettingsHandler.UpdateDNSSettings).Methods("PUT", "OPTIONS") - err = apiHandler.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error { + err = apiHandler.Walk(func(route *mux.Route, _ *mux.Router, _ []*mux.Route) error { methods, err := route.GetMethods() if err != nil { return err @@ -110,5 +118,4 @@ func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience } return rootRouter, nil - } diff --git a/management/server/http/middleware/access_control.go b/management/server/http/middleware/access_control.go index 3d12eb4a5..5e56f75ab 100644 --- a/management/server/http/middleware/access_control.go +++ b/management/server/http/middleware/access_control.go @@ -1,9 +1,10 @@ package middleware import ( + "net/http" + "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/status" - "net/http" "github.com/netbirdio/netbird/management/server/jwtclaims" ) @@ -12,17 +13,18 @@ type IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error) // AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only type AccessControl struct { - jwtExtractor jwtclaims.ClaimsExtractor - isUserAdmin IsUserAdminFunc - audience string + isUserAdmin IsUserAdminFunc + claimsExtract jwtclaims.ClaimsExtractor } // NewAccessControl instance constructor -func NewAccessControl(audience string, isUserAdmin IsUserAdminFunc) *AccessControl { +func NewAccessControl(audience, userIDClaim string, isUserAdmin IsUserAdminFunc) *AccessControl { return &AccessControl{ - isUserAdmin: isUserAdmin, - audience: audience, - jwtExtractor: *jwtclaims.NewClaimsExtractor(nil), + isUserAdmin: isUserAdmin, + claimsExtract: *jwtclaims.NewClaimsExtractor( + jwtclaims.WithAudience(audience), + jwtclaims.WithUserIDClaim(userIDClaim), + ), } } @@ -30,9 +32,9 @@ func NewAccessControl(audience string, isUserAdmin IsUserAdminFunc) *AccessContr // It also adds func (a *AccessControl) Handler(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - jwtClaims := a.jwtExtractor.ExtractClaimsFromRequestContext(r, a.audience) + claims := a.claimsExtract.FromRequestContext(r) - ok, err := a.isUserAdmin(jwtClaims) + ok, err := a.isUserAdmin(claims) if err != nil { util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w) return @@ -40,7 +42,6 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler { if !ok { switch r.Method { - case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut: util.WriteError(status.Errorf(status.PermissionDenied, "only admin can perform this operation"), w) return diff --git a/management/server/http/middleware/handler.go b/management/server/http/middleware/handler.go index 8725a3d27..c647506bc 100644 --- a/management/server/http/middleware/handler.go +++ b/management/server/http/middleware/handler.go @@ -10,10 +10,11 @@ import ( "encoding/pem" "errors" "fmt" - "github.com/golang-jwt/jwt" - log "github.com/sirupsen/logrus" "math/big" "net/http" + + "github.com/golang-jwt/jwt" + log "github.com/sirupsen/logrus" ) // Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation @@ -33,7 +34,6 @@ type JSONWebKey struct { // NewJwtMiddleware creates new middleware to verify the JWT token sent via Authorization header func NewJwtMiddleware(issuer string, audience string, keysLocation string) (*JWTMiddleware, error) { - keys, err := getPemKeys(keysLocation) if err != nil { return nil, err @@ -67,13 +67,12 @@ func NewJwtMiddleware(issuer string, audience string, keysLocation string) (*JWT func getPemKeys(keysLocation string) (*Jwks, error) { resp, err := http.Get(keysLocation) - if err != nil { return nil, err } defer resp.Body.Close() - var jwks = &Jwks{} + jwks := &Jwks{} err = json.NewDecoder(resp.Body).Decode(jwks) if err != nil { return jwks, err diff --git a/management/server/http/nameservers.go b/management/server/http/nameservers.go index 05ff45ece..3e0df5201 100644 --- a/management/server/http/nameservers.go +++ b/management/server/http/nameservers.go @@ -3,6 +3,8 @@ package http import ( "encoding/json" "fmt" + "net/http" + "github.com/gorilla/mux" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server" @@ -11,28 +13,28 @@ import ( "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" log "github.com/sirupsen/logrus" - "net/http" ) // Nameservers is the nameserver group handler of the account type Nameservers struct { - jwtExtractor jwtclaims.ClaimsExtractor - accountManager server.AccountManager - authAudience string + accountManager server.AccountManager + claimsExtractor *jwtclaims.ClaimsExtractor } // NewNameservers returns a new instance of Nameservers handler -func NewNameservers(accountManager server.AccountManager, authAudience string) *Nameservers { +func NewNameservers(accountManager server.AccountManager, authCfg AuthCfg) *Nameservers { return &Nameservers{ accountManager: accountManager, - authAudience: authAudience, - jwtExtractor: *jwtclaims.NewClaimsExtractor(nil), + claimsExtractor: jwtclaims.NewClaimsExtractor( + jwtclaims.WithAudience(authCfg.Audience), + jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), + ), } } // GetAllNameserversHandler returns the list of nameserver groups for the account func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Request) { - claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + claims := h.claimsExtractor.FromRequestContext(r) account, _, err := h.accountManager.GetAccountFromToken(claims) if err != nil { log.Error(err) @@ -56,7 +58,7 @@ func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Re // CreateNameserverGroupHandler handles nameserver group creation request func (h *Nameservers) CreateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) { - claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + claims := h.claimsExtractor.FromRequestContext(r) account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { util.WriteError(err, w) @@ -89,7 +91,7 @@ func (h *Nameservers) CreateNameserverGroupHandler(w http.ResponseWriter, r *htt // UpdateNameserverGroupHandler handles update to a nameserver group identified by a given ID func (h *Nameservers) UpdateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) { - claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + claims := h.claimsExtractor.FromRequestContext(r) account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { util.WriteError(err, w) @@ -139,7 +141,7 @@ func (h *Nameservers) UpdateNameserverGroupHandler(w http.ResponseWriter, r *htt // PatchNameserverGroupHandler handles patch updates to a nameserver group identified by a given ID func (h *Nameservers) PatchNameserverGroupHandler(w http.ResponseWriter, r *http.Request) { - claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + claims := h.claimsExtractor.FromRequestContext(r) account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { util.WriteError(err, w) @@ -221,7 +223,7 @@ func (h *Nameservers) PatchNameserverGroupHandler(w http.ResponseWriter, r *http // DeleteNameserverGroupHandler handles nameserver group deletion request func (h *Nameservers) DeleteNameserverGroupHandler(w http.ResponseWriter, r *http.Request) { - claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + claims := h.claimsExtractor.FromRequestContext(r) account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { util.WriteError(err, w) @@ -245,7 +247,7 @@ func (h *Nameservers) DeleteNameserverGroupHandler(w http.ResponseWriter, r *htt // GetNameserverGroupHandler handles a nameserver group Get request identified by ID func (h *Nameservers) GetNameserverGroupHandler(w http.ResponseWriter, r *http.Request) { - claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + claims := h.claimsExtractor.FromRequestContext(r) account, _, err := h.accountManager.GetAccountFromToken(claims) if err != nil { log.Error(err) @@ -268,7 +270,6 @@ func (h *Nameservers) GetNameserverGroupHandler(w http.ResponseWriter, r *http.R resp := toNameserverGroupResponse(nsGroup) util.WriteJSONObject(w, &resp) - } func toServerNSList(apiNSList []api.Nameserver) ([]nbdns.NameServer, error) { diff --git a/management/server/http/nameservers_test.go b/management/server/http/nameservers_test.go index 2a2954e7d..78d8430ac 100644 --- a/management/server/http/nameservers_test.go +++ b/management/server/http/nameservers_test.go @@ -3,16 +3,17 @@ package http import ( "bytes" "encoding/json" - nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/status" - "github.com/stretchr/testify/assert" "io" "net/http" "net/http/httptest" "net/netip" "testing" + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/status" + "github.com/stretchr/testify/assert" + "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -113,16 +114,15 @@ func initNameserversTestData() *Nameservers { return testingNSAccount, testingAccount.Users["test_user"], nil }, }, - authAudience: "", - jwtExtractor: jwtclaims.ClaimsExtractor{ - ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims { + claimsExtractor: jwtclaims.NewClaimsExtractor( + jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { return jwtclaims.AuthorizationClaims{ UserId: "test_user", Domain: "hotmail.com", AccountId: testNSGroupAccountID, } - }, - }, + }), + ), } } diff --git a/management/server/http/peers.go b/management/server/http/peers.go index 910802d4b..e9d21832b 100644 --- a/management/server/http/peers.go +++ b/management/server/http/peers.go @@ -3,27 +3,29 @@ package http import ( "encoding/json" "fmt" + "net/http" + "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" - "net/http" ) // Peers is a handler that returns peers of the account type Peers struct { - accountManager server.AccountManager - authAudience string - jwtExtractor jwtclaims.ClaimsExtractor + accountManager server.AccountManager + claimsExtractor *jwtclaims.ClaimsExtractor } -func NewPeers(accountManager server.AccountManager, authAudience string) *Peers { +func NewPeers(accountManager server.AccountManager, authCfg AuthCfg) *Peers { return &Peers{ accountManager: accountManager, - authAudience: authAudience, - jwtExtractor: *jwtclaims.NewClaimsExtractor(nil), + claimsExtractor: jwtclaims.NewClaimsExtractor( + jwtclaims.WithAudience(authCfg.Audience), + jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), + ), } } @@ -55,7 +57,7 @@ func (h *Peers) deletePeer(accountID, userID string, peerID string, w http.Respo } func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) { - claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + claims := h.claimsExtractor.FromRequestContext(r) account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { util.WriteError(err, w) @@ -78,13 +80,12 @@ func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) { default: util.WriteError(status.Errorf(status.NotFound, "unknown METHOD"), w) } - } func (h *Peers) GetPeers(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: - claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + claims := h.claimsExtractor.FromRequestContext(r) account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { util.WriteError(err, w) diff --git a/management/server/http/peers_test.go b/management/server/http/peers_test.go index e7690b976..026fc668f 100644 --- a/management/server/http/peers_test.go +++ b/management/server/http/peers_test.go @@ -2,13 +2,14 @@ package http import ( "encoding/json" - "github.com/netbirdio/netbird/management/server/http/api" "io" "net" "net/http" "net/http/httptest" "testing" + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/magiconair/properties/assert" @@ -36,16 +37,15 @@ func initTestMetaData(peers ...*server.Peer) *Peers { }, user, nil }, }, - authAudience: "", - jwtExtractor: jwtclaims.ClaimsExtractor{ - ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims { + claimsExtractor: jwtclaims.NewClaimsExtractor( + jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { return jwtclaims.AuthorizationClaims{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_id", } - }, - }, + }), + ), } } diff --git a/management/server/http/routes.go b/management/server/http/routes.go index 93cd5ff78..36104415b 100644 --- a/management/server/http/routes.go +++ b/management/server/http/routes.go @@ -2,6 +2,9 @@ package http import ( "encoding/json" + "net/http" + "unicode/utf8" + "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" @@ -9,29 +12,28 @@ import ( "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/route" - "net/http" - "unicode/utf8" ) // Routes is the routes handler of the account type Routes struct { - jwtExtractor jwtclaims.ClaimsExtractor - accountManager server.AccountManager - authAudience string + accountManager server.AccountManager + claimsExtractor *jwtclaims.ClaimsExtractor } // NewRoutes returns a new instance of Routes handler -func NewRoutes(accountManager server.AccountManager, authAudience string) *Routes { +func NewRoutes(accountManager server.AccountManager, authCfg AuthCfg) *Routes { return &Routes{ accountManager: accountManager, - authAudience: authAudience, - jwtExtractor: *jwtclaims.NewClaimsExtractor(nil), + claimsExtractor: jwtclaims.NewClaimsExtractor( + jwtclaims.WithAudience(authCfg.Audience), + jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), + ), } } // GetAllRoutesHandler returns the list of routes for the account func (h *Routes) GetAllRoutesHandler(w http.ResponseWriter, r *http.Request) { - claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + claims := h.claimsExtractor.FromRequestContext(r) account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { util.WriteError(err, w) @@ -53,7 +55,7 @@ func (h *Routes) GetAllRoutesHandler(w http.ResponseWriter, r *http.Request) { // CreateRouteHandler handles route creation request func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) { - claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + claims := h.claimsExtractor.FromRequestContext(r) account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { util.WriteError(err, w) @@ -92,7 +94,7 @@ func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) { // UpdateRouteHandler handles update to a route identified by a given ID func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) { - claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + claims := h.claimsExtractor.FromRequestContext(r) account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { util.WriteError(err, w) @@ -158,7 +160,7 @@ func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) { // PatchRouteHandler handles patch updates to a route identified by a given ID func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) { - claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + claims := h.claimsExtractor.FromRequestContext(r) account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { util.WriteError(err, w) @@ -299,7 +301,7 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) { // DeleteRouteHandler handles route deletion request func (h *Routes) DeleteRouteHandler(w http.ResponseWriter, r *http.Request) { - claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + claims := h.claimsExtractor.FromRequestContext(r) account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { util.WriteError(err, w) @@ -323,7 +325,7 @@ func (h *Routes) DeleteRouteHandler(w http.ResponseWriter, r *http.Request) { // GetRouteHandler handles a route Get request identified by ID func (h *Routes) GetRouteHandler(w http.ResponseWriter, r *http.Request) { - claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + claims := h.claimsExtractor.FromRequestContext(r) account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { util.WriteError(err, w) diff --git a/management/server/http/routes_test.go b/management/server/http/routes_test.go index f4dbec6cd..f7994c3f6 100644 --- a/management/server/http/routes_test.go +++ b/management/server/http/routes_test.go @@ -4,9 +4,6 @@ import ( "bytes" "encoding/json" "fmt" - "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/status" - "github.com/netbirdio/netbird/route" "io" "net/http" "net/http/httptest" @@ -14,6 +11,10 @@ import ( "strconv" "testing" + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/route" + "github.com/gorilla/mux" "github.com/magiconair/properties/assert" "github.com/netbirdio/netbird/management/server" @@ -142,16 +143,15 @@ func initRoutesTestData() *Routes { return testingAccount, testingAccount.Users["test_user"], nil }, }, - authAudience: "", - jwtExtractor: jwtclaims.ClaimsExtractor{ - ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims { + claimsExtractor: jwtclaims.NewClaimsExtractor( + jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { return jwtclaims.AuthorizationClaims{ UserId: "test_user", Domain: "hotmail.com", AccountId: testAccountID, } - }, - }, + }), + ), } } diff --git a/management/server/http/rules.go b/management/server/http/rules.go index 42f9a25b2..dd281970b 100644 --- a/management/server/http/rules.go +++ b/management/server/http/rules.go @@ -2,6 +2,8 @@ package http import ( "encoding/json" + "net/http" + "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" @@ -9,27 +11,27 @@ import ( "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" "github.com/rs/xid" - "net/http" ) // Rules is a handler that returns rules of the account type Rules struct { - jwtExtractor jwtclaims.ClaimsExtractor - accountManager server.AccountManager - authAudience string + accountManager server.AccountManager + claimsExtractor *jwtclaims.ClaimsExtractor } -func NewRules(accountManager server.AccountManager, authAudience string) *Rules { +func NewRules(accountManager server.AccountManager, authCfg AuthCfg) *Rules { return &Rules{ accountManager: accountManager, - authAudience: authAudience, - jwtExtractor: *jwtclaims.NewClaimsExtractor(nil), + claimsExtractor: jwtclaims.NewClaimsExtractor( + jwtclaims.WithAudience(authCfg.Audience), + jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), + ), } } // GetAllRulesHandler list for the account func (h *Rules) GetAllRulesHandler(w http.ResponseWriter, r *http.Request) { - claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + claims := h.claimsExtractor.FromRequestContext(r) account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { util.WriteError(err, w) @@ -51,7 +53,7 @@ func (h *Rules) GetAllRulesHandler(w http.ResponseWriter, r *http.Request) { // UpdateRuleHandler handles update to a rule identified by a given ID func (h *Rules) UpdateRuleHandler(w http.ResponseWriter, r *http.Request) { - claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + claims := h.claimsExtractor.FromRequestContext(r) account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { util.WriteError(err, w) @@ -122,7 +124,7 @@ func (h *Rules) UpdateRuleHandler(w http.ResponseWriter, r *http.Request) { // PatchRuleHandler handles patch updates to a rule identified by a given ID func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) { - claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + claims := h.claimsExtractor.FromRequestContext(r) account, _, err := h.accountManager.GetAccountFromToken(claims) if err != nil { util.WriteError(err, w) @@ -253,7 +255,6 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) { } rule, err := h.accountManager.UpdateRule(account.Id, ruleID, operations) - if err != nil { util.WriteError(err, w) return @@ -266,7 +267,7 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) { // CreateRuleHandler handles rule creation request func (h *Rules) CreateRuleHandler(w http.ResponseWriter, r *http.Request) { - claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + claims := h.claimsExtractor.FromRequestContext(r) account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { util.WriteError(err, w) @@ -325,7 +326,7 @@ func (h *Rules) CreateRuleHandler(w http.ResponseWriter, r *http.Request) { // DeleteRuleHandler handles rule deletion request func (h *Rules) DeleteRuleHandler(w http.ResponseWriter, r *http.Request) { - claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + claims := h.claimsExtractor.FromRequestContext(r) account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { util.WriteError(err, w) @@ -350,7 +351,7 @@ func (h *Rules) DeleteRuleHandler(w http.ResponseWriter, r *http.Request) { // GetRuleHandler handles a group Get request identified by ID func (h *Rules) GetRuleHandler(w http.ResponseWriter, r *http.Request) { - claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + claims := h.claimsExtractor.FromRequestContext(r) account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { util.WriteError(err, w) diff --git a/management/server/http/rules_test.go b/management/server/http/rules_test.go index 9eade4c1f..b30e7576b 100644 --- a/management/server/http/rules_test.go +++ b/management/server/http/rules_test.go @@ -4,13 +4,14 @@ import ( "bytes" "encoding/json" "fmt" - "github.com/netbirdio/netbird/management/server/http/api" "io" "net/http" "net/http/httptest" "strings" "testing" + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -71,7 +72,7 @@ func initRulesTestData(rules ...*server.Rule) *Rules { return &server.Account{ Id: claims.AccountId, Domain: "hotmail.com", - Rules: map[string]*server.Rule{"id-existed": &server.Rule{ID: "id-existed"}}, + Rules: map[string]*server.Rule{"id-existed": {ID: "id-existed"}}, Groups: map[string]*server.Group{ "F": {ID: "F"}, "G": {ID: "G"}, @@ -82,16 +83,15 @@ func initRulesTestData(rules ...*server.Rule) *Rules { }, user, nil }, }, - authAudience: "", - jwtExtractor: jwtclaims.ClaimsExtractor{ - ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims { + claimsExtractor: jwtclaims.NewClaimsExtractor( + jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { return jwtclaims.AuthorizationClaims{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_id", } - }, - }, + }), + ), } } @@ -264,7 +264,8 @@ func TestRulesWriteRule(t *testing.T) { Flow: server.TrafficFlowBidirectString, Sources: []api.GroupMinimum{ {Id: "G"}, - {Id: "F"}}, + {Id: "F"}, + }, }, }, } @@ -306,7 +307,6 @@ func TestRulesWriteRule(t *testing.T) { } assert.Equal(t, got, tc.expectedRule) - }) } } diff --git a/management/server/http/setupkeys.go b/management/server/http/setupkeys.go index cecc86036..aec7160b0 100644 --- a/management/server/http/setupkeys.go +++ b/management/server/http/setupkeys.go @@ -2,34 +2,36 @@ package http import ( "encoding/json" + "net/http" + "time" + "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" - "net/http" - "time" ) // SetupKeys is a handler that returns a list of setup keys of the account type SetupKeys struct { - accountManager server.AccountManager - jwtExtractor jwtclaims.ClaimsExtractor - authAudience string + accountManager server.AccountManager + claimsExtractor *jwtclaims.ClaimsExtractor } -func NewSetupKeysHandler(accountManager server.AccountManager, authAudience string) *SetupKeys { +func NewSetupKeysHandler(accountManager server.AccountManager, authCfg AuthCfg) *SetupKeys { return &SetupKeys{ accountManager: accountManager, - authAudience: authAudience, - jwtExtractor: *jwtclaims.NewClaimsExtractor(nil), + claimsExtractor: jwtclaims.NewClaimsExtractor( + jwtclaims.WithAudience(authCfg.Audience), + jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), + ), } } // CreateSetupKeyHandler is a POST requests that creates a new SetupKey func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request) { - claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + claims := h.claimsExtractor.FromRequestContext(r) account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { util.WriteError(err, w) @@ -72,7 +74,7 @@ func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request // GetSetupKeyHandler is a GET request to get a SetupKey by ID func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) { - claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + claims := h.claimsExtractor.FromRequestContext(r) account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { util.WriteError(err, w) @@ -97,7 +99,7 @@ func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) { // UpdateSetupKeyHandler is a PUT request to update server.SetupKey func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request) { - claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + claims := h.claimsExtractor.FromRequestContext(r) account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { util.WriteError(err, w) @@ -144,8 +146,7 @@ func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request // GetAllSetupKeysHandler is a GET request that returns a list of SetupKey func (h *SetupKeys) GetAllSetupKeysHandler(w http.ResponseWriter, r *http.Request) { - - claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + claims := h.claimsExtractor.FromRequestContext(r) account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { util.WriteError(err, w) diff --git a/management/server/http/setupkeys_test.go b/management/server/http/setupkeys_test.go index 9e780b86f..861e102ed 100644 --- a/management/server/http/setupkeys_test.go +++ b/management/server/http/setupkeys_test.go @@ -4,16 +4,17 @@ import ( "bytes" "encoding/json" "fmt" - "github.com/gorilla/mux" - "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/status" - "github.com/stretchr/testify/assert" "io" "net/http" "net/http/httptest" "testing" "time" + "github.com/gorilla/mux" + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/status" + "github.com/stretchr/testify/assert" + "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server" @@ -28,7 +29,8 @@ const ( ) func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.SetupKey, updatedSetupKey *server.SetupKey, - user *server.User) *SetupKeys { + user *server.User, +) *SetupKeys { return &SetupKeys{ accountManager: &mock_server.MockAccountManager{ GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { @@ -43,11 +45,13 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup }, Groups: map[string]*server.Group{ "group-1": {ID: "group-1", Peers: []string{"A", "B"}}, - "id-all": {ID: "id-all", Name: "All"}}, + "id-all": {ID: "id-all", Name: "All"}, + }, }, user, nil }, CreateSetupKeyFunc: func(_ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string, - _ int, _ string) (*server.SetupKey, error) { + _ int, _ string, + ) (*server.SetupKey, error) { if keyName == newKey.Name || typ != newKey.Type { return newKey, nil } @@ -75,16 +79,15 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup return []*server.SetupKey{defaultKey}, nil }, }, - authAudience: "", - jwtExtractor: jwtclaims.ClaimsExtractor{ - ExtractClaimsFromRequestContext: func(r *http.Request, authAudience string) jwtclaims.AuthorizationClaims { + claimsExtractor: jwtclaims.NewClaimsExtractor( + jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { return jwtclaims.AuthorizationClaims{ UserId: user.Id, Domain: "hotmail.com", AccountId: testAccountID, } - }, - }, + }), + ), } } @@ -209,7 +212,6 @@ func TestSetupKeysHandlers(t *testing.T) { assertKeys(t, got[0], tc.expectedSetupKeys[0]) return } - }) } } diff --git a/management/server/http/users.go b/management/server/http/users.go index d89c7981f..7cc57e3fd 100644 --- a/management/server/http/users.go +++ b/management/server/http/users.go @@ -2,27 +2,29 @@ package http import ( "encoding/json" + "net/http" + "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/status" - "net/http" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/jwtclaims" ) type UserHandler struct { - accountManager server.AccountManager - authAudience string - jwtExtractor jwtclaims.ClaimsExtractor + accountManager server.AccountManager + claimsExtractor *jwtclaims.ClaimsExtractor } -func NewUserHandler(accountManager server.AccountManager, authAudience string) *UserHandler { +func NewUserHandler(accountManager server.AccountManager, authCfg AuthCfg) *UserHandler { return &UserHandler{ accountManager: accountManager, - authAudience: authAudience, - jwtExtractor: *jwtclaims.NewClaimsExtractor(nil), + claimsExtractor: jwtclaims.NewClaimsExtractor( + jwtclaims.WithAudience(authCfg.Audience), + jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), + ), } } @@ -33,7 +35,7 @@ func (h *UserHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { return } - claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + claims := h.claimsExtractor.FromRequestContext(r) account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { util.WriteError(err, w) @@ -69,8 +71,7 @@ func (h *UserHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { util.WriteError(err, w) return } - util.WriteJSONObject(w, toUserResponse(newUser)) - + util.WriteJSONObject(w, toUserResponse(newUser, claims.UserId)) } // CreateUserHandler creates a User in the system with a status "invited" (effectively this is a user invite). @@ -80,7 +81,7 @@ func (h *UserHandler) CreateUserHandler(w http.ResponseWriter, r *http.Request) return } - claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + claims := h.claimsExtractor.FromRequestContext(r) account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { util.WriteError(err, w) @@ -109,7 +110,7 @@ func (h *UserHandler) CreateUserHandler(w http.ResponseWriter, r *http.Request) util.WriteError(err, w) return } - util.WriteJSONObject(w, toUserResponse(newUser)) + util.WriteJSONObject(w, toUserResponse(newUser, claims.UserId)) } // GetUsers returns a list of users of the account this user belongs to. @@ -120,7 +121,7 @@ func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) { return } - claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + claims := h.claimsExtractor.FromRequestContext(r) account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { util.WriteError(err, w) @@ -135,14 +136,13 @@ func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) { users := make([]*api.User, 0) for _, r := range data { - users = append(users, toUserResponse(r)) + users = append(users, toUserResponse(r, claims.UserId)) } util.WriteJSONObject(w, users) } -func toUserResponse(user *server.UserInfo) *api.User { - +func toUserResponse(user *server.UserInfo, currenUserID string) *api.User { autoGroups := user.AutoGroups if autoGroups == nil { autoGroups = []string{} @@ -158,6 +158,7 @@ func toUserResponse(user *server.UserInfo) *api.User { userStatus = api.UserStatusDisabled } + isCurrent := user.ID == currenUserID return &api.User{ Id: user.ID, Name: user.Name, @@ -165,5 +166,6 @@ func toUserResponse(user *server.UserInfo) *api.User { Role: user.Role, AutoGroups: autoGroups, Status: userStatus, + IsCurrent: &isCurrent, } } diff --git a/management/server/http/users_test.go b/management/server/http/users_test.go index 806c6152d..44494b699 100644 --- a/management/server/http/users_test.go +++ b/management/server/http/users_test.go @@ -40,16 +40,15 @@ func initUsers(user ...*server.User) *UserHandler { return users, nil }, }, - authAudience: "", - jwtExtractor: jwtclaims.ClaimsExtractor{ - ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims { + claimsExtractor: jwtclaims.NewClaimsExtractor( + jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { return jwtclaims.AuthorizationClaims{ UserId: "1", Domain: "hotmail.com", AccountId: "test_id", } - }, - }, + }), + ), } } @@ -57,7 +56,7 @@ func TestGetUsers(t *testing.T) { users := []*server.User{{Id: "1", Role: "admin"}, {Id: "2", Role: "user"}, {Id: "3", Role: "user"}} userHandler := initUsers(users...) - var tt = []struct { + tt := []struct { name string expectedStatus int requestType string diff --git a/management/server/jwtclaims/extractor.go b/management/server/jwtclaims/extractor.go index a794dd3b4..9d60da335 100644 --- a/management/server/jwtclaims/extractor.go +++ b/management/server/jwtclaims/extractor.go @@ -1,8 +1,9 @@ package jwtclaims import ( - "github.com/golang-jwt/jwt" "net/http" + + "github.com/golang-jwt/jwt" ) const ( @@ -14,51 +15,85 @@ const ( ) // Extract function type -type ExtractClaims func(r *http.Request, authAudiance string) AuthorizationClaims +type ExtractClaims func(r *http.Request) AuthorizationClaims // ClaimsExtractor struct that holds the extract function type ClaimsExtractor struct { - ExtractClaimsFromRequestContext ExtractClaims + authAudience string + userIDClaim string + + FromRequestContext ExtractClaims +} + +// ClaimsExtractorOption is a function that configures the ClaimsExtractor +type ClaimsExtractorOption func(*ClaimsExtractor) + +// WithAudience sets the audience for the extractor +func WithAudience(audience string) ClaimsExtractorOption { + return func(c *ClaimsExtractor) { + c.authAudience = audience + } +} + +// WithUserIDClaim sets the user id claim for the extractor +func WithUserIDClaim(userIDClaim string) ClaimsExtractorOption { + return func(c *ClaimsExtractor) { + c.userIDClaim = userIDClaim + } +} + +// WithFromRequestContext sets the function that extracts claims from the request context +func WithFromRequestContext(ec ExtractClaims) ClaimsExtractorOption { + return func(c *ClaimsExtractor) { + c.FromRequestContext = ec + } } // NewClaimsExtractor returns an extractor, and if provided with a function with ExtractClaims signature, // then it will use that logic. Uses ExtractClaimsFromRequestContext by default -func NewClaimsExtractor(e ExtractClaims) *ClaimsExtractor { - var extractFunc ExtractClaims - if extractFunc = e; extractFunc == nil { - extractFunc = ExtractClaimsFromRequestContext +func NewClaimsExtractor(options ...ClaimsExtractorOption) *ClaimsExtractor { + ce := &ClaimsExtractor{} + for _, option := range options { + option(ce) } - - return &ClaimsExtractor{ - ExtractClaimsFromRequestContext: extractFunc, + if ce.FromRequestContext == nil { + ce.FromRequestContext = ce.fromRequestContext } + if ce.userIDClaim == "" { + ce.userIDClaim = UserIDClaim + } + return ce } -// ExtractClaimsFromRequestContext extracts claims from the request context previously filled by the JWT token (after auth) -func ExtractClaimsFromRequestContext(r *http.Request, authAudience string) AuthorizationClaims { - if r.Context().Value(TokenUserProperty) == nil { - return AuthorizationClaims{} - } - token := r.Context().Value(TokenUserProperty).(*jwt.Token) - return ExtractClaimsWithToken(token, authAudience) -} - -// ExtractClaimsWithToken extracts claims from the token (after auth) -func ExtractClaimsWithToken(token *jwt.Token, authAudience string) AuthorizationClaims { +// FromToken extracts claims from the token (after auth) +func (c *ClaimsExtractor) FromToken(token *jwt.Token) AuthorizationClaims { claims := token.Claims.(jwt.MapClaims) jwtClaims := AuthorizationClaims{} - jwtClaims.UserId = claims[UserIDClaim].(string) - accountIdClaim, ok := claims[authAudience+AccountIDSuffix] - if ok { - jwtClaims.AccountId = accountIdClaim.(string) + userID, ok := claims[c.userIDClaim].(string) + if !ok { + return jwtClaims } - domainClaim, ok := claims[authAudience+DomainIDSuffix] + jwtClaims.UserId = userID + accountIDClaim, ok := claims[c.authAudience+AccountIDSuffix] + if ok { + jwtClaims.AccountId = accountIDClaim.(string) + } + domainClaim, ok := claims[c.authAudience+DomainIDSuffix] if ok { jwtClaims.Domain = domainClaim.(string) } - domainCategoryClaim, ok := claims[authAudience+DomainCategorySuffix] + domainCategoryClaim, ok := claims[c.authAudience+DomainCategorySuffix] if ok { jwtClaims.DomainCategory = domainCategoryClaim.(string) } return jwtClaims } + +// fromRequestContext extracts claims from the request context previously filled by the JWT token (after auth) +func (c *ClaimsExtractor) fromRequestContext(r *http.Request) AuthorizationClaims { + if r.Context().Value(TokenUserProperty) == nil { + return AuthorizationClaims{} + } + token := r.Context().Value(TokenUserProperty).(*jwt.Token) + return c.FromToken(token) +} diff --git a/management/server/jwtclaims/extractor_test.go b/management/server/jwtclaims/extractor_test.go index 9f4d7c7d3..d8acd79b6 100644 --- a/management/server/jwtclaims/extractor_test.go +++ b/management/server/jwtclaims/extractor_test.go @@ -2,10 +2,11 @@ package jwtclaims import ( "context" - "github.com/golang-jwt/jwt" - "github.com/stretchr/testify/require" "net/http" "testing" + + "github.com/golang-jwt/jwt" + "github.com/stretchr/testify/require" ) func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audiance string) *http.Request { @@ -31,7 +32,6 @@ func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audiance st } func TestExtractClaimsFromRequestContext(t *testing.T) { - type test struct { name string inputAuthorizationClaims AuthorizationClaims @@ -99,12 +99,84 @@ func TestExtractClaimsFromRequestContext(t *testing.T) { for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5} { t.Run(testCase.name, func(t *testing.T) { - request := newTestRequestWithJWT(t, testCase.inputAuthorizationClaims, testCase.inputAudiance) - extractedClaims := ExtractClaimsFromRequestContext(request, testCase.inputAudiance) + extractor := NewClaimsExtractor(WithAudience(testCase.inputAudiance)) + extractedClaims := extractor.FromRequestContext(request) testCase.testingFunc(t, testCase.inputAuthorizationClaims, extractedClaims, testCase.expectedMSG) }) } } + +func TestExtractClaimsSetOptions(t *testing.T) { + type test struct { + name string + extractor *ClaimsExtractor + check func(t *testing.T, c test) + } + + testCase1 := test{ + name: "No custom options", + extractor: NewClaimsExtractor(), + check: func(t *testing.T, c test) { + if c.extractor.authAudience != "" { + t.Error("audience should be empty") + return + } + if c.extractor.userIDClaim != UserIDClaim { + t.Errorf("user id claim should be default, expected %s, got %s", UserIDClaim, c.extractor.userIDClaim) + return + } + if c.extractor.FromRequestContext == nil { + t.Error("from request context should not be nil") + return + } + }, + } + + testCase2 := test{ + name: "Custom audience", + extractor: NewClaimsExtractor(WithAudience("https://login/")), + check: func(t *testing.T, c test) { + if c.extractor.authAudience != "https://login/" { + t.Errorf("audience expected %s, got %s", "https://login/", c.extractor.authAudience) + return + } + }, + } + + testCase3 := test{ + name: "Custom user id claim", + extractor: NewClaimsExtractor(WithUserIDClaim("customUserId")), + check: func(t *testing.T, c test) { + if c.extractor.userIDClaim != "customUserId" { + t.Errorf("user id claim expected %s, got %s", "customUserId", c.extractor.userIDClaim) + return + } + }, + } + + testCase4 := test{ + name: "Custom extractor from request context", + extractor: NewClaimsExtractor( + WithFromRequestContext(func(r *http.Request) AuthorizationClaims { + return AuthorizationClaims{ + UserId: "testCustomRequest", + } + })), + check: func(t *testing.T, c test) { + claims := c.extractor.FromRequestContext(&http.Request{}) + if claims.UserId != "testCustomRequest" { + t.Errorf("user id claim expected %s, got %s", "testCustomRequest", claims.UserId) + return + } + }, + } + + for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4} { + t.Run(testCase.name, func(t *testing.T) { + testCase.check(t, testCase) + }) + } +}